Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Choreo] Support passing API-Key for WebSocket Requests using sec-websocket-protocol Header #3562

Merged
merged 11 commits into from
Aug 9, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ public class RequestContext {
// For example, reason for denying a request
private String extAuthDetails;

private Map<String, String> responseHeadersToAddMap;

/**
* The dynamic metadata sent from enforcer are stored in this metadata map.
* @return dynamic metadata map
Expand Down Expand Up @@ -358,6 +360,20 @@ public void setExtAuthDetails(String extAuthDetails) {
this.extAuthDetails = extAuthDetails;
}

/**
* Specifies if headers needs to be added for the response based on request
*
* @return response headers to add map
*/
public Map<String, String> getResponseHeadersToAddMap() {
return responseHeadersToAddMap;
}

public void setResponseHeadersToAddMap(Map<String, String> responseHeadersToAddMap) {
this.responseHeadersToAddMap = responseHeadersToAddMap;
}


/**
* Implements builder pattern to build an {@link RequestContext} object.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public class ResponseObject {
private String requestPath;
private String apiUuid;
private String extAuthDetails;
private Map<String, String> responseHeadersToAddMap;

public ArrayList<String> getRemoveHeaderMap() {
return removeHeaderMap;
Expand All @@ -48,6 +49,14 @@ public void setRemoveHeaderMap(ArrayList<String> removeHeaderMap) {
this.removeHeaderMap = removeHeaderMap;
}

public Map<String, String> getResponseHeadersToAddMap() {
return responseHeadersToAddMap;
}

public void setResponseHeadersToAddMap(Map<String, String> responseHeadersToAddMap) {
this.responseHeadersToAddMap = responseHeadersToAddMap;
}

public ResponseObject(String correlationID) {
this.correlationID = correlationID;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,17 @@ public ResponseObject process(RequestContext requestContext) {
Utils.populateRemoveAndProtectedHeaders(requestContext);

if (executeFilterChain(requestContext)) {
responseObject.setRemoveHeaderMap(requestContext.getRemoveHeaders());
responseObject.setQueryParamsToRemove(requestContext.getQueryParamsToRemove());
responseObject.setQueryParamMap(requestContext.getQueryParameters());
responseObject.setStatusCode(APIConstants.StatusCodes.OK.getCode());
if (requestContext.getAddHeaders() != null && requestContext.getAddHeaders().size() > 0) {
responseObject.setHeaderMap(requestContext.getAddHeaders());
}
if (requestContext.getResponseHeadersToAddMap() != null
&& requestContext.getResponseHeadersToAddMap().size() > 0) {
responseObject.setResponseHeadersToAddMap(requestContext.getResponseHeadersToAddMap());
}
logger.debug("ext_authz metadata: {}", requestContext.getMetadataMap());
responseObject.setMetaDataMap(requestContext.getMetadataMap());
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,7 @@ public class Constants {
public static final String PROP_CON_FACTORY = "connectionfactory.TopicConnectionFactory";
public static final String DEFAULT_DESTINATION_TYPE = "Topic";
public static final String DEFAULT_CON_FACTORY_JNDI_NAME = "TopicConnectionFactory";

// keyword to identify API-Key sent in sec-websocket-protocol header
public static final String WS_API_KEY_IDENTIFIER = "choreo-internal-API-Key";
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ public class HttpConstants {
public static final String X_REQUEST_ID_HEADER = "x-request-id";
public static final String APPLICATION_JSON = "application/json";
public static final String BASIC_LOWER = "basic";
public static final String WEBSOCKET_PROTOCOL_HEADER = "sec-websocket-protocol";
}
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ private CheckResponse buildResponse(CheckRequest request, ResponseObject respons
.build();
} else {
OkHttpResponse.Builder okResponseBuilder = OkHttpResponse.newBuilder();

// If the user is sending the APIKey credentials within query parameters, those query parameters should
// not be sent to the backend. Hence, the :path header needs to be constructed again removing the apiKey
// query parameter. In this scenario, apiKey query parameter is sent within the property called
Expand All @@ -175,6 +175,16 @@ private CheckResponse buildResponse(CheckRequest request, ResponseObject respons
}
);
}

if (responseObject.getResponseHeadersToAddMap() != null) {
responseObject.getResponseHeadersToAddMap().forEach((key, value) -> {
HeaderValueOption headerValueOption = HeaderValueOption.newBuilder()
.setHeader(HeaderValue.newBuilder().setKey(key).setValue(value).build())
.build();
okResponseBuilder.addResponseHeadersToAdd(headerValueOption);
}
);
}
okResponseBuilder.addAllHeadersToRemove(responseObject.getRemoveHeaderMap());
if (responseObject.getMetaDataMap() != null) {
responseObject.getMetaDataMap().forEach((key, value) ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
import org.wso2.choreo.connect.enforcer.config.EnforcerConfig;
import org.wso2.choreo.connect.enforcer.constants.APIConstants;
import org.wso2.choreo.connect.enforcer.constants.APISecurityConstants;
import org.wso2.choreo.connect.enforcer.constants.Constants;
import org.wso2.choreo.connect.enforcer.constants.HttpConstants;
import org.wso2.choreo.connect.enforcer.dto.APIKeyValidationInfoDTO;
import org.wso2.choreo.connect.enforcer.dto.JWTTokenPayloadInfo;
import org.wso2.choreo.connect.enforcer.exception.APISecurityException;
Expand All @@ -47,6 +49,10 @@
import org.wso2.choreo.connect.enforcer.util.FilterUtils;

import java.text.ParseException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;

/**
* Implements the authenticator interface to authenticate request using an Internal Key.
Expand All @@ -69,8 +75,16 @@ public InternalAPIKeyAuthenticator(String securityParam) {

@Override
public boolean canAuthenticate(RequestContext requestContext) {
String apiType = requestContext.getMatchedAPI().getApiType();
String internalKey = requestContext.getHeaders().get(
ConfigHolder.getInstance().getConfig().getAuthHeader().getTestConsoleHeaderName().toLowerCase());
if (apiType.equalsIgnoreCase("WS")) {
if (internalKey == null) {
internalKey = extractInternalKeyInWSProtocolHeader(requestContext);
}
addWSProtocolResponseHeaderIfRequired(requestContext);
}

return isAPIKey(internalKey);
}

Expand Down Expand Up @@ -281,13 +295,65 @@ public String getName() {
}

private String extractInternalKey(RequestContext requestContext) {
String internalKey = requestContext.getHeaders().get(securityParam);
String internalKey;
internalKey = requestContext.getHeaders().get(securityParam);
if (internalKey != null) {
return internalKey.trim();
}
if (requestContext.getMatchedAPI().getApiType().equalsIgnoreCase("WS")) {
internalKey = extractInternalKeyInWSProtocolHeader(requestContext);
if (internalKey != null && !internalKey.isEmpty()) {
String protocols = getProtocolsToSetInRequestHeaders(requestContext);
if (protocols != null) {
requestContext.addOrModifyHeaders(HttpConstants.WEBSOCKET_PROTOCOL_HEADER, protocols);
}
return internalKey.trim();
}
}
return null;
}

public String extractInternalKeyInWSProtocolHeader(RequestContext requestContext) {
String protocolHeader = requestContext.getHeaders().get(
HttpConstants.WEBSOCKET_PROTOCOL_HEADER);
if (protocolHeader != null) {
String[] secProtocolHeaderValues = protocolHeader.split(",");
if (secProtocolHeaderValues.length > 1 && secProtocolHeaderValues[0].equals(
Constants.WS_API_KEY_IDENTIFIER)) {
return secProtocolHeaderValues[1].trim();
}
}
return "";
}

public String getProtocolsToSetInRequestHeaders(RequestContext requestContext) {
String[] secProtocolHeaderValues = requestContext.getHeaders().get(
HttpConstants.WEBSOCKET_PROTOCOL_HEADER).split(",");
if (secProtocolHeaderValues.length > 2) {
return Arrays.stream(secProtocolHeaderValues, 2, secProtocolHeaderValues.length)
.collect(Collectors.joining(",")).trim();
}
return null;
}

public void addWSProtocolResponseHeaderIfRequired(RequestContext requestContext) {
String secProtocolHeader = requestContext.getHeaders().get(HttpConstants.WEBSOCKET_PROTOCOL_HEADER);
if (secProtocolHeader != null) {
String[] secProtocolHeaderValues = secProtocolHeader.split(",");
if (secProtocolHeaderValues[0].equals(Constants.WS_API_KEY_IDENTIFIER) &&
secProtocolHeaderValues.length == 2) {
Map<String, String> responseHeadersToAddMap = requestContext.getResponseHeadersToAddMap();

if (responseHeadersToAddMap == null) {
responseHeadersToAddMap = new HashMap<>();
}
responseHeadersToAddMap.put(
HttpConstants.WEBSOCKET_PROTOCOL_HEADER, Constants.WS_API_KEY_IDENTIFIER);
requestContext.setResponseHeadersToAddMap(responseHeadersToAddMap);
}
}
}

@Override
public int getPriority() {
return -10;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Copyright (c) 2024, WSO2 LLC. (http://www.wso2.org) All Rights Reserved.
*
* WSO2 Inc. licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
* in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.wso2.choreo.connect.enforcer.security.jwt;

import java.util.HashMap;
import java.util.Map;

import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mockito;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PowerMockIgnore;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import org.wso2.carbon.apimgt.common.gateway.dto.JWTConfigurationDto;
import org.wso2.choreo.connect.enforcer.commons.model.APIConfig;
import org.wso2.choreo.connect.enforcer.commons.model.RequestContext;
import org.wso2.choreo.connect.enforcer.config.ConfigHolder;
import org.wso2.choreo.connect.enforcer.config.EnforcerConfig;
import org.wso2.choreo.connect.enforcer.config.dto.CacheDto;
@RunWith(PowerMockRunner.class)
@PrepareForTest({ConfigHolder.class})
@PowerMockIgnore("javax.management.*")
public class InternalAPIKeyAuthenticatorTest {

@Test
public void extractInternalKeyInWSProtocolHeaderTest() {
PowerMockito.mockStatic(ConfigHolder.class);
ConfigHolder configHolder = Mockito.mock(ConfigHolder.class);
EnforcerConfig enforcerConfig = Mockito.mock(EnforcerConfig.class);
CacheDto cacheDto = Mockito.mock(CacheDto.class);
Mockito.when(cacheDto.isEnabled()).thenReturn(true);
Mockito.when(enforcerConfig.getCacheDto()).thenReturn(cacheDto);
JWTConfigurationDto jwtConfigurationDto = Mockito.mock(JWTConfigurationDto.class);
Mockito.when(jwtConfigurationDto.isEnabled()).thenReturn(false);
Mockito.when(enforcerConfig.getJwtConfigurationDto()).thenReturn(jwtConfigurationDto);
Mockito.when(configHolder.getConfig()).thenReturn(enforcerConfig);
Mockito.when(ConfigHolder.getInstance()).thenReturn(configHolder);

String securityParam = "API-Key";

String mockToken = "eyJraWQiOiJnYXRld2F5XUlMyNTYifQlzaGVyXC92MlwvYXBpc1wvaW50ZXJuYlzaGVyXC92XBpc1wvaW50ZXJuY." +
"eyJzdWIiOiJhMzllYGV2OjQ0M1wvYXBpXC9hbVwvcHVibGlzaGVyXC92MlwvYXBpc1wvaW50ZXJuYWwta2V5Iiwia2V5dHlwZcl." +
"cnZpY2VcL3YxLjAiLCJwdWJsaXNoZXIiOiJjaG9yZW9fZGV2X2FwaW1fYWRtaW4iLCJ2ZXJzaW9uIjoidj7MIXRnS-2UWHdrmd7";

String secWebsocketProtocolHeader = "sec-websocket-protocol";

// Test case to test for an Upgrade request sent from the choreo console
// The token will be set to the sec-websocket-protocol header with choreo-internal-API-Key keyword
// the value after choreo-internal-API-Key will be the token
RequestContext.Builder builder = new RequestContext.Builder("/pets");
builder.matchedAPI(new APIConfig.Builder("Petstore")
.basePath("/choreo")
.apiType("WS")
.build());
Map<String, String> headersMap = new HashMap<>();
headersMap.put(
secWebsocketProtocolHeader,
"choreo-internal-API-Key," + mockToken);
builder.headers(headersMap);
RequestContext requestContext = builder.build();
InternalAPIKeyAuthenticator internalAPIKeyAuthenticator = new InternalAPIKeyAuthenticator(securityParam);
Assert.assertEquals(internalAPIKeyAuthenticator.extractInternalKeyInWSProtocolHeader(requestContext), mockToken);

// Test case to test for an Upgrade request sent from a client with api-key
RequestContext.Builder builder2 = new RequestContext.Builder("/pets");
builder2.matchedAPI(new APIConfig.Builder("Petstore")
.basePath("/choreo")
.apiType("WS")
.build());
Map<String, String> headersMap2 = new HashMap<>();
headersMap2.put(securityParam, mockToken);
builder2.headers(headersMap2);
RequestContext requestContext2 = builder2.build();
Assert.assertEquals(internalAPIKeyAuthenticator.extractInternalKeyInWSProtocolHeader(requestContext2), "");

}

@Test
public void getProtocolsToSetInRequestHeadersTest() {
PowerMockito.mockStatic(ConfigHolder.class);
ConfigHolder configHolder = Mockito.mock(ConfigHolder.class);
EnforcerConfig enforcerConfig = Mockito.mock(EnforcerConfig.class);
CacheDto cacheDto = Mockito.mock(CacheDto.class);
Mockito.when(cacheDto.isEnabled()).thenReturn(true);
Mockito.when(enforcerConfig.getCacheDto()).thenReturn(cacheDto);
JWTConfigurationDto jwtConfigurationDto = Mockito.mock(JWTConfigurationDto.class);
Mockito.when(jwtConfigurationDto.isEnabled()).thenReturn(false);
Mockito.when(enforcerConfig.getJwtConfigurationDto()).thenReturn(jwtConfigurationDto);
Mockito.when(configHolder.getConfig()).thenReturn(enforcerConfig);
Mockito.when(ConfigHolder.getInstance()).thenReturn(configHolder);

String securityParam = "API-Key";

String secWebsocketProtocolHeader = "sec-websocket-protocol";

String mockToken = "eyJraWQiOiJnYXRld2F5XUlMyNTYifQlzaGVyXC92MlwvYXBpc1wvaW50ZXJuYlzaGVyXC92XBpc1wvaW50ZXJuY." +
"eyJzdWIiOiJhMzllYGV2OjQ0M1wvYXBpXC9hbVwvcHVibGlzaGVyXC92MlwvYXBpc1wvaW50ZXJuYWwta2V5Iiwia2V5dHlwZcl." +
"cnZpY2VcL3YxLjAiLCJwdWJsaXNoZXIiOiJjaG9yZW9fZGV2X2FwaW1fYWRtaW4iLCJ2ZXJzaW9uIjoidj7MIXRnS-2UWHdrmd7";

RequestContext.Builder builder = new RequestContext.Builder("/pets");
builder.matchedAPI(new APIConfig.Builder("Petstore")
.basePath("/choreo")
.apiType("WS")
.build());
Map<String, String> headersMap = new HashMap<>();
headersMap.put(
secWebsocketProtocolHeader,
"choreo-internal-API-Key, " + mockToken + ", " + "chat, bar");
builder.headers(headersMap);
RequestContext requestContext = builder.build();
InternalAPIKeyAuthenticator internalAPIKeyAuthenticator = new InternalAPIKeyAuthenticator(securityParam);
Assert.assertEquals(internalAPIKeyAuthenticator.getProtocolsToSetInRequestHeaders(requestContext), "chat, bar");

}
}
Loading