diff --git a/src/main/java/io/lettuce/core/protocol/DefaultAutoBatchFlushEndpoint.java b/src/main/java/io/lettuce/core/protocol/DefaultAutoBatchFlushEndpoint.java index 8053cf3eb3..6cfd36aaf5 100644 --- a/src/main/java/io/lettuce/core/protocol/DefaultAutoBatchFlushEndpoint.java +++ b/src/main/java/io/lettuce/core/protocol/DefaultAutoBatchFlushEndpoint.java @@ -144,7 +144,7 @@ protected static void cancelCommandOnEndpointClose(RedisCommand cmd) { private final boolean debugEnabled = logger.isDebugEnabled(); - protected final CompletableFuture closeFuture = new CompletableFuture<>(); + private final CompletableFuture closeFuture = new CompletableFuture<>(); private String logPrefix; @@ -152,17 +152,17 @@ protected static void cancelCommandOnEndpointClose(RedisCommand cmd) { private boolean inActivation = false; - protected @Nullable ConnectionWatchdog connectionWatchdog; + private @Nullable ConnectionWatchdog connectionWatchdog; private ConnectionFacade connectionFacade; private final String cachedEndpointId; - protected final UnboundedOfferFirstQueue taskQueue; + private final UnboundedOfferFirstQueue taskQueue; - private final boolean canFire; + private final OwnershipSynchronizer taskQueueConsumeSync; // make sure only one consumer exists at any given time - private volatile EventExecutor lastEventExecutor; + private final boolean canFire; private volatile Throwable connectionError; @@ -172,8 +172,6 @@ protected static void cancelCommandOnEndpointClose(RedisCommand cmd) { private final int batchSize; - private final boolean usesMpscQueue; - /** * Create a new {@link AutoBatchFlushEndpoint}. * @@ -197,13 +195,14 @@ protected DefaultAutoBatchFlushEndpoint(ClientOptions clientOptions, ClientResou this.rejectCommandsWhileDisconnected = isRejectCommand(clientOptions); long endpointId = ENDPOINT_COUNTER.incrementAndGet(); this.cachedEndpointId = "0x" + Long.toHexString(endpointId); - this.usesMpscQueue = clientOptions.getAutoBatchFlushOptions().usesMpscQueue(); - this.taskQueue = usesMpscQueue ? new JcToolsUnboundedMpscOfferFirstQueue<>() : new ConcurrentLinkedOfferFirstQueue<>(); + this.taskQueue = clientOptions.getAutoBatchFlushOptions().usesMpscQueue() ? new JcToolsUnboundedMpscOfferFirstQueue<>() + : new ConcurrentLinkedOfferFirstQueue<>(); this.canFire = false; this.callbackOnClose = callbackOnClose; this.writeSpinCount = clientOptions.getAutoBatchFlushOptions().getWriteSpinCount(); this.batchSize = clientOptions.getAutoBatchFlushOptions().getBatchSize(); - this.lastEventExecutor = clientResources.eventExecutorGroup().next(); + this.taskQueueConsumeSync = new OwnershipSynchronizer(clientResources.eventExecutorGroup().next(), + Thread.currentThread().getName(), true/* allows to be preempted by first event loop thread */); } @Override @@ -324,7 +323,8 @@ public void notifyChannelActive(Channel channel) { return; } - this.lastEventExecutor = channel.eventLoop(); + this.taskQueueConsumeSync.preempt(channel.eventLoop(), Thread.currentThread().getName(), + false /* disallow preempt until reached quiescent point, see onEndpointQuiescence() */); this.connectionError = null; this.inProtectMode = false; this.logPrefix = null; @@ -379,7 +379,7 @@ public void notifyReconnectFailed(Throwable t) { return; } - syncAfterTerminated(() -> { + taskQueueConsumeSync.execute(() -> { if (isClosed()) { onEndpointClosed(); } else { @@ -474,10 +474,10 @@ public void flushCommands() { final ContextualChannel chan = this.channel; switch (chan.context.initialState) { case ENDPOINT_CLOSED: - syncAfterTerminated(this::onEndpointClosed); + taskQueueConsumeSync.execute(this::onEndpointClosed); return; case RECONNECT_FAILED: - syncAfterTerminated(() -> { + taskQueueConsumeSync.execute(() -> { if (isClosed()) { onEndpointClosed(); } else { @@ -563,7 +563,6 @@ public void disconnect() { */ @Override public void reset() { - if (debugEnabled) { logger.debug("{} reset()", logPrefix()); } @@ -572,10 +571,7 @@ public void reset() { if (chan.context.initialState.isConnected()) { chan.pipeline().fireUserEventTriggered(new ConnectionEvents.Reset()); } - if (!usesMpscQueue) { - cancelCommands("reset"); - } - // Otherwise, unsafe to call cancelBufferedCommands() here. + taskQueueConsumeSync.execute(() -> cancelCommands("reset")); } private void resetInternal() { @@ -587,7 +583,6 @@ private void resetInternal() { if (chan.context.initialState.isConnected()) { chan.pipeline().fireUserEventTriggered(new ConnectionEvents.Reset()); } - LettuceAssert.assertState(lastEventExecutor.inEventLoop(), "must be called in lastEventLoop thread"); cancelCommands("resetInternal"); } @@ -596,10 +591,8 @@ private void resetInternal() { */ @Override public void initialState() { - if (!usesMpscQueue) { - cancelCommands("initialState"); - } - // Otherwise, unsafe to call cancelBufferedCommands() here. + taskQueueConsumeSync.execute(() -> cancelCommands("initialState")); + ContextualChannel currentChannel = this.channel; if (currentChannel.context.initialState.isConnected()) { ChannelFuture close = currentChannel.close(); @@ -637,8 +630,6 @@ public String getId() { } private void scheduleSendJobOnConnected(final ContextualChannel chan) { - LettuceAssert.assertState(chan.eventLoop().inEventLoop(), "must be called in event loop thread"); - // Schedule directly loopSend(chan, false); } @@ -758,7 +749,6 @@ private int pollBatch(final AutoBatchFlushEndPointContext autoBatchFlushEndPoint private void trySetEndpointQuiescence(ContextualChannel chan) { final EventLoop eventLoop = chan.eventLoop(); LettuceAssert.isTrue(eventLoop.inEventLoop(), "unexpected: not in event loop"); - LettuceAssert.isTrue(eventLoop == lastEventExecutor, "unexpected: lastEventLoop not match"); final ConnectionContext connectionContext = chan.context; final @Nullable ConnectionContext.CloseStatus closeStatus = connectionContext.getCloseStatus(); @@ -827,6 +817,8 @@ private void onWontReconnect(@Nonnull final ConnectionContext.CloseStatus closeS } private void onEndpointQuiescence() { + taskQueueConsumeSync.done(1); // allows preemption + if (channel.context.initialState == ConnectionContext.State.ENDPOINT_CLOSED) { return; } @@ -864,7 +856,7 @@ private final void onEndpointClosed(Queue>... queues) { fulfillCommands("endpoint closed", callbackOnClose, queues); } - private final void onReconnectFailed() { + private void onReconnectFailed() { fulfillCommands("reconnect failed", cmd -> cmd.completeExceptionally(getFailedToReconnectReason())); } @@ -996,7 +988,7 @@ private Throwable validateWrite(ContextualChannel chan, int commands, boolean is private void onUnexpectedState(String caller, ConnectionContext.State exp) { final ConnectionContext.State actual = this.channel.context.initialState; logger.error("{}[{}][unexpected] : unexpected state: exp '{}' got '{}'", logPrefix(), caller, exp, actual); - syncAfterTerminated( + taskQueueConsumeSync.execute( () -> cancelCommands(String.format("%s: state not match: expect '%s', got '%s'", caller, exp, actual))); } @@ -1017,23 +1009,6 @@ private ChannelFuture channelWrite(Channel channel, RedisCommand comman return channel.write(command); } - /* - * Synchronize after the endpoint is terminated. This is to ensure only one thread can access the task queue after endpoint - * is terminated (state is RECONNECT_FAILED/ENDPOINT_CLOSED) - */ - private void syncAfterTerminated(Runnable runnable) { - final EventExecutor localLastEventExecutor = lastEventExecutor; - if (localLastEventExecutor.inEventLoop()) { - runnable.run(); - } else { - localLastEventExecutor.execute(() -> { - runnable.run(); - LettuceAssert.isTrue(lastEventExecutor == localLastEventExecutor, - "lastEventLoop must not be changed after terminated"); - }); - } - } - private enum Reliability { AT_MOST_ONCE, AT_LEAST_ONCE } @@ -1103,7 +1078,7 @@ public void operationComplete(Future future) { final Throwable retryableErr = checkSendResult(future); if (retryableErr != null && autoBatchFlushEndPointContext.addRetryableFailedToSendCommand(cmd, retryableErr)) { - // Close connection on first transient write failure + // Close connection on first transient write failure. internalCloseConnectionIfNeeded(retryableErr); } @@ -1163,6 +1138,7 @@ private void internalCloseConnectionIfNeeded(Throwable reason) { return; } + // It is really rare (maybe impossible?) that the connection is still active. logger.error( "[internalCloseConnectionIfNeeded][interesting][{}] close the connection due to write error, reason: '{}'", endpoint.logPrefix(), reason.getMessage(), reason); @@ -1184,4 +1160,145 @@ private void recycle() { } + public static class OwnershipSynchronizer { + + private static class Owner { + + private final EventExecutor thread; + + private final String threadName; + + // if positive, no other thread can preempt the ownership. + private final int semaphore; + + public Owner(EventExecutor thread, String threadName, int semaphore) { + LettuceAssert.assertState(semaphore >= 0, () -> String.format("negative semaphore: %d", semaphore)); + this.thread = thread; + this.threadName = threadName; + this.semaphore = semaphore; + } + + public boolean isCurrentThread() { + return thread.inEventLoop(); + } + + public Owner toAdd(int n) { + return new Owner(thread, threadName, semaphore + n); + } + + public Owner toDone(int n) { + return new Owner(thread, threadName, semaphore - n); + } + + public boolean isDone() { + return semaphore == 0; + } + + } + + private static final AtomicReferenceFieldUpdater OWNER = AtomicReferenceFieldUpdater + .newUpdater(OwnershipSynchronizer.class, Owner.class, "owner"); + + private volatile Owner owner; + + public OwnershipSynchronizer(EventExecutor thread, String threadName, boolean allowsPreemptByOtherThreads) { + this.owner = new Owner(thread, threadName, allowsPreemptByOtherThreads ? 0 : 1); + } + + /** + * Preempt ownership only when there is no running tasks in current owner + * + * @param thread new thread + * @param threadName thread name + * @param allowsPreemptByOtherThreads whether allows a third thread to preempt after @param `thread` preempts from + * current owner thread, if true, initial running task number will be set to 1. + */ + public void preempt(EventExecutor thread, String threadName, boolean allowsPreemptByOtherThreads) { + Owner cur; + Owner newOwner = null; + while (true) { + cur = this.owner; + if (cur.thread == thread) { + if (allowsPreemptByOtherThreads) { + return; + } + if (OWNER.compareAndSet(this, cur, cur.toAdd(1))) { // prevent preempt + return; + } + continue; + } + + if (!cur.isDone()) { + // unsafe to preempt + continue; + } + + if (newOwner == null) { + newOwner = new Owner(thread, threadName, allowsPreemptByOtherThreads ? 0 : 1); + } + if (OWNER.compareAndSet(this, cur, newOwner)) { + logger.debug("ownership preempted by a new thread [{}]", threadName); + // established happens-before with done() + return; + } + } + } + + /** + * done n tasks in current owner. + * + * @param n number of tasks to be done. + */ + public void done(int n) { + Owner cur; + do { + cur = this.owner; + assertIsOwnerThreadAndPreemptPrevented(cur); + } while (!OWNER.compareAndSet(this, cur, cur.toDone(n))); + // create happens-before with preempt() + } + + /** + * Safely run a task in current owner thread and release its memory effect to next owner thread. + * + * @param task task to run + */ + public void execute(Runnable task) { + Owner cur; + do { + cur = this.owner; + if (isOwnerCurrentThreadAndPreemptPrevented(cur)) { + // already prevented preemption, safe to skip expensive add/done calls + task.run(); + return; + } + } while (!OWNER.compareAndSet(this, cur, cur.toAdd(1))); + + if (cur.isCurrentThread()) { + executeInOwnerWithPreemptPrevention(task); + } else { + cur.thread.execute(() -> executeInOwnerWithPreemptPrevention(task)); + } + } + + private void executeInOwnerWithPreemptPrevention(Runnable task) { + try { + task.run(); + } finally { + done(1); + } + } + + private void assertIsOwnerThreadAndPreemptPrevented(Owner cur) { + LettuceAssert.assertState(isOwnerCurrentThreadAndPreemptPrevented(cur), + () -> "[executeInOwnerWithPreemptPrevention] unexpected: " + + (cur.isCurrentThread() ? "preemption not prevented" : "owner is not this thread")); + } + + private boolean isOwnerCurrentThreadAndPreemptPrevented(Owner owner) { + return owner.isCurrentThread() && !owner.isDone(); + } + + } + }