Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix session rewriting false positives #7323

Merged
merged 8 commits into from
Jul 24, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package datadog.trace.instrumentation.servlet3;

import static datadog.trace.agent.tooling.bytebuddy.matcher.ClassLoaderMatchers.hasClassNamed;
import static datadog.trace.agent.tooling.bytebuddy.matcher.HierarchyMatchers.extendsClass;
import static datadog.trace.agent.tooling.bytebuddy.matcher.HierarchyMatchers.implementsInterface;
import static datadog.trace.agent.tooling.bytebuddy.matcher.NameMatchers.named;
import static datadog.trace.agent.tooling.bytebuddy.matcher.NameMatchers.namedNoneOf;
import static net.bytebuddy.matcher.ElementMatchers.*;

import com.google.auto.service.AutoService;
import datadog.trace.agent.tooling.Instrumenter;
import datadog.trace.agent.tooling.InstrumenterModule;
import datadog.trace.api.iast.InstrumentationBridge;
import datadog.trace.api.iast.Sink;
import datadog.trace.api.iast.VulnerabilityTypes;
import datadog.trace.api.iast.sink.ApplicationModule;
import datadog.trace.bootstrap.InstrumentationContext;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import javax.servlet.ServletContext;
import javax.servlet.SessionTrackingMode;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpSession;
import net.bytebuddy.asm.Advice;
import net.bytebuddy.description.type.TypeDescription;
import net.bytebuddy.matcher.ElementMatcher;

