From 6b0b682da2cb0feabb54489731472a05af1d4b1b Mon Sep 17 00:00:00 2001 From: Craig Perkins Date: Fri, 6 Oct 2023 18:48:59 -0400 Subject: [PATCH] Add early rejection from RestHandler for unauthorized requests (#3418) Previously unauthorized requests were fully processed and rejected once they reached the RestHandler. This allocations more memory and resources for these requests that might not be useful if they are already detected as unauthorized. Using the headerVerifer and decompressor customization from [1], perform an early authorization check when only the headers are available, save an 'early response' for transmission and do not perform the decompression on the request to speed up closing out the connection. - Resolves https://github.com/opensearch-project/OpenSearch/issues/10260 Signed-off-by: Peter Nied Signed-off-by: Craig Perkins Signed-off-by: Craig Perkins Co-authored-by: Peter Nied --- .github/workflows/ci.yml | 27 ++ build.gradle | 21 +- .../security/ResourceFocusedTests.java | 257 ++++++++++++++++++ .../framework/cluster/TestRestClient.java | 2 +- .../auth/http/saml/HTTPSamlAuthenticator.java | 2 +- .../security/OpenSearchSecurityPlugin.java | 7 +- .../security/auditlog/sink/WebhookSink.java | 3 +- .../security/auth/BackendRegistry.java | 11 +- .../security/filter/NettyAttribute.java | 49 ++++ .../security/filter/NettyRequest.java | 104 +++++++ .../security/filter/NettyRequestChannel.java | 54 ++++ .../filter/OpenSearchRequestChannel.java | 45 --- .../filter/SecurityRequestChannel.java | 7 +- .../filter/SecurityRequestFactory.java | 7 + .../security/filter/SecurityResponse.java | 11 + .../security/filter/SecurityRestFilter.java | 47 +++- .../security/filter/SecurityRestUtils.java | 12 + .../http/SecurityHttpServerTransport.java | 17 +- .../SecurityNonSslHttpServerTransport.java | 26 +- .../opensearch/security/http/XFFResolver.java | 13 +- .../ssl/OpenSearchSecuritySSLPlugin.java | 7 +- .../netty/Netty4ConditionalDecompressor.java | 37 +++ .../Netty4HttpRequestHeaderVerifier.java | 139 ++++++++++ .../SecuritySSLNettyHttpServerTransport.java | 19 +- .../security/SystemIntegratorsTests.java | 14 +- .../integration/BasicAuditlogTest.java | 5 +- .../filter/SecurityRestFilterUnitTests.java | 111 ++++++++ .../helper/cluster/ClusterConfiguration.java | 7 +- .../security/test/helper/rest/RestHelper.java | 2 +- .../test/plugin/UserInjectorPlugin.java | 159 ----------- 30 files changed, 947 insertions(+), 275 deletions(-) create mode 100644 src/integrationTest/java/org/opensearch/security/ResourceFocusedTests.java create mode 100644 src/main/java/org/opensearch/security/filter/NettyAttribute.java create mode 100644 src/main/java/org/opensearch/security/filter/NettyRequest.java create mode 100644 src/main/java/org/opensearch/security/filter/NettyRequestChannel.java create mode 100644 src/main/java/org/opensearch/security/filter/SecurityRestUtils.java create mode 100644 src/main/java/org/opensearch/security/ssl/http/netty/Netty4ConditionalDecompressor.java create mode 100644 src/main/java/org/opensearch/security/ssl/http/netty/Netty4HttpRequestHeaderVerifier.java create mode 100644 src/test/java/org/opensearch/security/filter/SecurityRestFilterUnitTests.java delete mode 100644 src/test/java/org/opensearch/security/test/plugin/UserInjectorPlugin.java diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9e58ff4b4f..fbafa851e2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,6 +10,7 @@ on: env: GRADLE_OPTS: -Dhttp.keepAlive=false + CI_ENVIRONMENT: normal jobs: generate-test-list: @@ -108,6 +109,32 @@ jobs: arguments: | integrationTest -Dbuild.snapshot=false + resource-tests: + env: + CI_ENVIRONMENT: resource-test + strategy: + fail-fast: false + matrix: + jdk: [17] + platform: [ubuntu-latest] + runs-on: ${{ matrix.platform }} + + steps: + - name: Set up JDK for build and test + uses: actions/setup-java@v3 + with: + distribution: temurin # Temurin is a distribution of adoptium + java-version: ${{ matrix.jdk }} + + - name: Checkout security + uses: actions/checkout@v4 + + - name: Build and Test + uses: gradle/gradle-build-action@v2 + with: + cache-disabled: true + arguments: | + integrationTest -Dbuild.snapshot=false --tests org.opensearch.security.ResourceFocusedTests backward-compatibility-build: runs-on: ubuntu-latest steps: diff --git a/build.gradle b/build.gradle index 12d67b9e0d..0574d5b3b4 100644 --- a/build.gradle +++ b/build.gradle @@ -459,16 +459,27 @@ sourceSets { //add new task that runs integration tests task integrationTest(type: Test) { + doFirst { + // Only run resources tests on resource-test CI environments or locally + if (System.getenv('CI_ENVIRONMENT') == 'resource-test' || System.getenv('CI_ENVIRONMENT') == null) { + include '**/ResourceFocusedTests.class' + } else { + exclude '**/ResourceFocusedTests.class' + } + // Only run with retries while in CI systems + if (System.getenv('CI_ENVIRONMENT') == 'normal') { + retry { + failOnPassedAfterRetry = false + maxRetries = 2 + maxFailures = 10 + } + } + } description = 'Run integration tests.' group = 'verification' systemProperty "java.util.logging.manager", "org.apache.logging.log4j.jul.LogManager" testClassesDirs = sourceSets.integrationTest.output.classesDirs classpath = sourceSets.integrationTest.runtimeClasspath - retry { - failOnPassedAfterRetry = false - maxRetries = 2 - maxFailures = 10 - } //run the integrationTest task after the test task shouldRunAfter test } diff --git a/src/integrationTest/java/org/opensearch/security/ResourceFocusedTests.java b/src/integrationTest/java/org/opensearch/security/ResourceFocusedTests.java new file mode 100644 index 0000000000..2ab19d15ce --- /dev/null +++ b/src/integrationTest/java/org/opensearch/security/ResourceFocusedTests.java @@ -0,0 +1,257 @@ +package org.opensearch.security; + +import static org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; +import static org.opensearch.test.framework.TestSecurityConfig.AuthcDomain.AUTHC_HTTPBASIC_INTERNAL; +import static org.opensearch.test.framework.TestSecurityConfig.Role.ALL_ACCESS; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.lang.management.GarbageCollectorMXBean; +import java.lang.management.ManagementFactory; +import java.lang.management.MemoryPoolMXBean; +import java.lang.management.MemoryUsage; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.zip.GZIPOutputStream; + +import org.apache.hc.client5.http.classic.methods.HttpPost; +import org.apache.hc.core5.http.ContentType; +import org.apache.hc.core5.http.io.entity.ByteArrayEntity; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.client.Client; +import org.opensearch.test.framework.TestSecurityConfig; +import org.opensearch.test.framework.TestSecurityConfig.User; +import org.opensearch.test.framework.cluster.ClusterManager; +import org.opensearch.test.framework.cluster.LocalCluster; +import org.opensearch.test.framework.cluster.TestRestClient; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; + +@RunWith(com.carrotsearch.randomizedtesting.RandomizedRunner.class) +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class ResourceFocusedTests { + private static final User ADMIN_USER = new User("admin").roles(ALL_ACCESS); + private static final User LIMITED_USER = new User("limited_user").roles( + new TestSecurityConfig.Role("limited-role").clusterPermissions( + "indices:data/read/mget", + "indices:data/read/msearch", + "indices:data/read/scroll", + "cluster:monitor/state", + "cluster:monitor/health" + ) + .indexPermissions( + "indices:data/read/search", + "indices:data/read/mget*", + "indices:monitor/settings/get", + "indices:monitor/stats" + ) + .on("*") + ); + + @ClassRule + public static LocalCluster cluster = new LocalCluster.Builder().clusterManager(ClusterManager.THREE_CLUSTER_MANAGERS) + .authc(AUTHC_HTTPBASIC_INTERNAL) + .users(ADMIN_USER, LIMITED_USER) + .anonymousAuth(false) + .doNotFailOnForbidden(true) + .build(); + + @BeforeClass + public static void createTestData() { + try (Client client = cluster.getInternalNodeClient()) { + client.index(new IndexRequest().setRefreshPolicy(IMMEDIATE).index("document").source(Map.of("foo", "bar", "abc", "xyz"))) + .actionGet(); + } + } + + @Test + public void testUnauthenticatedFewBig() { + // Tweaks: + final RequestBodySize size = RequestBodySize.XLarge; + final String requestPath = "/*/_search"; + final int parrallelism = 5; + final int totalNumberOfRequests = 100; + final boolean statsPrinter = false; + + runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests, statsPrinter); + } + + @Test + public void testUnauthenticatedManyMedium() { + // Tweaks: + final RequestBodySize size = RequestBodySize.Medium; + final String requestPath = "/*/_search"; + final int parrallelism = 20; + final int totalNumberOfRequests = 10_000; + final boolean statsPrinter = false; + + runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests, statsPrinter); + } + + @Test + public void testUnauthenticatedTonsSmall() { + // Tweaks: + final RequestBodySize size = RequestBodySize.Small; + final String requestPath = "/*/_search"; + final int parrallelism = 100; + final int totalNumberOfRequests = 1_000_000; + final boolean statsPrinter = false; + + runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests, statsPrinter); + } + + private Long runResourceTest( + final RequestBodySize size, + final String requestPath, + final int parrallelism, + final int totalNumberOfRequests, + final boolean statsPrinter + ) { + final byte[] compressedRequestBody = createCompressedRequestBody(size); + try (final TestRestClient client = cluster.getRestClient(new BasicHeader("Content-Encoding", "gzip"))) { + + if (statsPrinter) { + printStats(); + } + final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath); + post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON)); + + final ForkJoinPool forkJoinPool = new ForkJoinPool(parrallelism); + + final List> waitingOn = IntStream.rangeClosed(1, totalNumberOfRequests) + .boxed() + .map(i -> CompletableFuture.runAsync(() -> client.executeRequest(post), forkJoinPool)) + .collect(Collectors.toList()); + Supplier getCount = () -> waitingOn.stream().filter(cf -> cf.isDone() && !cf.isCompletedExceptionally()).count(); + + CompletableFuture statPrinter = statsPrinter ? CompletableFuture.runAsync(() -> { + while (true) { + printStats(); + System.out.println(" & Succesful completions: " + getCount.get()); + try { + Thread.sleep(500); + } catch (Exception e) { + break; + } + } + }, forkJoinPool) : CompletableFuture.completedFuture(null); + + final CompletableFuture allOfThem = CompletableFuture.allOf(waitingOn.toArray(new CompletableFuture[0])); + + try { + allOfThem.get(30, TimeUnit.SECONDS); + statPrinter.cancel(true); + } catch (final Exception e) { + // Ignored + } + + if (statsPrinter) { + printStats(); + System.out.println(" & Succesful completions: " + getCount.get()); + } + return getCount.get(); + } + } + + static enum RequestBodySize { + Small(1), + Medium(1_000), + XLarge(1_000_000); + + public final int elementCount; + + private RequestBodySize(final int elementCount) { + this.elementCount = elementCount; + } + } + + private byte[] createCompressedRequestBody(final RequestBodySize size) { + final int repeatCount = size.elementCount; + final String prefix = "{ \"items\": ["; + final String repeatedElement = IntStream.range(0, 20) + .mapToObj(n -> ('a' + n) + "") + .map(n -> '"' + n + '"' + ": 123") + .collect(Collectors.joining(",", "{", "}")); + final String postfix = "]}"; + long uncompressedBytesSize = 0; + + try ( + final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + final GZIPOutputStream gzipOutputStream = new GZIPOutputStream(byteArrayOutputStream) + ) { + + final byte[] prefixBytes = prefix.getBytes(StandardCharsets.UTF_8); + final byte[] repeatedElementBytes = repeatedElement.getBytes(StandardCharsets.UTF_8); + final byte[] postfixBytes = postfix.getBytes(StandardCharsets.UTF_8); + + gzipOutputStream.write(prefixBytes); + uncompressedBytesSize = uncompressedBytesSize + prefixBytes.length; + for (int i = 0; i < repeatCount; i++) { + gzipOutputStream.write(repeatedElementBytes); + uncompressedBytesSize = uncompressedBytesSize + repeatedElementBytes.length; + } + gzipOutputStream.write(postfixBytes); + uncompressedBytesSize = uncompressedBytesSize + postfixBytes.length; + gzipOutputStream.finish(); + + final byte[] compressedRequestBody = byteArrayOutputStream.toByteArray(); + System.out.println( + "^^^" + + String.format( + "Original size was %,d bytes, compressed to %,d bytes, ratio %,.2f", + uncompressedBytesSize, + compressedRequestBody.length, + ((double) uncompressedBytesSize / compressedRequestBody.length) + ) + ); + return compressedRequestBody; + } catch (final IOException ioe) { + throw new RuntimeException(ioe); + } + } + + private void printStats() { + System.out.println("** Stats "); + printMemory(); + printMemoryPools(); + printGCPools(); + } + + private void printMemory() { + final Runtime runtime = Runtime.getRuntime(); + + final long totalMemory = runtime.totalMemory(); // Total allocated memory + final long freeMemory = runtime.freeMemory(); // Amount of free memory + final long usedMemory = totalMemory - freeMemory; // Amount of used memory + + System.out.println(" Memory Total: " + totalMemory + " Free:" + freeMemory + " Used:" + usedMemory); + } + + private void printMemoryPools() { + List memoryPools = ManagementFactory.getMemoryPoolMXBeans(); + for (MemoryPoolMXBean memoryPool : memoryPools) { + MemoryUsage usage = memoryPool.getUsage(); + System.out.println(" " + memoryPool.getName() + " USED: " + usage.getUsed() + " MAX: " + usage.getMax()); + } + } + + private void printGCPools() { + List garbageCollectors = ManagementFactory.getGarbageCollectorMXBeans(); + for (GarbageCollectorMXBean garbageCollector : garbageCollectors) { + System.out.println(" " + garbageCollector.getName() + " COLLECTION TIME: " + garbageCollector.getCollectionTime()); + } + } + +} diff --git a/src/integrationTest/java/org/opensearch/test/framework/cluster/TestRestClient.java b/src/integrationTest/java/org/opensearch/test/framework/cluster/TestRestClient.java index eff2a1db9c..f494f00796 100644 --- a/src/integrationTest/java/org/opensearch/test/framework/cluster/TestRestClient.java +++ b/src/integrationTest/java/org/opensearch/test/framework/cluster/TestRestClient.java @@ -268,7 +268,7 @@ public void createRoleMapping(String backendRoleName, String roleName) { response.assertStatusCode(201); } - protected final String getHttpServerUri() { + public final String getHttpServerUri() { return "http" + (enableHTTPClientSSL ? "s" : "") + "://" + nodeHttpAddress.getHostString() + ":" + nodeHttpAddress.getPort(); } diff --git a/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java index a3f37ba46e..846573289f 100644 --- a/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java @@ -79,7 +79,7 @@ public class HTTPSamlAuthenticator implements HTTPAuthenticator, Destroyable { public static final String IDP_METADATA_FILE = "idp.metadata_file"; public static final String IDP_METADATA_CONTENT = "idp.metadata_content"; - private static final String API_AUTHTOKEN_SUFFIX = "api/authtoken"; + public static final String API_AUTHTOKEN_SUFFIX = "api/authtoken"; private static final String AUTHINFO_SUFFIX = "authinfo"; private static final String REGEX_PATH_PREFIX = "/(" + LEGACY_OPENDISTRO_PREFIX + "|" + PLUGINS_PREFIX + ")/" + "(.*)"; private static final Pattern PATTERN_PATH_PREFIX = Pattern.compile(REGEX_PATH_PREFIX); diff --git a/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java b/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java index fd9590e04d..7cf2d6d70f 100644 --- a/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java +++ b/src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java @@ -220,7 +220,6 @@ public final class OpenSearchSecurityPlugin extends OpenSearchSecuritySSLPlugin public static final String PLUGINS_PREFIX = "_plugins/_security"; private boolean sslCertReloadEnabled; - private volatile SecurityRestFilter securityRestHandler; private volatile SecurityInterceptor si; private volatile PrivilegesEvaluator evaluator; private volatile UserService userService; @@ -914,7 +913,8 @@ public Map> getHttpTransports( validatingDispatcher, clusterSettings, sharedGroupFactory, - tracer + tracer, + securityRestHandler ); return Collections.singletonMap("org.opensearch.security.http.SecurityHttpServerTransport", () -> odshst); @@ -930,7 +930,8 @@ public Map> getHttpTransports( dispatcher, clusterSettings, sharedGroupFactory, - tracer + tracer, + securityRestHandler ) ); } diff --git a/src/main/java/org/opensearch/security/auditlog/sink/WebhookSink.java b/src/main/java/org/opensearch/security/auditlog/sink/WebhookSink.java index 78780fb8e1..8616fa9df5 100644 --- a/src/main/java/org/opensearch/security/auditlog/sink/WebhookSink.java +++ b/src/main/java/org/opensearch/security/auditlog/sink/WebhookSink.java @@ -37,12 +37,11 @@ import org.apache.hc.client5.http.ssl.NoopHostnameVerifier; import org.apache.hc.client5.http.ssl.SSLConnectionSocketFactory; import org.apache.hc.core5.http.ContentType; -import org.apache.hc.core5.http.HttpStatus; import org.apache.hc.core5.http.io.SocketConfig; import org.apache.hc.core5.http.io.entity.StringEntity; import org.apache.hc.core5.ssl.SSLContextBuilder; import org.apache.hc.core5.ssl.TrustStrategy; - +import org.apache.http.HttpStatus; import org.opensearch.common.settings.Settings; import org.opensearch.core.common.Strings; import org.opensearch.security.auditlog.impl.AuditMessage; diff --git a/src/main/java/org/opensearch/security/auth/BackendRegistry.java b/src/main/java/org/opensearch/security/auth/BackendRegistry.java index 88e88c3c52..f4abb5361a 100644 --- a/src/main/java/org/opensearch/security/auth/BackendRegistry.java +++ b/src/main/java/org/opensearch/security/auth/BackendRegistry.java @@ -51,6 +51,7 @@ import org.opensearch.OpenSearchSecurityException; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.rest.RestStatus; import org.opensearch.security.auditlog.AuditLog; @@ -145,11 +146,14 @@ public BackendRegistry( this.auditLog = auditLog; this.threadPool = threadPool; this.userInjector = new UserInjector(settings, threadPool, auditLog, xffResolver); + this.restAuthDomains = Collections.emptySortedSet(); + this.ipAuthFailureListeners = Collections.emptyList(); this.ttlInMin = settings.getAsInt(ConfigConstants.SECURITY_CACHE_TTL_MINUTES, 60); // This is going to be defined in the opensearch.yml, so it's best suited to be initialized once. this.injectedUserEnabled = opensearchSettings.getAsBoolean(ConfigConstants.SECURITY_UNSUPPORTED_INJECT_USER_ENABLED, false); + initialized = this.injectedUserEnabled; createCaches(); } @@ -186,7 +190,6 @@ public void onDynamicConfigModelChanged(DynamicConfigModel dcm) { /** * * @param request - * @param channel * @return The authenticated user, null means another roundtrip * @throws OpenSearchSecurityException */ @@ -205,11 +208,13 @@ public boolean authenticate(final SecurityRequestChannel request) { return false; } - final String sslPrincipal = (String) threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_SSL_PRINCIPAL); + ThreadContext threadContext = this.threadPool.getThreadContext(); + + final String sslPrincipal = (String) threadContext.getTransient(ConfigConstants.OPENDISTRO_SECURITY_SSL_PRINCIPAL); if (adminDns.isAdminDN(sslPrincipal)) { // PKI authenticated REST call - threadPool.getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, new User(sslPrincipal)); + threadContext.putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, new User(sslPrincipal)); auditLog.logSucceededLogin(sslPrincipal, true, null, request); return true; } diff --git a/src/main/java/org/opensearch/security/filter/NettyAttribute.java b/src/main/java/org/opensearch/security/filter/NettyAttribute.java new file mode 100644 index 0000000000..685e94e199 --- /dev/null +++ b/src/main/java/org/opensearch/security/filter/NettyAttribute.java @@ -0,0 +1,49 @@ +package org.opensearch.security.filter; + +import java.util.Optional; + +import org.opensearch.http.netty4.Netty4HttpChannel; +import org.opensearch.rest.RestRequest; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.AttributeKey; + +public class NettyAttribute { + + /** + * Gets an attribute value from the request context and clears it from that context + */ + public static Optional popFrom(final RestRequest request, final AttributeKey attribute) { + if (request.getHttpChannel() instanceof Netty4HttpChannel) { + Channel nettyChannel = ((Netty4HttpChannel) request.getHttpChannel()).getNettyChannel(); + return Optional.ofNullable(nettyChannel.attr(attribute).getAndSet(null)); + } + return Optional.empty(); + } + + /** + * Gets an attribute value from the channel handler context and clears it from that context + */ + public static Optional popFrom(final ChannelHandlerContext ctx, final AttributeKey attribute) { + return Optional.ofNullable(ctx.channel().attr(attribute).getAndSet(null)); + } + + /** + * Gets an attribute value from the channel handler context + */ + public static Optional peekFrom(final ChannelHandlerContext ctx, final AttributeKey attribute) { + return Optional.ofNullable(ctx.channel().attr(attribute).get()); + } + + /** + * Clears an attribute value from the channel handler context + */ + public static void clearAttribute(final RestRequest request, final AttributeKey attribute) { + if (request.getHttpChannel() instanceof Netty4HttpChannel) { + Channel nettyChannel = ((Netty4HttpChannel) request.getHttpChannel()).getNettyChannel(); + nettyChannel.attr(attribute).set(null); + } + } + +} diff --git a/src/main/java/org/opensearch/security/filter/NettyRequest.java b/src/main/java/org/opensearch/security/filter/NettyRequest.java new file mode 100644 index 0000000000..f3f2827367 --- /dev/null +++ b/src/main/java/org/opensearch/security/filter/NettyRequest.java @@ -0,0 +1,104 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.filter; + +import java.net.InetSocketAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.TreeMap; + +import javax.net.ssl.SSLEngine; + +import io.netty.handler.ssl.SslHandler; +import org.opensearch.http.netty4.Netty4HttpChannel; +import org.opensearch.rest.RestRequest.Method; + +import io.netty.handler.codec.http.HttpRequest; +import org.opensearch.rest.RestUtils; + +/** + * Wraps the functionality of HttpRequest for use in the security plugin + */ +public class NettyRequest implements SecurityRequest { + + protected final HttpRequest underlyingRequest; + protected final Netty4HttpChannel underlyingChannel; + + NettyRequest(final HttpRequest request, final Netty4HttpChannel channel) { + this.underlyingRequest = request; + this.underlyingChannel = channel; + } + + @Override + public Map> getHeaders() { + final Map> headers = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); + underlyingRequest.headers().forEach(h -> headers.put(h.getKey(), List.of(h.getValue()))); + return headers; + } + + @Override + public SSLEngine getSSLEngine() { + // We look for Ssl_handler called `ssl_http` in the outbound pipeline of Netty channel first, and if its not + // present we look for it in inbound channel. If its present in neither we return null, else we return the sslHandler. + SslHandler sslhandler = (SslHandler) underlyingChannel.getNettyChannel().pipeline().get("ssl_http"); + if (sslhandler == null && underlyingChannel.inboundPipeline() != null) { + sslhandler = (SslHandler) underlyingChannel.inboundPipeline().get("ssl_http"); + } + + return sslhandler != null ? sslhandler.engine() : null; + } + + @Override + public String path() { + String rawPath = SecurityRestUtils.path(underlyingRequest.uri()); + return RestUtils.decodeComponent(rawPath); + } + + @Override + public Method method() { + return Method.valueOf(underlyingRequest.method().name()); + } + + @Override + public Optional getRemoteAddress() { + return Optional.ofNullable(this.underlyingChannel.getRemoteAddress()); + } + + @Override + public String uri() { + return underlyingRequest.uri(); + } + + @Override + public Map params() { + return params(underlyingRequest.uri()); + } + + private static Map params(String uri) { + // Sourced from + // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/http/AbstractHttpServerTransport.java#L419-L422 + final Map params = new HashMap<>(); + final int index = uri.indexOf(63); + if (index >= 0) { + try { + RestUtils.decodeQueryString(uri, index + 1, params); + } catch (IllegalArgumentException var4) { + return Collections.emptyMap(); + } + } + + return params; + } +} diff --git a/src/main/java/org/opensearch/security/filter/NettyRequestChannel.java b/src/main/java/org/opensearch/security/filter/NettyRequestChannel.java new file mode 100644 index 0000000000..a83ecdea8a --- /dev/null +++ b/src/main/java/org/opensearch/security/filter/NettyRequestChannel.java @@ -0,0 +1,54 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.filter; + +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import io.netty.handler.codec.http.HttpRequest; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.http.netty4.Netty4HttpChannel; + +public class NettyRequestChannel extends NettyRequest implements SecurityRequestChannel { + private final Logger log = LogManager.getLogger(NettyRequestChannel.class); + + private AtomicBoolean hasCompleted = new AtomicBoolean(false); + private final AtomicReference responseRef = new AtomicReference(null); + + NettyRequestChannel(final HttpRequest request, Netty4HttpChannel channel) { + super(request, channel); + } + + @Override + public void queueForSending(SecurityResponse response) { + if (underlyingChannel == null) { + throw new UnsupportedOperationException("Channel was not defined"); + } + + if (hasCompleted.get()) { + throw new UnsupportedOperationException("This channel has already completed"); + } + + if (getQueuedResponse().isPresent()) { + throw new UnsupportedOperationException("Another response was already queued"); + } + + responseRef.set(response); + } + + @Override + public Optional getQueuedResponse() { + return Optional.ofNullable(responseRef.get()); + } +} diff --git a/src/main/java/org/opensearch/security/filter/OpenSearchRequestChannel.java b/src/main/java/org/opensearch/security/filter/OpenSearchRequestChannel.java index 45035e0d83..24c90488cb 100644 --- a/src/main/java/org/opensearch/security/filter/OpenSearchRequestChannel.java +++ b/src/main/java/org/opensearch/security/filter/OpenSearchRequestChannel.java @@ -12,22 +12,14 @@ package org.opensearch.security.filter; import java.util.Optional; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; public class OpenSearchRequestChannel extends OpenSearchRequest implements SecurityRequestChannel { - private final Logger log = LogManager.getLogger(OpenSearchRequest.class); - private final AtomicReference responseRef = new AtomicReference(null); - private final AtomicBoolean hasCompleted = new AtomicBoolean(false); private final RestChannel underlyingChannel; OpenSearchRequestChannel(final RestRequest request, final RestChannel channel) { @@ -46,10 +38,6 @@ public void queueForSending(final SecurityResponse response) { throw new UnsupportedOperationException("Channel was not defined"); } - if (hasCompleted.get()) { - throw new UnsupportedOperationException("This channel has already completed"); - } - if (getQueuedResponse().isPresent()) { throw new UnsupportedOperationException("Another response was already queued"); } @@ -61,37 +49,4 @@ public void queueForSending(final SecurityResponse response) { public Optional getQueuedResponse() { return Optional.ofNullable(responseRef.get()); } - - @Override - public boolean sendResponse() { - if (underlyingChannel == null) { - throw new UnsupportedOperationException("Channel was not defined"); - } - - if (hasCompleted.get()) { - throw new UnsupportedOperationException("This channel has already completed"); - } - - if (getQueuedResponse().isEmpty()) { - throw new UnsupportedOperationException("No response has been associated with this channel"); - } - - final SecurityResponse response = responseRef.get(); - - try { - final BytesRestResponse restResponse = new BytesRestResponse(RestStatus.fromCode(response.getStatus()), response.getBody()); - if (response.getHeaders() != null) { - response.getHeaders().forEach(restResponse::addHeader); - } - underlyingChannel.sendResponse(restResponse); - - return true; - } catch (final Exception e) { - log.error("Error when attempting to send response", e); - throw new RuntimeException(e); - } finally { - hasCompleted.set(true); - } - - } } diff --git a/src/main/java/org/opensearch/security/filter/SecurityRequestChannel.java b/src/main/java/org/opensearch/security/filter/SecurityRequestChannel.java index 1eec754c08..66744d01dd 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityRequestChannel.java +++ b/src/main/java/org/opensearch/security/filter/SecurityRequestChannel.java @@ -19,11 +19,8 @@ public interface SecurityRequestChannel extends SecurityRequest { /** Associate a response with this channel */ - public void queueForSending(final SecurityResponse response); + void queueForSending(final SecurityResponse response); /** Acess the queued response */ - public Optional getQueuedResponse(); - - /** Send the response through the channel */ - public boolean sendResponse(); + Optional getQueuedResponse(); } diff --git a/src/main/java/org/opensearch/security/filter/SecurityRequestFactory.java b/src/main/java/org/opensearch/security/filter/SecurityRequestFactory.java index de74df01ff..0b64d0220d 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityRequestFactory.java +++ b/src/main/java/org/opensearch/security/filter/SecurityRequestFactory.java @@ -11,6 +11,8 @@ package org.opensearch.security.filter; +import io.netty.handler.codec.http.HttpRequest; +import org.opensearch.http.netty4.Netty4HttpChannel; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; @@ -24,6 +26,11 @@ public static SecurityRequest from(final RestRequest request) { return new OpenSearchRequest(request); } + /** Creates a security request from a netty HttpRequest object */ + public static SecurityRequestChannel from(HttpRequest request, Netty4HttpChannel channel) { + return new NettyRequestChannel(request, channel); + } + /** Creates a security request channel from a RestRequest & RestChannel */ public static SecurityRequestChannel from(final RestRequest request, final RestChannel channel) { return new OpenSearchRequestChannel(request, channel); diff --git a/src/main/java/org/opensearch/security/filter/SecurityResponse.java b/src/main/java/org/opensearch/security/filter/SecurityResponse.java index 8618be3aab..be80a5d4d4 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityResponse.java +++ b/src/main/java/org/opensearch/security/filter/SecurityResponse.java @@ -14,6 +14,9 @@ import java.util.Map; import org.apache.http.HttpHeaders; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestResponse; public class SecurityResponse { @@ -41,4 +44,12 @@ public String getBody() { return body; } + public RestResponse asRestResponse() { + final RestResponse restResponse = new BytesRestResponse(RestStatus.fromCode(getStatus()), getBody()); + if (getHeaders() != null) { + getHeaders().forEach(restResponse::addHeader); + } + return restResponse; + } + } diff --git a/src/main/java/org/opensearch/security/filter/SecurityRestFilter.java b/src/main/java/org/opensearch/security/filter/SecurityRestFilter.java index ce80a9e143..87fe7f5323 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityRestFilter.java +++ b/src/main/java/org/opensearch/security/filter/SecurityRestFilter.java @@ -63,10 +63,14 @@ import org.opensearch.security.support.ConfigConstants; import org.opensearch.security.support.HTTPHelper; import org.opensearch.security.user.User; +import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import static org.opensearch.security.OpenSearchSecurityPlugin.LEGACY_OPENDISTRO_PREFIX; import static org.opensearch.security.OpenSearchSecurityPlugin.PLUGINS_PREFIX; +import static org.opensearch.security.http.SecurityHttpServerTransport.CONTEXT_TO_RESTORE; +import static org.opensearch.security.http.SecurityHttpServerTransport.EARLY_RESPONSE; +import static org.opensearch.security.http.SecurityHttpServerTransport.IS_AUTHENTICATED; public class SecurityRestFilter { @@ -83,11 +87,11 @@ public class SecurityRestFilter { private WhitelistingSettings whitelistingSettings; private AllowlistingSettings allowlistingSettings; - private static final String HEALTH_SUFFIX = "health"; - private static final String WHO_AM_I_SUFFIX = "whoami"; + public static final String HEALTH_SUFFIX = "health"; + public static final String WHO_AM_I_SUFFIX = "whoami"; - private static final String REGEX_PATH_PREFIX = "/(" + LEGACY_OPENDISTRO_PREFIX + "|" + PLUGINS_PREFIX + ")/" + "(.*)"; - private static final Pattern PATTERN_PATH_PREFIX = Pattern.compile(REGEX_PATH_PREFIX); + public static final String REGEX_PATH_PREFIX = "/(" + LEGACY_OPENDISTRO_PREFIX + "|" + PLUGINS_PREFIX + ")/" + "(.*)"; + public static final Pattern PATTERN_PATH_PREFIX = Pattern.compile(REGEX_PATH_PREFIX); public SecurityRestFilter( final BackendRegistry registry, @@ -127,13 +131,33 @@ public SecurityRestFilter( */ public RestHandler wrap(RestHandler original, AdminDNs adminDNs) { return (request, channel, client) -> { - org.apache.logging.log4j.ThreadContext.clearAll(); + + final Optional maybeSavedResponse = NettyAttribute.popFrom(request, EARLY_RESPONSE); + if (maybeSavedResponse.isPresent()) { + NettyAttribute.clearAttribute(request, CONTEXT_TO_RESTORE); + NettyAttribute.clearAttribute(request, IS_AUTHENTICATED); + channel.sendResponse(maybeSavedResponse.get().asRestResponse()); + return; + } + + NettyAttribute.popFrom(request, CONTEXT_TO_RESTORE).ifPresent(storedContext -> { + // X_OPAQUE_ID will be overritten on restore - save to apply after restoring the saved context + final String xOpaqueId = threadContext.getHeader(Task.X_OPAQUE_ID); + storedContext.restore(); + if (xOpaqueId != null) { + threadContext.putHeader(Task.X_OPAQUE_ID, xOpaqueId); + } + }); + final SecurityRequestChannel requestChannel = SecurityRequestFactory.from(request, channel); // Authenticate request - checkAndAuthenticateRequest(requestChannel); + if (!NettyAttribute.popFrom(request, IS_AUTHENTICATED).orElse(false)) { + // we aren't authenticated so we should skip this step + checkAndAuthenticateRequest(requestChannel); + } if (requestChannel.getQueuedResponse().isPresent()) { - requestChannel.sendResponse(); + channel.sendResponse(requestChannel.getQueuedResponse().get().asRestResponse()); return; } @@ -149,14 +173,13 @@ public RestHandler wrap(RestHandler original, AdminDNs adminDNs) { .or(() -> allowlistingSettings.checkRequestIsAllowed(requestChannel)); if (deniedResponse.isPresent()) { - requestChannel.queueForSending(deniedResponse.orElseThrow()); - requestChannel.sendResponse(); + channel.sendResponse(deniedResponse.get().asRestResponse()); return; } authorizeRequest(original, requestChannel, user); if (requestChannel.getQueuedResponse().isPresent()) { - requestChannel.sendResponse(); + channel.sendResponse(requestChannel.getQueuedResponse().get().asRestResponse()); return; } @@ -168,11 +191,11 @@ public RestHandler wrap(RestHandler original, AdminDNs adminDNs) { /** * Checks if a given user is a SuperAdmin */ - private boolean userIsSuperAdmin(User user, AdminDNs adminDNs) { + boolean userIsSuperAdmin(User user, AdminDNs adminDNs) { return user != null && adminDNs.isAdmin(user); } - private void authorizeRequest(RestHandler original, SecurityRequestChannel request, User user) { + void authorizeRequest(RestHandler original, SecurityRequestChannel request, User user) { List restRoutes = original.routes(); Optional handler = restRoutes.stream() .filter(rh -> rh.getMethod().equals(request.method())) diff --git a/src/main/java/org/opensearch/security/filter/SecurityRestUtils.java b/src/main/java/org/opensearch/security/filter/SecurityRestUtils.java new file mode 100644 index 0000000000..1599346b90 --- /dev/null +++ b/src/main/java/org/opensearch/security/filter/SecurityRestUtils.java @@ -0,0 +1,12 @@ +package org.opensearch.security.filter; + +public class SecurityRestUtils { + public static String path(final String uri) { + final int index = uri.indexOf('?'); + if (index >= 0) { + return uri.substring(0, index); + } else { + return uri; + } + } +} diff --git a/src/main/java/org/opensearch/security/http/SecurityHttpServerTransport.java b/src/main/java/org/opensearch/security/http/SecurityHttpServerTransport.java index fc36e2411b..3b70a5ebda 100644 --- a/src/main/java/org/opensearch/security/http/SecurityHttpServerTransport.java +++ b/src/main/java/org/opensearch/security/http/SecurityHttpServerTransport.java @@ -26,11 +26,15 @@ package org.opensearch.security.http; +import io.netty.util.AttributeKey; import org.opensearch.common.network.NetworkService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.BigArrays; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.security.filter.SecurityResponse; +import org.opensearch.security.filter.SecurityRestFilter; import org.opensearch.security.ssl.SecurityKeyStore; import org.opensearch.security.ssl.SslExceptionHandler; import org.opensearch.security.ssl.http.netty.SecuritySSLNettyHttpServerTransport; @@ -41,6 +45,13 @@ public class SecurityHttpServerTransport extends SecuritySSLNettyHttpServerTransport { + public static final AttributeKey EARLY_RESPONSE = AttributeKey.newInstance("opensearch-http-early-response"); + public static final AttributeKey CONTEXT_TO_RESTORE = AttributeKey.newInstance( + "opensearch-http-request-thread-context" + ); + public static final AttributeKey SHOULD_DECOMPRESS = AttributeKey.newInstance("opensearch-http-should-decompress"); + public static final AttributeKey IS_AUTHENTICATED = AttributeKey.newInstance("opensearch-http-is-authenticated"); + public SecurityHttpServerTransport( final Settings settings, final NetworkService networkService, @@ -52,7 +63,8 @@ public SecurityHttpServerTransport( final ValidatingDispatcher dispatcher, final ClusterSettings clusterSettings, SharedGroupFactory sharedGroupFactory, - Tracer tracer + Tracer tracer, + SecurityRestFilter restFilter ) { super( settings, @@ -65,7 +77,8 @@ public SecurityHttpServerTransport( sslExceptionHandler, clusterSettings, sharedGroupFactory, - tracer + tracer, + restFilter ); } } diff --git a/src/main/java/org/opensearch/security/http/SecurityNonSslHttpServerTransport.java b/src/main/java/org/opensearch/security/http/SecurityNonSslHttpServerTransport.java index a8e675ec74..71586a2dff 100644 --- a/src/main/java/org/opensearch/security/http/SecurityNonSslHttpServerTransport.java +++ b/src/main/java/org/opensearch/security/http/SecurityNonSslHttpServerTransport.java @@ -29,6 +29,7 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelInboundHandlerAdapter; import org.opensearch.common.network.NetworkService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; @@ -36,12 +37,18 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.http.HttpHandlingSettings; import org.opensearch.http.netty4.Netty4HttpServerTransport; +import org.opensearch.security.filter.SecurityRestFilter; +import org.opensearch.security.ssl.http.netty.Netty4ConditionalDecompressor; +import org.opensearch.security.ssl.http.netty.Netty4HttpRequestHeaderVerifier; import org.opensearch.telemetry.tracing.Tracer; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.SharedGroupFactory; public class SecurityNonSslHttpServerTransport extends Netty4HttpServerTransport { + private final ChannelInboundHandlerAdapter headerVerifier; + private final ChannelInboundHandlerAdapter conditionalDecompressor; + public SecurityNonSslHttpServerTransport( final Settings settings, final NetworkService networkService, @@ -49,9 +56,10 @@ public SecurityNonSslHttpServerTransport( final ThreadPool threadPool, final NamedXContentRegistry namedXContentRegistry, final Dispatcher dispatcher, - ClusterSettings clusterSettings, - SharedGroupFactory sharedGroupFactory, - Tracer tracer + final ClusterSettings clusterSettings, + final SharedGroupFactory sharedGroupFactory, + final Tracer tracer, + final SecurityRestFilter restFilter ) { super( settings, @@ -64,6 +72,8 @@ public SecurityNonSslHttpServerTransport( sharedGroupFactory, tracer ); + headerVerifier = new Netty4HttpRequestHeaderVerifier(restFilter, threadPool, settings); + conditionalDecompressor = new Netty4ConditionalDecompressor(); } @Override @@ -82,4 +92,14 @@ protected void initChannel(Channel ch) throws Exception { super.initChannel(ch); } } + + @Override + protected ChannelInboundHandlerAdapter createHeaderVerifier() { + return headerVerifier; + } + + @Override + protected ChannelInboundHandlerAdapter createDecompressor() { + return conditionalDecompressor; + } } diff --git a/src/main/java/org/opensearch/security/http/XFFResolver.java b/src/main/java/org/opensearch/security/http/XFFResolver.java index e9ad412831..64c5f4b60c 100644 --- a/src/main/java/org/opensearch/security/http/XFFResolver.java +++ b/src/main/java/org/opensearch/security/http/XFFResolver.java @@ -34,11 +34,8 @@ import org.opensearch.OpenSearchSecurityException; import org.opensearch.core.common.transport.TransportAddress; -import org.opensearch.http.netty4.Netty4HttpChannel; -import org.opensearch.rest.RestRequest; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.security.filter.SecurityRequest; -import org.opensearch.security.filter.OpenSearchRequest; import org.opensearch.security.securityconf.DynamicConfigModel; import org.opensearch.security.support.ConfigConstants; import org.opensearch.threadpool.ThreadPool; @@ -61,15 +58,7 @@ public TransportAddress resolve(final SecurityRequest request) throws OpenSearch log.trace("resolve {}", request.getRemoteAddress().orElse(null)); } - boolean requestFromNetty = false; - if (request instanceof OpenSearchRequest) { - final OpenSearchRequest securityRequestChannel = (OpenSearchRequest) request; - final RestRequest restRequest = securityRequestChannel.breakEncapsulationForRequest(); - - requestFromNetty = restRequest.getHttpChannel() instanceof Netty4HttpChannel; - } - - if (enabled && request.getRemoteAddress().isPresent() && requestFromNetty) { + if (enabled && request.getRemoteAddress().isPresent()) { final InetSocketAddress remoteAddress = request.getRemoteAddress().get(); final InetSocketAddress isa = new InetSocketAddress(detector.detect(request, threadContext), remoteAddress.getPort()); diff --git a/src/main/java/org/opensearch/security/ssl/OpenSearchSecuritySSLPlugin.java b/src/main/java/org/opensearch/security/ssl/OpenSearchSecuritySSLPlugin.java index 18ec7457e9..722e55370e 100644 --- a/src/main/java/org/opensearch/security/ssl/OpenSearchSecuritySSLPlugin.java +++ b/src/main/java/org/opensearch/security/ssl/OpenSearchSecuritySSLPlugin.java @@ -71,6 +71,7 @@ import org.opensearch.script.ScriptService; import org.opensearch.security.DefaultObjectMapper; import org.opensearch.security.NonValidatingObjectMapper; +import org.opensearch.security.filter.SecurityRestFilter; import org.opensearch.security.ssl.http.netty.SecuritySSLNettyHttpServerTransport; import org.opensearch.security.ssl.http.netty.ValidatingDispatcher; import org.opensearch.security.ssl.rest.SecuritySSLInfoAction; @@ -96,12 +97,13 @@ public class OpenSearchSecuritySSLPlugin extends Plugin implements SystemIndexPl ); public static final boolean OPENSSL_SUPPORTED = (PlatformDependent.javaVersion() < 12) && USE_NETTY_DEFAULT_ALLOCATOR; protected final Logger log = LogManager.getLogger(this.getClass()); - protected static final String CLIENT_TYPE = "client.type"; + public static final String CLIENT_TYPE = "client.type"; protected final boolean client; protected final boolean httpSSLEnabled; protected final boolean transportSSLEnabled; protected final boolean extendedKeyUsageEnabled; protected final Settings settings; + protected volatile SecurityRestFilter securityRestHandler; protected final SharedGroupFactory sharedGroupFactory; protected final SecurityKeyStore sks; protected PrincipalExtractor principalExtractor; @@ -267,7 +269,8 @@ public Map> getHttpTransports( NOOP_SSL_EXCEPTION_HANDLER, clusterSettings, sharedGroupFactory, - tracer + tracer, + securityRestHandler ); return Collections.singletonMap("org.opensearch.security.ssl.http.netty.SecuritySSLNettyHttpServerTransport", () -> sgsnht); diff --git a/src/main/java/org/opensearch/security/ssl/http/netty/Netty4ConditionalDecompressor.java b/src/main/java/org/opensearch/security/ssl/http/netty/Netty4ConditionalDecompressor.java new file mode 100644 index 0000000000..c8059fad5d --- /dev/null +++ b/src/main/java/org/opensearch/security/ssl/http/netty/Netty4ConditionalDecompressor.java @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.security.ssl.http.netty; + +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.HttpContentDecompressor; + +import static org.opensearch.security.http.SecurityHttpServerTransport.EARLY_RESPONSE; +import static org.opensearch.security.http.SecurityHttpServerTransport.SHOULD_DECOMPRESS; + +import org.opensearch.security.filter.NettyAttribute; + +@Sharable +public class Netty4ConditionalDecompressor extends HttpContentDecompressor { + + @Override + protected EmbeddedChannel newContentDecoder(String contentEncoding) throws Exception { + final boolean hasAnEarlyReponse = NettyAttribute.peekFrom(ctx, EARLY_RESPONSE).isPresent(); + final boolean shouldDecompress = NettyAttribute.popFrom(ctx, SHOULD_DECOMPRESS).orElse(false); + if (hasAnEarlyReponse || !shouldDecompress) { + // If there was an error prompting an early response,... don't decompress + // If there is no explicit decompress flag,... don't decompress + // If there is a decompress flag and it is false,... don't decompress + return super.newContentDecoder("identity"); + } + + // Decompresses the content based on its encoding + return super.newContentDecoder(contentEncoding); + } +} diff --git a/src/main/java/org/opensearch/security/ssl/http/netty/Netty4HttpRequestHeaderVerifier.java b/src/main/java/org/opensearch/security/ssl/http/netty/Netty4HttpRequestHeaderVerifier.java new file mode 100644 index 0000000000..16e23ea361 --- /dev/null +++ b/src/main/java/org/opensearch/security/ssl/http/netty/Netty4HttpRequestHeaderVerifier.java @@ -0,0 +1,139 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.security.ssl.http.netty; + +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.util.ReferenceCountUtil; +import org.opensearch.ExceptionsHelper; +import org.opensearch.common.util.concurrent.ThreadContext; + +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import org.opensearch.http.netty4.Netty4HttpChannel; +import org.opensearch.http.netty4.Netty4HttpServerTransport; +import org.opensearch.rest.RestUtils; +import org.opensearch.security.filter.SecurityRequestChannel; +import org.opensearch.security.filter.SecurityRequestChannelUnsupported; +import org.opensearch.security.filter.SecurityRequestFactory; +import org.opensearch.security.filter.SecurityResponse; +import org.opensearch.security.filter.SecurityRestFilter; +import org.opensearch.security.filter.SecurityRestUtils; +import org.opensearch.security.ssl.transport.SSLConfig; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.security.support.ConfigConstants; +import org.opensearch.security.ssl.OpenSearchSecuritySSLPlugin; +import org.opensearch.common.settings.Settings; +import org.opensearch.OpenSearchSecurityException; + +import java.util.regex.Matcher; + +import static com.amazon.dlic.auth.http.saml.HTTPSamlAuthenticator.API_AUTHTOKEN_SUFFIX; +import static org.opensearch.security.filter.SecurityRestFilter.HEALTH_SUFFIX; +import static org.opensearch.security.filter.SecurityRestFilter.PATTERN_PATH_PREFIX; +import static org.opensearch.security.filter.SecurityRestFilter.WHO_AM_I_SUFFIX; +import static org.opensearch.security.http.SecurityHttpServerTransport.CONTEXT_TO_RESTORE; +import static org.opensearch.security.http.SecurityHttpServerTransport.EARLY_RESPONSE; +import static org.opensearch.security.http.SecurityHttpServerTransport.SHOULD_DECOMPRESS; +import static org.opensearch.security.http.SecurityHttpServerTransport.IS_AUTHENTICATED; + +@Sharable +public class Netty4HttpRequestHeaderVerifier extends SimpleChannelInboundHandler { + private final SecurityRestFilter restFilter; + private final ThreadPool threadPool; + private final SSLConfig sslConfig; + private final boolean injectUserEnabled; + private final boolean passthrough; + + public Netty4HttpRequestHeaderVerifier(SecurityRestFilter restFilter, ThreadPool threadPool, Settings settings) { + this.restFilter = restFilter; + this.threadPool = threadPool; + + this.injectUserEnabled = settings.getAsBoolean(ConfigConstants.SECURITY_UNSUPPORTED_INJECT_USER_ENABLED, false); + boolean disabled = settings.getAsBoolean(ConfigConstants.SECURITY_DISABLED, false); + if (disabled) { + sslConfig = new SSLConfig(false, false); + } else { + sslConfig = new SSLConfig(settings); + } + boolean client = !"node".equals(settings.get(OpenSearchSecuritySSLPlugin.CLIENT_TYPE)); + this.passthrough = client || disabled || sslConfig.isSslOnlyMode(); + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, DefaultHttpRequest msg) throws Exception { + // DefaultHttpRequest should always be first and contain headers + ReferenceCountUtil.retain(msg); + + if (passthrough) { + ctx.fireChannelRead(msg); + return; + } + + // Start by setting this value to false, only requests that meet all the criteria will be decompressed + ctx.channel().attr(SHOULD_DECOMPRESS).set(Boolean.FALSE); + ctx.channel().attr(IS_AUTHENTICATED).set(Boolean.FALSE); + + final Netty4HttpChannel httpChannel = ctx.channel().attr(Netty4HttpServerTransport.HTTP_CHANNEL_KEY).get(); + String rawPath = SecurityRestUtils.path(msg.uri()); + String path = RestUtils.decodeComponent(rawPath); + Matcher matcher = PATTERN_PATH_PREFIX.matcher(path); + final String suffix = matcher.matches() ? matcher.group(2) : null; + if (API_AUTHTOKEN_SUFFIX.equals(suffix)) { + ctx.fireChannelRead(msg); + return; + } + + final SecurityRequestChannel requestChannel = SecurityRequestFactory.from(msg, httpChannel); + ThreadContext threadContext = threadPool.getThreadContext(); + try (ThreadContext.StoredContext ignore = threadPool.getThreadContext().stashContext()) { + injectUser(msg, threadContext); + + boolean shouldSkipAuthentication = HttpMethod.OPTIONS.equals(msg.method()) + || HEALTH_SUFFIX.equals(suffix) + || WHO_AM_I_SUFFIX.equals(suffix); + + if (!shouldSkipAuthentication) { + // If request channel is completed and a response is sent, then there was a failure during authentication + restFilter.checkAndAuthenticateRequest(requestChannel); + } + + ThreadContext.StoredContext contextToRestore = threadPool.getThreadContext().newStoredContext(false); + ctx.channel().attr(CONTEXT_TO_RESTORE).set(contextToRestore); + + requestChannel.getQueuedResponse().ifPresent(response -> ctx.channel().attr(EARLY_RESPONSE).set(response)); + + boolean shouldDecompress = !shouldSkipAuthentication && requestChannel.getQueuedResponse().isEmpty(); + + if (requestChannel.getQueuedResponse().isEmpty() || shouldSkipAuthentication) { + // Only allow decompression on authenticated requests that also aren't one of those ^ + ctx.channel().attr(SHOULD_DECOMPRESS).set(Boolean.valueOf(shouldDecompress)); + ctx.channel().attr(IS_AUTHENTICATED).set(Boolean.TRUE); + } + } catch (final OpenSearchSecurityException e) { + final SecurityResponse earlyResponse = new SecurityResponse(ExceptionsHelper.status(e).getStatus(), null, e.getMessage()); + ctx.channel().attr(EARLY_RESPONSE).set(earlyResponse); + } catch (final SecurityRequestChannelUnsupported srcu) { + // Use defaults for unsupported channels + } finally { + ctx.fireChannelRead(msg); + } + } + + private void injectUser(HttpRequest request, ThreadContext threadContext) { + if (this.injectUserEnabled) { + threadContext.putTransient( + ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER, + request.headers().get(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER) + ); + } + } +} diff --git a/src/main/java/org/opensearch/security/ssl/http/netty/SecuritySSLNettyHttpServerTransport.java b/src/main/java/org/opensearch/security/ssl/http/netty/SecuritySSLNettyHttpServerTransport.java index 081cc13f3e..41e44ce371 100644 --- a/src/main/java/org/opensearch/security/ssl/http/netty/SecuritySSLNettyHttpServerTransport.java +++ b/src/main/java/org/opensearch/security/ssl/http/netty/SecuritySSLNettyHttpServerTransport.java @@ -20,6 +20,7 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.handler.codec.DecoderException; import io.netty.handler.ssl.ApplicationProtocolNames; import io.netty.handler.ssl.ApplicationProtocolNegotiationHandler; @@ -36,6 +37,7 @@ import org.opensearch.http.HttpHandlingSettings; import org.opensearch.http.netty4.Netty4HttpChannel; import org.opensearch.http.netty4.Netty4HttpServerTransport; +import org.opensearch.security.filter.SecurityRestFilter; import org.opensearch.security.ssl.SecurityKeyStore; import org.opensearch.security.ssl.SslExceptionHandler; import org.opensearch.telemetry.tracing.Tracer; @@ -46,6 +48,8 @@ public class SecuritySSLNettyHttpServerTransport extends Netty4HttpServerTranspo private static final Logger logger = LogManager.getLogger(SecuritySSLNettyHttpServerTransport.class); private final SecurityKeyStore sks; private final SslExceptionHandler errorHandler; + private final ChannelInboundHandlerAdapter headerVerifier; + private final ChannelInboundHandlerAdapter conditionalDecompressor; public SecuritySSLNettyHttpServerTransport( final Settings settings, @@ -58,7 +62,8 @@ public SecuritySSLNettyHttpServerTransport( final SslExceptionHandler errorHandler, ClusterSettings clusterSettings, SharedGroupFactory sharedGroupFactory, - Tracer tracer + Tracer tracer, + SecurityRestFilter restFilter ) { super( settings, @@ -73,6 +78,8 @@ public SecuritySSLNettyHttpServerTransport( ); this.sks = sks; this.errorHandler = errorHandler; + headerVerifier = new Netty4HttpRequestHeaderVerifier(restFilter, threadPool, settings); + conditionalDecompressor = new Netty4ConditionalDecompressor(); } @Override @@ -149,4 +156,14 @@ protected void configurePipeline(Channel ch) { ch.pipeline().addLast(new Http2OrHttpHandler()); } } + + @Override + protected ChannelInboundHandlerAdapter createHeaderVerifier() { + return headerVerifier; + } + + @Override + protected ChannelInboundHandlerAdapter createDecompressor() { + return conditionalDecompressor; + } } diff --git a/src/test/java/org/opensearch/security/SystemIntegratorsTests.java b/src/test/java/org/opensearch/security/SystemIntegratorsTests.java index 3c2195b6f5..27e44b1ce5 100644 --- a/src/test/java/org/opensearch/security/SystemIntegratorsTests.java +++ b/src/test/java/org/opensearch/security/SystemIntegratorsTests.java @@ -44,10 +44,7 @@ public class SystemIntegratorsTests extends SingleClusterTest { @Test public void testInjectedUserMalformed() throws Exception { - final Settings settings = Settings.builder() - .put(ConfigConstants.SECURITY_UNSUPPORTED_INJECT_USER_ENABLED, true) - .put("http.type", "org.opensearch.security.http.UserInjectingServerTransport") - .build(); + final Settings settings = Settings.builder().put(ConfigConstants.SECURITY_UNSUPPORTED_INJECT_USER_ENABLED, true).build(); setup(settings, ClusterConfiguration.USERINJECTOR); @@ -115,10 +112,7 @@ public void testInjectedUserMalformed() throws Exception { @Test public void testInjectedUser() throws Exception { - final Settings settings = Settings.builder() - .put(ConfigConstants.SECURITY_UNSUPPORTED_INJECT_USER_ENABLED, true) - .put("http.type", "org.opensearch.security.http.UserInjectingServerTransport") - .build(); + final Settings settings = Settings.builder().put(ConfigConstants.SECURITY_UNSUPPORTED_INJECT_USER_ENABLED, true).build(); setup(settings, ClusterConfiguration.USERINJECTOR); @@ -250,7 +244,7 @@ public void testInjectedUser() throws Exception { @Test public void testInjectedUserDisabled() throws Exception { - final Settings settings = Settings.builder().put("http.type", "org.opensearch.security.http.UserInjectingServerTransport").build(); + final Settings settings = Settings.builder().build(); setup(settings, ClusterConfiguration.USERINJECTOR); @@ -276,7 +270,6 @@ public void testInjectedAdminUser() throws Exception { ConfigConstants.SECURITY_AUTHCZ_ADMIN_DN, Lists.newArrayList("CN=kirk,OU=client,O=client,L=Test,C=DE", "injectedadmin") ) - .put("http.type", "org.opensearch.security.http.UserInjectingServerTransport") .build(); setup(settings, ClusterConfiguration.USERINJECTOR); @@ -312,7 +305,6 @@ public void testInjectedAdminUserAdminInjectionDisabled() throws Exception { ConfigConstants.SECURITY_AUTHCZ_ADMIN_DN, Lists.newArrayList("CN=kirk,OU=client,O=client,L=Test,C=DE", "injectedadmin") ) - .put("http.type", "org.opensearch.security.http.UserInjectingServerTransport") .build(); setup(settings, ClusterConfiguration.USERINJECTOR); diff --git a/src/test/java/org/opensearch/security/auditlog/integration/BasicAuditlogTest.java b/src/test/java/org/opensearch/security/auditlog/integration/BasicAuditlogTest.java index 0edc55c47c..6c1812c32b 100644 --- a/src/test/java/org/opensearch/security/auditlog/integration/BasicAuditlogTest.java +++ b/src/test/java/org/opensearch/security/auditlog/integration/BasicAuditlogTest.java @@ -510,12 +510,15 @@ public void testUpdateSettings() throws Exception { + "}" + "}"; + String expectedRequestBodyLog = + "{\\\"persistent_settings\\\":{\\\"indices\\\":{\\\"recovery\\\":{\\\"*\\\":null}}},\\\"transient_settings\\\":{\\\"indices\\\":{\\\"recovery\\\":{\\\"*\\\":null}}}}"; + HttpResponse response = rh.executePutRequest("_cluster/settings", json, encodeBasicHeader("admin", "admin")); Assert.assertEquals(HttpStatus.SC_OK, response.getStatusCode()); String auditLogImpl = TestAuditlogImpl.sb.toString(); Assert.assertTrue(auditLogImpl.contains("AUTHENTICATED")); Assert.assertTrue(auditLogImpl.contains("cluster:admin/settings/update")); - Assert.assertTrue(auditLogImpl.contains("indices.recovery.*")); + Assert.assertTrue(auditLogImpl.contains(expectedRequestBodyLog)); // may vary because we log may hit cluster manager directly or not Assert.assertTrue(TestAuditlogImpl.messages.size() > 1); Assert.assertTrue(validateMsgs(TestAuditlogImpl.messages)); diff --git a/src/test/java/org/opensearch/security/filter/SecurityRestFilterUnitTests.java b/src/test/java/org/opensearch/security/filter/SecurityRestFilterUnitTests.java new file mode 100644 index 0000000000..96d08659c4 --- /dev/null +++ b/src/test/java/org/opensearch/security/filter/SecurityRestFilterUnitTests.java @@ -0,0 +1,111 @@ +package org.opensearch.security.filter; + +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.security.auditlog.AuditLog; +import org.opensearch.security.auth.BackendRegistry; +import org.opensearch.security.configuration.AdminDNs; +import org.opensearch.security.configuration.CompatConfig; +import org.opensearch.security.privileges.RestLayerPrivilegesEvaluator; +import org.opensearch.security.ssl.transport.PrincipalExtractor; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.ThreadPool; + +import java.nio.file.Path; +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; + +public class SecurityRestFilterUnitTests { + + SecurityRestFilter sf; + RestHandler testRestHandler; + + class TestRestHandler implements RestHandler { + + @Override + public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception { + channel.sendResponse(new BytesRestResponse(RestStatus.OK, BytesRestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY)); + } + } + + @Before + public void setUp() throws NoSuchMethodException { + testRestHandler = new TestRestHandler(); + + ThreadPool tp = spy(new ThreadPool(Settings.builder().put("node.name", "mock").build())); + doReturn(new ThreadContext(Settings.EMPTY)).when(tp).getThreadContext(); + + sf = new SecurityRestFilter( + mock(BackendRegistry.class), + mock(RestLayerPrivilegesEvaluator.class), + mock(AuditLog.class), + tp, + mock(PrincipalExtractor.class), + Settings.EMPTY, + mock(Path.class), + mock(CompatConfig.class) + ); + } + + @Ignore + @Test + public void testDoesCallDelegateOnSuccessfulAuthorization() throws Exception { + SecurityRestFilter filterSpy = spy(sf); + AdminDNs adminDNs = mock(AdminDNs.class); + + RestHandler testRestHandlerSpy = spy(testRestHandler); + RestHandler wrappedRestHandler = filterSpy.wrap(testRestHandlerSpy, adminDNs); + + doReturn(false).when(filterSpy).userIsSuperAdmin(any(), any()); + // doReturn(true).when(filterSpy).authorizeRequest(any(), any(), any()); + + FakeRestRequest fakeRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withPath("/test") + .withMethod(RestRequest.Method.POST) + .withHeaders(Map.of("Content-Type", List.of("application/json"))) + .build(); + + wrappedRestHandler.handleRequest(fakeRequest, mock(RestChannel.class), mock(NodeClient.class)); + + verify(testRestHandlerSpy).handleRequest(any(), any(), any()); + } + + @Ignore + @Test + public void testDoesNotCallDelegateOnUnauthorized() throws Exception { + SecurityRestFilter filterSpy = spy(sf); + AdminDNs adminDNs = mock(AdminDNs.class); + + RestHandler testRestHandlerSpy = spy(testRestHandler); + RestHandler wrappedRestHandler = filterSpy.wrap(testRestHandlerSpy, adminDNs); + + doReturn(false).when(filterSpy).userIsSuperAdmin(any(), any()); + // doReturn(false).when(filterSpy).authorizeRequest(any(), any(), any()); + + FakeRestRequest fakeRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withPath("/test") + .withMethod(RestRequest.Method.POST) + .withHeaders(Map.of("Content-Type", List.of("application/json"))) + .build(); + + wrappedRestHandler.handleRequest(fakeRequest, mock(RestChannel.class), mock(NodeClient.class)); + + verify(testRestHandlerSpy, never()).handleRequest(any(), any(), any()); + } +} diff --git a/src/test/java/org/opensearch/security/test/helper/cluster/ClusterConfiguration.java b/src/test/java/org/opensearch/security/test/helper/cluster/ClusterConfiguration.java index e9b503e669..f653b272aa 100644 --- a/src/test/java/org/opensearch/security/test/helper/cluster/ClusterConfiguration.java +++ b/src/test/java/org/opensearch/security/test/helper/cluster/ClusterConfiguration.java @@ -41,7 +41,6 @@ import org.opensearch.script.mustache.MustacheModulePlugin; import org.opensearch.search.aggregations.matrix.MatrixAggregationModulePlugin; import org.opensearch.security.OpenSearchSecurityPlugin; -import org.opensearch.security.test.plugin.UserInjectorPlugin; import org.opensearch.transport.Netty4ModulePlugin; public enum ClusterConfiguration { @@ -79,11 +78,7 @@ public enum ClusterConfiguration { CLIENTNODE(new NodeSettings(true, false), new NodeSettings(false, true), new NodeSettings(false, true), new NodeSettings(false, false)), // 3 nodes (1m, 2d) plus additional UserInjectorPlugin - USERINJECTOR( - new NodeSettings(true, false, Lists.newArrayList(UserInjectorPlugin.class)), - new NodeSettings(false, true, Lists.newArrayList(UserInjectorPlugin.class)), - new NodeSettings(false, true, Lists.newArrayList(UserInjectorPlugin.class)) - ); + USERINJECTOR(new NodeSettings(true, false), new NodeSettings(false, true), new NodeSettings(false, true)); private List nodeSettings = new LinkedList<>(); diff --git a/src/test/java/org/opensearch/security/test/helper/rest/RestHelper.java b/src/test/java/org/opensearch/security/test/helper/rest/RestHelper.java index 7eefd22273..bd82fa2dd5 100644 --- a/src/test/java/org/opensearch/security/test/helper/rest/RestHelper.java +++ b/src/test/java/org/opensearch/security/test/helper/rest/RestHelper.java @@ -149,7 +149,7 @@ public void cancelled() { final SimpleHttpResponse response = future.join(); if (response.getCode() >= 300) { - throw new Exception("Statuscode " + response.getCode()); + throw new Exception("Statuscode " + response.getCode() + ", body " + response.getBodyText()); } if (enableHTTPClientSSL && !response.getVersion().equals(HttpVersion.HTTP_2)) { diff --git a/src/test/java/org/opensearch/security/test/plugin/UserInjectorPlugin.java b/src/test/java/org/opensearch/security/test/plugin/UserInjectorPlugin.java deleted file mode 100644 index 73ede93651..0000000000 --- a/src/test/java/org/opensearch/security/test/plugin/UserInjectorPlugin.java +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Copyright 2015-2018 _floragunn_ GmbH - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.security.test.plugin; - -import java.nio.file.Path; -import java.util.Map; -import java.util.function.Supplier; - -import com.google.common.collect.ImmutableMap; - -import org.opensearch.common.network.NetworkService; -import org.opensearch.common.settings.ClusterSettings; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.util.BigArrays; -import org.opensearch.common.util.PageCacheRecycler; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.core.indices.breaker.CircuitBreakerService; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.http.HttpServerTransport; -import org.opensearch.http.HttpServerTransport.Dispatcher; -import org.opensearch.http.netty4.Netty4HttpServerTransport; -import org.opensearch.plugins.NetworkPlugin; -import org.opensearch.plugins.Plugin; -import org.opensearch.rest.RestChannel; -import org.opensearch.rest.RestRequest; -import org.opensearch.security.support.ConfigConstants; -import org.opensearch.telemetry.tracing.Tracer; -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.SharedGroupFactory; - -/** - * Mimics the behavior of system integrators that run their own plugins (i.e. server transports) - * in front of OpenSearch Security. This transport just copies the user string from the - * REST headers to the ThreadContext to test user injection. - * @author jkressin - */ -public class UserInjectorPlugin extends Plugin implements NetworkPlugin { - - Settings settings; - private final SharedGroupFactory sharedGroupFactory; - ThreadPool threadPool; - - public UserInjectorPlugin(final Settings settings, final Path configPath) { - this.settings = settings; - sharedGroupFactory = new SharedGroupFactory(settings); - } - - @Override - public Map> getHttpTransports( - Settings settings, - ThreadPool threadPool, - BigArrays bigArrays, - PageCacheRecycler pageCacheRecycler, - CircuitBreakerService circuitBreakerService, - NamedXContentRegistry xContentRegistry, - NetworkService networkService, - Dispatcher dispatcher, - ClusterSettings clusterSettings, - Tracer tracer - ) { - - final UserInjectingDispatcher validatingDispatcher = new UserInjectingDispatcher(dispatcher); - return ImmutableMap.of( - "org.opensearch.security.http.UserInjectingServerTransport", - () -> new UserInjectingServerTransport( - settings, - networkService, - bigArrays, - threadPool, - xContentRegistry, - validatingDispatcher, - clusterSettings, - sharedGroupFactory, - tracer - ) - ); - } - - class UserInjectingServerTransport extends Netty4HttpServerTransport { - - public UserInjectingServerTransport( - final Settings settings, - final NetworkService networkService, - final BigArrays bigArrays, - final ThreadPool threadPool, - final NamedXContentRegistry namedXContentRegistry, - final Dispatcher dispatcher, - ClusterSettings clusterSettings, - SharedGroupFactory sharedGroupFactory, - Tracer tracer - ) { - super( - settings, - networkService, - bigArrays, - threadPool, - namedXContentRegistry, - dispatcher, - clusterSettings, - sharedGroupFactory, - tracer - ); - } - } - - class UserInjectingDispatcher implements Dispatcher { - - private Dispatcher originalDispatcher; - - public UserInjectingDispatcher(final Dispatcher originalDispatcher) { - super(); - this.originalDispatcher = originalDispatcher; - } - - @Override - public void dispatchRequest(RestRequest request, RestChannel channel, ThreadContext threadContext) { - threadContext.putTransient( - ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER, - request.header(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER) - ); - originalDispatcher.dispatchRequest(request, channel, threadContext); - - } - - @Override - public void dispatchBadRequest(RestChannel channel, ThreadContext threadContext, Throwable cause) { - threadContext.putTransient( - ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER, - channel.request().header(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER) - ); - originalDispatcher.dispatchBadRequest(channel, threadContext, cause); - } - } - -}