From bb8120d01bfd706260bac914954792d7f1b11a8f Mon Sep 17 00:00:00 2001 From: Gregorius Soedharmo Date: Thu, 4 Sep 2025 04:35:11 +0700 Subject: [PATCH] Implement nullable for Akka.Streams.Implementation.ChannelSources --- .../ChannelSourceFromReaderRegressionSpecs.cs | 149 ++++++++++++++++++ .../Implementation/ChannelSources.cs | 44 ++++-- 2 files changed, 181 insertions(+), 12 deletions(-) create mode 100644 src/core/Akka.Streams.Tests/Implementation/ChannelSourceFromReaderRegressionSpecs.cs diff --git a/src/core/Akka.Streams.Tests/Implementation/ChannelSourceFromReaderRegressionSpecs.cs b/src/core/Akka.Streams.Tests/Implementation/ChannelSourceFromReaderRegressionSpecs.cs new file mode 100644 index 00000000000..07271814c0b --- /dev/null +++ b/src/core/Akka.Streams.Tests/Implementation/ChannelSourceFromReaderRegressionSpecs.cs @@ -0,0 +1,149 @@ +//----------------------------------------------------------------------- +// +// Copyright (C) 2009-2022 Lightbend Inc. +// Copyright (C) 2013-2025 .NET Foundation +// +//----------------------------------------------------------------------- + +using System; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; +using Akka.Actor; +using Akka.Streams.Dsl; +using Xunit; +using Xunit.Abstractions; + +namespace Akka.Streams.Tests.Implementation; + +// Simple reference type used in the reproducer (matches the Discord/GitHub thread) +public sealed class Message +{ + public TKey Key { get; init; } + public TValue Value { get; init; } +} + +public sealed class ChannelSourceFromReaderRegressionSpecs : Akka.TestKit.Xunit2.TestKit +{ + private readonly IMaterializer _mat; + + public ChannelSourceFromReaderRegressionSpecs(ITestOutputHelper output) : base(output: output) + { + _mat = Sys.Materializer(); + } + + [Fact(DisplayName = "FromReader: closing without writing any elements should complete stream (no NRE)")] + public async Task FromReader_should_complete_cleanly_with_zero_elements() + { + var ch = Channel.CreateBounded>(new BoundedChannelOptions(8) + { + SingleReader = true, + SingleWriter = true, + FullMode = BoundedChannelFullMode.Wait + }); + + var src = ChannelSource.FromReader(ch.Reader); + + // Collect to a list to ensure materialized task completes on stage completion + var resultTask = src.RunWith(Sink.Seq>(), _mat); + + // Complete the writer without sending any items (problematic path pre-fix) + ch.Writer.Complete(); + + var results = await resultTask.Within(TimeSpan.FromSeconds(5)); + Assert.Empty(results); // main assertion is actually "no exception" + } + + [Fact(DisplayName = "FromReader: one element then close should complete stream (no NRE)")] + public async Task FromReader_should_complete_cleanly_with_one_element_then_close() + { + var ch = Channel.CreateBounded>(new BoundedChannelOptions(8) + { + SingleReader = true, + SingleWriter = true, + FullMode = BoundedChannelFullMode.Wait + }); + + var src = ChannelSource.FromReader(ch.Reader); + var resultTask = src.RunWith(Sink.Seq>(), _mat); + + // Write a single reference-type element then complete + ch.Writer.TryWrite(new Message { Key = "k1", Value = "v1" }); + ch.Writer.Complete(); + + var results = await resultTask.Within(TimeSpan.FromSeconds(5)); + Assert.Single(results); + Assert.Equal("k1", results[0].Key); + Assert.Equal("v1", results[0].Value); + } + + [Fact(DisplayName = "FromReader: failure completion should fail the stream with the same exception")] + public async Task FromReader_should_propagate_failure_instead_of_throwing_NRE() + { + var ch = Channel.CreateBounded>(new BoundedChannelOptions(8) + { + SingleReader = true, + SingleWriter = true, + FullMode = BoundedChannelFullMode.Wait + }); + + var src = ChannelSource.FromReader(ch.Reader); + + // Materialize to Ignore; we only care that the materialized task faults with our exception + var resultTask = src.RunWith(Sink.Ignore>(), _mat); + + var boom = new InvalidOperationException("boom"); + ch.Writer.TryComplete(boom); + + var ex = await Assert.ThrowsAsync(async () => + { + await resultTask.Within(TimeSpan.FromSeconds(5)); + }); + Assert.Equal("boom", ex.Message); + } + + [Fact(DisplayName = "FromReader: value type smoke test should not regress")] + public async Task FromReader_should_work_with_value_types() + { + var ch = Channel.CreateBounded(new BoundedChannelOptions(8) + { + SingleReader = true, + SingleWriter = true, + FullMode = BoundedChannelFullMode.Wait + }); + + var src = ChannelSource.FromReader(ch.Reader); + var resultTask = src.RunWith(Sink.Seq(), _mat); + + ch.Writer.TryWrite(42); + ch.Writer.Complete(); + + var results = await resultTask.Within(TimeSpan.FromSeconds(5)); + Assert.Single(results); + Assert.Equal(42, results[0]); + } +} + +internal static class TaskTimeoutExtensions +{ + /// + /// Helper to await a Task with a timeout (throws if time is exceeded). + /// + public static async Task Within(this Task task, TimeSpan timeout) + { + using var cts = new CancellationTokenSource(timeout); + var completed = await Task.WhenAny(task, Task.Delay(Timeout.InfiniteTimeSpan, cts.Token)); + if (completed != task) + throw new TimeoutException($"Task did not complete within {timeout}."); + return await task; // unwrap exceptions if any + } + + public static async Task Within(this Task task, TimeSpan timeout) + { + using var cts = new CancellationTokenSource(timeout); + var completed = await Task.WhenAny(task, Task.Delay(Timeout.InfiniteTimeSpan, cts.Token)); + if (completed != task) + throw new TimeoutException($"Task did not complete within {timeout}."); + await task; // unwrap exceptions if any + } +} \ No newline at end of file diff --git a/src/core/Akka.Streams/Implementation/ChannelSources.cs b/src/core/Akka.Streams/Implementation/ChannelSources.cs index 7097bef3516..dced1aa6bb3 100644 --- a/src/core/Akka.Streams/Implementation/ChannelSources.cs +++ b/src/core/Akka.Streams/Implementation/ChannelSources.cs @@ -10,15 +10,26 @@ using System.Threading.Tasks; using Akka.Streams.Stage; +#nullable enable namespace Akka.Streams.Implementation { sealed class ChannelSourceLogic : OutGraphStageLogic { + private struct ReaderCompleted + { + public ReaderCompleted(Exception? reason) + { + Reason = reason; + } + + public Exception? Reason { get; } + } + private readonly Outlet _outlet; private readonly ChannelReader _reader; private readonly Action _onValueRead; private readonly Action _onValueReadFailure; - private readonly Action _onReaderComplete; + private readonly Action _onReaderComplete; private readonly Action> _onReadReady; public ChannelSourceLogic(SourceShape source, Outlet outlet, @@ -29,25 +40,35 @@ public ChannelSourceLogic(SourceShape source, Outlet outlet, _onValueRead = GetAsyncCallback(OnValueRead); _onValueReadFailure = GetAsyncCallback(OnValueReadFailure); - _onReaderComplete = GetAsyncCallback(OnReaderComplete); + _onReaderComplete = GetAsyncCallback(OnReaderComplete); _onReadReady = ContinueAsyncRead; _reader.Completion.ContinueWith(t => { - if (t.IsFaulted) _onReaderComplete(t.Exception); + if (t.IsFaulted) _onReaderComplete(new ReaderCompleted(FlattenException(t.Exception))); else if (t.IsCanceled) - _onReaderComplete(new TaskCanceledException(t)); - else _onReaderComplete(null); + _onReaderComplete(new ReaderCompleted(new TaskCanceledException(t))); + else _onReaderComplete(new ReaderCompleted(null)); }); SetHandler(_outlet, this); } - private void OnReaderComplete(Exception reason) + private static Exception? FlattenException(Exception? exception) { - if (reason is null) + if (exception is AggregateException agg) + { + var flat = agg.Flatten(); + return flat.InnerExceptions.Count == 1 ? flat.InnerExceptions[0] : exception; + } + return exception; + } + + private void OnReaderComplete(ReaderCompleted completion) + { + if (completion.Reason is null) CompleteStage(); else - FailStage(reason); + FailStage(completion.Reason); } private void OnValueReadFailure(Exception reason) => FailStage(reason); @@ -84,8 +105,8 @@ public override void OnPull() private void ContinueAsyncRead(Task t) { - if (t.IsFaulted) - _onValueReadFailure(t.Exception); + if (t.IsFaulted) + _onValueReadFailure(t.Exception ?? new Exception("Channel read failed")); else if (t.IsCanceled) _onValueReadFailure(new TaskCanceledException(t)); else @@ -135,7 +156,6 @@ public override ILogicAndMaterializedValue> internal sealed class ChannelReaderSource : GraphStage> { - private readonly ChannelReader _reader; public ChannelReaderSource(ChannelReader reader) @@ -152,4 +172,4 @@ protected override GraphStageLogic CreateLogic(Attributes inheritedAttributes) => new ChannelSourceLogic(Shape, Outlet, _reader); } -} +} \ No newline at end of file