diff --git a/CHANGELOG.md b/CHANGELOG.md index 50b38450aa724..62cf5e45e80f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -114,6 +114,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add cluster primary balance contraint for rebalancing with buffer ([#12656](https://github.com/opensearch-project/OpenSearch/pull/12656)) - [Remote Store] Make translog transfer timeout configurable ([#12704](https://github.com/opensearch-project/OpenSearch/pull/12704)) - Reject Resize index requests (i.e, split, shrink and clone), While DocRep to SegRep migration is in progress.([#12686](https://github.com/opensearch-project/OpenSearch/pull/12686)) +- Add support for more than one protocol for transport ([#12967](https://github.com/opensearch-project/OpenSearch/pull/12967)) ### Dependencies - Bump `org.apache.commons:commons-configuration2` from 2.10.0 to 2.10.1 ([#12896](https://github.com/opensearch-project/OpenSearch/pull/12896)) diff --git a/server/build.gradle b/server/build.gradle index 7d52849844aaa..cb48142a61159 100644 --- a/server/build.gradle +++ b/server/build.gradle @@ -387,7 +387,7 @@ tasks.register("japicmp", me.champeau.gradle.japicmp.JapicmpTask) { onlyModified = true failOnModification = true ignoreMissingClasses = true - annotationIncludes = ['@org.opensearch.common.annotation.PublicApi'] + annotationIncludes = ['@org.opensearch.common.annotation.PublicApi', '@org.opensearch.common.annotation.DeprecatedApi'] txtOutputFile = layout.buildDirectory.file("reports/java-compatibility/report.txt") htmlOutputFile = layout.buildDirectory.file("reports/java-compatibility/report.html") dependsOn downloadSnapshot diff --git a/server/src/main/java/org/opensearch/transport/Header.java b/server/src/main/java/org/opensearch/transport/Header.java index 57c1da6f46aec..ac30df8dda02c 100644 --- a/server/src/main/java/org/opensearch/transport/Header.java +++ b/server/src/main/java/org/opensearch/transport/Header.java @@ -75,11 +75,11 @@ public int getNetworkMessageSize() { return networkMessageSize; } - Version getVersion() { + public Version getVersion() { return version; } - long getRequestId() { + public long getRequestId() { return requestId; } @@ -87,7 +87,7 @@ byte getStatus() { return status; } - boolean isRequest() { + public boolean isRequest() { return TransportStatus.isRequest(status); } @@ -99,7 +99,7 @@ boolean isError() { return TransportStatus.isError(status); } - boolean isHandshake() { + public boolean isHandshake() { return TransportStatus.isHandshake(status); } diff --git a/server/src/main/java/org/opensearch/transport/InboundBytesHandler.java b/server/src/main/java/org/opensearch/transport/InboundBytesHandler.java new file mode 100644 index 0000000000000..276891212e43f --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/InboundBytesHandler.java @@ -0,0 +1,32 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.transport; + +import org.opensearch.common.bytes.ReleasableBytesReference; + +import java.io.Closeable; +import java.io.IOException; +import java.util.function.BiConsumer; + +/** + * Interface for handling inbound bytes. Can be implemented by different transport protocols. + */ +public interface InboundBytesHandler extends Closeable { + + public void doHandleBytes( + TcpChannel channel, + ReleasableBytesReference reference, + BiConsumer messageHandler + ) throws IOException; + + public boolean canHandleBytes(ReleasableBytesReference reference); + + @Override + void close(); +} diff --git a/server/src/main/java/org/opensearch/transport/InboundDecoder.java b/server/src/main/java/org/opensearch/transport/InboundDecoder.java index 82fc09a985446..d6b7a98e876b3 100644 --- a/server/src/main/java/org/opensearch/transport/InboundDecoder.java +++ b/server/src/main/java/org/opensearch/transport/InboundDecoder.java @@ -50,8 +50,8 @@ */ public class InboundDecoder implements Releasable { - static final Object PING = new Object(); - static final Object END_CONTENT = new Object(); + public static final Object PING = new Object(); + public static final Object END_CONTENT = new Object(); private final Version version; private final PageCacheRecycler recycler; diff --git a/server/src/main/java/org/opensearch/transport/InboundHandler.java b/server/src/main/java/org/opensearch/transport/InboundHandler.java index a8315c3cae4e0..6492900c49a0e 100644 --- a/server/src/main/java/org/opensearch/transport/InboundHandler.java +++ b/server/src/main/java/org/opensearch/transport/InboundHandler.java @@ -32,35 +32,14 @@ package org.opensearch.transport; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.message.ParameterizedMessage; -import org.apache.lucene.util.BytesRef; -import org.opensearch.Version; import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.util.concurrent.AbstractRunnable; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.core.common.io.stream.ByteBufferStreamInput; -import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.transport.TransportAddress; -import org.opensearch.core.transport.TransportResponse; -import org.opensearch.telemetry.tracing.Span; -import org.opensearch.telemetry.tracing.SpanBuilder; -import org.opensearch.telemetry.tracing.SpanScope; import org.opensearch.telemetry.tracing.Tracer; -import org.opensearch.telemetry.tracing.channels.TraceableTcpTransportChannel; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.nativeprotocol.NativeInboundMessage; -import java.io.EOFException; import java.io.IOException; -import java.net.InetSocketAddress; -import java.nio.ByteBuffer; -import java.util.Collection; -import java.util.Collections; import java.util.Map; -import java.util.stream.Collectors; /** * Handler for inbound data @@ -69,21 +48,13 @@ */ public class InboundHandler { - private static final Logger logger = LogManager.getLogger(InboundHandler.class); - private final ThreadPool threadPool; - private final OutboundHandler outboundHandler; - private final NamedWriteableRegistry namedWriteableRegistry; - private final TransportHandshaker handshaker; - private final TransportKeepAlive keepAlive; - private final Transport.ResponseHandlers responseHandlers; - private final Transport.RequestHandlers requestHandlers; private volatile TransportMessageListener messageListener = TransportMessageListener.NOOP_LISTENER; private volatile long slowLogThresholdMs = Long.MAX_VALUE; - private final Tracer tracer; + private final Map protocolMessageHandlers; InboundHandler( ThreadPool threadPool, @@ -96,13 +67,19 @@ public class InboundHandler { Tracer tracer ) { this.threadPool = threadPool; - this.outboundHandler = outboundHandler; - this.namedWriteableRegistry = namedWriteableRegistry; - this.handshaker = handshaker; - this.keepAlive = keepAlive; - this.requestHandlers = requestHandlers; - this.responseHandlers = responseHandlers; - this.tracer = tracer; + this.protocolMessageHandlers = Map.of( + NativeInboundMessage.NATIVE_PROTOCOL, + new NativeMessageHandler( + threadPool, + outboundHandler, + namedWriteableRegistry, + handshaker, + requestHandlers, + responseHandlers, + tracer, + keepAlive + ) + ); } void setMessageListener(TransportMessageListener listener) { @@ -117,377 +94,17 @@ void setSlowLogThreshold(TimeValue slowLogThreshold) { this.slowLogThresholdMs = slowLogThreshold.getMillis(); } - void inboundMessage(TcpChannel channel, InboundMessage message) throws Exception { + void inboundMessage(TcpChannel channel, ProtocolInboundMessage message) throws Exception { final long startTime = threadPool.relativeTimeInMillis(); channel.getChannelStats().markAccessed(startTime); - TransportLogger.logInboundMessage(channel, message); - if (message.isPing()) { - keepAlive.receiveKeepAlive(channel); - } else { - messageReceived(channel, message, startTime); - } - } - - // Empty stream constant to avoid instantiating a new stream for empty messages. - private static final StreamInput EMPTY_STREAM_INPUT = new ByteBufferStreamInput(ByteBuffer.wrap(BytesRef.EMPTY_BYTES)); - - private void messageReceived(TcpChannel channel, InboundMessage message, long startTime) throws IOException { - final InetSocketAddress remoteAddress = channel.getRemoteAddress(); - final Header header = message.getHeader(); - assert header.needsToReadVariableHeader() == false; - ThreadContext threadContext = threadPool.getThreadContext(); - try (ThreadContext.StoredContext existing = threadContext.stashContext()) { - // Place the context with the headers from the message - threadContext.setHeaders(header.getHeaders()); - threadContext.putTransient("_remote_address", remoteAddress); - if (header.isRequest()) { - handleRequest(channel, header, message); - } else { - // Responses do not support short circuiting currently - assert message.isShortCircuit() == false; - final TransportResponseHandler handler; - long requestId = header.getRequestId(); - if (header.isHandshake()) { - handler = handshaker.removeHandlerForHandshake(requestId); - } else { - TransportResponseHandler theHandler = responseHandlers.onResponseReceived( - requestId, - messageListener - ); - if (theHandler == null && header.isError()) { - handler = handshaker.removeHandlerForHandshake(requestId); - } else { - handler = theHandler; - } - } - // ignore if its null, the service logs it - if (handler != null) { - final StreamInput streamInput; - if (message.getContentLength() > 0 || header.getVersion().equals(Version.CURRENT) == false) { - streamInput = namedWriteableStream(message.openOrGetStreamInput()); - assertRemoteVersion(streamInput, header.getVersion()); - if (header.isError()) { - handlerResponseError(requestId, streamInput, handler); - } else { - handleResponse(requestId, remoteAddress, streamInput, handler); - } - } else { - assert header.isError() == false; - handleResponse(requestId, remoteAddress, EMPTY_STREAM_INPUT, handler); - } - } - - } - } finally { - final long took = threadPool.relativeTimeInMillis() - startTime; - final long logThreshold = slowLogThresholdMs; - if (logThreshold > 0 && took > logThreshold) { - logger.warn( - "handling inbound transport message [{}] took [{}ms] which is above the warn threshold of [{}ms]", - message, - took, - logThreshold - ); - } - } - } - - private Map> extractHeaders(Map headers) { - return headers.entrySet().stream().collect(Collectors.toMap(e -> e.getKey(), e -> Collections.singleton(e.getValue()))); - } - - private void handleRequest(TcpChannel channel, Header header, InboundMessage message) throws IOException { - final String action = header.getActionName(); - final long requestId = header.getRequestId(); - final Version version = header.getVersion(); - final Map> headers = extractHeaders(header.getHeaders().v1()); - Span span = tracer.startSpan(SpanBuilder.from(action, channel), headers); - try (SpanScope spanScope = tracer.withSpanInScope(span)) { - if (header.isHandshake()) { - messageListener.onRequestReceived(requestId, action); - // Cannot short circuit handshakes - assert message.isShortCircuit() == false; - final StreamInput stream = namedWriteableStream(message.openOrGetStreamInput()); - assertRemoteVersion(stream, header.getVersion()); - final TcpTransportChannel transportChannel = new TcpTransportChannel( - outboundHandler, - channel, - action, - requestId, - version, - header.getFeatures(), - header.isCompressed(), - header.isHandshake(), - message.takeBreakerReleaseControl() - ); - TransportChannel traceableTransportChannel = TraceableTcpTransportChannel.create(transportChannel, span, tracer); - try { - handshaker.handleHandshake(traceableTransportChannel, requestId, stream); - } catch (Exception e) { - if (Version.CURRENT.isCompatible(header.getVersion())) { - sendErrorResponse(action, traceableTransportChannel, e); - } else { - logger.warn( - new ParameterizedMessage( - "could not send error response to handshake received on [{}] using wire format version [{}], closing channel", - channel, - header.getVersion() - ), - e - ); - channel.close(); - } - } - } else { - final TcpTransportChannel transportChannel = new TcpTransportChannel( - outboundHandler, - channel, - action, - requestId, - version, - header.getFeatures(), - header.isCompressed(), - header.isHandshake(), - message.takeBreakerReleaseControl() - ); - TransportChannel traceableTransportChannel = TraceableTcpTransportChannel.create(transportChannel, span, tracer); - try { - messageListener.onRequestReceived(requestId, action); - if (message.isShortCircuit()) { - sendErrorResponse(action, traceableTransportChannel, message.getException()); - } else { - final StreamInput stream = namedWriteableStream(message.openOrGetStreamInput()); - assertRemoteVersion(stream, header.getVersion()); - final RequestHandlerRegistry reg = requestHandlers.getHandler(action); - assert reg != null; - - final T request = newRequest(requestId, action, stream, reg); - request.remoteAddress(new TransportAddress(channel.getRemoteAddress())); - checkStreamIsFullyConsumed(requestId, action, stream); - - final String executor = reg.getExecutor(); - if (ThreadPool.Names.SAME.equals(executor)) { - try { - reg.processMessageReceived(request, traceableTransportChannel); - } catch (Exception e) { - sendErrorResponse(reg.getAction(), traceableTransportChannel, e); - } - } else { - threadPool.executor(executor).execute(new RequestHandler<>(reg, request, traceableTransportChannel)); - } - } - } catch (Exception e) { - sendErrorResponse(action, traceableTransportChannel, e); - } - } - } catch (Exception e) { - span.setError(e); - span.endSpan(); - throw e; - } - } - - /** - * Creates new request instance out of input stream. Throws IllegalStateException if the end of - * the stream was reached before the request is fully deserialized from the stream. - * @param transport request type - * @param requestId request identifier - * @param action action name - * @param stream stream - * @param reg request handler registry - * @return new request instance - * @throws IOException IOException - * @throws IllegalStateException IllegalStateException - */ - private T newRequest( - final long requestId, - final String action, - final StreamInput stream, - final RequestHandlerRegistry reg - ) throws IOException { - try { - return reg.newRequest(stream); - } catch (final EOFException e) { - // Another favor of (de)serialization issues is when stream contains less bytes than - // the request handler needs to deserialize the payload. - throw new IllegalStateException( - "Message fully read (request) but more data is expected for requestId [" - + requestId - + "], action [" - + action - + "]; resetting", - e - ); - } - } - - /** - * Checks if the stream is fully consumed and throws the exceptions if that is not the case. - * @param requestId request identifier - * @param action action name - * @param stream stream - * @throws IOException IOException - */ - private void checkStreamIsFullyConsumed(final long requestId, final String action, final StreamInput stream) throws IOException { - // in case we throw an exception, i.e. when the limit is hit, we don't want to verify - final int nextByte = stream.read(); - - // calling read() is useful to make sure the message is fully read, even if there some kind of EOS marker - if (nextByte != -1) { - throw new IllegalStateException( - "Message not fully read (request) for requestId [" - + requestId - + "], action [" - + action - + "], available [" - + stream.available() - + "]; resetting" - ); - } - } - - /** - * Checks if the stream is fully consumed and throws the exceptions if that is not the case. - * @param requestId request identifier - * @param handler response handler - * @param stream stream - * @param error "true" if response represents error, "false" otherwise - * @throws IOException IOException - */ - private void checkStreamIsFullyConsumed( - final long requestId, - final TransportResponseHandler handler, - final StreamInput stream, - final boolean error - ) throws IOException { - if (stream != EMPTY_STREAM_INPUT) { - // Check the entire message has been read - final int nextByte = stream.read(); - // calling read() is useful to make sure the message is fully read, even if there is an EOS marker - if (nextByte != -1) { - throw new IllegalStateException( - "Message not fully read (response) for requestId [" - + requestId - + "], handler [" - + handler - + "], error [" - + error - + "]; resetting" - ); - } - } - } - - private static void sendErrorResponse(String actionName, TransportChannel transportChannel, Exception e) { - try { - transportChannel.sendResponse(e); - } catch (Exception inner) { - inner.addSuppressed(e); - logger.warn(() -> new ParameterizedMessage("Failed to send error message back to client for action [{}]", actionName), inner); - } - } - - private void handleResponse( - final long requestId, - InetSocketAddress remoteAddress, - final StreamInput stream, - final TransportResponseHandler handler - ) { - final T response; - try { - response = handler.read(stream); - response.remoteAddress(new TransportAddress(remoteAddress)); - checkStreamIsFullyConsumed(requestId, handler, stream, false); - } catch (Exception e) { - final Exception serializationException = new TransportSerializationException( - "Failed to deserialize response from handler [" + handler + "]", - e - ); - logger.warn(new ParameterizedMessage("Failed to deserialize response from [{}]", remoteAddress), serializationException); - handleException(handler, serializationException); - return; - } - final String executor = handler.executor(); - if (ThreadPool.Names.SAME.equals(executor)) { - doHandleResponse(handler, response); - } else { - threadPool.executor(executor).execute(() -> doHandleResponse(handler, response)); - } + messageReceivedFromPipeline(channel, message, startTime); } - private void doHandleResponse(TransportResponseHandler handler, T response) { - try { - handler.handleResponse(response); - } catch (Exception e) { - handleException(handler, new ResponseHandlerFailureTransportException(e)); - } - } - - private void handlerResponseError(final long requestId, StreamInput stream, final TransportResponseHandler handler) { - Exception error; - try { - error = stream.readException(); - checkStreamIsFullyConsumed(requestId, handler, stream, true); - } catch (Exception e) { - error = new TransportSerializationException( - "Failed to deserialize exception response from stream for handler [" + handler + "]", - e - ); - } - handleException(handler, error); - } - - private void handleException(final TransportResponseHandler handler, Throwable error) { - if (!(error instanceof RemoteTransportException)) { - error = new RemoteTransportException(error.getMessage(), error); - } - final RemoteTransportException rtx = (RemoteTransportException) error; - threadPool.executor(handler.executor()).execute(() -> { - try { - handler.handleException(rtx); - } catch (Exception e) { - logger.error(() -> new ParameterizedMessage("failed to handle exception response [{}]", handler), e); - } - }); - } - - private StreamInput namedWriteableStream(StreamInput delegate) { - return new NamedWriteableAwareStreamInput(delegate, namedWriteableRegistry); - } - - static void assertRemoteVersion(StreamInput in, Version version) { - assert version.equals(in.getVersion()) : "Stream version [" + in.getVersion() + "] does not match version [" + version + "]"; - } - - /** - * Internal request handler - * - * @opensearch.internal - */ - private static class RequestHandler extends AbstractRunnable { - private final RequestHandlerRegistry reg; - private final T request; - private final TransportChannel transportChannel; - - RequestHandler(RequestHandlerRegistry reg, T request, TransportChannel transportChannel) { - this.reg = reg; - this.request = request; - this.transportChannel = transportChannel; - } - - @Override - protected void doRun() throws Exception { - reg.processMessageReceived(request, transportChannel); - } - - @Override - public boolean isForceExecution() { - return reg.isForceExecution(); - } - - @Override - public void onFailure(Exception e) { - sendErrorResponse(reg.getAction(), transportChannel, e); + private void messageReceivedFromPipeline(TcpChannel channel, ProtocolInboundMessage message, long startTime) throws IOException { + ProtocolMessageHandler protocolMessageHandler = protocolMessageHandlers.get(message.getProtocol()); + if (protocolMessageHandler == null) { + throw new IllegalStateException("No protocol message handler found for protocol: " + message.getProtocol()); } + protocolMessageHandler.messageReceived(channel, message, startTime, slowLogThresholdMs, messageListener); } } diff --git a/server/src/main/java/org/opensearch/transport/InboundMessage.java b/server/src/main/java/org/opensearch/transport/InboundMessage.java index 71c4d6973505d..5c68257557061 100644 --- a/server/src/main/java/org/opensearch/transport/InboundMessage.java +++ b/server/src/main/java/org/opensearch/transport/InboundMessage.java @@ -32,105 +32,77 @@ package org.opensearch.transport; -import org.opensearch.common.annotation.PublicApi; +import org.opensearch.common.annotation.DeprecatedApi; import org.opensearch.common.bytes.ReleasableBytesReference; import org.opensearch.common.lease.Releasable; -import org.opensearch.common.lease.Releasables; -import org.opensearch.common.util.io.IOUtils; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.transport.nativeprotocol.NativeInboundMessage; import java.io.IOException; /** * Inbound data as a message - * + * This api is deprecated, please use {@link org.opensearch.transport.nativeprotocol.NativeInboundMessage} instead. * @opensearch.api */ -@PublicApi(since = "1.0.0") -public class InboundMessage implements Releasable { +@DeprecatedApi(since = "2.14.0") +public class InboundMessage implements Releasable, ProtocolInboundMessage { - private final Header header; - private final ReleasableBytesReference content; - private final Exception exception; - private final boolean isPing; - private Releasable breakerRelease; - private StreamInput streamInput; + private final NativeInboundMessage nativeInboundMessage; public InboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) { - this.header = header; - this.content = content; - this.breakerRelease = breakerRelease; - this.exception = null; - this.isPing = false; + this.nativeInboundMessage = new NativeInboundMessage(header, content, breakerRelease); } public InboundMessage(Header header, Exception exception) { - this.header = header; - this.content = null; - this.breakerRelease = null; - this.exception = exception; - this.isPing = false; + this.nativeInboundMessage = new NativeInboundMessage(header, exception); } public InboundMessage(Header header, boolean isPing) { - this.header = header; - this.content = null; - this.breakerRelease = null; - this.exception = null; - this.isPing = isPing; + this.nativeInboundMessage = new NativeInboundMessage(header, isPing); } public Header getHeader() { - return header; + return this.nativeInboundMessage.getHeader(); } public int getContentLength() { - if (content == null) { - return 0; - } else { - return content.length(); - } + return this.nativeInboundMessage.getContentLength(); } public Exception getException() { - return exception; + return this.nativeInboundMessage.getException(); } public boolean isPing() { - return isPing; + return this.nativeInboundMessage.isPing(); } public boolean isShortCircuit() { - return exception != null; + return this.nativeInboundMessage.getException() != null; } public Releasable takeBreakerReleaseControl() { - final Releasable toReturn = breakerRelease; - breakerRelease = null; - if (toReturn != null) { - return toReturn; - } else { - return () -> {}; - } + return this.nativeInboundMessage.takeBreakerReleaseControl(); } public StreamInput openOrGetStreamInput() throws IOException { - assert isPing == false && content != null; - if (streamInput == null) { - streamInput = content.streamInput(); - streamInput.setVersion(header.getVersion()); - } - return streamInput; + return this.nativeInboundMessage.openOrGetStreamInput(); } @Override public void close() { - IOUtils.closeWhileHandlingException(streamInput); - Releasables.closeWhileHandlingException(content, breakerRelease); + this.nativeInboundMessage.close(); } @Override public String toString() { - return "InboundMessage{" + header + "}"; + return this.nativeInboundMessage.toString(); } + + @Override + public String getProtocol() { + return this.nativeInboundMessage.getProtocol(); + } + } diff --git a/server/src/main/java/org/opensearch/transport/InboundPipeline.java b/server/src/main/java/org/opensearch/transport/InboundPipeline.java index dd4690e5e6abf..5cee3bb975223 100644 --- a/server/src/main/java/org/opensearch/transport/InboundPipeline.java +++ b/server/src/main/java/org/opensearch/transport/InboundPipeline.java @@ -38,11 +38,11 @@ import org.opensearch.common.lease.Releasables; import org.opensearch.common.util.PageCacheRecycler; import org.opensearch.core.common.breaker.CircuitBreaker; -import org.opensearch.core.common.bytes.CompositeBytesReference; +import org.opensearch.transport.nativeprotocol.NativeInboundBytesHandler; import java.io.IOException; import java.util.ArrayDeque; -import java.util.ArrayList; +import java.util.List; import java.util.function.BiConsumer; import java.util.function.Function; import java.util.function.LongSupplier; @@ -55,17 +55,16 @@ */ public class InboundPipeline implements Releasable { - private static final ThreadLocal> fragmentList = ThreadLocal.withInitial(ArrayList::new); - private static final InboundMessage PING_MESSAGE = new InboundMessage(null, true); - private final LongSupplier relativeTimeInMillis; private final StatsTracker statsTracker; private final InboundDecoder decoder; private final InboundAggregator aggregator; - private final BiConsumer messageHandler; private Exception uncaughtException; private final ArrayDeque pending = new ArrayDeque<>(2); private boolean isClosed = false; + private final BiConsumer messageHandler; + private final List protocolBytesHandlers; + private InboundBytesHandler currentHandler; public InboundPipeline( Version version, @@ -74,7 +73,7 @@ public InboundPipeline( LongSupplier relativeTimeInMillis, Supplier circuitBreaker, Function> registryFunction, - BiConsumer messageHandler + BiConsumer messageHandler ) { this( statsTracker, @@ -90,18 +89,23 @@ public InboundPipeline( LongSupplier relativeTimeInMillis, InboundDecoder decoder, InboundAggregator aggregator, - BiConsumer messageHandler + BiConsumer messageHandler ) { this.relativeTimeInMillis = relativeTimeInMillis; this.statsTracker = statsTracker; this.decoder = decoder; this.aggregator = aggregator; + this.protocolBytesHandlers = List.of(new NativeInboundBytesHandler(pending, decoder, aggregator, statsTracker)); this.messageHandler = messageHandler; } @Override public void close() { isClosed = true; + if (currentHandler != null) { + currentHandler.close(); + currentHandler = null; + } Releasables.closeWhileHandlingException(decoder, aggregator); Releasables.closeWhileHandlingException(pending); pending.clear(); @@ -124,95 +128,21 @@ public void doHandleBytes(TcpChannel channel, ReleasableBytesReference reference statsTracker.markBytesRead(reference.length()); pending.add(reference.retain()); - final ArrayList fragments = fragmentList.get(); - boolean continueHandling = true; - - while (continueHandling && isClosed == false) { - boolean continueDecoding = true; - while (continueDecoding && pending.isEmpty() == false) { - try (ReleasableBytesReference toDecode = getPendingBytes()) { - final int bytesDecoded = decoder.decode(toDecode, fragments::add); - if (bytesDecoded != 0) { - releasePendingBytes(bytesDecoded); - if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) { - continueDecoding = false; - } - } else { - continueDecoding = false; - } - } - } - - if (fragments.isEmpty()) { - continueHandling = false; - } else { - try { - forwardFragments(channel, fragments); - } finally { - for (Object fragment : fragments) { - if (fragment instanceof ReleasableBytesReference) { - ((ReleasableBytesReference) fragment).close(); - } - } - fragments.clear(); - } - } - } - } - - private void forwardFragments(TcpChannel channel, ArrayList fragments) throws IOException { - for (Object fragment : fragments) { - if (fragment instanceof Header) { - assert aggregator.isAggregating() == false; - aggregator.headerReceived((Header) fragment); - } else if (fragment == InboundDecoder.PING) { - assert aggregator.isAggregating() == false; - messageHandler.accept(channel, PING_MESSAGE); - } else if (fragment == InboundDecoder.END_CONTENT) { - assert aggregator.isAggregating(); - try (InboundMessage aggregated = aggregator.finishAggregation()) { - statsTracker.markMessageReceived(); - messageHandler.accept(channel, aggregated); + // If we don't have a current handler, we should try to find one based on the protocol of the incoming bytes. + if (currentHandler == null) { + for (InboundBytesHandler handler : protocolBytesHandlers) { + if (handler.canHandleBytes(reference)) { + currentHandler = handler; + break; } - } else { - assert aggregator.isAggregating(); - assert fragment instanceof ReleasableBytesReference; - aggregator.aggregate((ReleasableBytesReference) fragment); } } - } - private boolean endOfMessage(Object fragment) { - return fragment == InboundDecoder.PING || fragment == InboundDecoder.END_CONTENT || fragment instanceof Exception; - } - - private ReleasableBytesReference getPendingBytes() { - if (pending.size() == 1) { - return pending.peekFirst().retain(); + // If we have a current handler determined based on protocol, we should continue to use it for the fragmented bytes. + if (currentHandler != null) { + currentHandler.doHandleBytes(channel, reference, messageHandler); } else { - final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()]; - int index = 0; - for (ReleasableBytesReference pendingReference : pending) { - bytesReferences[index] = pendingReference.retain(); - ++index; - } - final Releasable releasable = () -> Releasables.closeWhileHandlingException(bytesReferences); - return new ReleasableBytesReference(CompositeBytesReference.of(bytesReferences), releasable); - } - } - - private void releasePendingBytes(int bytesConsumed) { - int bytesToRelease = bytesConsumed; - while (bytesToRelease != 0) { - try (ReleasableBytesReference reference = pending.pollFirst()) { - assert reference != null; - if (bytesToRelease < reference.length()) { - pending.addFirst(reference.retainedSlice(bytesToRelease, reference.length() - bytesToRelease)); - bytesToRelease -= bytesToRelease; - } else { - bytesToRelease -= reference.length(); - } - } + throw new IllegalStateException("No bytes handler found for the incoming transport protocol"); } } } diff --git a/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java b/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java new file mode 100644 index 0000000000000..861b95a8098f2 --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/NativeMessageHandler.java @@ -0,0 +1,494 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.apache.lucene.util.BytesRef; +import org.opensearch.Version; +import org.opensearch.common.util.concurrent.AbstractRunnable; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.common.io.stream.ByteBufferStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.core.transport.TransportResponse; +import org.opensearch.telemetry.tracing.Span; +import org.opensearch.telemetry.tracing.SpanBuilder; +import org.opensearch.telemetry.tracing.SpanScope; +import org.opensearch.telemetry.tracing.Tracer; +import org.opensearch.telemetry.tracing.channels.TraceableTcpTransportChannel; +import org.opensearch.threadpool.ThreadPool; + +import java.io.EOFException; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.util.Collection; +import java.util.Collections; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Native handler for inbound data + * + * @opensearch.internal + */ +public class NativeMessageHandler implements ProtocolMessageHandler { + + private static final Logger logger = LogManager.getLogger(NativeMessageHandler.class); + + private final ThreadPool threadPool; + private final OutboundHandler outboundHandler; + private final NamedWriteableRegistry namedWriteableRegistry; + private final TransportHandshaker handshaker; + private final TransportKeepAlive keepAlive; + private final Transport.ResponseHandlers responseHandlers; + private final Transport.RequestHandlers requestHandlers; + + private final Tracer tracer; + + NativeMessageHandler( + ThreadPool threadPool, + OutboundHandler outboundHandler, + NamedWriteableRegistry namedWriteableRegistry, + TransportHandshaker handshaker, + Transport.RequestHandlers requestHandlers, + Transport.ResponseHandlers responseHandlers, + Tracer tracer, + TransportKeepAlive keepAlive + ) { + this.threadPool = threadPool; + this.outboundHandler = outboundHandler; + this.namedWriteableRegistry = namedWriteableRegistry; + this.handshaker = handshaker; + this.requestHandlers = requestHandlers; + this.responseHandlers = responseHandlers; + this.tracer = tracer; + this.keepAlive = keepAlive; + } + + // Empty stream constant to avoid instantiating a new stream for empty messages. + private static final StreamInput EMPTY_STREAM_INPUT = new ByteBufferStreamInput(ByteBuffer.wrap(BytesRef.EMPTY_BYTES)); + + @Override + public void messageReceived( + TcpChannel channel, + ProtocolInboundMessage message, + long startTime, + long slowLogThresholdMs, + TransportMessageListener messageListener + ) throws IOException { + InboundMessage inboundMessage = (InboundMessage) message; + TransportLogger.logInboundMessage(channel, inboundMessage); + if (inboundMessage.isPing()) { + keepAlive.receiveKeepAlive(channel); + } else { + handleMessage(channel, inboundMessage, startTime, slowLogThresholdMs, messageListener); + } + } + + private void handleMessage( + TcpChannel channel, + InboundMessage message, + long startTime, + long slowLogThresholdMs, + TransportMessageListener messageListener + ) throws IOException { + final InetSocketAddress remoteAddress = channel.getRemoteAddress(); + final Header header = message.getHeader(); + assert header.needsToReadVariableHeader() == false; + ThreadContext threadContext = threadPool.getThreadContext(); + try (ThreadContext.StoredContext existing = threadContext.stashContext()) { + // Place the context with the headers from the message + threadContext.setHeaders(header.getHeaders()); + threadContext.putTransient("_remote_address", remoteAddress); + if (header.isRequest()) { + handleRequest(channel, header, message, messageListener); + } else { + // Responses do not support short circuiting currently + assert message.isShortCircuit() == false; + final TransportResponseHandler handler; + long requestId = header.getRequestId(); + if (header.isHandshake()) { + handler = handshaker.removeHandlerForHandshake(requestId); + } else { + TransportResponseHandler theHandler = responseHandlers.onResponseReceived( + requestId, + messageListener + ); + if (theHandler == null && header.isError()) { + handler = handshaker.removeHandlerForHandshake(requestId); + } else { + handler = theHandler; + } + } + // ignore if its null, the service logs it + if (handler != null) { + final StreamInput streamInput; + if (message.getContentLength() > 0 || header.getVersion().equals(Version.CURRENT) == false) { + streamInput = namedWriteableStream(message.openOrGetStreamInput()); + assertRemoteVersion(streamInput, header.getVersion()); + if (header.isError()) { + handlerResponseError(requestId, streamInput, handler); + } else { + handleResponse(requestId, remoteAddress, streamInput, handler); + } + } else { + assert header.isError() == false; + handleResponse(requestId, remoteAddress, EMPTY_STREAM_INPUT, handler); + } + } + + } + } finally { + final long took = threadPool.relativeTimeInMillis() - startTime; + final long logThreshold = slowLogThresholdMs; + if (logThreshold > 0 && took > logThreshold) { + logger.warn( + "handling inbound transport message [{}] took [{}ms] which is above the warn threshold of [{}ms]", + message, + took, + logThreshold + ); + } + } + } + + private Map> extractHeaders(Map headers) { + return headers.entrySet().stream().collect(Collectors.toMap(e -> e.getKey(), e -> Collections.singleton(e.getValue()))); + } + + private void handleRequest( + TcpChannel channel, + Header header, + InboundMessage message, + TransportMessageListener messageListener + ) throws IOException { + final String action = header.getActionName(); + final long requestId = header.getRequestId(); + final Version version = header.getVersion(); + final Map> headers = extractHeaders(header.getHeaders().v1()); + Span span = tracer.startSpan(SpanBuilder.from(action, channel), headers); + try (SpanScope spanScope = tracer.withSpanInScope(span)) { + if (header.isHandshake()) { + messageListener.onRequestReceived(requestId, action); + // Cannot short circuit handshakes + assert message.isShortCircuit() == false; + final StreamInput stream = namedWriteableStream(message.openOrGetStreamInput()); + assertRemoteVersion(stream, header.getVersion()); + final TcpTransportChannel transportChannel = new TcpTransportChannel( + outboundHandler, + channel, + action, + requestId, + version, + header.getFeatures(), + header.isCompressed(), + header.isHandshake(), + message.takeBreakerReleaseControl() + ); + TransportChannel traceableTransportChannel = TraceableTcpTransportChannel.create(transportChannel, span, tracer); + try { + handshaker.handleHandshake(traceableTransportChannel, requestId, stream); + } catch (Exception e) { + if (Version.CURRENT.isCompatible(header.getVersion())) { + sendErrorResponse(action, traceableTransportChannel, e); + } else { + logger.warn( + new ParameterizedMessage( + "could not send error response to handshake received on [{}] using wire format version [{}], closing channel", + channel, + header.getVersion() + ), + e + ); + channel.close(); + } + } + } else { + final TcpTransportChannel transportChannel = new TcpTransportChannel( + outboundHandler, + channel, + action, + requestId, + version, + header.getFeatures(), + header.isCompressed(), + header.isHandshake(), + message.takeBreakerReleaseControl() + ); + TransportChannel traceableTransportChannel = TraceableTcpTransportChannel.create(transportChannel, span, tracer); + try { + messageListener.onRequestReceived(requestId, action); + if (message.isShortCircuit()) { + sendErrorResponse(action, traceableTransportChannel, message.getException()); + } else { + final StreamInput stream = namedWriteableStream(message.openOrGetStreamInput()); + assertRemoteVersion(stream, header.getVersion()); + final RequestHandlerRegistry reg = requestHandlers.getHandler(action); + assert reg != null; + + final T request = newRequest(requestId, action, stream, reg); + request.remoteAddress(new TransportAddress(channel.getRemoteAddress())); + checkStreamIsFullyConsumed(requestId, action, stream); + + final String executor = reg.getExecutor(); + if (ThreadPool.Names.SAME.equals(executor)) { + try { + reg.processMessageReceived(request, traceableTransportChannel); + } catch (Exception e) { + sendErrorResponse(reg.getAction(), traceableTransportChannel, e); + } + } else { + threadPool.executor(executor).execute(new RequestHandler<>(reg, request, traceableTransportChannel)); + } + } + } catch (Exception e) { + sendErrorResponse(action, traceableTransportChannel, e); + } + } + } catch (Exception e) { + span.setError(e); + span.endSpan(); + throw e; + } + } + + /** + * Creates new request instance out of input stream. Throws IllegalStateException if the end of + * the stream was reached before the request is fully deserialized from the stream. + * @param transport request type + * @param requestId request identifier + * @param action action name + * @param stream stream + * @param reg request handler registry + * @return new request instance + * @throws IOException IOException + * @throws IllegalStateException IllegalStateException + */ + private T newRequest( + final long requestId, + final String action, + final StreamInput stream, + final RequestHandlerRegistry reg + ) throws IOException { + try { + return reg.newRequest(stream); + } catch (final EOFException e) { + // Another favor of (de)serialization issues is when stream contains less bytes than + // the request handler needs to deserialize the payload. + throw new IllegalStateException( + "Message fully read (request) but more data is expected for requestId [" + + requestId + + "], action [" + + action + + "]; resetting", + e + ); + } + } + + /** + * Checks if the stream is fully consumed and throws the exceptions if that is not the case. + * @param requestId request identifier + * @param action action name + * @param stream stream + * @throws IOException IOException + */ + private void checkStreamIsFullyConsumed(final long requestId, final String action, final StreamInput stream) throws IOException { + // in case we throw an exception, i.e. when the limit is hit, we don't want to verify + final int nextByte = stream.read(); + + // calling read() is useful to make sure the message is fully read, even if there some kind of EOS marker + if (nextByte != -1) { + throw new IllegalStateException( + "Message not fully read (request) for requestId [" + + requestId + + "], action [" + + action + + "], available [" + + stream.available() + + "]; resetting" + ); + } + } + + /** + * Checks if the stream is fully consumed and throws the exceptions if that is not the case. + * @param requestId request identifier + * @param handler response handler + * @param stream stream + * @param error "true" if response represents error, "false" otherwise + * @throws IOException IOException + */ + private void checkStreamIsFullyConsumed( + final long requestId, + final TransportResponseHandler handler, + final StreamInput stream, + final boolean error + ) throws IOException { + if (stream != EMPTY_STREAM_INPUT) { + // Check the entire message has been read + final int nextByte = stream.read(); + // calling read() is useful to make sure the message is fully read, even if there is an EOS marker + if (nextByte != -1) { + throw new IllegalStateException( + "Message not fully read (response) for requestId [" + + requestId + + "], handler [" + + handler + + "], error [" + + error + + "]; resetting" + ); + } + } + } + + private static void sendErrorResponse(String actionName, TransportChannel transportChannel, Exception e) { + try { + transportChannel.sendResponse(e); + } catch (Exception inner) { + inner.addSuppressed(e); + logger.warn(() -> new ParameterizedMessage("Failed to send error message back to client for action [{}]", actionName), inner); + } + } + + private void handleResponse( + final long requestId, + InetSocketAddress remoteAddress, + final StreamInput stream, + final TransportResponseHandler handler + ) { + final T response; + try { + response = handler.read(stream); + response.remoteAddress(new TransportAddress(remoteAddress)); + checkStreamIsFullyConsumed(requestId, handler, stream, false); + } catch (Exception e) { + final Exception serializationException = new TransportSerializationException( + "Failed to deserialize response from handler [" + handler + "]", + e + ); + logger.warn(new ParameterizedMessage("Failed to deserialize response from [{}]", remoteAddress), serializationException); + handleException(handler, serializationException); + return; + } + final String executor = handler.executor(); + if (ThreadPool.Names.SAME.equals(executor)) { + doHandleResponse(handler, response); + } else { + threadPool.executor(executor).execute(() -> doHandleResponse(handler, response)); + } + } + + private void doHandleResponse(TransportResponseHandler handler, T response) { + try { + handler.handleResponse(response); + } catch (Exception e) { + handleException(handler, new ResponseHandlerFailureTransportException(e)); + } + } + + private void handlerResponseError(final long requestId, StreamInput stream, final TransportResponseHandler handler) { + Exception error; + try { + error = stream.readException(); + checkStreamIsFullyConsumed(requestId, handler, stream, true); + } catch (Exception e) { + error = new TransportSerializationException( + "Failed to deserialize exception response from stream for handler [" + handler + "]", + e + ); + } + handleException(handler, error); + } + + private void handleException(final TransportResponseHandler handler, Throwable error) { + if (!(error instanceof RemoteTransportException)) { + error = new RemoteTransportException(error.getMessage(), error); + } + final RemoteTransportException rtx = (RemoteTransportException) error; + threadPool.executor(handler.executor()).execute(() -> { + try { + handler.handleException(rtx); + } catch (Exception e) { + logger.error(() -> new ParameterizedMessage("failed to handle exception response [{}]", handler), e); + } + }); + } + + private StreamInput namedWriteableStream(StreamInput delegate) { + return new NamedWriteableAwareStreamInput(delegate, namedWriteableRegistry); + } + + static void assertRemoteVersion(StreamInput in, Version version) { + assert version.equals(in.getVersion()) : "Stream version [" + in.getVersion() + "] does not match version [" + version + "]"; + } + + /** + * Internal request handler + * + * @opensearch.internal + */ + private static class RequestHandler extends AbstractRunnable { + private final RequestHandlerRegistry reg; + private final T request; + private final TransportChannel transportChannel; + + RequestHandler(RequestHandlerRegistry reg, T request, TransportChannel transportChannel) { + this.reg = reg; + this.request = request; + this.transportChannel = transportChannel; + } + + @Override + protected void doRun() throws Exception { + reg.processMessageReceived(request, transportChannel); + } + + @Override + public boolean isForceExecution() { + return reg.isForceExecution(); + } + + @Override + public void onFailure(Exception e) { + sendErrorResponse(reg.getAction(), transportChannel, e); + } + } + +} diff --git a/server/src/main/java/org/opensearch/transport/ProtocolInboundMessage.java b/server/src/main/java/org/opensearch/transport/ProtocolInboundMessage.java new file mode 100644 index 0000000000000..43c2d5ffe4c96 --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/ProtocolInboundMessage.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.transport; + +import org.opensearch.common.annotation.PublicApi; + +/** + * Base class for inbound data as a message. + * Different implementations are used for different protocols. + * + * @opensearch.internal + */ +@PublicApi(since = "2.14.0") +public interface ProtocolInboundMessage { + + /** + * @return the protocol used to encode this message + */ + public String getProtocol(); + +} diff --git a/server/src/main/java/org/opensearch/transport/ProtocolMessageHandler.java b/server/src/main/java/org/opensearch/transport/ProtocolMessageHandler.java new file mode 100644 index 0000000000000..714d91d1e74c7 --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/ProtocolMessageHandler.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.transport; + +import java.io.IOException; + +/** + * Interface for message handlers based on transport protocol. + * + * @opensearch.internal + */ +public interface ProtocolMessageHandler { + + public void messageReceived( + TcpChannel channel, + ProtocolInboundMessage message, + long startTime, + long slowLogThresholdMs, + TransportMessageListener messageListener + ) throws IOException; +} diff --git a/server/src/main/java/org/opensearch/transport/TcpTransport.java b/server/src/main/java/org/opensearch/transport/TcpTransport.java index 7d45152089f37..e32bba5e836d3 100644 --- a/server/src/main/java/org/opensearch/transport/TcpTransport.java +++ b/server/src/main/java/org/opensearch/transport/TcpTransport.java @@ -762,12 +762,24 @@ protected void serverAcceptedChannel(TcpChannel channel) { protected abstract void stopInternal(); /** + * @deprecated use {@link #inboundMessage(TcpChannel, ProtocolInboundMessage)} * Handles inbound message that has been decoded. * * @param channel the channel the message is from * @param message the message */ + @Deprecated(since = "2.14.0", forRemoval = true) public void inboundMessage(TcpChannel channel, InboundMessage message) { + inboundMessage(channel, (ProtocolInboundMessage) message); + } + + /** + * Handles inbound message that has been decoded. + * + * @param channel the channel the message is from + * @param message the message + */ + public void inboundMessage(TcpChannel channel, ProtocolInboundMessage message) { try { inboundHandler.inboundMessage(channel, message); } catch (Exception e) { diff --git a/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundBytesHandler.java b/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundBytesHandler.java new file mode 100644 index 0000000000000..a8a4c0da7ec0f --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundBytesHandler.java @@ -0,0 +1,167 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.transport.nativeprotocol; + +import org.opensearch.common.bytes.ReleasableBytesReference; +import org.opensearch.common.lease.Releasable; +import org.opensearch.common.lease.Releasables; +import org.opensearch.core.common.bytes.CompositeBytesReference; +import org.opensearch.transport.Header; +import org.opensearch.transport.InboundAggregator; +import org.opensearch.transport.InboundBytesHandler; +import org.opensearch.transport.InboundDecoder; +import org.opensearch.transport.InboundMessage; +import org.opensearch.transport.ProtocolInboundMessage; +import org.opensearch.transport.StatsTracker; +import org.opensearch.transport.TcpChannel; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.function.BiConsumer; + +/** + * Handler for inbound bytes for the native protocol. + */ +public class NativeInboundBytesHandler implements InboundBytesHandler { + + private static final ThreadLocal> fragmentList = ThreadLocal.withInitial(ArrayList::new); + private static final InboundMessage PING_MESSAGE = new InboundMessage(null, true); + + private final ArrayDeque pending; + private final InboundDecoder decoder; + private final InboundAggregator aggregator; + private final StatsTracker statsTracker; + private boolean isClosed = false; + + public NativeInboundBytesHandler( + ArrayDeque pending, + InboundDecoder decoder, + InboundAggregator aggregator, + StatsTracker statsTracker + ) { + this.pending = pending; + this.decoder = decoder; + this.aggregator = aggregator; + this.statsTracker = statsTracker; + } + + @Override + public void close() { + isClosed = true; + } + + @Override + public boolean canHandleBytes(ReleasableBytesReference reference) { + return true; + } + + @Override + public void doHandleBytes( + TcpChannel channel, + ReleasableBytesReference reference, + BiConsumer messageHandler + ) throws IOException { + final ArrayList fragments = fragmentList.get(); + boolean continueHandling = true; + + while (continueHandling && isClosed == false) { + boolean continueDecoding = true; + while (continueDecoding && pending.isEmpty() == false) { + try (ReleasableBytesReference toDecode = getPendingBytes()) { + final int bytesDecoded = decoder.decode(toDecode, fragments::add); + if (bytesDecoded != 0) { + releasePendingBytes(bytesDecoded); + if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) { + continueDecoding = false; + } + } else { + continueDecoding = false; + } + } + } + + if (fragments.isEmpty()) { + continueHandling = false; + } else { + try { + forwardFragments(channel, fragments, messageHandler); + } finally { + for (Object fragment : fragments) { + if (fragment instanceof ReleasableBytesReference) { + ((ReleasableBytesReference) fragment).close(); + } + } + fragments.clear(); + } + } + } + } + + private ReleasableBytesReference getPendingBytes() { + if (pending.size() == 1) { + return pending.peekFirst().retain(); + } else { + final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()]; + int index = 0; + for (ReleasableBytesReference pendingReference : pending) { + bytesReferences[index] = pendingReference.retain(); + ++index; + } + final Releasable releasable = () -> Releasables.closeWhileHandlingException(bytesReferences); + return new ReleasableBytesReference(CompositeBytesReference.of(bytesReferences), releasable); + } + } + + private void releasePendingBytes(int bytesConsumed) { + int bytesToRelease = bytesConsumed; + while (bytesToRelease != 0) { + try (ReleasableBytesReference reference = pending.pollFirst()) { + assert reference != null; + if (bytesToRelease < reference.length()) { + pending.addFirst(reference.retainedSlice(bytesToRelease, reference.length() - bytesToRelease)); + bytesToRelease -= bytesToRelease; + } else { + bytesToRelease -= reference.length(); + } + } + } + } + + private boolean endOfMessage(Object fragment) { + return fragment == InboundDecoder.PING || fragment == InboundDecoder.END_CONTENT || fragment instanceof Exception; + } + + private void forwardFragments( + TcpChannel channel, + ArrayList fragments, + BiConsumer messageHandler + ) throws IOException { + for (Object fragment : fragments) { + if (fragment instanceof Header) { + assert aggregator.isAggregating() == false; + aggregator.headerReceived((Header) fragment); + } else if (fragment == InboundDecoder.PING) { + assert aggregator.isAggregating() == false; + messageHandler.accept(channel, PING_MESSAGE); + } else if (fragment == InboundDecoder.END_CONTENT) { + assert aggregator.isAggregating(); + try (InboundMessage aggregated = aggregator.finishAggregation()) { + statsTracker.markMessageReceived(); + messageHandler.accept(channel, aggregated); + } + } else { + assert aggregator.isAggregating(); + assert fragment instanceof ReleasableBytesReference; + aggregator.aggregate((ReleasableBytesReference) fragment); + } + } + } + +} diff --git a/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundMessage.java b/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundMessage.java new file mode 100644 index 0000000000000..1143f129b6319 --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/nativeprotocol/NativeInboundMessage.java @@ -0,0 +1,149 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.transport.nativeprotocol; + +import org.opensearch.common.annotation.PublicApi; +import org.opensearch.common.bytes.ReleasableBytesReference; +import org.opensearch.common.lease.Releasable; +import org.opensearch.common.lease.Releasables; +import org.opensearch.common.util.io.IOUtils; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.transport.Header; +import org.opensearch.transport.ProtocolInboundMessage; + +import java.io.IOException; + +/** + * Inbound data as a message + * + * @opensearch.api + */ +@PublicApi(since = "2.14.0") +public class NativeInboundMessage implements Releasable, ProtocolInboundMessage { + + /** + * The protocol used to encode this message + */ + public static String NATIVE_PROTOCOL = "native"; + + private final Header header; + private final ReleasableBytesReference content; + private final Exception exception; + private final boolean isPing; + private Releasable breakerRelease; + private StreamInput streamInput; + + public NativeInboundMessage(Header header, ReleasableBytesReference content, Releasable breakerRelease) { + this.header = header; + this.content = content; + this.breakerRelease = breakerRelease; + this.exception = null; + this.isPing = false; + } + + public NativeInboundMessage(Header header, Exception exception) { + this.header = header; + this.content = null; + this.breakerRelease = null; + this.exception = exception; + this.isPing = false; + } + + public NativeInboundMessage(Header header, boolean isPing) { + this.header = header; + this.content = null; + this.breakerRelease = null; + this.exception = null; + this.isPing = isPing; + } + + @Override + public String getProtocol() { + return NATIVE_PROTOCOL; + } + + public Header getHeader() { + return header; + } + + public int getContentLength() { + if (content == null) { + return 0; + } else { + return content.length(); + } + } + + public Exception getException() { + return exception; + } + + public boolean isPing() { + return isPing; + } + + public boolean isShortCircuit() { + return exception != null; + } + + public Releasable takeBreakerReleaseControl() { + final Releasable toReturn = breakerRelease; + breakerRelease = null; + if (toReturn != null) { + return toReturn; + } else { + return () -> {}; + } + } + + public StreamInput openOrGetStreamInput() throws IOException { + assert isPing == false && content != null; + if (streamInput == null) { + streamInput = content.streamInput(); + streamInput.setVersion(header.getVersion()); + } + return streamInput; + } + + @Override + public void close() { + IOUtils.closeWhileHandlingException(streamInput); + Releasables.closeWhileHandlingException(content, breakerRelease); + } + + @Override + public String toString() { + return "InboundMessage{" + header + "}"; + } + +} diff --git a/server/src/main/java/org/opensearch/transport/nativeprotocol/package-info.java b/server/src/main/java/org/opensearch/transport/nativeprotocol/package-info.java new file mode 100644 index 0000000000000..84f6d7d0ec5c2 --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/nativeprotocol/package-info.java @@ -0,0 +1,10 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** Native transport protocol package. */ +package org.opensearch.transport.nativeprotocol; diff --git a/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java index e002297911788..0d171e17e70e1 100644 --- a/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundHandlerTests.java @@ -275,11 +275,11 @@ public void testClosesChannelOnErrorInHandshakeWithIncompatibleVersion() throws // response so we must just close the connection on an error. To avoid the failure disappearing into a black hole we at least log // it. - try (MockLogAppender mockAppender = MockLogAppender.createForLoggers(LogManager.getLogger(InboundHandler.class))) { + try (MockLogAppender mockAppender = MockLogAppender.createForLoggers(LogManager.getLogger(NativeMessageHandler.class))) { mockAppender.addExpectation( new MockLogAppender.SeenEventExpectation( "expected message", - InboundHandler.class.getCanonicalName(), + NativeMessageHandler.class.getCanonicalName(), Level.WARN, "could not send error response to handshake" ) @@ -308,11 +308,11 @@ public void testClosesChannelOnErrorInHandshakeWithIncompatibleVersion() throws } public void testLogsSlowInboundProcessing() throws Exception { - try (MockLogAppender mockAppender = MockLogAppender.createForLoggers(LogManager.getLogger(InboundHandler.class))) { + try (MockLogAppender mockAppender = MockLogAppender.createForLoggers(LogManager.getLogger(NativeMessageHandler.class))) { mockAppender.addExpectation( new MockLogAppender.SeenEventExpectation( "expected message", - InboundHandler.class.getCanonicalName(), + NativeMessageHandler.class.getCanonicalName(), Level.WARN, "handling inbound transport message " ) diff --git a/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java b/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java index ae4b537223394..2dfe8a0dd8590 100644 --- a/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java +++ b/server/src/test/java/org/opensearch/transport/InboundPipelineTests.java @@ -72,24 +72,25 @@ public void testPipelineHandling() throws IOException { final List> expected = new ArrayList<>(); final List> actual = new ArrayList<>(); final List toRelease = new ArrayList<>(); - final BiConsumer messageHandler = (c, m) -> { + final BiConsumer messageHandler = (c, m) -> { try { - final Header header = m.getHeader(); + InboundMessage message = (InboundMessage) m; + final Header header = message.getHeader(); final MessageData actualData; final Version version = header.getVersion(); final boolean isRequest = header.isRequest(); final long requestId = header.getRequestId(); final boolean isCompressed = header.isCompressed(); - if (m.isShortCircuit()) { + if (message.isShortCircuit()) { actualData = new MessageData(version, requestId, isRequest, isCompressed, header.getActionName(), null); } else if (isRequest) { - final TestRequest request = new TestRequest(m.openOrGetStreamInput()); + final TestRequest request = new TestRequest(message.openOrGetStreamInput()); actualData = new MessageData(version, requestId, isRequest, isCompressed, header.getActionName(), request.value); } else { - final TestResponse response = new TestResponse(m.openOrGetStreamInput()); + final TestResponse response = new TestResponse(message.openOrGetStreamInput()); actualData = new MessageData(version, requestId, isRequest, isCompressed, null, response.value); } - actual.add(new Tuple<>(actualData, m.getException())); + actual.add(new Tuple<>(actualData, message.getException())); } catch (IOException e) { throw new AssertionError(e); } @@ -214,7 +215,7 @@ public void testPipelineHandling() throws IOException { } public void testDecodeExceptionIsPropagated() throws IOException { - BiConsumer messageHandler = (c, m) -> {}; + BiConsumer messageHandler = (c, m) -> {}; final StatsTracker statsTracker = new StatsTracker(); final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime()); final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); @@ -268,7 +269,7 @@ public void testDecodeExceptionIsPropagated() throws IOException { } public void testEnsureBodyIsNotPrematurelyReleased() throws IOException { - BiConsumer messageHandler = (c, m) -> {}; + BiConsumer messageHandler = (c, m) -> {}; final StatsTracker statsTracker = new StatsTracker(); final LongSupplier millisSupplier = () -> TimeValue.nsecToMSec(System.nanoTime()); final InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE); diff --git a/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java b/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java index ff99435f765d8..36ba409a2de03 100644 --- a/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java +++ b/server/src/test/java/org/opensearch/transport/OutboundHandlerTests.java @@ -97,8 +97,9 @@ public void setUp() throws Exception { final InboundAggregator aggregator = new InboundAggregator(breaker, (Predicate) action -> true); pipeline = new InboundPipeline(statsTracker, millisSupplier, decoder, aggregator, (c, m) -> { try (BytesStreamOutput streamOutput = new BytesStreamOutput()) { - Streams.copy(m.openOrGetStreamInput(), streamOutput); - message.set(new Tuple<>(m.getHeader(), streamOutput.bytes())); + InboundMessage m1 = (InboundMessage) m; + Streams.copy(m1.openOrGetStreamInput(), streamOutput); + message.set(new Tuple<>(m1.getHeader(), streamOutput.bytes())); } catch (IOException e) { throw new AssertionError(e); }