diff --git a/MQTTnet.sln b/MQTTnet.sln index 15a0d2323..984482f18 100644 --- a/MQTTnet.sln +++ b/MQTTnet.sln @@ -7,11 +7,11 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "MQTTnet", "Source\MQTTnet\M EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{B3F60ECB-45BA-4C66-8903-8BB89CA67998}" ProjectSection(SolutionItems) = preProject + .github\workflows\ci.yml = .github\workflows\ci.yml CODE-OF-CONDUCT.md = CODE-OF-CONDUCT.md LICENSE = LICENSE README.md = README.md Source\ReleaseNotes.md = Source\ReleaseNotes.md - .github\workflows\ci.yml = .github\workflows\ci.yml EndProjectSection EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "MQTTnet.AspNetCore", "Source\MQTTnet.AspnetCore\MQTTnet.AspNetCore.csproj", "{F10C4060-F7EE-4A83-919F-FF723E72F94A}" @@ -85,6 +85,4 @@ Global GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {07536672-5CBC-4BE3-ACE0-708A431A7894} EndGlobalSection - GlobalSection(NestedProjects) = preSolution - EndGlobalSection EndGlobal diff --git a/Samples/MQTTnet.Samples.csproj b/Samples/MQTTnet.Samples.csproj index 5fe84d380..2441f1c9c 100644 --- a/Samples/MQTTnet.Samples.csproj +++ b/Samples/MQTTnet.Samples.csproj @@ -14,7 +14,7 @@ all true low - latest-Recommended + diff --git a/Samples/Server/Server_ASP_NET_Samples.cs b/Samples/Server/Server_ASP_NET_Samples.cs index 9247093e2..5ee50b357 100644 --- a/Samples/Server/Server_ASP_NET_Samples.cs +++ b/Samples/Server/Server_ASP_NET_Samples.cs @@ -12,6 +12,7 @@ using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; using MQTTnet.AspNetCore; using MQTTnet.Server; @@ -19,92 +20,87 @@ namespace MQTTnet.Samples.Server; public static class Server_ASP_NET_Samples { + static readonly string unixSocketPath = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), "mqtt.socks"); + public static Task Start_Server_With_WebSockets_Support() { - /* - * This sample starts a minimal ASP.NET Webserver including a hosted MQTT server. - */ - var host = Host.CreateDefaultBuilder(Array.Empty()) - .ConfigureWebHostDefaults( - webBuilder => - { - webBuilder.UseKestrel( - o => - { - // This will allow MQTT connections based on TCP port 1883. - o.ListenAnyIP(1883, l => l.UseMqtt()); - - // This will allow MQTT connections based on HTTP WebSockets with URI "localhost:5000/mqtt" - // See code below for URI configuration. - o.ListenAnyIP(5000); // Default HTTP pipeline - }); - - webBuilder.UseStartup(); - }); - - return host.RunConsoleAsync(); + File.Delete(unixSocketPath); + + var builder = WebApplication.CreateBuilder(); + builder.Services.AddMqttServer(s => s.WithDefaultEndpoint().WithEncryptedEndpoint()); + builder.Services.AddMqttClient(); + builder.Services.AddHostedService(); + + builder.WebHost.UseKestrel(kestrel => + { + // Need ConfigureMqttServer(s => ...) to enable the endpoints + kestrel.ListenMqtt(); + + // We can also manually listen to a specific port without ConfigureMqttServer() + kestrel.ListenUnixSocket(unixSocketPath, l => l.UseMqtt()); + // kestrel.ListenAnyIP(1883, l => l.UseMqtt()); // mqtt over tcp + // kestrel.ListenAnyIP(8883, l => l.UseHttps().UseMqtt()); // mqtt over tls over tcp + }); + + var app = builder.Build(); + app.MapMqtt("/mqtt"); + app.UseMqttServer(); + return app.RunAsync(); } - sealed class MqttController + sealed class MqttServerController { - public MqttController() + private readonly ILogger _logger; + + public MqttServerController( + MqttServer mqttServer, + ILogger logger) { - // Inject other services via constructor. + _logger = logger; + + mqttServer.ValidatingConnectionAsync += ValidateConnection; + mqttServer.ClientConnectedAsync += OnClientConnected; } public Task OnClientConnected(ClientConnectedEventArgs eventArgs) { - Console.WriteLine($"Client '{eventArgs.ClientId}' connected."); + _logger.LogInformation($"Client '{eventArgs.ClientId}' connected."); return Task.CompletedTask; } - public Task ValidateConnection(ValidatingConnectionEventArgs eventArgs) { - Console.WriteLine($"Client '{eventArgs.ClientId}' wants to connect. Accepting!"); + _logger.LogInformation($"Client '{eventArgs.ClientId}' wants to connect. Accepting!"); return Task.CompletedTask; } } - sealed class Startup + sealed class MqttClientController : BackgroundService { - public void Configure(IApplicationBuilder app, IWebHostEnvironment environment, MqttController mqttController) + private readonly IMqttClientFactory _mqttClientFactory; + + public MqttClientController(IMqttClientFactory mqttClientFactory) { - app.UseRouting(); - - app.UseEndpoints( - endpoints => - { - endpoints.MapConnectionHandler( - "/mqtt", - httpConnectionDispatcherOptions => httpConnectionDispatcherOptions.WebSockets.SubProtocolSelector = - protocolList => protocolList.FirstOrDefault() ?? string.Empty); - }); - - app.UseMqttServer( - server => - { - /* - * Attach event handlers etc. if required. - */ - - server.ValidatingConnectionAsync += mqttController.ValidateConnection; - server.ClientConnectedAsync += mqttController.OnClientConnected; - }); + _mqttClientFactory = mqttClientFactory; } - public void ConfigureServices(IServiceCollection services) + protected override async Task ExecuteAsync(CancellationToken stoppingToken) { - services.AddHostedMqttServer( - optionsBuilder => - { - optionsBuilder.WithDefaultEndpoint(); - }); + await Task.Delay(1000); + using var client = _mqttClientFactory.CreateMqttClient(); + + // var mqttUri = "mqtt://localhost:1883"; + // var mqttsUri = "mqtts://localhost:8883"; + // var wsMqttUri = "ws://localhost:1883/mqtt"; + var wssMqttUri = "wss://localhost:8883/mqtt"; - services.AddMqttConnectionHandler(); - services.AddConnections(); + var options = new MqttClientOptionsBuilder() + //.WithEndPoint(new UnixDomainSocketEndPoint(unixSocketPath)) + .WithConnectionUri(wssMqttUri) + .Build(); - services.AddSingleton(); + await client.ConnectAsync(options, stoppingToken); + await client.DisconnectAsync(); } } } \ No newline at end of file diff --git a/Source/MQTTnet.AspTestApp/MQTTnet.AspTestApp.csproj b/Source/MQTTnet.AspTestApp/MQTTnet.AspTestApp.csproj index 254069b60..ede29a49f 100644 --- a/Source/MQTTnet.AspTestApp/MQTTnet.AspTestApp.csproj +++ b/Source/MQTTnet.AspTestApp/MQTTnet.AspTestApp.csproj @@ -13,7 +13,7 @@ all true low - latest-Recommended + diff --git a/Source/MQTTnet.AspTestApp/Program.cs b/Source/MQTTnet.AspTestApp/Program.cs index 317c5f7eb..0b2b48571 100644 --- a/Source/MQTTnet.AspTestApp/Program.cs +++ b/Source/MQTTnet.AspTestApp/Program.cs @@ -11,11 +11,15 @@ builder.Services.AddRazorPages(); // Setup MQTT stuff. -builder.Services.AddMqttServer(); -builder.Services.AddConnections(); +builder.Services.AddMqttServer(s => s.WithDefaultEndpoint().WithDefaultEndpointPort(5000)); -var app = builder.Build(); +// ListenMqtt +builder.WebHost.UseKestrel(kestrel => +{ + kestrel.ListenMqtt(MqttProtocols.WebSocket); +}); +var app = builder.Build(); if (!app.Environment.IsDevelopment()) { app.UseExceptionHandler("/Error"); @@ -29,7 +33,7 @@ app.MapRazorPages(); -// Setup MQTT stuff. +// mqtt over websocket app.MapMqtt("/mqtt"); app.UseMqttServer(server => diff --git a/Source/MQTTnet.AspTestApp/appsettings.json b/Source/MQTTnet.AspTestApp/appsettings.json index 10f68b8c8..0f22ea11d 100644 --- a/Source/MQTTnet.AspTestApp/appsettings.json +++ b/Source/MQTTnet.AspTestApp/appsettings.json @@ -1,4 +1,4 @@ -{ +{ "Logging": { "LogLevel": { "Default": "Information", diff --git a/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs index f98983da9..e01a6556c 100644 --- a/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ApplicationBuilderExtensions.cs @@ -2,52 +2,39 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using Microsoft.AspNetCore.Builder; using Microsoft.Extensions.DependencyInjection; using MQTTnet.Server; +using System; +using System.Diagnostics.CodeAnalysis; namespace MQTTnet.AspNetCore; public static class ApplicationBuilderExtensions { - [Obsolete( - "This class is obsolete and will be removed in a future version. The recommended alternative is to use MapMqtt inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")] - public static IApplicationBuilder UseMqttEndpoint(this IApplicationBuilder app, string path = "/mqtt") - { - app.UseWebSockets(); - app.Use( - async (context, next) => - { - if (!context.WebSockets.IsWebSocketRequest || context.Request.Path != path) - { - await next(); - return; - } - - string subProtocol = null; - - if (context.Request.Headers.TryGetValue("Sec-WebSocket-Protocol", out var requestedSubProtocolValues)) - { - subProtocol = MqttSubProtocolSelector.SelectSubProtocol(requestedSubProtocolValues); - } - - var adapter = app.ApplicationServices.GetRequiredService(); - using (var webSocket = await context.WebSockets.AcceptWebSocketAsync(subProtocol).ConfigureAwait(false)) - { - await adapter.RunWebSocketConnectionAsync(webSocket, context); - } - }); - - return app; - } - + /// + /// Get and use + /// + /// Also, you can inject into your service + /// + /// + /// public static IApplicationBuilder UseMqttServer(this IApplicationBuilder app, Action configure) { var server = app.ApplicationServices.GetRequiredService(); - configure(server); + return app; + } + /// + /// Active MqttServer's wrapper service + /// + /// + /// + /// + public static IApplicationBuilder UseMqttServer<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] TMQttServerWrapper>(this IApplicationBuilder app) + { + ActivatorUtilities.GetServiceOrCreateInstance(app.ApplicationServices); return app; } } \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/AspNetCoreMqttNetLoggerOptions.cs b/Source/MQTTnet.AspnetCore/AspNetCoreMqttNetLoggerOptions.cs new file mode 100644 index 000000000..273463f4d --- /dev/null +++ b/Source/MQTTnet.AspnetCore/AspNetCoreMqttNetLoggerOptions.cs @@ -0,0 +1,29 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Extensions.Logging; +using MQTTnet.Diagnostics.Logger; +using System; + +namespace MQTTnet.AspNetCore +{ + public sealed class AspNetCoreMqttNetLoggerOptions + { + public string? CategoryNamePrefix { get; set; } = "MQTTnet.AspNetCore."; + + public Func LogLevelConverter { get; set; } = ConvertLogLevel; + + private static LogLevel ConvertLogLevel(MqttNetLogLevel level) + { + return level switch + { + MqttNetLogLevel.Verbose => LogLevel.Trace, + MqttNetLogLevel.Info => LogLevel.Information, + MqttNetLogLevel.Warning => LogLevel.Warning, + MqttNetLogLevel.Error => LogLevel.Error, + _ => LogLevel.None + }; + } + } +} diff --git a/Source/MQTTnet.AspnetCore/AspNetMqttServerOptionsBuilder.cs b/Source/MQTTnet.AspnetCore/AspNetMqttServerOptionsBuilder.cs deleted file mode 100644 index 394483959..000000000 --- a/Source/MQTTnet.AspnetCore/AspNetMqttServerOptionsBuilder.cs +++ /dev/null @@ -1,18 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using MQTTnet.Server; - -namespace MQTTnet.AspNetCore; - -public sealed class AspNetMqttServerOptionsBuilder : MqttServerOptionsBuilder -{ - public AspNetMqttServerOptionsBuilder(IServiceProvider serviceProvider) - { - ServiceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider)); - } - - public IServiceProvider ServiceProvider { get; } -} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/BufferExtensions.cs b/Source/MQTTnet.AspnetCore/BufferExtensions.cs deleted file mode 100644 index 47a5c0747..000000000 --- a/Source/MQTTnet.AspnetCore/BufferExtensions.cs +++ /dev/null @@ -1,21 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Runtime.InteropServices; - -namespace MQTTnet.AspNetCore; - -public static class BufferExtensions -{ - public static ArraySegment GetArray(this ReadOnlyMemory memory) - { - if (!MemoryMarshal.TryGetArray(memory, out var result)) - { - throw new InvalidOperationException("Buffer backed by array was expected"); - } - - return result; - } -} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs index 9ea8922ea..f45380831 100644 --- a/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ConnectionBuilderExtensions.cs @@ -3,14 +3,30 @@ // See the LICENSE file in the project root for more information. using Microsoft.AspNetCore.Connections; +using Microsoft.Extensions.DependencyInjection; +using MQTTnet.Adapter; +using MQTTnet.Server; +using System; namespace MQTTnet.AspNetCore { public static class ConnectionBuilderExtensions { - public static IConnectionBuilder UseMqtt(this IConnectionBuilder builder) + /// + /// Handle the connection using the specified MQTT protocols + /// + /// + /// + /// + /// + public static IConnectionBuilder UseMqtt(this IConnectionBuilder builder, MqttProtocols protocols = MqttProtocols.MqttAndWebSocket, Func? allowPacketFragmentationSelector = null) { - return builder.UseConnectionHandler(); + // check services.AddMqttServer() + builder.ApplicationServices.GetRequiredService(); + builder.ApplicationServices.GetRequiredService().UseFlag = true; + + var middleware = builder.ApplicationServices.GetRequiredService(); + return builder.Use(next => context => middleware.InvokeAsync(next, context, protocols, allowPacketFragmentationSelector)); } } } diff --git a/Source/MQTTnet.AspnetCore/DuplexPipe.cs b/Source/MQTTnet.AspnetCore/DuplexPipe.cs deleted file mode 100644 index 35075e800..000000000 --- a/Source/MQTTnet.AspnetCore/DuplexPipe.cs +++ /dev/null @@ -1,44 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System.IO.Pipelines; - -namespace MQTTnet.AspNetCore; - -public class DuplexPipe : IDuplexPipe -{ - public DuplexPipe(PipeReader reader, PipeWriter writer) - { - Input = reader; - Output = writer; - } - - public PipeReader Input { get; } - - public PipeWriter Output { get; } - - public static DuplexPipePair CreateConnectionPair(PipeOptions inputOptions, PipeOptions outputOptions) - { - var input = new Pipe(inputOptions); - var output = new Pipe(outputOptions); - - var transportToApplication = new DuplexPipe(output.Reader, input.Writer); - var applicationToTransport = new DuplexPipe(input.Reader, output.Writer); - - return new DuplexPipePair(applicationToTransport, transportToApplication); - } - - // This class exists to work around issues with value tuple on .NET Framework - public readonly struct DuplexPipePair - { - public IDuplexPipe Transport { get; } - public IDuplexPipe Application { get; } - - public DuplexPipePair(IDuplexPipe transport, IDuplexPipe application) - { - Transport = transport; - Application = application; - } - } -} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs new file mode 100644 index 000000000..0448c3154 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/EndpointRouteBuilderExtensions.cs @@ -0,0 +1,61 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http.Connections; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; +using MQTTnet.Server; +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; + +namespace MQTTnet.AspNetCore +{ + public static class EndpointRouteBuilderExtensions + { + /// + /// Specify the matching path for mqtt-over-websocket + /// + /// + /// + /// + public static ConnectionEndpointRouteBuilder MapMqtt(this IEndpointRouteBuilder endpoints, [StringSyntax("Route")] string pattern) + { + return endpoints.MapMqtt(pattern, null); + } + + /// + /// Specify the matching path for mqtt-over-websocket + /// + /// + /// + /// + /// + public static ConnectionEndpointRouteBuilder MapMqtt(this IEndpointRouteBuilder endpoints, [StringSyntax("Route")] string pattern, Action? configureOptions) + { + // check services.AddMqttServer() + endpoints.ServiceProvider.GetRequiredService(); + + endpoints.ServiceProvider.GetRequiredService().MapFlag = true; + return endpoints.MapConnectionHandler(pattern, ConfigureOptions); + + + void ConfigureOptions(HttpConnectionDispatcherOptions options) + { + options.Transports = HttpTransportType.WebSockets; + options.WebSockets.SubProtocolSelector = SelectSubProtocol; + configureOptions?.Invoke(options); + } + + static string SelectSubProtocol(IList requestedSubProtocolValues) + { + // Order the protocols to also match "mqtt", "mqttv-3.1", "mqttv-3.11" etc. + return requestedSubProtocolValues.OrderByDescending(p => p.Length).FirstOrDefault(p => p.ToLower().StartsWith("mqtt"))!; + } + } + } +} + diff --git a/Source/MQTTnet.AspnetCore/EndpointRouterExtensions.cs b/Source/MQTTnet.AspnetCore/EndpointRouterExtensions.cs deleted file mode 100644 index 8e96a6c28..000000000 --- a/Source/MQTTnet.AspnetCore/EndpointRouterExtensions.cs +++ /dev/null @@ -1,24 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using Microsoft.AspNetCore.Builder; -using Microsoft.AspNetCore.Routing; - -namespace MQTTnet.AspNetCore -{ - public static class EndpointRouterExtensions - { - public static void MapMqtt(this IEndpointRouteBuilder endpoints, string pattern) - { - ArgumentNullException.ThrowIfNull(endpoints); - - endpoints.MapConnectionHandler(pattern, options => - { - options.WebSockets.SubProtocolSelector = MqttSubProtocolSelector.SelectSubProtocol; - }); - } - } -} - diff --git a/Source/MQTTnet.AspnetCore/Features/PacketFragmentationFeature.cs b/Source/MQTTnet.AspnetCore/Features/PacketFragmentationFeature.cs new file mode 100644 index 000000000..eb05414a4 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Features/PacketFragmentationFeature.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using MQTTnet.Adapter; +using MQTTnet.Server; +using System; + +namespace MQTTnet.AspNetCore +{ + sealed class PacketFragmentationFeature(Func allowPacketFragmentationSelector) + { + public Func AllowPacketFragmentationSelector { get; } = allowPacketFragmentationSelector; + + public static bool CanAllowPacketFragmentation(IMqttChannelAdapter channelAdapter, MqttServerTcpEndpointBaseOptions? endpointOptions) + { + //if (endpointOptions != null && endpointOptions.AllowPacketFragmentationSelector != null) + //{ + // return endpointOptions.AllowPacketFragmentationSelector(channelAdapter); + //} + + // In the AspNetCore environment, we need to exclude WebSocket before AllowPacketFragmentation. + if (channelAdapter.IsWebSocketConnection() == true) + { + return false; + } + + return endpointOptions == null || endpointOptions.AllowPacketFragmentation; + } + } +} diff --git a/Source/MQTTnet.AspnetCore/Features/TlsConnectionFeature.cs b/Source/MQTTnet.AspnetCore/Features/TlsConnectionFeature.cs new file mode 100644 index 000000000..5fd7fd6e3 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Features/TlsConnectionFeature.cs @@ -0,0 +1,28 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Http.Features; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore +{ + sealed class TlsConnectionFeature : ITlsConnectionFeature + { + public static readonly TlsConnectionFeature WithoutClientCertificate = new(null); + + public X509Certificate2? ClientCertificate { get; set; } + + public Task GetClientCertificateAsync(CancellationToken cancellationToken) + { + return Task.FromResult(ClientCertificate); + } + + public TlsConnectionFeature(X509Certificate? clientCertificate) + { + ClientCertificate = clientCertificate as X509Certificate2; + } + } +} diff --git a/Source/MQTTnet.AspnetCore/Features/WebSocketConnectionFeature.cs b/Source/MQTTnet.AspnetCore/Features/WebSocketConnectionFeature.cs new file mode 100644 index 000000000..872a76515 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Features/WebSocketConnectionFeature.cs @@ -0,0 +1,14 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace MQTTnet.AspNetCore +{ + sealed class WebSocketConnectionFeature(string path) + { + /// + /// The path of WebSocket request. + /// + public string Path { get; } = path; + } +} diff --git a/Source/MQTTnet.AspnetCore/IMqttBuilder.cs b/Source/MQTTnet.AspnetCore/IMqttBuilder.cs new file mode 100644 index 000000000..a4438ff4f --- /dev/null +++ b/Source/MQTTnet.AspnetCore/IMqttBuilder.cs @@ -0,0 +1,13 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Extensions.DependencyInjection; + +namespace MQTTnet.AspNetCore +{ + public interface IMqttBuilder + { + IServiceCollection Services { get; } + } +} diff --git a/Source/MQTTnet.AspnetCore/IMqttClientBuilder.cs b/Source/MQTTnet.AspnetCore/IMqttClientBuilder.cs new file mode 100644 index 000000000..575ce7c61 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/IMqttClientBuilder.cs @@ -0,0 +1,13 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace MQTTnet.AspNetCore +{ + /// + /// Builder of + /// + public interface IMqttClientBuilder: IMqttBuilder + { + } +} diff --git a/Source/MQTTnet.AspnetCore/IMqttClientFactory.cs b/Source/MQTTnet.AspnetCore/IMqttClientFactory.cs new file mode 100644 index 000000000..30472ce54 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/IMqttClientFactory.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using MQTTnet.LowLevelClient; + +namespace MQTTnet.AspNetCore +{ + public interface IMqttClientFactory + { + IMqttClient CreateMqttClient(); + + ILowLevelMqttClient CreateLowLevelMqttClient(); + } +} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/IMqttServerBuilder.cs b/Source/MQTTnet.AspnetCore/IMqttServerBuilder.cs new file mode 100644 index 000000000..28c71acab --- /dev/null +++ b/Source/MQTTnet.AspnetCore/IMqttServerBuilder.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using MQTTnet.Server; + +namespace MQTTnet.AspNetCore +{ + /// + /// Builder of + /// + public interface IMqttServerBuilder : IMqttBuilder + { + } +} diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs new file mode 100644 index 000000000..f739de225 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientAdapterFactory.cs @@ -0,0 +1,22 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using MQTTnet.Adapter; +using MQTTnet.Diagnostics.Logger; +using MQTTnet.Formatter; +using System; + +namespace MQTTnet.AspNetCore +{ + sealed class AspNetCoreMqttClientAdapterFactory : IMqttClientAdapterFactory + { + public IMqttChannelAdapter CreateClientAdapter(MqttClientOptions options, MqttPacketInspector packetInspector, IMqttNetLogger logger) + { + ArgumentNullException.ThrowIfNull(nameof(options)); + var bufferWriter = new MqttBufferWriter(options.WriterBufferSize, options.WriterBufferSizeMax); + var formatter = new MqttPacketFormatterAdapter(options.ProtocolVersion, bufferWriter); + return new MqttClientChannelAdapter(formatter, options.ChannelOptions, options.AllowPacketFragmentation, packetInspector); + } + } +} diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientFactory.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientFactory.cs new file mode 100644 index 000000000..7d38129dc --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttClientFactory.cs @@ -0,0 +1,18 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using MQTTnet.Adapter; +using MQTTnet.Diagnostics.Logger; + +namespace MQTTnet.AspNetCore +{ + sealed class AspNetCoreMqttClientFactory : MqttClientFactory, IMqttClientFactory + { + public AspNetCoreMqttClientFactory( + IMqttNetLogger logger, + IMqttClientAdapterFactory clientAdapterFactory) : base(logger, clientAdapterFactory) + { + } + } +} diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs new file mode 100644 index 000000000..58b3bc387 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttHostedServer.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Extensions.Hosting; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore +{ + sealed class AspNetCoreMqttHostedServer : BackgroundService + { + private readonly AspNetCoreMqttServer _aspNetCoreMqttServer; + private readonly Task _applicationStartedTask; + + public AspNetCoreMqttHostedServer( + AspNetCoreMqttServer aspNetCoreMqttServer, + IHostApplicationLifetime hostApplicationLifetime) + { + _aspNetCoreMqttServer = aspNetCoreMqttServer; + _applicationStartedTask = WaitApplicationStartedAsync(hostApplicationLifetime); + } + + private static Task WaitApplicationStartedAsync(IHostApplicationLifetime hostApplicationLifetime) + { + var taskCompletionSource = new TaskCompletionSource(); + hostApplicationLifetime.ApplicationStarted.Register(OnApplicationStarted); + return taskCompletionSource.Task; + + void OnApplicationStarted() + { + taskCompletionSource.TrySetResult(); + } + } + + protected override async Task ExecuteAsync(CancellationToken stoppingToken) + { + await _applicationStartedTask.WaitAsync(stoppingToken).ConfigureAwait(false); + await _aspNetCoreMqttServer.StartAsync(stoppingToken).ConfigureAwait(false) ; + } + + public override Task StopAsync(CancellationToken cancellationToken) + { + return _aspNetCoreMqttServer.StopAsync(cancellationToken); + } + } +} diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs new file mode 100644 index 000000000..caa1c07ca --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttNetLogger.cs @@ -0,0 +1,41 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using MQTTnet.Diagnostics.Logger; +using System; + +namespace MQTTnet.AspNetCore +{ + sealed class AspNetCoreMqttNetLogger : IMqttNetLogger + { + private readonly ILoggerFactory _loggerFactory; + private readonly AspNetCoreMqttNetLoggerOptions _loggerOptions; + + public bool IsEnabled => true; + + public AspNetCoreMqttNetLogger( + ILoggerFactory loggerFactory, + IOptions loggerOptions) + { + _loggerFactory = loggerFactory; + _loggerOptions = loggerOptions.Value; + } + + public void Publish(MqttNetLogLevel logLevel, string? source, string? message, object[]? parameters, Exception? exception) + { + try + { + var categoryName = $"{_loggerOptions.CategoryNamePrefix}{source}"; + var logger = _loggerFactory.CreateLogger(categoryName); + var level = _loggerOptions.LogLevelConverter(logLevel); + logger.Log(level, exception, message, parameters ?? []); + } + catch (ObjectDisposedException) + { + } + } + } +} diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServer.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServer.cs new file mode 100644 index 000000000..e2966b2f0 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServer.cs @@ -0,0 +1,50 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using MQTTnet.Diagnostics.Logger; +using MQTTnet.Exceptions; +using MQTTnet.Server; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore; + +sealed class AspNetCoreMqttServer : MqttServer +{ + private readonly MqttConnectionHandler _connectionHandler; + private readonly MqttServerStopOptions _stopOptions; + private readonly IEnumerable _adapters; + + public AspNetCoreMqttServer( + MqttConnectionHandler connectionHandler, + MqttServerOptions serverOptions, + MqttServerStopOptions stopOptions, + IEnumerable adapters, + IMqttNetLogger logger) : base(serverOptions, adapters, logger) + { + _connectionHandler = connectionHandler; + _stopOptions = stopOptions; + _adapters = adapters; + } + + public Task StartAsync(CancellationToken cancellationToken) + { + if (!_connectionHandler.ListenFlag && + !_connectionHandler.UseFlag && + !_connectionHandler.MapFlag && + _adapters.All(item => item.GetType() == typeof(AspNetCoreMqttServerAdapter))) + { + throw new MqttConfigurationException("ListenMqtt() or UseMqtt() or MapMqtt() must be called in at least one place"); + } + + return base.StartAsync(); + } + + public Task StopAsync(CancellationToken cancellationToken) + { + return base.StopAsync(_stopOptions); + } +} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServerAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServerAdapter.cs new file mode 100644 index 000000000..f4e6e167e --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/AspNetCoreMqttServerAdapter.cs @@ -0,0 +1,55 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using MQTTnet.Adapter; +using MQTTnet.Diagnostics.Logger; +using MQTTnet.Server; +using System; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore; + +sealed class AspNetCoreMqttServerAdapter : IMqttServerAdapter +{ + readonly MqttConnectionHandler _connectionHandler; + public Func? ClientHandler + { + get => _connectionHandler.ClientHandler; + set => _connectionHandler.ClientHandler = value; + } + + public AspNetCoreMqttServerAdapter(MqttConnectionHandler connectionHandler) + { + _connectionHandler = connectionHandler; + } + + public Task StartAsync(MqttServerOptions options, IMqttNetLogger logger) + { + if (!_connectionHandler.ListenFlag) + { + if (options.DefaultEndpointOptions.IsEnabled) + { + var message = "DefaultEndpointOptions has been ignored because the user called UseMqtt() on the specified listener."; + logger.Publish(MqttNetLogLevel.Warning, nameof(AspNetCoreMqttServerAdapter), message, null, null); + } + + if (options.TlsEndpointOptions.IsEnabled) + { + var message = "TlsEndpointOptions has been ignored because the user called UseMqtt() on the specified listener."; + logger.Publish(MqttNetLogLevel.Warning, nameof(AspNetCoreMqttServerAdapter), message, null, null); + } + } + + return Task.CompletedTask; + } + + public Task StopAsync() + { + return Task.CompletedTask; + } + + public void Dispose() + { + } +} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs new file mode 100644 index 000000000..fe3caf6fa --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.Tcp.cs @@ -0,0 +1,193 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Http.Features; +using System; +using System.Net; +using System.Net.Security; +using System.Net.Sockets; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore +{ + partial class ClientConnectionContext + { + public static async Task CreateAsync(MqttClientTcpOptions options, CancellationToken cancellationToken) + { + Socket socket; + if (options.RemoteEndpoint is UnixDomainSocketEndPoint) + { + socket = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified); + } + else if (options.AddressFamily == AddressFamily.Unspecified) + { + socket = new Socket(SocketType.Stream, options.ProtocolType); + } + else + { + socket = new Socket(options.AddressFamily, SocketType.Stream, options.ProtocolType); + } + + if (options.LocalEndpoint != null) + { + socket.Bind(options.LocalEndpoint); + } + + socket.ReceiveBufferSize = options.BufferSize; + socket.SendBufferSize = options.BufferSize; + + if (options.ProtocolType == ProtocolType.Tcp && options.RemoteEndpoint is not UnixDomainSocketEndPoint) + { + // Other protocol types do not support the Nagle algorithm. + socket.NoDelay = options.NoDelay; + } + + if (options.LingerState != null) + { + socket.LingerState = options.LingerState; + } + + if (options.DualMode.HasValue) + { + // It is important to avoid setting the flag if no specific value is set by the user + // because on IPv4 only networks the setter will always throw an exception. Regardless + // of the actual value. + socket.DualMode = options.DualMode.Value; + } + + try + { + await socket.ConnectAsync(options.RemoteEndpoint, cancellationToken).ConfigureAwait(false); + } + catch (Exception) + { + socket.Dispose(); + throw; + } + + var networkStream = new NetworkStream(socket, ownsSocket: true); + if (options.TlsOptions?.UseTls != true) + { + return new ClientConnectionContext(networkStream) + { + LocalEndPoint = socket.LocalEndPoint, + RemoteEndPoint = socket.RemoteEndPoint, + }; + } + + var targetHost = options.TlsOptions.TargetHost; + if (string.IsNullOrEmpty(targetHost)) + { + if (options.RemoteEndpoint is DnsEndPoint dns) + { + targetHost = dns.Host; + } + } + + SslStream sslStream; + if (options.TlsOptions.CertificateSelectionHandler != null) + { + sslStream = new SslStream( + networkStream, + leaveInnerStreamOpen: false, + InternalUserCertificateValidationCallback, + InternalUserCertificateSelectionCallback); + } + else + { + // Use a different constructor depending on the options for MQTTnet so that we do not have + // to copy the exact same behavior of the selection handler. + sslStream = new SslStream( + networkStream, + leaveInnerStreamOpen: false, + InternalUserCertificateValidationCallback); + } + + var sslOptions = new SslClientAuthenticationOptions + { + ApplicationProtocols = options.TlsOptions.ApplicationProtocols, + ClientCertificates = LoadCertificates(), + EnabledSslProtocols = options.TlsOptions.SslProtocol, + CertificateRevocationCheckMode = options.TlsOptions.IgnoreCertificateRevocationErrors ? X509RevocationMode.NoCheck : options.TlsOptions.RevocationMode, + TargetHost = targetHost, + CipherSuitesPolicy = options.TlsOptions.CipherSuitesPolicy, + EncryptionPolicy = options.TlsOptions.EncryptionPolicy, + AllowRenegotiation = options.TlsOptions.AllowRenegotiation + }; + + if (options.TlsOptions.TrustChain?.Count > 0) + { + sslOptions.CertificateChainPolicy = new X509ChainPolicy + { + TrustMode = X509ChainTrustMode.CustomRootTrust, + VerificationFlags = X509VerificationFlags.IgnoreEndRevocationUnknown, + RevocationMode = options.TlsOptions.IgnoreCertificateRevocationErrors ? X509RevocationMode.NoCheck : options.TlsOptions.RevocationMode + }; + + sslOptions.CertificateChainPolicy.CustomTrustStore.AddRange(options.TlsOptions.TrustChain); + } + + try + { + await sslStream.AuthenticateAsClientAsync(sslOptions, cancellationToken).ConfigureAwait(false); + } + catch (Exception) + { + await sslStream.DisposeAsync().ConfigureAwait(false); + throw; + } + + var connection = new ClientConnectionContext(sslStream) + { + LocalEndPoint = socket.LocalEndPoint, + RemoteEndPoint = socket.RemoteEndPoint, + }; + + connection.Features.Set(new TlsConnectionFeature(sslStream.LocalCertificate)); + return connection; + + + X509Certificate InternalUserCertificateSelectionCallback(object sender, string targetHost, X509CertificateCollection? localCertificates, X509Certificate? remoteCertificate, string[] acceptableIssuers) + { + var certificateSelectionHandler = options?.TlsOptions?.CertificateSelectionHandler; + if (certificateSelectionHandler != null) + { + var eventArgs = new MqttClientCertificateSelectionEventArgs(targetHost, localCertificates, remoteCertificate, acceptableIssuers, options); + return certificateSelectionHandler(eventArgs); + } + + if (localCertificates?.Count > 0) + { + return localCertificates[0]; + } + + return null!; + } + + bool InternalUserCertificateValidationCallback(object sender, X509Certificate? x509Certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors) + { + var certificateValidationHandler = options?.TlsOptions?.CertificateValidationHandler; + if (certificateValidationHandler != null) + { + var eventArgs = new MqttClientCertificateValidationEventArgs(x509Certificate, chain, sslPolicyErrors, options); + return certificateValidationHandler(eventArgs); + } + + if (options?.TlsOptions?.IgnoreCertificateChainErrors ?? false) + { + sslPolicyErrors &= ~SslPolicyErrors.RemoteCertificateChainErrors; + } + + return sslPolicyErrors == SslPolicyErrors.None; + } + + X509CertificateCollection? LoadCertificates() + { + return options.TlsOptions.ClientCertificatesProvider?.GetCertificates(); + } + } + } +} diff --git a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs new file mode 100644 index 000000000..4965278fc --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.WebSocket.cs @@ -0,0 +1,204 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Http.Features; +using MQTTnet.Exceptions; +using System; +using System.IO; +using System.Net; +using System.Net.WebSockets; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore +{ + partial class ClientConnectionContext + { + public static async Task CreateAsync(MqttClientWebSocketOptions options, CancellationToken cancellationToken) + { + var uri = new Uri(options.Uri, UriKind.Absolute); + if (uri.Scheme != Uri.UriSchemeWs && uri.Scheme != Uri.UriSchemeWss) + { + throw new MqttConfigurationException("The scheme of the WebSocket Uri must be ws or wss."); + } + + // Patching TlsOptions + options.TlsOptions ??= new MqttClientTlsOptions(); + // Scheme decides whether to use TLS + options.TlsOptions.UseTls = uri.Scheme == Uri.UriSchemeWss; + + var clientWebSocket = new ClientWebSocket(); + try + { + SetupClientWebSocket(clientWebSocket.Options, options); + await clientWebSocket.ConnectAsync(uri, cancellationToken).ConfigureAwait(false); + } + catch + { + // Prevent a memory leak when always creating new instance which will fail while connecting. + clientWebSocket.Dispose(); + throw; + } + + var webSocketStream = new WebSocketStream(clientWebSocket); + var connection = new ClientConnectionContext(webSocketStream) + { + LocalEndPoint = null, + RemoteEndPoint = new DnsEndPoint(uri.Host, uri.Port), + }; + + connection.Features.Set(new WebSocketConnectionFeature(uri.AbsolutePath)); + if (uri.Scheme == Uri.UriSchemeWss) + { + connection.Features.Set(TlsConnectionFeature.WithoutClientCertificate); + } + return connection; + } + + private static void SetupClientWebSocket(ClientWebSocketOptions webSocketOptions, MqttClientWebSocketOptions options) + { + if (options.ProxyOptions != null) + { + webSocketOptions.Proxy = CreateProxy(options); + } + + if (options.RequestHeaders != null) + { + foreach (var requestHeader in options.RequestHeaders) + { + webSocketOptions.SetRequestHeader(requestHeader.Key, requestHeader.Value); + } + } + + if (options.SubProtocols != null) + { + foreach (var subProtocol in options.SubProtocols) + { + webSocketOptions.AddSubProtocol(subProtocol); + } + } + + if (options.CookieContainer != null) + { + webSocketOptions.Cookies = options.CookieContainer; + } + + if (options.TlsOptions.UseTls) + { + var certificates = options.TlsOptions.ClientCertificatesProvider?.GetCertificates(); + if (certificates?.Count > 0) + { + webSocketOptions.ClientCertificates = certificates; + } + } + + // Only set the value if it is actually true. This property is not supported on all platforms + // and will throw a _PlatformNotSupported_ (i.e. WASM) exception when being used regardless of the actual value. + if (options.UseDefaultCredentials) + { + webSocketOptions.UseDefaultCredentials = options.UseDefaultCredentials; + } + + if (options.KeepAliveInterval != WebSocket.DefaultKeepAliveInterval) + { + webSocketOptions.KeepAliveInterval = options.KeepAliveInterval; + } + + if (options.Credentials != null) + { + webSocketOptions.Credentials = options.Credentials; + } + + var certificateValidationHandler = options.TlsOptions.CertificateValidationHandler; + if (certificateValidationHandler != null) + { + webSocketOptions.RemoteCertificateValidationCallback = (_, certificate, chain, sslPolicyErrors) => + { + // TODO: Find a way to add client options to same callback. Problem is that they have a different type. + var context = new MqttClientCertificateValidationEventArgs(certificate, chain, sslPolicyErrors, options); + return certificateValidationHandler(context); + }; + + var certificateSelectionHandler = options.TlsOptions.CertificateSelectionHandler; + if (certificateSelectionHandler != null) + { + throw new NotSupportedException("Remote certificate selection callback is not supported for WebSocket connections."); + } + } + } + + private static IWebProxy? CreateProxy(MqttClientWebSocketOptions options) + { + if (!Uri.TryCreate(options.ProxyOptions?.Address, UriKind.Absolute, out var proxyUri)) + { + return null; + } + + + WebProxy webProxy; + if (!string.IsNullOrEmpty(options.ProxyOptions.Username) && !string.IsNullOrEmpty(options.ProxyOptions.Password)) + { + var credentials = new NetworkCredential(options.ProxyOptions.Username, options.ProxyOptions.Password, options.ProxyOptions.Domain); + webProxy = new WebProxy(proxyUri, options.ProxyOptions.BypassOnLocal, options.ProxyOptions.BypassList, credentials); + } + else + { + webProxy = new WebProxy(proxyUri, options.ProxyOptions.BypassOnLocal, options.ProxyOptions.BypassList); + } + + if (options.ProxyOptions.UseDefaultCredentials) + { + // Only update the property if required because setting it to false will alter + // the used credentials internally! + webProxy.UseDefaultCredentials = true; + } + + return webProxy; + } + + + private sealed class WebSocketStream(WebSocket webSocket) : Stream + { + private readonly WebSocket _webSocket = webSocket; + + public override bool CanRead => true; + public override bool CanSeek => false; + public override bool CanWrite => true; + public override long Length => throw new NotSupportedException(); + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public override void Flush() { } + public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + return _webSocket.SendAsync(buffer, WebSocketMessageType.Binary, false, cancellationToken); + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + var result = await _webSocket.ReceiveAsync(buffer, cancellationToken); + return result.MessageType == WebSocketMessageType.Close ? 0 : result.Count; + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + return Task.CompletedTask; + } + + protected override void Dispose(bool disposing) + { + _webSocket.Dispose(); + base.Dispose(disposing); + } + } + } +} diff --git a/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs new file mode 100644 index 000000000..e9ef720d6 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/ClientConnectionContext.cs @@ -0,0 +1,68 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http.Features; +using System; +using System.Collections.Generic; +using System.IO; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore +{ + sealed partial class ClientConnectionContext : ConnectionContext + { + private readonly Stream _stream; + private readonly CancellationTokenSource _connectionCloseSource = new(); + + private IDictionary? _items; + + public override IDuplexPipe Transport { get; set; } + + public override CancellationToken ConnectionClosed + { + get => _connectionCloseSource.Token; + set => throw new InvalidOperationException(); + } + + public override string ConnectionId { get; set; } = string.Empty; + + public override IFeatureCollection Features { get; } = new FeatureCollection(); + + public override IDictionary Items + { + get => _items ??= new Dictionary(); + set => _items = value; + } + + public ClientConnectionContext(Stream stream) + { + _stream = stream; + Transport = new StreamTransport(stream); + } + + public override async ValueTask DisposeAsync() + { + await _stream.DisposeAsync().ConfigureAwait(false); + _connectionCloseSource.Cancel(); + _connectionCloseSource.Dispose(); + } + + public override void Abort() + { + _stream.Close(); + _connectionCloseSource.Cancel(); + } + + + private class StreamTransport(Stream stream) : IDuplexPipe + { + public PipeReader Input { get; } = PipeReader.Create(stream, new StreamPipeReaderOptions(leaveOpen: true)); + + public PipeWriter Output { get; } = PipeWriter.Create(stream, new StreamPipeWriterOptions(leaveOpen: true)); + } + } +} diff --git a/Source/MQTTnet.AspnetCore/Internal/IAspNetCoreMqttChannel.cs b/Source/MQTTnet.AspnetCore/Internal/IAspNetCoreMqttChannel.cs new file mode 100644 index 000000000..5eb5358ae --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/IAspNetCoreMqttChannel.cs @@ -0,0 +1,17 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Http; + +namespace MQTTnet.AspNetCore +{ + interface IAspNetCoreMqttChannel + { + HttpContext? HttpContext { get; } + + bool IsWebSocketConnection { get; } + + TFeature? GetFeature(); + } +} diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttBufferWriterPool.cs b/Source/MQTTnet.AspnetCore/Internal/MqttBufferWriterPool.cs new file mode 100644 index 000000000..c9cf437f5 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/MqttBufferWriterPool.cs @@ -0,0 +1,104 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Extensions.Options; +using MQTTnet.Formatter; +using MQTTnet.Server; +using System; +using System.Collections; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; + +namespace MQTTnet.AspNetCore +{ + [DebuggerDisplay("Count = {Count}")] + sealed class MqttBufferWriterPool : IReadOnlyCollection + { + private readonly MqttServerOptions _serverOptions; + private readonly IOptionsMonitor _poolOptions; + private readonly ConcurrentQueue _bufferWriterQueue = new(); + + public int Count => _bufferWriterQueue.Count; + + public MqttBufferWriterPool( + MqttServerOptions serverOptions, + IOptionsMonitor poolOptions) + { + _serverOptions = serverOptions; + _poolOptions = poolOptions; + } + + public ChannelMqttBufferWriter Rent() + { + if (_bufferWriterQueue.TryDequeue(out var bufferWriter)) + { + bufferWriter.Reset(); + } + else + { + var writer = new MqttBufferWriter(_serverOptions.WriterBufferSize, _serverOptions.WriterBufferSizeMax); + bufferWriter = new ChannelMqttBufferWriter(writer); + } + return bufferWriter; + } + + public bool Return(ChannelMqttBufferWriter bufferWriter) + { + if (CanReturn(bufferWriter)) + { + _bufferWriterQueue.Enqueue(bufferWriter); + return true; + } + return false; + } + + private bool CanReturn(ChannelMqttBufferWriter bufferWriter) + { + var options = _poolOptions.CurrentValue; + if (bufferWriter.Lifetime < options.MaxLifetime) + { + return true; + } + + if (options.LargeBufferSizeEnabled && bufferWriter.BufferSize > _serverOptions.WriterBufferSize) + { + return true; + } + + return false; + } + + public IEnumerator GetEnumerator() + { + return _bufferWriterQueue.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return _bufferWriterQueue.GetEnumerator(); + } + + + [DebuggerDisplay("BufferSize = {BufferSize}, LifeTime = {LifeTime}")] + public sealed class ChannelMqttBufferWriter(MqttBufferWriter bufferWriter) + { + private long _tickCount = Environment.TickCount64; + private readonly MqttBufferWriter _bufferWriter = bufferWriter; + + public int BufferSize => _bufferWriter.GetBuffer().Length; + public TimeSpan Lifetime => TimeSpan.FromMilliseconds(Environment.TickCount64 - _tickCount); + + public void Reset() + { + _tickCount = Environment.TickCount64; + } + + public static implicit operator MqttBufferWriter(ChannelMqttBufferWriter channelMqttBufferWriter) + { + return channelMqttBufferWriter._bufferWriter; + } + } + } +} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs new file mode 100644 index 000000000..dc6cdacdd --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/MqttChannel.cs @@ -0,0 +1,343 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using MQTTnet.Adapter; +using MQTTnet.Exceptions; +using MQTTnet.Formatter; +using MQTTnet.Internal; +using MQTTnet.Packets; +using System; +using System.Buffers; +using System.IO; +using System.IO.Pipelines; +using System.Net; +using System.Net.Sockets; +using System.Runtime.InteropServices; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore; + +class MqttChannel : IAspNetCoreMqttChannel, IDisposable +{ + readonly ConnectionContext _connection; + readonly HttpContext? _httpContext; + + readonly AsyncLock _writerLock = new(); + readonly PipeReader _input; + readonly PipeWriter _output; + readonly MqttPacketInspector? _packetInspector; + bool _allowPacketFragmentation = false; + + public MqttPacketFormatterAdapter PacketFormatterAdapter { get; } + + public long BytesReceived { get; private set; } + + public long BytesSent { get; private set; } + + public X509Certificate2? ClientCertificate { get; } + + public EndPoint? RemoteEndPoint { get; private set; } + + public bool IsSecureConnection { get; } + + public bool IsWebSocketConnection { get; } + + public HttpContext? HttpContext => _httpContext; + + public MqttChannel( + MqttPacketFormatterAdapter packetFormatterAdapter, + ConnectionContext connection, + HttpContext? httpContext, + MqttPacketInspector? packetInspector) + { + PacketFormatterAdapter = packetFormatterAdapter; + _connection = connection; + _httpContext = httpContext; + _packetInspector = packetInspector; + + _input = connection.Transport.Input; + _output = connection.Transport.Output; + + var tlsConnectionFeature = GetFeature(); + var webSocketConnectionFeature = GetFeature(); + + IsWebSocketConnection = webSocketConnectionFeature != null; + IsSecureConnection = tlsConnectionFeature != null; + ClientCertificate = tlsConnectionFeature?.ClientCertificate; + RemoteEndPoint = GetRemoteEndPoint(connection.RemoteEndPoint, httpContext); + } + + + public TFeature? GetFeature() + { + var feature = _connection.Features.Get(); + if (feature != null) + { + return feature; + } + + if (_httpContext != null) + { + return _httpContext.Features.Get(); + } + + return default; + } + + private static EndPoint? GetRemoteEndPoint(EndPoint? remoteEndPoint, HttpContext? httpContext) + { + if (remoteEndPoint != null) + { + return remoteEndPoint; + } + + if (httpContext != null) + { + var httpConnection = httpContext.Connection; + var remoteAddress = httpConnection.RemoteIpAddress; + if (remoteAddress != null) + { + return new IPEndPoint(remoteAddress, httpConnection.RemotePort); + } + } + + return null; + } + + public void SetAllowPacketFragmentation(bool value) + { + _allowPacketFragmentation = value; + } + + public async Task DisconnectAsync() + { + try + { + await _input.CompleteAsync().ConfigureAwait(false); + await _output.CompleteAsync().ConfigureAwait(false); + } + catch (Exception exception) + { + if (!WrapAndThrowException(exception)) + { + throw; + } + } + } + + public virtual void Dispose() + { + _writerLock.Dispose(); + } + + public async Task ReceivePacketAsync(CancellationToken cancellationToken) + { + try + { + return await ReceivePacketCoreAsync(cancellationToken).ConfigureAwait(false); + } + catch (OperationCanceledException) + { + } + catch (ObjectDisposedException) + { + } + catch (Exception exception) + { + if (!WrapAndThrowException(exception)) + { + throw; + } + } + + return null; + } + + private async Task ReceivePacketCoreAsync(CancellationToken cancellationToken) + { + try + { + _packetInspector?.BeginReceivePacket(); + + while (!cancellationToken.IsCancellationRequested) + { + ReadResult readResult; + var readTask = _input.ReadAsync(cancellationToken); + if (readTask.IsCompleted) + { + readResult = readTask.Result; + } + else + { + readResult = await readTask.ConfigureAwait(false); + } + + var buffer = readResult.Buffer; + + var consumed = buffer.Start; + var observed = buffer.Start; + + try + { + if (!buffer.IsEmpty) + { + if (PacketFormatterAdapter.TryDecode(buffer, _packetInspector, out var packet, out consumed, out observed, out var received)) + { + BytesReceived += received; + + if (_packetInspector != null) + { + await _packetInspector.EndReceivePacket().ConfigureAwait(false); + } + return packet; + } + } + else if (readResult.IsCompleted) + { + throw new MqttCommunicationException("Connection Aborted"); + } + } + finally + { + // The buffer was sliced up to where it was consumed, so we can just advance to the start. + // We mark examined as buffer.End so that if we didn't receive a full frame, we'll wait for more data + // before yielding the read again. + _input.AdvanceTo(consumed, observed); + } + } + } + catch (Exception) + { + // completing the channel makes sure that there is no more data read after a protocol error + await _input.CompleteAsync().ConfigureAwait(false); + await _output.CompleteAsync().ConfigureAwait(false); + + throw; + } + + cancellationToken.ThrowIfCancellationRequested(); + return null; + } + + public void ResetStatistics() + { + BytesReceived = 0; + BytesSent = 0; + } + + public async Task SendPacketAsync(MqttPacket packet, CancellationToken cancellationToken) + { + try + { + await SendPacketCoreAsync(packet, cancellationToken).ConfigureAwait(false); + } + catch (Exception exception) + { + if (!WrapAndThrowException(exception)) + { + throw; + } + } + } + + private async Task SendPacketCoreAsync(MqttPacket packet, CancellationToken cancellationToken) + { + using (await _writerLock.EnterAsync(cancellationToken).ConfigureAwait(false)) + { + try + { + var buffer = PacketFormatterAdapter.Encode(packet); + if (_packetInspector != null) + { + await _packetInspector.BeginSendPacket(buffer).ConfigureAwait(false); + } + + if (buffer.Payload.Length == 0) + { + // zero copy + // https://github.com/dotnet/runtime/blob/e31ddfdc4f574b26231233dc10c9a9c402f40590/src/libraries/System.IO.Pipelines/src/System/IO/Pipelines/StreamPipeWriter.cs#L279 + await _output.WriteAsync(buffer.Packet, cancellationToken).ConfigureAwait(false); + } + else if (_allowPacketFragmentation) + { + await _output.WriteAsync(buffer.Packet, cancellationToken).ConfigureAwait(false); + foreach (var memory in buffer.Payload) + { + await _output.WriteAsync(memory, cancellationToken).ConfigureAwait(false); + } + } + else + { + // Make sure the MQTT packet is in a WebSocket frame to be compatible with JavaScript WebSocket + WritePacketBuffer(_output, buffer); + await _output.FlushAsync(cancellationToken).ConfigureAwait(false); + } + + BytesSent += buffer.Length; + } + finally + { + PacketFormatterAdapter.Cleanup(); + } + } + } + + static void WritePacketBuffer(PipeWriter output, MqttPacketBuffer buffer) + { + // copy MqttPacketBuffer's Packet and Payload to the same buffer block of PipeWriter + // MqttPacket will be transmitted within the bounds of a WebSocket frame after PipeWriter.FlushAsync + + var span = output.GetSpan(buffer.Length); + + buffer.Packet.AsSpan().CopyTo(span); + var offset = buffer.Packet.Count; + buffer.Payload.CopyTo(destination: span.Slice(offset)); + output.Advance(buffer.Length); + } + + public static bool WrapAndThrowException(Exception exception) + { + if (exception is OperationCanceledException || + exception is MqttCommunicationTimedOutException || + exception is MqttCommunicationException || + exception is MqttProtocolViolationException) + { + return false; + } + + if (exception is IOException && exception.InnerException is SocketException innerException) + { + exception = innerException; + } + + if (exception is SocketException socketException) + { + if (socketException.SocketErrorCode == SocketError.OperationAborted) + { + throw new OperationCanceledException(); + } + + if (socketException.SocketErrorCode == SocketError.ConnectionAborted) + { + throw new MqttCommunicationException(socketException); + } + } + + if (exception is COMException comException) + { + const uint ErrorOperationAborted = 0x800703E3; + if ((uint)comException.HResult == ErrorOperationAborted) + { + throw new OperationCanceledException(); + } + } + + throw new MqttCommunicationException(exception); + } +} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs new file mode 100644 index 000000000..311a469ba --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/MqttClientChannelAdapter.cs @@ -0,0 +1,136 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http; +using MQTTnet.Adapter; +using MQTTnet.Formatter; +using MQTTnet.Packets; +using System; +using System.Net; +using System.Runtime.CompilerServices; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore; + +sealed class MqttClientChannelAdapter : IAsyncDisposable, IMqttChannelAdapter, IAspNetCoreMqttChannel +{ + private bool _disposed = false; + private ConnectionContext? _connection; + private MqttChannel? _channel; + private readonly MqttPacketFormatterAdapter _packetFormatterAdapter; + private readonly IMqttClientChannelOptions _channelOptions; + private readonly bool _allowPacketFragmentation; + private readonly MqttPacketInspector? _packetInspector; + + public MqttClientChannelAdapter( + MqttPacketFormatterAdapter packetFormatterAdapter, + IMqttClientChannelOptions channelOptions, + bool allowPacketFragmentation, + MqttPacketInspector? packetInspector) + { + _packetFormatterAdapter = packetFormatterAdapter; + _channelOptions = channelOptions; + _allowPacketFragmentation = allowPacketFragmentation; + _packetInspector = packetInspector; + } + + public MqttPacketFormatterAdapter PacketFormatterAdapter => GetChannel().PacketFormatterAdapter; + + public long BytesReceived => GetChannel().BytesReceived; + + public long BytesSent => GetChannel().BytesSent; + + public X509Certificate2? ClientCertificate => GetChannel().ClientCertificate; + + public EndPoint? RemoteEndPoint => GetChannel().RemoteEndPoint; + + public bool IsSecureConnection => GetChannel().IsSecureConnection; + + public bool IsWebSocketConnection => GetChannel().IsSecureConnection; + + public HttpContext? HttpContext => GetChannel().HttpContext; + + public TFeature? GetFeature() + { + return GetChannel().GetFeature(); + } + + + public async Task ConnectAsync(CancellationToken cancellationToken) + { + try + { + _connection = _channelOptions switch + { + MqttClientTcpOptions tcpOptions => await ClientConnectionContext.CreateAsync(tcpOptions, cancellationToken).ConfigureAwait(false), + MqttClientWebSocketOptions webSocketOptions => await ClientConnectionContext.CreateAsync(webSocketOptions, cancellationToken).ConfigureAwait(false), + _ => throw new NotSupportedException(), + }; + _channel = new MqttChannel(_packetFormatterAdapter, _connection, httpContext: null, _packetInspector); + _channel.SetAllowPacketFragmentation(_allowPacketFragmentation); + } + catch (Exception ex) + { + if (!MqttChannel.WrapAndThrowException(ex)) + { + throw; + } + } + } + + public Task DisconnectAsync(CancellationToken cancellationToken) + { + return GetChannel().DisconnectAsync(); + } + + public async ValueTask DisposeAsync() + { + if (_disposed) + { + return; + } + + _disposed = true; + + if (_channel != null) + { + await _channel.DisconnectAsync().ConfigureAwait(false); + _channel.Dispose(); + } + + if (_connection != null) + { + await _connection.DisposeAsync().ConfigureAwait(false); + } + } + + public void Dispose() + { + DisposeAsync().ConfigureAwait(false).GetAwaiter().GetResult(); + } + + public Task ReceivePacketAsync(CancellationToken cancellationToken) + { + return GetChannel().ReceivePacketAsync(cancellationToken); + } + + public void ResetStatistics() + { + GetChannel().ResetStatistics(); + } + + public Task SendPacketAsync(MqttPacket packet, CancellationToken cancellationToken) + { + return GetChannel().SendPacketAsync(packet, cancellationToken); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private MqttChannel GetChannel() + { + return _channel ?? throw new InvalidOperationException(); + } +} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs new file mode 100644 index 000000000..c5675b18a --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionHandler.cs @@ -0,0 +1,75 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Http.Connections; +using MQTTnet.Adapter; +using MQTTnet.Diagnostics.Logger; +using MQTTnet.Formatter; +using MQTTnet.Server; +using System; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore; + +sealed class MqttConnectionHandler : ConnectionHandler +{ + readonly IMqttNetLogger _logger; + readonly MqttBufferWriterPool _bufferWriterPool; + + public bool UseFlag { get; set; } + + public bool MapFlag { get; set; } + + public bool ListenFlag { get; set; } + + public Func? ClientHandler { get; set; } + + public MqttConnectionHandler( + IMqttNetLogger logger, + MqttBufferWriterPool bufferWriterPool) + { + _logger = logger; + _bufferWriterPool = bufferWriterPool; + } + + public override async Task OnConnectedAsync(ConnectionContext connection) + { + var clientHandler = ClientHandler; + if (clientHandler == null) + { + connection.Abort(); + _logger.Publish(MqttNetLogLevel.Warning, nameof(MqttConnectionHandler), $"{nameof(MqttServer)} has not been started yet.", null, null); + return; + } + + // required for websocket transport to work + var transferFormatFeature = connection.Features.Get(); + if (transferFormatFeature != null) + { + transferFormatFeature.ActiveFormat = TransferFormat.Binary; + } + + // WebSocketConnectionFeature will be accessed in MqttChannel + var httpContext = connection.GetHttpContext(); + if (httpContext != null && httpContext.WebSockets.IsWebSocketRequest) + { + var path = httpContext.Request.Path; + connection.Features.Set(new WebSocketConnectionFeature(path)); + } + + var bufferWriter = _bufferWriterPool.Rent(); + try + { + var formatter = new MqttPacketFormatterAdapter(bufferWriter); + using var adapter = new MqttServerChannelAdapter(formatter, connection, httpContext); + await clientHandler(adapter).ConfigureAwait(false); + } + finally + { + _bufferWriterPool.Return(bufferWriter); + } + } +} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs new file mode 100644 index 000000000..58b7aef8d --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/MqttConnectionMiddleware.cs @@ -0,0 +1,76 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Connections; +using MQTTnet.Adapter; +using System; +using System.Buffers; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore; + +/// +/// Middleware that connection using the specified MQTT protocols +/// +sealed class MqttConnectionMiddleware +{ + private static readonly byte[] _mqtt = "MQTT"u8.ToArray(); + private static readonly byte[] _mqisdp = "MQIsdp"u8.ToArray(); + private readonly MqttConnectionHandler _connectionHandler; + + public MqttConnectionMiddleware(MqttConnectionHandler connectionHandler) + { + _connectionHandler = connectionHandler; + } + + public async Task InvokeAsync( + ConnectionDelegate next, + ConnectionContext connection, + MqttProtocols protocols, + Func? allowPacketFragmentationSelector) + { + if (allowPacketFragmentationSelector != null) + { + connection.Features.Set(new PacketFragmentationFeature(allowPacketFragmentationSelector)); + } + + if (protocols == MqttProtocols.MqttAndWebSocket) + { + var input = connection.Transport.Input; + var readResult = await input.ReadAsync(); + var isMqtt = IsMqttRequest(readResult.Buffer); + input.AdvanceTo(readResult.Buffer.Start); + + protocols = isMqtt ? MqttProtocols.Mqtt : MqttProtocols.WebSocket; + } + + if (protocols == MqttProtocols.Mqtt) + { + await _connectionHandler.OnConnectedAsync(connection).ConfigureAwait(false); + } + else if (protocols == MqttProtocols.WebSocket) + { + await next(connection).ConfigureAwait(false); + } + else + { + throw new NotSupportedException(protocols.ToString()); + } + } + + public static bool IsMqttRequest(ReadOnlySequence buffer) + { + if (!buffer.IsEmpty) + { + var span = buffer.FirstSpan; + if (span.Length > 4) + { + var protocol = span[4..]; + return protocol.StartsWith(_mqtt) || protocol.StartsWith(_mqisdp); + } + } + + return false; + } +} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttOptionsFactory.cs b/Source/MQTTnet.AspnetCore/Internal/MqttOptionsFactory.cs new file mode 100644 index 000000000..18d21c9a1 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/MqttOptionsFactory.cs @@ -0,0 +1,50 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Extensions.Options; +using System; +using System.Collections.Generic; + +namespace MQTTnet.AspNetCore +{ + class MqttOptionsFactory where TOptions : class + { + private readonly Func _defaultOptionsFactory; + private readonly IEnumerable> _setups; + private readonly IEnumerable> _postConfigures; + + public MqttOptionsFactory( + Func defaultOptionsFactory, + IEnumerable> setups, + IEnumerable> postConfigures) + { + _defaultOptionsFactory = defaultOptionsFactory; + _setups = setups; + _postConfigures = postConfigures; + } + + public TOptions CreateOptions() + { + var options = _defaultOptionsFactory(); + var name = Options.DefaultName; + + foreach (var setup in _setups) + { + if (setup is IConfigureNamedOptions namedSetup) + { + namedSetup.Configure(name, options); + } + else if (name == Options.DefaultName) + { + setup.Configure(options); + } + } + foreach (var post in _postConfigures) + { + post.PostConfigure(name, options); + } + return options; + } + } +} diff --git a/Source/MQTTnet.AspnetCore/ReaderExtensions.cs b/Source/MQTTnet.AspnetCore/Internal/MqttPacketFormatterAdapterExtensions.cs similarity index 90% rename from Source/MQTTnet.AspnetCore/ReaderExtensions.cs rename to Source/MQTTnet.AspnetCore/Internal/MqttPacketFormatterAdapterExtensions.cs index 9b4f24ca5..796bdc01f 100644 --- a/Source/MQTTnet.AspnetCore/ReaderExtensions.cs +++ b/Source/MQTTnet.AspnetCore/Internal/MqttPacketFormatterAdapterExtensions.cs @@ -2,22 +2,24 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Buffers; -using System.Runtime.InteropServices; using MQTTnet.Adapter; using MQTTnet.Exceptions; using MQTTnet.Formatter; using MQTTnet.Packets; +using System; +using System.Buffers; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.InteropServices; namespace MQTTnet.AspNetCore; -public static class ReaderExtensions +static class MqttPacketFormatterAdapterExtensions { public static bool TryDecode( this MqttPacketFormatterAdapter formatter, in ReadOnlySequence input, - out MqttPacket packet, + MqttPacketInspector? packetInspector, + [MaybeNullWhen(false)] out MqttPacket packet, out SequencePosition consumed, out SequencePosition observed, out int bytesRead) @@ -50,6 +52,12 @@ public static bool TryDecode( var bodySlice = copy.Slice(0, bodyLength); var bodySegment = GetArraySegment(ref bodySlice); + if (packetInspector != null) + { + packetInspector.FillReceiveBuffer(input.Slice(input.Start, headerLength).ToArray()); + packetInspector.FillReceiveBuffer(bodySegment.ToArray()); + } + var receivedMqttPacket = new ReceivedMqttPacket(fixedHeader, bodySegment, headerLength + bodyLength); if (formatter.ProtocolVersion == MqttProtocolVersion.Unknown) { diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs b/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs new file mode 100644 index 000000000..c6b565ac6 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/MqttServerChannelAdapter.cs @@ -0,0 +1,46 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http; +using MQTTnet.Adapter; +using MQTTnet.Formatter; +using System.Threading; +using System.Threading.Tasks; + +namespace MQTTnet.AspNetCore; + +sealed class MqttServerChannelAdapter : MqttChannel, IMqttChannelAdapter, IAspNetCoreMqttChannel +{ + public MqttServerChannelAdapter(MqttPacketFormatterAdapter packetFormatterAdapter, ConnectionContext connection, HttpContext? httpContext) + : base(packetFormatterAdapter, connection, httpContext, packetInspector: null) + { + var packetFragmentationFeature = GetFeature(); + if (packetFragmentationFeature == null) + { + var value = PacketFragmentationFeature.CanAllowPacketFragmentation(this, null); + SetAllowPacketFragmentation(value); + } + else + { + var value = packetFragmentationFeature.AllowPacketFragmentationSelector(this); + SetAllowPacketFragmentation(value); + } + } + + /// + /// This method will never be called + /// + /// + /// + public Task ConnectAsync(CancellationToken cancellationToken) + { + return Task.CompletedTask; + } + + public Task DisconnectAsync(CancellationToken cancellationToken) + { + return base.DisconnectAsync(); + } +} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttServerOptionsFactory.cs b/Source/MQTTnet.AspnetCore/Internal/MqttServerOptionsFactory.cs new file mode 100644 index 000000000..452d29e3c --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/MqttServerOptionsFactory.cs @@ -0,0 +1,21 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Extensions.Options; +using MQTTnet.Server; +using System.Collections.Generic; + +namespace MQTTnet.AspNetCore +{ + sealed class MqttServerOptionsFactory : MqttOptionsFactory + { + public MqttServerOptionsFactory( + IOptions optionsBuilderOptions, + IEnumerable> setups, + IEnumerable> postConfigures) + : base(optionsBuilderOptions.Value.Build, setups, postConfigures) + { + } + } +} diff --git a/Source/MQTTnet.AspnetCore/Internal/MqttServerStopOptionsFactory.cs b/Source/MQTTnet.AspnetCore/Internal/MqttServerStopOptionsFactory.cs new file mode 100644 index 000000000..b49570e85 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/Internal/MqttServerStopOptionsFactory.cs @@ -0,0 +1,21 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Extensions.Options; +using MQTTnet.Server; +using System.Collections.Generic; + +namespace MQTTnet.AspNetCore +{ + sealed class MqttServerStopOptionsFactory : MqttOptionsFactory + { + public MqttServerStopOptionsFactory( + IOptions optionsBuilderOptions, + IEnumerable> setups, + IEnumerable> postConfigures) + : base(optionsBuilderOptions.Value.Build, setups, postConfigures) + { + } + } +} diff --git a/Source/MQTTnet.AspnetCore/InternalsVisible.cs b/Source/MQTTnet.AspnetCore/InternalsVisible.cs new file mode 100644 index 000000000..b823bc96d --- /dev/null +++ b/Source/MQTTnet.AspnetCore/InternalsVisible.cs @@ -0,0 +1,7 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +[assembly: System.Runtime.CompilerServices.InternalsVisibleTo("MQTTnet.Tests")] +[assembly: System.Runtime.CompilerServices.InternalsVisibleTo("MQTTnet.AspTestApp")] +[assembly: System.Runtime.CompilerServices.InternalsVisibleTo("MQTTnet.Benchmarks")] \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs new file mode 100644 index 000000000..afbbed92b --- /dev/null +++ b/Source/MQTTnet.AspnetCore/KestrelServerOptionsExtensions.cs @@ -0,0 +1,171 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.AspNetCore.Server.Kestrel.Https; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using MQTTnet.Exceptions; +using MQTTnet.Server; +using System; +using System.Net; +using System.Net.Sockets; +using System.Security.Cryptography.X509Certificates; + +namespace MQTTnet.AspNetCore +{ + public static class KestrelServerOptionsExtensions + { + /// + /// Listen all endponts in + /// + /// + /// + /// + /// + public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel, MqttProtocols protocols = MqttProtocols.MqttAndWebSocket) + { + return kestrel.ListenMqtt(protocols, default(Action)); + } + + /// + /// Listen all endponts in + /// + /// + /// + /// + /// + /// + public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel, MqttProtocols protocols, X509Certificate2? serverCertificate) + { + return kestrel.ListenMqtt(protocols, tls => tls.ServerCertificate = serverCertificate); + } + + /// + /// Listen all endponts in + /// + /// + /// + /// + /// + /// + public static KestrelServerOptions ListenMqtt(this KestrelServerOptions kestrel, MqttProtocols protocols, Action? tlsConfigure) + { + // check services.AddMqttServer() + kestrel.ApplicationServices.GetRequiredService(); + + var serverOptions = kestrel.ApplicationServices.GetRequiredService(); + var connectionHandler = kestrel.ApplicationServices.GetRequiredService(); + var listenSocketFactory = kestrel.ApplicationServices.GetRequiredService>().Value.CreateBoundListenSocket ?? SocketTransportOptions.CreateDefaultBoundListenSocket; + + Listen(serverOptions.DefaultEndpointOptions); + Listen(serverOptions.TlsEndpointOptions); + + return connectionHandler.ListenFlag + ? kestrel + : throw new MqttConfigurationException("None of the MqttServerOptions Endpoints are enabled."); + + void Listen(MqttServerTcpEndpointBaseOptions endpoint) + { + if (!endpoint.IsEnabled) + { + return; + } + + // No need to listen IPv4EndPoint when IPv6EndPoint's DualMode is true. + var ipV6EndPoint = new IPEndPoint(endpoint.BoundInterNetworkV6Address, endpoint.Port); + using var listenSocket = listenSocketFactory.Invoke(ipV6EndPoint); + if (!listenSocket.DualMode) + { + kestrel.Listen(endpoint.BoundInterNetworkAddress, endpoint.Port, UseMiddlewares); + } + + kestrel.Listen(ipV6EndPoint, UseMiddlewares); + connectionHandler.ListenFlag = true; + + + void UseMiddlewares(ListenOptions listenOptions) + { + listenOptions.Use(next => context => + { + var socketFeature = context.Features.Get(); + if (socketFeature != null) + { + endpoint.AdaptTo(socketFeature.Socket); + } + return next(context); + }); + + if (endpoint is MqttServerTlsTcpEndpointOptions tlsEndPoint) + { + listenOptions.UseHttps(httpsOptions => + { + tlsEndPoint.AdaptTo(httpsOptions); + tlsConfigure?.Invoke(httpsOptions); + }); + } + + listenOptions.UseMqtt(protocols, channelAdapter => PacketFragmentationFeature.CanAllowPacketFragmentation(channelAdapter, endpoint)); + } + } + } + + private static void AdaptTo(this MqttServerTcpEndpointBaseOptions endpoint, Socket acceptSocket) + { + acceptSocket.NoDelay = endpoint.NoDelay; + if (endpoint.LingerState != null) + { + acceptSocket.LingerState = endpoint.LingerState; + } + + if (endpoint.KeepAlive.HasValue) + { + var value = endpoint.KeepAlive.Value; + acceptSocket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, value); + } + + if (endpoint.TcpKeepAliveInterval.HasValue) + { + var value = endpoint.TcpKeepAliveInterval.Value; + acceptSocket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.TcpKeepAliveInterval, value); + } + + if (endpoint.TcpKeepAliveRetryCount.HasValue) + { + var value = endpoint.TcpKeepAliveRetryCount.Value; + acceptSocket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.TcpKeepAliveRetryCount, value); + } + + if (endpoint.TcpKeepAliveTime.HasValue) + { + var value = endpoint.TcpKeepAliveTime.Value; + acceptSocket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.TcpKeepAliveTime, value); + } + } + + private static void AdaptTo(this MqttServerTlsTcpEndpointOptions tlsEndPoint, HttpsConnectionAdapterOptions httpsOptions) + { + httpsOptions.SslProtocols = tlsEndPoint.SslProtocol; + httpsOptions.CheckCertificateRevocation = tlsEndPoint.CheckCertificateRevocation; + + if (tlsEndPoint.ClientCertificateRequired) + { + httpsOptions.ClientCertificateMode = ClientCertificateMode.RequireCertificate; + } + + if (tlsEndPoint.CertificateProvider != null) + { + httpsOptions.ServerCertificateSelector = (context, host) => tlsEndPoint.CertificateProvider.GetCertificate(); + } + + if (tlsEndPoint.RemoteCertificateValidationCallback != null) + { + httpsOptions.ClientCertificateValidation = (cert, chain, errors) => tlsEndPoint.RemoteCertificateValidationCallback(tlsEndPoint, cert, chain, errors); + } + } + } +} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj b/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj index 5357dd702..8864320ac 100644 --- a/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj +++ b/Source/MQTTnet.AspnetCore/MQTTnet.AspNetCore.csproj @@ -37,7 +37,8 @@ true low low - latest-Recommended + enable + @@ -48,16 +49,12 @@ - - + + - - - - - + diff --git a/Source/MQTTnet.AspnetCore/MqttBufferWriterPoolOptions.cs b/Source/MQTTnet.AspnetCore/MqttBufferWriterPoolOptions.cs new file mode 100644 index 000000000..9927404a6 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/MqttBufferWriterPoolOptions.cs @@ -0,0 +1,22 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using MQTTnet.Formatter; +using System; + +namespace MQTTnet.AspNetCore +{ + public sealed class MqttBufferWriterPoolOptions + { + /// + /// When the lifetime of the is less than this value, is pooled. + /// + public TimeSpan MaxLifetime { get; set; } = TimeSpan.FromMinutes(1d); + + /// + /// Whether to pool with BufferSize greater than the default buffer size. + /// + public bool LargeBufferSizeEnabled { get; set; } = true; + } +} diff --git a/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs new file mode 100644 index 000000000..c0623a048 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/MqttBuilderExtensions.cs @@ -0,0 +1,73 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using MQTTnet.Diagnostics.Logger; +using System; +using System.Diagnostics.CodeAnalysis; + +namespace MQTTnet.AspNetCore +{ + public static class MqttBuilderExtensions + { + /// + /// Use as + /// + /// + /// + /// + public static IMqttBuilder UseAspNetCoreMqttNetLogger(this IMqttBuilder builder, Action configure) + { + builder.Services.Configure(configure); + return builder.UseAspNetCoreMqttNetLogger(); + } + + /// + /// Use as + /// + /// + /// + public static IMqttBuilder UseAspNetCoreMqttNetLogger(this IMqttBuilder builder) + { + return builder.UseLogger(); + } + + /// + /// Use as + /// + /// + /// + public static IMqttBuilder UseMqttNetNullLogger(this IMqttBuilder builder) + { + return builder.UseLogger(MqttNetNullLogger.Instance); + } + + /// + /// Use a logger + /// + /// + /// + /// + public static IMqttBuilder UseLogger<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] TLogger>(this IMqttBuilder builder) + where TLogger : class, IMqttNetLogger + { + builder.Services.Replace(ServiceDescriptor.Singleton()); + return builder; + } + + /// + /// Use a logger + /// + /// + /// + /// + public static IMqttBuilder UseLogger(this IMqttBuilder builder, IMqttNetLogger logger) + { + ArgumentNullException.ThrowIfNull(logger); + builder.Services.Replace(ServiceDescriptor.Singleton(logger)); + return builder; + } + } +} diff --git a/Source/MQTTnet.AspnetCore/MqttChannelAdapterExtensions.cs b/Source/MQTTnet.AspnetCore/MqttChannelAdapterExtensions.cs new file mode 100644 index 000000000..ff96dd0fc --- /dev/null +++ b/Source/MQTTnet.AspnetCore/MqttChannelAdapterExtensions.cs @@ -0,0 +1,48 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Http; +using MQTTnet.Adapter; +using System; + +namespace MQTTnet.AspNetCore +{ + public static class MqttChannelAdapterExtensions + { + public static bool? IsWebSocketConnection(this IMqttChannelAdapter channelAdapter) + { + ArgumentNullException.ThrowIfNull(channelAdapter); + return channelAdapter is IAspNetCoreMqttChannel channel + ? channel.IsWebSocketConnection + : null; + } + + /// + /// Retrieves the requested feature from the feature collection of channelAdapter. + /// + /// + /// + /// + public static TFeature? GetFeature(this IMqttChannelAdapter channelAdapter) + { + ArgumentNullException.ThrowIfNull(channelAdapter); + return channelAdapter is IAspNetCoreMqttChannel channel + ? channel.GetFeature() + : default; + } + + /// + /// When the channelAdapter is a WebSocket connection, it can get an associated . + /// + /// + /// + public static HttpContext? GetHttpContext(this IMqttChannelAdapter channelAdapter) + { + ArgumentNullException.ThrowIfNull(channelAdapter); + return channelAdapter is IAspNetCoreMqttChannel channel + ? channel.HttpContext + : null; + } + } +} diff --git a/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs new file mode 100644 index 000000000..1c79663c4 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/MqttClientBuilderExtensions.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using MQTTnet.Adapter; +using MQTTnet.Implementations; +using System; +using System.Diagnostics.CodeAnalysis; + +namespace MQTTnet.AspNetCore +{ + public static class MqttClientBuilderExtensions + { + /// + /// Replace the implementation of to + /// + /// + /// + public static IMqttClientBuilder UseMQTTnetMqttClientAdapterFactory(this IMqttClientBuilder builder) + { + return builder.UseMqttClientAdapterFactory(); + } + + /// + /// Replace the implementation of to + /// + /// + /// + public static IMqttClientBuilder UseAspNetCoreMqttClientAdapterFactory(this IMqttClientBuilder builder) + { + return builder.UseMqttClientAdapterFactory(); + } + + /// + /// Replace the implementation of to TMqttClientAdapterFactory + /// + /// + /// + /// + public static IMqttClientBuilder UseMqttClientAdapterFactory<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] TMqttClientAdapterFactory>(this IMqttClientBuilder builder) + where TMqttClientAdapterFactory : class, IMqttClientAdapterFactory + { + builder.Services.Replace(ServiceDescriptor.Singleton()); + return builder; + } + } +} diff --git a/Source/MQTTnet.AspnetCore/MqttClientConnectionContextFactory.cs b/Source/MQTTnet.AspnetCore/MqttClientConnectionContextFactory.cs deleted file mode 100644 index 0ddbbd8f1..000000000 --- a/Source/MQTTnet.AspnetCore/MqttClientConnectionContextFactory.cs +++ /dev/null @@ -1,34 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using MQTTnet.Adapter; -using MQTTnet.Diagnostics.Logger; -using MQTTnet.Formatter; - -namespace MQTTnet.AspNetCore -{ - public sealed class MqttClientConnectionContextFactory : IMqttClientAdapterFactory - { - public IMqttChannelAdapter CreateClientAdapter(MqttClientOptions options, MqttPacketInspector packetInspector, IMqttNetLogger logger) - { - if (options == null) throw new ArgumentNullException(nameof(options)); - - switch (options.ChannelOptions) - { - case MqttClientTcpOptions tcpOptions: - { - var tcpConnection = new SocketConnection(tcpOptions.RemoteEndpoint); - - var formatter = new MqttPacketFormatterAdapter(options.ProtocolVersion, new MqttBufferWriter(4096, 65535)); - return new MqttConnectionContext(formatter, tcpConnection); - } - default: - { - throw new NotSupportedException(); - } - } - } - } -} diff --git a/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs b/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs deleted file mode 100644 index 61184c4d5..000000000 --- a/Source/MQTTnet.AspnetCore/MqttConnectionContext.cs +++ /dev/null @@ -1,239 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Buffers; -using System.IO.Pipelines; -using System.Net; -using System.Security.Cryptography.X509Certificates; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Connections; -using Microsoft.AspNetCore.Http.Connections.Features; -using Microsoft.AspNetCore.Http.Features; -using MQTTnet.Adapter; -using MQTTnet.Exceptions; -using MQTTnet.Formatter; -using MQTTnet.Internal; -using MQTTnet.Packets; - -namespace MQTTnet.AspNetCore; - -public sealed class MqttConnectionContext : IMqttChannelAdapter -{ - readonly ConnectionContext _connection; - readonly AsyncLock _writerLock = new(); - - PipeReader _input; - PipeWriter _output; - - public MqttConnectionContext(MqttPacketFormatterAdapter packetFormatterAdapter, ConnectionContext connection) - { - PacketFormatterAdapter = packetFormatterAdapter ?? throw new ArgumentNullException(nameof(packetFormatterAdapter)); - _connection = connection ?? throw new ArgumentNullException(nameof(connection)); - - if (!(_connection is SocketConnection tcp) || tcp.IsConnected) - { - _input = connection.Transport.Input; - _output = connection.Transport.Output; - } - } - - public long BytesReceived { get; private set; } - - public long BytesSent { get; private set; } - - public X509Certificate2 ClientCertificate - { - get - { - // mqtt over tcp - var tlsFeature = _connection.Features.Get(); - if (tlsFeature != null) - { - return tlsFeature.ClientCertificate; - } - - // mqtt over websocket - var httpFeature = _connection.Features.Get(); - return httpFeature?.HttpContext?.Connection.ClientCertificate; - } - } - - public EndPoint RemoteEndPoint - { - get - { - // mqtt over tcp - if (_connection.RemoteEndPoint != null) - { - return _connection.RemoteEndPoint; - } - - // mqtt over websocket - var httpFeature = _connection.Features.Get(); - if (httpFeature?.RemoteIpAddress != null) - { - return new IPEndPoint(httpFeature.RemoteIpAddress, httpFeature.RemotePort); - } - - return null; - } - } - - public bool IsSecureConnection - { - get - { - // mqtt over tcp - var tlsFeature = _connection.Features.Get(); - if (tlsFeature != null) - { - return true; - } - - // mqtt over websocket - var httpFeature = _connection.Features.Get(); - if (httpFeature?.HttpContext != null) - { - return httpFeature.HttpContext.Request.IsHttps; - } - - return false; - } - } - - public MqttPacketFormatterAdapter PacketFormatterAdapter { get; } - - public async Task ConnectAsync(CancellationToken cancellationToken) - { - if (_connection is SocketConnection tcp && !tcp.IsConnected) - { - await tcp.StartAsync().ConfigureAwait(false); - } - - _input = _connection.Transport.Input; - _output = _connection.Transport.Output; - } - - public Task DisconnectAsync(CancellationToken cancellationToken) - { - _input?.Complete(); - _output?.Complete(); - - return Task.CompletedTask; - } - - public void Dispose() - { - _writerLock.Dispose(); - } - - public async Task ReceivePacketAsync(CancellationToken cancellationToken) - { - try - { - while (!cancellationToken.IsCancellationRequested) - { - ReadResult readResult; - var readTask = _input.ReadAsync(cancellationToken); - if (readTask.IsCompleted) - { - readResult = readTask.Result; - } - else - { - readResult = await readTask.ConfigureAwait(false); - } - - var buffer = readResult.Buffer; - - var consumed = buffer.Start; - var observed = buffer.Start; - - try - { - if (!buffer.IsEmpty) - { - if (PacketFormatterAdapter.TryDecode(buffer, out var packet, out consumed, out observed, out var received)) - { - BytesReceived += received; - return packet; - } - } - else if (readResult.IsCompleted) - { - throw new MqttCommunicationException("Connection Aborted"); - } - } - finally - { - // The buffer was sliced up to where it was consumed, so we can just advance to the start. - // We mark examined as buffer.End so that if we didn't receive a full frame, we'll wait for more data - // before yielding the read again. - _input.AdvanceTo(consumed, observed); - } - } - } - catch (Exception exception) - { - // completing the channel makes sure that there is no more data read after a protocol error - _input?.Complete(exception); - _output?.Complete(exception); - - throw; - } - - cancellationToken.ThrowIfCancellationRequested(); - return null; - } - - public void ResetStatistics() - { - BytesReceived = 0; - BytesSent = 0; - } - - public async Task SendPacketAsync(MqttPacket packet, CancellationToken cancellationToken) - { - using (await _writerLock.EnterAsync(cancellationToken).ConfigureAwait(false)) - { - try - { - var buffer = PacketFormatterAdapter.Encode(packet); - - if (buffer.Payload.Length == 0) - { - // zero copy - // https://github.com/dotnet/runtime/blob/e31ddfdc4f574b26231233dc10c9a9c402f40590/src/libraries/System.IO.Pipelines/src/System/IO/Pipelines/StreamPipeWriter.cs#L279 - await _output.WriteAsync(buffer.Packet, cancellationToken).ConfigureAwait(false); - } - else - { - WritePacketBuffer(_output, buffer); - await _output.FlushAsync(cancellationToken).ConfigureAwait(false); - } - - BytesSent += buffer.Length; - } - finally - { - PacketFormatterAdapter.Cleanup(); - } - } - } - - static void WritePacketBuffer(PipeWriter output, MqttPacketBuffer buffer) - { - // copy MqttPacketBuffer's Packet and Payload to the same buffer block of PipeWriter - // MqttPacket will be transmitted within the bounds of a WebSocket frame after PipeWriter.FlushAsync - - var span = output.GetSpan(buffer.Length); - - buffer.Packet.AsSpan().CopyTo(span); - int offset = buffer.Packet.Count; - buffer.Payload.CopyTo(destination: span.Slice(offset)); - output.Advance(buffer.Length); - } -} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/MqttConnectionHandler.cs b/Source/MQTTnet.AspnetCore/MqttConnectionHandler.cs deleted file mode 100644 index b4cbc42a8..000000000 --- a/Source/MQTTnet.AspnetCore/MqttConnectionHandler.cs +++ /dev/null @@ -1,57 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Connections; -using Microsoft.AspNetCore.Connections.Features; -using MQTTnet.Adapter; -using MQTTnet.Diagnostics.Logger; -using MQTTnet.Formatter; -using MQTTnet.Server; - -namespace MQTTnet.AspNetCore; - -public sealed class MqttConnectionHandler : ConnectionHandler, IMqttServerAdapter -{ - MqttServerOptions _serverOptions; - - public Func ClientHandler { get; set; } - - public void Dispose() - { - } - - public override async Task OnConnectedAsync(ConnectionContext connection) - { - // required for websocket transport to work - var transferFormatFeature = connection.Features.Get(); - if (transferFormatFeature != null) - { - transferFormatFeature.ActiveFormat = TransferFormat.Binary; - } - - var formatter = new MqttPacketFormatterAdapter(new MqttBufferWriter(_serverOptions.WriterBufferSize, _serverOptions.WriterBufferSizeMax)); - using (var adapter = new MqttConnectionContext(formatter, connection)) - { - var clientHandler = ClientHandler; - if (clientHandler != null) - { - await clientHandler(adapter).ConfigureAwait(false); - } - } - } - - public Task StartAsync(MqttServerOptions options, IMqttNetLogger logger) - { - _serverOptions = options; - - return Task.CompletedTask; - } - - public Task StopAsync() - { - return Task.CompletedTask; - } -} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/MqttHostedServer.cs b/Source/MQTTnet.AspnetCore/MqttHostedServer.cs deleted file mode 100644 index 4c74f6a43..000000000 --- a/Source/MQTTnet.AspnetCore/MqttHostedServer.cs +++ /dev/null @@ -1,48 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.Extensions.Hosting; -using MQTTnet.Diagnostics.Logger; -using MQTTnet.Server; - -namespace MQTTnet.AspNetCore; - -public sealed class MqttHostedServer : MqttServer, IHostedService -{ - readonly IHostApplicationLifetime _hostApplicationLifetime; - readonly MqttServerFactory _mqttFactory; - - public MqttHostedServer( - IHostApplicationLifetime hostApplicationLifetime, - MqttServerFactory mqttFactory, - MqttServerOptions options, - IEnumerable adapters, - IMqttNetLogger logger) : base(options, adapters, logger) - { - _mqttFactory = mqttFactory ?? throw new ArgumentNullException(nameof(mqttFactory)); - _hostApplicationLifetime = hostApplicationLifetime; - } - - public async Task StartAsync(CancellationToken cancellationToken) - { - // The yield makes sure that the hosted service is considered up and running. - await Task.Yield(); - - _hostApplicationLifetime.ApplicationStarted.Register(OnStarted); - } - - public Task StopAsync(CancellationToken cancellationToken) - { - return StopAsync(_mqttFactory.CreateMqttServerStopOptionsBuilder().Build()); - } - - void OnStarted() - { - _ = StartAsync(); - } -} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/MqttProtocols.cs b/Source/MQTTnet.AspnetCore/MqttProtocols.cs new file mode 100644 index 000000000..f1701445c --- /dev/null +++ b/Source/MQTTnet.AspnetCore/MqttProtocols.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace MQTTnet.AspNetCore +{ + public enum MqttProtocols + { + /// + /// Only support Mqtt + /// + Mqtt, + + /// + /// Only support Mqtt-over-WebSocket + /// + WebSocket, + + /// + /// Support both Mqtt and Mqtt-over-WebSocket + /// + MqttAndWebSocket + } +} diff --git a/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs b/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs new file mode 100644 index 000000000..8dbce7773 --- /dev/null +++ b/Source/MQTTnet.AspnetCore/MqttServerBuilderExtensions.cs @@ -0,0 +1,104 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using MQTTnet.Formatter; +using MQTTnet.Server; +using System; +using System.Diagnostics.CodeAnalysis; + +namespace MQTTnet.AspNetCore +{ + public static class MqttServerBuilderExtensions + { + /// + /// Configure + /// + /// + /// + /// + public static IMqttServerBuilder ConfigureMqttServer(this IMqttServerBuilder builder, Action builderConfigure) + { + builder.Services.Configure(builderConfigure); + return builder; + } + + /// + /// Configure and + /// + /// + /// + /// + /// + public static IMqttServerBuilder ConfigureMqttServer(this IMqttServerBuilder builder, Action builderConfigure, Action optionsConfigure) + { + builder.Services.Configure(builderConfigure).Configure(optionsConfigure); + return builder; + } + + /// + /// Configure + /// + /// + /// + /// + public static IMqttServerBuilder ConfigureMqttServerStop(this IMqttServerBuilder builder, Action builderConfigure) + { + builder.Services.Configure(builderConfigure); + return builder; + } + + /// + /// Configure and + /// + /// + /// + /// + /// + public static IMqttServerBuilder ConfigureMqttServerStop(this IMqttServerBuilder builder, Action builderConfigure, Action optionsConfigure) + { + builder.Services.Configure(builderConfigure).Configure(optionsConfigure); + return builder; + } + + /// + /// Configure the pool of + /// + /// + /// + /// + public static IMqttServerBuilder ConfigureMqttBufferWriterPool(this IMqttServerBuilder builder, Action configure) + { + builder.Services.Configure(configure); + return builder; + } + + /// + /// Configure the socket of mqtt listener + /// + /// + /// + /// + public static IMqttServerBuilder ConfigureMqttSocketTransport(this IMqttServerBuilder builder, Action configure) + { + builder.Services.Configure(configure); + return builder; + } + + /// + /// Add an to + /// + /// + /// + /// + public static IMqttServerBuilder AddMqttServerAdapter<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] TMqttServerAdapter>(this IMqttServerBuilder builder) + where TMqttServerAdapter : class, IMqttServerAdapter + { + builder.Services.TryAddEnumerable(ServiceDescriptor.Singleton()); + return builder; + } + } +} diff --git a/Source/MQTTnet.AspnetCore/MqttSubProtocolSelector.cs b/Source/MQTTnet.AspnetCore/MqttSubProtocolSelector.cs deleted file mode 100644 index c6acdfa8e..000000000 --- a/Source/MQTTnet.AspnetCore/MqttSubProtocolSelector.cs +++ /dev/null @@ -1,34 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Generic; -using System.Linq; -using Microsoft.AspNetCore.Http; - -namespace MQTTnet.AspNetCore; - -public static class MqttSubProtocolSelector -{ - public static string SelectSubProtocol(HttpRequest request) - { - ArgumentNullException.ThrowIfNull(request); - - string subProtocol = null; - if (request.Headers.TryGetValue("Sec-WebSocket-Protocol", out var requestedSubProtocolValues)) - { - subProtocol = SelectSubProtocol(requestedSubProtocolValues); - } - - return subProtocol; - } - - public static string SelectSubProtocol(IList requestedSubProtocolValues) - { - ArgumentNullException.ThrowIfNull(requestedSubProtocolValues); - - // Order the protocols to also match "mqtt", "mqttv-3.1", "mqttv-3.11" etc. - return requestedSubProtocolValues.OrderByDescending(p => p.Length).FirstOrDefault(p => p.ToLower().StartsWith("mqtt")); - } -} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/MqttWebSocketServerAdapter.cs b/Source/MQTTnet.AspnetCore/MqttWebSocketServerAdapter.cs deleted file mode 100644 index 18a382a12..000000000 --- a/Source/MQTTnet.AspnetCore/MqttWebSocketServerAdapter.cs +++ /dev/null @@ -1,68 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Net; -using System.Net.WebSockets; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Http; -using MQTTnet.Adapter; -using MQTTnet.Diagnostics.Logger; -using MQTTnet.Formatter; -using MQTTnet.Implementations; -using MQTTnet.Server; - -namespace MQTTnet.AspNetCore; - -public sealed class MqttWebSocketServerAdapter : IMqttServerAdapter -{ - IMqttNetLogger _logger = MqttNetNullLogger.Instance; - - public Func ClientHandler { get; set; } - - public void Dispose() - { - } - - public async Task RunWebSocketConnectionAsync(WebSocket webSocket, HttpContext httpContext) - { - ArgumentNullException.ThrowIfNull(webSocket); - - var remoteAddress = httpContext.Connection.RemoteIpAddress; - var remoteEndPoint = remoteAddress == null ? null : new IPEndPoint(remoteAddress, httpContext.Connection.RemotePort); - - var clientCertificate = await httpContext.Connection.GetClientCertificateAsync().ConfigureAwait(false); - try - { - var isSecureConnection = clientCertificate != null; - - var clientHandler = ClientHandler; - if (clientHandler != null) - { - var formatter = new MqttPacketFormatterAdapter(new MqttBufferWriter(4096, 65535)); - var channel = new MqttWebSocketChannel(webSocket, remoteEndPoint, isSecureConnection, clientCertificate); - - using (var channelAdapter = new MqttChannelAdapter(channel, formatter, _logger)) - { - await clientHandler(channelAdapter).ConfigureAwait(false); - } - } - } - finally - { - clientCertificate?.Dispose(); - } - } - - public Task StartAsync(MqttServerOptions options, IMqttNetLogger logger) - { - _logger = logger ?? throw new ArgumentNullException(nameof(logger)); - return Task.CompletedTask; - } - - public Task StopAsync() - { - return Task.CompletedTask; - } -} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs index 915f6791c..a028f1b65 100644 --- a/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs +++ b/Source/MQTTnet.AspnetCore/ServiceCollectionExtensions.cs @@ -2,108 +2,77 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; -using Microsoft.Extensions.Hosting; +using MQTTnet.Adapter; using MQTTnet.Diagnostics.Logger; using MQTTnet.Server; -using MQTTnet.Server.Internal.Adapter; +using System; namespace MQTTnet.AspNetCore; public static class ServiceCollectionExtensions { - public static IServiceCollection AddHostedMqttServer(this IServiceCollection services, MqttServerOptions options) - { - ArgumentNullException.ThrowIfNull(services); - ArgumentNullException.ThrowIfNull(options); - - services.AddSingleton(options); - services.AddHostedMqttServer(); - - return services; - } - - public static IServiceCollection AddHostedMqttServer(this IServiceCollection services, Action configure) - { - ArgumentNullException.ThrowIfNull(services); - - var serverOptionsBuilder = new MqttServerOptionsBuilder(); - - configure?.Invoke(serverOptionsBuilder); - - var options = serverOptionsBuilder.Build(); - - return AddHostedMqttServer(services, options); - } - - public static void AddHostedMqttServer(this IServiceCollection services) - { - // The user may have these services already registered. - services.TryAddSingleton(MqttNetNullLogger.Instance); - services.TryAddSingleton(new MqttServerFactory()); - - services.AddSingleton(); - services.AddSingleton(s => s.GetService()); - services.AddSingleton(s => s.GetService()); - } - - public static IServiceCollection AddHostedMqttServerWithServices(this IServiceCollection services, Action configure) + /// + /// Register as a singleton service + /// + /// + /// + /// + public static IMqttServerBuilder AddMqttServer(this IServiceCollection services, Action configure) { - ArgumentNullException.ThrowIfNull(services); - - services.AddSingleton( - s => - { - var builder = new AspNetMqttServerOptionsBuilder(s); - configure(builder); - return builder.Build(); - }); - - services.AddHostedMqttServer(); - - return services; + services.Configure(configure); + return services.AddMqttServer(); } - public static IServiceCollection AddMqttConnectionHandler(this IServiceCollection services) + /// + /// Register as a singleton service + /// + /// + /// + public static IMqttServerBuilder AddMqttServer(this IServiceCollection services) { - services.AddSingleton(); - services.AddSingleton(s => s.GetService()); - - return services; + services.AddOptions(); + services.AddConnections(); + services.TryAddSingleton(); + services.TryAddSingleton(); + services.TryAddSingleton(); + services.TryAddSingleton(); + services.TryAddSingleton(); + services.TryAddSingleton(s => s.GetRequiredService().CreateOptions()); + services.TryAddSingleton(s => s.GetRequiredService().CreateOptions()); + services.TryAddEnumerable(ServiceDescriptor.Singleton()); + + services.TryAddSingleton(); + services.AddHostedService(); + services.TryAddSingleton(s => s.GetRequiredService()); + + return services.AddMqtt(); } - public static void AddMqttLogger(this IServiceCollection services, IMqttNetLogger logger) + /// + /// Register and as singleton service + /// + /// + /// + public static IMqttClientBuilder AddMqttClient(this IServiceCollection services) { - ArgumentNullException.ThrowIfNull(services); - - services.AddSingleton(logger); + services.TryAddSingleton(); + services.TryAddSingleton(); + services.TryAddSingleton(s => s.GetRequiredService()); + services.TryAddSingleton(s => s.GetRequiredService()); + return services.AddMqtt(); } - public static IServiceCollection AddMqttServer(this IServiceCollection serviceCollection, Action configure = null) + private static MqttBuilder AddMqtt(this IServiceCollection services) { - ArgumentNullException.ThrowIfNull(serviceCollection); - - serviceCollection.AddMqttConnectionHandler(); - serviceCollection.AddHostedMqttServer(configure); - - return serviceCollection; + services.AddLogging(); + services.TryAddSingleton(); + return new MqttBuilder(services); } - public static IServiceCollection AddMqttTcpServerAdapter(this IServiceCollection services) + private class MqttBuilder(IServiceCollection services) : IMqttServerBuilder, IMqttClientBuilder { - services.AddSingleton(); - services.AddSingleton(s => s.GetService()); - - return services; - } - - public static IServiceCollection AddMqttWebSocketServerAdapter(this IServiceCollection services) - { - services.AddSingleton(); - services.AddSingleton(s => s.GetService()); - - return services; + public IServiceCollection Services { get; } = services; } } \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/SocketAwaitable.cs b/Source/MQTTnet.AspnetCore/SocketAwaitable.cs deleted file mode 100644 index 2c9607279..000000000 --- a/Source/MQTTnet.AspnetCore/SocketAwaitable.cs +++ /dev/null @@ -1,77 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Diagnostics; -using System.IO.Pipelines; -using System.Net.Sockets; -using System.Runtime.CompilerServices; -using System.Threading; -using System.Threading.Tasks; - -namespace MQTTnet.AspNetCore; - -public class SocketAwaitable : ICriticalNotifyCompletion -{ - static readonly Action _callbackCompleted = () => - { - }; - - readonly PipeScheduler _ioScheduler; - int _bytesTransferred; - - Action _callback; - SocketError _error; - - public SocketAwaitable(PipeScheduler ioScheduler) - { - _ioScheduler = ioScheduler; - } - - public bool IsCompleted => ReferenceEquals(_callback, _callbackCompleted); - - public void Complete(int bytesTransferred, SocketError socketError) - { - _error = socketError; - _bytesTransferred = bytesTransferred; - var continuation = Interlocked.Exchange(ref _callback, _callbackCompleted); - - if (continuation != null) - { - _ioScheduler.Schedule(state => ((Action)state)(), continuation); - } - } - - public SocketAwaitable GetAwaiter() - { - return this; - } - - public int GetResult() - { - Debug.Assert(ReferenceEquals(_callback, _callbackCompleted)); - - _callback = null; - - if (_error != SocketError.Success) - { - throw new SocketException((int)_error); - } - - return _bytesTransferred; - } - - public void OnCompleted(Action continuation) - { - if (ReferenceEquals(_callback, _callbackCompleted) || ReferenceEquals(Interlocked.CompareExchange(ref _callback, continuation, null), _callbackCompleted)) - { - Task.Run(continuation); - } - } - - public void UnsafeOnCompleted(Action continuation) - { - OnCompleted(continuation); - } -} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/SocketConnection.cs b/Source/MQTTnet.AspnetCore/SocketConnection.cs deleted file mode 100644 index 2021eccac..000000000 --- a/Source/MQTTnet.AspnetCore/SocketConnection.cs +++ /dev/null @@ -1,261 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Generic; -using System.IO; -using System.IO.Pipelines; -using System.Net; -using System.Net.Sockets; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Connections; -using Microsoft.AspNetCore.Http.Features; -using MQTTnet.Exceptions; - -namespace MQTTnet.AspNetCore; - -public sealed class SocketConnection : ConnectionContext -{ - readonly EndPoint _endPoint; - volatile bool _aborted; - IDuplexPipe _application; - SocketReceiver _receiver; - SocketSender _sender; - - Socket _socket; - - public SocketConnection(EndPoint endPoint) - { - _endPoint = endPoint; - } - - public SocketConnection(Socket socket) - { - _socket = socket; - _endPoint = socket.RemoteEndPoint; - - _sender = new SocketSender(_socket, PipeScheduler.ThreadPool); - _receiver = new SocketReceiver(_socket, PipeScheduler.ThreadPool); - } - - public override string ConnectionId { get; set; } - public override IFeatureCollection Features { get; } - - public bool IsConnected { get; private set; } - public override IDictionary Items { get; set; } - public override IDuplexPipe Transport { get; set; } - - public override ValueTask DisposeAsync() - { - IsConnected = false; - - Transport?.Output.Complete(); - Transport?.Input.Complete(); - - _socket?.Dispose(); - - return base.DisposeAsync(); - } - - public async Task StartAsync() - { - if (_socket == null) - { - _socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); - _sender = new SocketSender(_socket, PipeScheduler.ThreadPool); - _receiver = new SocketReceiver(_socket, PipeScheduler.ThreadPool); - await _socket.ConnectAsync(_endPoint); - } - - var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); - - Transport = pair.Transport; - _application = pair.Application; - - _ = ExecuteAsync(); - - IsConnected = true; - } - - Exception ConnectionAborted() - { - return new MqttCommunicationException("Connection Aborted"); - } - - async Task DoReceive() - { - Exception error = null; - - try - { - await ProcessReceives(); - } - catch (SocketException ex) when (ex.SocketErrorCode == SocketError.ConnectionReset) - { - error = new MqttCommunicationException(ex); - } - catch (SocketException ex) when (ex.SocketErrorCode == SocketError.OperationAborted || ex.SocketErrorCode == SocketError.ConnectionAborted || - ex.SocketErrorCode == SocketError.Interrupted || ex.SocketErrorCode == SocketError.InvalidArgument) - { - if (!_aborted) - { - // Calling Dispose after ReceiveAsync can cause an "InvalidArgument" error on *nix. - error = ConnectionAborted(); - } - } - catch (ObjectDisposedException) - { - if (!_aborted) - { - error = ConnectionAborted(); - } - } - catch (IOException ex) - { - error = ex; - } - catch (Exception ex) - { - error = new IOException(ex.Message, ex); - } - finally - { - if (_aborted) - { - error = error ?? ConnectionAborted(); - } - - _application.Output.Complete(error); - } - } - - async Task DoSend() - { - Exception error = null; - - try - { - await ProcessSends(); - } - catch (SocketException ex) when (ex.SocketErrorCode == SocketError.OperationAborted) - { - } - catch (ObjectDisposedException) - { - } - catch (IOException ex) - { - error = ex; - } - catch (Exception ex) - { - error = new IOException(ex.Message, ex); - } - finally - { - _aborted = true; - _socket.Shutdown(SocketShutdown.Both); - } - - return error; - } - - async Task ExecuteAsync() - { - Exception sendError = null; - try - { - // Spawn send and receive logic - var receiveTask = DoReceive(); - var sendTask = DoSend(); - - // If the sending task completes then close the receive - // We don't need to do this in the other direction because the kestrel - // will trigger the output closing once the input is complete. - if (await Task.WhenAny(receiveTask, sendTask).ConfigureAwait(false) == sendTask) - { - // Tell the reader it's being aborted - _socket.Dispose(); - } - - // Now wait for both to complete - await receiveTask; - sendError = await sendTask; - - // Dispose the socket(should noop if already called) - _socket.Dispose(); - } - catch (Exception ex) - { - Console.WriteLine($"Unexpected exception in {nameof(SocketConnection)}.{nameof(StartAsync)}: " + ex); - } - finally - { - // Complete the output after disposing the socket - await _application.Input.CompleteAsync(sendError).ConfigureAwait(false); - } - } - - async Task ProcessReceives() - { - while (true) - { - // Ensure we have some reasonable amount of buffer space - var buffer = _application.Output.GetMemory(); - - var bytesReceived = await _receiver.ReceiveAsync(buffer); - - if (bytesReceived == 0) - { - // FIN - break; - } - - _application.Output.Advance(bytesReceived); - - var flushTask = _application.Output.FlushAsync(); - - if (!flushTask.IsCompleted) - { - await flushTask; - } - - var result = flushTask.GetAwaiter().GetResult(); - if (result.IsCompleted) - { - // Pipe consumer is shut down, do we stop writing - break; - } - } - } - - async Task ProcessSends() - { - while (true) - { - // Wait for data to write from the pipe producer - var result = await _application.Input.ReadAsync(); - var buffer = result.Buffer; - - if (result.IsCanceled) - { - break; - } - - var end = buffer.End; - var isCompleted = result.IsCompleted; - if (!buffer.IsEmpty) - { - await _sender.SendAsync(buffer); - } - - _application.Input.AdvanceTo(end); - - if (isCompleted) - { - break; - } - } - } -} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/SocketReceiver.cs b/Source/MQTTnet.AspnetCore/SocketReceiver.cs deleted file mode 100644 index f8b628fb5..000000000 --- a/Source/MQTTnet.AspnetCore/SocketReceiver.cs +++ /dev/null @@ -1,36 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.IO.Pipelines; -using System.Net.Sockets; - -namespace MQTTnet.AspNetCore; - -public sealed class SocketReceiver -{ - readonly SocketAwaitable _awaitable; - readonly SocketAsyncEventArgs _eventArgs = new(); - readonly Socket _socket; - - public SocketReceiver(Socket socket, PipeScheduler scheduler) - { - _socket = socket; - _awaitable = new SocketAwaitable(scheduler); - _eventArgs.UserToken = _awaitable; - _eventArgs.Completed += (_, e) => ((SocketAwaitable)e.UserToken).Complete(e.BytesTransferred, e.SocketError); - } - - public SocketAwaitable ReceiveAsync(Memory buffer) - { - _eventArgs.SetBuffer(buffer); - - if (!_socket.ReceiveAsync(_eventArgs)) - { - _awaitable.Complete(_eventArgs.BytesTransferred, _eventArgs.SocketError); - } - - return _awaitable; - } -} \ No newline at end of file diff --git a/Source/MQTTnet.AspnetCore/SocketSender.cs b/Source/MQTTnet.AspnetCore/SocketSender.cs deleted file mode 100644 index fc06ea6cf..000000000 --- a/Source/MQTTnet.AspnetCore/SocketSender.cs +++ /dev/null @@ -1,93 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Buffers; -using System.Collections.Generic; -using System.Diagnostics; -using System.IO.Pipelines; -using System.Net.Sockets; -using System.Runtime.InteropServices; - -namespace MQTTnet.AspNetCore; - -public sealed class SocketSender -{ - readonly SocketAwaitable _awaitable; - readonly SocketAsyncEventArgs _eventArgs = new(); - readonly Socket _socket; - - List> _bufferList; - - public SocketSender(Socket socket, PipeScheduler scheduler) - { - _socket = socket; - _awaitable = new SocketAwaitable(scheduler); - _eventArgs.UserToken = _awaitable; - _eventArgs.Completed += (_, e) => ((SocketAwaitable)e.UserToken).Complete(e.BytesTransferred, e.SocketError); - } - - public SocketAwaitable SendAsync(in ReadOnlySequence buffers) - { - if (buffers.IsSingleSegment) - { - return SendAsync(buffers.First); - } - - if (!_eventArgs.MemoryBuffer.Equals(Memory.Empty)) - { - _eventArgs.SetBuffer(null, 0, 0); - } - - _eventArgs.BufferList = GetBufferList(buffers); - - if (!_socket.SendAsync(_eventArgs)) - { - _awaitable.Complete(_eventArgs.BytesTransferred, _eventArgs.SocketError); - } - - return _awaitable; - } - - List> GetBufferList(in ReadOnlySequence buffer) - { - Debug.Assert(!buffer.IsEmpty); - Debug.Assert(!buffer.IsSingleSegment); - - if (_bufferList == null) - { - _bufferList = new List>(); - } - else - { - // Buffers are pooled, so it's OK to root them until the next multi-buffer write. - _bufferList.Clear(); - } - - foreach (var b in buffer) - { - _bufferList.Add(b.GetArray()); - } - - return _bufferList; - } - - SocketAwaitable SendAsync(ReadOnlyMemory memory) - { - // The BufferList getter is much less expensive then the setter. - if (_eventArgs.BufferList != null) - { - _eventArgs.BufferList = null; - } - - _eventArgs.SetBuffer(MemoryMarshal.AsMemory(memory)); - - if (!_socket.SendAsync(_eventArgs)) - { - _awaitable.Complete(_eventArgs.BytesTransferred, _eventArgs.SocketError); - } - - return _awaitable; - } -} \ No newline at end of file diff --git a/Source/MQTTnet.Benchmarks/AsyncLockBenchmark.cs b/Source/MQTTnet.Benchmarks/AsyncLockBenchmark.cs index 27a348d4f..96b40c568 100644 --- a/Source/MQTTnet.Benchmarks/AsyncLockBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/AsyncLockBenchmark.cs @@ -11,7 +11,7 @@ namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [MemoryDiagnoser] public class AsyncLockBenchmark : BaseBenchmark { diff --git a/Source/MQTTnet.Benchmarks/LoggerBenchmark.cs b/Source/MQTTnet.Benchmarks/LoggerBenchmark.cs index 6b8b5e410..45e536e74 100644 --- a/Source/MQTTnet.Benchmarks/LoggerBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/LoggerBenchmark.cs @@ -8,7 +8,7 @@ namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [RPlotExporter] [MemoryDiagnoser] public class LoggerBenchmark : BaseBenchmark diff --git a/Source/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj b/Source/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj index d50ca5cd9..b206fe801 100644 --- a/Source/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj +++ b/Source/MQTTnet.Benchmarks/MQTTnet.Benchmarks.csproj @@ -14,7 +14,7 @@ all true low - latest-Recommended + diff --git a/Source/MQTTnet.Benchmarks/MemoryCopyBenchmark.cs b/Source/MQTTnet.Benchmarks/MemoryCopyBenchmark.cs index 0733e2bb9..012f8d847 100644 --- a/Source/MQTTnet.Benchmarks/MemoryCopyBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/MemoryCopyBenchmark.cs @@ -5,7 +5,7 @@ namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [RPlotExporter, RankColumn] [MemoryDiagnoser] public class MemoryCopyBenchmark diff --git a/Source/MQTTnet.Benchmarks/MessageProcessingAspNetCoreBenchmark.cs b/Source/MQTTnet.Benchmarks/MessageProcessingAspNetCoreBenchmark.cs new file mode 100644 index 000000000..6d8e85d4d --- /dev/null +++ b/Source/MQTTnet.Benchmarks/MessageProcessingAspNetCoreBenchmark.cs @@ -0,0 +1,57 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using BenchmarkDotNet.Attributes; +using BenchmarkDotNet.Jobs; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.DependencyInjection; +using MQTTnet.AspNetCore; +using System.Threading.Tasks; + +namespace MQTTnet.Benchmarks; + +[SimpleJob(RuntimeMoniker.Net80)] +[RPlotExporter] +[RankColumn] +[MemoryDiagnoser] +public class MessageProcessingAspNetCoreBenchmark : BaseBenchmark +{ + IMqttClient _mqttClient; + string _payload = string.Empty; + + [Params(1 * 1024, 4 * 1024, 8 * 1024)] + public int PayloadSize { get; set; } + + [Benchmark] + public async Task Send_1000_Messages_AspNetCore() + { + for (var i = 0; i < 1000; i++) + { + await _mqttClient.PublishStringAsync("A", _payload); + } + } + + [GlobalSetup] + public async Task Setup() + { + var builder = WebApplication.CreateBuilder(); + + builder.Services.AddMqttServer(s => s.WithDefaultEndpoint()); + builder.Services.AddMqttClient(); + builder.WebHost.UseKestrel(k => k.ListenMqtt()); + + var app = builder.Build(); + await app.StartAsync(); + + _mqttClient = app.Services.GetRequiredService().CreateMqttClient(); + var clientOptions = new MqttClientOptionsBuilder() + .WithTcpServer("localhost") + .Build(); + + await _mqttClient.ConnectAsync(clientOptions); + + _payload = string.Empty.PadLeft(PayloadSize, '0'); + } +} \ No newline at end of file diff --git a/Source/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs b/Source/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs index 894ef19e5..cca6e4804 100644 --- a/Source/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/MessageProcessingBenchmark.cs @@ -8,7 +8,7 @@ namespace MQTTnet.Benchmarks; -[SimpleJob(RuntimeMoniker.Net60)] +[SimpleJob(RuntimeMoniker.Net80)] [RPlotExporter] [RankColumn] [MemoryDiagnoser] diff --git a/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs b/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs deleted file mode 100644 index b22d365e8..000000000 --- a/Source/MQTTnet.Benchmarks/MessageProcessingMqttConnectionContextBenchmark.cs +++ /dev/null @@ -1,73 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using BenchmarkDotNet.Attributes; -using BenchmarkDotNet.Jobs; -using Microsoft.AspNetCore; -using Microsoft.AspNetCore.Hosting; -using MQTTnet.AspNetCore; -using MQTTnet.Diagnostics.Logger; - -namespace MQTTnet.Benchmarks -{ - [SimpleJob(RuntimeMoniker.Net60)] - [MemoryDiagnoser] - public class MessageProcessingMqttConnectionContextBenchmark : BaseBenchmark - { - IWebHost _host; - IMqttClient _mqttClient; - MqttApplicationMessage _message; - - [GlobalSetup] - public void Setup() - { - _host = WebHost.CreateDefaultBuilder() - .UseKestrel(o => o.ListenAnyIP(1883, l => l.UseMqtt())) - .ConfigureServices(services => { - services - .AddHostedMqttServer(mqttServerOptions => mqttServerOptions.WithoutDefaultEndpoint()) - .AddMqttConnectionHandler(); - }) - .Configure(app => { - app.UseMqttServer(s => { - - }); - }) - .Build(); - - var factory = new MqttClientFactory(); - _mqttClient = factory.CreateMqttClient(new MqttNetEventLogger(), new MqttClientConnectionContextFactory()); - - _host.StartAsync().GetAwaiter().GetResult(); - - var clientOptions = new MqttClientOptionsBuilder() - .WithTcpServer("localhost").Build(); - - _mqttClient.ConnectAsync(clientOptions).GetAwaiter().GetResult(); - - _message = new MqttApplicationMessageBuilder() - .WithTopic("A") - .Build(); - } - - [GlobalCleanup] - public void Cleanup() - { - _mqttClient.DisconnectAsync().GetAwaiter().GetResult(); - _mqttClient.Dispose(); - - _host.StopAsync().GetAwaiter().GetResult(); - _host.Dispose(); - } - - [Benchmark] - public void Send_10000_Messages() - { - for (var i = 0; i < 10000; i++) - { - _mqttClient.PublishAsync(_message).GetAwaiter().GetResult(); - } - } - } -} diff --git a/Source/MQTTnet.Benchmarks/MqttBufferReaderBenchmark.cs b/Source/MQTTnet.Benchmarks/MqttBufferReaderBenchmark.cs index bfa3d209c..c8529535c 100644 --- a/Source/MQTTnet.Benchmarks/MqttBufferReaderBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/MqttBufferReaderBenchmark.cs @@ -10,7 +10,7 @@ namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [MemoryDiagnoser] public class MqttBufferReaderBenchmark { diff --git a/Source/MQTTnet.Benchmarks/MqttPacketReaderWriterBenchmark.cs b/Source/MQTTnet.Benchmarks/MqttPacketReaderWriterBenchmark.cs index 0efc7ffac..4190d4bfb 100644 --- a/Source/MQTTnet.Benchmarks/MqttPacketReaderWriterBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/MqttPacketReaderWriterBenchmark.cs @@ -10,7 +10,7 @@ namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [MemoryDiagnoser] public class MqttPacketReaderWriterBenchmark : BaseBenchmark { diff --git a/Source/MQTTnet.Benchmarks/MqttTcpChannelBenchmark.cs b/Source/MQTTnet.Benchmarks/MqttTcpChannelBenchmark.cs index 613647471..95257ef21 100644 --- a/Source/MQTTnet.Benchmarks/MqttTcpChannelBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/MqttTcpChannelBenchmark.cs @@ -17,7 +17,7 @@ namespace MQTTnet.Benchmarks; -[SimpleJob(RuntimeMoniker.Net60)] +[SimpleJob(RuntimeMoniker.Net80)] [MemoryDiagnoser] public class MqttTcpChannelBenchmark : BaseBenchmark { diff --git a/Source/MQTTnet.Benchmarks/ReaderExtensionsBenchmark.cs b/Source/MQTTnet.Benchmarks/ReaderExtensionsBenchmark.cs index 5f2242461..3a65dfa99 100644 --- a/Source/MQTTnet.Benchmarks/ReaderExtensionsBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/ReaderExtensionsBenchmark.cs @@ -14,7 +14,7 @@ namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [RPlotExporter, RankColumn] [MemoryDiagnoser] public class ReaderExtensionsBenchmark @@ -116,7 +116,7 @@ public async Task After() { if (!buffer.IsEmpty) { - if (ReaderExtensions.TryDecode(mqttPacketFormatter, buffer, out var packet, out consumed, out observed, out var received)) + if (MqttPacketFormatterAdapterExtensions.TryDecode(mqttPacketFormatter, buffer, null, out var packet, out consumed, out observed, out var received)) { break; } diff --git a/Source/MQTTnet.Benchmarks/RoundtripProcessingBenchmark.cs b/Source/MQTTnet.Benchmarks/RoundtripProcessingBenchmark.cs index e3358fb91..69fedda41 100644 --- a/Source/MQTTnet.Benchmarks/RoundtripProcessingBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/RoundtripProcessingBenchmark.cs @@ -5,7 +5,7 @@ namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [RPlotExporter, RankColumn] [MemoryDiagnoser] public class RoundtripProcessingBenchmark : BaseBenchmark diff --git a/Source/MQTTnet.Benchmarks/SendPacketAsyncBenchmark.cs b/Source/MQTTnet.Benchmarks/SendPacketAsyncBenchmark.cs index b31782e66..a4e7d05d6 100644 --- a/Source/MQTTnet.Benchmarks/SendPacketAsyncBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/SendPacketAsyncBenchmark.cs @@ -9,7 +9,7 @@ namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [RPlotExporter, RankColumn] [MemoryDiagnoser] public class SendPacketAsyncBenchmark : BaseBenchmark diff --git a/Source/MQTTnet.Benchmarks/SerializerBenchmark.cs b/Source/MQTTnet.Benchmarks/SerializerBenchmark.cs index 48117232c..d1638644c 100644 --- a/Source/MQTTnet.Benchmarks/SerializerBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/SerializerBenchmark.cs @@ -19,7 +19,7 @@ namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [RPlotExporter] [MemoryDiagnoser] public class SerializerBenchmark : BaseBenchmark diff --git a/Source/MQTTnet.Benchmarks/ServerProcessingBenchmark.cs b/Source/MQTTnet.Benchmarks/ServerProcessingBenchmark.cs index fbac6dc02..f2e582af7 100644 --- a/Source/MQTTnet.Benchmarks/ServerProcessingBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/ServerProcessingBenchmark.cs @@ -9,7 +9,7 @@ namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [RPlotExporter, RankColumn] [MemoryDiagnoser] public class ServerProcessingBenchmark : BaseBenchmark diff --git a/Source/MQTTnet.Benchmarks/TcpPipesBenchmark.cs b/Source/MQTTnet.Benchmarks/TcpPipesBenchmark.cs index 7692f78b3..4111be6e8 100644 --- a/Source/MQTTnet.Benchmarks/TcpPipesBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/TcpPipesBenchmark.cs @@ -2,18 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using BenchmarkDotNet.Attributes; +using BenchmarkDotNet.Jobs; using System.IO.Pipelines; using System.Net; using System.Net.Sockets; using System.Threading; using System.Threading.Tasks; -using BenchmarkDotNet.Attributes; -using BenchmarkDotNet.Jobs; -using MQTTnet.AspNetCore; namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [MemoryDiagnoser] public class TcpPipesBenchmark : BaseBenchmark { @@ -29,17 +28,17 @@ public void Setup() var task = Task.Run(() => server.AcceptSocket()); - var clientConnection = new SocketConnection(new IPEndPoint(IPAddress.Loopback, 1883)); + var clientConnection = new Socket(SocketType.Stream, ProtocolType.Tcp); + clientConnection.Connect(new IPEndPoint(IPAddress.Loopback, 1883)); + _client = new SocketDuplexPipe(clientConnection); - clientConnection.StartAsync().GetAwaiter().GetResult(); - _client = clientConnection.Transport; - - var serverConnection = new SocketConnection(task.GetAwaiter().GetResult()); - serverConnection.StartAsync().GetAwaiter().GetResult(); - _server = serverConnection.Transport; + var serverConnection =task.GetAwaiter().GetResult(); + _server = new SocketDuplexPipe(serverConnection); } + + [Benchmark] public async Task Send_10000_Chunks_Pipe() { @@ -76,5 +75,19 @@ async Task WriteAsync(int iterations, int size) await output.WriteAsync(new byte[size], CancellationToken.None).ConfigureAwait(false); } } + + private class SocketDuplexPipe : IDuplexPipe + { + public PipeReader Input { get; } + + public PipeWriter Output { get; } + + public SocketDuplexPipe(Socket socket) + { + var stream = new NetworkStream(socket); + this.Input = PipeReader.Create(stream); + this.Output = PipeWriter.Create(stream); + } + } } } diff --git a/Source/MQTTnet.Benchmarks/TopicFilterComparerBenchmark.cs b/Source/MQTTnet.Benchmarks/TopicFilterComparerBenchmark.cs index f78cb81c3..be8f2e7c8 100644 --- a/Source/MQTTnet.Benchmarks/TopicFilterComparerBenchmark.cs +++ b/Source/MQTTnet.Benchmarks/TopicFilterComparerBenchmark.cs @@ -8,7 +8,7 @@ namespace MQTTnet.Benchmarks { - [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [RPlotExporter] [MemoryDiagnoser] public class TopicFilterComparerBenchmark : BaseBenchmark diff --git a/Source/MQTTnet.Extensions.Rpc/MQTTnet.Extensions.Rpc.csproj b/Source/MQTTnet.Extensions.Rpc/MQTTnet.Extensions.Rpc.csproj index b38fb489d..f4353d21a 100644 --- a/Source/MQTTnet.Extensions.Rpc/MQTTnet.Extensions.Rpc.csproj +++ b/Source/MQTTnet.Extensions.Rpc/MQTTnet.Extensions.Rpc.csproj @@ -35,7 +35,7 @@ all true low - latest-Recommended + diff --git a/Source/MQTTnet.Server/MQTTnet.Server.csproj b/Source/MQTTnet.Server/MQTTnet.Server.csproj index df4863607..094fb10f8 100644 --- a/Source/MQTTnet.Server/MQTTnet.Server.csproj +++ b/Source/MQTTnet.Server/MQTTnet.Server.csproj @@ -36,7 +36,7 @@ low enable disable - latest-Recommended + diff --git a/Source/MQTTnet.TestApp/MQTTnet.TestApp.csproj b/Source/MQTTnet.TestApp/MQTTnet.TestApp.csproj index 374ded794..dec52e679 100644 --- a/Source/MQTTnet.TestApp/MQTTnet.TestApp.csproj +++ b/Source/MQTTnet.TestApp/MQTTnet.TestApp.csproj @@ -13,7 +13,7 @@ all true low - latest-Recommended + diff --git a/Source/MQTTnet.Tests/ASP/Mockups/ConnectionHandlerMockup.cs b/Source/MQTTnet.Tests/ASP/Mockups/ConnectionHandlerMockup.cs index f45284cc2..8e2e73020 100644 --- a/Source/MQTTnet.Tests/ASP/Mockups/ConnectionHandlerMockup.cs +++ b/Source/MQTTnet.Tests/ASP/Mockups/ConnectionHandlerMockup.cs @@ -5,6 +5,7 @@ using System; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http.Connections; using MQTTnet.Adapter; using MQTTnet.AspNetCore; using MQTTnet.Diagnostics.Logger; @@ -16,7 +17,7 @@ namespace MQTTnet.Tests.ASP.Mockups; public sealed class ConnectionHandlerMockup : IMqttServerAdapter { public Func ClientHandler { get; set; } - public TaskCompletionSource Context { get; } = new(); + TaskCompletionSource Context { get; } = new(); public void Dispose() { @@ -27,7 +28,7 @@ public async Task OnConnectedAsync(ConnectionContext connection) try { var formatter = new MqttPacketFormatterAdapter(new MqttBufferWriter(4096, 65535)); - var context = new MqttConnectionContext(formatter, connection); + var context = new MqttServerChannelAdapter(formatter, connection, connection.GetHttpContext()); Context.TrySetResult(context); await ClientHandler(context); diff --git a/Source/MQTTnet.Tests/ASP/MqttBufferWriterPoolTest.cs b/Source/MQTTnet.Tests/ASP/MqttBufferWriterPoolTest.cs new file mode 100644 index 000000000..8adb9e771 --- /dev/null +++ b/Source/MQTTnet.Tests/ASP/MqttBufferWriterPoolTest.cs @@ -0,0 +1,49 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.AspNetCore; +using MQTTnet.Formatter; +using MQTTnet.Server; +using System; +using System.Threading.Tasks; + +namespace MQTTnet.Tests.ASP +{ + [TestClass] + public class MqttBufferWriterPoolTest + { + [TestMethod] + public async Task RentReturnTest() + { + var services = new ServiceCollection(); + services.AddMqttServer().ConfigureMqttBufferWriterPool(p => + { + p.MaxLifetime = TimeSpan.FromSeconds(1d); + }); + + var s = services.BuildServiceProvider(); + var pool = s.GetRequiredService(); + var options = s.GetRequiredService(); + + var bufferWriter = pool.Rent(); + Assert.AreEqual(0, pool.Count); + + Assert.IsTrue(pool.Return(bufferWriter)); + Assert.AreEqual(1, pool.Count); + + bufferWriter = pool.Rent(); + Assert.AreEqual(0, pool.Count); + + await Task.Delay(TimeSpan.FromSeconds(2d)); + + Assert.IsFalse(pool.Return(bufferWriter)); + Assert.AreEqual(0, pool.Count); + + MqttBufferWriter writer = bufferWriter; + writer.Seek(options.WriterBufferSize + 1); + Assert.IsTrue(bufferWriter.BufferSize > options.WriterBufferSize); + + Assert.IsTrue(pool.Return(bufferWriter)); + Assert.AreEqual(1, pool.Count); + } + } +} diff --git a/Source/MQTTnet.Tests/ASP/MqttBuilderTest.cs b/Source/MQTTnet.Tests/ASP/MqttBuilderTest.cs new file mode 100644 index 000000000..a8e4c00a5 --- /dev/null +++ b/Source/MQTTnet.Tests/ASP/MqttBuilderTest.cs @@ -0,0 +1,31 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.AspNetCore; +using MQTTnet.Diagnostics.Logger; + +namespace MQTTnet.Tests.ASP +{ + [TestClass] + public class MqttBuilderTest + { + [TestMethod] + public void UseMqttNetNullLoggerTest() + { + var services = new ServiceCollection(); + services.AddMqttServer().UseMqttNetNullLogger(); + var s = services.BuildServiceProvider(); + var logger = s.GetRequiredService(); + Assert.IsInstanceOfType(logger); + } + + [TestMethod] + public void UseAspNetCoreMqttNetLoggerTest() + { + var services = new ServiceCollection(); + services.AddMqttServer().UseAspNetCoreMqttNetLogger(); + var s = services.BuildServiceProvider(); + var logger = s.GetRequiredService(); + Assert.IsInstanceOfType(logger); + } + } +} diff --git a/Source/MQTTnet.Tests/ASP/MqttClientBuilderTest.cs b/Source/MQTTnet.Tests/ASP/MqttClientBuilderTest.cs new file mode 100644 index 000000000..5258acd3e --- /dev/null +++ b/Source/MQTTnet.Tests/ASP/MqttClientBuilderTest.cs @@ -0,0 +1,50 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Adapter; +using MQTTnet.AspNetCore; +using MQTTnet.Implementations; + +namespace MQTTnet.Tests.ASP +{ + [TestClass] + public class MqttClientBuilderTest + { + [TestMethod] + public void AddMqttClientTest() + { + var services = new ServiceCollection(); + services.AddMqttClient(); + var s = services.BuildServiceProvider(); + + var mqttClientFactory1 = s.GetRequiredService(); + var mqttClientFactory2 = s.GetRequiredService(); + Assert.IsTrue(ReferenceEquals(mqttClientFactory2, mqttClientFactory2)); + + Assert.IsInstanceOfType(mqttClientFactory1); + Assert.IsInstanceOfType(mqttClientFactory1); + } + + [TestMethod] + public void UseMQTTnetMqttClientAdapterFactoryTest() + { + var services = new ServiceCollection(); + services.AddMqttClient().UseMQTTnetMqttClientAdapterFactory(); + var s = services.BuildServiceProvider(); + var adapterFactory = s.GetRequiredService(); + + Assert.IsInstanceOfType(adapterFactory); + } + + + [TestMethod] + public void UseAspNetCoreMqttClientAdapterFactoryTest() + { + var services = new ServiceCollection(); + services.AddMqttClient().UseAspNetCoreMqttClientAdapterFactory(); + var s = services.BuildServiceProvider(); + var adapterFactory = s.GetRequiredService(); + + Assert.IsInstanceOfType(adapterFactory); + } + } +} diff --git a/Source/MQTTnet.Tests/ASP/MqttConnectionContextTest.cs b/Source/MQTTnet.Tests/ASP/MqttConnectionContextTest.cs index bfd0f8431..83493671c 100644 --- a/Source/MQTTnet.Tests/ASP/MqttConnectionContextTest.cs +++ b/Source/MQTTnet.Tests/ASP/MqttConnectionContextTest.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http.Connections; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.AspNetCore; using MQTTnet.Exceptions; @@ -30,14 +31,14 @@ public async Task TestCorruptedConnectPacket() var pipe = new DuplexPipeMockup(); var connection = new DefaultConnectionContext(); connection.Transport = pipe; - var ctx = new MqttConnectionContext(serializer, connection); + var ctx = new MqttServerChannelAdapter(serializer, connection, connection.GetHttpContext()); await pipe.Receive.Writer.WriteAsync(writer.AddMqttHeader(MqttControlPacketType.Connect, Array.Empty())); await Assert.ThrowsExceptionAsync(() => ctx.ReceivePacketAsync(CancellationToken.None)); // the first exception should complete the pipes so if someone tries to use the connection after that it should throw immidiatly - await Assert.ThrowsExceptionAsync(() => ctx.ReceivePacketAsync(CancellationToken.None)); + await Assert.ThrowsExceptionAsync(() => ctx.ReceivePacketAsync(CancellationToken.None)); } // TODO: Fix test @@ -98,7 +99,7 @@ public async Task TestLargePacket() var pipe = new DuplexPipeMockup(); var connection = new DefaultConnectionContext(); connection.Transport = pipe; - var ctx = new MqttConnectionContext(serializer, connection); + var ctx = new MqttServerChannelAdapter(serializer, connection, connection.GetHttpContext()); await ctx.SendPacketAsync(new MqttPublishPacket { PayloadSegment = new byte[20_000] }, CancellationToken.None).ConfigureAwait(false); @@ -113,7 +114,7 @@ public async Task TestReceivePacketAsyncThrowsWhenReaderCompleted() var pipe = new DuplexPipeMockup(); var connection = new DefaultConnectionContext(); connection.Transport = pipe; - var ctx = new MqttConnectionContext(serializer, connection); + var ctx = new MqttServerChannelAdapter(serializer, connection, connection.GetHttpContext()); pipe.Receive.Writer.Complete(); diff --git a/Source/MQTTnet.Tests/ASP/MqttConnectionMiddlewareTest.cs b/Source/MQTTnet.Tests/ASP/MqttConnectionMiddlewareTest.cs new file mode 100644 index 000000000..a25faf88c --- /dev/null +++ b/Source/MQTTnet.Tests/ASP/MqttConnectionMiddlewareTest.cs @@ -0,0 +1,24 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.AspNetCore; +using System; +using System.Buffers; + +namespace MQTTnet.Tests.ASP +{ + [TestClass] + public class MqttConnectionMiddlewareTest + { + [TestMethod] + public void IsMqttRequestTest() + { + var mqttv31Request = Convert.FromHexString("102800044d51545404c0003c0008636c69656e7469640008757365726e616d650008706173736f777264"); + var mqttv50Request = Convert.FromHexString("102900044d51545405c0003c000008636c69656e7469640008757365726e616d650008706173736f777264"); + + var isMqttv31 = MqttConnectionMiddleware.IsMqttRequest(new ReadOnlySequence(mqttv31Request)); + var isMqttv50 = MqttConnectionMiddleware.IsMqttRequest(new ReadOnlySequence(mqttv50Request)); + + Assert.IsTrue(isMqttv31); + Assert.IsTrue(isMqttv50); + } + } +} diff --git a/Source/MQTTnet.Tests/ASP/ReaderExtensionsTest.cs b/Source/MQTTnet.Tests/ASP/MqttPacketFormatterAdapterExtensionsTest.cs similarity index 74% rename from Source/MQTTnet.Tests/ASP/ReaderExtensionsTest.cs rename to Source/MQTTnet.Tests/ASP/MqttPacketFormatterAdapterExtensionsTest.cs index 6c9cac8f8..7f810feb3 100644 --- a/Source/MQTTnet.Tests/ASP/ReaderExtensionsTest.cs +++ b/Source/MQTTnet.Tests/ASP/MqttPacketFormatterAdapterExtensionsTest.cs @@ -11,7 +11,7 @@ namespace MQTTnet.Tests.ASP; [TestClass] -public sealed class ReaderExtensionsTest +public sealed class MqttPacketFormatterAdapterExtensionsTest { [TestMethod] public void TestTryDeserialize() @@ -28,19 +28,19 @@ public void TestTryDeserialize() var read = 0; part = sequence.Slice(sequence.Start, 0); // empty message should fail - var result = serializer.TryDecode(part, out _, out consumed, out observed, out read); + var result = serializer.TryDecode(part,null, out _, out consumed, out observed, out read); Assert.IsFalse(result); part = sequence.Slice(sequence.Start, 1); // partial fixed header should fail - result = serializer.TryDecode(part, out _, out consumed, out observed, out read); + result = serializer.TryDecode(part, null, out _, out consumed, out observed, out read); Assert.IsFalse(result); part = sequence.Slice(sequence.Start, 4); // partial body should fail - result = serializer.TryDecode(part, out _, out consumed, out observed, out read); + result = serializer.TryDecode(part, null, out _, out consumed, out observed, out read); Assert.IsFalse(result); part = sequence; // complete msg should work - result = serializer.TryDecode(part, out _, out consumed, out observed, out read); + result = serializer.TryDecode(part, null, out _, out consumed, out observed, out read); Assert.IsTrue(result); } } \ No newline at end of file diff --git a/Source/MQTTnet.Tests/ASP/MqttServerBuilderTest.cs b/Source/MQTTnet.Tests/ASP/MqttServerBuilderTest.cs new file mode 100644 index 000000000..cd162bab9 --- /dev/null +++ b/Source/MQTTnet.Tests/ASP/MqttServerBuilderTest.cs @@ -0,0 +1,81 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.AspNetCore; +using MQTTnet.Server; +using MQTTnet.Server.Internal.Adapter; +using System.Collections.Generic; +using System.Linq; + +namespace MQTTnet.Tests.ASP +{ + [TestClass] + public class MqttServerBuilderTest + { + [TestMethod] + public void AddMqttServerTest() + { + var services = new ServiceCollection(); + services.AddMqttServer(); + var s = services.BuildServiceProvider(); + + var mqttServer1 = s.GetRequiredService(); + var mqttServer2 = s.GetRequiredService(); + Assert.IsInstanceOfType(mqttServer1); + Assert.AreEqual(mqttServer1, mqttServer2); + } + + [TestMethod] + public void ConfigureMqttServerTest() + { + const int TcpKeepAliveTime1 = 19; + const int TcpKeepAliveTime2 = 20; + + var services = new ServiceCollection(); + services.AddMqttServer().ConfigureMqttServer( + b => b.WithTcpKeepAliveTime(TcpKeepAliveTime1), + o => + { + Assert.AreEqual(TcpKeepAliveTime1, o.DefaultEndpointOptions.TcpKeepAliveTime); + o.DefaultEndpointOptions.TcpKeepAliveTime = TcpKeepAliveTime2; + }); + + var s = services.BuildServiceProvider(); + var options = s.GetRequiredService(); + Assert.AreEqual(TcpKeepAliveTime2, options.DefaultEndpointOptions.TcpKeepAliveTime); + } + + [TestMethod] + public void ConfigureMqttServerStopTest() + { + const string ReasonString1 = "ReasonString1"; + const string ReasonString2 = "ReasonString2"; + + var services = new ServiceCollection(); + services.AddMqttServer().ConfigureMqttServerStop( + b => b.WithDefaultClientDisconnectOptions(c => c.WithReasonString(ReasonString1)), + o => + { + Assert.AreEqual(ReasonString1, o.DefaultClientDisconnectOptions.ReasonString); + o.DefaultClientDisconnectOptions.ReasonString = ReasonString2; + }); + + var s = services.BuildServiceProvider(); + var options = s.GetRequiredService(); + Assert.AreEqual(ReasonString2, options.DefaultClientDisconnectOptions.ReasonString); + } + + [TestMethod] + public void AddMqttServerAdapterTest() + { + var services = new ServiceCollection(); + services.AddMqttServer().AddMqttServerAdapter(); + services.AddMqttServer().AddMqttServerAdapter(); + + var s = services.BuildServiceProvider(); + var adapters = s.GetRequiredService>().ToArray(); + Assert.AreEqual(2, adapters.Length); + Assert.IsInstanceOfType(adapters[0]); + Assert.IsInstanceOfType(adapters[1]); + } + } +} diff --git a/Source/MQTTnet.Tests/BaseTestClass.cs b/Source/MQTTnet.Tests/BaseTestClass.cs index 8e5248e7f..9441b09a0 100644 --- a/Source/MQTTnet.Tests/BaseTestClass.cs +++ b/Source/MQTTnet.Tests/BaseTestClass.cs @@ -2,21 +2,35 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Formatter; using MQTTnet.Tests.Mockups; +using System; +using System.Threading.Tasks; namespace MQTTnet.Tests { public abstract class BaseTestClass { public TestContext TestContext { get; set; } - - protected TestEnvironment CreateTestEnvironment(MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) + + protected TestEnvironmentCollection CreateMQTTnetTestEnvironment(MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) + { + var mqttnet = new TestEnvironment(TestContext, protocolVersion); + return new TestEnvironmentCollection(mqttnet); + } + + protected TestEnvironmentCollection CreateAspNetCoreTestEnvironment(MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) + { + var aspnetcore = new AspNetCoreTestEnvironment(TestContext, protocolVersion); + return new TestEnvironmentCollection(aspnetcore); + } + + protected TestEnvironmentCollection CreateMixedTestEnvironment(MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) { - return new TestEnvironment(TestContext, protocolVersion); + var mqttnet = new TestEnvironment(TestContext, protocolVersion); + var aspnetcore = new AspNetCoreTestEnvironment(TestContext, protocolVersion); + return new TestEnvironmentCollection(mqttnet, aspnetcore); } protected Task LongTestDelay() diff --git a/Source/MQTTnet.Tests/Clients/LowLevelMqttClient/LowLevelMqttClient_Tests.cs b/Source/MQTTnet.Tests/Clients/LowLevelMqttClient/LowLevelMqttClient_Tests.cs index 4c8a44ad8..86d0477bc 100644 --- a/Source/MQTTnet.Tests/Clients/LowLevelMqttClient/LowLevelMqttClient_Tests.cs +++ b/Source/MQTTnet.Tests/Clients/LowLevelMqttClient/LowLevelMqttClient_Tests.cs @@ -22,7 +22,8 @@ public sealed class LowLevelMqttClient_Tests : BaseTestClass [TestMethod] public async Task Authenticate() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -43,7 +44,8 @@ public async Task Authenticate() [TestMethod] public async Task Connect_And_Disconnect() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -78,7 +80,8 @@ public async Task Connect_To_Wrong_Host() [TestMethod] public async Task Loose_Connection() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreServerLogErrors = true; @@ -116,7 +119,8 @@ public async Task Loose_Connection() [TestMethod] public async Task Maintain_IsConnected_Property() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreServerLogErrors = true; @@ -161,7 +165,8 @@ public async Task Maintain_IsConnected_Property() [TestMethod] public async Task Subscribe() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Connection_Tests.cs b/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Connection_Tests.cs index 75b00d39d..f8173bca6 100644 --- a/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Connection_Tests.cs +++ b/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Connection_Tests.cs @@ -82,7 +82,8 @@ public async Task ConnectTimeout_Throws_Exception() [TestMethod] public async Task Disconnect_Clean() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -110,7 +111,8 @@ public async Task Disconnect_Clean() [TestMethod] public async Task Disconnect_Clean_With_Custom_Reason() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -138,7 +140,8 @@ public async Task Disconnect_Clean_With_Custom_Reason() [TestMethod] public async Task Disconnect_Clean_With_User_Properties() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -169,7 +172,8 @@ public async Task Disconnect_Clean_With_User_Properties() [TestMethod] public async Task No_Unobserved_Exception() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreClientLogErrors = true; @@ -201,7 +205,8 @@ public async Task No_Unobserved_Exception() [TestMethod] public async Task Return_Non_Success() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Tests.cs b/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Tests.cs index b8b1ba9e2..5c87a4621 100644 --- a/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Tests.cs +++ b/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Tests.cs @@ -36,7 +36,8 @@ public async Task Concurrent_Processing(MqttQualityOfServiceLevel qos) long concurrency = 0; var success = false; - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var publisher = await testEnvironment.ConnectClient(); @@ -74,7 +75,8 @@ async Task InvokeInternal() [TestMethod] public async Task Connect_Disconnect_Connect() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -91,7 +93,8 @@ public async Task Connect_Disconnect_Connect() [ExpectedException(typeof(InvalidOperationException))] public async Task Connect_Multiple_Times_Should_Fail() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -136,7 +139,8 @@ public async Task Disconnect_Event_Contains_Exception() [TestMethod] public async Task Ensure_Queue_Drain() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); var client = await testEnvironment.ConnectLowLevelClient(); @@ -177,7 +181,8 @@ await client.SendAsync( [TestMethod] public async Task Fire_Disconnected_Event_On_Server_Shutdown() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); var client = await testEnvironment.ConnectClient(); @@ -200,7 +205,8 @@ public async Task Fire_Disconnected_Event_On_Server_Shutdown() [TestMethod] public async Task Frequent_Connects() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -268,7 +274,8 @@ public async Task Invalid_Connect_Throws_Exception() [TestMethod] public async Task No_Payload() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -304,7 +311,8 @@ await receiver.SubscribeAsync( [TestMethod] public async Task NoConnectedHandler_Connect_DoesNotThrowException() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -317,7 +325,8 @@ public async Task NoConnectedHandler_Connect_DoesNotThrowException() [TestMethod] public async Task NoDisconnectedHandler_Disconnect_DoesNotThrowException() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var client = await testEnvironment.ConnectClient(); @@ -332,7 +341,8 @@ public async Task NoDisconnectedHandler_Disconnect_DoesNotThrowException() [TestMethod] public async Task PacketIdentifier_In_Publish_Result() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var client = await testEnvironment.ConnectClient(); @@ -365,7 +375,8 @@ public async Task Preserve_Message_Order() // is an issue). const int MessagesCount = 50; - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -410,7 +421,8 @@ public async Task Preserve_Message_Order_With_Delayed_Acknowledgement() // is an issue). const int MessagesCount = 50; - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -454,7 +466,8 @@ Task Handler1(MqttApplicationMessageReceivedEventArgs eventArgs) [TestMethod] public async Task Publish_QoS_0_Over_Period_Exceeding_KeepAlive() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { const int KeepAlivePeriodSecs = 3; @@ -486,7 +499,8 @@ public async Task Publish_QoS_0_Over_Period_Exceeding_KeepAlive() [TestMethod] public async Task Publish_QoS_1_In_ApplicationMessageReceiveHandler() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -530,7 +544,8 @@ public async Task Publish_QoS_1_In_ApplicationMessageReceiveHandler() [TestMethod] public async Task Publish_With_Correct_Retain_Flag() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -563,7 +578,8 @@ public async Task Publish_With_Correct_Retain_Flag() [TestMethod] public async Task Reconnect() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); var client = await testEnvironment.ConnectClient(); @@ -586,7 +602,8 @@ public async Task Reconnect() [TestMethod] public async Task Reconnect_From_Disconnected_Event() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreClientLogErrors = true; @@ -627,7 +644,8 @@ public async Task Reconnect_From_Disconnected_Event() [TestMethod] public async Task Reconnect_While_Server_Offline() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreClientLogErrors = true; @@ -665,7 +683,8 @@ public async Task Reconnect_While_Server_Offline() [TestMethod] public async Task Send_Manual_Ping() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var client = await testEnvironment.ConnectClient(); @@ -677,7 +696,8 @@ public async Task Send_Manual_Ping() [TestMethod] public async Task Send_Reply_For_Any_Received_Message() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -723,7 +743,8 @@ Task Handler2(MqttApplicationMessageReceivedEventArgs eventArgs) [TestMethod] public async Task Send_Reply_In_Message_Handler() { - using (var testEnvironment = new TestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var client1 = await testEnvironment.ConnectClient(); @@ -770,7 +791,8 @@ public async Task Send_Reply_In_Message_Handler() [TestMethod] public async Task Send_Reply_In_Message_Handler_For_Same_Client() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var client = await testEnvironment.ConnectClient(); @@ -806,7 +828,8 @@ public async Task Send_Reply_In_Message_Handler_For_Same_Client() [TestMethod] public async Task Set_ClientWasConnected_On_ClientDisconnect() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var client = await testEnvironment.ConnectClient(); @@ -826,7 +849,8 @@ public async Task Set_ClientWasConnected_On_ClientDisconnect() [TestMethod] public async Task Set_ClientWasConnected_On_ServerDisconnect() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); var client = await testEnvironment.ConnectClient(); @@ -847,7 +871,8 @@ public async Task Set_ClientWasConnected_On_ServerDisconnect() [TestMethod] public async Task Subscribe_In_Callback_Events() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -887,7 +912,8 @@ public async Task Subscribe_In_Callback_Events() [TestMethod] public async Task Subscribe_With_QoS2() { - using (var testEnvironment = new TestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var client1 = await testEnvironment.ConnectClient(o => o.WithProtocolVersion(MqttProtocolVersion.V500)); diff --git a/Source/MQTTnet.Tests/Diagnostics/PacketInspection_Tests.cs b/Source/MQTTnet.Tests/Diagnostics/PacketInspection_Tests.cs index 7db513f58..71e17c68e 100644 --- a/Source/MQTTnet.Tests/Diagnostics/PacketInspection_Tests.cs +++ b/Source/MQTTnet.Tests/Diagnostics/PacketInspection_Tests.cs @@ -18,7 +18,8 @@ public sealed class PacketInspection_Tests : BaseTestClass [TestMethod] public async Task Inspect_Client_Packets() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Extensions/Rpc_Tests.cs b/Source/MQTTnet.Tests/Extensions/Rpc_Tests.cs index 02dddbfa8..209b51ae1 100644 --- a/Source/MQTTnet.Tests/Extensions/Rpc_Tests.cs +++ b/Source/MQTTnet.Tests/Extensions/Rpc_Tests.cs @@ -22,7 +22,8 @@ public sealed class Rpc_Tests : BaseTestClass [TestMethod] public async Task Execute_Success_MQTT_V5_Mixed_Clients() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var responseSender = await testEnvironment.ConnectClient(); @@ -54,7 +55,8 @@ public async Task Execute_Success_Parameters_Propagated_Correctly() { TestParametersTopicGenerationStrategy.ExpectedParamName, "123" } }; - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -132,7 +134,8 @@ public Task Execute_Success_With_QoS_2_MQTT_V5_Use_ResponseTopic() [ExpectedException(typeof(MqttCommunicationTimedOutException))] public async Task Execute_Timeout() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -147,7 +150,8 @@ public async Task Execute_Timeout() [ExpectedException(typeof(MqttCommunicationTimedOutException))] public async Task Execute_Timeout_MQTT_V5_Mixed_Clients() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var responseSender = await testEnvironment.ConnectClient(); @@ -172,7 +176,8 @@ public async Task Execute_Timeout_MQTT_V5_Mixed_Clients() [ExpectedException(typeof(MqttCommunicationTimedOutException))] public async Task Execute_With_Custom_Topic_Names() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -196,7 +201,8 @@ public void Use_Factory() async Task Execute_Success(MqttQualityOfServiceLevel qosLevel, MqttProtocolVersion protocolVersion) { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var responseSender = await testEnvironment.ConnectClient(new MqttClientOptionsBuilder().WithProtocolVersion(protocolVersion)); @@ -217,7 +223,8 @@ async Task Execute_Success(MqttQualityOfServiceLevel qosLevel, MqttProtocolVersi async Task Execute_Success_MQTT_V5(MqttQualityOfServiceLevel qosLevel) { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var responseSender = await testEnvironment.ConnectClient(new MqttClientOptionsBuilder().WithProtocolVersion(MqttProtocolVersion.V500)); diff --git a/Source/MQTTnet.Tests/Internal/CrossPlatformSocket_Tests.cs b/Source/MQTTnet.Tests/Internal/CrossPlatformSocket_Tests.cs index 7476ffea8..765429e6b 100644 --- a/Source/MQTTnet.Tests/Internal/CrossPlatformSocket_Tests.cs +++ b/Source/MQTTnet.Tests/Internal/CrossPlatformSocket_Tests.cs @@ -2,14 +2,19 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Implementations; using System; +using System.Linq; using System.Net; +using System.Net.NetworkInformation; using System.Net.Sockets; using System.Text; using System.Threading; using System.Threading.Tasks; -using Microsoft.VisualStudio.TestTools.UnitTesting; -using MQTTnet.Implementations; namespace MQTTnet.Tests.Internal { @@ -19,19 +24,47 @@ public class CrossPlatformSocket_Tests [TestMethod] public async Task Connect_Send_Receive() { + var serverPort = GetServerPort(); + var responseContent = "Connect_Send_Receive"; + + // create a localhost web server. + var builder = WebApplication.CreateSlimBuilder(); + builder.WebHost.UseKestrel(k => k.ListenLocalhost(serverPort)); + + await using var webApp = builder.Build(); + var webAppStartedSource = new TaskCompletionSource(); + webApp.Lifetime.ApplicationStarted.Register(() => webAppStartedSource.TrySetResult()); + webApp.Use(next => context => context.Response.WriteAsync(responseContent)); + await webApp.StartAsync(); + await webAppStartedSource.Task; + + var crossPlatformSocket = new CrossPlatformSocket(ProtocolType.Tcp); - await crossPlatformSocket.ConnectAsync(new DnsEndPoint("www.google.de", 80), CancellationToken.None); + await crossPlatformSocket.ConnectAsync(new DnsEndPoint("localhost", serverPort), CancellationToken.None); - var requestBuffer = Encoding.UTF8.GetBytes("GET / HTTP/1.1\r\nHost: www.google.de\r\n\r\n"); - await crossPlatformSocket.SendAsync(new ArraySegment(requestBuffer), System.Net.Sockets.SocketFlags.None); + var requestBuffer = Encoding.UTF8.GetBytes($"GET /test/path HTTP/1.1\r\nHost: localhost:{serverPort}\r\n\r\n"); + await crossPlatformSocket.SendAsync(new ArraySegment(requestBuffer), SocketFlags.None); var buffer = new byte[1024]; - var length = await crossPlatformSocket.ReceiveAsync(new ArraySegment(buffer), System.Net.Sockets.SocketFlags.None); + var length = await crossPlatformSocket.ReceiveAsync(new ArraySegment(buffer), SocketFlags.None); crossPlatformSocket.Dispose(); var responseText = Encoding.UTF8.GetString(buffer, 0, length); - Assert.IsTrue(responseText.Contains("HTTP/1.1 200 OK")); + Assert.IsTrue(responseText.Contains(responseContent)); + + + static int GetServerPort(int defaultPort = 9999) + { + var listeners = IPGlobalProperties.GetIPGlobalProperties().GetActiveTcpListeners(); + var portSet = listeners.Select(i => i.Port).ToHashSet(); + + while (!portSet.Add(defaultPort)) + { + defaultPort += 1; + } + return defaultPort; + } } [TestMethod] diff --git a/Source/MQTTnet.Tests/MQTTnet.Tests.csproj b/Source/MQTTnet.Tests/MQTTnet.Tests.csproj index c89d8057b..e0fc8ee21 100644 --- a/Source/MQTTnet.Tests/MQTTnet.Tests.csproj +++ b/Source/MQTTnet.Tests/MQTTnet.Tests.csproj @@ -11,7 +11,7 @@ all true low - latest-Recommended + diff --git a/Source/MQTTnet.Tests/MQTTv5/Client_Tests.cs b/Source/MQTTnet.Tests/MQTTv5/Client_Tests.cs index 7b7570a99..d2d2fdf93 100644 --- a/Source/MQTTnet.Tests/MQTTv5/Client_Tests.cs +++ b/Source/MQTTnet.Tests/MQTTv5/Client_Tests.cs @@ -21,7 +21,8 @@ public sealed class Client_Tests : BaseTestClass [TestMethod] public async Task Connect_With_New_Mqtt_Features() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -66,7 +67,8 @@ await client.PublishAsync(new MqttApplicationMessageBuilder() [TestMethod] public async Task Connect() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); await testEnvironment.ConnectClient(o => o.WithProtocolVersion(MqttProtocolVersion.V500).Build()); @@ -76,7 +78,8 @@ public async Task Connect() [TestMethod] public async Task Connect_And_Disconnect() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -88,7 +91,8 @@ public async Task Connect_And_Disconnect() [TestMethod] public async Task Subscribe() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -114,7 +118,8 @@ public async Task Subscribe() [TestMethod] public async Task Unsubscribe() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -133,7 +138,8 @@ public async Task Unsubscribe() public async Task Publish_QoS_0_LargeBuffer() { using var recyclableMemoryStream = GetLargePayload(); - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -149,7 +155,8 @@ public async Task Publish_QoS_0_LargeBuffer() public async Task Publish_QoS_1_LargeBuffer() { using var recyclableMemoryStream = GetLargePayload(); - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -165,7 +172,8 @@ public async Task Publish_QoS_1_LargeBuffer() public async Task Publish_QoS_2_LargeBuffer() { using var recyclableMemoryStream = GetLargePayload(); - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -180,7 +188,8 @@ public async Task Publish_QoS_2_LargeBuffer() [TestMethod] public async Task Publish_QoS_0() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -195,7 +204,8 @@ public async Task Publish_QoS_0() [TestMethod] public async Task Publish_QoS_1() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -210,7 +220,8 @@ public async Task Publish_QoS_1() [TestMethod] public async Task Publish_QoS_2() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -226,7 +237,8 @@ public async Task Publish_QoS_2() public async Task Publish_With_RecyclableMemoryStream() { var memoryManager = new RecyclableMemoryStreamManager(options: new RecyclableMemoryStreamManager.Options { BlockSize = 4096 }); - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -264,7 +276,8 @@ public async Task Publish_With_RecyclableMemoryStream() [TestMethod] public async Task Publish_With_Properties() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -293,7 +306,8 @@ public async Task Publish_With_Properties() [TestMethod] public async Task Subscribe_And_Publish() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -321,7 +335,8 @@ public async Task Subscribe_And_Publish() [TestMethod] public async Task Publish_And_Receive_New_Properties() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs b/Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs index 067fed791..6a2849cb3 100644 --- a/Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs +++ b/Source/MQTTnet.Tests/MQTTv5/Server_Tests.cs @@ -10,6 +10,7 @@ using MQTTnet.Internal; using MQTTnet.Protocol; using MQTTnet.Server; +using System; namespace MQTTnet.Tests.MQTTv5 { @@ -19,7 +20,8 @@ public sealed class Server_Tests : BaseTestClass [TestMethod] public async Task Will_Message_Send() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -28,9 +30,11 @@ public async Task Will_Message_Send() var c1 = await testEnvironment.ConnectClient(new MqttClientOptionsBuilder().WithProtocolVersion(MqttProtocolVersion.V500)); var receivedMessagesCount = 0; + var taskSource = new TaskCompletionSource(); c1.ApplicationMessageReceivedAsync += e => { Interlocked.Increment(ref receivedMessagesCount); + taskSource.TrySetResult(); return CompletedTask.Instance; }; @@ -39,7 +43,7 @@ public async Task Will_Message_Send() var c2 = await testEnvironment.ConnectClient(clientOptions); c2.Dispose(); // Dispose will not send a DISCONNECT packet first so the will message must be sent. - await LongTestDelay(); + await taskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); Assert.AreEqual(1, receivedMessagesCount); } @@ -48,7 +52,8 @@ public async Task Will_Message_Send() [TestMethod] public async Task Validate_IsSessionPresent() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { // Create server with persistent sessions enabled @@ -85,7 +90,8 @@ public async Task Validate_IsSessionPresent() [TestMethod] public async Task Connect_with_Undefined_SessionExpiryInterval() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { // Create server with persistent sessions enabled @@ -125,7 +131,8 @@ public async Task Connect_with_Undefined_SessionExpiryInterval() [TestMethod] public async Task Reconnect_with_different_SessionExpiryInterval() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { // Create server with persistent sessions enabled @@ -177,17 +184,18 @@ public async Task Reconnect_with_different_SessionExpiryInterval() [TestMethod] public async Task Disconnect_with_Reason() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { - var disconnectReason = MqttClientDisconnectReason.UnspecifiedError; + var disconnectReasonTaskSource = new TaskCompletionSource(); - string testClientId = null; + var testClientIdTaskSource = new TaskCompletionSource(); await testEnvironment.StartServer(); testEnvironment.Server.ClientConnectedAsync += e => { - testClientId = e.ClientId; + testClientIdTaskSource.TrySetResult(e.ClientId); return CompletedTask.Instance; }; @@ -195,7 +203,7 @@ public async Task Disconnect_with_Reason() client.DisconnectedAsync += e => { - disconnectReason = e.Reason; + disconnectReasonTaskSource.TrySetResult(e.Reason); return CompletedTask.Instance; }; @@ -203,14 +211,14 @@ public async Task Disconnect_with_Reason() // Test client should be connected now + var testClientId = await testClientIdTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); Assert.IsTrue(testClientId != null); // Have the server disconnect the client with AdministrativeAction reason await testEnvironment.Server.DisconnectClientAsync(testClientId, MqttDisconnectReasonCode.AdministrativeAction); - await LongTestDelay(); - + var disconnectReason = await disconnectReasonTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); // The reason should be returned to the client in the DISCONNECT packet Assert.AreEqual(MqttClientDisconnectReason.AdministrativeAction, disconnectReason); diff --git a/Source/MQTTnet.Tests/Mockups/AspNetCoreTestEnvironment.cs b/Source/MQTTnet.Tests/Mockups/AspNetCoreTestEnvironment.cs new file mode 100644 index 000000000..739f318b8 --- /dev/null +++ b/Source/MQTTnet.Tests/Mockups/AspNetCoreTestEnvironment.cs @@ -0,0 +1,155 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.AspNetCore; +using MQTTnet.Diagnostics.Logger; +using MQTTnet.Formatter; +using MQTTnet.Internal; +using MQTTnet.LowLevelClient; +using MQTTnet.Protocol; +using MQTTnet.Server; +using System; +using System.Linq; +using System.Net.NetworkInformation; +using System.Threading.Tasks; + +namespace MQTTnet.Tests.Mockups +{ + public sealed class AspNetCoreTestEnvironment : TestEnvironment + { + private WebApplication _app; + + public AspNetCoreTestEnvironment() + : this(null) + { + } + + public AspNetCoreTestEnvironment(TestContext testContext, MqttProtocolVersion protocolVersion = MqttProtocolVersion.V311) + : base(testContext, protocolVersion) + { + } + + protected override IMqttClient CreateClientCore() + { + return CreateClientFactory().CreateMqttClient(); + } + + protected override ILowLevelMqttClient CreateLowLevelClientCore() + { + return CreateClientFactory().CreateLowLevelMqttClient(); + } + + private IMqttClientFactory CreateClientFactory() + { + var services = new ServiceCollection(); + + var logger = EnableLogger ? (IMqttNetLogger)ClientLogger : MqttNetNullLogger.Instance; + services.AddMqttClient().UseLogger(logger); + + return services.BuildServiceProvider().GetRequiredService(); + } + + public override MqttServer CreateServer(MqttServerOptions options) + { + throw new NotSupportedException("Can not create MqttServer in AspNetCoreTestEnvironment."); + } + + public override Task StartServer(Action configure) + { + var optionsBuilder = new MqttServerOptionsBuilder(); + configure?.Invoke(optionsBuilder); + return StartServer(optionsBuilder); + } + + public override Task StartServer(MqttServerOptionsBuilder optionsBuilder) + { + optionsBuilder.WithDefaultEndpoint(); + optionsBuilder.WithDefaultEndpointPort(ServerPort); + optionsBuilder.WithMaxPendingMessagesPerClient(int.MaxValue); + var serverOptions = optionsBuilder.Build(); + return StartServer(serverOptions); + } + + private async Task StartServer(MqttServerOptions serverOptions) + { + if (Server != null) + { + throw new InvalidOperationException("Server already started."); + } + + if (serverOptions.DefaultEndpointOptions.Port == 0) + { + var serverPort = ServerPort > 0 ? ServerPort : GetServerPort(); + serverOptions.DefaultEndpointOptions.Port = serverPort; + } + + var appBuilder = WebApplication.CreateBuilder(); + appBuilder.Services.AddSingleton(serverOptions); + + var logger = EnableLogger ? (IMqttNetLogger)ServerLogger : new MqttNetNullLogger(); + appBuilder.Services.AddMqttServer().UseLogger(logger); + + appBuilder.WebHost.UseKestrel(k => k.ListenMqtt()); + appBuilder.Host.ConfigureHostOptions(h => h.ShutdownTimeout = TimeSpan.FromMilliseconds(500d)); + + _app = appBuilder.Build(); + + Server = _app.Services.GetRequiredService(); + ServerPort = serverOptions.DefaultEndpointOptions.Port; + + Server.ValidatingConnectionAsync += e => + { + if (TestContext != null) + { + // Null is used when the client id is assigned from the server! + if (!string.IsNullOrEmpty(e.ClientId) && !e.ClientId.StartsWith(TestContext.TestName)) + { + TrackException(new InvalidOperationException($"Invalid client ID used ({e.ClientId}). It must start with UnitTest name.")); + e.ReasonCode = MqttConnectReasonCode.ClientIdentifierNotValid; + } + } + + return CompletedTask.Instance; + }; + + var appStartedSource = new TaskCompletionSource(); + _app.Lifetime.ApplicationStarted.Register(() => appStartedSource.TrySetResult()); + + await _app.StartAsync(); + await appStartedSource.Task; + + return Server; + } + + + private static int GetServerPort() + { + var listeners = IPGlobalProperties.GetIPGlobalProperties().GetActiveTcpListeners(); + var portSet = listeners.Select(i => i.Port).ToHashSet(); + + var port = 1883; + while (!portSet.Add(port)) + { + port += 1; + } + return port; + } + + public override void Dispose() + { + base.Dispose(); + if (_app != null) + { + _app.StopAsync().ConfigureAwait(false).GetAwaiter().GetResult(); + _app.DisposeAsync().ConfigureAwait(false).GetAwaiter().GetResult(); + _app = null; + } + } + } +} \ No newline at end of file diff --git a/Source/MQTTnet.Tests/Mockups/TestEnvironment.cs b/Source/MQTTnet.Tests/Mockups/TestEnvironment.cs index 4f1391f15..a11ebda0e 100644 --- a/Source/MQTTnet.Tests/Mockups/TestEnvironment.cs +++ b/Source/MQTTnet.Tests/Mockups/TestEnvironment.cs @@ -19,8 +19,9 @@ namespace MQTTnet.Tests.Mockups { - public sealed class TestEnvironment : IDisposable + public class TestEnvironment : IDisposable { + bool _disposed = false; readonly List _clientErrors = new(); readonly List _clients = new(); readonly List _exceptions = new(); @@ -87,7 +88,7 @@ public TestEnvironment(TestContext testContext, MqttProtocolVersion protocolVers public bool IgnoreServerLogErrors { get; set; } - public MqttServer Server { get; private set; } + public MqttServer Server { get; protected set; } public MqttNetEventLogger ServerLogger { get; } = new("server"); @@ -197,9 +198,7 @@ public TestApplicationMessageReceivedHandler CreateApplicationMessageHandler(IMq public IMqttClient CreateClient() { - var logger = EnableLogger ? (IMqttNetLogger)ClientLogger : MqttNetNullLogger.Instance; - - var client = ClientFactory.CreateMqttClient(logger); + var client = CreateClientCore(); client.ConnectingAsync += e => { @@ -224,6 +223,12 @@ public IMqttClient CreateClient() return client; } + protected virtual IMqttClient CreateClientCore() + { + var logger = EnableLogger ? (IMqttNetLogger)ClientLogger : MqttNetNullLogger.Instance; + return ClientFactory.CreateMqttClient(logger); + } + public MqttClientOptions CreateDefaultClientOptions() { return CreateDefaultClientOptionsBuilder().Build(); @@ -239,7 +244,7 @@ public MqttClientOptionsBuilder CreateDefaultClientOptionsBuilder() public ILowLevelMqttClient CreateLowLevelClient() { - var client = ClientFactory.CreateLowLevelMqttClient(ClientLogger); + var client = CreateLowLevelClientCore(); lock (_lowLevelClients) { @@ -249,7 +254,13 @@ public ILowLevelMqttClient CreateLowLevelClient() return client; } - public MqttServer CreateServer(MqttServerOptions options) + protected virtual ILowLevelMqttClient CreateLowLevelClientCore() + { + return ClientFactory.CreateLowLevelMqttClient(ClientLogger); + } + + + public virtual MqttServer CreateServer(MqttServerOptions options) { if (Server != null) { @@ -278,8 +289,14 @@ public MqttServer CreateServer(MqttServerOptions options) return Server; } - public void Dispose() + public virtual void Dispose() { + if (_disposed) + { + return; + } + _disposed = true; + try { lock (_clients) @@ -350,7 +367,7 @@ public Task StartServer() return StartServer(ServerFactory.CreateServerOptionsBuilder()); } - public async Task StartServer(MqttServerOptionsBuilder optionsBuilder) + public virtual async Task StartServer(MqttServerOptionsBuilder optionsBuilder) { optionsBuilder.WithDefaultEndpoint(); optionsBuilder.WithDefaultEndpointPort(ServerPort); @@ -365,7 +382,7 @@ public async Task StartServer(MqttServerOptionsBuilder optionsBuilde return server; } - public async Task StartServer(Action configure) + public virtual async Task StartServer(Action configure) { var optionsBuilder = ServerFactory.CreateServerOptionsBuilder(); diff --git a/Source/MQTTnet.Tests/Mockups/TestEnvironmentCollection.cs b/Source/MQTTnet.Tests/Mockups/TestEnvironmentCollection.cs new file mode 100644 index 000000000..2fc57593a --- /dev/null +++ b/Source/MQTTnet.Tests/Mockups/TestEnvironmentCollection.cs @@ -0,0 +1,39 @@ +using System; +using System.Collections; +using System.Collections.Generic; + +namespace MQTTnet.Tests.Mockups +{ + public class TestEnvironmentCollection : IReadOnlyCollection, IDisposable + { + private readonly TestEnvironment[] _testEnvironments; + + public int Count => _testEnvironments.Length; + + public TestEnvironmentCollection(params TestEnvironment[] testEnvironments) + { + _testEnvironments = testEnvironments; + } + + public IEnumerator GetEnumerator() + { + foreach (var environment in _testEnvironments) + { + yield return environment; + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + public void Dispose() + { + foreach (var environment in _testEnvironments) + { + environment.Dispose(); + } + } + } +} diff --git a/Source/MQTTnet.Tests/RoundtripTime_Tests.cs b/Source/MQTTnet.Tests/RoundtripTime_Tests.cs index b1359bc3d..b8b0fe25c 100644 --- a/Source/MQTTnet.Tests/RoundtripTime_Tests.cs +++ b/Source/MQTTnet.Tests/RoundtripTime_Tests.cs @@ -13,14 +13,14 @@ namespace MQTTnet.Tests { [TestClass] - public class RoundtripTime_Tests + public class RoundtripTime_Tests : BaseTestClass { - public TestContext TestContext { get; set; } [TestMethod] public async Task Round_Trip_Time() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Assigned_Client_ID_Tests.cs b/Source/MQTTnet.Tests/Server/Assigned_Client_ID_Tests.cs index 69ed38875..14f8b00ac 100644 --- a/Source/MQTTnet.Tests/Server/Assigned_Client_ID_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Assigned_Client_ID_Tests.cs @@ -28,7 +28,8 @@ public Task Connect_With_Client_Id() async Task Connect_With_Client_Id(string expectedClientId, string expectedReturnedClientId, string usedClientId, string assignedClientId) { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { string serverConnectedClientId = null; string serverDisconnectedClientId = null; diff --git a/Source/MQTTnet.Tests/Server/Connection_Tests.cs b/Source/MQTTnet.Tests/Server/Connection_Tests.cs index ef2a482c4..26d180269 100644 --- a/Source/MQTTnet.Tests/Server/Connection_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Connection_Tests.cs @@ -20,7 +20,8 @@ public sealed class Connection_Tests : BaseTestClass [TestMethod] public async Task Close_Idle_Connection_On_Connect() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(1))); @@ -49,7 +50,8 @@ public async Task Close_Idle_Connection_On_Connect() [TestMethod] public async Task Send_Garbage() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(1))); diff --git a/Source/MQTTnet.Tests/Server/Cross_Version_Tests.cs b/Source/MQTTnet.Tests/Server/Cross_Version_Tests.cs index 83a30ecfe..f3b946d7f 100644 --- a/Source/MQTTnet.Tests/Server/Cross_Version_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Cross_Version_Tests.cs @@ -12,7 +12,8 @@ public sealed class Cross_Version_Tests : BaseTestClass [TestMethod] public async Task Send_V311_Receive_V500() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -36,7 +37,8 @@ public async Task Send_V311_Receive_V500() [TestMethod] public async Task Send_V500_Receive_V311() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -48,9 +50,9 @@ public async Task Send_V500_Receive_V311() var applicationMessage = new MqttApplicationMessageBuilder().WithTopic("My/Message") .WithPayload("My_Payload") - .WithUserProperty("A", "B") - .WithResponseTopic("Response") - .WithCorrelationData(Encoding.UTF8.GetBytes("Correlation")) + //.WithUserProperty("A", "B") + //.WithResponseTopic("Response") + //.WithCorrelationData(Encoding.UTF8.GetBytes("Correlation")) .Build(); await sender.PublishAsync(applicationMessage); diff --git a/Source/MQTTnet.Tests/Server/Events_Tests.cs b/Source/MQTTnet.Tests/Server/Events_Tests.cs index bce9197ed..d6f3895e5 100644 --- a/Source/MQTTnet.Tests/Server/Events_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Events_Tests.cs @@ -9,6 +9,7 @@ using MQTTnet.Internal; using MQTTnet.Protocol; using MQTTnet.Server; +using MQTTnet.Tests.Mockups; namespace MQTTnet.Tests.Server { @@ -18,20 +19,21 @@ public sealed class Events_Tests : BaseTestClass [TestMethod] public async Task Fire_Client_Connected_Event() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); - - ClientConnectedEventArgs eventArgs = null; + var eventArgsTaskSource = new TaskCompletionSource(); + server.ClientConnectedAsync += e => { - eventArgs = e; + eventArgsTaskSource.TrySetResult(e); return CompletedTask.Instance; }; await testEnvironment.ConnectClient(o => o.WithCredentials("TheUser", "ThePassword")); - await LongTestDelay(); + var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); Assert.IsNotNull(eventArgs); @@ -46,21 +48,23 @@ public async Task Fire_Client_Connected_Event() [TestMethod] public async Task Fire_Client_Disconnected_Event() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); - ClientDisconnectedEventArgs eventArgs = null; + var eventArgsTaskSource = new TaskCompletionSource(); + server.ClientDisconnectedAsync += e => { - eventArgs = e; + eventArgsTaskSource.TrySetResult(e); return CompletedTask.Instance; }; var client = await testEnvironment.ConnectClient(o => o.WithCredentials("TheUser", "ThePassword")); await client.DisconnectAsync(); - await LongTestDelay(); + var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); Assert.IsNotNull(eventArgs); @@ -76,21 +80,23 @@ public async Task Fire_Client_Disconnected_Event() [TestMethod] public async Task Fire_Client_Subscribed_Event() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); - ClientSubscribedTopicEventArgs eventArgs = null; + var eventArgsTaskSource = new TaskCompletionSource(); + server.ClientSubscribedTopicAsync += e => { - eventArgs = e; + eventArgsTaskSource.TrySetResult(e); return CompletedTask.Instance; }; var client = await testEnvironment.ConnectClient(o => o.WithCredentials("TheUser")); await client.SubscribeAsync("The/Topic", MqttQualityOfServiceLevel.AtLeastOnce); - await LongTestDelay(); + var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); Assert.IsNotNull(eventArgs); @@ -104,21 +110,23 @@ public async Task Fire_Client_Subscribed_Event() [TestMethod] public async Task Fire_Client_Unsubscribed_Event() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); - ClientUnsubscribedTopicEventArgs eventArgs = null; + var eventArgsTaskSource = new TaskCompletionSource(); + server.ClientUnsubscribedTopicAsync += e => { - eventArgs = e; + eventArgsTaskSource.TrySetResult(e); return CompletedTask.Instance; }; var client = await testEnvironment.ConnectClient(o => o.WithCredentials("TheUser")); await client.UnsubscribeAsync("The/Topic"); - await LongTestDelay(); + var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); Assert.IsNotNull(eventArgs); @@ -131,21 +139,23 @@ public async Task Fire_Client_Unsubscribed_Event() [TestMethod] public async Task Fire_Application_Message_Received_Event() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); - InterceptingPublishEventArgs eventArgs = null; + var eventArgsTaskSource = new TaskCompletionSource(); + server.InterceptingPublishAsync += e => { - eventArgs = e; + eventArgsTaskSource.TrySetResult(e); return CompletedTask.Instance; }; var client = await testEnvironment.ConnectClient(o => o.WithCredentials("TheUser")); await client.PublishStringAsync("The_Topic", "The_Payload"); - await LongTestDelay(); + var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); Assert.IsNotNull(eventArgs); @@ -159,20 +169,21 @@ public async Task Fire_Application_Message_Received_Event() [TestMethod] public async Task Fire_Started_Event() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMQTTnetTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = testEnvironment.CreateServer(new MqttServerOptions()); - EventArgs eventArgs = null; + var eventArgsTaskSource = new TaskCompletionSource(); server.StartedAsync += e => { - eventArgs = e; + eventArgsTaskSource.TrySetResult(e); return CompletedTask.Instance; }; await server.StartAsync(); - await LongTestDelay(); + var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); Assert.IsNotNull(eventArgs); } @@ -181,20 +192,21 @@ public async Task Fire_Started_Event() [TestMethod] public async Task Fire_Stopped_Event() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); - - EventArgs eventArgs = null; + var eventArgsTaskSource = new TaskCompletionSource(); + server.StoppedAsync += e => { - eventArgs = e; + eventArgsTaskSource.TrySetResult(e); return CompletedTask.Instance; }; await server.StopAsync(); - await LongTestDelay(); + var eventArgs = await eventArgsTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); Assert.IsNotNull(eventArgs); } diff --git a/Source/MQTTnet.Tests/Server/General.cs b/Source/MQTTnet.Tests/Server/General.cs index 45cff1983..3228f0a6f 100644 --- a/Source/MQTTnet.Tests/Server/General.cs +++ b/Source/MQTTnet.Tests/Server/General.cs @@ -25,7 +25,8 @@ public sealed class General_Tests : BaseTestClass [TestMethod] public async Task Client_Disconnect_Without_Errors() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { bool clientWasConnected; @@ -54,7 +55,8 @@ public async Task Client_Disconnect_Without_Errors() [TestMethod] public async Task Collect_Messages_In_Disconnected_Session() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(new MqttServerOptionsBuilder().WithPersistentSessions()); @@ -88,7 +90,8 @@ public async Task Collect_Messages_In_Disconnected_Session() [TestMethod] public async Task Deny_Connection() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreClientLogErrors = true; @@ -110,7 +113,8 @@ public async Task Deny_Connection() [TestMethod] public async Task Do_Not_Send_Retained_Messages_For_Denied_Subscription() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -162,7 +166,8 @@ public async Task Do_Not_Send_Retained_Messages_For_Denied_Subscription() [TestMethod] public async Task Handle_Clean_Disconnect() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(new MqttServerOptionsBuilder()); @@ -204,7 +209,8 @@ public async Task Handle_Lots_Of_Parallel_Retained_Messages() { const int clientCount = 50; - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -268,7 +274,8 @@ await client.PublishAsync( [TestMethod] public async Task Intercept_Application_Message() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -299,7 +306,8 @@ public async Task Intercept_Application_Message() [TestMethod] public async Task Intercept_Message() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); server.InterceptingPublishAsync += e => @@ -331,7 +339,8 @@ public async Task Intercept_Message() [TestMethod] public async Task Intercept_Undelivered() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var undeliverd = string.Empty; @@ -357,7 +366,8 @@ public async Task Intercept_Undelivered() [TestMethod] public async Task No_Messages_If_No_Subscription() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -392,7 +402,8 @@ public async Task No_Messages_If_No_Subscription() [TestMethod] public async Task Persist_Retained_Message() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { List savedRetainedMessages = null; @@ -416,7 +427,8 @@ public async Task Persist_Retained_Message() [TestMethod] public async Task Publish_After_Client_Connects() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); server.ClientConnectedAsync += async e => @@ -477,7 +489,8 @@ public async Task Publish_Exactly_Once_0x02() [TestMethod] public async Task Publish_From_Server() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -508,10 +521,10 @@ await server.InjectApplicationMessage( [TestMethod] public async Task Publish_Multiple_Clients() { - var receivedMessagesCount = 0; - - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { + var receivedMessagesCount = 0; await testEnvironment.StartServer(); var c1 = await testEnvironment.ConnectClient(); @@ -549,7 +562,8 @@ public async Task Publish_Multiple_Clients() [TestMethod] public async Task Remove_Session() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(new MqttServerOptionsBuilder()); @@ -568,7 +582,8 @@ public async Task Remove_Session() [TestMethod] public async Task Same_Client_Id_Connect_Disconnect_Event_Order() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -655,7 +670,8 @@ public async Task Same_Client_Id_Connect_Disconnect_Event_Order() [TestMethod] public async Task Same_Client_Id_Refuse_Connection() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreClientLogErrors = true; @@ -754,7 +770,8 @@ public async Task Same_Client_Id_Refuse_Connection() [TestMethod] public async Task Send_Long_Body() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { const int PayloadSizeInMB = 30; const int CharCount = PayloadSizeInMB * 1024 * 1024; @@ -774,15 +791,16 @@ public async Task Send_Long_Body() } } - byte[] receivedBody = null; + TaskCompletionSource> receivedBodyTaskSource = new(); await testEnvironment.StartServer(); var client1 = await testEnvironment.ConnectClient(); client1.ApplicationMessageReceivedAsync += e => { - receivedBody = e.ApplicationMessage.Payload.ToArray(); - return CompletedTask.Instance; + var payload = e.ApplicationMessage.Payload; + receivedBodyTaskSource.TrySetResult(payload); + return Task.CompletedTask; }; await client1.SubscribeAsync("string"); @@ -790,16 +808,17 @@ public async Task Send_Long_Body() var client2 = await testEnvironment.ConnectClient(); await client2.PublishBinaryAsync("string", longBody); - await Task.Delay(TimeSpan.FromSeconds(5)); + var receivedBody = await receivedBodyTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); - Assert.IsTrue(longBody.SequenceEqual(receivedBody ?? new byte[0])); + Assert.IsTrue(MqttMemoryHelper.SequenceEqual(receivedBody, new ReadOnlySequence(longBody))); } } [TestMethod] public async Task Set_Subscription_At_Server() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -837,7 +856,8 @@ public async Task Set_Subscription_At_Server() [TestMethod] public async Task Shutdown_Disconnects_Clients_Gracefully() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(new MqttServerOptionsBuilder()); @@ -863,7 +883,8 @@ public async Task Shutdown_Disconnects_Clients_Gracefully() [TestMethod] public async Task Stop_And_Restart() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreClientLogErrors = true; @@ -895,7 +916,8 @@ public async Task Stop_And_Restart() [DataRow(null, null)] public async Task Use_Admissible_Credentials(string username, string password) { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -913,7 +935,8 @@ public async Task Use_Admissible_Credentials(string username, string password) [TestMethod] public async Task Use_Empty_Client_ID() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var client2 = await testEnvironment.ConnectClient(new MqttClientOptionsBuilder().WithClientId("b").WithCleanSession(false)); @@ -936,7 +959,8 @@ public async Task Use_Empty_Client_ID() [TestMethod] public async Task Disconnect_Client_with_Reason() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var disconnectPacketReceived = false; @@ -1007,7 +1031,8 @@ async Task TestPublishAsync( MqttQualityOfServiceLevel filterQualityOfServiceLevel, int expectedReceivedMessagesCount) { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Injection_Tests.cs b/Source/MQTTnet.Tests/Server/Injection_Tests.cs index cefbc34dd..9c6db90d0 100644 --- a/Source/MQTTnet.Tests/Server/Injection_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Injection_Tests.cs @@ -11,7 +11,8 @@ public sealed class Injection_Tests : BaseTestClass [TestMethod] public async Task Inject_Application_Message_At_Session_Level() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); var receiver1 = await testEnvironment.ConnectClient(); @@ -40,7 +41,8 @@ public async Task Inject_Application_Message_At_Session_Level() [TestMethod] public async Task Inject_ApplicationMessage_At_Server_Level() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -64,7 +66,8 @@ public async Task Inject_ApplicationMessage_At_Server_Level() [TestMethod] public async Task Intercept_Injected_Application_Message() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Keep_Alive_Tests.cs b/Source/MQTTnet.Tests/Server/Keep_Alive_Tests.cs index 8bcd8b0a8..43b781bbf 100644 --- a/Source/MQTTnet.Tests/Server/Keep_Alive_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Keep_Alive_Tests.cs @@ -9,6 +9,7 @@ using MQTTnet.Formatter; using MQTTnet.Packets; using MQTTnet.Protocol; +using MQTTnet.Tests.Mockups; namespace MQTTnet.Tests.Server { @@ -18,7 +19,8 @@ public sealed class KeepAlive_Tests : BaseTestClass [TestMethod] public async Task Disconnect_Client_DueTo_KeepAlive() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -40,12 +42,12 @@ await client.SendAsync(new MqttConnectPacket for (var i = 0; i < 6; i++) { await Task.Delay(500); - + await client.SendAsync(MqttPingReqPacket.Instance, CancellationToken.None); responsePacket = await client.ReceiveAsync(CancellationToken.None); Assert.IsTrue(responsePacket is MqttPingRespPacket); } - + // If we reach this point everything works as expected (server did not close the connection // due to proper ping messages. // Now we will wait 1.1 seconds because the server MUST wait 1.5 seconds in total (See spec). diff --git a/Source/MQTTnet.Tests/Server/Load_Tests.cs b/Source/MQTTnet.Tests/Server/Load_Tests.cs index 466e1cfa0..019fa21cb 100644 --- a/Source/MQTTnet.Tests/Server/Load_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Load_Tests.cs @@ -14,7 +14,8 @@ public sealed class Load_Tests : BaseTestClass [TestMethod] public async Task Handle_100_000_Messages_In_Receiving_Client() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -64,7 +65,8 @@ await client.PublishAsync(message) [TestMethod] public async Task Handle_100_000_Messages_In_Low_Level_Client() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -127,7 +129,8 @@ await client.SendAsync(publishPacket, CancellationToken.None) [TestMethod] public async Task Handle_100_000_Messages_In_Server() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/No_Local_Tests.cs b/Source/MQTTnet.Tests/Server/No_Local_Tests.cs index 5a8627d2f..208d3979c 100644 --- a/Source/MQTTnet.Tests/Server/No_Local_Tests.cs +++ b/Source/MQTTnet.Tests/Server/No_Local_Tests.cs @@ -27,7 +27,8 @@ async Task ExecuteTest( bool noLocal, int expectedCountAfterPublish) { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Publishing_Tests.cs b/Source/MQTTnet.Tests/Server/Publishing_Tests.cs index e829e2573..469b03fd4 100644 --- a/Source/MQTTnet.Tests/Server/Publishing_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Publishing_Tests.cs @@ -19,7 +19,8 @@ public sealed class Publishing_Tests : BaseTestClass [ExpectedException(typeof(MqttClientDisconnectedException))] public async Task Disconnect_While_Publishing() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -34,7 +35,8 @@ public async Task Disconnect_While_Publishing() [TestMethod] public async Task Return_NoMatchingSubscribers_When_Not_Subscribed() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -55,7 +57,8 @@ public async Task Return_NoMatchingSubscribers_When_Not_Subscribed() [TestMethod] public async Task Return_Success_When_Subscribed() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -76,7 +79,8 @@ public async Task Return_Success_When_Subscribed() [TestMethod] public async Task Intercept_Client_Enqueue() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -110,7 +114,8 @@ public async Task Intercept_Client_Enqueue() [TestMethod] public async Task Intercept_Client_Enqueue_Multiple_Clients_Subscribed_Messages_Are_Filtered() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/QoS_Tests.cs b/Source/MQTTnet.Tests/Server/QoS_Tests.cs index b48a1a99c..b52c159d8 100644 --- a/Source/MQTTnet.Tests/Server/QoS_Tests.cs +++ b/Source/MQTTnet.Tests/Server/QoS_Tests.cs @@ -17,7 +17,8 @@ public sealed class QoS_Tests : BaseTestClass [TestMethod] public async Task Fire_Event_On_Client_Acknowledges_QoS_0() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -44,7 +45,8 @@ public async Task Fire_Event_On_Client_Acknowledges_QoS_0() [TestMethod] public async Task Fire_Event_On_Client_Acknowledges_QoS_1() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -75,7 +77,8 @@ public async Task Fire_Event_On_Client_Acknowledges_QoS_1() [TestMethod] public async Task Fire_Event_On_Client_Acknowledges_QoS_2() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -114,7 +117,8 @@ public async Task Fire_Event_On_Client_Acknowledges_QoS_2() [TestMethod] public async Task Preserve_Message_Order_For_Queued_Messages() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(o => o.WithPersistentSessions()); diff --git a/Source/MQTTnet.Tests/Server/Retain_As_Published_Tests.cs b/Source/MQTTnet.Tests/Server/Retain_As_Published_Tests.cs index 43e706247..4dcc495fa 100644 --- a/Source/MQTTnet.Tests/Server/Retain_As_Published_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Retain_As_Published_Tests.cs @@ -25,7 +25,8 @@ public Task Subscribe_Without_Retain_As_Published() async Task ExecuteTest(bool retainAsPublished) { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Retain_Handling_Tests.cs b/Source/MQTTnet.Tests/Server/Retain_Handling_Tests.cs index aa4b7f227..78d1aa4e6 100644 --- a/Source/MQTTnet.Tests/Server/Retain_Handling_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Retain_Handling_Tests.cs @@ -36,7 +36,8 @@ async Task ExecuteTest( int expectedCountAfterSecondPublish, int expectedCountAfterSecondSubscribe) { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Retained_Messages_Tests.cs b/Source/MQTTnet.Tests/Server/Retained_Messages_Tests.cs index fbbe500d8..583d72643 100644 --- a/Source/MQTTnet.Tests/Server/Retained_Messages_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Retained_Messages_Tests.cs @@ -17,7 +17,8 @@ public sealed class Retained_Messages_Tests : BaseTestClass [TestMethod] public async Task Clear_Retained_Message_With_Empty_Payload() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -42,7 +43,8 @@ public async Task Clear_Retained_Message_With_Empty_Payload() [TestMethod] public async Task Clear_Retained_Message_With_Null_Payload() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -67,7 +69,8 @@ public async Task Clear_Retained_Message_With_Null_Payload() [TestMethod] public async Task Downgrade_QoS_Level() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -97,7 +100,8 @@ await c1.PublishAsync( [TestMethod] public async Task No_Upgrade_QoS_Level() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -127,7 +131,8 @@ await c1.PublishAsync( [TestMethod] public async Task Receive_No_Retained_Message_After_Subscribe() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -148,7 +153,8 @@ public async Task Receive_No_Retained_Message_After_Subscribe() [TestMethod] public async Task Receive_Retained_Message_After_Subscribe() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -171,7 +177,8 @@ public async Task Receive_Retained_Message_After_Subscribe() [TestMethod] public async Task Receive_Retained_Messages_From_Higher_Qos_Level() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -202,7 +209,8 @@ await c1.PublishAsync( [TestMethod] public async Task Retained_Messages_Flow() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var retainedMessage = new MqttApplicationMessageBuilder().WithTopic("r").WithPayload("r").WithRetainFlag().Build(); @@ -234,7 +242,8 @@ public async Task Retained_Messages_Flow() [TestMethod] public async Task Server_Reports_Retained_Messages_Supported_V3() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -253,7 +262,8 @@ public async Task Server_Reports_Retained_Messages_Supported_V3() [TestMethod] public async Task Server_Reports_Retained_Messages_Supported_V5() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Security_Tests.cs b/Source/MQTTnet.Tests/Server/Security_Tests.cs index 404314c43..5717dadc3 100644 --- a/Source/MQTTnet.Tests/Server/Security_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Security_Tests.cs @@ -18,7 +18,8 @@ public sealed class Security_Tests : BaseTestClass [TestMethod] public async Task Do_Not_Affect_Authorized_Clients() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreClientLogErrors = true; @@ -116,11 +117,12 @@ public Task Handle_Wrong_UserName_And_Password() [TestMethod] public async Task Use_Username_Null_Password_Empty() { - string username = null; - var password = string.Empty; - - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { + string username = null; + var password = string.Empty; + testEnvironment.IgnoreClientLogErrors = true; await testEnvironment.StartServer(); @@ -137,7 +139,8 @@ public async Task Use_Username_Null_Password_Empty() async Task TestCredentials(string userName, string password) { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreClientLogErrors = true; diff --git a/Source/MQTTnet.Tests/Server/Server_Reference_Tests.cs b/Source/MQTTnet.Tests/Server/Server_Reference_Tests.cs index 2b4f96c1c..5b53f2c95 100644 --- a/Source/MQTTnet.Tests/Server/Server_Reference_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Server_Reference_Tests.cs @@ -16,7 +16,8 @@ public sealed class Server_Reference_Tests : BaseTestClass [TestMethod] public async Task Server_Reports_With_Reference_Server() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreClientLogErrors = true; diff --git a/Source/MQTTnet.Tests/Server/Session_Tests.cs b/Source/MQTTnet.Tests/Server/Session_Tests.cs index 91345da56..ab28c896f 100644 --- a/Source/MQTTnet.Tests/Server/Session_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Session_Tests.cs @@ -21,7 +21,8 @@ public sealed class Session_Tests : BaseTestClass [TestMethod] public async Task Clean_Session_Persistence() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { // Create server with persistent sessions enabled @@ -72,7 +73,8 @@ public async Task Clean_Session_Persistence() [TestMethod] public async Task Do_Not_Use_Expired_Session() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(o => o.WithPersistentSessions()); @@ -94,15 +96,16 @@ public async Task Do_Not_Use_Expired_Session() [TestMethod] public async Task Fire_Deleted_Event() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { // Arrange client and server. var server = await testEnvironment.StartServer(o => o.WithPersistentSessions(false)); - var deletedEventFired = false; + var deletedEventFiredTaskSource = new TaskCompletionSource(); server.SessionDeletedAsync += e => { - deletedEventFired = true; + deletedEventFiredTaskSource.TrySetResult(true); return CompletedTask.Instance; }; @@ -111,7 +114,7 @@ public async Task Fire_Deleted_Event() // Act: Disconnect the client -> Event must be fired. await client.DisconnectAsync(); - await LongTestDelay(); + var deletedEventFired = await deletedEventFiredTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); // Assert that the event was fired properly. Assert.IsTrue(deletedEventFired); @@ -121,7 +124,8 @@ public async Task Fire_Deleted_Event() [TestMethod] public async Task Get_Session_Items_In_Status() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -149,7 +153,8 @@ public async Task Get_Session_Items_In_Status() [DataRow(MqttProtocolVersion.V500)] public async Task Handle_Parallel_Connection_Attempts(MqttProtocolVersion protocolVersion) { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { testEnvironment.IgnoreClientLogErrors = true; @@ -157,10 +162,11 @@ public async Task Handle_Parallel_Connection_Attempts(MqttProtocolVersion protoc var options = new MqttClientOptionsBuilder().WithClientId("1").WithTimeout(TimeSpan.FromSeconds(10)).WithProtocolVersion(protocolVersion); - var hasReceive = false; + + var hasReceiveTaskSource = new TaskCompletionSource(); void OnReceive() { - hasReceive = true; + hasReceiveTaskSource.TrySetResult(true); } // Try to connect 50 clients at the same time. @@ -176,7 +182,7 @@ void OnReceive() var sendClient = await testEnvironment.ConnectClient(option2); await sendClient.PublishStringAsync("aaa", "1"); - await LongTestDelay(); + var hasReceive = await hasReceiveTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); Assert.AreEqual(true, hasReceive); } @@ -187,9 +193,10 @@ void OnReceive() [DataRow(MqttQualityOfServiceLevel.AtLeastOnce)] public async Task Retry_If_Not_PubAck(MqttQualityOfServiceLevel qos) { - long count = 0; - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { + long count = 0; await testEnvironment.StartServer(o => o.WithPersistentSessions()); var publisher = await testEnvironment.ConnectClient(); @@ -223,7 +230,8 @@ public async Task Retry_If_Not_PubAck(MqttQualityOfServiceLevel qos) [TestMethod] public async Task Session_Takeover() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -235,9 +243,11 @@ public async Task Session_Takeover() await Task.Delay(500); var disconnectReason = MqttClientDisconnectReason.NormalDisconnection; + var disconnectTaskSource = new TaskCompletionSource(); client1.DisconnectedAsync += c => { disconnectReason = c.Reason; + disconnectTaskSource.TrySetResult(); return CompletedTask.Instance; }; @@ -247,6 +257,7 @@ public async Task Session_Takeover() Assert.IsFalse(client1.IsConnected); Assert.IsTrue(client2.IsConnected); + await disconnectTaskSource.Task.WaitAsync(TimeSpan.FromSeconds(10d)); Assert.AreEqual(MqttClientDisconnectReason.SessionTakenOver, disconnectReason); } } @@ -254,16 +265,15 @@ public async Task Session_Takeover() [TestMethod] public async Task Set_Session_Item() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); - server.ValidatingConnectionAsync += e => { // Don't validate anything. Just set some session items. e.SessionItems["can_subscribe_x"] = true; e.SessionItems["default_payload"] = "Hello World"; - return CompletedTask.Instance; }; @@ -311,7 +321,8 @@ public async Task Set_Session_Item() [TestMethod] public async Task Use_Clean_Session() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -331,7 +342,8 @@ public async Task Use_Clean_Session() [TestMethod] public async Task Will_Message_Do_Not_Send_On_Takeover() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var receivedMessagesCount = 0; @@ -355,7 +367,7 @@ public async Task Will_Message_Do_Not_Send_On_Takeover() // C3 will do the connection takeover. await testEnvironment.ConnectClient(clientOptions); - await Task.Delay(1000); + await LongTestDelay(); Assert.AreEqual(0, receivedMessagesCount); } diff --git a/Source/MQTTnet.Tests/Server/Shared_Subscriptions_Tests.cs b/Source/MQTTnet.Tests/Server/Shared_Subscriptions_Tests.cs index c1a39289d..0dac081fc 100644 --- a/Source/MQTTnet.Tests/Server/Shared_Subscriptions_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Shared_Subscriptions_Tests.cs @@ -15,7 +15,8 @@ public sealed class Shared_Subscriptions_Tests : BaseTestClass [TestMethod] public async Task Server_Reports_Shared_Subscriptions_Not_Supported() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -31,7 +32,8 @@ public async Task Server_Reports_Shared_Subscriptions_Not_Supported() [TestMethod] public async Task Subscription_Of_Shared_Subscription_Is_Denied() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Status_Tests.cs b/Source/MQTTnet.Tests/Server/Status_Tests.cs index ef9419eb3..ec4fe6c3e 100644 --- a/Source/MQTTnet.Tests/Server/Status_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Status_Tests.cs @@ -18,7 +18,8 @@ public sealed class Status_Tests : BaseTestClass [TestMethod] public async Task Disconnect_Client() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -46,7 +47,8 @@ public async Task Disconnect_Client() [TestMethod] public async Task Keep_Persistent_Session_Version311() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(o => o.WithPersistentSessions()); @@ -80,7 +82,8 @@ public async Task Keep_Persistent_Session_Version311() [TestMethod] public async Task Keep_Persistent_Session_Version500() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(o => o.WithPersistentSessions()); @@ -116,7 +119,8 @@ public async Task Keep_Persistent_Session_Version500() [TestMethod] public async Task Show_Client_And_Session_Statistics() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -150,7 +154,8 @@ public async Task Show_Client_And_Session_Statistics() [TestMethod] public async Task Track_Sent_Application_Messages() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(new MqttServerOptionsBuilder().WithPersistentSessions()); @@ -171,7 +176,8 @@ public async Task Track_Sent_Application_Messages() [TestMethod] public async Task Track_Sent_Packets() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(new MqttServerOptionsBuilder().WithPersistentSessions()); diff --git a/Source/MQTTnet.Tests/Server/Subscribe_Tests.cs b/Source/MQTTnet.Tests/Server/Subscribe_Tests.cs index 9dd48c317..ff89c3afb 100644 --- a/Source/MQTTnet.Tests/Server/Subscribe_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Subscribe_Tests.cs @@ -39,7 +39,8 @@ public sealed class Subscribe_Tests : BaseTestClass [DataRow("A/B1/B2/C", "A/+/C", false)] public async Task Subscription_Roundtrip(string topic, string filter, bool shouldWork) { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer().ConfigureAwait(false); @@ -66,7 +67,8 @@ public async Task Subscription_Roundtrip(string topic, string filter, bool shoul [TestMethod] public async Task Deny_Invalid_Topic() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -93,7 +95,8 @@ public async Task Deny_Invalid_Topic() [TestMethod] public async Task Intercept_Subscribe_With_User_Properties() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -117,7 +120,8 @@ public async Task Intercept_Subscribe_With_User_Properties() [ExpectedException(typeof(MqttClientDisconnectedException))] public async Task Disconnect_While_Subscribing() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -132,7 +136,8 @@ public async Task Disconnect_While_Subscribing() [TestMethod] public async Task Enqueue_Message_After_Subscription() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -156,7 +161,8 @@ public async Task Enqueue_Message_After_Subscription() [TestMethod] public async Task Intercept_Subscription() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -199,7 +205,8 @@ public async Task Intercept_Subscription() [TestMethod] public async Task Response_Contains_Equal_Reason_Codes() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); var client = await testEnvironment.ConnectClient(); @@ -219,7 +226,8 @@ public async Task Response_Contains_Equal_Reason_Codes() [TestMethod] public async Task Subscribe_Lots_In_Multiple_Requests() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var receivedMessagesCount = 0; @@ -262,7 +270,8 @@ public async Task Subscribe_Lots_In_Multiple_Requests() [TestMethod] public async Task Subscribe_Lots_In_Single_Request() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var receivedMessagesCount = 0; @@ -302,7 +311,8 @@ public async Task Subscribe_Lots_In_Single_Request() [TestMethod] public async Task Subscribe_Multiple_In_Multiple_Request() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var receivedMessagesCount = 0; @@ -340,7 +350,8 @@ public async Task Subscribe_Multiple_In_Multiple_Request() [TestMethod] public async Task Subscribe_Multiple_In_Single_Request() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var receivedMessagesCount = 0; @@ -374,7 +385,8 @@ public async Task Subscribe_Multiple_In_Single_Request() [TestMethod] public async Task Subscribe_Unsubscribe() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var receivedMessagesCount = 0; diff --git a/Source/MQTTnet.Tests/Server/Subscription_Identifier_Tests.cs b/Source/MQTTnet.Tests/Server/Subscription_Identifier_Tests.cs index 24b370d15..fa1ee652e 100644 --- a/Source/MQTTnet.Tests/Server/Subscription_Identifier_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Subscription_Identifier_Tests.cs @@ -14,7 +14,8 @@ public sealed class Subscription_Identifier_Tests : BaseTestClass [TestMethod] public async Task Server_Reports_Subscription_Identifiers_Supported() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -30,7 +31,8 @@ public async Task Server_Reports_Subscription_Identifiers_Supported() [TestMethod] public async Task Subscribe_With_Subscription_Identifier() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -57,7 +59,8 @@ public async Task Subscribe_With_Subscription_Identifier() [TestMethod] public async Task Subscribe_With_Multiple_Subscription_Identifiers() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Tls_Tests.cs b/Source/MQTTnet.Tests/Server/Tls_Tests.cs index a87b0dc4c..57b9a46b1 100644 --- a/Source/MQTTnet.Tests/Server/Tls_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Tls_Tests.cs @@ -50,7 +50,7 @@ static X509Certificate2 CreateCertificate(string oid) [TestMethod] public async Task Tls_Swap_Test() { - var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500); + using var testEnvironment = new TestEnvironment(TestContext, MqttProtocolVersion.V500); var serverOptionsBuilder = testEnvironment.ServerFactory.CreateServerOptionsBuilder(); var firstOid = "1.3.6.1.5.5.7.3.1"; diff --git a/Source/MQTTnet.Tests/Server/Topic_Alias_Tests.cs b/Source/MQTTnet.Tests/Server/Topic_Alias_Tests.cs index 024fca16c..e104e58b2 100644 --- a/Source/MQTTnet.Tests/Server/Topic_Alias_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Topic_Alias_Tests.cs @@ -17,7 +17,8 @@ public sealed class Topic_Alias_Tests : BaseTestClass [TestMethod] public async Task Server_Reports_Topic_Alias_Supported() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -35,7 +36,8 @@ public async Task Server_Reports_Topic_Alias_Supported() [TestMethod] public async Task Publish_With_Topic_Alias() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Unsubscribe_Tests.cs b/Source/MQTTnet.Tests/Server/Unsubscribe_Tests.cs index 9f97ad88e..5fc7a3035 100644 --- a/Source/MQTTnet.Tests/Server/Unsubscribe_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Unsubscribe_Tests.cs @@ -10,6 +10,7 @@ using MQTTnet.Internal; using MQTTnet.Protocol; using MQTTnet.Server; +using MQTTnet.Tests.Mockups; namespace MQTTnet.Tests.Server { @@ -20,7 +21,8 @@ public sealed class Unsubscribe_Tests : BaseTestClass [ExpectedException(typeof(MqttClientDisconnectedException))] public async Task Disconnect_While_Unsubscribing() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); @@ -36,7 +38,8 @@ public async Task Disconnect_While_Unsubscribing() [TestMethod] public async Task Intercept_Unsubscribe_With_User_Properties() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/User_Properties_Tests.cs b/Source/MQTTnet.Tests/Server/User_Properties_Tests.cs index 2a6f07137..84cd8e122 100644 --- a/Source/MQTTnet.Tests/Server/User_Properties_Tests.cs +++ b/Source/MQTTnet.Tests/Server/User_Properties_Tests.cs @@ -15,14 +15,13 @@ namespace MQTTnet.Tests.Server { [TestClass] - public class Feature_Tests + public class Feature_Tests : BaseTestClass { - public TestContext TestContext { get; set; } - [TestMethod] public async Task Use_User_Properties() { - using (var testEnvironment = new TestEnvironment(TestContext)) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Wildcard_Subscription_Available_Tests.cs b/Source/MQTTnet.Tests/Server/Wildcard_Subscription_Available_Tests.cs index 85b224d63..fcdcea44c 100644 --- a/Source/MQTTnet.Tests/Server/Wildcard_Subscription_Available_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Wildcard_Subscription_Available_Tests.cs @@ -14,7 +14,8 @@ public sealed class Wildcard_Subscription_Available_Tests : BaseTestClass [TestMethod] public async Task Server_Reports_Wildcard_Subscription_Available_Tests_Supported_V3() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -30,7 +31,8 @@ public async Task Server_Reports_Wildcard_Subscription_Available_Tests_Supported [TestMethod] public async Task Server_Reports_Wildcard_Subscription_Available_Tests_Supported_V5() { - using (var testEnvironment = CreateTestEnvironment(MqttProtocolVersion.V500)) + using var testEnvironments = CreateMixedTestEnvironment(MqttProtocolVersion.V500); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet.Tests/Server/Will_Tests.cs b/Source/MQTTnet.Tests/Server/Will_Tests.cs index 0e823bcde..b0f860036 100644 --- a/Source/MQTTnet.Tests/Server/Will_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Will_Tests.cs @@ -11,7 +11,8 @@ public sealed class Will_Tests : BaseTestClass [TestMethod] public async Task Intercept_Will_Message() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { var server = await testEnvironment.StartServer().ConfigureAwait(false); @@ -36,7 +37,8 @@ public async Task Intercept_Will_Message() [TestMethod] public async Task Will_Message_Do_Not_Send_On_Clean_Disconnect() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); @@ -59,7 +61,8 @@ public async Task Will_Message_Do_Not_Send_On_Clean_Disconnect() [TestMethod] public async Task Will_Message_Send() { - using (var testEnvironment = CreateTestEnvironment()) + using var testEnvironments = CreateMixedTestEnvironment(); + foreach (var testEnvironment in testEnvironments) { await testEnvironment.StartServer(); diff --git a/Source/MQTTnet/Implementations/MqttClientAdapterFactory.cs b/Source/MQTTnet/Implementations/MqttClientAdapterFactory.cs index 0a4031f31..6d7e6b40e 100644 --- a/Source/MQTTnet/Implementations/MqttClientAdapterFactory.cs +++ b/Source/MQTTnet/Implementations/MqttClientAdapterFactory.cs @@ -3,10 +3,10 @@ // See the LICENSE file in the project root for more information. using MQTTnet.Adapter; -using MQTTnet.Formatter; -using System; using MQTTnet.Channel; using MQTTnet.Diagnostics.Logger; +using MQTTnet.Formatter; +using System; namespace MQTTnet.Implementations { diff --git a/Source/MQTTnet/MQTTnet.csproj b/Source/MQTTnet/MQTTnet.csproj index ece6812a0..7c7a93f19 100644 --- a/Source/MQTTnet/MQTTnet.csproj +++ b/Source/MQTTnet/MQTTnet.csproj @@ -44,7 +44,7 @@ all true low - latest-Recommended + diff --git a/Source/MQTTnet/MqttClient.cs b/Source/MQTTnet/MqttClient.cs index 9d19ce574..f89c6bdad 100644 --- a/Source/MQTTnet/MqttClient.cs +++ b/Source/MQTTnet/MqttClient.cs @@ -2,10 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; using MQTTnet.Adapter; using MQTTnet.Diagnostics.Logger; using MQTTnet.Diagnostics.PacketInspection; @@ -15,6 +11,10 @@ using MQTTnet.PacketDispatcher; using MQTTnet.Packets; using MQTTnet.Protocol; +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; namespace MQTTnet; @@ -270,21 +270,21 @@ public Task PublishAsync(MqttApplicationMessage applica switch (applicationMessage.QualityOfServiceLevel) { case MqttQualityOfServiceLevel.AtMostOnce: - { - return PublishAtMostOnce(publishPacket, cancellationToken); - } + { + return PublishAtMostOnce(publishPacket, cancellationToken); + } case MqttQualityOfServiceLevel.AtLeastOnce: - { - return PublishAtLeastOnce(publishPacket, cancellationToken); - } + { + return PublishAtLeastOnce(publishPacket, cancellationToken); + } case MqttQualityOfServiceLevel.ExactlyOnce: - { - return PublishExactlyOnce(publishPacket, cancellationToken); - } + { + return PublishExactlyOnce(publishPacket, cancellationToken); + } default: - { - throw new NotSupportedException(); - } + { + throw new NotSupportedException(); + } } } @@ -395,34 +395,34 @@ Task AcknowledgeReceivedPublishPacket(MqttApplicationMessageReceivedEventArgs ev switch (eventArgs.PublishPacket.QualityOfServiceLevel) { case MqttQualityOfServiceLevel.AtMostOnce: - { - // no response required - break; - } - case MqttQualityOfServiceLevel.AtLeastOnce: - { - if (!eventArgs.ProcessingFailed) { - var pubAckPacket = MqttPubAckPacketFactory.Create(eventArgs); - return Send(pubAckPacket, cancellationToken); + // no response required + break; } + case MqttQualityOfServiceLevel.AtLeastOnce: + { + if (!eventArgs.ProcessingFailed) + { + var pubAckPacket = MqttPubAckPacketFactory.Create(eventArgs); + return Send(pubAckPacket, cancellationToken); + } - break; - } + break; + } case MqttQualityOfServiceLevel.ExactlyOnce: - { - if (!eventArgs.ProcessingFailed) { - var pubRecPacket = MqttPubRecPacketFactory.Create(eventArgs); - return Send(pubRecPacket, cancellationToken); - } + if (!eventArgs.ProcessingFailed) + { + var pubRecPacket = MqttPubRecPacketFactory.Create(eventArgs); + return Send(pubRecPacket, cancellationToken); + } - break; - } + break; + } default: - { - throw new MqttProtocolViolationException("Received a not supported QoS level."); - } + { + throw new MqttProtocolViolationException("Received a not supported QoS level."); + } } return CompletedTask.Instance; @@ -442,22 +442,22 @@ async Task Authenticate(IMqttChannelAdapter channelAdap switch (receivedPacket) { case MqttConnAckPacket connAckPacket: - { - result = MqttClientResultFactory.ConnectResult.Create(connAckPacket, channelAdapter.PacketFormatterAdapter.ProtocolVersion); - break; - } + { + result = MqttClientResultFactory.ConnectResult.Create(connAckPacket, channelAdapter.PacketFormatterAdapter.ProtocolVersion); + break; + } case MqttAuthPacket _: - { - throw new NotSupportedException("Extended authentication handler is not yet supported"); - } + { + throw new NotSupportedException("Extended authentication handler is not yet supported"); + } case null: - { - throw new MqttCommunicationException("Connection closed."); - } + { + throw new MqttCommunicationException("Connection closed."); + } default: - { - throw new InvalidOperationException($"Received an unexpected MQTT packet ({receivedPacket})."); - } + { + throw new InvalidOperationException($"Received an unexpected MQTT packet ({receivedPacket})."); + } } } catch (Exception exception) @@ -967,14 +967,14 @@ async Task TryProcessReceivedPacket(MqttPacket packet, CancellationToken cancell case MqttPingReqPacket _: throw new MqttProtocolViolationException("The PINGREQ Packet is sent from a client to the server only."); default: - { - if (!_packetDispatcher.TryDispatch(packet)) { - throw new MqttProtocolViolationException($"Received packet '{packet}' at an unexpected time."); - } + if (!_packetDispatcher.TryDispatch(packet)) + { + throw new MqttProtocolViolationException($"Received packet '{packet}' at an unexpected time."); + } - break; - } + break; + } } } catch (Exception exception)