diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index d5d1fb00f..aeb517396 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -37,6 +37,7 @@ import java.util.List; import java.util.Optional; import java.util.ServiceLoader; +import java.util.Set; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; import java.util.concurrent.ExecutorService; @@ -97,8 +98,8 @@ import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV2; import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV3; import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV4; -import org.whispersystems.textsecuregcm.controllers.CallRoutingController; import org.whispersystems.textsecuregcm.controllers.CallLinkController; +import org.whispersystems.textsecuregcm.controllers.CallRoutingController; import org.whispersystems.textsecuregcm.controllers.CertificateController; import org.whispersystems.textsecuregcm.controllers.ChallengeController; import org.whispersystems.textsecuregcm.controllers.DeviceController; @@ -792,7 +793,11 @@ public void run(WhisperServerConfiguration config, Environment environment) thro .setAuthenticator(accountAuthenticator) .buildAuthFilter(); - final MetricsHttpChannelListener metricsHttpChannelListener = new MetricsHttpChannelListener(clientReleaseManager); + final String websocketServletPath = "/v1/websocket/"; + final String provisioningWebsocketServletPath = "/v1/websocket/provisioning/"; + + final MetricsHttpChannelListener metricsHttpChannelListener = new MetricsHttpChannelListener(clientReleaseManager, + Set.of(websocketServletPath, provisioningWebsocketServletPath, "/health-check")); metricsHttpChannelListener.configure(environment); environment.jersey().register(new VirtualExecutorServiceProvider("managed-async-virtual-thread-")); @@ -950,10 +955,10 @@ public void run(WhisperServerConfiguration config, Environment environment) thro ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet); ServletRegistration.Dynamic provisioning = environment.servlets().addServlet("Provisioning", provisioningServlet); - websocket.addMapping("/v1/websocket/"); + websocket.addMapping(websocketServletPath); websocket.setAsyncSupported(true); - provisioning.addMapping("/v1/websocket/provisioning/"); + provisioning.addMapping(provisioningWebsocketServletPath); provisioning.setAsyncSupported(true); environment.admin().addTask(new SetRequestLoggingEnabledTask()); diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListener.java b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListener.java index cb4855130..ff6c654ac 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListener.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListener.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.Set; import javax.annotation.Nullable; import javax.ws.rs.container.ContainerRequestContext; import javax.ws.rs.container.ContainerResponseContext; @@ -54,6 +55,7 @@ private record RequestInfo(String path, String method, int statusCode, @Nullable } private final ClientReleaseManager clientReleaseManager; + private final Set servletPaths; public static final String REQUEST_COUNTER_NAME = name(MetricsHttpChannelListener.class, "request"); public static final String REQUESTS_BY_VERSION_COUNTER_NAME = name(MetricsHttpChannelListener.class, @@ -76,14 +78,16 @@ private record RequestInfo(String path, String method, int statusCode, @Nullable private final MeterRegistry meterRegistry; - public MetricsHttpChannelListener(final ClientReleaseManager clientReleaseManager) { - this(Metrics.globalRegistry, clientReleaseManager); + public MetricsHttpChannelListener(final ClientReleaseManager clientReleaseManager, final Set servletPaths) { + this(Metrics.globalRegistry, clientReleaseManager, servletPaths); } @VisibleForTesting - MetricsHttpChannelListener(final MeterRegistry meterRegistry, final ClientReleaseManager clientReleaseManager) { + MetricsHttpChannelListener(final MeterRegistry meterRegistry, final ClientReleaseManager clientReleaseManager, + final Set servletPaths) { this.meterRegistry = meterRegistry; this.clientReleaseManager = clientReleaseManager; + this.servletPaths = servletPaths; } public void configure(final Environment environment) { @@ -158,7 +162,12 @@ public void filter(final ContainerRequestContext requestContext, final Container private RequestInfo getRequestInfo(Request request) { final String path = Optional.ofNullable(request.getAttribute(URI_INFO_PROPERTY_NAME)) .map(attr -> UriInfoUtil.getPathTemplate((ExtendedUriInfo) attr)) - .orElse("unknown"); + .orElseGet(() -> { + if (servletPaths.contains(request.getPathInfo())) { + return request.getPathInfo(); + } + return "unknown"; + }); final String method = Optional.ofNullable(request.getMethod()).orElse("unknown"); // Response cannot be null, but its status might not always reflect an actual response status, since it gets // initialized to 200 diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerIntegrationTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerIntegrationTest.java index b2dcbf2c1..e58a5993a 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerIntegrationTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerIntegrationTest.java @@ -27,8 +27,10 @@ import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.Tag; import java.io.IOException; +import java.net.URI; import java.security.Principal; import java.time.Duration; +import java.util.EnumSet; import java.util.HashSet; import java.util.Set; import java.util.concurrent.CompletableFuture; @@ -38,6 +40,7 @@ import java.util.stream.Stream; import javax.annotation.Priority; import javax.security.auth.Subject; +import javax.servlet.DispatcherType; import javax.ws.rs.GET; import javax.ws.rs.InternalServerErrorException; import javax.ws.rs.NotAuthorizedException; @@ -55,13 +58,21 @@ import org.eclipse.jetty.server.Request; import org.eclipse.jetty.util.component.Container; import org.eclipse.jetty.util.component.LifeCycle; +import org.eclipse.jetty.websocket.api.Session; +import org.eclipse.jetty.websocket.api.WebSocketListener; +import org.eclipse.jetty.websocket.client.ClientUpgradeRequest; +import org.eclipse.jetty.websocket.client.WebSocketClient; import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.ArgumentCaptor; +import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; import org.whispersystems.websocket.WebSocketResourceProviderFactory; import org.whispersystems.websocket.configuration.WebSocketConfiguration; @@ -148,6 +159,64 @@ void testSimplePath(String requestPath, String expectedTagPath, String expectedR assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android"))); } + + @Nested + class WebSocket { + + private WebSocketClient client; + + @BeforeEach + void setUp() throws Exception { + client = new WebSocketClient(); + client.start(); + } + + @AfterEach + void tearDown() throws Exception { + client.stop(); + } + + @Test + void testWebSocketUpgrade() throws Exception { + final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest(); + upgradeRequest.setHeader(HttpHeaders.USER_AGENT, "Signal-Android/4.53.7 (Android 8.1)"); + + final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); + when(METER_REGISTRY.counter(anyString(), any(Iterable.class))) + .thenAnswer(a -> MetricsHttpChannelListener.REQUEST_COUNTER_NAME.equals(a.getArgument(0, String.class)) + ? COUNTER + : mock(Counter.class)) + .thenReturn(COUNTER); + + client.connect(new WebSocketListener() { + @Override + public void onWebSocketConnect(final Session session) { + session.close(1000, "OK"); + } + }, + URI.create(String.format("ws://localhost:%d%s", EXTENSION.getLocalPort(), "/v1/websocket")), upgradeRequest) + .get(1, TimeUnit.SECONDS); + + verify(METER_REGISTRY).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture()); + verify(COUNTER).increment(); + + final Iterable tagIterable = tagCaptor.getValue(); + final Set tags = new HashSet<>(); + + for (final Tag tag : tagIterable) { + tags.add(tag); + } + + assertEquals(5, tags.size()); + assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.PATH_TAG, "/v1/websocket"))); + assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.METHOD_TAG, "GET"))); + assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.STATUS_CODE_TAG, String.valueOf(101)))); + assertTrue( + tags.contains(Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase()))); + assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android"))); + } + } + static Stream testSimplePath() { return Stream.of( Arguments.of("/v1/test/hello", "/v1/test/hello", "Hello!", 200), @@ -166,11 +235,16 @@ public void run(final Configuration configuration, final MetricsHttpChannelListener metricsHttpChannelListener = new MetricsHttpChannelListener( METER_REGISTRY, - mock(ClientReleaseManager.class)); + mock(ClientReleaseManager.class), + Set.of("/v1/websocket") + ); metricsHttpChannelListener.configure(environment); environment.lifecycle().addEventListener(new TestListener(LISTENER_FUTURE_REFERENCE)); + environment.servlets().addFilter("RemoteAddressFilter", new RemoteAddressFilter(true)) + .addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*"); + environment.jersey().register(new TestResource()); environment.jersey().register(new TestAuthFilter()); @@ -185,9 +259,11 @@ public void run(final Configuration configuration, JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null); WebSocketResourceProviderFactory webSocketServlet = new WebSocketResourceProviderFactory<>( - webSocketEnvironment, TestPrincipal.class, webSocketConfiguration, "ignored"); + webSocketEnvironment, TestPrincipal.class, webSocketConfiguration, + RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); - environment.servlets().addServlet("WebSocket", webSocketServlet); + environment.servlets().addServlet("WebSocket", webSocketServlet) + .addMapping("/v1/websocket"); } } @@ -273,4 +349,5 @@ public boolean implies(final Subject subject) { return false; } } + } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerTest.java index 614ca99af..5fbdf2f91 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/metrics/MetricsHttpChannelListenerTest.java @@ -11,12 +11,14 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import com.google.common.net.HttpHeaders; import io.micrometer.core.instrument.Counter; import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.Tag; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Set; @@ -27,29 +29,39 @@ import org.glassfish.jersey.uri.UriTemplate; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.ArgumentCaptor; import org.whispersystems.textsecuregcm.storage.ClientReleaseManager; class MetricsHttpChannelListenerTest { private MeterRegistry meterRegistry; - private Counter counter; + private Counter requestCounter; + private Counter requestsByVersionCounter; + private ClientReleaseManager clientReleaseManager; private MetricsHttpChannelListener listener; @BeforeEach void setup() { meterRegistry = mock(MeterRegistry.class); - counter = mock(Counter.class); + requestCounter = mock(Counter.class); + requestsByVersionCounter = mock(Counter.class); - final ClientReleaseManager clientReleaseManager = mock(ClientReleaseManager.class); - when(clientReleaseManager.isVersionActive(any(), any())).thenReturn(false); + when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), any(Iterable.class))) + .thenReturn(requestCounter); + + when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUESTS_BY_VERSION_COUNTER_NAME), any(Iterable.class))) + .thenReturn(requestsByVersionCounter); - listener = new MetricsHttpChannelListener(meterRegistry, clientReleaseManager); + clientReleaseManager = mock(ClientReleaseManager.class); + + listener = new MetricsHttpChannelListener(meterRegistry, clientReleaseManager, Collections.emptySet()); } @Test @SuppressWarnings("unchecked") - void testOnEvent() { + void testRequests() { final String path = "/test"; final String method = "GET"; final int statusCode = 200; @@ -70,17 +82,15 @@ void testOnEvent() { when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate(path))); final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); - when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), any(Iterable.class))) - .thenReturn(counter); listener.onComplete(request); + verify(requestCounter).increment(); + verify(meterRegistry).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture()); - final Iterable tagIterable = tagCaptor.getValue(); final Set tags = new HashSet<>(); - - for (final Tag tag : tagIterable) { + for (final Tag tag : tagCaptor.getValue()) { tags.add(tag); } @@ -92,4 +102,50 @@ void testOnEvent() { tags.contains(Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase()))); assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android"))); } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + @SuppressWarnings("unchecked") + void testRequestsByVersion(final boolean versionActive) { + when(clientReleaseManager.isVersionActive(any(), any())).thenReturn(versionActive); + final String path = "/test"; + final String method = "GET"; + final int statusCode = 200; + + final HttpURI httpUri = mock(HttpURI.class); + when(httpUri.getPath()).thenReturn(path); + + final Request request = mock(Request.class); + when(request.getMethod()).thenReturn(method); + when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn("Signal-Android/6.53.7 (Android 8.1)"); + when(request.getHttpURI()).thenReturn(httpUri); + + final Response response = mock(Response.class); + when(response.getStatus()).thenReturn(statusCode); + when(request.getResponse()).thenReturn(response); + final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class); + when(request.getAttribute(MetricsHttpChannelListener.URI_INFO_PROPERTY_NAME)).thenReturn(extendedUriInfo); + when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate(path))); + + listener.onComplete(request); + + if (versionActive) { + final ArgumentCaptor> tagCaptor = ArgumentCaptor.forClass(Iterable.class); + verify(meterRegistry).counter(eq(MetricsHttpChannelListener.REQUESTS_BY_VERSION_COUNTER_NAME), + tagCaptor.capture()); + final Set tags = new HashSet<>(); + tags.clear(); + for (final Tag tag : tagCaptor.getValue()) { + tags.add(tag); + } + + assertEquals(2, tags.size()); + assertTrue(tags.contains(Tag.of(UserAgentTagUtil.VERSION_TAG, "6.53.7"))); + assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android"))); + } else { + verifyNoInteractions(requestsByVersionCounter); + } + + + } }