Skip to content

Commit

Permalink
More Fully Implement ToDefaultSyntax
Browse files Browse the repository at this point in the history
  • Loading branch information
justindbaur committed Nov 28, 2023
1 parent 73ba385 commit bcaac87
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 57 deletions.
12 changes: 6 additions & 6 deletions src/Pretender.SourceGenerator/Parser/KnownTypeSymbols.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ internal sealed class KnownTypeSymbols
public INamedTypeSymbol? Pretend { get; }
public INamedTypeSymbol? Pretend_Unbound { get; }

public INamedTypeSymbol String { get; }
public INamedTypeSymbol? Task { get; }
public INamedTypeSymbol? TaskOfT { get; }
public INamedTypeSymbol? TaskOfT_Unbound { get; }
public INamedTypeSymbol? ValueTask { get; }
public INamedTypeSymbol? ValueTaskOfT { get; }
public INamedTypeSymbol? ValueTaskOfT_Unbound { get; }



Expand All @@ -26,12 +27,11 @@ public KnownTypeSymbols(CSharpCompilation compilation)
Pretend = compilation.GetTypeByMetadataName("Pretender.Pretend`1");
Pretend_Unbound = Pretend?.ConstructUnboundGenericType();

String = compilation.GetSpecialType(SpecialType.System_String);
Task = compilation.GetTypeByMetadataName("System.Threading.Tasks.Task");
// TODO: Create unbounded?
TaskOfT = compilation.GetTypeByMetadataName("System.Threading.Tasks.Task`1");
TaskOfT_Unbound = compilation.GetTypeByMetadataName("System.Threading.Tasks.Task`1")?.ConstructUnboundGenericType();
ValueTask = compilation.GetTypeByMetadataName("System.Threading.Tasks.ValueTask");
// TODO: Create unbounded?
ValueTaskOfT = compilation.GetTypeByMetadataName("System.Threading.Tasks.ValueTask`1");
ValueTaskOfT_Unbound = compilation.GetTypeByMetadataName("System.Threading.Tasks.ValueTask`1")?.ConstructUnboundGenericType();
}

