Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
//-----------------------------------------------------------------------
// <copyright file="ChannelSourceFromReaderRegressionSpecs.cs" company="Akka.NET Project">
// Copyright (C) 2009-2022 Lightbend Inc. <http://www.lightbend.com>
// Copyright (C) 2013-2025 .NET Foundation <https://github.com/akkadotnet/akka.net>
// </copyright>
//-----------------------------------------------------------------------

using System;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Akka.Actor;
using Akka.Streams.Dsl;
using Xunit;
using Xunit.Abstractions;

namespace Akka.Streams.Tests.Implementation;

// Simple reference type used in the reproducer (matches the Discord/GitHub thread)
public sealed class Message<TKey, TValue>
{
public TKey Key { get; init; }
public TValue Value { get; init; }
}

public sealed class ChannelSourceFromReaderRegressionSpecs : Akka.TestKit.Xunit2.TestKit
{
private readonly IMaterializer _mat;

public ChannelSourceFromReaderRegressionSpecs(ITestOutputHelper output) : base(output: output)
{
_mat = Sys.Materializer();
}

[Fact(DisplayName = "FromReader: closing without writing any elements should complete stream (no NRE)")]
public async Task FromReader_should_complete_cleanly_with_zero_elements()
{
var ch = Channel.CreateBounded<Message<string, string>>(new BoundedChannelOptions(8)
{
SingleReader = true,
SingleWriter = true,
FullMode = BoundedChannelFullMode.Wait
});

var src = ChannelSource.FromReader(ch.Reader);

// Collect to a list to ensure materialized task completes on stage completion
var resultTask = src.RunWith(Sink.Seq<Message<string, string>>(), _mat);

// Complete the writer without sending any items (problematic path pre-fix)
ch.Writer.Complete();

var results = await resultTask.Within(TimeSpan.FromSeconds(5));
Assert.Empty(results); // main assertion is actually "no exception"
}

[Fact(DisplayName = "FromReader: one element then close should complete stream (no NRE)")]
public async Task FromReader_should_complete_cleanly_with_one_element_then_close()
{
var ch = Channel.CreateBounded<Message<string, string>>(new BoundedChannelOptions(8)
{
SingleReader = true,
SingleWriter = true,
FullMode = BoundedChannelFullMode.Wait
});

var src = ChannelSource.FromReader(ch.Reader);
var resultTask = src.RunWith(Sink.Seq<Message<string, string>>(), _mat);

// Write a single reference-type element then complete
ch.Writer.TryWrite(new Message<string, string> { Key = "k1", Value = "v1" });
ch.Writer.Complete();

var results = await resultTask.Within(TimeSpan.FromSeconds(5));
Assert.Single(results);
Assert.Equal("k1", results[0].Key);
Assert.Equal("v1", results[0].Value);
}

[Fact(DisplayName = "FromReader: failure completion should fail the stream with the same exception")]
public async Task FromReader_should_propagate_failure_instead_of_throwing_NRE()
{
var ch = Channel.CreateBounded<Message<string, string>>(new BoundedChannelOptions(8)
{
SingleReader = true,
SingleWriter = true,
FullMode = BoundedChannelFullMode.Wait
});

var src = ChannelSource.FromReader(ch.Reader);

// Materialize to Ignore; we only care that the materialized task faults with our exception
var resultTask = src.RunWith(Sink.Ignore<Message<string, string>>(), _mat);

var boom = new InvalidOperationException("boom");
ch.Writer.TryComplete(boom);

var ex = await Assert.ThrowsAsync<InvalidOperationException>(async () =>
{
await resultTask.Within(TimeSpan.FromSeconds(5));
});
Assert.Equal("boom", ex.Message);
}

[Fact(DisplayName = "FromReader: value type smoke test should not regress")]
public async Task FromReader_should_work_with_value_types()
{
var ch = Channel.CreateBounded<int>(new BoundedChannelOptions(8)
{
SingleReader = true,
SingleWriter = true,
FullMode = BoundedChannelFullMode.Wait
});

var src = ChannelSource.FromReader(ch.Reader);
var resultTask = src.RunWith(Sink.Seq<int>(), _mat);

ch.Writer.TryWrite(42);
ch.Writer.Complete();

var results = await resultTask.Within(TimeSpan.FromSeconds(5));
Assert.Single(results);
Assert.Equal(42, results[0]);
}
}

