diff --git a/Source/MQTTnet.Benchmarks/AsyncReadWriteLockBenchmark.cs b/Source/MQTTnet.Benchmarks/AsyncReadWriteLockBenchmark.cs new file mode 100644 index 000000000..f6e8fbcb6 --- /dev/null +++ b/Source/MQTTnet.Benchmarks/AsyncReadWriteLockBenchmark.cs @@ -0,0 +1,40 @@ +using BenchmarkDotNet.Attributes; +using BenchmarkDotNet.Jobs; +using MQTTnet.Internal; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace MQTTnet.Benchmarks +{ + [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net70)] + [MemoryDiagnoser] + public class AsyncReadWriteLockBenchmark : BaseBenchmark + { + + [Benchmark] + public async Task Synchronize_100_Read_Tasks() + { + const int tasksCount = 100; + var tasks = new Task[tasksCount]; + var asyncReadWriteLock = new AsyncReadWriteLock(); + + for (var i = 0; i < tasksCount; i++) + { + tasks[i] = Task.Run(async () => + { + using (await asyncReadWriteLock.EnterReadAsync().ConfigureAwait(false)) + { + await Task.Delay(5).ConfigureAwait(false); + } + }); + } + + await Task.WhenAll(tasks).ConfigureAwait(false); + } + + } +} diff --git a/Source/MQTTnet.Tests/Internal/AsyncReadWriteLock_Tests.cs b/Source/MQTTnet.Tests/Internal/AsyncReadWriteLock_Tests.cs new file mode 100644 index 000000000..9bb1d9534 --- /dev/null +++ b/Source/MQTTnet.Tests/Internal/AsyncReadWriteLock_Tests.cs @@ -0,0 +1,502 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Internal; + +namespace MQTTnet.Tests.Internal +{ + [TestClass] + public sealed class AsyncReadWriteLock_Tests + { + [TestMethod] + public async Task Cancellation_Of_Write_Awaiter() + { + var @lock = new AsyncReadWriteLock(); + + // This call will not yet "release" the lock due to missing _using_. + var releaser = await @lock.EnterWriteAsync().ConfigureAwait(false); + + var counter = 0; + + Debug.WriteLine("Prepared locked lock."); + + _ = Task.Run( + async () => + { + // SHOULD GET TIMEOUT! + using (var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(1))) + { + using (await @lock.EnterWriteAsync(timeout.Token)) + { + Debug.WriteLine("Task 1 incremented"); + counter++; + } + } + }); + + _ = Task.Run( + async () => + { + // SHOULD GET TIMEOUT! + using (var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(2))) + { + using (await @lock.EnterWriteAsync(timeout.Token)) + { + Debug.WriteLine("Task 2 incremented"); + counter++; + } + } + }); + + _ = Task.Run( + async () => + { + // SHOULD GET TIMEOUT! + using (var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(3))) + { + using (await @lock.EnterWriteAsync(timeout.Token)) + { + Debug.WriteLine("Task 3 incremented"); + counter++; + } + } + }); + + _ = Task.Run( + async () => + { + using (var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(4))) + { + using (await @lock.EnterWriteAsync(timeout.Token)) + { + Debug.WriteLine("Task 4 incremented"); + counter++; + } + } + }); + + _ = Task.Run( + async () => + { + using (var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(5))) + { + using (await @lock.EnterWriteAsync(timeout.Token)) + { + Debug.WriteLine("Task 5 incremented"); + counter++; + } + } + }); + + _ = Task.Run( + async () => + { + using (var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(6))) + { + using (await @lock.EnterWriteAsync(timeout.Token)) + { + Debug.WriteLine("Task 6 incremented"); + counter++; + } + } + }); + + Debug.WriteLine("Delay before release..."); + await Task.Delay(TimeSpan.FromSeconds(3.1)); + releaser.Dispose(); + + Debug.WriteLine("Wait for all tasks..."); + await Task.Delay(TimeSpan.FromSeconds(6.1)); + + Assert.AreEqual(3, counter); + } + + [TestMethod] + public async Task Cancellation_Of_Read_Awaiter() + { + var @lock = new AsyncReadWriteLock(); + + // This call will not yet "release" the lock due to missing _using_. + var releaser = await @lock.EnterWriteAsync().ConfigureAwait(false); + + var counter = 0; + var syncLock = new object(); + + Debug.WriteLine("Prepared locked lock."); + + _ = Task.Run( + async () => + { + // SHOULD GET TIMEOUT! + using (var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(1))) + { + using (await @lock.EnterReadAsync(timeout.Token)) + { + Debug.WriteLine("Task 1 incremented"); + lock (syncLock) + counter++; + } + } + }); + + _ = Task.Run( + async () => + { + // SHOULD GET TIMEOUT! + using (var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(2))) + { + using (await @lock.EnterReadAsync(timeout.Token)) + { + Debug.WriteLine("Task 2 incremented"); + lock (syncLock) + counter++; + } + } + }); + + _ = Task.Run( + async () => + { + // SHOULD GET TIMEOUT! + using (var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(3))) + { + using (await @lock.EnterReadAsync(timeout.Token)) + { + Debug.WriteLine("Task 3 incremented"); + lock (syncLock) + counter++; + } + } + }); + + _ = Task.Run( + async () => + { + using (var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(4))) + { + using (await @lock.EnterReadAsync(timeout.Token)) + { + Debug.WriteLine("Task 4 incremented"); + lock (syncLock) + counter++; + } + } + }); + + _ = Task.Run( + async () => + { + using (var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(5))) + { + using (await @lock.EnterReadAsync(timeout.Token)) + { + Debug.WriteLine("Task 5 incremented"); + lock (syncLock) + counter++; + } + } + }); + + _ = Task.Run( + async () => + { + using (var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(6))) + { + using (await @lock.EnterReadAsync(timeout.Token)) + { + Debug.WriteLine("Task 6 incremented"); + lock (syncLock) + counter++; + } + } + }); + + Debug.WriteLine("Delay before release..."); + await Task.Delay(TimeSpan.FromSeconds(3.1)); + releaser.Dispose(); + + Debug.WriteLine("Wait for all tasks..."); + await Task.Delay(TimeSpan.FromSeconds(6.1)); + + Assert.AreEqual(3, counter); + } + + [TestMethod] + public void Lock_Parallel_Tasks() + { + const int taskCount = 50; + + var @lock = new AsyncReadWriteLock(); + + var tasks = new Task[taskCount]; + var globalI = 0; + for (var i = 0; i < taskCount; i++) + { +#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + tasks[i] = Task.Run( + async () => + { + using (await @lock.EnterWriteAsync()) + { + var localI = globalI; + await Task.Delay(5); // Increase the chance for wrong data. + localI++; + globalI = localI; + } + }); +#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + } + + Task.WaitAll(tasks); + Assert.AreEqual(taskCount, globalI); + } + + [TestMethod] + public void Lock_Parallel_Write_Tasks_During_Read() + { + const int taskCount = 50; + + var @lock = new AsyncReadWriteLock(); + + var tasks = new Task[taskCount * 5]; + var globalWrittenI = 0; + var globalReadI = 0; + var globalParallelI = 0; + var readSync = new object(); + + for (var i = 0; i < taskCount; i++) + { + tasks[i] = Task.Run( + async () => + { + using (await @lock.EnterReadAsync()) + { + var localParallelI = globalParallelI; + lock (readSync) + globalReadI++; + await Task.Delay(5 * taskCount); // Increase the chance for wrong data. + localParallelI++; + globalParallelI = localParallelI; + } + }); + } + + for (var i = 0; i < taskCount; i++) + { + tasks[i + taskCount] = Task.Run( + async () => + { + using (await @lock.EnterWriteAsync()) + { + var localI = globalWrittenI; + await Task.Delay(5); // Increase the chance for wrong data. + localI++; + globalWrittenI = localI; + } + }); + } + + for (var i = 0; i < taskCount; i++) + { + tasks[i + taskCount * 2] = Task.Run( + async () => + { + using (await @lock.EnterReadAsync()) + { + var localParallelI = globalParallelI; + lock (readSync) + globalReadI++; + await Task.Delay(5 * taskCount); // Increase the chance for wrong data. + localParallelI++; + globalParallelI = localParallelI; + } + }); + } + + for (var i = 0; i < taskCount; i++) + { + tasks[i + taskCount * 3] = Task.Run( + async () => + { + using (await @lock.EnterWriteAsync()) + { + var localI = globalWrittenI; + await Task.Delay(5); // Increase the chance for wrong data. + localI++; + globalWrittenI = localI; + } + }); + } + + for (var i = 0; i < taskCount; i++) + { + tasks[i + taskCount * 4] = Task.Run( + async () => + { + using (await @lock.EnterReadAsync()) + { + var localParallelI = globalParallelI; + lock (readSync) + globalReadI++; + await Task.Delay(5 * taskCount); // Increase the chance for wrong data. + localParallelI++; + globalParallelI = localParallelI; + } + }); + } + + Task.WaitAll(tasks); + Assert.AreEqual(taskCount * 2, globalWrittenI); + Assert.AreEqual(taskCount * 3, globalReadI); // Validates that all reads occurred. + Assert.AreNotEqual(taskCount * 3, globalParallelI); // Ensures the reads happened in parallel. + } + + [TestMethod] + public void Lock_10_Parallel_Tasks_With_Dispose_Doesnt_Lockup() + { + const int ThreadsCount = 10; + + var threads = new Task[ThreadsCount]; + var @lock = new AsyncLock(); + var globalI = 0; + for (var i = 0; i < ThreadsCount; i++) + { +#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + threads[i] = Task.Run( + async () => + { + using (await @lock.EnterAsync()) + { + var localI = globalI; + await Task.Delay(10); // Increase the chance for wrong data. + localI++; + globalI = localI; + } + }) + .ContinueWith( + x => + { + if (globalI == 5) + { + @lock.Dispose(); + @lock = new AsyncLock(); + } + + if (x.Exception != null) + { + Debug.WriteLine(x.Exception.GetBaseException().GetType().Name); + } + }); +#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + } + + Task.WaitAll(threads); + + // Expect only 6 because the others are failing due to disposal (if (globalI == 5)). + Assert.AreEqual(6, globalI); + } + + [TestMethod] + public async Task Lock_Serial_Calls() + { + var sum = 0; + + var @lock = new AsyncReadWriteLock(); + for (var i = 0; i < 100; i++) + { + using (await @lock.EnterWriteAsync().ConfigureAwait(false)) + { + sum++; + } + } + + Assert.AreEqual(100, sum); + } + + [TestMethod] + [ExpectedException(typeof(TaskCanceledException))] + public async Task Test_Cancellation() + { + var @lock = new AsyncReadWriteLock(); + + // This call will never "release" the lock due to missing _using_. + await @lock.EnterWriteAsync().ConfigureAwait(false); + + using (var cts = new CancellationTokenSource(TimeSpan.FromSeconds(3))) + { + await @lock.EnterWriteAsync(cts.Token).ConfigureAwait(false); + } + } + + [TestMethod] + public async Task Test_Cancellation_With_Later_Access() + { + var asyncLock = new AsyncReadWriteLock(); + + var releaser = await asyncLock.EnterWriteAsync().ConfigureAwait(false); + + try + { + using (var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(3))) + { + await asyncLock.EnterWriteAsync(timeout.Token).ConfigureAwait(false); + } + + Assert.Fail("Exception should be thrown!"); + } + catch (OperationCanceledException) + { + } + + releaser.Dispose(); + + using (await asyncLock.EnterWriteAsync(CancellationToken.None).ConfigureAwait(false)) + { + // When the method finished, the thread got access. + } + } + + [TestMethod] + public async Task Use_After_Cancellation() + { + var @lock = new AsyncReadWriteLock(); + + // This call will not yet "release" the lock due to missing _using_. + var releaser = await @lock.EnterWriteAsync().ConfigureAwait(false); + + try + { + using (var cts = new CancellationTokenSource(TimeSpan.FromSeconds(3))) + { + await @lock.EnterWriteAsync(cts.Token).ConfigureAwait(false); + } + } + catch (OperationCanceledException) + { + // Expected exception! + } + + releaser.Dispose(); + + // Regular usage after cancellation. + using (await @lock.EnterWriteAsync().ConfigureAwait(false)) + { + } + + using (await @lock.EnterWriteAsync().ConfigureAwait(false)) + { + } + + using (await @lock.EnterWriteAsync().ConfigureAwait(false)) + { + } + } + + } +} \ No newline at end of file diff --git a/Source/MQTTnet/Internal/AsyncReadWriteLock.cs b/Source/MQTTnet/Internal/AsyncReadWriteLock.cs new file mode 100644 index 000000000..042172e37 --- /dev/null +++ b/Source/MQTTnet/Internal/AsyncReadWriteLock.cs @@ -0,0 +1,215 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.Internal +{ + public class AsyncReadWriteLock : IDisposable + { + private readonly object _syncRoot = new object(); + private readonly Queue _readWaiters = new Queue(64); + private readonly Queue _writeWaiters = new Queue(); + private readonly IDisposable _readReleaser; + private readonly IDisposable _writeReleaser; + private readonly Task _readCompletedTask; + private readonly Task _writeCompletedTask; + private int _readLockCount; + private bool _isLockedForWrite; + private bool _isDisposed; + + public AsyncReadWriteLock() + { + _readReleaser = new Releaser(this, true); + _writeReleaser = new Releaser(this, false); + _readCompletedTask = Task.FromResult(_readReleaser); + _writeCompletedTask = Task.FromResult(_writeReleaser); + } + + public Task EnterReadAsync(CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + + if (Volatile.Read(ref _isDisposed)) + { + throw new ObjectDisposedException(nameof(AsyncReadWriteLock)); + } + + lock (_syncRoot) + { + if (!_isLockedForWrite && _writeWaiters.Count == 0) + { + _readLockCount++; + return _readCompletedTask; + } + + var waiter = new ReadWriteLockWaiter(cancellationToken); + _readWaiters.Enqueue(waiter); + + return waiter.Task; + } + } + + public Task EnterWriteAsync(CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested(); + + if (Volatile.Read(ref _isDisposed)) + { + throw new ObjectDisposedException(nameof(AsyncReadWriteLock)); + } + + lock (_syncRoot) + { + if (!_isLockedForWrite && _readLockCount == 0) + { + _isLockedForWrite = true; + return _writeCompletedTask; + } + + var waiter = new ReadWriteLockWaiter(cancellationToken); + _writeWaiters.Enqueue(waiter); + + return waiter.Task; + } + } + + private void ReleaseRead() + { + lock (_syncRoot) + { + if (_isDisposed) + { + return; + } + + _readLockCount--; + + if (_readLockCount <= 0) + { + _readLockCount = 0; + ApproveNextWaiter(); + } + } + } + + private void ReleaseWrite() + { + lock (_syncRoot) + { + if (_isDisposed) + { + return; + } + + _isLockedForWrite = false; + + ApproveNextWaiter(); + } + } + + private void ApproveNextWaiter() + { + while (_writeWaiters.Count > 0) + { + var waiter = _writeWaiters.Dequeue(); + var isApproved = waiter.Approve(_writeReleaser); + waiter.Dispose(); + + if (isApproved) + { + _isLockedForWrite = true; + return; + } + } + while (_readWaiters.Count > 0) + { + var waiter = _readWaiters.Dequeue(); + var isApproved = waiter.Approve(_readReleaser); + waiter.Dispose(); + + if (isApproved) + _readLockCount++; + } + } + + public void Dispose() + { + Volatile.Write(ref _isDisposed, true); + } + + sealed class ReadWriteLockWaiter : IDisposable + { + private readonly CancellationTokenRegistration _cancellationTokenRegistration; + private readonly bool _hasCancellationRegistration; + private readonly AsyncTaskCompletionSource _promise = new AsyncTaskCompletionSource(); + + public ReadWriteLockWaiter(CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + + if (cancellationToken.CanBeCanceled) + { + _cancellationTokenRegistration = cancellationToken.Register(Cancel); + _hasCancellationRegistration = true; + } + } + + public Task Task => _promise.Task; + + public bool Approve(IDisposable scope) + { + if (scope == null) + { + throw new ArgumentNullException(nameof(scope)); + } + + if (_promise.Task.IsCompleted) + { + return false; + } + + return _promise.TrySetResult(scope); + } + + private void Cancel() + { + _promise.TrySetCanceled(); + } + + public void Dispose() + { + if (_hasCancellationRegistration) + { + _cancellationTokenRegistration.Dispose(); + } + + _promise.TrySetCanceled(); + } + + } + + readonly struct Releaser : IDisposable + { + private readonly AsyncReadWriteLock _readWriteLock; + private readonly bool _isRead; + + public Releaser(AsyncReadWriteLock readWriteLock, bool isRead) + { + _readWriteLock = readWriteLock; + _isRead = isRead; + } + + public void Dispose() + { + if (_isRead) + _readWriteLock.ReleaseRead(); + else + _readWriteLock.ReleaseWrite(); + } + } + + } +} diff --git a/Source/MQTTnet/Server/Internal/MqttClientSubscriptionsManager.cs b/Source/MQTTnet/Server/Internal/MqttClientSubscriptionsManager.cs index b409d959f..7ebcae4b4 100644 --- a/Source/MQTTnet/Server/Internal/MqttClientSubscriptionsManager.cs +++ b/Source/MQTTnet/Server/Internal/MqttClientSubscriptionsManager.cs @@ -31,7 +31,7 @@ public sealed class MqttClientSubscriptionsManager : IDisposable readonly Dictionary _subscriptions = new Dictionary(); // Use subscription lock to maintain consistency across subscriptions and topic hash dictionaries - readonly AsyncLock _subscriptionsLock = new AsyncLock(); + readonly AsyncReadWriteLock _subscriptionsLock = new AsyncReadWriteLock(); readonly Dictionary _wildcardSubscriptionsByTopicHash = new Dictionary(); public MqttClientSubscriptionsManager( @@ -51,7 +51,7 @@ public CheckSubscriptionsResult CheckSubscriptions(string topic, ulong topicHash var possibleSubscriptions = new List(); // Check for possible subscriptions. They might have collisions but this is fine. - using (_subscriptionsLock.EnterAsync(CancellationToken.None).GetAwaiter().GetResult()) + using (_subscriptionsLock.EnterReadAsync(CancellationToken.None).GetAwaiter().GetResult()) { if (_noWildcardSubscriptionsByTopicHash.TryGetValue(topicHash, out var noWildcardSubscriptions)) { @@ -232,7 +232,7 @@ public async Task Unsubscribe(MqttUnsubscribePacket unsubscri var removedSubscriptions = new List(); - using (await _subscriptionsLock.EnterAsync(cancellationToken).ConfigureAwait(false)) + using (await _subscriptionsLock.EnterWriteAsync(cancellationToken).ConfigureAwait(false)) { foreach (var topicFilter in unsubscribePacket.TopicFilters) { @@ -340,7 +340,7 @@ CreateSubscriptionResult CreateSubscription(MqttTopicFilter topicFilter, uint su // Add to subscriptions and maintain topic hash dictionaries - using (_subscriptionsLock.EnterAsync(CancellationToken.None).GetAwaiter().GetResult()) + using (_subscriptionsLock.EnterWriteAsync(CancellationToken.None).GetAwaiter().GetResult()) { MqttSubscription.CalculateTopicHash(topicFilter.Topic, out var topicHash, out var topicHashMask, out var hasWildcard);