diff --git a/src/main/java/io/lettuce/core/protocol/DefaultEndpoint.java b/src/main/java/io/lettuce/core/protocol/DefaultEndpoint.java index d3b9e751fe..3ef9fb5a2e 100644 --- a/src/main/java/io/lettuce/core/protocol/DefaultEndpoint.java +++ b/src/main/java/io/lettuce/core/protocol/DefaultEndpoint.java @@ -63,11 +63,11 @@ public class DefaultEndpoint implements RedisChannelWriter, Endpoint, PushHandle private static final AtomicLong ENDPOINT_COUNTER = new AtomicLong(); - private static final AtomicIntegerFieldUpdater QUEUE_SIZE = AtomicIntegerFieldUpdater - .newUpdater(DefaultEndpoint.class, "queueSize"); + private static final AtomicIntegerFieldUpdater QUEUE_SIZE = AtomicIntegerFieldUpdater.newUpdater( + DefaultEndpoint.class, "queueSize"); - private static final AtomicIntegerFieldUpdater STATUS = AtomicIntegerFieldUpdater - .newUpdater(DefaultEndpoint.class, "status"); + private static final AtomicIntegerFieldUpdater STATUS = AtomicIntegerFieldUpdater.newUpdater( + DefaultEndpoint.class, "status"); private static final int ST_OPEN = 0; @@ -105,6 +105,8 @@ public class DefaultEndpoint implements RedisChannelWriter, Endpoint, PushHandle private boolean inActivation = false; + private Channel inActivationChannel; // No need to be volatile, since it is always accessed from the same thread. + private ConnectionWatchdog connectionWatchdog; private ConnectionFacade connectionFacade; @@ -191,9 +193,9 @@ public RedisCommand write(RedisCommand command) { } if (autoFlushCommands) { - - if (isConnected()) { - writeToChannelAndFlush(command); + Channel channel = inActivation ? inActivationChannel : this.channel; + if (isConnected(channel)) { + writeToChannelAndFlush(channel, command); } else { writeToDisconnectedBuffer(command); } @@ -232,9 +234,9 @@ public RedisCommand write(RedisCommand command) { } if (autoFlushCommands) { - - if (isConnected()) { - writeToChannelAndFlush(commands); + Channel channel = inActivation ? inActivationChannel : this.channel; + if (isConnected(channel)) { + writeToChannelAndFlush(channel, commands); } else { writeToDisconnectedBuffer(commands); } @@ -284,10 +286,9 @@ private RedisException validateWrite(int commands) { return new RedisException("Connection is closed"); } + final boolean connected = isConnected(this.channel); if (usesBoundedQueues()) { - boolean connected = isConnected(); - if (QUEUE_SIZE.get(this) + commands > clientOptions.getRequestQueueSize()) { return new RedisException("Request queue size exceeded: " + clientOptions.getRequestQueueSize() + ". Commands are not accepted until the queue size drops."); @@ -304,7 +305,7 @@ private RedisException validateWrite(int commands) { } } - if (!isConnected() && rejectCommandsWhileDisconnected) { + if (!connected && rejectCommandsWhileDisconnected) { return new RedisException("Currently not connected. Commands are rejected."); } @@ -366,11 +367,11 @@ private void writeToDisconnectedBuffer(RedisCommand command) { commandBuffer.add(command); } - private void writeToChannelAndFlush(RedisCommand command) { + private void writeToChannelAndFlush(Channel channel, RedisCommand command) { QUEUE_SIZE.incrementAndGet(this); - ChannelFuture channelFuture = channelWriteAndFlush(command); + ChannelFuture channelFuture = channelWriteAndFlush(channel, command); if (reliability == Reliability.AT_MOST_ONCE) { // cancel on exceptions and remove from queue, because there is no housekeeping @@ -383,7 +384,7 @@ private void writeToChannelAndFlush(RedisCommand command) { } } - private void writeToChannelAndFlush(Collection> commands) { + private void writeToChannelAndFlush(Channel channel, Collection> commands) { QUEUE_SIZE.addAndGet(this, commands.size()); @@ -391,7 +392,7 @@ private void writeToChannelAndFlush(Collection> // cancel on exceptions and remove from queue, because there is no housekeeping for (RedisCommand command : commands) { - channelWrite(command).addListener(AtMostOnceWriteListener.newInstance(this, command)); + channelWrite(channel, command).addListener(AtMostOnceWriteListener.newInstance(this, command)); } } @@ -399,14 +400,14 @@ private void writeToChannelAndFlush(Collection> // commands are ok to stay within the queue, reconnect will retrigger them for (RedisCommand command : commands) { - channelWrite(command).addListener(RetryListener.newInstance(this, command)); + channelWrite(channel, command).addListener(RetryListener.newInstance(this, command)); } } - channelFlush(); + channelFlush(channel); } - private void channelFlush() { + private void channelFlush(Channel channel) { if (debugEnabled) { logger.debug("{} write() channelFlush", logPrefix()); @@ -415,7 +416,7 @@ private void channelFlush() { channel.flush(); } - private ChannelFuture channelWrite(RedisCommand command) { + private ChannelFuture channelWrite(Channel channel, RedisCommand command) { if (debugEnabled) { logger.debug("{} write() channelWrite command {}", logPrefix(), command); @@ -424,7 +425,7 @@ private ChannelFuture channelWrite(RedisCommand command) { return channel.write(command); } - private ChannelFuture channelWriteAndFlush(RedisCommand command) { + private ChannelFuture channelWriteAndFlush(Channel channel, RedisCommand command) { if (debugEnabled) { logger.debug("{} write() writeAndFlush command {}", logPrefix(), command); @@ -437,7 +438,6 @@ private ChannelFuture channelWriteAndFlush(RedisCommand command) { public void notifyChannelActive(Channel channel) { this.logPrefix = null; - this.channel = channel; this.connectionError = null; if (isClosed()) { @@ -468,13 +468,15 @@ public void notifyChannelActive(Channel channel) { } try { + this.inActivationChannel = channel; inActivation = true; connectionFacade.activated(); } finally { inActivation = false; + this.inActivationChannel = null; } - flushCommands(disconnectedBuffer); + flushCommands(channel, disconnectedBuffer); } catch (Exception e) { if (debugEnabled) { @@ -486,6 +488,8 @@ public void notifyChannelActive(Channel channel) { } throw e; + } finally { + this.channel = channel; } }); } @@ -527,7 +531,7 @@ public void notifyException(Throwable t) { doExclusive(this::drainCommands).forEach(cmd -> cmd.completeExceptionally(t)); } - if (!isConnected()) { + if (!isConnected(this.channel)) { connectionError = t; } } @@ -540,16 +544,16 @@ public void registerConnectionWatchdog(ConnectionWatchdog connectionWatchdog) { @Override @SuppressWarnings({ "rawtypes", "unchecked" }) public void flushCommands() { - flushCommands(commandBuffer); + flushCommands(this.channel, commandBuffer); } - private void flushCommands(Queue> queue) { + private void flushCommands(Channel channel, Queue> queue) { if (debugEnabled) { logger.debug("{} flushCommands()", logPrefix()); } - if (isConnected()) { + if (isConnected(channel)) { List> commands = sharedLock.doExclusive(() -> { @@ -565,7 +569,7 @@ private void flushCommands(Queue> queue) { } if (!commands.isEmpty()) { - writeToChannelAndFlush(commands); + writeToChannelAndFlush(channel, commands); } } } @@ -628,10 +632,10 @@ public void disconnect() { private Channel getOpenChannel() { - Channel currentChannel = this.channel; + Channel channel = this.channel; - if (currentChannel != null) { - return currentChannel; + if (channel != null /* && channel.isOpen() is this deliberately omitted? */) { + return channel; } return null; @@ -648,6 +652,7 @@ public void reset() { logger.debug("{} reset()", logPrefix()); } + Channel channel = this.channel; if (channel != null) { channel.pipeline().fireUserEventTriggered(new ConnectionEvents.Reset()); } @@ -720,8 +725,9 @@ public void notifyDrainQueuedCommands(HasQueuedCommands queuedCommands) { } } - if (isConnected()) { - flushCommands(disconnectedBuffer); + Channel channel = this.channel; + if (isConnected(channel)) { + flushCommands(channel, disconnectedBuffer); } }); } @@ -787,9 +793,7 @@ private void cancelCommands(String message, Iterable(CommandType.SELECT, new StatusOutput<>(StringCodec.UTF8)))); + + sut.registerConnectionWatchdog(connectionWatchdog); + doAnswer(i -> sut.write(new Command<>(CommandType.AUTH, new StatusOutput<>(StringCodec.UTF8)))).when(connectionWatchdog) + .arm(); + when(channel.isActive()).thenReturn(true); + + sut.notifyChannelActive(channel); + + DefaultChannelPromise promise = new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE); + + when(channel.writeAndFlush(any())).thenAnswer(invocation -> { + if (invocation.getArguments()[0] instanceof RedisCommand) { + queue.add((RedisCommand) invocation.getArguments()[0]); + } + + if (invocation.getArguments()[0] instanceof Collection) { + queue.addAll((Collection) invocation.getArguments()[0]); + } + return promise; + }); + + assertThat(queue).hasSize(2).first().hasFieldOrPropertyWithValue("type", CommandType.SELECT); + assertThat(queue).hasSize(2).last().hasFieldOrPropertyWithValue("type", CommandType.AUTH); + } + @Test void writeConnectedShouldWriteCommandToChannel() { @@ -396,11 +424,9 @@ void shouldNotReplayActivationCommands() { when(channel.isActive()).thenReturn(true); ConnectionTestUtil.getDisconnectedBuffer(sut) - .add(new ActivationCommand<>( - new Command<>(CommandType.SELECT, new StatusOutput<>(StringCodec.UTF8)))); + .add(new ActivationCommand<>(new Command<>(CommandType.SELECT, new StatusOutput<>(StringCodec.UTF8)))); ConnectionTestUtil.getDisconnectedBuffer(sut).add(new LatencyMeteredCommand<>( - new ActivationCommand<>( - new Command<>(CommandType.SUBSCRIBE, new StatusOutput<>(StringCodec.UTF8))))); + new ActivationCommand<>(new Command<>(CommandType.SUBSCRIBE, new StatusOutput<>(StringCodec.UTF8))))); doAnswer(i -> {