Skip to content

Commit

Permalink
Report telemetry metrics for Exploit Prevention (#7314)
Browse files Browse the repository at this point in the history
* Report telemetry metrics for Exploit Prevention

* Performance improvement

* Missing test

* Missing test

* Performance improvement

* Spotless

* Performance improvement
  • Loading branch information
ValentinZakharov authored Jul 18, 2024
1 parent 2e9ba7a commit 144efa8
Show file tree
Hide file tree
Showing 10 changed files with 194 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public BlockingDetails shouldBlockUser(@Nonnull String userId) {
}
SingletonDataBundle<String> db = new SingletonDataBundle<>(KnownAddresses.USER_ID, userId);
try {
GatewayContext gwCtx = new GatewayContext(true, false);
GatewayContext gwCtx = new GatewayContext(true);
flow = eventProducer.publishDataEvent(subInfo, reqCtx, db, gwCtx);
break;
} catch (ExpiredSubscriberInfoException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import datadog.trace.api.gateway.SubscriptionService;
import datadog.trace.api.http.StoredBodySupplier;
import datadog.trace.api.internal.TraceSegment;
import datadog.trace.api.telemetry.RuleType;
import datadog.trace.api.telemetry.WafMetricCollector;
import datadog.trace.bootstrap.instrumentation.api.Tags;
import datadog.trace.bootstrap.instrumentation.api.URIDataAdapter;
Expand Down Expand Up @@ -235,7 +236,7 @@ public void init() {
DataBundle bundle =
new SingletonDataBundle<>(KnownAddresses.REQUEST_PATH_PARAMS, data);
try {
GatewayContext gwCtx = new GatewayContext(false, false);
GatewayContext gwCtx = new GatewayContext(false);
return producerService.publishDataEvent(subInfo, ctx, bundle, gwCtx);
} catch (ExpiredSubscriberInfoException e) {
pathParamsSubInfo = null;
Expand Down Expand Up @@ -270,7 +271,7 @@ public void init() {
DataBundle bundle =
new SingletonDataBundle<>(KnownAddresses.REQUEST_BODY_RAW, bodyContent);
try {
GatewayContext gwCtx = new GatewayContext(false, false);
GatewayContext gwCtx = new GatewayContext(false);
return producerService.publishDataEvent(subInfo, ctx, bundle, gwCtx);
} catch (ExpiredSubscriberInfoException e) {
rawRequestBodySubInfo = null;
Expand Down Expand Up @@ -308,7 +309,7 @@ public void init() {
new SingletonDataBundle<>(
KnownAddresses.REQUEST_BODY_OBJECT, ObjectIntrospection.convert(obj));
try {
GatewayContext gwCtx = new GatewayContext(false, false);
GatewayContext gwCtx = new GatewayContext(false);
return producerService.publishDataEvent(subInfo, ctx, bundle, gwCtx);
} catch (ExpiredSubscriberInfoException e) {
requestBodySubInfo = null;
Expand Down Expand Up @@ -388,7 +389,7 @@ public void init() {
DataBundle bundle =
new SingletonDataBundle<>(KnownAddresses.GRPC_SERVER_METHOD, method);
try {
GatewayContext gwCtx = new GatewayContext(true, false);
GatewayContext gwCtx = new GatewayContext(true);
return producerService.publishDataEvent(subInfo, ctx, bundle, gwCtx);
} catch (ExpiredSubscriberInfoException e) {
grpcServerMethodSubInfo = null;
Expand Down Expand Up @@ -417,7 +418,7 @@ public void init() {
DataBundle bundle =
new SingletonDataBundle<>(KnownAddresses.GRPC_SERVER_REQUEST_MESSAGE, convObj);
try {
GatewayContext gwCtx = new GatewayContext(true, false);
GatewayContext gwCtx = new GatewayContext(true);
return producerService.publishDataEvent(subInfo, ctx, bundle, gwCtx);
} catch (ExpiredSubscriberInfoException e) {
grpcServerRequestMsgSubInfo = null;
Expand Down Expand Up @@ -445,7 +446,7 @@ public void init() {
DataBundle bundle =
new SingletonDataBundle<>(KnownAddresses.GRAPHQL_SERVER_ALL_RESOLVERS, data);
try {
GatewayContext gwCtx = new GatewayContext(true, false);
GatewayContext gwCtx = new GatewayContext(true);
return producerService.publishDataEvent(subInfo, ctx, bundle, gwCtx);
} catch (ExpiredSubscriberInfoException e) {
graphqlServerRequestMsgSubInfo = null;
Expand Down Expand Up @@ -487,7 +488,7 @@ public void init() {
.add(KnownAddresses.DB_SQL_QUERY, sql)
.build();
try {
GatewayContext gwCtx = new GatewayContext(false, true);
GatewayContext gwCtx = new GatewayContext(false, RuleType.SQL_INJECTION);
return producerService.publishDataEvent(subInfo, ctx, bundle, gwCtx);
} catch (ExpiredSubscriberInfoException e) {
dbSqlQuerySubInfo = null;
Expand Down Expand Up @@ -685,7 +686,7 @@ private Flow<Void> maybePublishRequestData(AppSecRequestContext ctx) {
}

try {
GatewayContext gwCtx = new GatewayContext(false, false);
GatewayContext gwCtx = new GatewayContext(false);
return producerService.publishDataEvent(subInfo, ctx, bundle, gwCtx);
} catch (ExpiredSubscriberInfoException e) {
this.initialReqDataSubInfo = null;
Expand Down Expand Up @@ -718,7 +719,7 @@ private Flow<Void> maybePublishResponseData(AppSecRequestContext ctx) {
}

try {
GatewayContext gwCtx = new GatewayContext(false, false);
GatewayContext gwCtx = new GatewayContext(false);
return producerService.publishDataEvent(subInfo, ctx, bundle, gwCtx);
} catch (ExpiredSubscriberInfoException e) {
respDataSubInfo = null;
Expand Down Expand Up @@ -751,7 +752,7 @@ private void maybeExtractSchemas(AppSecRequestContext ctx) {
KnownAddresses.WAF_CONTEXT_PROCESSOR,
Collections.singletonMap("extract-schema", true));
try {
GatewayContext gwCtx = new GatewayContext(false, false);
GatewayContext gwCtx = new GatewayContext(false);
producerService.publishDataEvent(subInfo, ctx, bundle, gwCtx);
return;
} catch (ExpiredSubscriberInfoException e) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
package com.datadog.appsec.gateway;

import datadog.trace.api.telemetry.RuleType;

public class GatewayContext {
public final boolean isTransient;
public final boolean isRasp;

public GatewayContext(final boolean isTransient, final boolean isRasp) {
public final RuleType raspRuleType;

public GatewayContext(final boolean isTransient) {
this(isTransient, null);
}

public GatewayContext(final boolean isTransient, final RuleType raspRuleType) {
this.isTransient = isTransient;
this.isRasp = isRasp;
this.isRasp = raspRuleType != null;
this.raspRuleType = raspRuleType;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -432,11 +432,18 @@ public void onDataAvailable(
start = System.currentTimeMillis();
}

if (gwCtx.isRasp) {
WafMetricCollector.get().raspRuleEval(gwCtx.raspRuleType);
}

try {
resultWithData = doRunPowerwaf(reqCtx, newData, ctxAndAddr, gwCtx);
} catch (TimeoutPowerwafException tpe) {
reqCtx.increaseTimeouts();
log.debug(LogCollector.EXCLUDE_TELEMETRY, "Timeout calling the WAF", tpe);
if (gwCtx.isRasp) {
WafMetricCollector.get().raspTimeout(gwCtx.raspRuleType);
}
return;
} catch (AbstractPowerwafException e) {
log.error("Error calling WAF", e);
Expand All @@ -455,6 +462,10 @@ public void onDataAvailable(
log.warn("WAF signalled result {}: {}", resultWithData.result, resultWithData.data);
}

if (gwCtx.isRasp) {
WafMetricCollector.get().raspRuleMatch(gwCtx.raspRuleType);
}

for (Map.Entry<String, Map<String, Object>> action : resultWithData.actions.entrySet()) {
String actionType = action.getKey();
Map<String, Object> actionParams = action.getValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class EventDispatcherSpecification extends DDSpecification {
set.addSubscription([KnownAddresses.REQUEST_CLIENT_IP], dataListener2)
set.addSubscription([KnownAddresses.REQUEST_METHOD], dataListener3)
dispatcher.subscribeDataAvailable(set)
def gwCtx = new GatewayContext(true, false)
def gwCtx = new GatewayContext(true)

when:
def subscribers = dispatcher.getDataSubscribers(KnownAddresses.REQUEST_CLIENT_IP, KnownAddresses.REQUEST_METHOD)
Expand Down Expand Up @@ -72,7 +72,7 @@ class EventDispatcherSpecification extends DDSpecification {
[KnownAddresses.REQUEST_CLIENT_IP, KnownAddresses.HEADERS_NO_COOKIES],
listener)
dispatcher.subscribeDataAvailable(set)
def gwCtx = new GatewayContext(true, false)
def gwCtx = new GatewayContext(true)

when:
def subscribers = dispatcher.getDataSubscribers(KnownAddresses.REQUEST_CLIENT_IP, KnownAddresses.HEADERS_NO_COOKIES)
Expand All @@ -95,7 +95,7 @@ class EventDispatcherSpecification extends DDSpecification {
set.addSubscription([KnownAddresses.REQUEST_CLIENT_IP], dataListener1)
set.addSubscription([KnownAddresses.REQUEST_CLIENT_IP], dataListener2)
dispatcher.subscribeDataAvailable(set)
def gwCtx = new GatewayContext(true, false)
def gwCtx = new GatewayContext(true)

when:
def subscribers = dispatcher.getDataSubscribers(KnownAddresses.REQUEST_CLIENT_IP)
Expand Down Expand Up @@ -123,7 +123,7 @@ class EventDispatcherSpecification extends DDSpecification {
def set = new EventDispatcher.DataSubscriptionSet()
set.addSubscription([KnownAddresses.REQUEST_CLIENT_IP], listener)
dispatcher.subscribeDataAvailable(set)
def gwCtx = new GatewayContext(false, false)
def gwCtx = new GatewayContext(false)

when:
def subscribers = dispatcher.getDataSubscribers(KnownAddresses.REQUEST_CLIENT_IP)
Expand Down Expand Up @@ -160,7 +160,7 @@ class EventDispatcherSpecification extends DDSpecification {
EventDispatcher anotherDispatcher = new EventDispatcher()
EventProducerService.DataSubscriberInfo subInfo =
anotherDispatcher.getDataSubscribers(KnownAddresses.REQUEST_CLIENT_IP)
def gwCtx = new GatewayContext(false, false)
def gwCtx = new GatewayContext(false)
dispatcher.publishDataEvent(subInfo, Stub(AppSecRequestContext), Stub(DataBundle), gwCtx)

then:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ class GatewayBridgeSpecification extends DDSpecification {
i.empty >> false
i
}()
GatewayContext gwCtx = new GatewayContext(false, false)

TraceSegmentPostProcessor pp = Mock()
GatewayBridge bridge = new GatewayBridge(ig, eventDispatcher, null, [pp])
Expand Down Expand Up @@ -746,6 +745,36 @@ class GatewayBridgeSpecification extends DDSpecification {
flow.action == Flow.Action.Noop.INSTANCE
}

void 'process database type'() {
setup:
eventDispatcher.getDataSubscribers({ KnownAddresses.DB_TYPE in it }) >> nonEmptyDsInfo

when:
databaseConnectionCB.accept(ctx, 'postgresql')

then:
arCtx.dbType == 'postgresql'
}

void 'process jdbc statement query object'() {
setup:
eventDispatcher.getDataSubscribers({ KnownAddresses.DB_SQL_QUERY in it }) >> nonEmptyDsInfo
DataBundle bundle
GatewayContext gatewayContext

when:
Flow<?> flow = databaseSqlQueryCB.apply(ctx, 'SELECT * FROM foo')

then:
1 * eventDispatcher.publishDataEvent(nonEmptyDsInfo, ctx.data, _ as DataBundle, _ as GatewayContext) >>
{ a, b, db, gw -> bundle = db; gatewayContext = gw; NoopFlow.INSTANCE }
bundle.get(KnownAddresses.DB_SQL_QUERY) == 'SELECT * FROM foo'
flow.result == null
flow.action == Flow.Action.Noop.INSTANCE
gatewayContext.isTransient == false
gatewayContext.isRasp == true
}

void 'calls trace segment post processor'() {
setup:
AgentSpan span = Stub()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class PowerWAFModuleSpecification extends DDSpecification {
}

AppSecRequestContext ctx = Spy()
GatewayContext gwCtx = new GatewayContext(false, false)
GatewayContext gwCtx = new GatewayContext(false)

StubAppSecConfigService service
PowerWAFModule pwafModule = new PowerWAFModule()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package datadog.trace.api.telemetry;

public enum RuleType {
LIF("lfi"),
SQL_INJECTION("sql_injection"),
SSRF("ssrf");

public final String name;
private static final int numValues = RuleType.values().length;

RuleType(String name) {
this.name = name;
}

public static int getNumValues() {
return numValues;
}

@Override
public String toString() {
return name;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicLongArray;

public class WafMetricCollector implements MetricCollector<WafMetricCollector.WafMetric> {

Expand All @@ -28,6 +29,12 @@ public static WafMetricCollector get() {
private static final AtomicRequestCounter wafRequestCounter = new AtomicRequestCounter();
private static final AtomicRequestCounter wafTriggeredRequestCounter = new AtomicRequestCounter();
private static final AtomicRequestCounter wafBlockedRequestCounter = new AtomicRequestCounter();
private static final AtomicLongArray raspRuleEvalCounter =
new AtomicLongArray(RuleType.getNumValues());
private static final AtomicLongArray raspRuleMatchCounter =
new AtomicLongArray(RuleType.getNumValues());
private static final AtomicLongArray respTimeoutCounter =
new AtomicLongArray(RuleType.getNumValues());

/** WAF version that will be initialized with wafInit and reused for all metrics. */
private static String wafVersion = "";
Expand Down Expand Up @@ -70,6 +77,18 @@ public void wafRequestBlocked() {
wafBlockedRequestCounter.increment();
}

public void raspRuleEval(final RuleType ruleType) {
raspRuleEvalCounter.incrementAndGet(ruleType.ordinal());
}

public void raspRuleMatch(final RuleType ruleType) {
raspRuleMatchCounter.incrementAndGet(ruleType.ordinal());
}

public void raspTimeout(final RuleType ruleType) {
respTimeoutCounter.incrementAndGet(ruleType.ordinal());
}

@Override
public Collection<WafMetric> drain() {
if (!rawMetricsQueue.isEmpty()) {
Expand Down Expand Up @@ -112,13 +131,48 @@ public void prepareMetrics() {

// Blocked requests
if (wafBlockedRequestCounter.get() > 0) {
rawMetricsQueue.offer(
if (!rawMetricsQueue.offer(
new WafRequestsRawMetric(
wafBlockedRequestCounter.getAndReset(),
WafMetricCollector.wafVersion,
WafMetricCollector.rulesVersion,
true,
true));
true))) {
return;
}
}

// RASP rule eval per rule type
for (RuleType ruleType : RuleType.values()) {
long counter = raspRuleEvalCounter.getAndSet(ruleType.ordinal(), 0);
if (counter > 0) {
if (!rawMetricsQueue.offer(
new RaspRuleEval(counter, ruleType, WafMetricCollector.wafVersion))) {
return;
}
}
}

// RASP rule match per rule type
for (RuleType ruleType : RuleType.values()) {
long counter = raspRuleMatchCounter.getAndSet(ruleType.ordinal(), 0);
if (counter > 0) {
if (!rawMetricsQueue.offer(
new RaspRuleMatch(counter, ruleType, WafMetricCollector.wafVersion))) {
return;
}
}
}

// RASP timeout per rule type
for (RuleType ruleType : RuleType.values()) {
long counter = respTimeoutCounter.getAndSet(ruleType.ordinal(), 0);
if (counter > 0) {
if (!rawMetricsQueue.offer(
new RaspTimeout(counter, ruleType, WafMetricCollector.wafVersion))) {
return;
}
}
}
}

Expand Down Expand Up @@ -165,6 +219,24 @@ public WafRequestsRawMetric(
}
}

public static class RaspRuleEval extends WafMetric {
public RaspRuleEval(final long counter, final RuleType ruleType, final String wafVersion) {
super("rasp.rule.eval", counter, "rule_type:" + ruleType, "waf_version:" + wafVersion);
}
}

public static class RaspRuleMatch extends WafMetric {
public RaspRuleMatch(final long counter, final RuleType ruleType, final String wafVersion) {
super("rasp.rule.match", counter, "rule_type:" + ruleType, "waf_version:" + wafVersion);
}
}

public static class RaspTimeout extends WafMetric {
public RaspTimeout(final long counter, final RuleType ruleType, final String wafVersion) {
super("rasp.timeout", counter, "rule_type:" + ruleType, "waf_version:" + wafVersion);
}
}

public static class AtomicRequestCounter {

private final AtomicLong atomicLong = new AtomicLong();
Expand Down
Loading

0 comments on commit 144efa8

Please sign in to comment.