Skip to content

Commit

Permalink
Add static servlet paths to MetricsHttpChannelListener
Browse files Browse the repository at this point in the history
  • Loading branch information
eager-signal committed Feb 14, 2024
1 parent f90ccd3 commit 9ce2b75
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import java.util.List;
import java.util.Optional;
import java.util.ServiceLoader;
import java.util.Set;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService;
Expand Down Expand Up @@ -97,8 +98,8 @@
import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV2;
import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV3;
import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV4;
import org.whispersystems.textsecuregcm.controllers.CallRoutingController;
import org.whispersystems.textsecuregcm.controllers.CallLinkController;
import org.whispersystems.textsecuregcm.controllers.CallRoutingController;
import org.whispersystems.textsecuregcm.controllers.CertificateController;
import org.whispersystems.textsecuregcm.controllers.ChallengeController;
import org.whispersystems.textsecuregcm.controllers.DeviceController;
Expand Down Expand Up @@ -792,7 +793,11 @@ public void run(WhisperServerConfiguration config, Environment environment) thro
.setAuthenticator(accountAuthenticator)
.buildAuthFilter();

final MetricsHttpChannelListener metricsHttpChannelListener = new MetricsHttpChannelListener(clientReleaseManager);
final String websocketServletPath = "/v1/websocket/";
final String provisioningWebsocketServletPath = "/v1/websocket/provisioning/";

final MetricsHttpChannelListener metricsHttpChannelListener = new MetricsHttpChannelListener(clientReleaseManager,
Set.of(websocketServletPath, provisioningWebsocketServletPath, "/health-check"));
metricsHttpChannelListener.configure(environment);

environment.jersey().register(new VirtualExecutorServiceProvider("managed-async-virtual-thread-"));
Expand Down Expand Up @@ -950,10 +955,10 @@ public void run(WhisperServerConfiguration config, Environment environment) thro
ServletRegistration.Dynamic websocket = environment.servlets().addServlet("WebSocket", webSocketServlet);
ServletRegistration.Dynamic provisioning = environment.servlets().addServlet("Provisioning", provisioningServlet);

websocket.addMapping("/v1/websocket/");
websocket.addMapping(websocketServletPath);
websocket.setAsyncSupported(true);

provisioning.addMapping("/v1/websocket/provisioning/");
provisioning.addMapping(provisioningWebsocketServletPath);
provisioning.setAsyncSupported(true);

environment.admin().addTask(new SetRequestLoggingEnabledTask());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import javax.annotation.Nullable;
import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.container.ContainerResponseContext;
Expand Down Expand Up @@ -54,6 +55,7 @@ private record RequestInfo(String path, String method, int statusCode, @Nullable
}

private final ClientReleaseManager clientReleaseManager;
private final Set<String> servletPaths;

