diff --git a/.editorconfig b/.editorconfig index b5f39e6..31b6b5c 100644 --- a/.editorconfig +++ b/.editorconfig @@ -25,7 +25,7 @@ insert_final_newline = false [*.{cs,vb}] # Organize usings -dotnet_separate_import_directive_groups = true +dotnet_separate_import_directive_groups = false dotnet_sort_system_directives_first = true file_header_template = unset @@ -74,6 +74,11 @@ dotnet_code_quality_unused_parameters = all:suggestion # Suppression preferences dotnet_remove_unnecessary_suppression_exclusions = none +dotnet_analyzer_diagnostic.category-Performance.severity = suggestion +dotnet_analyzer_diagnostic.category-Design.severity = suggestion +dotnet_analyzer_diagnostic.category-Maintainability.severity = suggestion +dotnet_analyzer_diagnostic.category-Usage.severity = suggestion + #### C# Coding Conventions #### [*.cs] @@ -171,6 +176,12 @@ csharp_space_between_square_brackets = false csharp_preserve_single_line_blocks = true csharp_preserve_single_line_statements = true +# C# Style preferences +csharp_style_namespace_declarations = block_scoped:silent +csharp_style_prefer_method_group_conversion = true:silent +csharp_style_prefer_top_level_statements = true:silent +csharp_style_prefer_primary_constructors = false:silent + #### Naming styles #### [*.{cs,vb}] @@ -361,4 +372,6 @@ dotnet_naming_style.s_camelcase.required_prefix = s_ dotnet_naming_style.s_camelcase.required_suffix = dotnet_naming_style.s_camelcase.word_separator = dotnet_naming_style.s_camelcase.capitalization = camel_case - +tab_width = 4 +indent_size = 4 +end_of_line = crlf diff --git a/README.md b/README.md index 26b70e0..018aca8 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@ # Pretender -[![Nuget](https://img.shields.io/nuget/v/Pretender)](https://www.nuget.org/packages/Pretender) +[![NuGet](https://img.shields.io/nuget/v/Pretender)](https://www.nuget.org/packages/Pretender) ## Example ```c# -var pretendMyInterface = Pretend.For(); +var pretendMyInterface = Pretend.That(); pretendMyInterface .Setup(i => i.MyMethod(It.IsAny(), 14)) diff --git a/example/Example.csproj b/example/Example.csproj index 586fe95..d616c20 100644 --- a/example/Example.csproj +++ b/example/Example.csproj @@ -7,7 +7,7 @@ false true true - $(Features);InterceptorsPreview + $(InterceptorsPreviewNamespaces);Pretender.SourceGeneration diff --git a/example/Random.cs b/example/Random.cs index 5c1bf16..274ef9b 100644 --- a/example/Random.cs +++ b/example/Random.cs @@ -8,7 +8,7 @@ public class TestClass { public TestClass() { - var pretend = Pretend.For(); + var pretend = Pretend.That(); pretend.Setup(i => i.Greeting("John", 12)); } } diff --git a/example/UnitTest1.cs b/example/UnitTest1.cs index f4406a4..70546e7 100644 --- a/example/UnitTest1.cs +++ b/example/UnitTest1.cs @@ -7,7 +7,7 @@ public class UnitTest1 [Fact] public void Test1() { - var pretendMyInterface = Pretend.For() + var pretendMyInterface = Pretend.That() .Setup(i => i.Greeting("Mike", It.IsAny())) .Returns("Hi Mike!"); @@ -19,7 +19,7 @@ public void Test1() [Fact] public async Task Test2() { - var pretend = Pretend.For(); + var pretend = Pretend.That(); var local = "Value"; @@ -29,15 +29,25 @@ public async Task Test2() var myOtherInterface = pretend.Create(); + var value = await myOtherInterface.Greeting("Value"); + + pretend.Verify(i => i.Greeting(local), 1); Assert.Equal("Thing", value); } - [Fact] + [Fact] public void Test3() { - var pretend = Pretend.For() + var pretend = Pretend.That() .Setup(i => i.Greeting("Hello", It.IsAny())); + + var item = pretend.Pretend.Create(); + + var response = item.Greeting("Hello", 12); + Assert.Null(response); + + pretend.Verify(1); } } diff --git a/example/example.sln b/example/example.sln new file mode 100644 index 0000000..f42766c --- /dev/null +++ b/example/example.sln @@ -0,0 +1,25 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 17 +VisualStudioVersion = 17.5.002.0 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Example", "Example.csproj", "{2FBB56BB-D9CA-4D6F-95DB-9096E555B136}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {2FBB56BB-D9CA-4D6F-95DB-9096E555B136}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {2FBB56BB-D9CA-4D6F-95DB-9096E555B136}.Debug|Any CPU.Build.0 = Debug|Any CPU + {2FBB56BB-D9CA-4D6F-95DB-9096E555B136}.Release|Any CPU.ActiveCfg = Release|Any CPU + {2FBB56BB-D9CA-4D6F-95DB-9096E555B136}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {75D71B61-4537-4257-BCDB-BB11AA0B1EF6} + EndGlobalSection +EndGlobal diff --git a/perf/Comparison/Comparison.csproj b/perf/Comparison/Comparison.csproj index b375574..993c8c2 100644 --- a/perf/Comparison/Comparison.csproj +++ b/perf/Comparison/Comparison.csproj @@ -2,7 +2,7 @@ Exe - $(Features);InterceptorsPreview + $(InterceptorsPreviewNamespaces);Pretender.SourceGeneration false diff --git a/perf/Comparison/Simple.cs b/perf/Comparison/Simple.cs index 96190c8..44d1410 100644 --- a/perf/Comparison/Simple.cs +++ b/perf/Comparison/Simple.cs @@ -31,7 +31,7 @@ public string NSubstituteTest() [Benchmark(Baseline = true)] public string PretenderTest() { - var pretend = Pretend.For(); + var pretend = Pretend.That(); pretend.Setup(i => i.Foo(It.Is(static i => i == "1"))) .Returns("2"); diff --git a/pretender.sln b/pretender.sln index 35705fb..dc57272 100644 --- a/pretender.sln +++ b/pretender.sln @@ -15,6 +15,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "SourceGeneratorTests", "tes EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{9E33F6C1-8032-4CCA-A485-685B730A6EAE}" ProjectSection(SolutionItems) = preProject + .editorconfig = .editorconfig README.md = README.md EndProjectSection EndProject diff --git a/src/Pretender.SourceGenerator/CreateEntrypoint.cs b/src/Pretender.SourceGenerator/CreateEntrypoint.cs index fd77ee6..2a29749 100644 --- a/src/Pretender.SourceGenerator/CreateEntrypoint.cs +++ b/src/Pretender.SourceGenerator/CreateEntrypoint.cs @@ -1,45 +1,86 @@ -using System; -using System.Collections.Generic; -using System.Text; - +using System.Linq; using Microsoft.CodeAnalysis.CSharp.Syntax; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; using Microsoft.CodeAnalysis.Operations; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; +using System.Collections.Immutable; namespace Pretender.SourceGenerator { internal class CreateEntrypoint { - public CreateEntrypoint(IInvocationOperation operation) + public CreateEntrypoint(IInvocationOperation operation, ImmutableArray? typeArguments) { Operation = operation; Location = new InterceptsLocationInfo(operation); + TypeArguments = typeArguments; // TODO: Do any Diagnostics? } public InterceptsLocationInfo Location { get; } public IInvocationOperation Operation { get; } + public ImmutableArray? TypeArguments { get; } public MethodDeclarationSyntax GetMethodDeclaration(int index) { var returnType = Operation.TargetMethod.ReturnType; - var returnTypeName = returnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var returnTypeSyntax = returnType.AsUnknownTypeSyntax(); + + TypeParameterSyntax[] typeParameters; + ParameterSyntax[] methodParameters; + ArgumentSyntax[] constructorArguments; + + if (TypeArguments.HasValue) + { + typeParameters = new TypeParameterSyntax[TypeArguments.Value.Length]; - var returnStatement = ReturnStatement(ObjectCreationExpression(ParseTypeName(returnType.ToPretendName())) - .AddArgumentListArguments(Argument(IdentifierName("pretend")))); + // We always take the Pretend argument first as a this parameter + methodParameters = new ParameterSyntax[TypeArguments.Value.Length + 1]; + constructorArguments = new ArgumentSyntax[TypeArguments.Value.Length + 1]; - return MethodDeclaration(ParseTypeName(returnTypeName), $"Create{index}") - .WithBody(Block(returnStatement)) - .WithParameterList(ParameterList(SeparatedList(new[] + for (var i = 0; i < TypeArguments.Value.Length; i++) { - Parameter(Identifier("pretend")) - .WithType(ParseTypeName($"Pretend<{returnType}>")) - .WithModifiers(TokenList(Token(SyntaxKind.ThisKeyword))), - }))) + var typeName = $"T{i}"; + var argName = $"arg{i}"; + + typeParameters[i] = TypeParameter(typeName); + methodParameters[i + 1] = Parameter(Identifier(argName)) + .WithType(ParseTypeName(typeName)); + constructorArguments[i + 1] = Argument(IdentifierName(argName)); + } + } + else + { + typeParameters = []; + methodParameters = new ParameterSyntax[1]; + constructorArguments = new ArgumentSyntax[1]; + } + + methodParameters[0] = Parameter(Identifier("pretend")) + .WithType(GenericName("Pretend") + .AddTypeArgumentListArguments(returnTypeSyntax)) + .WithModifiers(TokenList(Token(SyntaxKind.ThisKeyword)) + ); + + constructorArguments[0] = Argument(IdentifierName("pretend")); + + var objectCreation = ObjectCreationExpression(ParseTypeName(returnType.ToPretendName())) + .WithArgumentList(ArgumentList(SeparatedList(constructorArguments))); + + var method = MethodDeclaration(returnTypeSyntax, $"Create{index}") + .WithBody(Block(ReturnStatement(objectCreation))) + .WithParameterList(ParameterList(SeparatedList(methodParameters))) .WithModifiers(TokenList(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.StaticKeyword))); + + if (typeParameters.Length > 0) + { + return method + .WithTypeParameterList(TypeParameterList(SeparatedList(typeParameters))); + } + + return method; } } @@ -49,7 +90,34 @@ public class CreateEntryPointComparer : IEqualityComparer bool IEqualityComparer.Equals(CreateEntrypoint x, CreateEntrypoint y) { - return SymbolEqualityComparer.Default.Equals(x.Operation.TargetMethod.ReturnType, y.Operation.TargetMethod.ReturnType); + return SymbolEqualityComparer.Default.Equals(x.Operation.TargetMethod.ReturnType, y.Operation.TargetMethod.ReturnType) + && CompareTypeArguments(x.TypeArguments, y.TypeArguments); + } + + static bool CompareTypeArguments(ImmutableArray? x, ImmutableArray? y) + { + if (!x.HasValue) + { + return !y.HasValue; + } + + var xArray = x.Value; + var yArray = y!.Value; + + if (xArray.Length != yArray.Length) + { + return false; + } + + for (int i = 0; i < xArray.Length; i++) + { + if (!SymbolEqualityComparer.IncludeNullability.Equals(xArray[i], yArray[i])) + { + return false; + } + } + + return true; } int IEqualityComparer.GetHashCode(CreateEntrypoint obj) diff --git a/src/Pretender.SourceGenerator/Emitter/CommonSyntax.cs b/src/Pretender.SourceGenerator/Emitter/CommonSyntax.cs new file mode 100644 index 0000000..a8c355c --- /dev/null +++ b/src/Pretender.SourceGenerator/Emitter/CommonSyntax.cs @@ -0,0 +1,38 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace Pretender.SourceGenerator.Emitter +{ + internal static class CommonSyntax + { + // General + public static PredefinedTypeSyntax VoidType { get; } = PredefinedType(Token(SyntaxKind.VoidKeyword)); + public static TypeSyntax VarType { get; } = ParseTypeName("var"); + public static GenericNameSyntax GenericPretendType { get; } = GenericName("Pretend"); + public static UsingDirectiveSyntax UsingSystem { get; } = UsingDirective(ParseName("System")); + public static UsingDirectiveSyntax UsingSystemThreadingTasks { get; } = UsingDirective(ParseName("System.Threading.Tasks")); + + // Verify + public static SyntaxToken SetupIdentifier { get; } = Identifier("setup"); + public static SyntaxToken CalledIdentifier { get; } = Identifier("called"); + public static ParameterSyntax CalledParameter { get; } = Parameter(CalledIdentifier) + .WithType(ParseTypeName("Called")); + + public static CompilationUnitSyntax CreateVerifyCompilationUnit(MethodDeclarationSyntax[] verifyMethods) + { + var classDeclaration = ClassDeclaration("VerifyInterceptors") + .AddModifiers(Token(SyntaxKind.FileKeyword), Token(SyntaxKind.StaticKeyword)) + .AddMembers(verifyMethods); + + var namespaceDeclaration = NamespaceDeclaration(ParseName("Pretender.SourceGeneration")) + .AddMembers(classDeclaration) + .AddUsings(UsingSystem, KnownBlocks.CompilerServicesUsing, UsingSystemThreadingTasks, KnownBlocks.PretenderUsing, KnownBlocks.PretenderInternalsUsing); + + return CompilationUnit() + .AddMembers(KnownBlocks.InterceptsLocationAttribute, namespaceDeclaration) + .NormalizeWhitespace(); + } + } +} diff --git a/src/Pretender.SourceGenerator/Emitter/VerifyEmitter.cs b/src/Pretender.SourceGenerator/Emitter/VerifyEmitter.cs new file mode 100644 index 0000000..41fc2b2 --- /dev/null +++ b/src/Pretender.SourceGenerator/Emitter/VerifyEmitter.cs @@ -0,0 +1,67 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Operations; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace Pretender.SourceGenerator.Emitter +{ + internal class VerifyEmitter + { + private readonly ITypeSymbol _pretendType; + private readonly ITypeSymbol? _returnType; + private readonly SetupCreationSpec _setupCreationSpec; + private readonly IInvocationOperation _invocationOperation; + + public VerifyEmitter(ITypeSymbol pretendType, ITypeSymbol? returnType, SetupCreationSpec setupCreationSpec, IInvocationOperation invocationOperation) + { + _pretendType = pretendType; + _returnType = returnType; + _setupCreationSpec = setupCreationSpec; + _invocationOperation = invocationOperation; + } + + public MethodDeclarationSyntax EmitVerifyMethod(int index, CancellationToken cancellationToken) + { + var setupGetter = _setupCreationSpec.CreateSetupGetter(cancellationToken); + + // var setup = pretend.GetOrCreateSetup(...); + var setupLocal = LocalDeclarationStatement(VariableDeclaration(CommonSyntax.VarType) + .WithVariables(SingletonSeparatedList(VariableDeclarator(CommonSyntax.SetupIdentifier) + .WithInitializer(EqualsValueClause(setupGetter))))); + + TypeSyntax pretendType = ParseTypeName(_pretendType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + + TypeSyntax setupExpressionType = _returnType == null + ? GenericName("Action").AddTypeArgumentListArguments(pretendType) + : GenericName("Func").AddTypeArgumentListArguments(pretendType, _returnType.AsUnknownTypeSyntax()); + + // setup.Verify(called); + var verifyInvocation = InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(CommonSyntax.SetupIdentifier), + IdentifierName("Verify") + ) + ) + .AddArgumentListArguments(Argument(IdentifierName(CommonSyntax.CalledIdentifier))); + + var interceptsInfo = new InterceptsLocationInfo(_invocationOperation); + + // public static void Verify0( + return MethodDeclaration(CommonSyntax.VoidType, Identifier($"Verify{index}")) + .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword)) + .AddParameterListParameters( + [ + Parameter(Identifier("pretend")) + .WithType(CommonSyntax.GenericPretendType.AddTypeArgumentListArguments(pretendType)) + .AddModifiers(Token(SyntaxKind.ThisKeyword)), + Parameter(Identifier("setupExpression")) + .WithType(setupExpressionType), + CommonSyntax.CalledParameter + ]) + .AddBodyStatements([setupLocal, ExpressionStatement(verifyInvocation)]) + .AddAttributeLists(AttributeList(SingletonSeparatedList(interceptsInfo.ToAttributeSyntax()))); + } + } +} diff --git a/src/Pretender.SourceGenerator/IncrementalValuesProviderExtensions.cs b/src/Pretender.SourceGenerator/IncrementalValuesProviderExtensions.cs index 7daa586..041c46a 100644 --- a/src/Pretender.SourceGenerator/IncrementalValuesProviderExtensions.cs +++ b/src/Pretender.SourceGenerator/IncrementalValuesProviderExtensions.cs @@ -1,8 +1,4 @@ -using System; -using System.Collections.Generic; -using System.Collections.Immutable; -using System.Text; - +using System.Collections.Immutable; using Microsoft.CodeAnalysis; namespace Pretender.SourceGenerator diff --git a/src/Pretender.SourceGenerator/InvocationOperationExtensions.cs b/src/Pretender.SourceGenerator/InvocationOperationExtensions.cs index 3112654..eb83d8e 100644 --- a/src/Pretender.SourceGenerator/InvocationOperationExtensions.cs +++ b/src/Pretender.SourceGenerator/InvocationOperationExtensions.cs @@ -1,4 +1,6 @@ -using Microsoft.CodeAnalysis; +using System.Collections.Immutable; + +using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Operations; @@ -50,5 +52,32 @@ public static bool IsValidSetupOperation(this IOperation operation, Compilation return false; } + + public static bool IsValidCreateOperation(this IOperation? operation, Compilation compilation, out IInvocationOperation invocationOperation, out ImmutableArray? typeArguments) + { + var pretendGeneric = compilation.GetTypeByMetadataName("Pretender.Pretend`1"); + + if (operation is IInvocationOperation targetOperation + && targetOperation.Instance is not null + && SymbolEqualityComparer.Default.Equals(targetOperation.Instance.Type!.OriginalDefinition, pretendGeneric)) + { + invocationOperation = targetOperation; + if (targetOperation.TargetMethod.Parameters.Length == 1 && targetOperation.TargetMethod.Parameters[0].IsParams) + { + // They are in the params fallback, how lol? + typeArguments = null; + } + else + { + typeArguments = targetOperation.TargetMethod.TypeArguments; + } + + return true; + } + + invocationOperation = null!; + typeArguments = null; + return false; + } } } diff --git a/src/Pretender.SourceGenerator/KnownBlocks.cs b/src/Pretender.SourceGenerator/KnownBlocks.cs index e09abd1..df93fc4 100644 --- a/src/Pretender.SourceGenerator/KnownBlocks.cs +++ b/src/Pretender.SourceGenerator/KnownBlocks.cs @@ -1,11 +1,7 @@ -using System; -using System.Collections.Generic; -using System.Reflection; -using System.Text; - -using Microsoft.CodeAnalysis; +using System.Reflection; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; namespace Pretender.SourceGenerator { @@ -13,7 +9,7 @@ internal static class KnownBlocks { private static readonly AssemblyName s_assemblyName = typeof(KnownBlocks).Assembly.GetName(); private static readonly string GeneratedCodeAnnotationString = $@"[GeneratedCode(""{s_assemblyName.Name}"", ""{s_assemblyName.Version}"")]"; - public static MemberDeclarationSyntax InterceptsLocationAttribute { get; } = ((CompilationUnitSyntax)SyntaxFactory.ParseSyntaxTree($$""" + public static MemberDeclarationSyntax InterceptsLocationAttribute { get; } = ((CompilationUnitSyntax)ParseSyntaxTree($$""" namespace System.Runtime.CompilerServices { using System; @@ -31,15 +27,49 @@ public InterceptsLocationAttribute(string filePath, int line, int column) """).GetRoot()).Members[0]; public static NamespaceDeclarationSyntax OurNamespace { get; } - = SyntaxFactory.NamespaceDeclaration(SyntaxFactory.IdentifierName("Pretender.SourceGeneration")); + = NamespaceDeclaration(IdentifierName("Pretender.SourceGeneration")); public static UsingDirectiveSyntax PretenderUsing { get; } - = SyntaxFactory.UsingDirective(SyntaxFactory.ParseName("Pretender")); + = UsingDirective(ParseName("Pretender")); public static UsingDirectiveSyntax PretenderInternalsUsing { get; } - = SyntaxFactory.UsingDirective(SyntaxFactory.ParseName("Pretender.Internals")); + = UsingDirective(ParseName("Pretender.Internals")); public static UsingDirectiveSyntax CompilerServicesUsing { get; } - = SyntaxFactory.UsingDirective(SyntaxFactory.ParseName("System.Runtime.CompilerServices")); + = UsingDirective(ParseName("System.Runtime.CompilerServices")); + + public static MemberAccessExpressionSyntax TaskCompletedTask = MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("Task"), + IdentifierName("CompletedTask") + ); + + public static MemberAccessExpressionSyntax ValueTaskCompletedTask = MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("ValueTask"), + IdentifierName("CompletedTask") + ); + + public static InvocationExpressionSyntax TaskFromResult(TypeSyntax resultType, ExpressionSyntax resultValue) => InvocationExpression( + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("Task"), + GenericName("FromResult") + .AddTypeArgumentListArguments(resultType)) + ) + .AddArgumentListArguments(Argument(resultValue)); + + public static InvocationExpressionSyntax ValueTaskFromResult(TypeSyntax resultType, ExpressionSyntax resultValue) + { + // ValueTask.FromResult + var memberAccess = MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("ValueTask"), + GenericName("FromResult") + .AddTypeArgumentListArguments(resultType)); + + // ValueTask.FromResult(value) + return InvocationExpression(memberAccess, + ArgumentList( + SingletonSeparatedList(Argument(resultValue)))); + } } } diff --git a/src/Pretender.SourceGenerator/Parser/KnownTypeSymbols.cs b/src/Pretender.SourceGenerator/Parser/KnownTypeSymbols.cs new file mode 100644 index 0000000..c41b05d --- /dev/null +++ b/src/Pretender.SourceGenerator/Parser/KnownTypeSymbols.cs @@ -0,0 +1,38 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; + +namespace Pretender.SourceGenerator.Parser +{ + internal sealed class KnownTypeSymbols + { + public CSharpCompilation Compilation { get; } + + // INamedTypeSymbols + public INamedTypeSymbol? Pretend { get; } + public INamedTypeSymbol? Pretend_Unbound { get; } + + public KnownTypeSymbols(CSharpCompilation compilation) + { + Compilation = compilation; + + // TODO: Get known types + Pretend = compilation.GetTypeByMetadataName("Pretender.Pretend`1"); + Pretend_Unbound = Pretend?.ConstructUnboundGenericType(); + } + + public static bool IsPretend(INamedTypeSymbol type) + { + // This should be enough + return type is + { + Name: "Pretend", + ContainingNamespace: + { + Name: "Pretender", + ContainingNamespace.IsGlobalNamespace: true, + }, + ContainingAssembly.Name: "Pretender", + }; + } + } +} diff --git a/src/Pretender.SourceGenerator/Parser/VerifyParser.cs b/src/Pretender.SourceGenerator/Parser/VerifyParser.cs new file mode 100644 index 0000000..96561f9 --- /dev/null +++ b/src/Pretender.SourceGenerator/Parser/VerifyParser.cs @@ -0,0 +1,45 @@ +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; +using Pretender.SourceGenerator.Emitter; +using static Pretender.SourceGenerator.PretenderSourceGenerator; + +namespace Pretender.SourceGenerator.Parser +{ + internal class VerifyParser + { + private readonly VerifyInvocation _verifyInvocation; + private readonly KnownTypeSymbols _knownTypeSymbols; + + public VerifyParser(VerifyInvocation verifyInvocation, CompilationData compilationData) + { + _knownTypeSymbols = compilationData.TypeSymbols!; + _verifyInvocation = verifyInvocation; + } + + public (VerifyEmitter? VerifyEmitter, ImmutableArray? Diagnostics) GetVerifyEmitter(CancellationToken cancellationToken) + { + var operation = _verifyInvocation.Operation; + + // Verify calls are expected to have 2 arguments, the first being the setup expression + var setupArgument = operation.Arguments[0]; + + // Verify calls are expected to be called from Pretend so the type argument gives us the type we are pretending + var pretendType = operation.TargetMethod.ContainingType.TypeArguments[0]; + + // TODO: This doesn't exist yet + var useSetMethod = operation.TargetMethod.Name == "VerifySet"; + + // TODO: This should be done in a Parser type class as well + var setupCreationSpec = new SetupCreationSpec(setupArgument, pretendType, useSetMethod); + + var returnType = setupArgument.Parameter!.Type.Name == "Func" + ? ((INamedTypeSymbol)setupArgument.Parameter.Type).TypeArguments[1] // The Func variant is expected to have the return type in the second type argument + : null; + + var emitter = new VerifyEmitter(pretendType, returnType, setupCreationSpec, _verifyInvocation.Operation); + + // TODO: Get diagnostics from elsewhere + return (emitter, null); + } + } +} diff --git a/src/Pretender.SourceGenerator/PretendEntrypoint.cs b/src/Pretender.SourceGenerator/PretendEntrypoint.cs index e8e96e4..e1c3456 100644 --- a/src/Pretender.SourceGenerator/PretendEntrypoint.cs +++ b/src/Pretender.SourceGenerator/PretendEntrypoint.cs @@ -1,10 +1,5 @@ -using System; -using System.Collections.Generic; -using System.Collections.Immutable; +using System.Collections.Immutable; using System.Diagnostics; -using System.Linq; -using System.Threading; - using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -39,13 +34,10 @@ public PretendEntrypoint(ITypeSymbol typeToPretend, Location invocationLocation) invocationLocation, TypeToPretend)); } - - PretendName = TypeToPretend.ToPretendName(); } public ITypeSymbol TypeToPretend { get; } public Location InvocationLocation { get; } - public string PretendName { get; } public List Diagnostics { get; } = new List(); public CompilationUnitSyntax GetCompilationUnit(CancellationToken token) @@ -165,16 +157,23 @@ public CompilationUnitSyntax GetCompilationUnit(CancellationToken token) Trivia(NullableDirectiveTrivia(Token(SyntaxKind.EnableKeyword), true)), Comment("/// ")); + var sourceGenerationNamespace = KnownBlocks.OurNamespace + .AddMembers(classDeclaration.WithInheritDoc()) + .AddUsings( + UsingDirective(IdentifierName("System.Reflection")), + KnownBlocks.PretenderUsing + ); + return CompilationUnit() - .AddMembers(classDeclaration.WithLeadingTrivia(leadingTrivia)) + .AddMembers(sourceGenerationNamespace) .WithLeadingTrivia(leadingTrivia) .NormalizeWhitespace(); } private static FieldDeclarationSyntax CreateMethodInfoField(IMethodSymbol method, ExpressionSyntax expressionSyntax) { - // public static readonly MethodInfo_name_4B2 = !; - return FieldDeclaration(VariableDeclaration(ParseTypeName("global::System.Reflection.MethodInfo"))) + // public static readonly MethodInfo MethodInfo_name_4B2 = !; + return FieldDeclaration(VariableDeclaration(ParseTypeName("MethodInfo"))) .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.ReadOnlyKeyword)) .AddDeclarationVariables(VariableDeclarator(Identifier(method.ToMethodInfoCacheName())) .WithInitializer(EqualsValueClause( @@ -223,14 +222,14 @@ private ParameterSyntax CreateConstructorParameter() private TypeSyntax GetGenericPretendType() { - return GenericName(Identifier("global::Pretender.Pretend"), + return GenericName(Identifier("Pretend"), TypeArgumentList(SingletonSeparatedList(ParseTypeName(TypeToPretend.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))))); } private FieldDeclarationSyntax GetStaticMethodCacheField(IMethodSymbol method, int index) { // TODO: Get method info via argument types - return FieldDeclaration(VariableDeclaration(ParseTypeName("global::System.Reflection.MethodInfo"))) + return FieldDeclaration(VariableDeclaration(ParseTypeName("MethodInfo"))) .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.ReadOnlyKeyword)) .AddDeclarationVariables(VariableDeclarator(Identifier($"__methodInfo_{method.Name}_{index}")) .WithInitializer(EqualsValueClause( @@ -242,20 +241,28 @@ private BlockSyntax CreateMethodBody(IMethodSymbol method) { var methodBodyStatements = new List(); + // This is using the new collection expression syntax in C# 12 + // [arg1, arg2, arg3] var collectionExpression = CollectionExpression() - .AddElements(method.Parameters.Select(p => - { - return ExpressionElement(IdentifierName(p.Name)); - }).ToArray()); + .AddElements(method.Parameters.Select(p + => ExpressionElement(IdentifierName(p.Name))).ToArray()); - // ReadOnlySpan arguments = [arg0, arg1]; + // object?[] + var typeSyntax = ArrayType(NullableType(PredefinedType(Token(SyntaxKind.ObjectKeyword)))) + .WithRankSpecifiers(SingletonList(ArrayRankSpecifier())); + var argumentsIdentifier = IdentifierName("__arguments"); + var callInfoIdentifier = IdentifierName("__callInfo"); + + // I'm not currently able to use Span because I have to store CallInfo for late Setup/Verify + // but I don't want to delete this code in case this becomes possible or I don't want to support that + // Span + //var typeSyntax = GenericName("Span").AddTypeArgumentListArguments(NullableType(PredefinedType(Token(SyntaxKind.ObjectKeyword)))); + + // object? [] arguments = [arg0, arg1]; var argumentsDeclaration = LocalDeclarationStatement( - VariableDeclaration( - // ReadOnlySpan - GenericName("Span").AddTypeArgumentListArguments(NullableType(PredefinedType(Token(SyntaxKind.ObjectKeyword)))) - ) - .AddVariables(VariableDeclarator("arguments") + VariableDeclaration(typeSyntax) + .AddVariables(VariableDeclarator(argumentsIdentifier.Identifier) .WithInitializer(EqualsValueClause(collectionExpression)) ) ); @@ -265,18 +272,20 @@ private BlockSyntax CreateMethodBody(IMethodSymbol method) // var callInfo = new CallInfo(__methodInfo_MethodName_0, arguments); var callInfoCreation = LocalDeclarationStatement( VariableDeclaration(IdentifierName("var")) - .AddVariables(VariableDeclarator(Identifier("callInfo")) - .WithInitializer(EqualsValueClause(ObjectCreationExpression(ParseTypeName("global::Pretender.CallInfo")) - .AddArgumentListArguments(Argument(IdentifierName(method.ToMethodInfoCacheName())), Argument(IdentifierName("arguments"))))))); + .AddVariables(VariableDeclarator(callInfoIdentifier.Identifier) + .WithInitializer(EqualsValueClause(ObjectCreationExpression(ParseTypeName("CallInfo")) + .AddArgumentListArguments(Argument(IdentifierName(method.ToMethodInfoCacheName())), Argument(argumentsIdentifier)))))); methodBodyStatements.Add(callInfoCreation); + // TODO: Call inner implementations when we support them + // _pretend.Handle(callInfo); var handleCall = ExpressionStatement(InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName("_pretend"), IdentifierName("Handle"))) .WithArgumentList(ArgumentList(SingletonSeparatedList( - Argument(IdentifierName("callInfo")).WithRefKindKeyword(Token(SyntaxKind.RefKeyword)))))); + Argument(callInfoIdentifier))))); methodBodyStatements.Add(handleCall); @@ -285,6 +294,8 @@ private BlockSyntax CreateMethodBody(IMethodSymbol method) var refAndOutParameters = method.Parameters .Where(p => p.RefKind == RefKind.Ref || p.RefKind == RefKind.Out); + + foreach (var p in refAndOutParameters) { // assign them to the values from arguments @@ -292,7 +303,7 @@ private BlockSyntax CreateMethodBody(IMethodSymbol method) SyntaxKind.SimpleAssignmentExpression, IdentifierName(p.Name), ElementAccessExpression( - IdentifierName("arguments"), + argumentsIdentifier, BracketedArgumentList(SingletonSeparatedList( Argument(LiteralExpression( SyntaxKind.NumericLiteralExpression, @@ -305,8 +316,8 @@ private BlockSyntax CreateMethodBody(IMethodSymbol method) if (method.ReturnType.SpecialType != SpecialType.System_Void) { var returnStatement = ReturnStatement(CastExpression( - ParseTypeName(method.ReturnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)), - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName("callInfo"), IdentifierName("ReturnValue")))); + method.ReturnType.AsUnknownTypeSyntax(), + MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, callInfoIdentifier, IdentifierName("ReturnValue")))); methodBodyStatements.Add(returnStatement); } diff --git a/src/Pretender.SourceGenerator/PretenderSourceGenerator.cs b/src/Pretender.SourceGenerator/PretenderSourceGenerator.cs index 0e95ed6..90fcac2 100644 --- a/src/Pretender.SourceGenerator/PretenderSourceGenerator.cs +++ b/src/Pretender.SourceGenerator/PretenderSourceGenerator.cs @@ -1,11 +1,10 @@ -using System.Collections.Generic; -using System.Collections.Immutable; -using System.Linq; +using System.Collections.Immutable; using System.Text; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Operations; +using Pretender.SourceGenerator.Emitter; +using Pretender.SourceGenerator.Parser; namespace Pretender.SourceGenerator { @@ -14,20 +13,29 @@ public class PretenderSourceGenerator : IIncrementalGenerator { public void Initialize(IncrementalGeneratorInitializationContext context) { + // TODO: Refactor our region use + // TODO: Create compilation data + IncrementalValueProvider compilationData = + context.CompilationProvider + .Select((compilation, _) => compilation.Options is CSharpCompilationOptions + ? new CompilationData((CSharpCompilation)compilation) + : null); + + #region Pretend IncrementalValuesProvider pretendsWithDiagnostics = context.SyntaxProvider.CreateSyntaxProvider( predicate: static (node, token) => { - // Pretend.For(); + // Pretend.That(); if (node is InvocationExpressionSyntax { Expression: MemberAccessExpressionSyntax { // TODO: Will this work with a using static Pretender.Pretend // ... - // For(); + // That(); Expression: IdentifierNameSyntax { Identifier.ValueText: "Pretend" }, - Name: GenericNameSyntax { Identifier.ValueText: "For", TypeArgumentList.Arguments.Count: 1 }, + Name: GenericNameSyntax { Identifier.ValueText: "That", TypeArgumentList.Arguments.Count: 1 }, }, }) { @@ -70,9 +78,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context) context.RegisterSourceOutput(pretends, static (context, pretend) => { var compilationUnit = pretend.Source.GetCompilationUnit(context.CancellationToken); - context.AddSource($"Pretender.Type.{pretend.Source.PretendName}.g.cs", compilationUnit.GetText(Encoding.UTF8)); + context.AddSource($"Pretender.Type.{pretend.Source.TypeToPretend.ToPretendName()}.g.cs", compilationUnit.GetText(Encoding.UTF8)); }); + #endregion + #region Setup var setupCallsWithDiagnostics = context.SyntaxProvider.CreateSyntaxProvider( predicate: static (node, _) => node.IsSetupCall(), @@ -135,7 +145,62 @@ public void Initialize(IncrementalGeneratorInitializationContext context) context.AddSource("Pretender.Setups.g.cs", compilationUnit.GetText(Encoding.UTF8)); }); + #endregion + + #region Verify + IncrementalValuesProvider<(VerifyEmitter? Emitter, ImmutableArray? Diagnostics)> verifyCallsWithDiagnostics = + context.SyntaxProvider.CreateSyntaxProvider( + predicate: (node, _) => VerifyInvocation.IsCandidateSyntaxNode(node), + transform: VerifyInvocation.Create) + .Where(vi => vi is not null) + .Combine(compilationData) + .Select((tuple, cancellationToken) => + { + if (tuple.Right is not CompilationData compilationData) + { + return (null, null); + } + + // Create new VerifySpec + var parser = new VerifyParser(tuple.Left!, compilationData); + + return parser.GetVerifyEmitter(cancellationToken); + }) + .WithTrackingName("Verify"); + // TODO: Register diagnostics + context.RegisterSourceOutput(verifyCallsWithDiagnostics.Collect(), (context, inputs) => + { + var methods = new List(); + for ( var i = 0; i < inputs.Length; i++) + { + var input = inputs[i]; + if (input.Diagnostics is ImmutableArray diagnostics) + { + foreach (var diagnostic in diagnostics) + { + context.ReportDiagnostic(diagnostic); + } + } + + if (input.Emitter is VerifyEmitter emitter) + { + // TODO: Emit VerifyMethod + var method = emitter.EmitVerifyMethod(0, context.CancellationToken); + methods.Add(method); + } + } + + if (methods.Count > 0) + { + // Emit all methods + var compilationUnit = CommonSyntax.CreateVerifyCompilationUnit([.. methods]); + context.AddSource("Pretender.Verifies.g.cs", compilationUnit.GetText(Encoding.UTF8)); + } + }); + #endregion + + #region Create var createCalls = context.SyntaxProvider.CreateSyntaxProvider( predicate: (node, token) => { @@ -145,7 +210,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context) { Name.Identifier.ValueText: "Create" }, - ArgumentList.Arguments.Count: 0 } ) { @@ -157,12 +221,9 @@ public void Initialize(IncrementalGeneratorInitializationContext context) transform: (context, token) => { var operation = context.SemanticModel.GetOperation(context.Node); - var invocationOperation = (IInvocationOperation?)operation; - - if (invocationOperation?.Instance is not null) + if (operation.IsValidCreateOperation(context.SemanticModel.Compilation, out var invocation, out var typeArguments)) { - // TODO: Do more validation, we should match the type this is being done on. - return new CreateEntrypoint(invocationOperation); + return new CreateEntrypoint(invocation, typeArguments); } return null; @@ -189,6 +250,26 @@ public void Initialize(IncrementalGeneratorInitializationContext context) context.AddSource($"Pretender.Creates.{createCalls.Source.Operation.TargetMethod.ReturnType.ToPretendName()}.g.cs", cu.GetText(Encoding.UTF8)); }); + #endregion + } + + internal sealed class CompilationData + { + public bool LanguageVersionIsSupported { get; } + public KnownTypeSymbols? TypeSymbols { get; } + + public CompilationData(CSharpCompilation compilation) + { + // We don't have a CSharp12 value available yet. Polyfill the value here for forward compat, rather than use the LanguageVersion.Preview enum value. + // https://github.com/dotnet/roslyn/blob/168689931cb4e3150641ec2fb188a64ce4b3b790/src/Compilers/CSharp/Portable/LanguageVersion.cs#L218-L232 + const int LangVersion_CSharp12 = 1200; + LanguageVersionIsSupported = (int)compilation.LanguageVersion >= LangVersion_CSharp12; + + if (LanguageVersionIsSupported) + { + TypeSymbols = new KnownTypeSymbols(compilation); + } + } } } } diff --git a/src/Pretender.SourceGenerator/ScaffoldTypeOptions.cs b/src/Pretender.SourceGenerator/ScaffoldTypeOptions.cs index f4d73b7..113caa3 100644 --- a/src/Pretender.SourceGenerator/ScaffoldTypeOptions.cs +++ b/src/Pretender.SourceGenerator/ScaffoldTypeOptions.cs @@ -8,7 +8,7 @@ namespace Pretender.SourceGenerator; public class ScaffoldTypeOptions { - public ImmutableArray CustomFields { get; set; } = []; + public ImmutableArray CustomFields { get; set; } = default; public Func AddMethodBody { get; set; } = (_) => Block(); // TODO: Is there a better symbol for constructors, methods? diff --git a/src/Pretender.SourceGenerator/SetupArgument.cs b/src/Pretender.SourceGenerator/SetupArgument.cs deleted file mode 100644 index f2b0773..0000000 --- a/src/Pretender.SourceGenerator/SetupArgument.cs +++ /dev/null @@ -1,424 +0,0 @@ -using Microsoft.CodeAnalysis.CSharp.Syntax; -using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; -using Microsoft.CodeAnalysis.Operations; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using System.Diagnostics; -using System; -using System.Linq; -using System.Collections.Immutable; - -namespace Pretender.SourceGenerator -{ - internal class ArgumentTracker - { - private readonly List _neededLocals = new(); - private readonly Stack> _trackedLocals = new(); - - public ArgumentTracker() - { - _trackedLocals = new Stack>(); - _trackedLocals.Push(new HashSet(SymbolEqualityComparer.Default)); - } - - public ImmutableArray NeededLocals => _neededLocals.ToImmutableArray(); - public bool NeedsCapturer { get; private set; } - public void SetNeedsCapturer() - { - NeedsCapturer = true; - } - - public bool TryTrackLocal(ILocalReferenceOperation localReferenceOperation) - { - var currentScope = _trackedLocals.Peek(); - if (currentScope.Contains(localReferenceOperation.Local)) - { - // This is being tracked as created during the current scope, ignore it - return false; - } - - _neededLocals.Add(localReferenceOperation); - return true; - } - - public void LocalDefined(ILocalSymbol local) - { - var currentScope = _trackedLocals.Peek(); - currentScope.Add(local); - } - - public void LocalsDefined(IEnumerable locals) - { - var currentScope = _trackedLocals.Peek(); - foreach (var local in locals) - { - currentScope.Add(local); - } - } - - // TODO: could create an IDisposable for this - public void EnterScope() - { - _trackedLocals.Push(new(SymbolEqualityComparer.Default)); - } - - public void ExitScope() - { - _trackedLocals.Pop(); - } - } - - - internal class SetupArgument - { - private static readonly IdentifierNameSyntax CallInfoIdentifier = IdentifierName("callInfo"); - private static readonly IdentifierNameSyntax ArgumentsPropertyIdentifier = IdentifierName("Arguments"); - - private readonly int _index; - - public SetupArgument(IArgumentOperation argumentOperation, int index, List diagnostics) - { - var argOperationValue = argumentOperation.Value; - var tracker = new ArgumentTracker(); - // Walk the operation tree to find all locals - Visit(argOperationValue, tracker); - - RequiredLocals = tracker.NeededLocals; - NeedsCapturer = tracker.NeedsCapturer; - - ArgumentOperation = argumentOperation; - _index = index; - } - - - public ImmutableArray RequiredLocals { get; } - - public IArgumentOperation ArgumentOperation { get; } - - public bool NeedsCapturer { get; } - public bool IsLiteral => ArgumentOperation.Value is ILiteralOperation; - public bool IsInvocation => ArgumentOperation.Value is IInvocationOperation; - public bool IsLocalReference => ArgumentOperation.Value is ILocalReferenceOperation; - - public ITypeSymbol ParameterType => ArgumentOperation.Parameter!.Type; - public string ArgumentLocalName => $"{ArgumentOperation.Parameter!.Name}_arg"; - - - public LocalDeclarationStatementSyntax EmitArgumentAccessor() - { - // (string?)callInfo.Arguments[index]; - ExpressionSyntax argumentGetter = CastExpression( - ParameterType.AsUnknownTypeSyntax(), - ElementAccessExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, CallInfoIdentifier, ArgumentsPropertyIdentifier)) - .AddArgumentListArguments(Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(_index)))) - ); - - // var name_arg = (string?)callInfo.Arguments[0]; - return LocalDeclarationStatement(VariableDeclaration(ParseTypeName("var")) - .AddVariables(VariableDeclarator(ArgumentLocalName) - .WithInitializer(EqualsValueClause(argumentGetter)))); - } - - public bool TryEmitInvocationStatements(out StatementSyntax[] statements) - { - Debug.Assert(IsInvocation, "Should have been asserted already."); - var invocationOperation = (IInvocationOperation)ArgumentOperation.Value; - if (TryGetMatcherAttributeType(invocationOperation, out var matcherType)) - { - // AnyMatcher - if (matcherType.EqualsByName(["Pretender", "Matchers", "AnyMatcher"])) - { - statements = []; - return true; - } - - var arguments = new ArgumentSyntax[invocationOperation.Arguments.Length]; - bool allArgumentsSafe = true; - - for ( int i = 0; i < arguments.Length; i++) - { - var arg = invocationOperation.Arguments[i]; - if (arg.Value is ILiteralOperation literalOperation) - { - arguments[i] = Argument(literalOperation.ToLiteralExpression()); - } - else if (arg.Value is IDelegateCreationOperation delegateCreation) - { - - if (delegateCreation.Target is IAnonymousFunctionOperation anonymousFunctionOperation) - { - if (anonymousFunctionOperation.Symbol.IsStatic) // This isn't enough either though, they could call a static method that only exists in their context - { - // If it's guaranteed to be static, we can just rewrite it in our code - arguments[i] = Argument(ParseExpression(delegateCreation.Syntax.GetText().ToString())); - } - else if (false) // Is non-scope capturing - { - // This is a lot more work but also very powerful in terms of speed - // We need to rewrite the delegate and replace all local references with our getter - allArgumentsSafe = false; - } - else - { - // We need a static matcher - allArgumentsSafe = false; - } - } - else - { - allArgumentsSafe = false; - } - } - else - { - allArgumentsSafe = false; - } - } - - if (!allArgumentsSafe) - { - statements = []; - return false; - } - - statements = new StatementSyntax[3]; - statements[0] = EmitArgumentAccessor(); - - var matcherLocalName = $"{ArgumentOperation.Parameter!.Name}_matcher"; - - // var name_matcher = new global::MyMatcher(arg0, arg1); - statements[1] = LocalDeclarationStatement( - VariableDeclaration(ParseTypeName("var")) - .AddVariables(VariableDeclarator(matcherLocalName) - .WithInitializer(EqualsValueClause(ObjectCreationExpression(ParseTypeName(matcherType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))) - .AddArgumentListArguments(arguments)) - ) - ) - ); - - statements[2] = CreateArgumentCheck( - PrefixUnaryExpression( - SyntaxKind.LogicalNotExpression, - InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(matcherLocalName), - IdentifierName("Matches") - ) - ) - .AddArgumentListArguments(Argument(IdentifierName(ArgumentLocalName))) - ) - ); - - return true; - } - else - { - // TODO: Setup static listener - statements = []; - return false; - } - } - - // Returns true if the visited operation captured a local - private static bool Visit(IOperation? operation, ArgumentTracker tracker) - { - if (operation == null) - { - return false; - } - - // TODO: Handle most operations - switch (operation.Kind) - { - case OperationKind.Block: - var block = (IBlockOperation)operation; - return VisitMany(block.Operations, tracker); - case OperationKind.VariableDeclarationGroup: - var variableDeclarationGroup = (IVariableDeclarationGroupOperation)operation; - return VisitMany(variableDeclarationGroup.Declarations, tracker); - - case OperationKind.Return: - var returnOp = (IReturnOperation)operation; - return Visit(returnOp.ReturnedValue, tracker); - case OperationKind.Literal: - // Literals are the best, they are easy and the end of the line - return false; - case OperationKind.Invocation: - var invocation = (IInvocationOperation)operation; - // The instance could be a local itself - return Visit(invocation.Instance, tracker) - | VisitMany(invocation.Arguments, tracker); - case OperationKind.LocalReference: - var local = (ILocalReferenceOperation)operation; - tracker.TryTrackLocal(local); - return true; - case OperationKind.ParameterReference: - return false; - case OperationKind.Binary: - var binary = (IBinaryOperation)operation; - return Visit(binary.LeftOperand, tracker) | Visit(binary.RightOperand, tracker); - case OperationKind.AnonymousFunction: - // TODO: I'm not sure if this belongs in here or DelegateCreation but lets go with here for now - tracker.EnterScope(); - var anonymousFunction = (IAnonymousFunctionOperation)operation; - var found = Visit(anonymousFunction.Body, tracker); - tracker.ExitScope(); - return found; - case OperationKind.DelegateCreation: - var delegateCreation = (IDelegateCreationOperation)operation; - // TODO: Now that we are in a delegate should we ignore their locals somehow? - return Visit(delegateCreation.Target, tracker); - case OperationKind.VariableInitializer: - var variableInitializer = (IVariableInitializerOperation)operation; - tracker.LocalsDefined(variableInitializer.Locals); - // TODO: Not sure if this is right - Visit(variableInitializer.Value, tracker); - return true; - case OperationKind.VariableDeclaration: - var variableDeclaration = (IVariableDeclarationOperation)operation; - return VisitMany(variableDeclaration.Declarators, tracker) - | Visit(variableDeclaration.Initializer, tracker); - case OperationKind.VariableDeclarator: - var variableDeclarator = (IVariableDeclaratorOperation)operation; - tracker.LocalDefined(variableDeclarator.Symbol); - // TODO: IgnoredArguments property? - return Visit(variableDeclarator.Initializer, tracker); - - - case OperationKind.Argument: - var argument = (IArgumentOperation)operation; - return Visit(argument.Value, tracker); - - - } - - throw new NotImplementedException($"Can't visit operation '{operation.Kind}'"); - } - - private static bool VisitMany(IEnumerable operations, ArgumentTracker tracker) - { - var foundLocal = false; - foreach (var operation in operations) - { - if (Visit(operation, tracker)) - { - foundLocal = true; - } - } - - return foundLocal; - } - - private bool TryGetMatcherAttributeType(IInvocationOperation invocationOperation, out INamedTypeSymbol matcherType) - { - var allAttributes = invocationOperation.TargetMethod.GetAttributes(); - var matcherAttribute = allAttributes.Single(ad => ad.AttributeClass!.EqualsByName(["Pretender", "Matchers", "MatcherAttribute"])); - - matcherType = null!; - - if (matcherAttribute.AttributeClass!.IsGenericType) - { - // We are in the typed version, get the generic arg - matcherType = (INamedTypeSymbol)matcherAttribute.AttributeClass.TypeArguments[0]; - } - else - { - // We are in the base version, get the constructor arg - // TODO: Make this work - // matcherType = matcherAttribute.ConstructorArguments[0]; - var attributeType = matcherAttribute.ConstructorArguments[0]; - // TODO: When can Type be null? - if (!attributeType.Type!.EqualsByName(["System", "Type"])) - { - return false; - } - - if (attributeType.Value is null) - { - return false; - } - - matcherType = (INamedTypeSymbol)attributeType.Value!; - } - - // TODO: Write a lot more tests for this - if (matcherType.IsUnboundGenericType) - { - if (invocationOperation.TargetMethod.TypeArguments.Length != matcherType.TypeArguments.Length) - { - return false; - } - - matcherType = matcherType.ConstructedFrom.Construct([.. invocationOperation.TargetMethod.TypeArguments]); - } - - return true; - } - - public StatementSyntax[] EmitLocalIfCheck(int index) - { - Debug.Assert(IsLocalReference, "Shouldn't have been called."); - - var localOperation = (ILocalReferenceOperation)ArgumentOperation.Value; - - var variableName = $"{ArgumentOperation.Parameter!.Name}_local"; - - var statements = new StatementSyntax[3]; - statements[0] = EmitArgumentAccessor(); - - // This is for calling the UnsafeAccessor method that doesn't seem to work for my needs - //statements[1] = LocalDeclarationStatement(VariableDeclaration(ParseTypeName("var")) - // .AddVariables(VariableDeclarator(variableName) - // .WithInitializer(EqualsValueClause(InvocationExpression( - // MemberAccessExpression( - // SyntaxKind.SimpleMemberAccessExpression, - // IdentifierName($"Setup{index}Accessor"), - // IdentifierName(((ILocalReferenceOperation)ArgumentOperation.Value).Local.Name) - // ) - // ) - // .AddArgumentListArguments(Argument(IdentifierName("target"))))))); - - - //statements[1] = LocalDeclarationStatement(VariableDeclaration(localOperation.Local.Type.AsUnknownTypeSyntax()) - // .AddVariables(VariableDeclarator(variableName) - // .WithInitializer(EqualsValueClause( - // MemberAccessExpression( - // SyntaxKind.SimpleMemberAccessExpression, - // ParenthesizedExpression(CastExpression(ParseTypeName("dynamic"), IdentifierName("target"))), - // IdentifierName(localOperation.Local.Name)))) - // ) - // ); - - // var arg_local = target.GetType().GetField("local").GetValue(target); - - // TODO: This really sucks, but neither other way works - statements[1] = ExpressionStatement( - ParseExpression($"var {variableName} = target.GetType().GetField(\"{localOperation.Local.Name}\")!.GetValue(target)") - ); - - statements[2] = EmitIfCheck(IdentifierName(variableName)); - - return statements; - } - - public IfStatementSyntax EmitIfCheck(ExpressionSyntax right) - { - var binaryExpression = BinaryExpression( - SyntaxKind.NotEqualsExpression, - IdentifierName(ArgumentLocalName), - right - ); - - return CreateArgumentCheck(binaryExpression); - } - - private static IfStatementSyntax CreateArgumentCheck(ExpressionSyntax condition) - { - return IfStatement(condition, Block( - ReturnStatement(LiteralExpression(SyntaxKind.FalseLiteralExpression)) - ) - ); - } - } -} diff --git a/src/Pretender.SourceGenerator/SetupArguments/InvocationArgumentSpec.cs b/src/Pretender.SourceGenerator/SetupArguments/InvocationArgumentSpec.cs new file mode 100644 index 0000000..cc3e3fc --- /dev/null +++ b/src/Pretender.SourceGenerator/SetupArguments/InvocationArgumentSpec.cs @@ -0,0 +1,186 @@ +using System.Collections.Immutable; +using System.Diagnostics; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Operations; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace Pretender.SourceGenerator.SetupArguments +{ + internal class InvocationArgumentSpec : SetupArgumentSpec + { + private readonly IInvocationOperation _invocationOperation; + + public InvocationArgumentSpec( + IInvocationOperation invocationOperation, + IArgumentOperation originalArgument, + int argumentPlacement) + : base(originalArgument, argumentPlacement) + { + _invocationOperation = invocationOperation; + } + + private ImmutableArray? _cachedMatcherStatements; + + public override int NeededMatcherStatements + { + get + { + if (TryGetMatcherAttributeType(out var matcherType)) + { + if (matcherType.EqualsByName(["Pretender", "Matchers", "AnyMatcher"])) + { + _cachedMatcherStatements = ImmutableArray.Empty; + return 0; + } + + var arguments = new ArgumentSyntax[_invocationOperation.Arguments.Length]; + bool allArgumentsSafe = true; + + for (int i = 0; i < arguments.Length; i++) + { + var arg = _invocationOperation.Arguments[i]; + if (arg.Value is ILiteralOperation literalOperation) + { + arguments[i] = Argument(literalOperation.ToLiteralExpression()); + } + else if (arg.Value is IDelegateCreationOperation delegateCreation) + { + + if (delegateCreation.Target is IAnonymousFunctionOperation anonymousFunctionOperation) + { + if (anonymousFunctionOperation.Symbol.IsStatic) // This isn't enough either though, they could call a static method that only exists in their context + { + // If it's guaranteed to be static, we can just rewrite it in our code + arguments[i] = Argument(ParseExpression(delegateCreation.Syntax.GetText().ToString())); + } + else if (false) // Is non-scope capturing + { + // This is a lot more work but also very powerful in terms of speed + // We need to rewrite the delegate and replace all local references with our getter + allArgumentsSafe = false; + } + else + { + // We need a static matcher + allArgumentsSafe = false; + } + } + else + { + allArgumentsSafe = false; + } + } + else + { + allArgumentsSafe = false; + } + } + + if (!allArgumentsSafe) + { + _cachedMatcherStatements = ImmutableArray.Empty; + return 0; + } + + var statements = new StatementSyntax[3]; + var (identifier, accessor) = CreateArgumentAccessor(); + statements[0] = accessor; + + var matcherLocalName = $"{Parameter.Name}_matcher"; + + // var name_matcher = new global::MyMatcher(arg0, arg1); + statements[1] = LocalDeclarationStatement( + VariableDeclaration(ParseTypeName("var")) + .AddVariables(VariableDeclarator(matcherLocalName) + .WithInitializer(EqualsValueClause(ObjectCreationExpression(ParseTypeName(matcherType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))) + .AddArgumentListArguments(arguments)) + ) + ) + ); + + statements[2] = CreateIfCheck( + PrefixUnaryExpression( + SyntaxKind.LogicalNotExpression, + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(matcherLocalName), + IdentifierName("Matches") + ) + ) + .AddArgumentListArguments(Argument(IdentifierName(identifier))) + ) + ); + + _cachedMatcherStatements = ImmutableArray.Create(statements); + return 3; + } + else + { + _cachedMatcherStatements = ImmutableArray.Empty; + return 0; + } + } + } + + public override ImmutableArray CreateMatcherStatements(CancellationToken cancellationToken) + { + Debug.Assert(_cachedMatcherStatements.HasValue, "Should have called NeededStatements first."); + return _cachedMatcherStatements!.Value; + } + + private bool TryGetMatcherAttributeType(out INamedTypeSymbol matcherType) + { + var allAttributes = _invocationOperation.TargetMethod.GetAttributes(); + var matcherAttribute = allAttributes.Single(ad => ad.AttributeClass!.EqualsByName(["Pretender", "Matchers", "MatcherAttribute"])); + + matcherType = null!; + + if (matcherAttribute.AttributeClass!.IsGenericType) + { + // We are in the typed version, get the generic arg + matcherType = (INamedTypeSymbol)matcherAttribute.AttributeClass.TypeArguments[0]; + } + else + { + // We are in the base version, get the constructor arg + // TODO: Make this work + // matcherType = matcherAttribute.ConstructorArguments[0]; + var attributeType = matcherAttribute.ConstructorArguments[0]; + // TODO: When can Type be null? + if (!attributeType.Type!.EqualsByName(["System", "Type"])) + { + return false; + } + + if (attributeType.Value is null) + { + return false; + } + + matcherType = (INamedTypeSymbol)attributeType.Value!; + } + + // TODO: Write a lot more tests for this + if (matcherType.IsUnboundGenericType) + { + if (_invocationOperation.TargetMethod.TypeArguments.Length != matcherType.TypeArguments.Length) + { + return false; + } + + matcherType = matcherType.ConstructedFrom.Construct([.._invocationOperation.TargetMethod.TypeArguments]); + } + + return true; + } + + public override int GetHashCode() + { + // TODO: This is not enought for uniqueness + return SymbolEqualityComparer.Default.GetHashCode(_invocationOperation.TargetMethod); + } + } +} diff --git a/src/Pretender.SourceGenerator/SetupArguments/LiteralArgumentSpec.cs b/src/Pretender.SourceGenerator/SetupArguments/LiteralArgumentSpec.cs new file mode 100644 index 0000000..e953713 --- /dev/null +++ b/src/Pretender.SourceGenerator/SetupArguments/LiteralArgumentSpec.cs @@ -0,0 +1,34 @@ +using System.Collections.Immutable; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Operations; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace Pretender.SourceGenerator.SetupArguments +{ + internal class LiteralArgumentSpec : SetupArgumentSpec + { + private readonly ILiteralOperation _literalOperation; + + public LiteralArgumentSpec(ILiteralOperation literalOperation, IArgumentOperation originalArgument, int argumentPlacement) + : base(originalArgument, argumentPlacement) + { + _literalOperation = literalOperation; + } + + public override int NeededMatcherStatements => 2; + + public override ImmutableArray CreateMatcherStatements(CancellationToken cancellationToken) + { + var (argumentName, localDeclaration) = CreateArgumentAccessor(); + var ifCheck = CreateIfCheck(IdentifierName(argumentName), _literalOperation.ToLiteralExpression()); + return ImmutableArray.Create([localDeclaration, ifCheck]); + } + + public override int GetHashCode() + { + return _literalOperation.ConstantValue.HasValue + ? _literalOperation.ConstantValue.Value?.GetHashCode() ?? 41602 // TODO: Magic value? + : 1337; // TODO: Magic value? + } + } +} diff --git a/src/Pretender.SourceGenerator/SetupArguments/LocalReferenceArgumentSpec.cs b/src/Pretender.SourceGenerator/SetupArguments/LocalReferenceArgumentSpec.cs new file mode 100644 index 0000000..3060e11 --- /dev/null +++ b/src/Pretender.SourceGenerator/SetupArguments/LocalReferenceArgumentSpec.cs @@ -0,0 +1,92 @@ +using System; +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Operations; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace Pretender.SourceGenerator.SetupArguments +{ + internal class LocalReferenceArgumentSpec : SetupArgumentSpec + { + private readonly ILocalReferenceOperation _localReferenceOperation; + + public LocalReferenceArgumentSpec( + ILocalReferenceOperation localReferenceOperation, + IArgumentOperation originalArgument, + int argumentPlacement) : base(originalArgument, argumentPlacement) + { + _localReferenceOperation = localReferenceOperation; + } + + public override int NeededMatcherStatements => 3; + + public override ImmutableArray CreateMatcherStatements(CancellationToken cancellationToken) + { + var variableName = $"{Parameter.Name}_local"; + var (identifier, accessor) = CreateArgumentAccessor(); + + // This is for calling the UnsafeAccessor method that doesn't seem to work for my needs + //statements[1] = LocalDeclarationStatement(VariableDeclaration(ParseTypeName("var")) + // .AddVariables(VariableDeclarator(variableName) + // .WithInitializer(EqualsValueClause(InvocationExpression( + // MemberAccessExpression( + // SyntaxKind.SimpleMemberAccessExpression, + // IdentifierName($"Setup{index}Accessor"), + // IdentifierName(((ILocalReferenceOperation)ArgumentOperation.Value).Local.Name) + // ) + // ) + // .AddArgumentListArguments(Argument(IdentifierName("target"))))))); + + + //statements[1] = LocalDeclarationStatement(VariableDeclaration(localOperation.Local.Type.AsUnknownTypeSyntax()) + // .AddVariables(VariableDeclarator(variableName) + // .WithInitializer(EqualsValueClause( + // MemberAccessExpression( + // SyntaxKind.SimpleMemberAccessExpression, + // ParenthesizedExpression(CastExpression(ParseTypeName("dynamic"), IdentifierName("target"))), + // IdentifierName(localOperation.Local.Name)))) + // ) + // ); + + // var arg_local = target.GetType().GetField("local").GetValue(target); + + // target.GetType() + var getTypeInvocation = InvocationExpression(MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("target"), + IdentifierName("GetType"))); + + // target.GetType().GetField("local")! + var getFieldInvocation = PostfixUnaryExpression(SyntaxKind.SuppressNullableWarningExpression, InvocationExpression(MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + getTypeInvocation, + IdentifierName("GetField"))) + .AddArgumentListArguments(Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(_localReferenceOperation.Local.Name))))); + + var getValueInvocation = CastExpression(_localReferenceOperation.Local.Type.AsUnknownTypeSyntax(), InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + getFieldInvocation, + IdentifierName("GetValue"))) + .AddArgumentListArguments(Argument(IdentifierName("target")))); + + // TODO: This really sucks, but neither other way works + var localDeclaration = LocalDeclarationStatement(VariableDeclaration(ParseTypeName("var"))) + .AddDeclarationVariables(VariableDeclarator(variableName) + .WithInitializer(EqualsValueClause(getValueInvocation))); + //statements[1] = ExpressionStatement( + // ParseExpression($"var {variableName} = target.GetType().GetField(\"{localOperation.Local.Name}\")!.GetValue(target)") + //); + + var ifCheck = CreateIfCheck(IdentifierName(identifier), IdentifierName(variableName)); + return ImmutableArray.Create([accessor, localDeclaration, ifCheck]); + } + + public override int GetHashCode() + { + return SymbolEqualityComparer.Default.GetHashCode(_localReferenceOperation.Local); + } + } +} diff --git a/src/Pretender.SourceGenerator/SetupArguments/SetupArgumentSpec.cs b/src/Pretender.SourceGenerator/SetupArguments/SetupArgumentSpec.cs new file mode 100644 index 0000000..d5abde0 --- /dev/null +++ b/src/Pretender.SourceGenerator/SetupArguments/SetupArgumentSpec.cs @@ -0,0 +1,254 @@ +using System.Collections.Immutable; +using System.Diagnostics; +using System.Reflection; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Operations; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace Pretender.SourceGenerator.SetupArguments +{ + internal abstract class SetupArgumentSpec + { + private readonly List _diagnostics = []; + public SetupArgumentSpec(IArgumentOperation originalArgument, int argumentPlacement) + { + OriginalArgument = originalArgument; + ArgumentPlacement = argumentPlacement; + + var tracker = new ArgumentTracker(); + Visit(originalArgument, tracker); + + NeedsCapturer = tracker.NeedsCapturer; + NeededLocals = tracker.NeededLocals; + } + + protected IArgumentOperation OriginalArgument { get; } + protected IParameterSymbol Parameter => OriginalArgument.Parameter!; + protected int ArgumentPlacement { get; } + protected void AddDiagnostic(Diagnostic diagnostic) + { + _diagnostics.Add(diagnostic); + } + + public IReadOnlyList Diagnostics => _diagnostics; + public bool NeedsCapturer { get; } + public ImmutableArray NeededLocals { get; } + public abstract int NeededMatcherStatements { get; } + + public abstract ImmutableArray CreateMatcherStatements(CancellationToken cancellationToken); + + protected (SyntaxToken Identifier, LocalDeclarationStatementSyntax Accessor) CreateArgumentAccessor() + { + var argumentLocal = Identifier($"{Parameter.Name}_arg"); + + // (string?)callInfo.Arguments[index]; + ExpressionSyntax argumentGetter = CastExpression( + Parameter.Type.AsUnknownTypeSyntax(), + ElementAccessExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName("callInfo"), IdentifierName("Arguments"))) + .AddArgumentListArguments(Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ArgumentPlacement)))) + ); + + var localAccessor = LocalDeclarationStatement(VariableDeclaration(ParseTypeName("var")) + .AddVariables(VariableDeclarator(argumentLocal) + .WithInitializer(EqualsValueClause(argumentGetter)))); + + return (argumentLocal, localAccessor); + } + + protected IfStatementSyntax CreateIfCheck(ExpressionSyntax left, ExpressionSyntax right) + { + var binaryExpression = BinaryExpression( + SyntaxKind.NotEqualsExpression, + left, + right); + + return CreateIfCheck(binaryExpression); + } + + protected IfStatementSyntax CreateIfCheck(ExpressionSyntax condition) + { + return IfStatement(condition, Block( + ReturnStatement(LiteralExpression(SyntaxKind.FalseLiteralExpression)))); + } + + private static void Visit(IOperation? operation, ArgumentTracker tracker) + { + if (operation == null) + { + return; + } + + // TODO: Handle most operations + switch (operation.Kind) + { + case OperationKind.Block: + var block = (IBlockOperation)operation; + VisitMany(block.Operations, tracker); + break; + case OperationKind.VariableDeclarationGroup: + var variableDeclarationGroup = (IVariableDeclarationGroupOperation)operation; + VisitMany(variableDeclarationGroup.Declarations, tracker); + break; + case OperationKind.Return: + var returnOp = (IReturnOperation)operation; + Visit(returnOp.ReturnedValue, tracker); + break; + case OperationKind.Literal: + // Literals are the best, they are easy and the end of the line + break; + case OperationKind.Invocation: + var invocation = (IInvocationOperation)operation; + // The instance could be a local itself + Visit(invocation.Instance, tracker); + VisitMany(invocation.Arguments, tracker); + break; + case OperationKind.LocalReference: + var local = (ILocalReferenceOperation)operation; + tracker.TryTrackLocal(local); + break; + case OperationKind.ParameterReference: + break; + case OperationKind.Binary: + var binary = (IBinaryOperation)operation; + Visit(binary.LeftOperand, tracker); + Visit(binary.RightOperand, tracker); + break; + case OperationKind.AnonymousFunction: + // TODO: I'm not sure if this belongs in here or DelegateCreation but lets go with here for now + tracker.EnterScope(); + var anonymousFunction = (IAnonymousFunctionOperation)operation; + Visit(anonymousFunction.Body, tracker); + tracker.ExitScope(); + break; + case OperationKind.DelegateCreation: + var delegateCreation = (IDelegateCreationOperation)operation; + // TODO: Now that we are in a delegate should we ignore their locals somehow? + Visit(delegateCreation.Target, tracker); + break; + case OperationKind.VariableInitializer: + var variableInitializer = (IVariableInitializerOperation)operation; + tracker.LocalsDefined(variableInitializer.Locals); + // TODO: Not sure if this is right + Visit(variableInitializer.Value, tracker); + break; + case OperationKind.VariableDeclaration: + var variableDeclaration = (IVariableDeclarationOperation)operation; + VisitMany(variableDeclaration.Declarators, tracker); + Visit(variableDeclaration.Initializer, tracker); + break; + case OperationKind.VariableDeclarator: + var variableDeclarator = (IVariableDeclaratorOperation)operation; + tracker.LocalDefined(variableDeclarator.Symbol); + // TODO: IgnoredArguments property? + Visit(variableDeclarator.Initializer, tracker); + break; + case OperationKind.Argument: + var argument = (IArgumentOperation)operation; + Visit(argument.Value, tracker); + break; + default: +#if DEBUG + // TODO: Figure out what operation this is + Debugger.Launch(); + // TODO: Report diagnostic? + // TODO: Do fallback support? by looping over ChildOperations? +#endif + return; + } + } + + private static void VisitMany(IEnumerable operations, ArgumentTracker tracker) + { + foreach (var operation in operations) + { + Visit(operation, tracker); + } + } + + // Factory method for creating an ArgumentSpec based on the argument operation + public static SetupArgumentSpec Create(IArgumentOperation argumentOperation, int argumentPlacement) + { + var argumentOperationValue = argumentOperation.Value; + switch (argumentOperationValue.Kind) + { + case OperationKind.Literal: + return new LiteralArgumentSpec( + (ILiteralOperation)argumentOperationValue, + argumentOperation, + argumentPlacement); + case OperationKind.Invocation: + return new InvocationArgumentSpec( + (IInvocationOperation)argumentOperationValue, + argumentOperation, + argumentPlacement); + case OperationKind.LocalReference: + return new LocalReferenceArgumentSpec( + (ILocalReferenceOperation)argumentOperationValue, + argumentOperation, + argumentPlacement); + default: + throw new NotImplementedException(); + } + } + + private class ArgumentTracker + { + private readonly List _neededLocals = new(); + private readonly Stack> _trackedLocals = new(); + + public ArgumentTracker() + { + _trackedLocals = new Stack>(); + _trackedLocals.Push(new HashSet(SymbolEqualityComparer.Default)); + } + + public ImmutableArray NeededLocals => _neededLocals.ToImmutableArray(); + public bool NeedsCapturer { get; private set; } + public void SetNeedsCapturer() + { + NeedsCapturer = true; + } + + public bool TryTrackLocal(ILocalReferenceOperation localReferenceOperation) + { + var currentScope = _trackedLocals.Peek(); + if (currentScope.Contains(localReferenceOperation.Local)) + { + // This is being tracked as created during the current scope, ignore it + return false; + } + + _neededLocals.Add(localReferenceOperation); + return true; + } + + public void LocalDefined(ILocalSymbol local) + { + var currentScope = _trackedLocals.Peek(); + currentScope.Add(local); + } + + public void LocalsDefined(IEnumerable locals) + { + var currentScope = _trackedLocals.Peek(); + foreach (var local in locals) + { + currentScope.Add(local); + } + } + + // TODO: could create an IDisposable for this + public void EnterScope() + { + _trackedLocals.Push(new(SymbolEqualityComparer.Default)); + } + + public void ExitScope() + { + _trackedLocals.Pop(); + } + } + } +} diff --git a/src/Pretender.SourceGenerator/SetupCreationSpec.cs b/src/Pretender.SourceGenerator/SetupCreationSpec.cs new file mode 100644 index 0000000..1538d78 --- /dev/null +++ b/src/Pretender.SourceGenerator/SetupCreationSpec.cs @@ -0,0 +1,363 @@ +using System; +using System.Collections.Immutable; +using System.Diagnostics; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Operations; +using Pretender.SourceGenerator.SetupArguments; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace Pretender.SourceGenerator +{ + internal class SetupCreationSpec + { + private readonly IArgumentOperation _setupArgument; + private readonly ITypeSymbol _pretendType; + private readonly bool _useSetMethod; + + private readonly IMethodSymbol? _setupMethod; + private readonly ImmutableArray _argumentSpecs; + + public SetupCreationSpec(IArgumentOperation setupArgument, ITypeSymbol pretendType, bool useSetMethod) + { + _setupArgument = setupArgument; + _pretendType = pretendType; + _useSetMethod = useSetMethod; + + var candidates = GetInvocationCandidates(); + + if (candidates.Length == 0) + { + // TODO: Add diagnostic + return; + } + else if (candidates.Length != 1) + { + // TODO: Add diagnostic + return; + } + + var candidate = candidates[0]; + _setupMethod = candidate.Method; + + var builder = ImmutableArray.CreateBuilder(candidate.Arguments.Length); + for (var i = 0; i < candidate.Arguments.Length; i++) + { + builder.Add(SetupArgumentSpec.Create(candidate.Arguments[i], i)); + } + + _argumentSpecs = builder.MoveToImmutable(); + + // TODO: Get argument specs diagnostics and make them my own + } + + private ImmutableArray GetInvocationCandidates() + { + var builder = ImmutableArray.CreateBuilder(); + TraverseOperation(_setupArgument.Value, builder); + return builder.ToImmutable(); + } + + private void TraverseOperation(IOperation operation, ImmutableArray.Builder invocationCandidates) + { + switch (operation.Kind) + { + case OperationKind.Block: + var blockOperation = (IBlockOperation)operation; + TraverseOperationList(blockOperation.Operations, invocationCandidates); + break; + case OperationKind.Return: + var returnOperation = (IReturnOperation)operation; + if (returnOperation.ReturnedValue != null) + { + TraverseOperation(returnOperation.ReturnedValue, invocationCandidates); + } + break; + case OperationKind.ExpressionStatement: + var expressionStatement = (IExpressionStatementOperation)operation; + TraverseOperation(expressionStatement.Operation, invocationCandidates); + break; + case OperationKind.Conversion: + var conversionOperation = (IConversionOperation)operation; + TraverseOperation(conversionOperation.Operand, invocationCandidates); + break; + case OperationKind.Invocation: + var invocationOperation = (IInvocationOperation)operation; + TryMatchInvocationOperation(invocationOperation, invocationCandidates); + break; + case OperationKind.PropertyReference: + var propertyReferenceOperation = (IPropertyReferenceOperation)operation; + TryMatchPropertyReference(propertyReferenceOperation, invocationCandidates); + break; + case OperationKind.AnonymousFunction: + var anonymousFunctionOperation = (IAnonymousFunctionOperation)operation; + TraverseOperation(anonymousFunctionOperation.Body, invocationCandidates); + break; + case OperationKind.DelegateCreation: + var delegateCreationOperation = (IDelegateCreationOperation)operation; + TraverseOperation(delegateCreationOperation.Target, invocationCandidates); + break; + default: +#if DEBUG + // TODO: Figure out what operation caused this, it's not ideal to "randomly" support operations + Debugger.Launch(); +#endif + // Absolute fallback, most of our operations can be supported this way but it's nicer to be explicit + TraverseOperationList(operation.ChildOperations, invocationCandidates); + break; + } + } + + private void TraverseOperationList(IEnumerable operations, ImmutableArray.Builder invocationCandidates) + { + foreach (var operation in operations) + { + TraverseOperation(operation, invocationCandidates); + } + } + + private void TryMatchPropertyReference(IPropertyReferenceOperation propertyReference, ImmutableArray.Builder invocationCandidates) + { + if (propertyReference.Instance is not IParameterReferenceOperation parameterReference) + { + return; + } + + if (!SymbolEqualityComparer.Default.Equals(parameterReference.Type, _pretendType)) + { + return; + } + + var method = _useSetMethod + ? propertyReference.Property.SetMethod + : propertyReference.Property.GetMethod; + + if (method == null) + { + return; + } + + invocationCandidates.Add(new InvocationCandidate(method, ImmutableArray.Empty)); + } + + private void TryMatchInvocationOperation(IInvocationOperation invocation, ImmutableArray.Builder invocationCandidates) + { + if (invocation.Instance is not IParameterReferenceOperation parameterReference) + { + return; + } + + if (!SymbolEqualityComparer.Default.Equals(parameterReference.Type, _pretendType)) + { + return; + } + + // TODO: Any more validation? + + invocationCandidates.Add(new InvocationCandidate(invocation.TargetMethod, invocation.Arguments)); + } + + public InvocationExpressionSyntax CreateSetupGetter(CancellationToken cancellationToken) + { + Debug.Assert(_setupMethod is not null, "A setup method could not be found, which means there should have been error diagnostics and this method should not have ran."); + + var totalMatchStatements = _argumentSpecs.Sum(sa => sa.NeededMatcherStatements); + cancellationToken.ThrowIfCancellationRequested(); + + var matchStatements = new StatementSyntax[totalMatchStatements]; + int addedStatements = 0; + + for (var i = 0; i < _argumentSpecs.Length; i++) + { + var argument = _argumentSpecs[i]; + + var newStatements = argument.CreateMatcherStatements(cancellationToken); + cancellationToken.ThrowIfCancellationRequested(); + + newStatements.CopyTo(matchStatements, addedStatements); + addedStatements += newStatements.Length; + } + + ArgumentSyntax matcherArgument; + ImmutableArray statements; + if (matchStatements.Length == 0) + { + statements = ImmutableArray.Empty; + + // Nothing actually needs to match this will always return true, so we use a cached matcher that always returns true + matcherArgument = Argument(MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + ParseTypeName("Cache"), + IdentifierName("NoOpMatcher"))) + .WithNameColon(NameColon("matcher")); + } + else + { + // Other match statements should have added all the ways the method could return false + // so if it gets through all those statements it should return true at the end. + var trueReturnStatement = ReturnStatement(LiteralExpression(SyntaxKind.TrueLiteralExpression)); + + /* + * Matcher matchCall = static (callInfo, target) => + * { + * ...matching calls... + * return true; + * } + */ + var matchCallIdentifier = Identifier("matchCall"); + + var matcherDelegate = ParenthesizedLambdaExpression( + ParameterList(SeparatedList([ + Parameter(Identifier("callInfo")), + Parameter(Identifier("target")) + ])), + Block(List([.. matchStatements, trueReturnStatement]))) + .WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword))); + + statements = ImmutableArray.Create(LocalDeclarationStatement(VariableDeclaration( + ParseTypeName("Matcher")) + .WithVariables(SingletonSeparatedList( + VariableDeclarator(matchCallIdentifier) + .WithInitializer(EqualsValueClause(matcherDelegate)))))); + + matcherArgument = Argument(IdentifierName(matchCallIdentifier)); + } + + cancellationToken.ThrowIfCancellationRequested(); + + var objectCreationArguments = ArgumentList( + SeparatedList(new[] + { + Argument(IdentifierName("pretend")), + //Argument(IdentifierName("setupExpression")), + Argument(MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(_pretendType.ToPretendName()), + IdentifierName(_setupMethod!.ToMethodInfoCacheName()) + )), + matcherArgument, + Argument(MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("expr"), + IdentifierName("Target") + )), + })); + + cancellationToken.ThrowIfCancellationRequested(); + + GenericNameSyntax returnObjectName; + SimpleNameSyntax getOrCreateName; + if (_setupMethod!.ReturnsVoid) + { + // VoidCompiledSetup + returnObjectName = GenericName("VoidCompiledSetup") + .AddTypeArgumentListArguments(ParseTypeName(_pretendType.ToFullDisplayString())); + + getOrCreateName = IdentifierName("GetOrCreateSetup"); + } + else + { + // ReturningCompiledSetup + returnObjectName = GenericName("ReturningCompiledSetup") + .AddTypeArgumentListArguments( + ParseTypeName(_pretendType.ToFullDisplayString()), + _setupMethod.ReturnType.AsUnknownTypeSyntax()); + + getOrCreateName = GenericName("GetOrCreateSetup") + .AddTypeArgumentListArguments(_setupMethod.ReturnType.AsUnknownTypeSyntax()); + + // TODO: Recursively mock? + ExpressionSyntax defaultValue; + + if (_setupMethod.ReturnType.EqualsByName(["System", "Threading", "Tasks", "Task"])) + { + if (_setupMethod.ReturnType is INamedTypeSymbol namedType && namedType.TypeArguments.Length == 1) + { + // Task.FromResult(default) + defaultValue = KnownBlocks.TaskFromResult( + namedType.TypeArguments[0].AsUnknownTypeSyntax(), + LiteralExpression(SyntaxKind.DefaultLiteralExpression)); + } + else + { + // Task.CompletedTask + defaultValue = KnownBlocks.TaskCompletedTask; + } + } + else if (_setupMethod.ReturnType.EqualsByName(["System", "Threading", "Tasks", "ValueTask"])) + { + if (_setupMethod.ReturnType is INamedTypeSymbol namedType && namedType.TypeArguments.Length == 1) + { + // ValueTask.FromResult(default) + defaultValue = KnownBlocks.ValueTaskFromResult( + namedType.TypeArguments[0].AsUnknownTypeSyntax(), + LiteralExpression(SyntaxKind.DefaultLiteralExpression) + ); + } + else + { + // ValueTask.CompletedTask + defaultValue = KnownBlocks.ValueTaskCompletedTask; + } + } + else + { + // TODO: Support custom awaitable + // default + defaultValue = LiteralExpression(SyntaxKind.DefaultLiteralExpression); + } + + cancellationToken.ThrowIfCancellationRequested(); + + objectCreationArguments = objectCreationArguments.AddArguments(Argument( + defaultValue).WithNameColon(NameColon("defaultValue"))); + } + + cancellationToken.ThrowIfCancellationRequested(); + + var compiledSetupCreation = ObjectCreationExpression(returnObjectName) + .WithArgumentList(objectCreationArguments); + + // (pretend, expression) => + // { + // return new CompiledSetup(); + // } + var creator = ParenthesizedLambdaExpression() + .WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword))) + .AddParameterListParameters(Parameter(Identifier("pretend")), Parameter(Identifier("expr"))) + .AddBlockStatements([.. statements, ReturnStatement(compiledSetupCreation)]); + + // TODO: The hash code doesn't actually work, right now, this will create a new pretend every call. + // We likely need to create our own class that can calculate the hash code and place that number in here. + + cancellationToken.ThrowIfCancellationRequested(); + // TODO: Should I have a different seed? + //var badHashCode = _argumentSpecs.Aggregate(0, (agg, s) => HashCode.Combine(agg, s.GetHashCode())); + var badHashCode = 0; + + cancellationToken.ThrowIfCancellationRequested(); + + // pretend.GetOrCreateSetup() + return InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("pretend"), + getOrCreateName)) + .AddArgumentListArguments( + Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(badHashCode))), + Argument(creator), + Argument(IdentifierName("setupExpression"))); + } + + private class InvocationCandidate + { + public InvocationCandidate(IMethodSymbol methodSymbol, ImmutableArray argumentOperations) + { + Method = methodSymbol; + Arguments = argumentOperations; + } + + public IMethodSymbol Method { get; } + public ImmutableArray Arguments { get; } + } + } +} diff --git a/src/Pretender.SourceGenerator/SetupEntrypoint.cs b/src/Pretender.SourceGenerator/SetupEntrypoint.cs index 3df09d2..71a2155 100644 --- a/src/Pretender.SourceGenerator/SetupEntrypoint.cs +++ b/src/Pretender.SourceGenerator/SetupEntrypoint.cs @@ -1,12 +1,10 @@ -using System; -using System.Collections.Generic; -using System.Collections.Immutable; +using System.Collections.Immutable; using System.Diagnostics; -using System.Linq; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Operations; +using Pretender.SourceGenerator.SetupArguments; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; namespace Pretender.SourceGenerator @@ -17,6 +15,7 @@ public SetupEntrypoint(IInvocationOperation invocationOperation) { OriginalInvocation = invocationOperation; var setupExpressionArg = invocationOperation.Arguments[0]; + SetupExpression = setupExpressionArg; Debug.Assert(invocationOperation.Type is INamedTypeSymbol, "This should have been asserted via making sure it's the right invocation."); @@ -24,8 +23,11 @@ public SetupEntrypoint(IInvocationOperation invocationOperation) PretendType = pretendType; - var setupMethod = SimplifyOperation(setupExpressionArg.Value); + // TODO: Use correct useSetup value + SetupCreation = new SetupCreationSpec(setupExpressionArg, pretendType, false); + // TODO: Consume diagnostics + var setupMethod = SimplifyOperation(setupExpressionArg.Value); if (setupMethod == default) { @@ -45,235 +47,42 @@ public SetupEntrypoint(IInvocationOperation invocationOperation) SetupMethod = setupMethod.Method; - if (setupMethod.Arguments != default) - { - var setupArguments = new SetupArgument[setupMethod.Arguments.Length]; - for (int i = 0; i < setupArguments.Length; i++) - { - setupArguments[i] = new SetupArgument(setupMethod.Arguments[i], i, Diagnostics); - } - Arguments = setupArguments.ToImmutableArray(); - } - else + var setupArguments = new SetupArgumentSpec[setupMethod.Arguments.Length]; + for (int i = 0; i < setupArguments.Length; i++) { - Arguments = []; + setupArguments[i] = SetupArgumentSpec.Create(setupMethod.Arguments[i], i); } + Arguments = setupArguments.ToImmutableArray(); + + Diagnostics.AddRange(Arguments.SelectMany(s => s.Diagnostics)); } + public IArgumentOperation SetupExpression { get; } public IInvocationOperation OriginalInvocation { get; } public ITypeSymbol PretendType { get; } public List Diagnostics { get; } = new List(); public IMethodSymbol SetupMethod { get; } = null!; - public ImmutableArray Arguments { get; } + public SetupCreationSpec SetupCreation { get; } + public ImmutableArray Arguments { get; } public MemberDeclarationSyntax[] GetMembers(int index) { var allMembers = new List(); - var statements = new List(); - - var returnTypeString = SetupMethod.ReturnsVoid - ? null - : SetupMethod.ReturnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var interceptsLocation = new InterceptsLocationInfo(OriginalInvocation); - var typeArgumentList = SetupMethod.ReturnsVoid - ? TypeArgumentList(SingletonSeparatedList(ParseTypeName(PretendType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)))) - : TypeArgumentList(SeparatedList([ParseTypeName(PretendType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)), ParseTypeName(returnTypeString!)])); + // TODO: This is wrong + var typeArguments = SetupMethod.ReturnsVoid + ? TypeArgumentList(SingletonSeparatedList(ParseTypeName(PretendType.ToFullDisplayString()))) + : TypeArgumentList(SeparatedList([ParseTypeName(PretendType.ToFullDisplayString()), SetupMethod.ReturnType.AsUnknownTypeSyntax()])); var returnType = GenericName("IPretendSetup") - .WithTypeArgumentList(typeArgumentList); - - var matchStatements = new List(); - - // Match method info first - - if (Arguments != default) - { - var distinctLocals = Arguments - .SelectMany(s => s.RequiredLocals) - .Select(l => l.Local) - .Distinct(SymbolEqualityComparer.Default) - .Cast(); - - if (false) - { - var methods = distinctLocals.Select(l => MethodDeclaration( - l.Type.AsUnknownTypeSyntax(), - l.Name - ) - .AddModifiers(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.ExternKeyword), Token(SyntaxKind.StaticKeyword)) - .AddParameterListParameters(Parameter(Identifier("target")).WithType(ParseTypeName("object?"))) - .AddAttributeLists(AttributeList(SingletonSeparatedList( - Attribute(IdentifierName("UnsafeAccessor")) - .AddArgumentListArguments(AttributeArgument(MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("UnsafeAccessorKind"), - IdentifierName("Field") - ))))) - ) - .WithSemicolonToken(Token(SyntaxKind.SemicolonToken))); - - // We have captured locals - // make a static class for - var accessorClass = ClassDeclaration($"Setup{index}Accessor") - .AddModifiers(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.StaticKeyword)) - .AddMembers(methods.ToArray()); - - allMembers.Add(accessorClass); - } - - for (var i = 0; i < Arguments.Length; i++) - { - var argument = Arguments[i]; - - if (argument.IsLiteral) - { - matchStatements.Add(argument.EmitArgumentAccessor()); - matchStatements.Add(argument.EmitIfCheck(((ILiteralOperation)argument.ArgumentOperation.Value).ToLiteralExpression())); - } - else if (argument.IsInvocation) - { - if (argument.TryEmitInvocationStatements(out var invocationStatements)) - { - matchStatements.AddRange(invocationStatements); - continue; - } - } - else if (argument.IsLocalReference) - { - matchStatements.AddRange(argument.EmitLocalIfCheck(index)); - } - else - { - // TODO: Have this but also have a lot more support for different arguments - // throw new NotImplementedException($"We have not implemented arguments of kind '{argument.Kind}', please file an issue."); - } - // TODO: More Argument types - } - } + .WithTypeArgumentList(typeArguments); - ArgumentSyntax matcherArgument; - if (matchStatements.Count == 0) - { - // Nothing actually needs to match this will always return true. - matcherArgument = Argument(MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - ParseTypeName("Cache"), - IdentifierName("NoOpMatcher"))) - .WithNameColon(NameColon("matcher")); - } - else - { - matchStatements.Add(ReturnStatement(LiteralExpression(SyntaxKind.TrueLiteralExpression))); - - /* - * Matcher matchCall = static (callInfo, target) => - * { - * ...matching calls... - * return true; - * } - */ - var matcherDelegate = ParenthesizedLambdaExpression( - ParameterList(SeparatedList([ - Parameter(Identifier("callInfo")), - Parameter(Identifier("target")) - ])), - Block(matchStatements)) - .WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword))); - - statements.Add(LocalDeclarationStatement(VariableDeclaration( - ParseTypeName("Matcher")) - .WithVariables(SingletonSeparatedList( - VariableDeclarator("matchCall") - .WithInitializer(EqualsValueClause(matcherDelegate)))))); - - matcherArgument = Argument(IdentifierName("matchCall")); - } - - GenericNameSyntax returnObjectType; - var objectCreationArguments = ArgumentList( - SeparatedList(new[] - { - Argument(IdentifierName("pretend")), - Argument(IdentifierName("setupExpression")), - Argument(MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(PretendType.ToPretendName()), - IdentifierName(SetupMethod.ToMethodInfoCacheName()) - )), - matcherArgument, - Argument(MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("setupExpression"), - IdentifierName("Target") - )), - })); - - if (SetupMethod.ReturnsVoid) - { - returnObjectType = GenericName("VoidCompiledSetup") - .WithTypeArgumentList(typeArgumentList); - } - else - { - returnObjectType = GenericName("ReturningCompiledSetup") - .WithTypeArgumentList(typeArgumentList); - - ExpressionSyntax additionalArgument; - - if (SetupMethod.ReturnType.EqualsByName(["System", "Threading", "Tasks", "Task"])) - { - if (SetupMethod.ReturnType is INamedTypeSymbol namedType && namedType.TypeArguments.Length == 1) - { - additionalArgument = ParseExpression($"Task.FromResult<{namedType.TypeArguments[0].ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}>(default)"); - } - else - { - additionalArgument = ParseExpression("Task.CompletedTask"); - } - } - else if (SetupMethod.ReturnType.EqualsByName(["System", "Threading", "Tasks", "ValueTask"])) - { - if (SetupMethod.ReturnType is INamedTypeSymbol namedType && namedType.TypeArguments.Length == 1) - { - additionalArgument = ParseExpression($"ValueTask.FromResult<{namedType.TypeArguments[0].ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}>(default)"); - } - else - { - additionalArgument = ParseExpression("ValueTask.CompletedTask"); - } - } - else - { - // TODO: Support custom awaitable - additionalArgument = ParseExpression("default"); - } - - objectCreationArguments = objectCreationArguments.AddArguments(Argument(additionalArgument) - .WithNameColon(NameColon("defaultValue"))); - } - - var compiledSetupCreation = ObjectCreationExpression(returnObjectType) - .WithArgumentList(objectCreationArguments); - - // var setup = new CompiledSetup(pretend, setupExpression, matchCall); - statements.Add(LocalDeclarationStatement(VariableDeclaration(ParseTypeName("var")) - .WithVariables(SingletonSeparatedList(VariableDeclarator("setup") - .WithInitializer(EqualsValueClause(compiledSetupCreation)))))); - - // pretend.Add(setup); - statements.Add(ExpressionStatement(InvocationExpression( - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName("pretend"), IdentifierName("AddSetup")), - ArgumentList(SingletonSeparatedList(Argument(IdentifierName("setup"))))))); - - var returnSetupCall = ReturnStatement(IdentifierName("setup")); - - statements.Add(returnSetupCall); - - var interceptsLocation = new InterceptsLocationInfo(OriginalInvocation); + var setupInvocation = SetupCreation.CreateSetupGetter(default); var setupMethod = MethodDeclaration(returnType, $"Setup{index}") - .WithBody(Block(statements.ToArray())) + .WithBody(Block(ReturnStatement(setupInvocation))) .WithParameterList(ParameterList(SeparatedList(new[] { Parameter(Identifier("pretend")) @@ -281,7 +90,7 @@ public MemberDeclarationSyntax[] GetMembers(int index) .WithType(ParseTypeName($"Pretend<{PretendType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}>")), Parameter(Identifier("setupExpression")) - .WithType(GenericName(SetupMethod.ReturnsVoid ? "Action" : "Func").WithTypeArgumentList(typeArgumentList)), + .WithType(GenericName(SetupMethod.ReturnsVoid ? "Action" : "Func").WithTypeArgumentList(typeArguments)), }))) .WithModifiers(TokenList(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.StaticKeyword))) .WithAttributeLists(SingletonList(AttributeList( @@ -291,66 +100,6 @@ public MemberDeclarationSyntax[] GetMembers(int index) return [.. allMembers]; } - private void ValidateArgument(IArgumentOperation operation) - { - var value = operation.Value; - - var hasSupport = value switch - { - ILiteralOperation => true, - // TODO: It matchers - IInvocationOperation invocationOperation => ValidateInvocationOperation(invocationOperation), - _ => false, - }; - } - - private bool ValidateInvocationOperation(IInvocationOperation operation) - { - if (operation.Instance != null) - { - // TODO: Make its own descriptor and offer fixer - Diagnostics.Add(Diagnostic.Create( - DiagnosticDescriptors.InvalidSetupArgument, - operation.Syntax.GetLocation(), - "Instance invocation" - )); - return false; - } - - if (!operation.TargetMethod.IsStatic) - { - Diagnostics.Add(Diagnostic.Create( - DiagnosticDescriptors.InvalidSetupArgument, - operation.Syntax.GetLocation(), - "invocation where the method is not static." - )); - return false; - } - - // TODO: Validate owning type and check for attributes - var attributes = operation.TargetMethod.GetAttributes(); - - // TODO: When can attribute class be null? - // TODO: Validate this a little more - var matcherAttribute = attributes.SingleOrDefault( - ad => ad.AttributeClass!.Name == "MatcherAttribute"); - - if (matcherAttribute is null) - { - // TODO: Make this be it's own descriptor - Diagnostics.Add(Diagnostic.Create( - DiagnosticDescriptors.InvalidSetupArgument, - operation.Syntax.GetLocation(), - "Static invocation with no matcher attribute on method" - )); - return false; - } - - // TODO: Validate the matcher attribute further - - return true; - } - private (IMethodSymbol Method, ImmutableArray Arguments) SimplifyBlockOperation(IBlockOperation operation) { foreach (var childOperation in operation.Operations) @@ -380,16 +129,15 @@ private bool ValidateInvocationOperation(IInvocationOperation operation) // TODO: Support more operations return operation.Kind switch { - OperationKind.Return => SimplifyReturnOperation((IReturnOperation)operation), - OperationKind.Conversion => SimplifyOperation(((IConversionOperation)operation).Operand), OperationKind.Block => SimplifyBlockOperation((IBlockOperation)operation), - OperationKind.AnonymousFunction => SimplifyOperation(((IAnonymousFunctionOperation)operation).Body), - OperationKind.Invocation => TryMethod((IInvocationOperation)operation), + OperationKind.Return => SimplifyReturnOperation((IReturnOperation)operation), // ExpressionStatement is probably a dead path now but who cares OperationKind.ExpressionStatement => SimplifyOperation(((IExpressionStatementOperation)operation).Operation), - OperationKind.DelegateCreation => SimplifyOperation(((IDelegateCreationOperation)operation).Target), - // TODO: Do something for SetupSet to get the set method instead + OperationKind.Conversion => SimplifyOperation(((IConversionOperation)operation).Operand), + OperationKind.Invocation => TryMethod((IInvocationOperation)operation), OperationKind.PropertyReference => TryProperty((IPropertyReferenceOperation)operation), + OperationKind.AnonymousFunction => SimplifyOperation(((IAnonymousFunctionOperation)operation).Body), + OperationKind.DelegateCreation => SimplifyOperation(((IDelegateCreationOperation)operation).Target), _ => default, }; } @@ -418,7 +166,7 @@ private bool ValidateInvocationOperation(IInvocationOperation operation) } // I still don't return arguments for a property setter right? - return (method, []); + return (method, ImmutableArray.Empty); } private (IMethodSymbol Method, ImmutableArray Arguments) TryMethod(IInvocationOperation operation) diff --git a/src/Pretender.SourceGenerator/SymbolExtensions.cs b/src/Pretender.SourceGenerator/SymbolExtensions.cs index d338d28..0d81068 100644 --- a/src/Pretender.SourceGenerator/SymbolExtensions.cs +++ b/src/Pretender.SourceGenerator/SymbolExtensions.cs @@ -13,42 +13,6 @@ namespace Pretender.SourceGenerator { internal static class SymbolExtensions { - public static IEnumerable> GetGroupedMethods(this ITypeSymbol type) - { - return type.GetMembers() - .OfType() - .GroupBy(m => m.Name); - } - - public static IEnumerable<(IMethodSymbol MethodSymbol, MethodDeclarationSyntax MethodDeclaration)> GetEquivalentMethodSignatures(this IEnumerable methods) - { - foreach (var method in methods) - { - var methodDeclaration = MethodDeclaration( - returnType: ParseTypeName(method.ReturnType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)), - identifier: method.Name) - .WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword))) // TODO: Are there other modifiers we need to copy? - .AddParameterListParameters(method.Parameters.Select(GetParameter).ToArray()) - .WithInheritDoc(); - - yield return (method, methodDeclaration); - } - - static ParameterSyntax GetParameter(IParameterSymbol parameter) - { - var parameterSyntax = Parameter(Identifier(parameter.Name)) - .WithType(ParseTypeName(parameter.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))); - - if (parameter.HasExplicitDefaultValue) - { - // TODO: Support default parameters - throw new NotImplementedException("Default parameters are not supported yet."); - } - - return parameterSyntax; - } - } - public static bool EqualsByName(this ITypeSymbol type, string[] name) { var length = name.Length; @@ -67,14 +31,15 @@ public static bool EqualsByName(this ITypeSymbol type, string[] name) } targetNamespace = targetNamespace.ContainingNamespace; } + // Once all namespace parts have been enumerated // we should be in the global namespace - if (targetNamespace.IsGlobalNamespace) + if (!targetNamespace.IsGlobalNamespace) { - return true; + return false; } - return false; + return true; } public static TypeSyntax AsUnknownTypeSyntax(this ITypeSymbol type) @@ -278,7 +243,7 @@ public static ParameterSyntax ToParameterSyntax(this IParameterSymbol parameter) if (parameter.HasExplicitDefaultValue) { parameterSyntax = parameterSyntax - .WithDefault(EqualsValueClause(parameter.ExplicitDefaultValue.ToLiteralExpression())); + .WithDefault(EqualsValueClause(parameter.ToLiteralExpression())); } var modifiers = new List(); @@ -306,7 +271,13 @@ public static ParameterSyntax ToParameterSyntax(this IParameterSymbol parameter) return parameterSyntax; } - public static LiteralExpressionSyntax ToLiteralExpression(this object? value) + public static LiteralExpressionSyntax ToLiteralExpression(this IParameterSymbol parameterSymbol) + { + Debug.Assert(parameterSymbol.HasExplicitDefaultValue); + return ToLiteralExpression(parameterSymbol.ExplicitDefaultValue); + } + + private static LiteralExpressionSyntax ToLiteralExpression(object? value) { if (value == null) { diff --git a/src/Pretender.SourceGenerator/VerifyInvocation.cs b/src/Pretender.SourceGenerator/VerifyInvocation.cs new file mode 100644 index 0000000..8ac5768 --- /dev/null +++ b/src/Pretender.SourceGenerator/VerifyInvocation.cs @@ -0,0 +1,62 @@ +using System.Diagnostics; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Operations; +using Pretender.SourceGenerator.Parser; + +namespace Pretender.SourceGenerator +{ + // This should be a simple class just holding some information, deeper introspection to find diagnostics should + // be done with a type cache + internal class VerifyInvocation + { + public VerifyInvocation(IInvocationOperation operation, Location location) + { + Operation = operation; + Location = location; + } + + public IInvocationOperation Operation { get; } + public Location Location { get; } + + public static bool IsCandidateSyntaxNode(SyntaxNode node) + { + return node is InvocationExpressionSyntax + { + // pretend.Verify(i => i.Something(), 2); + Expression: MemberAccessExpressionSyntax + { + Name.Identifier.ValueText: "Verify", // TODO: or VerifySet + }, + ArgumentList.Arguments.Count: 2 + }; + } + + public static VerifyInvocation? Create(GeneratorSyntaxContext context, CancellationToken cancellationToken) + { + Debug.Assert(IsCandidateSyntaxNode(context.Node)); + var invocationSyntax = (InvocationExpressionSyntax)context.Node; + + return context.SemanticModel.GetOperation(invocationSyntax, cancellationToken) is IInvocationOperation operation + && IsVerifyOperation(operation) + ? new VerifyInvocation(operation, invocationSyntax.GetLocation()) + : null; + } + + private static bool IsVerifyOperation(IInvocationOperation operation) + { + // TODO: Verify ALL of the things, no false positives should escape here + // but we should do it all with string comparisons + if (operation.TargetMethod is not IMethodSymbol + { + Name: "Verify", // TODO: or VerifySet, + ContainingType: INamedTypeSymbol namedTypeSymbol + } || !KnownTypeSymbols.IsPretend(namedTypeSymbol)) + { + return false; + } + + return true; + } + } +} diff --git a/src/Pretender/Argument.cs b/src/Pretender/Argument.cs new file mode 100644 index 0000000..5d2887f --- /dev/null +++ b/src/Pretender/Argument.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Pretender +{ + public struct Argument + { + private readonly Type _declaredType; + private object? _value; + + public Argument(Type declaredType, object? value) + { + _declaredType = declaredType; + _value = value; + } + + public readonly Type DeclaredType => _declaredType; + public object? Value + { + readonly get { return _value; } + set { _value = value; } + } + public readonly Type ActualType => _value != null ? _value.GetType() : _declaredType; + } +} diff --git a/src/Pretender/Behavior.cs b/src/Pretender/Behavior.cs index 65697af..43d095d 100644 --- a/src/Pretender/Behavior.cs +++ b/src/Pretender/Behavior.cs @@ -2,6 +2,6 @@ { public abstract class Behavior { - public abstract void Execute(ref CallInfo callInfo); + public abstract void Execute(CallInfo callInfo); } } diff --git a/src/Pretender/Behaviors/CallbackBehavior.cs b/src/Pretender/Behaviors/CallbackBehavior.cs index 028befa..dfd4af0 100644 --- a/src/Pretender/Behaviors/CallbackBehavior.cs +++ b/src/Pretender/Behaviors/CallbackBehavior.cs @@ -11,7 +11,7 @@ public CallbackBehavior(Callback action) _action = action; } - public override void Execute(ref CallInfo callInfo) + public override void Execute(CallInfo callInfo) { _action(ref callInfo); } diff --git a/src/Pretender/Behaviors/ReturnValueBehavior.cs b/src/Pretender/Behaviors/ReturnValueBehavior.cs index 43b9abe..1c58ee9 100644 --- a/src/Pretender/Behaviors/ReturnValueBehavior.cs +++ b/src/Pretender/Behaviors/ReturnValueBehavior.cs @@ -8,7 +8,7 @@ public ReturnValueBehavior(object? value) { _value = value; } - public override void Execute(ref CallInfo callInfo) + public override void Execute(CallInfo callInfo) { callInfo.ReturnValue = _value; } diff --git a/src/Pretender/Behaviors/ThrowBehavior.cs b/src/Pretender/Behaviors/ThrowBehavior.cs index d912caa..d656e50 100644 --- a/src/Pretender/Behaviors/ThrowBehavior.cs +++ b/src/Pretender/Behaviors/ThrowBehavior.cs @@ -9,7 +9,7 @@ public ThrowBehavior(Exception exception) _exception = exception; } - public override void Execute(ref CallInfo callInfo) + public override void Execute(CallInfo callInfo) { throw _exception; } diff --git a/src/Pretender/CallInfo.cs b/src/Pretender/CallInfo.cs index d986fb1..d74eb4b 100644 --- a/src/Pretender/CallInfo.cs +++ b/src/Pretender/CallInfo.cs @@ -2,10 +2,16 @@ namespace Pretender { - public ref struct CallInfo(MethodInfo methodInfo, Span arguments) + public class CallInfo { - public MethodInfo MethodInfo { get; } = methodInfo; - public Span Arguments { get; } = arguments; + public CallInfo(MethodInfo methodInfo, object?[] arguments) + { + MethodInfo = methodInfo; + Arguments = arguments; + } + + public MethodInfo MethodInfo { get; } + public object?[] Arguments { get; } public object? ReturnValue { get; set; } } } diff --git a/src/Pretender/Called.cs b/src/Pretender/Called.cs new file mode 100644 index 0000000..1707368 --- /dev/null +++ b/src/Pretender/Called.cs @@ -0,0 +1,44 @@ +namespace Pretender +{ + public readonly struct Called + { + private readonly int _from; + private readonly int _to; + private readonly CalledKind _calledKind; + + private Called(int from, int to, CalledKind calledKind) + { + _from = from; + _to = to; + _calledKind = calledKind; + } + + enum CalledKind + { + Exact + } + + public static Called Exactly(int expectedCalls) + => new(expectedCalls, expectedCalls, CalledKind.Exact); + + public static implicit operator Called(int expectedCalls) + => new(expectedCalls, expectedCalls, CalledKind.Exact); + + public void Validate(int callCount) + { + switch (_calledKind) + { + case CalledKind.Exact: + if (callCount != _from) + { + // TODO: Better exception + throw new Exception("It was not called exactly that many times."); + } + break; + default: + throw new Exception("Invalid call kind."); + } + + } + } +} diff --git a/src/Pretender/IPretendSetup.cs b/src/Pretender/IPretendSetup.cs index f4538fb..df7517d 100644 --- a/src/Pretender/IPretendSetup.cs +++ b/src/Pretender/IPretendSetup.cs @@ -5,10 +5,13 @@ namespace Pretender public interface IPretendSetup { Pretend Pretend { get; } - void Execute(ref CallInfo callInfo); + internal void Execute(CallInfo callInfo); + internal bool Matches(CallInfo callInfo); + int TimesCalled { get; } [EditorBrowsable(EditorBrowsableState.Advanced)] void SetBehavior(Behavior behavior); + void Verify(Called called) => Pretend.Verify(pretendSetup: this, called); } public interface IPretendSetup : IPretendSetup diff --git a/src/Pretender/Internals/BaseCompiledSetup.cs b/src/Pretender/Internals/BaseCompiledSetup.cs index 65e6bae..f0079be 100644 --- a/src/Pretender/Internals/BaseCompiledSetup.cs +++ b/src/Pretender/Internals/BaseCompiledSetup.cs @@ -4,6 +4,7 @@ namespace Pretender.Internals { [EditorBrowsable(EditorBrowsableState.Never)] + // TODO: Obsolete this public abstract class BaseCompiledSetup( Pretend pretend, MethodInfo methodInfo, @@ -13,10 +14,10 @@ public abstract class BaseCompiledSetup( private readonly MethodInfo _methodInfo = methodInfo; private readonly Matcher _matcher = matcher; private readonly object? _target = target; - protected Behavior? _behavior; public Pretend Pretend { get; } = pretend; + public int TimesCalled { get; private set; } public void SetBehavior(Behavior behavior) { @@ -28,21 +29,29 @@ public void SetBehavior(Behavior behavior) _behavior = behavior; } - public void ExecuteCore(ref CallInfo callInfo) + public void ExecuteCore(CallInfo callInfo) + { + if (!Matches(callInfo)) + { + return; + } + TimesCalled++; + } + + public bool Matches(CallInfo callInfo) { // TODO: Mark as attempted? if (callInfo.MethodInfo != _methodInfo) { - return; + return false; } if (!_matcher(callInfo, _target)) { - return; + return false; } - // TODO: Mark as matched - // TODO: Set times matched? + return true; } } } diff --git a/src/Pretender/Internals/ReturningCompiledSetup.cs b/src/Pretender/Internals/ReturningCompiledSetup.cs index c09befa..fd7b5af 100644 --- a/src/Pretender/Internals/ReturningCompiledSetup.cs +++ b/src/Pretender/Internals/ReturningCompiledSetup.cs @@ -7,18 +7,17 @@ namespace Pretender.Internals [EditorBrowsable(EditorBrowsableState.Never)] - public class ReturningCompiledSetup(Pretend pretend, Func setupExpression, MethodInfo methodInfo, Matcher matcher, object? target, TResult defaultValue) + public class ReturningCompiledSetup(Pretend pretend, MethodInfo methodInfo, Matcher matcher, object? target, TResult defaultValue) : BaseCompiledSetup(pretend, methodInfo, matcher, target), IPretendSetup { - private readonly Func _setupExpression = setupExpression; private readonly TResult _defaultValue = defaultValue; public Type ReturnType => typeof(TResult); [DebuggerStepThrough] - public void Execute(ref CallInfo callInfo) + public void Execute(CallInfo callInfo) { - ExecuteCore(ref callInfo); + ExecuteCore(callInfo); // Run behavior if (_behavior is null) @@ -27,7 +26,7 @@ public void Execute(ref CallInfo callInfo) return; } - _behavior.Execute(ref callInfo); + _behavior.Execute(callInfo); // This is where I could track nullability state and throw if the return value is null still callInfo.ReturnValue ??= _defaultValue; diff --git a/src/Pretender/Internals/SetupWrapper.cs b/src/Pretender/Internals/SetupWrapper.cs new file mode 100644 index 0000000..8fc7f25 --- /dev/null +++ b/src/Pretender/Internals/SetupWrapper.cs @@ -0,0 +1,33 @@ +using System.Diagnostics.CodeAnalysis; + +namespace Pretender.Internals +{ + public readonly struct SetupWrapper : IEquatable> + { + private readonly IPretendSetup _setup; + private readonly int _hashCode; + + public SetupWrapper(IPretendSetup setup, int hashCode) + { + _setup = setup; + _hashCode = hashCode; + } + + public readonly IPretendSetup Setup => _setup; + + public bool Equals(SetupWrapper other) + { + return _hashCode == other._hashCode; + } + + public override bool Equals([NotNullWhen(true)] object? obj) + { + return obj is SetupWrapper otherWrapper && Equals(otherWrapper); + } + + public override int GetHashCode() + { + return _hashCode; + } + } +} diff --git a/src/Pretender/Internals/VoidCompiledSetup.cs b/src/Pretender/Internals/VoidCompiledSetup.cs index c1a0826..158cf9f 100644 --- a/src/Pretender/Internals/VoidCompiledSetup.cs +++ b/src/Pretender/Internals/VoidCompiledSetup.cs @@ -5,15 +5,13 @@ namespace Pretender.Internals { [EditorBrowsable(EditorBrowsableState.Never)] - public class VoidCompiledSetup(Pretend pretend, Action setupExpression, MethodInfo methodInfo, Matcher matcher, object? target) + public class VoidCompiledSetup(Pretend pretend, MethodInfo methodInfo, Matcher matcher, object? target) : BaseCompiledSetup(pretend, methodInfo, matcher, target), IPretendSetup { - private readonly Action _setupExpression = setupExpression; - [DebuggerStepThrough] - public void Execute(ref CallInfo callInfo) + public void Execute(CallInfo callInfo) { - ExecuteCore(ref callInfo); + ExecuteCore(callInfo); // Run behavior if (_behavior is null) @@ -22,7 +20,7 @@ public void Execute(ref CallInfo callInfo) } // For void returning we just run the behavior - _behavior.Execute(ref callInfo); + _behavior.Execute(callInfo); } } } diff --git a/src/Pretender/Matchers/AnyMatcher.cs b/src/Pretender/Matchers/AnyMatcher.cs index c17f4b8..3f85ae7 100644 --- a/src/Pretender/Matchers/AnyMatcher.cs +++ b/src/Pretender/Matchers/AnyMatcher.cs @@ -1,8 +1,8 @@ namespace Pretender.Matchers { - public class AnyMatcher : IMatcher + public sealed class AnyMatcher : IMatcher { - public static AnyMatcher Instance = new AnyMatcher(); + public static AnyMatcher Instance = new(); public bool Matches(object? argument) { diff --git a/src/Pretender/Pretend.Create.cs b/src/Pretender/Pretend.Create.cs new file mode 100644 index 0000000..8050179 --- /dev/null +++ b/src/Pretender/Pretend.Create.cs @@ -0,0 +1,43 @@ +namespace Pretender +{ + public partial class Pretend + { + public T Create() + { + throw new InvalidProgramException("This method should have been intercepted via a source generator."); + } + + public T Create(T0 arg0) + { + throw new InvalidProgramException("This method should have been intercepted via a source generator."); + } + + public T Create(T0 arg0, T1 arg1) + { + throw new InvalidProgramException("This method should have been intercepted via a source generator."); + } + + public T Create(T0 arg0, T1 arg1, T2 arg2) + { + throw new InvalidProgramException("This method should have been intercepted via a source generator."); + } + + public T Create(T0 arg0, T1 arg1, T2 arg2, T3 arg3) + { + throw new InvalidProgramException("This method should have been intercepted via a source generator."); + } + + public T Create(T0 arg0, T1 arg1, T2 arg2, T3 arg3, T4 arg4) + { + throw new InvalidProgramException("This method should have been intercepted via a source generator."); + } + + public T Create(T0 arg0, T1 arg1, T2 arg2, T3 arg3, T4 arg4, T5 arg5) + { + throw new InvalidProgramException("This method should have been intercepted via a source generator."); + } + + // TODO: Support overloads up to 16 + // TODO: Support params object[] args after that, maybe when params Span comes around? + } +} diff --git a/src/Pretender/Pretend.cs b/src/Pretender/Pretend.cs index 215f01a..3bf41bd 100644 --- a/src/Pretender/Pretend.cs +++ b/src/Pretender/Pretend.cs @@ -1,54 +1,80 @@ using System.ComponentModel; using System.Diagnostics; -using System.Linq.Expressions; +using System.Globalization; +using Pretender.Internals; namespace Pretender; [DebuggerDisplay("{DebuggerToString(),nq}")] -public class Pretend +public partial class Pretend { - private List>? _setups; - private IPretendSetup? _singleSetup; + // TODO: Should we minimize allocations for rarely called mocks? + private List? _calls; public Pretend() { } - // TODO: Create interceptor for returning the configured type - public T Create() + public IPretendSetup Setup(Func setupExpression) { throw new InvalidProgramException("This method should have been intercepted via a source generator."); } - public IPretendSetup Setup(Func setupExpression) + public IPretendSetup SetupSet(Func setupExpression) { throw new InvalidProgramException("This method should have been intercepted via a source generator."); } - public IPretendSetup SetupSet(Func setupExpression) + public IPretendSetup Setup(Action setupExpression) { throw new InvalidProgramException("This method should have been intercepted via a source generator."); } - public IPretendSetup Setup(Action setupExpression) + public void Verify(Action verifyExpression, Called called) { throw new InvalidProgramException("This method should have been intercepted via a source generator."); } - [DebuggerStepThrough] + public void Verify(Func verifyExpression, Called called) + { + throw new InvalidProgramException("This method should have been intercepted via a source generator."); + } + + // TODO: VerifySet? + [EditorBrowsable(EditorBrowsableState.Never)] // TODO: Make this obsolete - public void Handle(ref CallInfo callInfo) + public void Verify(IPretendSetup pretendSetup, Called called) { - if (_singleSetup != null) + // Right now we can't trust that this setup was created before, loop over all the calls and check it + int timesCalled = 0; + if (_calls != null) { - _singleSetup.Execute(ref callInfo); + for (var i = 0; i < _calls.Count; i++) + { + var call = _calls[i]; + if (pretendSetup.Matches(call)) + { + timesCalled++; + } + } } - else if (_setups != null) + + called.Validate(timesCalled); + } + + [EditorBrowsable(EditorBrowsableState.Never)] + // TODO: Make this obsolete + public void Handle(CallInfo callInfo) + { + _calls = []; + _calls.Add(callInfo); + + if (_setups != null) { foreach (var setup in _setups) { - setup.Execute(ref callInfo); + setup.Execute(callInfo); } } } @@ -58,31 +84,33 @@ private string DebuggerToString() return $"Type = {typeof(T).FullName}"; } + // private Dictionary>? _setupDictionary; + private List>? _setups; + [EditorBrowsable(EditorBrowsableState.Never)] - // TODO: Make obsolete? - public void AddSetup(IPretendSetup setup) + // TODO: Make Obsolete + public IPretendSetup GetOrCreateSetup(int hashCode, Func, Action, IPretendSetup> setupCreator, Action setupExpression) { - if (_setups == null && _singleSetup == null) - { - _singleSetup = setup; - } - else if (_setups == null) - { - _setups ??= new List>(); - _setups.Add(_singleSetup!); - _setups.Add(setup); - _singleSetup = null; - } - else - { - _setups.Add(setup); - } + _setups ??= []; + var newSetup = setupCreator(this, setupExpression); + _setups.Add(newSetup); + return newSetup; + } + + [EditorBrowsable(EditorBrowsableState.Never)] + // TODO: Make Obsolete + public IPretendSetup GetOrCreateSetup(int hashCode, Func, Func, IPretendSetup> setupCreator, Func setupExpression) + { + _setups ??= []; + var newSetup = setupCreator(this, setupExpression); + _setups.Add(newSetup); + return newSetup; } } public static class Pretend { - public static Pretend For() + public static Pretend That() { return new Pretend(); } diff --git a/test/SourceGeneratorTests/Baselines/MainTests/ReturningMethod/Pretender_Type_PretendISimpleInterface2AADE68_g_cs.txt b/test/SourceGeneratorTests/Baselines/MainTests/ReturningMethod/Pretender_Type_PretendISimpleInterface2AADE68_g_cs.txt new file mode 100644 index 0000000..051fe36 --- /dev/null +++ b/test/SourceGeneratorTests/Baselines/MainTests/ReturningMethod/Pretender_Type_PretendISimpleInterface2AADE68_g_cs.txt @@ -0,0 +1,77 @@ +// +#nullable enable +/// +internal class PretendISimpleInterface2AADE68 : global::ISimpleInterface +{ + public static readonly global::System.Reflection.MethodInfo MethodInfo_Foo_30A5A51 = typeof(global::ISimpleInterface).GetMethod(nameof(Foo))!; + public static readonly global::System.Reflection.MethodInfo MethodInfo_VoidMethod_369C2F4 = typeof(global::ISimpleInterface).GetMethod(nameof(VoidMethod))!; + public static readonly global::System.Reflection.MethodInfo MethodInfo_AsyncMethod_30D4D3B = typeof(global::ISimpleInterface).GetMethod(nameof(AsyncMethod))!; + public static readonly global::System.Reflection.MethodInfo MethodInfo_AsyncReturningMethod_24A71EA = typeof(global::ISimpleInterface).GetMethod(nameof(AsyncReturningMethod))!; + public static readonly global::System.Reflection.MethodInfo MethodInfo_TryParse_85E202 = typeof(global::ISimpleInterface).GetMethod(nameof(TryParse))!; + public static readonly global::System.Reflection.MethodInfo MethodInfo_get_Bar_3177B39 = typeof(global::ISimpleInterface).GetProperty(nameof(Bar)).GetMethod!; + public static readonly global::System.Reflection.MethodInfo MethodInfo_set_Bar_392D46 = typeof(global::ISimpleInterface).GetProperty(nameof(Bar)).SetMethod!; + private readonly global::Pretender.Pretend _pretend; + /// + public PretendISimpleInterface2AADE68(global::Pretender.Pretend pretend) + { + _pretend = pretend; + } + + public string? Foo(string? bar, int baz) + { + Span arguments = [bar, baz]; + var callInfo = new global::Pretender.CallInfo(MethodInfo_Foo_30A5A51, arguments); + _pretend.Handle(ref callInfo); + return (string)callInfo.ReturnValue; + } + + public void VoidMethod(bool baz) + { + Span arguments = [baz]; + var callInfo = new global::Pretender.CallInfo(MethodInfo_VoidMethod_369C2F4, arguments); + _pretend.Handle(ref callInfo); + } + + public global::System.Threading.Tasks.Task AsyncMethod() + { + Span arguments = []; + var callInfo = new global::Pretender.CallInfo(MethodInfo_AsyncMethod_30D4D3B, arguments); + _pretend.Handle(ref callInfo); + return (global::System.Threading.Tasks.Task)callInfo.ReturnValue; + } + + public global::System.Threading.Tasks.Task AsyncReturningMethod(string bar) + { + Span arguments = [bar]; + var callInfo = new global::Pretender.CallInfo(MethodInfo_AsyncReturningMethod_24A71EA, arguments); + _pretend.Handle(ref callInfo); + return (global::System.Threading.Tasks.Task)callInfo.ReturnValue; + } + + public bool TryParse(string thing, out bool myValue) + { + Span arguments = [thing, myValue]; + var callInfo = new global::Pretender.CallInfo(MethodInfo_TryParse_85E202, arguments); + _pretend.Handle(ref callInfo); + myValue = arguments[1]; + return (bool)callInfo.ReturnValue; + } + + public string Bar + { + get + { + Span arguments = []; + var callInfo = new global::Pretender.CallInfo(MethodInfo_get_Bar_3177B39, arguments); + _pretend.Handle(ref callInfo); + return (string)callInfo.ReturnValue; + } + + set + { + Span arguments = [value]; + var callInfo = new global::Pretender.CallInfo(MethodInfo_set_Bar_392D46, arguments); + _pretend.Handle(ref callInfo); + } + } +} \ No newline at end of file diff --git a/test/SourceGeneratorTests/MainTests.cs b/test/SourceGeneratorTests/MainTests.cs index b3d4b70..5b03519 100644 --- a/test/SourceGeneratorTests/MainTests.cs +++ b/test/SourceGeneratorTests/MainTests.cs @@ -1,39 +1,57 @@ namespace SourceGeneratorTests; -public class MainTests : TestBase +public partial class MainTests : TestBase { [Fact] - public async Task Test1() + public async Task ReturningMethod() { var (result, compilation) = await RunGeneratorAsync($$""" - var pretendSimpleInterface = Pretend.For(); + var pretendSimpleInterface = Pretend.That(); pretendSimpleInterface - .SetupSet(i => i.Bar); + .Setup(i => i.Bar) + .Returns("Hi"); - var simpleInterface = pretendSimpleInterface.Create(); + var pretend = pretendSimpleInterface.Create(); + + pretendSimpleInterface.Verify(i => i.Bar, 2); """); - Assert.Equal(3, result.GeneratedSources.Length); + Assert.Equal(4, result.GeneratedSources.Length); + //Assert.All(result.GeneratedSources, (result) => + //{ + // CompareAgainstBaseline(result); + //}); var source1 = result.GeneratedSources[0]; var text1 = source1.SourceText.ToString(); var source2 = result.GeneratedSources[1]; var text2 = source2.SourceText.ToString(); var source3 = result.GeneratedSources[2]; var text3 = source3.SourceText.ToString(); + var source4 = result.GeneratedSources[3]; + var text4 = source4.SourceText.ToString(); } + [Fact] public async Task Test2() { var (result, compilation) = await RunGeneratorAsync($$""" - var pretendSimpleInterface = Pretend.For(); + var pretendSimpleInterface = Pretend.That(); pretendSimpleInterface .Setup(i => i.Foo("1", 2)) .Returns("Hello"); + pretendSimpleInterface + .Setup(i => i.Foo("1", 2)) + .Returns("Hello"); + + pretendSimpleInterface + .Setup(i => i.Foo("2", 3)) + .Returns("Bye!"); + var simpleInterface = pretendSimpleInterface.Create(); """); @@ -46,4 +64,27 @@ public async Task Test2() var source3 = result.GeneratedSources[2]; var text3 = source3.SourceText.ToString(); } + + [Fact] + public async Task Test3() + { + var (result, compilation) = await RunGeneratorAsync($$""" + var pretendSimpleInterface = Pretend.That(); + + pretendSimpleInterface + .Setup(i => i.Foo("1", 1)) + .Returns("Hi"); + + var pretend = pretendSimpleInterface.Create(); + """); + + Assert.Equal(3, result.GeneratedSources.Length); + + var source1 = result.GeneratedSources[0]; + var text1 = source1.SourceText.ToString(); + var source2 = result.GeneratedSources[1]; + var text2 = source2.SourceText.ToString(); + var source3 = result.GeneratedSources[2]; + var text3 = source3.SourceText.ToString(); + } } diff --git a/test/SourceGeneratorTests/SourceGeneratorTests.csproj b/test/SourceGeneratorTests/SourceGeneratorTests.csproj index f164255..f9f574b 100644 --- a/test/SourceGeneratorTests/SourceGeneratorTests.csproj +++ b/test/SourceGeneratorTests/SourceGeneratorTests.csproj @@ -33,4 +33,8 @@ + + + + diff --git a/test/SourceGeneratorTests/TestBase.cs b/test/SourceGeneratorTests/TestBase.cs index 19646f8..ba1c323 100644 --- a/test/SourceGeneratorTests/TestBase.cs +++ b/test/SourceGeneratorTests/TestBase.cs @@ -1,4 +1,5 @@ -using System.Text; +using System.Runtime.CompilerServices; +using System.Text; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.Text; @@ -81,6 +82,19 @@ public SimpleAbstractClass(string arg) return (Assert.Single(runResult.Results), updateCompilation); } + public void CompareAgainstBaseline(GeneratedSourceResult result, [CallerMemberName] string testMethodName = null!) + { + var resultFileName = result.HintName.Replace('.', '_'); + var baseLineName = $"{GetType().Name}.{testMethodName}.{resultFileName}.txt"; + var resourceName = typeof(TestBase).Assembly.GetManifestResourceNames() + .Single(r => r.EndsWith(baseLineName)); + + using var stream = typeof(TestBase).Assembly.GetManifestResourceStream(resourceName)!; + using var reader = new StreamReader(stream); + Assert.Equal(reader.ReadToEnd(), result.SourceText.ToString()); + + } + private Task CreateCompilationAsync(string source) { var fullText = Base + CreateTestTemplate(source);