Skip to content

Commit

Permalink
Add grpc.server.method to WAF addresses with FQN of the grpc method (#…
Browse files Browse the repository at this point in the history
…7079)

Add grpc.server.method to WAF addresses with FQN of the grpc method
  • Loading branch information
manuel-alvarez-alvarez committed Jun 14, 2024
1 parent 38271ed commit ca01312
Show file tree
Hide file tree
Showing 16 changed files with 248 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ public interface KnownAddresses {
Address<CaseInsensitiveMap<List<String>>> HEADERS_NO_COOKIES =
new Address<>("server.request.headers.no_cookies");

Address<Object> GRPC_SERVER_METHOD = new Address<>("grpc.server.method");

Address<Object> GRPC_SERVER_REQUEST_MESSAGE = new Address<>("grpc.server.request.message");

// XXX: Not really used yet, but it's a known address and we should not treat it as unknown.
Expand Down Expand Up @@ -159,6 +161,8 @@ static Address<?> forName(String name) {
return REQUEST_QUERY;
case "server.request.headers.no_cookies":
return HEADERS_NO_COOKIES;
case "grpc.server.method":
return GRPC_SERVER_METHOD;
case "grpc.server.request.message":
return GRPC_SERVER_REQUEST_MESSAGE;
case "grpc.server.request.metadata":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ public class GatewayBridge {
private volatile DataSubscriberInfo requestBodySubInfo;
private volatile DataSubscriberInfo pathParamsSubInfo;
private volatile DataSubscriberInfo respDataSubInfo;
private volatile DataSubscriberInfo grpcServerMethodSubInfo;
private volatile DataSubscriberInfo grpcServerRequestMsgSubInfo;
private volatile DataSubscriberInfo graphqlServerRequestMsgSubInfo;
private volatile DataSubscriberInfo requestEndSubInfo;
Expand Down Expand Up @@ -361,6 +362,32 @@ public void init() {
return maybePublishResponseData(ctx);
});

subscriptionService.registerCallback(
EVENTS.grpcServerMethod(),
(ctx_, method) -> {
AppSecRequestContext ctx = ctx_.getData(RequestContextSlot.APPSEC);
if (ctx == null || method == null || method.isEmpty()) {
return NoopFlow.INSTANCE;
}
while (true) {
DataSubscriberInfo subInfo = grpcServerMethodSubInfo;
if (subInfo == null) {
subInfo = producerService.getDataSubscribers(KnownAddresses.GRPC_SERVER_METHOD);
grpcServerMethodSubInfo = subInfo;
}
if (subInfo == null || subInfo.isEmpty()) {
return NoopFlow.INSTANCE;
}
DataBundle bundle =
new SingletonDataBundle<>(KnownAddresses.GRPC_SERVER_METHOD, method);
try {
return producerService.publishDataEvent(subInfo, ctx, bundle, true);
} catch (ExpiredSubscriberInfoException e) {
grpcServerMethodSubInfo = null;
}
}
});

subscriptionService.registerCallback(
EVENTS.grpcServerRequestMessage(),
(ctx_, obj) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class KnownAddressesSpecification extends Specification {
'server.request.body.combined_file_size',
'server.request.query',
'server.request.headers.no_cookies',
'grpc.server.method',
'grpc.server.request.message',
'grpc.server.request.metadata',
'graphql.server.all_resolvers',
Expand All @@ -41,7 +42,7 @@ class KnownAddressesSpecification extends Specification {

void 'number of known addresses is expected number'() {
expect:
Address.instanceCount() == 29
Address.instanceCount() == 30
KnownAddresses.WAF_CONTEXT_PROCESSOR.serial == Address.instanceCount() - 1
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class GatewayBridgeSpecification extends DDSpecification {
BiFunction<RequestContext, Integer, Flow<Void>> responseStartedCB
TriConsumer<RequestContext, String, String> respHeaderCB
Function<RequestContext, Flow<Void>> respHeadersDoneCB
BiFunction<RequestContext, String, Flow<Void>> grpcServerMethodCB
BiFunction<RequestContext, Object, Flow<Void>> grpcServerRequestMessageCB
BiFunction<RequestContext, Map<String, Object>, Flow<Void>> graphqlServerRequestMessageCB
BiConsumer<RequestContext, String> databaseConnectionCB
Expand Down Expand Up @@ -413,6 +414,7 @@ class GatewayBridgeSpecification extends DDSpecification {
1 * ig.registerCallback(EVENTS.responseStarted(), _) >> { responseStartedCB = it[1]; null }
1 * ig.registerCallback(EVENTS.responseHeader(), _) >> { respHeaderCB = it[1]; null }
1 * ig.registerCallback(EVENTS.responseHeaderDone(), _) >> { respHeadersDoneCB = it[1]; null }
1 * ig.registerCallback(EVENTS.grpcServerMethod(), _) >> { grpcServerMethodCB = it[1]; null }
1 * ig.registerCallback(EVENTS.grpcServerRequestMessage(), _) >> { grpcServerRequestMessageCB = it[1]; null }
1 * ig.registerCallback(EVENTS.graphqlServerRequestMessage(), _) >> { graphqlServerRequestMessageCB = it[1]; null }
1 * ig.registerCallback(EVENTS.databaseConnection(), _) >> { databaseConnectionCB = it[1]; null }
Expand Down Expand Up @@ -710,6 +712,22 @@ class GatewayBridgeSpecification extends DDSpecification {
flow.action == Flow.Action.Noop.INSTANCE
}

void 'grpc server method publishes'() {
setup:
eventDispatcher.getDataSubscribers(KnownAddresses.GRPC_SERVER_METHOD) >> nonEmptyDsInfo
DataBundle bundle

when:
Flow<?> flow = grpcServerMethodCB.apply(ctx, '/my.package.Greeter/SayHello')

then:
1 * eventDispatcher.publishDataEvent(nonEmptyDsInfo, ctx.data, _ as DataBundle, true) >>
{ args -> bundle = args[2]; NoopFlow.INSTANCE }
bundle.get(KnownAddresses.GRPC_SERVER_METHOD) == '/my.package.Greeter/SayHello'
flow.result == null
flow.action == Flow.Action.Noop.INSTANCE
}

void 'calls trace segment post processor'() {
setup:
AgentSpan span = Stub()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.grpc.ForwardingServerCallListener;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
Expand Down Expand Up @@ -76,6 +77,7 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
if (reqContext != null) {
callIGCallbackClientAddress(cbp, reqContext, call);
callIGCallbackHeaders(cbp, reqContext, headers);
callIGCallbackGrpcServerMethod(cbp, reqContext, call.getMethodDescriptor());
}

DECORATE.afterStart(span);
Expand Down Expand Up @@ -315,6 +317,16 @@ private static void callIGCallbackRequestEnded(@Nonnull final AgentSpan span) {
}
}

private static <ReqT, RespT> void callIGCallbackGrpcServerMethod(
CallbackProvider cbp, RequestContext ctx, MethodDescriptor<ReqT, RespT> methodDescriptor) {
String method = methodDescriptor.getFullMethodName();
BiFunction<RequestContext, String, Flow<Void>> cb = cbp.getCallback(EVENTS.grpcServerMethod());
if (method == null || cb == null) {
return;
}
cb.apply(ctx, method);
}

private static void callIGCallbackGrpcMessage(@Nonnull final AgentSpan span, Object obj) {
if (obj == null) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ abstract class ArmeriaGrpcTest extends VersionedNamingTestBase {

def collectedAppSecHeaders = [:]
boolean appSecHeaderDone = false
def collectedAppSecServerMethods = []
def collectedAppSecReqMsgs = []

final Duration timeoutDuration() {
Expand Down Expand Up @@ -97,6 +98,10 @@ abstract class ArmeriaGrpcTest extends VersionedNamingTestBase {
collectedAppSecReqMsgs << obj
Flow.ResultFlow.empty()
} as BiFunction<RequestContext, Object, Flow<Void>>)
ig.registerCallback(EVENTS.grpcServerMethod(), { reqCtx, method ->
collectedAppSecServerMethods << method
Flow.ResultFlow.empty()
} as BiFunction<RequestContext, String, Flow<Void>>)
}

def cleanup() {
Expand Down Expand Up @@ -230,6 +235,8 @@ abstract class ArmeriaGrpcTest extends VersionedNamingTestBase {
traceId.toLong() as String == collectedAppSecHeaders['x-datadog-trace-id']
collectedAppSecReqMsgs.size() == 1
collectedAppSecReqMsgs.first().name == name
collectedAppSecServerMethods.size() == 1
collectedAppSecServerMethods.first() == 'example.Greeter/SayHello'

and:
if (isDataStreamsEnabled()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.grpc.ForwardingServerCallListener;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
Expand Down Expand Up @@ -75,6 +76,7 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
if (reqContext != null) {
callIGCallbackClientAddress(cbp, reqContext, call);
callIGCallbackHeaders(cbp, reqContext, headers);
callIGCallbackGrpcServerMethod(cbp, reqContext, call.getMethodDescriptor());
}

DECORATE.afterStart(span);
Expand Down Expand Up @@ -314,6 +316,16 @@ private static void callIGCallbackRequestEnded(@Nonnull final AgentSpan span) {
}
}

private static <ReqT, RespT> void callIGCallbackGrpcServerMethod(
CallbackProvider cbp, RequestContext ctx, MethodDescriptor<ReqT, RespT> methodDescriptor) {
String method = methodDescriptor.getFullMethodName();
BiFunction<RequestContext, String, Flow<Void>> cb = cbp.getCallback(EVENTS.grpcServerMethod());
if (method == null || cb == null) {
return;
}
cb.apply(ctx, method);
}

private static void callIGCallbackGrpcMessage(@Nonnull final AgentSpan span, Object obj) {
if (obj == null) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ abstract class GrpcTest extends VersionedNamingTestBase {
def collectedAppSecHeaders = [:]
boolean appSecHeaderDone = false
def collectedAppSecReqMsgs = []
def collectedAppSecServerMethods = []

@Override
final String service() {
Expand Down Expand Up @@ -89,6 +90,10 @@ abstract class GrpcTest extends VersionedNamingTestBase {
collectedAppSecReqMsgs << obj
Flow.ResultFlow.empty()
} as BiFunction<RequestContext, Object, Flow<Void>>)
ig.registerCallback(EVENTS.grpcServerMethod(), { reqCtx, method ->
collectedAppSecServerMethods << method
Flow.ResultFlow.empty()
} as BiFunction<RequestContext, String, Flow<Void>>)
}

def cleanup() {
Expand Down Expand Up @@ -217,6 +222,8 @@ abstract class GrpcTest extends VersionedNamingTestBase {
traceId.toLong() as String == collectedAppSecHeaders['x-datadog-trace-id']
collectedAppSecReqMsgs.size() == 1
collectedAppSecReqMsgs.first().name == name
collectedAppSecServerMethods.size() == 1
collectedAppSecServerMethods.first() == 'example.Greeter/SayHello'

and:
if (isDataStreamsEnabled()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package datadog.smoketest.appsec

abstract class AbstractSpringBootWithGRPCAppSecTest extends AbstractAppSecServerSmokeTest {

@Override
ProcessBuilder createProcessBuilder() {
String springBootShadowJar = System.getProperty("datadog.smoketest.appsec.springboot-grpc.shadowJar.path")
assert springBootShadowJar != null

List<String> command = [
javaPath(),
*defaultJavaProperties,
*defaultAppSecProperties,
"-jar",
springBootShadowJar,
"--server.port=${httpPort}"
].collect { it as String }

ProcessBuilder processBuilder = new ProcessBuilder(command)
processBuilder.directory(new File(buildDirectory))
}

static final String ROUTE = 'async_annotation_greeting'
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package datadog.smoketest.appsec


import okhttp3.Request
import spock.lang.Shared

class ServerMethodTest extends AbstractSpringBootWithGRPCAppSecTest {

@Shared
String buildDir = new File(System.getProperty("datadog.smoketest.builddir")).absolutePath
@Shared
String customRulesPath = "${buildDir}/appsec_custom_rules.json"

@Override
ProcessBuilder createProcessBuilder() {
// We run this here to ensure it runs before starting the process. Child setupSpec runs after parent setupSpec,
// so it is not a valid location.
appendRules(customRulesPath, [
[
id : '__test_server_method_bock',
name : 'test rule to block on server method',
tags : [
type : 'test',
category : 'test',
confidence: '1',
],
conditions : [
[
parameters: [
inputs: [[address: 'grpc.server.method']],
regex : 'Greeter',
],
operator : 'match_regex',
]
],
transformers: [],
on_match : ['block']
]
])
return super.createProcessBuilder()
}

void 'test grpc.server.method address'() {
setup:
String url = "http://localhost:${httpPort}/${ROUTE}"
def request = new Request.Builder()
.url("${url}?message=${'Hello!'.bytes.encodeBase64()}")
.get().build()

when:
def response = client.newCall(request).execute()

then:
def responseBodyStr = response.body().string()
responseBodyStr != null
responseBodyStr.contains("bye")
response.body().contentType().toString().contains("text/plain")
response.code() == 200

and:
waitForTraceCount(2) == 2
rootSpans.size() == 2
def grpcRootSpan = rootSpans.find { it.triggers }
grpcRootSpan != null
def match = grpcRootSpan.triggers[0]['rule_matches'][0]
match != null
match['parameters'][0]['address'] == 'grpc.server.method'
match['parameters'][0]['value'] == 'smoketest.Greeter/Hello'
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,9 @@ package datadog.smoketest.appsec

import okhttp3.Request

class SpringBootWithGRPCAppSecTest extends AbstractAppSecServerSmokeTest {
class ServerRequestMessageTest extends AbstractSpringBootWithGRPCAppSecTest {

@Override
ProcessBuilder createProcessBuilder() {
String springBootShadowJar = System.getProperty("datadog.smoketest.appsec.springboot-grpc.shadowJar.path")
assert springBootShadowJar != null

List<String> command = [
javaPath(),
*defaultJavaProperties,
*defaultAppSecProperties,
"-jar",
springBootShadowJar,
"--server.port=${httpPort}"
].collect { it as String }

ProcessBuilder processBuilder = new ProcessBuilder(command)
processBuilder.directory(new File(buildDirectory))
}

static final String ROUTE = 'async_annotation_greeting'

def greeter() {
void 'test grpc.server.request.message address'() {
setup:
String url = "http://localhost:${httpPort}/${ROUTE}"
def request = new Request.Builder()
Expand Down
Loading

0 comments on commit ca01312

Please sign in to comment.