Skip to content

Commit

Permalink
Remove X-Forwarded-For from RemoteAddressFilter
Browse files Browse the repository at this point in the history
  • Loading branch information
eager-signal committed Apr 11, 2024
1 parent 39fd955 commit 05a9249
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 161 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -332,10 +332,6 @@ public void run(WhisperServerConfiguration config, Environment environment) thro

MetricsUtil.configureRegistries(config, environment, dynamicConfigurationManager);

final boolean useRemoteAddress = Optional.ofNullable(
System.getenv("SIGNAL_USE_REMOTE_ADDRESS"))
.isPresent();

if (config.getServerFactory() instanceof DefaultServerFactory defaultServerFactory) {
defaultServerFactory.getApplicationConnectors()
.forEach(connectorFactory -> {
Expand Down Expand Up @@ -823,7 +819,7 @@ protected void configureServer(final ServerBuilder<?> serverBuilder) {

final List<Filter> filters = new ArrayList<>();
filters.add(remoteDeprecationFilter);
filters.add(new RemoteAddressFilter(useRemoteAddress));
filters.add(new RemoteAddressFilter());

for (Filter filter : filters) {
environment.servlets()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,28 @@
package org.whispersystems.textsecuregcm.filters;

import java.io.IOException;
import java.util.Optional;
import javax.annotation.Nullable;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.util.HttpServletRequestUtil;

/**
* Sets a {@link HttpServletRequest} attribute (that will also be available as a
* {@link javax.ws.rs.container.ContainerRequestContext} property) with the remote address of the connection, using
* either the {@link HttpServletRequest#getRemoteAddr()} or the {@code X-Forwarded-For} HTTP header value, depending on
* whether {@link #preferRemoteAddress} is {@code true}.
* {@link HttpServletRequest#getRemoteAddr()}.
*/
public class RemoteAddressFilter implements Filter {

public static final String REMOTE_ADDRESS_ATTRIBUTE_NAME = RemoteAddressFilter.class.getName() + ".remoteAddress";
private static final Logger logger = LoggerFactory.getLogger(RemoteAddressFilter.class);

private final boolean preferRemoteAddress;


public RemoteAddressFilter(boolean preferRemoteAddress) {
this.preferRemoteAddress = preferRemoteAddress;
public RemoteAddressFilter() {
}

@Override
Expand All @@ -43,16 +36,7 @@ public void doFilter(final ServletRequest request, final ServletResponse respons

if (request instanceof HttpServletRequest httpServletRequest) {

final String remoteAddress;

if (preferRemoteAddress) {
remoteAddress = HttpServletRequestUtil.getRemoteAddress(httpServletRequest);
} else {
final String forwardedFor = httpServletRequest.getHeader(com.google.common.net.HttpHeaders.X_FORWARDED_FOR);
remoteAddress = getMostRecentProxy(forwardedFor)
.orElseGet(() -> HttpServletRequestUtil.getRemoteAddress(httpServletRequest));
}

final String remoteAddress = HttpServletRequestUtil.getRemoteAddress(httpServletRequest);
request.setAttribute(REMOTE_ADDRESS_ATTRIBUTE_NAME, remoteAddress);

} else {
Expand All @@ -62,23 +46,4 @@ public void doFilter(final ServletRequest request, final ServletResponse respons
chain.doFilter(request, response);
}

/**
* Returns the most recent proxy in a chain described by an {@code X-Forwarded-For} header.
*
* @param forwardedFor the value of an X-Forwarded-For header
* @return the IP address of the most recent proxy in the forwarding chain, or empty if none was found or
* {@code forwardedFor} was null
* @see <a href="https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For">X-Forwarded-For - HTTP |
* MDN</a>
*/
public static Optional<String> getMostRecentProxy(@Nullable final String forwardedFor) {
return Optional.ofNullable(forwardedFor)
.map(ff -> {
final int idx = forwardedFor.lastIndexOf(',') + 1;
return idx < forwardedFor.length()
? forwardedFor.substring(idx).trim()
: null;
})
.filter(StringUtils::isNotBlank);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.util.Optional;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.signal.libsignal.protocol.ecc.ECKeyPair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;

/**
Expand Down Expand Up @@ -113,7 +114,7 @@ private Optional<InetAddress> getPreferredRemoteAddress(final ChannelHandlerCont
if (trustForwardedFor && handshakeCompleteEvent.requestHeaders().contains(FORWARDED_FOR_HEADER)) {
final String forwardedFor = handshakeCompleteEvent.requestHeaders().get(FORWARDED_FOR_HEADER);

return RemoteAddressFilter.getMostRecentProxy(forwardedFor).map(mostRecentProxy -> {
return getMostRecentProxy(forwardedFor).map(mostRecentProxy -> {
try {
return InetAddresses.forString(mostRecentProxy);
} catch (final IllegalArgumentException e) {
Expand All @@ -131,4 +132,25 @@ private Optional<InetAddress> getPreferredRemoteAddress(final ChannelHandlerCont
}
}
}

/**
* Returns the most recent proxy in a chain described by an {@code X-Forwarded-For} header.
*
* @param forwardedFor the value of an X-Forwarded-For header
* @return the IP address of the most recent proxy in the forwarding chain, or empty if none was found or
* {@code forwardedFor} was null
* @see <a href="https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For">X-Forwarded-For - HTTP |
* MDN</a>
*/
@VisibleForTesting
static Optional<String> getMostRecentProxy(@Nullable final String forwardedFor) {
return Optional.ofNullable(forwardedFor)
.map(ff -> {
final int idx = forwardedFor.lastIndexOf(',') + 1;
return idx < forwardedFor.length()
? forwardedFor.substring(idx).trim()
: null;
})
.filter(StringUtils::isNotBlank);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.when;
import static org.whispersystems.textsecuregcm.filters.RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME;

import io.dropwizard.core.Application;
Expand Down Expand Up @@ -77,10 +75,10 @@ public void run(final Configuration configuration, final Environment environment

environment.jersey().register(testController);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter(true))
.addFilter("RemoteAddressFilter", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(testController);
webSocketEnvironment.jersey().register(new RemoteAddressFilter(true));
webSocketEnvironment.jersey().register(new RemoteAddressFilter());
webSocketEnvironment.setAuthenticator(upgradeRequest ->
ReusableAuth.authenticated(mock(AuthenticatedAccount.class), PrincipalSupplier.forImmutablePrincipal()));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ public void run(final Configuration configuration, final Environment environment

environment.jersey().register(testController);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter(true))
.addFilter("RemoteAddressFilter", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(testController);
webSocketEnvironment.jersey().register(new RemoteAddressFilter(true));
webSocketEnvironment.jersey().register(new RemoteAddressFilter());
webSocketEnvironment.setAuthenticator(upgradeRequest -> ReusableAuth.authenticated(ACCOUNT, PRINCIPAL_SUPPLIER));

webSocketEnvironment.jersey().property(ServerProperties.UNWRAP_COMPLETION_STAGE_IN_WRITER_ENABLE, Boolean.TRUE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ public void run(final Configuration configuration, final Environment environment
environment.jersey().register(testController);
webSocketEnvironment.jersey().register(testController);
environment.servlets()
.addFilter("RemoteAddressFilter", new RemoteAddressFilter(true))
.addFilter("RemoteAddressFilter", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");
webSocketEnvironment.jersey().register(new RemoteAddressFilter(true));
webSocketEnvironment.jersey().register(new RemoteAddressFilter());
webSocketEnvironment.jersey()
.register(new WebsocketRefreshApplicationEventListener(ACCOUNTS_MANAGER, CLIENT_PRESENCE));
environment.jersey()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assumptions.assumeTrue;

import com.google.common.net.HttpHeaders;
import io.dropwizard.core.Application;
import io.dropwizard.core.Configuration;
import io.dropwizard.core.setup.Environment;
Expand Down Expand Up @@ -39,15 +38,13 @@
import org.eclipse.jetty.util.HostPort;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.whispersystems.textsecuregcm.util.SystemMapper;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
Expand All @@ -62,7 +59,6 @@ class RemoteAddressFilterIntegrationTest {

private static final String WEBSOCKET_PREFIX = "/websocket";
private static final String REMOTE_ADDRESS_PATH = "/remoteAddress";
private static final String FORWARDED_FOR_PATH = "/forwardedFor";
private static final String WS_REQUEST_PATH = "/wsRequest";

// The Grizzly test container does not match the Jetty container used in real deployments, and JettyTestContainerFactory
Expand Down Expand Up @@ -92,22 +88,6 @@ void testRemoteAddress(String ip) throws Exception {

assertEquals(ip, response.remoteAddress());
}

@ParameterizedTest
@CsvSource(value = {"127.0.0.1, 192.168.1.1 \t 192.168.1.1",
"127.0.0.1, fe80:1:1:1:1:1:1:1 \t fe80:1:1:1:1:1:1:1"}, delimiterString = "\t")
void testForwardedFor(String forwardedFor, String expectedIp) {

Client client = EXTENSION.client();

final RemoteAddressFilterIntegrationTest.TestResponse response = client.target(
String.format("http://localhost:%d%s", EXTENSION.getLocalPort(), FORWARDED_FOR_PATH))
.request("application/json")
.header(HttpHeaders.X_FORWARDED_FOR, forwardedFor)
.get(RemoteAddressFilterIntegrationTest.TestResponse.class);

assertEquals(expectedIp, response.remoteAddress());
}
}

@Nested
Expand Down Expand Up @@ -149,28 +129,6 @@ void testRemoteAddress(String ip) throws Exception {

assertEquals(ip, response.remoteAddress());
}

@ParameterizedTest
@CsvSource(value = {"127.0.0.1, 192.168.1.1 \t 192.168.1.1",
"127.0.0.1, fe80:1:1:1:1:1:1:1 \t fe80:1:1:1:1:1:1:1"}, delimiterString = "\t")
void testForwardedFor(String forwardedFor, String expectedIp) throws Exception {

final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
upgradeRequest.setHeader(HttpHeaders.X_FORWARDED_FOR, forwardedFor);

final CompletableFuture<byte[]> responseFuture = new CompletableFuture<>();

client.connect(new ClientEndpoint(WS_REQUEST_PATH, responseFuture),
URI.create(
String.format("ws://localhost:%d%s", EXTENSION.getLocalPort(), WEBSOCKET_PREFIX + FORWARDED_FOR_PATH)),
upgradeRequest);

final byte[] responseBytes = responseFuture.get(1, TimeUnit.SECONDS);

final TestResponse response = SystemMapper.jsonMapper().readValue(responseBytes, TestResponse.class);

assertEquals(expectedIp, response.remoteAddress());
}
}

private static class ClientEndpoint implements WebSocketListener {
Expand Down Expand Up @@ -233,11 +191,6 @@ public static class TestRemoteAddressController extends TestController {

}

@Path(FORWARDED_FOR_PATH)
public static class TestForwardedForController extends TestController {

}

@Path(WS_REQUEST_PATH)
public static class TestWebSocketController extends TestController {

Expand All @@ -253,17 +206,11 @@ public static class TestApplication extends Application<Configuration> {
public void run(final Configuration configuration,
final Environment environment) throws Exception {

// 2 filters, to cover useRemoteAddress = {true, false}
// each has explicit (not wildcard) path matching
environment.servlets().addFilter("RemoteAddressFilterRemoteAddress", new RemoteAddressFilter(true))
environment.servlets().addFilter("RemoteAddressFilterRemoteAddress", new RemoteAddressFilter())
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, REMOTE_ADDRESS_PATH,
WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH);
environment.servlets().addFilter("RemoteAddressFilterForwardedFor", new RemoteAddressFilter(false))
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, FORWARDED_FOR_PATH,
WEBSOCKET_PREFIX + FORWARDED_FOR_PATH);

environment.jersey().register(new TestRemoteAddressController());
environment.jersey().register(new TestForwardedForController());

// WebSocket set up
final WebSocketConfiguration webSocketConfiguration = new WebSocketConfiguration();
Expand All @@ -279,9 +226,6 @@ public void run(final Configuration configuration,
webSocketEnvironment, TestPrincipal.class, webSocketConfiguration,
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);

// 2 servlets, because the filter only runs for the Upgrade request
environment.servlets().addServlet("WebSocketForwardedFor", webSocketServlet)
.addMapping(WEBSOCKET_PREFIX + FORWARDED_FOR_PATH);
environment.servlets().addServlet("WebSocketRemoteAddress", webSocketServlet)
.addMapping(WEBSOCKET_PREFIX + REMOTE_ADDRESS_PATH);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,17 @@

package org.whispersystems.textsecuregcm.filters;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.params.provider.Arguments.arguments;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.common.net.HttpHeaders;
import java.util.Optional;
import java.util.stream.Stream;
import javax.servlet.FilterChain;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.MethodSource;

class RemoteAddressFilterTest {

Expand All @@ -36,7 +29,7 @@ void testGetRemoteAddress(final String remoteAddr, final String expectedRemoteAd
final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class);
when(httpServletRequest.getRemoteAddr()).thenReturn(remoteAddr);

final RemoteAddressFilter filter = new RemoteAddressFilter(true);
final RemoteAddressFilter filter = new RemoteAddressFilter();

final FilterChain filterChain = mock(FilterChain.class);
filter.doFilter(httpServletRequest, mock(ServletResponse.class), filterChain);
Expand All @@ -45,41 +38,4 @@ void testGetRemoteAddress(final String remoteAddr, final String expectedRemoteAd
verify(filterChain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
}

@ParameterizedTest
@CsvSource(value = {
"192.168.1.1, 127.0.0.1 \t 127.0.0.1",
"192.168.1.1, 0:0:0:0:0:0:0:1 \t 0:0:0:0:0:0:0:1"
}, delimiterString = "\t")
void testGetRemoteAddressFromHeader(final String forwardedFor, final String expectedRemoteAddr) throws Exception {
final HttpServletRequest httpServletRequest = mock(HttpServletRequest.class);
when(httpServletRequest.getHeader(HttpHeaders.X_FORWARDED_FOR)).thenReturn(forwardedFor);

final RemoteAddressFilter filter = new RemoteAddressFilter(false);

final FilterChain filterChain = mock(FilterChain.class);
filter.doFilter(httpServletRequest, mock(ServletResponse.class), filterChain);

verify(httpServletRequest).setAttribute(RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME, expectedRemoteAddr);
verify(filterChain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
}

@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
@ParameterizedTest
@MethodSource("argumentsForGetMostRecentProxy")
void getMostRecentProxy(final String forwardedFor, final Optional<String> expectedMostRecentProxy) {
assertEquals(expectedMostRecentProxy, RemoteAddressFilter.getMostRecentProxy(forwardedFor));
}

private static Stream<Arguments> argumentsForGetMostRecentProxy() {
return Stream.of(
arguments(null, Optional.empty()),
arguments("", Optional.empty()),
arguments(" ", Optional.empty()),
arguments("203.0.113.195,", Optional.empty()),
arguments("203.0.113.195, ", Optional.empty()),
arguments("203.0.113.195", Optional.of("203.0.113.195")),
arguments("203.0.113.195, 70.41.3.18, 150.172.238.178", Optional.of("150.172.238.178"))
);
}

}
Loading

0 comments on commit 05a9249

Please sign in to comment.