public static bool IsPretend(INamedTypeSymbol type)
Expand Down
82 changes: 41 additions & 41 deletions src/Pretender.SourceGenerator/SetupActionEmitter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,47 +136,47 @@ public InvocationExpressionSyntax CreateSetupGetter(CancellationToken cancellati
ExpressionSyntax defaultValue;

// TODO: Is this safe?
// var namedType = (INamedTypeSymbol)SetupMethod.ReturnType;

// defaultValue = namedType.ToDefaultValueSyntax(_knownTypeSymbols);

if (SetupMethod.ReturnType.EqualsByName(["System", "Threading", "Tasks", "Task"]))
{
if (SetupMethod.ReturnType is INamedTypeSymbol namedType && namedType.TypeArguments.Length == 1)
{
// Task.FromResult<T>(default)
defaultValue = KnownBlocks.TaskFromResult(
namedType.TypeArguments[0].AsUnknownTypeSyntax(),
LiteralExpression(SyntaxKind.DefaultLiteralExpression));
}
else
{
// Task.CompletedTask
defaultValue = KnownBlocks.TaskCompletedTask;
}
}
else if (SetupMethod.ReturnType.EqualsByName(["System", "Threading", "Tasks", "ValueTask"]))
{
if (SetupMethod.ReturnType is INamedTypeSymbol namedType && namedType.TypeArguments.Length == 1)
{
// ValueTask.FromResult<T>(default)
defaultValue = KnownBlocks.ValueTaskFromResult(
namedType.TypeArguments[0].AsUnknownTypeSyntax(),
LiteralExpression(SyntaxKind.DefaultLiteralExpression)
);
}
else
{
// ValueTask.CompletedTask
defaultValue = KnownBlocks.ValueTaskCompletedTask;
}
}
else
{
// TODO: Support custom awaitable
// default
defaultValue = LiteralExpression(SyntaxKind.DefaultLiteralExpression);
}
var namedType = (INamedTypeSymbol)SetupMethod.ReturnType;

defaultValue = namedType.ToDefaultValueSyntax(_knownTypeSymbols);

//if (SetupMethod.ReturnType.EqualsByName(["System", "Threading", "Tasks", "Task"]))
//{
// if (SetupMethod.ReturnType is INamedTypeSymbol namedType && namedType.TypeArguments.Length == 1)
// {
// // Task.FromResult<T>(default)
// defaultValue = KnownBlocks.TaskFromResult(
// namedType.TypeArguments[0].AsUnknownTypeSyntax(),
// LiteralExpression(SyntaxKind.DefaultLiteralExpression));
// }
// else
// {
// // Task.CompletedTask
// defaultValue = KnownBlocks.TaskCompletedTask;
// }
//}
//else if (SetupMethod.ReturnType.EqualsByName(["System", "Threading", "Tasks", "ValueTask"]))
//{
// if (SetupMethod.ReturnType is INamedTypeSymbol namedType && namedType.TypeArguments.Length == 1)
// {
// // ValueTask.FromResult<T>(default)
// defaultValue = KnownBlocks.ValueTaskFromResult(
// namedType.TypeArguments[0].AsUnknownTypeSyntax(),
// LiteralExpression(SyntaxKind.DefaultLiteralExpression)
// );
// }
// else
// {
// // ValueTask.CompletedTask
// defaultValue = KnownBlocks.ValueTaskCompletedTask;
// }
//}
//else
//{
// // TODO: Support custom awaitable
// // default
// defaultValue = LiteralExpression(SyntaxKind.DefaultLiteralExpression);
//}

cancellationToken.ThrowIfCancellationRequested();

Expand Down
50 changes: 48 additions & 2 deletions src/Pretender.SourceGenerator/SymbolExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,65 @@ public static TypeSyntax AsUnknownTypeSyntax(this ITypeSymbol type)
public static ExpressionSyntax ToDefaultValueSyntax(this INamedTypeSymbol type, KnownTypeSymbols knownTypeSymbols)
{
// They have explicitly annotated this type as nullable, so return null
if (type.NullableAnnotation == NullableAnnotation.Annotated)
if (type.NullableAnnotation != NullableAnnotation.NotAnnotated)
{
return LiteralExpression(SyntaxKind.DefaultLiteralExpression);
}

var comparer = SymbolEqualityComparer.Default;

if (type.IsUnboundGenericType)
{
throw new NotImplementedException("We believe this should have been impossible, please report this issue with a minimally reproducible sample.");
}

if (type.IsGenericType)
{
var unboundType = type.ConstructUnboundGenericType();

if (comparer.Equals(unboundType, knownTypeSymbols.TaskOfT_Unbound))
{
// Create Task.FromResult();
// TODO: Is this ever an unsafe cast? How could you have Task<anonymous_type>?
var resultType = (INamedTypeSymbol)type.TypeArguments[0];

// Recursion? Issue?
return KnownBlocks.TaskFromResult(resultType.AsUnknownTypeSyntax(), resultType.ToDefaultValueSyntax(knownTypeSymbols));
}

if (comparer.Equals(unboundType, knownTypeSymbols.ValueTaskOfT_Unbound))
{
var resultType = (INamedTypeSymbol)type.TypeArguments[0];

// Recursion!
return KnownBlocks.ValueTaskFromResult(resultType.AsUnknownTypeSyntax(), resultType.ToDefaultValueSyntax(knownTypeSymbols));
}

// TODO: Support IEnumerable, Lists, Arrays, and others
}

if (comparer.Equals(type, knownTypeSymbols.Task))
{
return KnownBlocks.TaskCompletedTask;
}

throw new NotImplementedException();
if (comparer.Equals(type, knownTypeSymbols.ValueTask))
{
return KnownBlocks.ValueTaskCompletedTask;
}

if (comparer.Equals(type, knownTypeSymbols.String))
{
// They have requested not-null so special case non-null string to be string.Empty
// people may not like this.
return MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
PredefinedType(Token(SyntaxKind.StringKeyword)),
IdentifierName("Empty"));
}

// No better default found, just use 'default' even though it might not be in line with nullability annotations
return LiteralExpression(SyntaxKind.DefaultLiteralExpression);
}