public static final String REQUEST_COUNTER_NAME = name(MetricsHttpChannelListener.class, "request");
public static final String REQUESTS_BY_VERSION_COUNTER_NAME = name(MetricsHttpChannelListener.class,
Expand All @@ -76,14 +78,16 @@ private record RequestInfo(String path, String method, int statusCode, @Nullable
private final MeterRegistry meterRegistry;


public MetricsHttpChannelListener(final ClientReleaseManager clientReleaseManager) {
this(Metrics.globalRegistry, clientReleaseManager);
public MetricsHttpChannelListener(final ClientReleaseManager clientReleaseManager, final Set<String> servletPaths) {
this(Metrics.globalRegistry, clientReleaseManager, servletPaths);
}

@VisibleForTesting
MetricsHttpChannelListener(final MeterRegistry meterRegistry, final ClientReleaseManager clientReleaseManager) {
MetricsHttpChannelListener(final MeterRegistry meterRegistry, final ClientReleaseManager clientReleaseManager,
final Set<String> servletPaths) {
this.meterRegistry = meterRegistry;
this.clientReleaseManager = clientReleaseManager;
this.servletPaths = servletPaths;
}

public void configure(final Environment environment) {
Expand Down Expand Up @@ -158,7 +162,12 @@ public void filter(final ContainerRequestContext requestContext, final Container
private RequestInfo getRequestInfo(Request request) {
final String path = Optional.ofNullable(request.getAttribute(URI_INFO_PROPERTY_NAME))
.map(attr -> UriInfoUtil.getPathTemplate((ExtendedUriInfo) attr))
.orElse("unknown");
.orElseGet(() -> {
if (servletPaths.contains(request.getPathInfo())) {
return request.getPathInfo();
}
return "unknown";
});
final String method = Optional.ofNullable(request.getMethod()).orElse("unknown");
// Response cannot be null, but its status might not always reflect an actual response status, since it gets
// initialized to 200
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tag;
import java.io.IOException;
import java.net.URI;
import java.security.Principal;
import java.time.Duration;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
Expand All @@ -38,6 +40,7 @@
import java.util.stream.Stream;
import javax.annotation.Priority;
import javax.security.auth.Subject;
import javax.servlet.DispatcherType;
import javax.ws.rs.GET;
import javax.ws.rs.InternalServerErrorException;
import javax.ws.rs.NotAuthorizedException;
Expand All @@ -55,13 +58,21 @@
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.util.component.Container;
import org.eclipse.jetty.util.component.LifeCycle;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.config.JettyWebSocketServletContainerInitializer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;
import org.whispersystems.websocket.WebSocketResourceProviderFactory;
import org.whispersystems.websocket.configuration.WebSocketConfiguration;
Expand Down Expand Up @@ -148,6 +159,64 @@ void testSimplePath(String requestPath, String expectedTagPath, String expectedR
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android")));
}


@Nested
class WebSocket {

private WebSocketClient client;

@BeforeEach
void setUp() throws Exception {
client = new WebSocketClient();
client.start();
}

@AfterEach
void tearDown() throws Exception {
client.stop();
}

@Test
void testWebSocketUpgrade() throws Exception {
final ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
upgradeRequest.setHeader(HttpHeaders.USER_AGENT, "Signal-Android/4.53.7 (Android 8.1)");

final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
when(METER_REGISTRY.counter(anyString(), any(Iterable.class)))
.thenAnswer(a -> MetricsHttpChannelListener.REQUEST_COUNTER_NAME.equals(a.getArgument(0, String.class))
? COUNTER
: mock(Counter.class))
.thenReturn(COUNTER);

client.connect(new WebSocketListener() {
@Override
public void onWebSocketConnect(final Session session) {
session.close(1000, "OK");
}
},
URI.create(String.format("ws://localhost:%d%s", EXTENSION.getLocalPort(), "/v1/websocket")), upgradeRequest)
.get(1, TimeUnit.SECONDS);

verify(METER_REGISTRY).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture());
verify(COUNTER).increment();

final Iterable<Tag> tagIterable = tagCaptor.getValue();
final Set<Tag> tags = new HashSet<>();

for (final Tag tag : tagIterable) {
tags.add(tag);
}

assertEquals(5, tags.size());
assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.PATH_TAG, "/v1/websocket")));
assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.METHOD_TAG, "GET")));
assertTrue(tags.contains(Tag.of(MetricsHttpChannelListener.STATUS_CODE_TAG, String.valueOf(101))));
assertTrue(
tags.contains(Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TRAFFIC_SOURCE.name().toLowerCase())));
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android")));
}
}

static Stream<Arguments> testSimplePath() {
return Stream.of(
Arguments.of("/v1/test/hello", "/v1/test/hello", "Hello!", 200),
Expand All @@ -166,11 +235,16 @@ public void run(final Configuration configuration,

final MetricsHttpChannelListener metricsHttpChannelListener = new MetricsHttpChannelListener(
METER_REGISTRY,
mock(ClientReleaseManager.class));
mock(ClientReleaseManager.class),
Set.of("/v1/websocket")
);

metricsHttpChannelListener.configure(environment);
environment.lifecycle().addEventListener(new TestListener(LISTENER_FUTURE_REFERENCE));

environment.servlets().addFilter("RemoteAddressFilter", new RemoteAddressFilter(true))
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), false, "/*");

environment.jersey().register(new TestResource());
environment.jersey().register(new TestAuthFilter());

Expand All @@ -185,9 +259,11 @@ public void run(final Configuration configuration,
JettyWebSocketServletContainerInitializer.configure(environment.getApplicationContext(), null);

WebSocketResourceProviderFactory<TestPrincipal> webSocketServlet = new WebSocketResourceProviderFactory<>(
webSocketEnvironment, TestPrincipal.class, webSocketConfiguration, "ignored");
webSocketEnvironment, TestPrincipal.class, webSocketConfiguration,
RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME);

environment.servlets().addServlet("WebSocket", webSocketServlet);
environment.servlets().addServlet("WebSocket", webSocketServlet)
.addMapping("/v1/websocket");
}
}

