diff --git a/src/Pretender.SourceGenerator/Parser/KnownTypeSymbols.cs b/src/Pretender.SourceGenerator/Parser/KnownTypeSymbols.cs index 25f928e..9e8b7ff 100644 --- a/src/Pretender.SourceGenerator/Parser/KnownTypeSymbols.cs +++ b/src/Pretender.SourceGenerator/Parser/KnownTypeSymbols.cs @@ -11,10 +11,11 @@ internal sealed class KnownTypeSymbols public INamedTypeSymbol? Pretend { get; } public INamedTypeSymbol? Pretend_Unbound { get; } + public INamedTypeSymbol String { get; } public INamedTypeSymbol? Task { get; } - public INamedTypeSymbol? TaskOfT { get; } + public INamedTypeSymbol? TaskOfT_Unbound { get; } public INamedTypeSymbol? ValueTask { get; } - public INamedTypeSymbol? ValueTaskOfT { get; } + public INamedTypeSymbol? ValueTaskOfT_Unbound { get; } @@ -26,12 +27,11 @@ public KnownTypeSymbols(CSharpCompilation compilation) Pretend = compilation.GetTypeByMetadataName("Pretender.Pretend`1"); Pretend_Unbound = Pretend?.ConstructUnboundGenericType(); + String = compilation.GetSpecialType(SpecialType.System_String); Task = compilation.GetTypeByMetadataName("System.Threading.Tasks.Task"); - // TODO: Create unbounded? - TaskOfT = compilation.GetTypeByMetadataName("System.Threading.Tasks.Task`1"); + TaskOfT_Unbound = compilation.GetTypeByMetadataName("System.Threading.Tasks.Task`1")?.ConstructUnboundGenericType(); ValueTask = compilation.GetTypeByMetadataName("System.Threading.Tasks.ValueTask"); - // TODO: Create unbounded? - ValueTaskOfT = compilation.GetTypeByMetadataName("System.Threading.Tasks.ValueTask`1"); + ValueTaskOfT_Unbound = compilation.GetTypeByMetadataName("System.Threading.Tasks.ValueTask`1")?.ConstructUnboundGenericType(); } public static bool IsPretend(INamedTypeSymbol type) diff --git a/src/Pretender.SourceGenerator/SetupActionEmitter.cs b/src/Pretender.SourceGenerator/SetupActionEmitter.cs index 4c36ddc..d1f55d2 100644 --- a/src/Pretender.SourceGenerator/SetupActionEmitter.cs +++ b/src/Pretender.SourceGenerator/SetupActionEmitter.cs @@ -136,47 +136,47 @@ public InvocationExpressionSyntax CreateSetupGetter(CancellationToken cancellati ExpressionSyntax defaultValue; // TODO: Is this safe? - // var namedType = (INamedTypeSymbol)SetupMethod.ReturnType; - - // defaultValue = namedType.ToDefaultValueSyntax(_knownTypeSymbols); - - if (SetupMethod.ReturnType.EqualsByName(["System", "Threading", "Tasks", "Task"])) - { - if (SetupMethod.ReturnType is INamedTypeSymbol namedType && namedType.TypeArguments.Length == 1) - { - // Task.FromResult(default) - defaultValue = KnownBlocks.TaskFromResult( - namedType.TypeArguments[0].AsUnknownTypeSyntax(), - LiteralExpression(SyntaxKind.DefaultLiteralExpression)); - } - else - { - // Task.CompletedTask - defaultValue = KnownBlocks.TaskCompletedTask; - } - } - else if (SetupMethod.ReturnType.EqualsByName(["System", "Threading", "Tasks", "ValueTask"])) - { - if (SetupMethod.ReturnType is INamedTypeSymbol namedType && namedType.TypeArguments.Length == 1) - { - // ValueTask.FromResult(default) - defaultValue = KnownBlocks.ValueTaskFromResult( - namedType.TypeArguments[0].AsUnknownTypeSyntax(), - LiteralExpression(SyntaxKind.DefaultLiteralExpression) - ); - } - else - { - // ValueTask.CompletedTask - defaultValue = KnownBlocks.ValueTaskCompletedTask; - } - } - else - { - // TODO: Support custom awaitable - // default - defaultValue = LiteralExpression(SyntaxKind.DefaultLiteralExpression); - } + var namedType = (INamedTypeSymbol)SetupMethod.ReturnType; + + defaultValue = namedType.ToDefaultValueSyntax(_knownTypeSymbols); + + //if (SetupMethod.ReturnType.EqualsByName(["System", "Threading", "Tasks", "Task"])) + //{ + // if (SetupMethod.ReturnType is INamedTypeSymbol namedType && namedType.TypeArguments.Length == 1) + // { + // // Task.FromResult(default) + // defaultValue = KnownBlocks.TaskFromResult( + // namedType.TypeArguments[0].AsUnknownTypeSyntax(), + // LiteralExpression(SyntaxKind.DefaultLiteralExpression)); + // } + // else + // { + // // Task.CompletedTask + // defaultValue = KnownBlocks.TaskCompletedTask; + // } + //} + //else if (SetupMethod.ReturnType.EqualsByName(["System", "Threading", "Tasks", "ValueTask"])) + //{ + // if (SetupMethod.ReturnType is INamedTypeSymbol namedType && namedType.TypeArguments.Length == 1) + // { + // // ValueTask.FromResult(default) + // defaultValue = KnownBlocks.ValueTaskFromResult( + // namedType.TypeArguments[0].AsUnknownTypeSyntax(), + // LiteralExpression(SyntaxKind.DefaultLiteralExpression) + // ); + // } + // else + // { + // // ValueTask.CompletedTask + // defaultValue = KnownBlocks.ValueTaskCompletedTask; + // } + //} + //else + //{ + // // TODO: Support custom awaitable + // // default + // defaultValue = LiteralExpression(SyntaxKind.DefaultLiteralExpression); + //} cancellationToken.ThrowIfCancellationRequested(); diff --git a/src/Pretender.SourceGenerator/SymbolExtensions.cs b/src/Pretender.SourceGenerator/SymbolExtensions.cs index 6d64153..6fd0eee 100644 --- a/src/Pretender.SourceGenerator/SymbolExtensions.cs +++ b/src/Pretender.SourceGenerator/SymbolExtensions.cs @@ -52,19 +52,65 @@ public static TypeSyntax AsUnknownTypeSyntax(this ITypeSymbol type) public static ExpressionSyntax ToDefaultValueSyntax(this INamedTypeSymbol type, KnownTypeSymbols knownTypeSymbols) { // They have explicitly annotated this type as nullable, so return null - if (type.NullableAnnotation == NullableAnnotation.Annotated) + if (type.NullableAnnotation != NullableAnnotation.NotAnnotated) { return LiteralExpression(SyntaxKind.DefaultLiteralExpression); } var comparer = SymbolEqualityComparer.Default; + if (type.IsUnboundGenericType) + { + throw new NotImplementedException("We believe this should have been impossible, please report this issue with a minimally reproducible sample."); + } + + if (type.IsGenericType) + { + var unboundType = type.ConstructUnboundGenericType(); + + if (comparer.Equals(unboundType, knownTypeSymbols.TaskOfT_Unbound)) + { + // Create Task.FromResult(); + // TODO: Is this ever an unsafe cast? How could you have Task? + var resultType = (INamedTypeSymbol)type.TypeArguments[0]; + + // Recursion? Issue? + return KnownBlocks.TaskFromResult(resultType.AsUnknownTypeSyntax(), resultType.ToDefaultValueSyntax(knownTypeSymbols)); + } + + if (comparer.Equals(unboundType, knownTypeSymbols.ValueTaskOfT_Unbound)) + { + var resultType = (INamedTypeSymbol)type.TypeArguments[0]; + + // Recursion! + return KnownBlocks.ValueTaskFromResult(resultType.AsUnknownTypeSyntax(), resultType.ToDefaultValueSyntax(knownTypeSymbols)); + } + + // TODO: Support IEnumerable, Lists, Arrays, and others + } + if (comparer.Equals(type, knownTypeSymbols.Task)) { return KnownBlocks.TaskCompletedTask; } - throw new NotImplementedException(); + if (comparer.Equals(type, knownTypeSymbols.ValueTask)) + { + return KnownBlocks.ValueTaskCompletedTask; + } + + if (comparer.Equals(type, knownTypeSymbols.String)) + { + // They have requested not-null so special case non-null string to be string.Empty + // people may not like this. + return MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + PredefinedType(Token(SyntaxKind.StringKeyword)), + IdentifierName("Empty")); + } + + // No better default found, just use 'default' even though it might not be in line with nullability annotations + return LiteralExpression(SyntaxKind.DefaultLiteralExpression); } public static string ToPretendName(this ITypeSymbol symbol) diff --git a/test/SourceGeneratorTests/MainTests.cs b/test/SourceGeneratorTests/MainTests.cs index 450d664..f85f2f4 100644 --- a/test/SourceGeneratorTests/MainTests.cs +++ b/test/SourceGeneratorTests/MainTests.cs @@ -5,7 +5,7 @@ public partial class MainTests : TestBase [Fact] public async Task ReturningMethod() { - await RunAndCompareAsync($$""" + await RunAndComparePartialAsync($$""" var pretendSimpleInterface = Pretend.That(); pretendSimpleInterface @@ -16,11 +16,42 @@ await RunAndCompareAsync($$""" """); } + [Fact] + public async Task TaskOfTMethod() + { + var (result, _) = await RunGeneratorAsync($$""" + #nullable disable + using System; + using System.Threading.Tasks; + using Pretender; + + public interface IMyInterface + { + Task MethodAsync(string str); + } + + public class TestClass + { + public TestClass() + { + var pretend = Pretend.That(); + + pretend.Setup(i => i.MethodAsync("Hi")); + } + } + """); + + Assert.Equal(2, result.GeneratedSources.Length); + + var text1 = result.GeneratedSources[0].SourceText.ToString(); + var text2 = result.GeneratedSources[1].SourceText.ToString(); + } + [Fact] public async Task Test2() { - var (result, compilation) = await RunGeneratorAsync($$""" + var (result, compilation) = await RunPartialGeneratorAsync($$""" var pretendSimpleInterface = Pretend.That(); var simpleInterface = pretendSimpleInterface.Create(); @@ -39,7 +70,7 @@ public async Task Test2() [Fact] public async Task Test3() { - var (result, compilation) = await RunGeneratorAsync($$""" + var (result, compilation) = await RunPartialGeneratorAsync($$""" var pretendSimpleInterface = Pretend.That(); pretendSimpleInterface diff --git a/test/SourceGeneratorTests/TestBase.cs b/test/SourceGeneratorTests/TestBase.cs index 163eb7a..d248bc0 100644 --- a/test/SourceGeneratorTests/TestBase.cs +++ b/test/SourceGeneratorTests/TestBase.cs @@ -65,9 +65,9 @@ public SimpleAbstractClass(string arg) } """; - public async Task<(GeneratorRunResult GeneratorResult, Compilation UpdateCompilation)> RunGeneratorAsync(string source) + public async Task<(GeneratorRunResult GeneratorResult, Compilation UpdateCompilation)> RunPartialGeneratorAsync(string source) { - var compilation = await CreateCompilationAsync(source); + var compilation = await CreateCompilationFromPartialAsync(source); var generator = new PretenderSourceGenerator().AsSourceGenerator(); GeneratorDriver driver = CSharpGeneratorDriver.Create( @@ -83,9 +83,27 @@ public SimpleAbstractClass(string arg) return (Assert.Single(runResult.Results), updateCompilation); } - public async Task RunAndCompareAsync(string source, [CallerMemberName] string testMethodName = null!) + public async Task<(GeneratorRunResult GeneratorResult, Compilation UpdatedCompilation)> RunGeneratorAsync(string fullSource) { - var (result, _) = await RunGeneratorAsync(source); + var compilation = await CreateCompilationAsync(fullSource); + + var generator = new PretenderSourceGenerator().AsSourceGenerator(); + GeneratorDriver driver = CSharpGeneratorDriver.Create( + generators: [generator], + driverOptions: new GeneratorDriverOptions(IncrementalGeneratorOutputKind.None, trackIncrementalGeneratorSteps: true), + parseOptions: ParseOptions + ); + + driver = driver.RunGeneratorsAndUpdateCompilation(compilation, out var updateCompilation, out var diagnostics); + + var runResult = driver.GetRunResult(); + + return (Assert.Single(runResult.Results), updateCompilation); + } + + public async Task RunAndComparePartialAsync(string source, [CallerMemberName] string testMethodName = null!) + { + var (result, _) = await RunPartialGeneratorAsync(source); Assert.All(result.GeneratedSources, s => { CompareAgainstBaseline(s, testMethodName); @@ -124,7 +142,7 @@ private void CompareAgainstBaseline(GeneratedSourceResult result, string testMet #endif } - private Task CreateCompilationAsync(string source) + private Task CreateCompilationFromPartialAsync(string source) { var fullText = Base + CreateTestTemplate(source); var project = BaseProject @@ -134,6 +152,15 @@ private Task CreateCompilationAsync(string source) return project.GetCompilationAsync()!; } + private static Task CreateCompilationAsync(string fullSource) + { + var project = BaseProject + .AddDocument("MyTest.cs", SourceText.From(fullSource, Encoding.UTF8)) + .Project; + + return project.GetCompilationAsync()!; + } + private static Project CreateProject() { var projectName = $"TestProject-{Guid.NewGuid()}";