public static string ToPretendName(this ITypeSymbol symbol)
Expand Down
37 changes: 34 additions & 3 deletions test/SourceGeneratorTests/MainTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ public partial class MainTests : TestBase
[Fact]
public async Task ReturningMethod()
{
await RunAndCompareAsync($$"""
await RunAndComparePartialAsync($$"""
var pretendSimpleInterface = Pretend.That<ISimpleInterface>();
pretendSimpleInterface
Expand All @@ -16,11 +16,42 @@ await RunAndCompareAsync($$"""
""");
}

[Fact]
public async Task TaskOfTMethod()
{
var (result, _) = await RunGeneratorAsync($$"""
#nullable disable
using System;
using System.Threading.Tasks;
using Pretender;
public interface IMyInterface
{
Task<string> MethodAsync(string str);
}
public class TestClass
{
public TestClass()
{
var pretend = Pretend.That<IMyInterface>();
pretend.Setup(i => i.MethodAsync("Hi"));
}
}
""");

Assert.Equal(2, result.GeneratedSources.Length);

var text1 = result.GeneratedSources[0].SourceText.ToString();
var text2 = result.GeneratedSources[1].SourceText.ToString();
}


[Fact]
public async Task Test2()
{
var (result, compilation) = await RunGeneratorAsync($$"""
var (result, compilation) = await RunPartialGeneratorAsync($$"""
var pretendSimpleInterface = Pretend.That<SimpleAbstractClass>();
var simpleInterface = pretendSimpleInterface.Create();
Expand All @@ -39,7 +70,7 @@ public async Task Test2()
[Fact]
public async Task Test3()
{
var (result, compilation) = await RunGeneratorAsync($$"""
var (result, compilation) = await RunPartialGeneratorAsync($$"""
var pretendSimpleInterface = Pretend.That<ISimpleInterface>();
pretendSimpleInterface
Expand Down
37 changes: 32 additions & 5 deletions test/SourceGeneratorTests/TestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ public SimpleAbstractClass(string arg)
}
""";

public async Task<(GeneratorRunResult GeneratorResult, Compilation UpdateCompilation)> RunGeneratorAsync(string source)
public async Task<(GeneratorRunResult GeneratorResult, Compilation UpdateCompilation)> RunPartialGeneratorAsync(string source)
{
var compilation = await CreateCompilationAsync(source);
var compilation = await CreateCompilationFromPartialAsync(source);

var generator = new PretenderSourceGenerator().AsSourceGenerator();
GeneratorDriver driver = CSharpGeneratorDriver.Create(
Expand All @@ -83,9 +83,27 @@ public SimpleAbstractClass(string arg)
return (Assert.Single(runResult.Results), updateCompilation);
}

public async Task RunAndCompareAsync(string source, [CallerMemberName] string testMethodName = null!)
public async Task<(GeneratorRunResult GeneratorResult, Compilation UpdatedCompilation)> RunGeneratorAsync(string fullSource)
{
var (result, _) = await RunGeneratorAsync(source);
var compilation = await CreateCompilationAsync(fullSource);

var generator = new PretenderSourceGenerator().AsSourceGenerator();
GeneratorDriver driver = CSharpGeneratorDriver.Create(
generators: [generator],
driverOptions: new GeneratorDriverOptions(IncrementalGeneratorOutputKind.None, trackIncrementalGeneratorSteps: true),
parseOptions: ParseOptions
);

driver = driver.RunGeneratorsAndUpdateCompilation(compilation, out var updateCompilation, out var diagnostics);

var runResult = driver.GetRunResult();

return (Assert.Single(runResult.Results), updateCompilation);
}

public async Task RunAndComparePartialAsync(string source, [CallerMemberName] string testMethodName = null!)
{
var (result, _) = await RunPartialGeneratorAsync(source);
Assert.All(result.GeneratedSources, s =>
{
CompareAgainstBaseline(s, testMethodName);
Expand Down Expand Up @@ -124,7 +142,7 @@ private void CompareAgainstBaseline(GeneratedSourceResult result, string testMet
#endif
}

private Task<Compilation> CreateCompilationAsync(string source)
private Task<Compilation> CreateCompilationFromPartialAsync(string source)
{
var fullText = Base + CreateTestTemplate(source);
var project = BaseProject
Expand All @@ -134,6 +152,15 @@ private Task<Compilation> CreateCompilationAsync(string source)
return project.GetCompilationAsync()!;
}

private static Task<Compilation> CreateCompilationAsync(string fullSource)
{
var project = BaseProject
.AddDocument("MyTest.cs", SourceText.From(fullSource, Encoding.UTF8))
.Project;

return project.GetCompilationAsync()!;
}

private static Project CreateProject()
{
var projectName = $"TestProject-{Guid.NewGuid()}";
Expand Down

0 comments on commit bcaac87

Please sign in to comment.