@AutoService(InstrumenterModule.class)
public class IastOptOutHttpServletRequest3Instrumentation extends InstrumenterModule.Iast
implements Instrumenter.ForTypeHierarchy {

public IastOptOutHttpServletRequest3Instrumentation() {
super("servlet", "servlet-3");
}

@Override
public String muzzleDirective() {
return "servlet-3.x";
}

@Override
public ElementMatcher.Junction<ClassLoader> classLoaderMatcher() {
// Avoid matching request before servlet-3.x which don't have session tracking modes
return hasClassNamed("javax.servlet.SessionTrackingMode");
}

@Override
public String hierarchyMarkerType() {
return "javax.servlet.http.HttpServletRequest";
}

@Override
public ElementMatcher<TypeDescription> hierarchyMatcher() {
return implementsInterface(named(hierarchyMarkerType()))
// ignore wrappers that ship with servlet-api
.and(namedNoneOf("javax.servlet.http.HttpServletRequestWrapper"))
.and(not(extendsClass(named("javax.servlet.http.HttpServletRequestWrapper"))));
}

@Override
public void methodAdvice(MethodTransformer transformer) {
transformer.applyAdvice(
named("getSession").and(returns(named("javax.servlet.http.HttpSession"))).and(isPublic()),
getClass().getName() + "$GetHttpSessionAdvice");
}

@Override
public Map<String, String> contextStore() {
return Collections.singletonMap(
"javax.servlet.ServletContext", "javax.servlet.SessionTrackingMode");
}

@Override
protected boolean isOptOutEnabled() {
return true;
}

public static class GetHttpSessionAdvice {
@Advice.OnMethodExit(suppress = Throwable.class)
@Sink(VulnerabilityTypes.SESSION_REWRITING)
public static void onExit(
@Advice.This final HttpServletRequest request, @Advice.Return final HttpSession session) {
if (session == null) {
return;
}
final ApplicationModule module = InstrumentationBridge.APPLICATION;
if (module == null) {
return;
}
final ServletContext context = request.getServletContext();
if (InstrumentationContext.get(ServletContext.class, SessionTrackingMode.class).get(context)
!= null) {
return;
}
// We only want to report it once per application
InstrumentationContext.get(ServletContext.class, SessionTrackingMode.class)
.put(context, SessionTrackingMode.URL);
if (context.getEffectiveSessionTrackingModes() != null
&& !context.getEffectiveSessionTrackingModes().isEmpty()) {
Set<String> sessionTrackingModes = new HashSet<>();
for (SessionTrackingMode mode : context.getEffectiveSessionTrackingModes()) {
sessionTrackingModes.add(mode.name());
}
module.checkSessionTrackingModes(sessionTrackingModes);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
import datadog.trace.api.iast.VulnerabilityTypes;
import datadog.trace.api.iast.sink.ApplicationModule;
import datadog.trace.bootstrap.InstrumentationContext;
import java.util.HashSet;
import java.util.Set;
import javax.servlet.ServletContext;
import javax.servlet.ServletRequest;
import javax.servlet.SessionTrackingMode;
import javax.servlet.http.HttpServletRequest;
import net.bytebuddy.asm.Advice;

Expand All @@ -32,14 +29,6 @@ public static void onExit(@Advice.Argument(0) ServletRequest request) {
InstrumentationContext.get(ServletContext.class, Boolean.class).put(context, true);
if (applicationModule != null) {
applicationModule.onRealPath(context.getRealPath("/"));
if (context.getEffectiveSessionTrackingModes() != null
&& !context.getEffectiveSessionTrackingModes().isEmpty()) {
Set<String> sessionTrackingModes = new HashSet<>();
for (SessionTrackingMode mode : context.getEffectiveSessionTrackingModes()) {
sessionTrackingModes.add(mode.name());
}
applicationModule.checkSessionTrackingModes(sessionTrackingModes);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import groovy.servlet.AbstractHttpServlet
import org.eclipse.jetty.server.Request
import org.eclipse.jetty.server.Server
import org.eclipse.jetty.server.handler.ErrorHandler
import org.eclipse.jetty.server.session.SessionHandler
import org.eclipse.jetty.servlet.ServletContextHandler
import javax.servlet.AsyncEvent
import javax.servlet.AsyncListener
Expand Down Expand Up @@ -52,6 +53,7 @@ abstract class JettyServlet3Test extends AbstractServlet3Test<Server, ServletCon
}

ServletContextHandler servletContext = new ServletContextHandler(null, "/$context", ServletContextHandler.SESSIONS)
servletContext.sessionHandler = new SessionHandler()
servletContext.errorHandler = new ErrorHandler() {
@Override
void handle(String target, Request baseRequest, HttpServletRequest request, HttpServletResponse response) throws IOException {
Expand Down Expand Up @@ -522,6 +524,11 @@ class JettyServlet3ServeFromAsyncTimeout extends JettyServlet3Test {

class IastJettyServlet3ForkedTest extends JettyServlet3TestSync {

@Override
Class<Servlet> servlet() {
return TestServlet3.GetSession
}

@Override
void configurePreAgent() {
super.configurePreAgent()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,11 @@ class TomcatServlet3TestDispatchAsync extends TomcatServlet3Test {

class IastTomcatServlet3ForkedTest extends TomcatServlet3TestSync {

@Override
Class<Servlet> servlet() {
return TestServlet3.GetSession
}

@Override
void configurePreAgent() {
super.configurePreAgent()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,4 +305,13 @@ class TestServlet3 {
}
}
}

@WebServlet
static class GetSession extends Sync {
@Override
void service(HttpServletRequest req, HttpServletResponse resp) {
req.getSession(true)
super.service(req,resp)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,9 @@
import datadog.trace.api.iast.sink.ApplicationModule;
import datadog.trace.bootstrap.InstrumentationContext;
import jakarta.servlet.ServletContext;
import jakarta.servlet.SessionTrackingMode;
import jakarta.servlet.http.HttpServlet;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import net.bytebuddy.asm.Advice;
import net.bytebuddy.description.type.TypeDescription;
import net.bytebuddy.matcher.ElementMatcher;
Expand Down Expand Up @@ -78,14 +75,6 @@ public static void after(@Advice.This final HttpServlet servlet) {
InstrumentationContext.get(ServletContext.class, Boolean.class).put(context, true);
if (applicationModule != null) {
applicationModule.onRealPath(context.getRealPath("/"));
if (context.getEffectiveSessionTrackingModes() != null
&& !context.getEffectiveSessionTrackingModes().isEmpty()) {
Set<String> sessionTrackingModes = new HashSet<>();
for (SessionTrackingMode mode : context.getEffectiveSessionTrackingModes()) {
sessionTrackingModes.add(mode.name());
}
applicationModule.checkSessionTrackingModes(sessionTrackingModes);
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package datadog.trace.instrumentation.servlet5;

import static datadog.trace.agent.tooling.bytebuddy.matcher.HierarchyMatchers.extendsClass;
import static datadog.trace.agent.tooling.bytebuddy.matcher.HierarchyMatchers.implementsInterface;
import static datadog.trace.agent.tooling.bytebuddy.matcher.NameMatchers.named;
import static net.bytebuddy.matcher.ElementMatchers.*;

import com.google.auto.service.AutoService;
import datadog.trace.agent.tooling.Instrumenter;
import datadog.trace.agent.tooling.InstrumenterModule;
import datadog.trace.api.iast.*;
import datadog.trace.api.iast.sink.ApplicationModule;
import datadog.trace.bootstrap.InstrumentationContext;
import jakarta.servlet.ServletContext;
import jakarta.servlet.SessionTrackingMode;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpSession;
import java.util.*;
import net.bytebuddy.asm.Advice;
import net.bytebuddy.description.type.TypeDescription;
import net.bytebuddy.matcher.ElementMatcher;

@SuppressWarnings("unused")
@AutoService(InstrumenterModule.class)
public class IastOptOutJakartaHttpServletRequestInstrumentation extends InstrumenterModule.Iast
implements Instrumenter.ForTypeHierarchy {

private static final String CLASS_NAME =
IastOptOutJakartaHttpServletRequestInstrumentation.class.getName();
private static final ElementMatcher.Junction<? super TypeDescription> WRAPPER_CLASS =
named("jakarta.servlet.http.HttpServletRequestWrapper");

public IastOptOutJakartaHttpServletRequestInstrumentation() {
super("servlet", "servlet-5", "servlet-request");
}

@Override
public String hierarchyMarkerType() {
return "jakarta.servlet.http.HttpServletRequest";
}

@Override
public ElementMatcher<TypeDescription> hierarchyMatcher() {
return implementsInterface(named(hierarchyMarkerType()))
.and(not(WRAPPER_CLASS))
.and(not(extendsClass(WRAPPER_CLASS)));
}

@Override
protected boolean isOptOutEnabled() {
Copy link
Member Author

@jandro996 jandro996 Jul 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although we already have a JakartaHttpServletRequestInstrumentation for IAST we need another one due to session rewriting is an opt-out feature

return true;
}

@Override
public void methodAdvice(MethodTransformer transformer) {
transformer.applyAdvice(
named("getSession").and(returns(named("jakarta.servlet.http.HttpSession"))).and(isPublic()),
CLASS_NAME + "$GetHttpSessionAdvice");
}

@Override
public Map<String, String> contextStore() {
return Collections.singletonMap(
"jakarta.servlet.ServletContext", "jakarta.servlet.SessionTrackingMode");
}

public static class GetHttpSessionAdvice {
@Advice.OnMethodExit(suppress = Throwable.class)
@Sink(VulnerabilityTypes.SESSION_REWRITING)
public static void onExit(
@Advice.This final HttpServletRequest request, @Advice.Return final HttpSession session) {
if (session == null) {
return;
}
final ApplicationModule module = InstrumentationBridge.APPLICATION;
if (module == null) {
return;
}
final ServletContext context = request.getServletContext();

if (InstrumentationContext.get(ServletContext.class, SessionTrackingMode.class).get(context)
!= null) {
return;
}
// We only want to report it once per application
InstrumentationContext.get(ServletContext.class, SessionTrackingMode.class)
.put(context, SessionTrackingMode.URL);
if (context.getEffectiveSessionTrackingModes() != null
&& !context.getEffectiveSessionTrackingModes().isEmpty()) {
Set<String> sessionTrackingModes = new HashSet<>();
for (SessionTrackingMode mode : context.getEffectiveSessionTrackingModes()) {
sessionTrackingModes.add(mode.name());
}
module.checkSessionTrackingModes(sessionTrackingModes);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import jakarta.servlet.Servlet
import jakarta.servlet.ServletRequest
import jakarta.servlet.ServletResponse



class IastJakartaServletInstrumentationTest extends AgentTestRunner{

@Override
Expand Down Expand Up @@ -45,7 +43,6 @@ class IastJakartaServletInstrumentationTest extends AgentTestRunner{

then:
1 * module.onRealPath(_)
1 * module.checkSessionTrackingModes(['COOKIE', 'URL'] as Set<String>)
0 * _
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,23 @@ import datadog.trace.api.iast.IastContext
import datadog.trace.api.iast.InstrumentationBridge
import datadog.trace.api.iast.SourceTypes
import datadog.trace.api.iast.propagation.PropagationModule
import datadog.trace.api.iast.sink.ApplicationModule
import datadog.trace.api.iast.sink.UnvalidatedRedirectModule
import datadog.trace.bootstrap.instrumentation.api.AgentTracer
import datadog.trace.bootstrap.instrumentation.api.TagContext
import foo.bar.smoketest.JakartaHttpServletRequestTestSuite
import foo.bar.smoketest.JakartaHttpServletRequestWrapperTestSuite
import foo.bar.smoketest.ServletRequestTestSuite
import jakarta.servlet.RequestDispatcher
import jakarta.servlet.ServletContext
import jakarta.servlet.ServletInputStream
import jakarta.servlet.SessionTrackingMode
import jakarta.servlet.http.Cookie
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletRequestWrapper

import datadog.trace.agent.tooling.iast.TaintableEnumeration
import jakarta.servlet.http.HttpSession

class JakartaHttpServletRequestInstrumentationTest extends AgentTestRunner {

Expand Down Expand Up @@ -431,6 +435,31 @@ class JakartaHttpServletRequestInstrumentationTest extends AgentTestRunner {
suite << testSuiteCallSites()
}

void 'test getSession'() {
setup:
final iastModule = Mock(ApplicationModule)
InstrumentationBridge.registerIastModule(iastModule)
final session = Mock(HttpSession)
final servletContext = Mock(ServletContext) {
getEffectiveSessionTrackingModes() >> new HashSet<SessionTrackingMode>(Arrays.asList(SessionTrackingMode.URL))
}
final mock = Mock(HttpServletRequest)
final request = suite.call(mock)

when:
final result = runUnderIastTrace { request.getSession() }

then:
result == session
1 * mock.getSession() >> session
1 * mock.getServletContext() >> servletContext
1 * iastModule.checkSessionTrackingModes(_)
0 * iastModule._

where:
suite << testSuite()
}

protected <E> E runUnderIastTrace(Closure<E> cl) {
final ddctx = new TagContext().withRequestContextDataIast(iastCtx)
final span = TEST_TRACER.startSpan("test", "test-iast-span", ddctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ private VulnerabilityTypes() {}
SESSION_TIMEOUT,
DIRECTORY_LISTING_LEAK,
INSECURE_JSP_LAYOUT,
SESSION_REWRITING,
DEFAULT_APP_DEPLOYED,
};

Expand Down
Loading