diff --git a/components/apimgt/org.wso2.carbon.apimgt.gateway/src/main/java/org/wso2/carbon/apimgt/gateway/handlers/security/APIKeyValidator.java b/components/apimgt/org.wso2.carbon.apimgt.gateway/src/main/java/org/wso2/carbon/apimgt/gateway/handlers/security/APIKeyValidator.java index 4b07281a8aeb..ebc507f2d28e 100644 --- a/components/apimgt/org.wso2.carbon.apimgt.gateway/src/main/java/org/wso2/carbon/apimgt/gateway/handlers/security/APIKeyValidator.java +++ b/components/apimgt/org.wso2.carbon.apimgt.gateway/src/main/java/org/wso2/carbon/apimgt/gateway/handlers/security/APIKeyValidator.java @@ -30,6 +30,8 @@ import org.apache.synapse.rest.RESTUtils; import org.apache.synapse.rest.Resource; import org.apache.synapse.rest.dispatch.RESTDispatcher; +import org.wso2.carbon.apimgt.api.APIDefinition; +import org.wso2.carbon.apimgt.api.APIManagementException; import org.wso2.carbon.apimgt.api.model.URITemplate; import org.wso2.carbon.apimgt.gateway.APIMgtGatewayConstants; import org.wso2.carbon.apimgt.gateway.MethodStats; @@ -39,6 +41,8 @@ import org.wso2.carbon.apimgt.impl.APIConstants; import org.wso2.carbon.apimgt.impl.APIManagerConfiguration; import org.wso2.carbon.apimgt.impl.caching.CacheProvider; +import org.wso2.carbon.apimgt.impl.definitions.OAS2Parser; +import org.wso2.carbon.apimgt.impl.definitions.OASParserUtil; import org.wso2.carbon.apimgt.impl.dto.APIInfoDTO; import org.wso2.carbon.apimgt.impl.dto.APIKeyValidationInfoDTO; import org.wso2.carbon.apimgt.impl.dto.ResourceInfoDTO; @@ -81,6 +85,8 @@ public class APIKeyValidator { protected Log log = LogFactory.getLog(getClass()); + private ArrayList uriTemplates = null; + public APIKeyValidator(AxisConfiguration axisConfig) { //check the client type from config String keyValidatorClientType = getKeyValidatorClientType(); @@ -294,7 +300,7 @@ public String getResourceAuthenticationScheme(MessageContext synCtx) throws APIS for (VerbInfoDTO verb : verbInfoList) { authType = verb.getAuthType(); if (authType == null) { - authType = StringUtils.capitalize(APIConstants.AUTH_TYPE_NONE.toLowerCase());; + authType = StringUtils.capitalize(APIConstants.AUTH_APPLICATION_OR_USER_LEVEL_TOKEN.toLowerCase()); } if (!StringUtils.capitalize(APIConstants.AUTH_TYPE_NONE.toLowerCase()).equals(authType)) { break; @@ -461,9 +467,9 @@ public List findMatchingVerb(MessageContext synCtx) throws Resource String apiType = (String) synCtx.getProperty(APIMgtGatewayConstants.API_TYPE); if (APIConstants.ApiTypes.PRODUCT_API.name().equalsIgnoreCase(apiType)) { - apiInfoDTO = doGetAPIProductInfo(apiContext, apiVersion); + apiInfoDTO = doGetAPIProductInfo(synCtx, apiContext, apiVersion); } else { - apiInfoDTO = doGetAPIInfo(apiContext, apiVersion); + apiInfoDTO = doGetAPIInfo(synCtx, apiContext, apiVersion); } if (Util.tracingEnabled()) { @@ -539,15 +545,15 @@ private boolean isResourcePathMatching(String resourceString, ResourceInfoDTO re } @MethodStats - private APIInfoDTO doGetAPIInfo(String context, String apiVersion) throws APISecurityException { - ArrayList uriTemplates = getAllURITemplates(context, apiVersion); + private APIInfoDTO doGetAPIInfo(MessageContext messageContext, String context, String apiVersion) throws APISecurityException { + ArrayList uriTemplates = getAllURITemplates(messageContext, context, apiVersion); return mapToAPIInfo(uriTemplates, context, apiVersion); } @MethodStats - private APIInfoDTO doGetAPIProductInfo(String context, String apiVersion) throws APISecurityException { - ArrayList uriTemplates = getAPIProductURITemplates(context, apiVersion); + private APIInfoDTO doGetAPIProductInfo(MessageContext messageContext, String context, String apiVersion) throws APISecurityException { + ArrayList uriTemplates = getAPIProductURITemplates(messageContext, context, apiVersion); return mapToAPIInfo(uriTemplates, context, apiVersion); } @@ -589,13 +595,14 @@ private APIInfoDTO mapToAPIInfo(ArrayList uriTemplates, String cont } /** + * @param messageContext The message context * @param context API context of API * @param apiVersion Version of API * @param requestPath Incoming request path * @param httpMethod http method of request * @return verbInfoDTO which contains throttling tier for given resource and verb+resource key */ - public VerbInfoDTO getVerbInfoDTOFromAPIData(String context, String apiVersion, String requestPath, String httpMethod) + public VerbInfoDTO getVerbInfoDTOFromAPIData(MessageContext messageContext, String context, String apiVersion, String requestPath, String httpMethod) throws APISecurityException { String cacheKey = context + ':' + apiVersion; @@ -604,7 +611,7 @@ public VerbInfoDTO getVerbInfoDTOFromAPIData(String context, String apiVersion, apiInfoDTO = (APIInfoDTO) getResourceCache().get(cacheKey); } if (apiInfoDTO == null) { - apiInfoDTO = doGetAPIInfo(context, apiVersion); + apiInfoDTO = doGetAPIInfo(messageContext, context, apiVersion); if (isGatewayAPIResourceValidationEnabled) { getResourceCache().put(cacheKey, apiInfoDTO); } @@ -701,15 +708,53 @@ public VerbInfoDTO getVerbInfoDTOFromAPIData(String context, String apiVersion, @MethodStats - protected ArrayList getAllURITemplates(String context, String apiVersion) + protected ArrayList getAllURITemplates(MessageContext messageContext, String context, String apiVersion) throws APISecurityException { - return dataStore.getAllURITemplates(context, apiVersion); + if (uriTemplates == null) { + synchronized (this) { + if (uriTemplates == null) { + String swagger = (String) messageContext.getProperty(APIMgtGatewayConstants.OPEN_API_STRING); + if (swagger != null) { + APIDefinition oasParser; + try { + oasParser = OASParserUtil.getOASParser(swagger); + uriTemplates = new ArrayList<>(); + uriTemplates.addAll(oasParser.getURITemplates(swagger)); + return uriTemplates; + } catch (APIManagementException e) { + log.error("Error while parsing swagger content to get URI Templates", e); + } + } + uriTemplates = dataStore.getAllURITemplates(context, apiVersion); + } + } + } + return uriTemplates; } @MethodStats - protected ArrayList getAPIProductURITemplates(String context, String apiVersion) + protected ArrayList getAPIProductURITemplates(MessageContext messageContext, String context, String apiVersion) throws APISecurityException { - return dataStore.getAPIProductURITemplates(context, apiVersion); + if (uriTemplates == null) { + synchronized (this) { + if (uriTemplates == null) { + String swagger = (String) messageContext.getProperty(APIMgtGatewayConstants.OPEN_API_STRING); + if (swagger != null) { + APIDefinition oasParser; + try { + oasParser = OASParserUtil.getOASParser(swagger); + uriTemplates = new ArrayList<>(); + uriTemplates.addAll(oasParser.getURITemplates(swagger)); + return uriTemplates; + } catch (APIManagementException e) { + log.error("Error while parsing swagger content to get URI Templates", e); + } + } + uriTemplates = dataStore.getAPIProductURITemplates(context, apiVersion); + } + } + } + return uriTemplates; } protected void setGatewayAPIResourceValidationEnabled(boolean gatewayAPIResourceValidationEnabled) { diff --git a/components/apimgt/org.wso2.carbon.apimgt.gateway/src/main/java/org/wso2/carbon/apimgt/gateway/handlers/security/basicauth/BasicAuthCredentialValidator.java b/components/apimgt/org.wso2.carbon.apimgt.gateway/src/main/java/org/wso2/carbon/apimgt/gateway/handlers/security/basicauth/BasicAuthCredentialValidator.java index 832d484bbc3a..cd1fc26b8867 100644 --- a/components/apimgt/org.wso2.carbon.apimgt.gateway/src/main/java/org/wso2/carbon/apimgt/gateway/handlers/security/basicauth/BasicAuthCredentialValidator.java +++ b/components/apimgt/org.wso2.carbon.apimgt.gateway/src/main/java/org/wso2/carbon/apimgt/gateway/handlers/security/basicauth/BasicAuthCredentialValidator.java @@ -36,7 +36,6 @@ import org.wso2.carbon.apimgt.impl.caching.CacheProvider; import org.wso2.carbon.authenticator.stub.AuthenticationAdminStub; import org.wso2.carbon.authenticator.stub.LoginAuthenticationExceptionException; -import org.wso2.carbon.base.ServerConfiguration; import org.wso2.carbon.um.ws.api.stub.RemoteUserStoreManagerServiceStub; import org.wso2.carbon.um.ws.api.stub.RemoteUserStoreManagerServiceUserStoreExceptionException; import org.wso2.carbon.user.api.UserStoreException; @@ -55,7 +54,6 @@ import java.security.NoSuchAlgorithmException; import java.util.ArrayList; import java.util.HashMap; -import java.util.LinkedHashMap; /** * This class will validate the basic auth credentials. @@ -188,15 +186,7 @@ public boolean validateScopes(String username, OpenAPI openAPI, MessageContext s String resourceRoles = null; String resourceScope = OpenAPIUtils.getScopesOfResource(openAPI, synCtx); if (resourceScope != null) { - ArrayList apiScopes = OpenAPIUtils.getScopeToRoleMappingOfApi(openAPI, synCtx); - if (apiScopes != null) { - for (LinkedHashMap scope : apiScopes) { - if (resourceScope.equals(scope.get(APIConstants.SWAGGER_SCOPE_KEY))) { - resourceRoles = (String) scope.get(APIConstants.SWAGGER_ROLES); - break; - } - } - } + resourceRoles = OpenAPIUtils.getRolesOfScope(openAPI, synCtx, resourceScope); } if (StringUtils.isNotBlank(resourceRoles)) { diff --git a/components/apimgt/org.wso2.carbon.apimgt.gateway/src/main/java/org/wso2/carbon/apimgt/gateway/utils/OpenAPIUtils.java b/components/apimgt/org.wso2.carbon.apimgt.gateway/src/main/java/org/wso2/carbon/apimgt/gateway/utils/OpenAPIUtils.java index b6fb0e9285d6..95aab7aaa20f 100644 --- a/components/apimgt/org.wso2.carbon.apimgt.gateway/src/main/java/org/wso2/carbon/apimgt/gateway/utils/OpenAPIUtils.java +++ b/components/apimgt/org.wso2.carbon.apimgt.gateway/src/main/java/org/wso2/carbon/apimgt/gateway/utils/OpenAPIUtils.java @@ -71,35 +71,71 @@ public static String getResourceAuthenticationScheme(OpenAPI openAPI, MessageCon public static String getScopesOfResource(OpenAPI openAPI, MessageContext synCtx) { Map vendorExtensions = getPathItemExtensions(synCtx, openAPI); if (vendorExtensions != null) { - return (String) vendorExtensions.get(APIConstants.SWAGGER_X_SCOPE); + String resourceScope = (String) vendorExtensions.get(APIConstants.SWAGGER_X_SCOPE); + if (resourceScope == null) { + // If x-scope not found in swagger, check for the scopes in security + ArrayList securityScopes = getPathItemSecurityScopes(synCtx, openAPI); + if (securityScopes == null || securityScopes.isEmpty()) { + return null; + } else { + return securityScopes.get(0); + } + } else { + return resourceScope; + } } return null; } /** - * Return the scope-role mapping of an API + * Return the roles of a given scope attached to a resource using the API swagger. * * @param openAPI OpenAPI of the API * @param synCtx The message containing resource request - * @return the scope-role mapping + * @param resourceScope The scope of the resource + * @return the roles of the scope in the comma separated format */ - public static ArrayList getScopeToRoleMappingOfApi(OpenAPI openAPI, MessageContext synCtx) { - Map vendorExtensions = getPathItemExtensions(synCtx, openAPI); + public static String getRolesOfScope(OpenAPI openAPI, MessageContext synCtx, String resourceScope) { + String resourceRoles = null; + Map vendorExtensions = getPathItemExtensions(synCtx, openAPI); if (vendorExtensions != null) { - String resourceScope = getScopesOfResource(openAPI, synCtx); if (StringUtils.isNotBlank(resourceScope)) { - LinkedHashMap swaggerWSO2Security = (LinkedHashMap) openAPI.getExtensions() - .get(APIConstants.SWAGGER_X_WSO2_SECURITY); - if (swaggerWSO2Security != null) { - LinkedHashMap swaggerObjectAPIM = (LinkedHashMap) swaggerWSO2Security - .get(APIConstants.SWAGGER_OBJECT_NAME_APIM); - if (swaggerObjectAPIM != null) { - return (ArrayList) swaggerObjectAPIM.get(APIConstants.SWAGGER_X_WSO2_SCOPES); + if (openAPI.getExtensions() != null && + openAPI.getExtensions().get(APIConstants.SWAGGER_X_WSO2_SECURITY) != null) { + LinkedHashMap swaggerWSO2Security = (LinkedHashMap) openAPI.getExtensions() + .get(APIConstants.SWAGGER_X_WSO2_SECURITY); + if (swaggerWSO2Security != null && + swaggerWSO2Security.get(APIConstants.SWAGGER_OBJECT_NAME_APIM) != null) { + LinkedHashMap swaggerObjectAPIM = (LinkedHashMap) swaggerWSO2Security + .get(APIConstants.SWAGGER_OBJECT_NAME_APIM); + if (swaggerObjectAPIM != null && swaggerObjectAPIM.get(APIConstants.SWAGGER_X_WSO2_SCOPES) != null) { + ArrayList apiScopes = + (ArrayList) swaggerObjectAPIM.get(APIConstants.SWAGGER_X_WSO2_SCOPES); + for (LinkedHashMap scope: apiScopes) { + if (resourceScope.equals(scope.get(APIConstants.SWAGGER_SCOPE_KEY))) { + resourceRoles = (String) scope.get(APIConstants.SWAGGER_ROLES); + break; + } + } + } } } } } + + if (resourceRoles == null) { + Map extensions = openAPI.getComponents().getSecuritySchemes() + .get(APIConstants.SWAGGER_APIM_DEFAULT_SECURITY).getExtensions(); + if (extensions != null && extensions.get(APIConstants.SWAGGER_X_SCOPES_BINDINGS) != null) { + LinkedHashMap scopeBindings = + (LinkedHashMap) extensions.get(APIConstants.SWAGGER_X_SCOPES_BINDINGS); + if (scopeBindings != null) { + return (String) scopeBindings.get(resourceScope); + } + } + } + return null; } @@ -151,4 +187,41 @@ private static Map getPathItemExtensions(MessageContext synCtx, } return null; } + + private static ArrayList getPathItemSecurityScopes(MessageContext synCtx, OpenAPI openAPI) { + if (openAPI != null) { + String apiElectedResource = (String) synCtx.getProperty(APIConstants.API_ELECTED_RESOURCE); + org.apache.axis2.context.MessageContext axis2MessageContext = + ((Axis2MessageContext) synCtx).getAxis2MessageContext(); + String httpMethod = (String) axis2MessageContext.getProperty(APIConstants.DigestAuthConstants.HTTP_METHOD); + PathItem path = openAPI.getPaths().get(apiElectedResource); + + if (path != null) { + switch (httpMethod) { + case APIConstants.HTTP_GET: + return (ArrayList) path.getGet().getSecurity().get(0) + .get(APIConstants.SWAGGER_APIM_DEFAULT_SECURITY); + case APIConstants.HTTP_POST: + return (ArrayList) path.getPost().getSecurity().get(0) + .get(APIConstants.SWAGGER_APIM_DEFAULT_SECURITY); + case APIConstants.HTTP_PUT: + return (ArrayList) path.getPut().getSecurity().get(0) + .get(APIConstants.SWAGGER_APIM_DEFAULT_SECURITY); + case APIConstants.HTTP_DELETE: + return (ArrayList) path.getDelete().getSecurity().get(0) + .get(APIConstants.SWAGGER_APIM_DEFAULT_SECURITY); + case APIConstants.HTTP_HEAD: + return (ArrayList) path.getHead().getSecurity().get(0) + .get(APIConstants.SWAGGER_APIM_DEFAULT_SECURITY); + case APIConstants.HTTP_OPTIONS: + return (ArrayList) path.getOptions().getSecurity().get(0) + .get(APIConstants.SWAGGER_APIM_DEFAULT_SECURITY); + case APIConstants.HTTP_PATCH: + return (ArrayList) path.getPatch().getSecurity().get(0) + .get(APIConstants.SWAGGER_APIM_DEFAULT_SECURITY); + } + } + } + return null; + } } diff --git a/components/apimgt/org.wso2.carbon.apimgt.gateway/src/test/java/org/wso2/carbon/apimgt/gateway/handlers/security/APIKeyValidatorTestCase.java b/components/apimgt/org.wso2.carbon.apimgt.gateway/src/test/java/org/wso2/carbon/apimgt/gateway/handlers/security/APIKeyValidatorTestCase.java index fdebbc2d4143..a7ffbac369e8 100644 --- a/components/apimgt/org.wso2.carbon.apimgt.gateway/src/test/java/org/wso2/carbon/apimgt/gateway/handlers/security/APIKeyValidatorTestCase.java +++ b/components/apimgt/org.wso2.carbon.apimgt.gateway/src/test/java/org/wso2/carbon/apimgt/gateway/handlers/security/APIKeyValidatorTestCase.java @@ -416,7 +416,8 @@ public void testGetVerbInfoDTOFromAPIData() throws Exception { Mockito.when(APIUtil.getAPIInfoDTOCacheKey("", "1.0")).thenReturn("abc"); Mockito.when((VerbInfoDTO) CacheProvider.getResourceCache().get("abc")) .thenReturn(verbInfoDTO1); - VerbInfoDTO verbInfoDTOFromAPIData = apiKeyValidator.getVerbInfoDTOFromAPIData(context, apiVersion, + MessageContext messageContext = Mockito.mock(MessageContext.class); + VerbInfoDTO verbInfoDTOFromAPIData = apiKeyValidator.getVerbInfoDTOFromAPIData(messageContext, context, apiVersion, requestPath, httpMethod); Assert.assertEquals("", verbDTO, verbInfoDTOFromAPIData); @@ -460,8 +461,10 @@ public void testGetVerbInfoDTOFromAPIDataWithRequestPath() throws Exception { Mockito.when(APIUtil.getAPIInfoDTOCacheKey("", "1.0")).thenReturn("abc"); Mockito.when((VerbInfoDTO) CacheProvider.getResourceCache().get("abc")) .thenReturn(verbInfoDTO1); + MessageContext messageContext = Mockito.mock(MessageContext.class); + Assert.assertEquals("", verbDTO, - apiKeyValidator.getVerbInfoDTOFromAPIData(context, apiVersion, requestPath, httpMethod)); + apiKeyValidator.getVerbInfoDTOFromAPIData(messageContext, context, apiVersion, requestPath, httpMethod)); } @@ -495,8 +498,9 @@ public void testGetVerbInfoDTOFromAPIDataWithInvalidRequestPath() throws Excepti PowerMockito.when(cacheProvider.getDefaultCacheTimeout()).thenReturn((long) 900); Mockito.when(CacheProvider.getResourceCache()).thenReturn(cache); + MessageContext messageContext = Mockito.mock(MessageContext.class); Assert.assertEquals("", null, - apiKeyValidator.getVerbInfoDTOFromAPIData(context, apiVersion, requestPath, httpMethod)); + apiKeyValidator.getVerbInfoDTOFromAPIData(messageContext, context, apiVersion, requestPath, httpMethod)); } @@ -621,7 +625,7 @@ protected Cache getCache(String cacheManagerName, String cacheName, long modifie } @Override - protected ArrayList getAllURITemplates(String context, String apiVersion) throws + protected ArrayList getAllURITemplates(MessageContext messageContext, String context, String apiVersion) throws APISecurityException { return urlTemplates; }