diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c920c71d0..d5d0620a5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -51,7 +51,7 @@ jobs: creds: ${{ secrets.AZURE_ACI_CREDENTIALS }} enable-AzPSSession: true - name: Setup RabbitMQ - uses: Particular/setup-rabbitmq-action@v1.6.0 + uses: Particular/setup-rabbitmq-action@v1.7.0 with: connection-string-name: RabbitMQTransport_ConnectionString tag: RabbitMQTransport diff --git a/src/NServiceBus.Transport.RabbitMQ.Tests/Connection/ChannelProviderTests.cs b/src/NServiceBus.Transport.RabbitMQ.Tests/Connection/ChannelProviderTests.cs new file mode 100644 index 000000000..cac893b3d --- /dev/null +++ b/src/NServiceBus.Transport.RabbitMQ.Tests/Connection/ChannelProviderTests.cs @@ -0,0 +1,168 @@ +namespace NServiceBus.Transport.RabbitMQ.Tests.ConnectionString +{ + using System; + using System.Collections.Generic; + using System.Threading; + using System.Threading.Tasks; + using global::RabbitMQ.Client; + using global::RabbitMQ.Client.Events; + using NUnit.Framework; + + [TestFixture] + public class ChannelProviderTests + { + [Test] + public async Task Should_recover_connection_and_dispose_old_one_when_connection_shutdown() + { + var channelProvider = new TestableChannelProvider(); + channelProvider.CreateConnection(); + + var publishConnection = channelProvider.PublishConnections.Dequeue(); + publishConnection.RaiseConnectionShutdown(new ShutdownEventArgs(ShutdownInitiator.Library, 0, "Test")); + + channelProvider.DelayTaskCompletionSource.SetResult(); + + await channelProvider.FireAndForgetAction(CancellationToken.None); + + var recoveredConnection = channelProvider.PublishConnections.Dequeue(); + + Assert.That(publishConnection.WasDisposed, Is.True); + Assert.That(recoveredConnection.WasDisposed, Is.False); + } + + [Test] + public void Should_dispose_connection_when_disposed() + { + var channelProvider = new TestableChannelProvider(); + channelProvider.CreateConnection(); + + var publishConnection = channelProvider.PublishConnections.Dequeue(); + channelProvider.Dispose(); + + Assert.That(publishConnection.WasDisposed, Is.True); + } + + [Test] + public async Task Should_not_attempt_to_recover_during_dispose_when_retry_delay_still_pending() + { + var channelProvider = new TestableChannelProvider(); + channelProvider.CreateConnection(); + + var publishConnection = channelProvider.PublishConnections.Dequeue(); + publishConnection.RaiseConnectionShutdown(new ShutdownEventArgs(ShutdownInitiator.Library, 0, "Test")); + + // Deliberately not completing the delay task with channelProvider.DelayTaskCompletionSource.SetResult(); before disposing + // to simulate a pending delay task + channelProvider.Dispose(); + + await channelProvider.FireAndForgetAction(CancellationToken.None); + + Assert.That(publishConnection.WasDisposed, Is.True); + Assert.That(channelProvider.PublishConnections.TryDequeue(out _), Is.False); + } + + [Test] + public async Task Should_dispose_newly_established_connection() + { + var channelProvider = new TestableChannelProvider(); + channelProvider.CreateConnection(); + + var publishConnection = channelProvider.PublishConnections.Dequeue(); + publishConnection.RaiseConnectionShutdown(new ShutdownEventArgs(ShutdownInitiator.Library, 0, "Test")); + + // This simulates the race of the reconnection loop being fired off with the delay task completed during + // the disposal of the channel provider. To achieve that it is necessary to kick off the reconnection loop + // and await its completion after the channel provider has been disposed. + var fireAndForgetTask = channelProvider.FireAndForgetAction(CancellationToken.None); + channelProvider.DelayTaskCompletionSource.SetResult(); + channelProvider.Dispose(); + + await fireAndForgetTask; + + var recoveredConnection = channelProvider.PublishConnections.Dequeue(); + + Assert.That(publishConnection.WasDisposed, Is.True); + Assert.That(recoveredConnection.WasDisposed, Is.True); + } + + class TestableChannelProvider() : ChannelProvider(null!, TimeSpan.Zero, null!) + { + public Queue PublishConnections { get; } = new(); + + public TaskCompletionSource DelayTaskCompletionSource { get; } = new(TaskCreationOptions.RunContinuationsAsynchronously); + + public Func FireAndForgetAction { get; private set; } + + protected override IConnection CreatePublishConnection() + { + var connection = new FakeConnection(); + PublishConnections.Enqueue(connection); + return connection; + } + + protected override void FireAndForget(Func action, CancellationToken cancellationToken = default) + => FireAndForgetAction = _ => action(cancellationToken); + + protected override async Task DelayReconnect(CancellationToken cancellationToken = default) + { + await using var _ = cancellationToken.Register(() => DelayTaskCompletionSource.TrySetCanceled(cancellationToken)); + await DelayTaskCompletionSource.Task; + } + } + + class FakeConnection : IConnection + { + public int LocalPort { get; } + public int RemotePort { get; } + + public void Dispose() => WasDisposed = true; + + public bool WasDisposed { get; private set; } + + public void UpdateSecret(string newSecret, string reason) => throw new NotImplementedException(); + + public void Abort() => throw new NotImplementedException(); + + public void Abort(ushort reasonCode, string reasonText) => throw new NotImplementedException(); + + public void Abort(TimeSpan timeout) => throw new NotImplementedException(); + + public void Abort(ushort reasonCode, string reasonText, TimeSpan timeout) => throw new NotImplementedException(); + + public void Close() => throw new NotImplementedException(); + + public void Close(ushort reasonCode, string reasonText) => throw new NotImplementedException(); + + public void Close(TimeSpan timeout) => throw new NotImplementedException(); + + public void Close(ushort reasonCode, string reasonText, TimeSpan timeout) => throw new NotImplementedException(); + + public IModel CreateModel() => throw new NotImplementedException(); + + public void HandleConnectionBlocked(string reason) => throw new NotImplementedException(); + + public void HandleConnectionUnblocked() => throw new NotImplementedException(); + + public ushort ChannelMax { get; } + public IDictionary ClientProperties { get; } + public ShutdownEventArgs CloseReason { get; } + public AmqpTcpEndpoint Endpoint { get; } + public uint FrameMax { get; } + public TimeSpan Heartbeat { get; } + public bool IsOpen { get; } + public AmqpTcpEndpoint[] KnownHosts { get; } + public IProtocol Protocol { get; } + public IDictionary ServerProperties { get; } + public IList ShutdownReport { get; } + public string ClientProvidedName { get; } = $"FakeConnection{Interlocked.Increment(ref connectionCounter)}"; + public event EventHandler CallbackException = (_, _) => { }; + public event EventHandler ConnectionBlocked = (_, _) => { }; + public event EventHandler ConnectionShutdown = (_, _) => { }; + public event EventHandler ConnectionUnblocked = (_, _) => { }; + + public void RaiseConnectionShutdown(ShutdownEventArgs args) => ConnectionShutdown?.Invoke(this, args); + + static int connectionCounter; + } + } +} \ No newline at end of file diff --git a/src/NServiceBus.Transport.RabbitMQ.Tests/ConnectionString/ConnectionConfigurationTests.cs b/src/NServiceBus.Transport.RabbitMQ.Tests/Connection/ConnectionConfigurationTests.cs similarity index 100% rename from src/NServiceBus.Transport.RabbitMQ.Tests/ConnectionString/ConnectionConfigurationTests.cs rename to src/NServiceBus.Transport.RabbitMQ.Tests/Connection/ConnectionConfigurationTests.cs diff --git a/src/NServiceBus.Transport.RabbitMQ.Tests/ConnectionString/ConnectionConfigurationWithAmqpTests.cs b/src/NServiceBus.Transport.RabbitMQ.Tests/Connection/ConnectionConfigurationWithAmqpTests.cs similarity index 100% rename from src/NServiceBus.Transport.RabbitMQ.Tests/ConnectionString/ConnectionConfigurationWithAmqpTests.cs rename to src/NServiceBus.Transport.RabbitMQ.Tests/Connection/ConnectionConfigurationWithAmqpTests.cs diff --git a/src/NServiceBus.Transport.RabbitMQ.TransportTests/When_changing_concurrency.cs b/src/NServiceBus.Transport.RabbitMQ.TransportTests/When_changing_concurrency.cs index 665dedfd7..75ab33744 100644 --- a/src/NServiceBus.Transport.RabbitMQ.TransportTests/When_changing_concurrency.cs +++ b/src/NServiceBus.Transport.RabbitMQ.TransportTests/When_changing_concurrency.cs @@ -16,6 +16,7 @@ public class When_changing_concurrency : NServiceBusTransportTest public async Task Should_complete_current_message(TransportTransactionMode transactionMode) { var triggeredChangeConcurrency = CreateTaskCompletionSource(); + var sentMessageReceived = CreateTaskCompletionSource(); Task concurrencyChanged = null; int invocationCounter = 0; @@ -30,6 +31,7 @@ await StartPump(async (context, ct) => await task; }, ct); + sentMessageReceived.SetResult(); await triggeredChangeConcurrency.Task; }, (_, _) => @@ -40,8 +42,10 @@ await StartPump(async (context, ct) => transactionMode); await SendMessage(InputQueueName); + await sentMessageReceived.Task; await concurrencyChanged; await StopPump(); + Assert.AreEqual(1, invocationCounter, "message should successfully complete on first processing attempt"); } @@ -62,6 +66,7 @@ await StartPump((context, _) => if (context.Headers.TryGetValue("FromOnError", out var value) && value == bool.TrueString) { sentMessageReceived.SetResult(); + return Task.CompletedTask; } throw new Exception("triggering recoverability pipeline"); @@ -84,9 +89,9 @@ await SendMessage(InputQueueName, transactionMode); await SendMessage(InputQueueName); - await sentMessageReceived.Task; await StopPump(); + Assert.AreEqual(2, invocationCounter, "there should be exactly 2 messages (initial message and new message from onError pipeline)"); } } diff --git a/src/NServiceBus.Transport.RabbitMQ/Connection/ChannelProvider.cs b/src/NServiceBus.Transport.RabbitMQ/Connection/ChannelProvider.cs index 01419652c..afd572a0f 100644 --- a/src/NServiceBus.Transport.RabbitMQ/Connection/ChannelProvider.cs +++ b/src/NServiceBus.Transport.RabbitMQ/Connection/ChannelProvider.cs @@ -1,3 +1,5 @@ +#nullable enable + namespace NServiceBus.Transport.RabbitMQ { using System; @@ -7,7 +9,7 @@ namespace NServiceBus.Transport.RabbitMQ using global::RabbitMQ.Client; using Logging; - sealed class ChannelProvider : IDisposable + class ChannelProvider : IDisposable { public ChannelProvider(ConnectionFactory connectionFactory, TimeSpan retryDelay, IRoutingTopology routingTopology) { @@ -19,36 +21,56 @@ public ChannelProvider(ConnectionFactory connectionFactory, TimeSpan retryDelay, channels = new ConcurrentQueue(); } - public void CreateConnection() + public void CreateConnection() => connection = CreateConnectionWithShutdownListener(); + + protected virtual IConnection CreatePublishConnection() => connectionFactory.CreatePublishConnection(); + + IConnection CreateConnectionWithShutdownListener() { - connection = connectionFactory.CreatePublishConnection(); - connection.ConnectionShutdown += Connection_ConnectionShutdown; + var newConnection = CreatePublishConnection(); + newConnection.ConnectionShutdown += Connection_ConnectionShutdown; + return newConnection; } - void Connection_ConnectionShutdown(object sender, ShutdownEventArgs e) + void Connection_ConnectionShutdown(object? sender, ShutdownEventArgs e) { - if (e.Initiator != ShutdownInitiator.Application) + if (e.Initiator == ShutdownInitiator.Application || sender is null) { - var connection = (IConnection)sender; - - // Task.Run() so the call returns immediately instead of waiting for the first await or return down the call stack - _ = Task.Run(() => ReconnectSwallowingExceptions(connection.ClientProvidedName), CancellationToken.None); + return; } + + var connectionThatWasShutdown = (IConnection)sender; + + FireAndForget(cancellationToken => ReconnectSwallowingExceptions(connectionThatWasShutdown.ClientProvidedName, cancellationToken), stoppingTokenSource.Token); } -#pragma warning disable PS0018 // A task-returning method should have a CancellationToken parameter unless it has a parameter implementing ICancellableContext - async Task ReconnectSwallowingExceptions(string connectionName) -#pragma warning restore PS0018 // A task-returning method should have a CancellationToken parameter unless it has a parameter implementing ICancellableContext + async Task ReconnectSwallowingExceptions(string connectionName, CancellationToken cancellationToken) { - while (true) + while (!cancellationToken.IsCancellationRequested) { Logger.InfoFormat("'{0}': Attempting to reconnect in {1} seconds.", connectionName, retryDelay.TotalSeconds); - await Task.Delay(retryDelay).ConfigureAwait(false); - try { - CreateConnection(); + await DelayReconnect(cancellationToken).ConfigureAwait(false); + + var newConnection = CreateConnectionWithShutdownListener(); + + // A race condition is possible where CreatePublishConnection is invoked during Dispose + // where the returned connection isn't disposed so invoking Dispose to be sure + if (cancellationToken.IsCancellationRequested) + { + newConnection.Dispose(); + break; + } + + var oldConnection = Interlocked.Exchange(ref connection, newConnection); + oldConnection?.Dispose(); + break; + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + Logger.InfoFormat("'{0}': Stopped trying to reconnecting to the broker due to shutdown", connectionName); break; } catch (Exception ex) @@ -60,6 +82,12 @@ async Task ReconnectSwallowingExceptions(string connectionName) Logger.InfoFormat("'{0}': Connection to the broker reestablished successfully.", connectionName); } + protected virtual void FireAndForget(Func action, CancellationToken cancellationToken = default) => + // Task.Run() so the call returns immediately instead of waiting for the first await or return down the call stack + _ = Task.Run(() => action(cancellationToken), CancellationToken.None); + + protected virtual Task DelayReconnect(CancellationToken cancellationToken = default) => Task.Delay(retryDelay, cancellationToken); + public ConfirmsAwareChannel GetPublishChannel() { if (!channels.TryDequeue(out var channel) || channel.IsClosed) @@ -86,19 +114,32 @@ public void ReturnPublishChannel(ConfirmsAwareChannel channel) public void Dispose() { - connection?.Dispose(); + if (disposed) + { + return; + } + + stoppingTokenSource.Cancel(); + stoppingTokenSource.Dispose(); + + var oldConnection = Interlocked.Exchange(ref connection, null); + oldConnection?.Dispose(); foreach (var channel in channels) { channel.Dispose(); } + + disposed = true; } readonly ConnectionFactory connectionFactory; readonly TimeSpan retryDelay; readonly IRoutingTopology routingTopology; readonly ConcurrentQueue channels; - IConnection connection; + readonly CancellationTokenSource stoppingTokenSource = new(); + volatile IConnection? connection; + bool disposed; static readonly ILog Logger = LogManager.GetLogger(typeof(ChannelProvider)); }