Expand Down Expand Up @@ -273,4 +349,5 @@ public boolean implies(final Subject subject) {
return false;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;

import com.google.common.net.HttpHeaders;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tag;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
Expand All @@ -27,29 +29,39 @@
import org.glassfish.jersey.uri.UriTemplate;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.whispersystems.textsecuregcm.storage.ClientReleaseManager;

class MetricsHttpChannelListenerTest {

private MeterRegistry meterRegistry;
private Counter counter;
private Counter requestCounter;
private Counter requestsByVersionCounter;
private ClientReleaseManager clientReleaseManager;
private MetricsHttpChannelListener listener;

@BeforeEach
void setup() {
meterRegistry = mock(MeterRegistry.class);
counter = mock(Counter.class);
requestCounter = mock(Counter.class);
requestsByVersionCounter = mock(Counter.class);

final ClientReleaseManager clientReleaseManager = mock(ClientReleaseManager.class);
when(clientReleaseManager.isVersionActive(any(), any())).thenReturn(false);
when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), any(Iterable.class)))
.thenReturn(requestCounter);

when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUESTS_BY_VERSION_COUNTER_NAME), any(Iterable.class)))
.thenReturn(requestsByVersionCounter);

listener = new MetricsHttpChannelListener(meterRegistry, clientReleaseManager);
clientReleaseManager = mock(ClientReleaseManager.class);

listener = new MetricsHttpChannelListener(meterRegistry, clientReleaseManager, Collections.emptySet());
}

@Test
@SuppressWarnings("unchecked")
void testOnEvent() {
void testRequests() {
final String path = "/test";
final String method = "GET";
final int statusCode = 200;
Expand All @@ -70,17 +82,15 @@ void testOnEvent() {
when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate(path)));

final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
when(meterRegistry.counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), any(Iterable.class)))
.thenReturn(counter);

listener.onComplete(request);

verify(requestCounter).increment();

verify(meterRegistry).counter(eq(MetricsHttpChannelListener.REQUEST_COUNTER_NAME), tagCaptor.capture());

final Iterable<Tag> tagIterable = tagCaptor.getValue();
final Set<Tag> tags = new HashSet<>();

for (final Tag tag : tagIterable) {
for (final Tag tag : tagCaptor.getValue()) {
tags.add(tag);
}

Expand All @@ -92,4 +102,50 @@ void testOnEvent() {
tags.contains(Tag.of(MetricsHttpChannelListener.TRAFFIC_SOURCE_TAG, TrafficSource.HTTP.name().toLowerCase())));
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android")));
}

@ParameterizedTest
@ValueSource(booleans = {true, false})
@SuppressWarnings("unchecked")
void testRequestsByVersion(final boolean versionActive) {
when(clientReleaseManager.isVersionActive(any(), any())).thenReturn(versionActive);
final String path = "/test";
final String method = "GET";
final int statusCode = 200;

final HttpURI httpUri = mock(HttpURI.class);
when(httpUri.getPath()).thenReturn(path);

final Request request = mock(Request.class);
when(request.getMethod()).thenReturn(method);
when(request.getHeader(HttpHeaders.USER_AGENT)).thenReturn("Signal-Android/6.53.7 (Android 8.1)");
when(request.getHttpURI()).thenReturn(httpUri);

final Response response = mock(Response.class);
when(response.getStatus()).thenReturn(statusCode);
when(request.getResponse()).thenReturn(response);
final ExtendedUriInfo extendedUriInfo = mock(ExtendedUriInfo.class);
when(request.getAttribute(MetricsHttpChannelListener.URI_INFO_PROPERTY_NAME)).thenReturn(extendedUriInfo);
when(extendedUriInfo.getMatchedTemplates()).thenReturn(List.of(new UriTemplate(path)));

listener.onComplete(request);

if (versionActive) {
final ArgumentCaptor<Iterable<Tag>> tagCaptor = ArgumentCaptor.forClass(Iterable.class);
verify(meterRegistry).counter(eq(MetricsHttpChannelListener.REQUESTS_BY_VERSION_COUNTER_NAME),
tagCaptor.capture());
final Set<Tag> tags = new HashSet<>();
tags.clear();
for (final Tag tag : tagCaptor.getValue()) {
tags.add(tag);
}

assertEquals(2, tags.size());
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.VERSION_TAG, "6.53.7")));
assertTrue(tags.contains(Tag.of(UserAgentTagUtil.PLATFORM_TAG, "android")));
} else {
verifyNoInteractions(requestsByVersionCounter);
}


}
}

0 comments on commit 9ce2b75

Please sign in to comment.