Skip to content

Commit

Permalink
Merge pull request #3512 from slahirucd7/subOrgImpl
Browse files Browse the repository at this point in the history
Correction to get correct subscription policy from SubscriptionDataStore
  • Loading branch information
slahirucd7 authored May 16, 2024
2 parents aa68bcf + 4d61ccf commit 848ec6b
Show file tree
Hide file tree
Showing 11 changed files with 51 additions and 113 deletions.
4 changes: 2 additions & 2 deletions adapter/internal/discovery/xds/rate_limiter_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ func parseRateLimitUnitFromSubscriptionPolicy(name string) (rls_config.RateLimit
return rls_config.RateLimitUnit_SECOND, nil
case "min":
return rls_config.RateLimitUnit_MINUTE, nil
case "hours":
case "hour":
return rls_config.RateLimitUnit_HOUR, nil
case "days":
case "day":
return rls_config.RateLimitUnit_DAY, nil
default:
return rls_config.RateLimitUnit_UNKNOWN, fmt.Errorf("invalid rate limit unit %q", name)
Expand Down
4 changes: 2 additions & 2 deletions adapter/internal/discovery/xds/rate_limiter_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ func TestAddSubscriptionLevelRateLimitPolicy(t *testing.T) {
QuotaType: "requestCount",
RequestCount: &types.SubscriptionRequestCount{
RequestCount: 300,
TimeUnit: "hours",
TimeUnit: "hour",
},
},
Organization: "org1",
Expand All @@ -620,7 +620,7 @@ func TestAddSubscriptionLevelRateLimitPolicy(t *testing.T) {
QuotaType: "eventCount",
RequestCount: &types.SubscriptionRequestCount{
RequestCount: 300,
TimeUnit: "hours",
TimeUnit: "hour",
},
},
Organization: "org1",
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,22 @@ public class ThrottlingPolicyRequestHandler extends RequestHandler {
@Override
public ResponsePayload handleRequest(String[] params, String requestType) throws Exception {
String policyName = null;
String organizationId = null;
ResponsePayload responsePayload;
if (params != null) {
for (String param : params) {
String[] keyVal = param.split("=");
if (AdminConstants.Parameters.NAME.equals(keyVal[0])) {
policyName = keyVal[1];
} else if (AdminConstants.Parameters.ORGANIZATION_ID.equals(keyVal[0])) {
organizationId = keyVal[1];
}
}
}
if (AdminConstants.APPLICATION_THROTTLING_POLICY_TYPE.equals(requestType)) {
responsePayload = getApplicationPolicies(policyName);
} else {
responsePayload = getSubscriptionPolicies(policyName);
responsePayload = getSubscriptionPolicies(organizationId, policyName);
}
return responsePayload;
}
Expand All @@ -59,8 +62,10 @@ private ResponsePayload getApplicationPolicies(String policyName) throws JsonPro
return AdminUtils.buildResponsePayload(applicationPolicyList, HttpResponseStatus.OK, false);
}

private ResponsePayload getSubscriptionPolicies(String policyName) throws JsonProcessingException {
List<SubscriptionPolicy> subscriptionPolicies = super.dataStore.getMatchingSubscriptionPolicies(policyName);
private ResponsePayload getSubscriptionPolicies(String organizationId, String policyName)
throws JsonProcessingException {
List<SubscriptionPolicy> subscriptionPolicies = super.dataStore.getMatchingSubscriptionPolicies(
organizationId, policyName);
SubscriptionPolicyList subscriptionPolicyList = AdminUtils.toSubscriptionPolicyList(subscriptionPolicies);
return AdminUtils.buildResponsePayload(subscriptionPolicyList, HttpResponseStatus.OK, false);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,40 +22,25 @@
*/
public class APIConstants {

//open API extensions
public static final String X_WSO2_BASE_PATH = "x-wso2-basepath";

public static final String GW_VHOST_PARAM = "vHost";
public static final String GW_BASE_PATH_PARAM = "basePath";
public static final String GW_RES_PATH_PARAM = "path";
public static final String GW_RES_METHOD_PARAM = "method";
public static final String GW_VERSION_PARAM = "version";
public static final String GW_API_NAME_PARAM = "name";
public static final String PROTOTYPED_LIFE_CYCLE_STATUS = "PROTOTYPED";
public static final String PUBLISHED_LIFE_CYCLE_STATUS = "PUBLISHED";
public static final String UNLIMITED_TIER = "Unlimited";
public static final String UNAUTHENTICATED_TIER = "Unauthenticated";
public static final String END_USER_ANONYMOUS = "anonymous";
public static final String ANONYMOUS_PREFIX = "anon:";
public static final String END_USER_UNKNOWN = "unknown";
public static final String GATEWAY_SIGNED_JWT_CACHE = "SignedJWTParseCache";
public static final String DEFAULT_ISSUER = "Resident Key Manager";
public static final String GATEWAY_PUBLIC_CERTIFICATE_ALIAS = "gateway_certificate_alias";
public static final String WSO2_PUBLIC_CERTIFICATE_ALIAS = "wso2carbon";
public static final String HTTPS_PROTOCOL = "https";
public static final String SUPER_TENANT_DOMAIN_NAME = "carbon.super";
public static final String BANDWIDTH_TYPE = "bandwidthVolume";
public static final String INTERNAL_WEB_APP_EP = "/internal/data/v1";
public static final String AUTHORIZATION_HEADER_DEFAULT = "Authorization";
public static final String AUTHORIZATION_BASIC = "Basic ";
public static final String AUTHORIZATION_BEARER = "Bearer ";
public static final String HEADER_TENANT = "xWSO2Tenant";
public static final String DEFAULT_VERSION_PREFIX = "_default_";
public static final String DEFAULT_WEBSOCKET_VERSION = "defaultVersion";
public static final String DELEM_COLON = ":";

public static final String API_KEY_TYPE_PRODUCTION = "PRODUCTION";
public static final String API_KEY_TYPE_SANDBOX = "SANDBOX";
public static final String TENANT_DOMAIN_SEPARATOR = "@";
public static final String DEFAULT_ENVIRONMENT_NAME = "Default";

public static final String AUTHORIZATION_HEADER_BASIC = "Basic";
Expand All @@ -68,14 +53,6 @@ public class APIConstants {
public static final String API_SECURITY_MUTUAL_SSL_MANDATORY = "mutualssl_mandatory";
public static final String API_SECURITY_OAUTH_BASIC_AUTH_API_KEY_MANDATORY = "oauth_basic_auth_api_key_mandatory";
public static final String WWW_AUTHENTICATE = "WWW-Authenticate";

public static final String BEGIN_CERTIFICATE_STRING = "-----BEGIN CERTIFICATE-----\n";
public static final String END_CERTIFICATE_STRING = "-----END CERTIFICATE-----";
public static final String BEGIN_PUBLIC_KEY_STRING = "-----BEGIN PUBLIC KEY-----\n";
public static final String END_PUBLIC_KEY_STRING = "-----END PUBLIC KEY-----";
public static final String OAUTH2_DEFAULT_SCOPE = "default";
public static final String EVENT_TYPE = "eventType";
public static final String EVENT_TIMESTAMP = "timestamp";
public static final String EVENT_PAYLOAD = "event";
public static final String EVENT_PAYLOAD_DATA = "payloadData";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.wso2.choreo.connect.enforcer.models;

import org.apache.commons.lang.StringUtils;
import org.wso2.choreo.connect.enforcer.common.CacheableEntity;
import org.wso2.choreo.connect.enforcer.constants.APIConstants;
import org.wso2.choreo.connect.enforcer.subscription.SubscriptionDataStoreUtil;
Expand Down Expand Up @@ -87,6 +88,9 @@ public void setTierName(String name) {
}

public String getOrganization() {
if (StringUtils.isEmpty(organization)) {
return APIConstants.SUPER_TENANT_DOMAIN_NAME;
}
return organization;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,16 @@ public void setGraphQLMaxDepth(int graphQLMaxDepth) {
}
@Override
public String getCacheKey() {

return PolicyType.SUBSCRIPTION + SubscriptionDataStoreUtil.getPolicyCacheKey(getName());

return PolicyType.SUBSCRIPTION + SubscriptionDataStoreUtil.DELEM_PERIOD + getOrganization() +
SubscriptionDataStoreUtil.DELEM_PERIOD + getName();
}

@Override
public String toString() {
return "SubscriptionPolicy [rateLimitCount=" + rateLimitCount + ", rateLimitTimeUnit=" + rateLimitTimeUnit
+ ", stopOnQuotaReach=" + stopOnQuotaReach + ", getId()=" + getId() + ", getQuotaType()="
+ getQuotaType() + ", isContentAware()=" + isContentAware() + ", getTenantId()=" + getTenantId()
+ ", getName()=" + getName() + "]";
+ ", getName()=" + getName() + ", getOrganization()=" + getOrganization() + "]";
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ private static void validate(APIKeyValidationInfoDTO infoDTO, SubscriptionDataSt
String apiTier = apiConfig.getTier();

ApplicationPolicy appPolicy = datastore.getApplicationPolicyByName(app.getPolicy());
SubscriptionPolicy subPolicy = datastore.getSubscriptionPolicyByName(sub.getPolicyId());
SubscriptionPolicy subPolicy = datastore.getSubscriptionPolicyByOrgIdAndName(apiConfig.getOrganizationId(),
sub.getPolicyId());
ApiPolicy apiPolicy = datastore.getApiPolicyByName(apiTier);
boolean isContentAware = false;
if (appPolicy.isContentAware() || subPolicy.isContentAware() || (apiPolicy != null && apiPolicy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,19 @@ public class JWTAuthenticator implements Authenticator {
private final boolean isGatewayTokenCacheEnabled;
private AbstractAPIMgtGatewayJWTGenerator jwtGenerator;
private static final Set<String> prodTokenNonProdAllowedOrgs = new HashSet<>();
private static final String orgList = System.getenv("CUSTOM_SUBSCRIPTION_POLICY_HANDLING_ORG");
private static Set<String> orgSet = new HashSet<>();

static {
if (System.getenv("PROD_TOKEN_NONPROD_ALLOWED_ORGS") != null) {
Collections.addAll(prodTokenNonProdAllowedOrgs,
System.getenv("PROD_TOKEN_NONPROD_ALLOWED_ORGS").split("\\s+"));
}
if (orgList != null) {
orgSet = Stream.of(orgList.trim().split("\\s*,\\s*"))
.collect(Collectors.toSet());
}
}
private static String orgList = System.getenv("CUSTOM_SUBSCRIPTION_POLICY_HANDLING_ORG");

public JWTAuthenticator() {
EnforcerConfig enforcerConfig = ConfigHolder.getInstance().getConfig();
Expand Down Expand Up @@ -329,27 +334,26 @@ public AuthenticationContext authenticate(RequestContext requestContext) throws
String subPolicyName = authenticationContext.getTier();
requestContext.addMetadataToMap("ratelimit:subscription", subscriptionId);
requestContext.addMetadataToMap("ratelimit:usage-policy", subPolicyName);
if (datastore.getSubscriptionPolicyByName(subPolicyName) != null &&
StringUtils.isNotEmpty(orgList)) {
SubscriptionPolicy subPolicy = datastore.getSubscriptionPolicyByName(subPolicyName);
Set<String> orgSet = Stream.of(orgList.trim().split("\\s*,\\s*"))
.collect(Collectors.toSet());
if (StringUtils.isNotEmpty(subPolicy.getOrganization()) &&
orgSet.contains(subPolicy.getOrganization()) || orgList.equals("*")) {
requestContext.addMetadataToMap("ratelimit:organization", subPolicy.getOrganization());
} else {
requestContext.addMetadataToMap("ratelimit:organization",
APIConstants.SUPER_TENANT_DOMAIN_NAME);
}

String matchedApiOrganizationId = requestContext.getMatchedAPI().getOrganizationId();
if (datastore.getSubscriptionPolicyByOrgIdAndName(matchedApiOrganizationId, subPolicyName)
!= null) {
SubscriptionPolicy subPolicy = datastore.getSubscriptionPolicyByOrgIdAndName
(matchedApiOrganizationId, subPolicyName);
String metaDataOrgId = StringUtils.isNotEmpty(orgList) &&
(orgSet.contains(subPolicy.getOrganization()) || orgList.equals("*")) ?
subPolicy.getOrganization() : APIConstants.SUPER_TENANT_DOMAIN_NAME;
log.debug("Subscription rate-limiting will be evaluated for the organization: " +
metaDataOrgId);
requestContext.addMetadataToMap("ratelimit:organization", metaDataOrgId);
} else {
requestContext.addMetadataToMap("ratelimit:organization",
APIConstants.SUPER_TENANT_DOMAIN_NAME);
}
if (log.isDebugEnabled()) {
log.debug("Organization ID: " +
requestContext.getMetadataMap().get("ratelimit:organization")
+ ", SubscriptionId: " + subscriptionId + ", SubscriptionPolicy: " + subPolicyName
+ " will be evaluated for subscription rate-limiting");
log.debug("Organization ID: " + matchedApiOrganizationId + ", SubscriptionId: "
+ subscriptionId + ", SubscriptionPolicy: " + subPolicyName +
" will be evaluated for subscription rate-limiting");
}
}
return authenticationContext;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public interface SubscriptionDataStore {
* @param policyName Name of the Throttling Policy
* @return Subscription Throttling Policy
*/
SubscriptionPolicy getSubscriptionPolicyByName(String policyName);
SubscriptionPolicy getSubscriptionPolicyByOrgIdAndName(String orgId, String policyName);

/**
* Gets Application Throttling Policy by the name and Tenant Id.
Expand Down Expand Up @@ -207,6 +207,6 @@ void addApplicationKeyMappings(
* @param policyName The name of the policy
* @return
*/
List<SubscriptionPolicy> getMatchingSubscriptionPolicies(String policyName);
List<SubscriptionPolicy> getMatchingSubscriptionPolicies(String organizationId, String policyName);

}
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ public enum PolicyType {
private Map<String, ApplicationPolicy> appPolicyMap;
private Map<String, Subscription> subscriptionMap;
private Map<String, Subscription> apiVersionRangeSubscriptionMap;
private String tenantDomain = APIConstants.SUPER_TENANT_DOMAIN_NAME;

SubscriptionDataStoreImpl() {
}
Expand Down Expand Up @@ -111,10 +110,11 @@ public API getApiByContextAndVersion(String uuid) {
}

@Override
public SubscriptionPolicy getSubscriptionPolicyByName(String policyName) {

public SubscriptionPolicy getSubscriptionPolicyByOrgIdAndName(String orgId, String policyName) {
String organizationId = StringUtils.isNotEmpty(orgId) ? APIConstants.SUPER_TENANT_DOMAIN_NAME : orgId;
String key = PolicyType.SUBSCRIPTION +
SubscriptionDataStoreUtil.getPolicyCacheKey(policyName);
SubscriptionDataStoreUtil.DELEM_PERIOD + organizationId +
SubscriptionDataStoreUtil.DELEM_PERIOD + policyName;
return subscriptionPolicyMap.get(key);
}

Expand Down Expand Up @@ -541,12 +541,12 @@ public List<ApplicationPolicy> getMatchingApplicationPolicies(String policyName)
}

@Override
public List<SubscriptionPolicy> getMatchingSubscriptionPolicies(String policyName) {
public List<SubscriptionPolicy> getMatchingSubscriptionPolicies(String organizationId, String policyName) {
List<SubscriptionPolicy> subscriptionPolicies = new ArrayList<>();
if (StringUtils.isEmpty(policyName)) {
subscriptionPolicies.addAll(this.subscriptionPolicyMap.values());
} else {
SubscriptionPolicy policy = this.getSubscriptionPolicyByName(policyName);
SubscriptionPolicy policy = this.getSubscriptionPolicyByOrgIdAndName(organizationId, policyName);
subscriptionPolicies.add(policy);
}
return subscriptionPolicies;
Expand Down

0 comments on commit 848ec6b

Please sign in to comment.