internal static class TaskTimeoutExtensions
{
/// <summary>
/// Helper to await a Task with a timeout (throws if time is exceeded).
/// </summary>
public static async Task<T> Within<T>(this Task<T> task, TimeSpan timeout)
{
using var cts = new CancellationTokenSource(timeout);
var completed = await Task.WhenAny(task, Task.Delay(Timeout.InfiniteTimeSpan, cts.Token));
if (completed != task)
throw new TimeoutException($"Task did not complete within {timeout}.");
return await task; // unwrap exceptions if any
}

public static async Task Within(this Task task, TimeSpan timeout)
{
using var cts = new CancellationTokenSource(timeout);
var completed = await Task.WhenAny(task, Task.Delay(Timeout.InfiniteTimeSpan, cts.Token));
if (completed != task)
throw new TimeoutException($"Task did not complete within {timeout}.");
await task; // unwrap exceptions if any
}
}
44 changes: 32 additions & 12 deletions src/core/Akka.Streams/Implementation/ChannelSources.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,26 @@
using System.Threading.Tasks;
using Akka.Streams.Stage;

#nullable enable
namespace Akka.Streams.Implementation
{
sealed class ChannelSourceLogic<T> : OutGraphStageLogic
{
private struct ReaderCompleted
{
public ReaderCompleted(Exception? reason)
{
Reason = reason;
}

public Exception? Reason { get; }
}

private readonly Outlet<T> _outlet;
private readonly ChannelReader<T> _reader;
private readonly Action<bool> _onValueRead;
private readonly Action<Exception> _onValueReadFailure;
private readonly Action<Exception> _onReaderComplete;
private readonly Action<ReaderCompleted> _onReaderComplete;
private readonly Action<Task<bool>> _onReadReady;

public ChannelSourceLogic(SourceShape<T> source, Outlet<T> outlet,
Expand All @@ -29,25 +40,35 @@ public ChannelSourceLogic(SourceShape<T> source, Outlet<T> outlet,
_onValueRead = GetAsyncCallback<bool>(OnValueRead);
_onValueReadFailure =
GetAsyncCallback<Exception>(OnValueReadFailure);
_onReaderComplete = GetAsyncCallback<Exception>(OnReaderComplete);
_onReaderComplete = GetAsyncCallback<ReaderCompleted>(OnReaderComplete);
_onReadReady = ContinueAsyncRead;
_reader.Completion.ContinueWith(t =>
{
if (t.IsFaulted) _onReaderComplete(t.Exception);
if (t.IsFaulted) _onReaderComplete(new ReaderCompleted(FlattenException(t.Exception)));
else if (t.IsCanceled)
_onReaderComplete(new TaskCanceledException(t));
else _onReaderComplete(null);
_onReaderComplete(new ReaderCompleted(new TaskCanceledException(t)));
else _onReaderComplete(new ReaderCompleted(null));
});

SetHandler(_outlet, this);
}

private void OnReaderComplete(Exception reason)
private static Exception? FlattenException(Exception? exception)
{
if (reason is null)
if (exception is AggregateException agg)
{
var flat = agg.Flatten();
return flat.InnerExceptions.Count == 1 ? flat.InnerExceptions[0] : exception;
}
return exception;
}

private void OnReaderComplete(ReaderCompleted completion)
{
if (completion.Reason is null)
CompleteStage();
else
FailStage(reason);
FailStage(completion.Reason);
}

private void OnValueReadFailure(Exception reason) => FailStage(reason);
Expand Down Expand Up @@ -84,8 +105,8 @@ public override void OnPull()

private void ContinueAsyncRead(Task<bool> t)
{
if (t.IsFaulted)
_onValueReadFailure(t.Exception);
if (t.IsFaulted)
_onValueReadFailure(t.Exception ?? new Exception("Channel read failed"));
else if (t.IsCanceled)
_onValueReadFailure(new TaskCanceledException(t));
else
Expand Down Expand Up @@ -135,7 +156,6 @@ public override ILogicAndMaterializedValue<ChannelWriter<T>>

internal sealed class ChannelReaderSource<T> : GraphStage<SourceShape<T>>
{

private readonly ChannelReader<T> _reader;

public ChannelReaderSource(ChannelReader<T> reader)
Expand All @@ -152,4 +172,4 @@ protected override GraphStageLogic
CreateLogic(Attributes inheritedAttributes) =>
new ChannelSourceLogic<T>(Shape, Outlet, _reader);
}
}
}
Loading