From de8c87d4fef5b92c5e2c5947bd30a3a833f4966f Mon Sep 17 00:00:00 2001 From: Justin Baur <19896123+justindbaur@users.noreply.github.com> Date: Sun, 24 Dec 2023 12:10:23 -0500 Subject: [PATCH] Swich to Text Writer dotnet/runtime#95882 --- example/UnitTest1.cs | 12 +- perf/Comparison/Simple.cs | 6 +- .../CSharpSyntaxUtilities.cs | 56 ++ .../Emitter/CommonSyntax.cs | 17 - .../Emitter/CreateEmitter.cs | 83 ++- .../Emitter/GrandEmitter.cs | 117 ++-- .../Emitter/MatcherArgumentEmitter.cs | 82 +++ .../Emitter/NoopArgumentEmitter.cs | 19 + .../Emitter/PretendEmitter.cs | 371 +++++-------- .../Emitter/SetupActionEmitter.cs | 215 ++------ .../Emitter/SetupArgumentEmitter.cs | 38 ++ .../Emitter/SetupEmitter.cs | 58 +- .../Emitter/VerifyEmitter.cs | 62 +-- .../Fakes/IKnownFake.cs | 2 +- .../Fakes/ILoggerFake.cs | 2 +- .../Invocation/PretendInvocation.cs | 8 +- src/Pretender.SourceGenerator/KnownBlocks.cs | 4 +- .../Parser/KnownTypeSymbols.cs | 40 +- .../Parser/MethodStrategy.cs | 86 +++ .../Parser/PretendParser.cs | 12 +- .../Parser/SetupActionParser.cs | 30 +- .../Pretender.SourceGenerator.csproj | 4 + .../PretenderSettings.cs | 1 + .../PretenderSourceGenerator.cs | 6 +- .../SetupArguments/InvocationArgumentSpec.cs | 188 ------- .../SetupArguments/LiteralArgumentEmitter.cs | 32 ++ .../SetupArguments/LiteralArgumentSpec.cs | 34 -- .../LocalReferenceArgumentEmitter.cs | 30 + .../LocalReferenceArgumentSpec.cs | 92 ---- ...ArgumentSpec.cs => SetupArgumentParser.cs} | 172 +++--- .../SymbolExtensions.cs | 35 +- .../Writing/ImmutableArrayBuilder.cs | 364 +++++++++++++ .../Writing/IndentedTextWriter.cs | 515 ++++++++++++++++++ .../Writing/ObjectPool{T}.cs | 154 ++++++ test/SourceGeneratorTests/MainTests.cs | 48 +- test/SourceGeneratorTests/TestBase.cs | 1 - 36 files changed, 1947 insertions(+), 1049 deletions(-) create mode 100644 src/Pretender.SourceGenerator/CSharpSyntaxUtilities.cs create mode 100644 src/Pretender.SourceGenerator/Emitter/MatcherArgumentEmitter.cs create mode 100644 src/Pretender.SourceGenerator/Emitter/NoopArgumentEmitter.cs create mode 100644 src/Pretender.SourceGenerator/Emitter/SetupArgumentEmitter.cs create mode 100644 src/Pretender.SourceGenerator/Parser/MethodStrategy.cs delete mode 100644 src/Pretender.SourceGenerator/SetupArguments/InvocationArgumentSpec.cs create mode 100644 src/Pretender.SourceGenerator/SetupArguments/LiteralArgumentEmitter.cs delete mode 100644 src/Pretender.SourceGenerator/SetupArguments/LiteralArgumentSpec.cs create mode 100644 src/Pretender.SourceGenerator/SetupArguments/LocalReferenceArgumentEmitter.cs delete mode 100644 src/Pretender.SourceGenerator/SetupArguments/LocalReferenceArgumentSpec.cs rename src/Pretender.SourceGenerator/SetupArguments/{SetupArgumentSpec.cs => SetupArgumentParser.cs} (58%) create mode 100644 src/Pretender.SourceGenerator/Writing/ImmutableArrayBuilder.cs create mode 100644 src/Pretender.SourceGenerator/Writing/IndentedTextWriter.cs create mode 100644 src/Pretender.SourceGenerator/Writing/ObjectPool{T}.cs diff --git a/example/UnitTest1.cs b/example/UnitTest1.cs index 70546e7..931dc66 100644 --- a/example/UnitTest1.cs +++ b/example/UnitTest1.cs @@ -39,15 +39,19 @@ public async Task Test2() [Fact] public void Test3() { - var pretend = Pretend.That() + var pretend = Pretend.That(); + + var setup = pretend .Setup(i => i.Greeting("Hello", It.IsAny())); - var item = pretend.Pretend.Create(); + setup.Returns("2"); + + var item = pretend.Create(); var response = item.Greeting("Hello", 12); - Assert.Null(response); + Assert.Equal("2", response); - pretend.Verify(1); + setup.Verify(1); } } diff --git a/perf/Comparison/Simple.cs b/perf/Comparison/Simple.cs index 44d1410..8b63f9d 100644 --- a/perf/Comparison/Simple.cs +++ b/perf/Comparison/Simple.cs @@ -11,7 +11,7 @@ public string MoqTest() { var mock = new Moq.Mock(); - mock.Setup(i => i.Foo(Moq.It.Is(static i => i == "1"))) + mock.Setup(i => i.Foo(Moq.It.IsAny())) .Returns("2"); var simpleInterface = mock.Object; @@ -23,7 +23,7 @@ public string NSubstituteTest() { var substitute = NSubstitute.Substitute.For(); - NSubstitute.SubstituteExtensions.Returns(substitute.Foo(NSubstitute.Arg.Is(static i => i == "1")), "2"); + NSubstitute.SubstituteExtensions.Returns(substitute.Foo(NSubstitute.Arg.Any()), "2"); return substitute.Foo("1"); } @@ -33,7 +33,7 @@ public string PretenderTest() { var pretend = Pretend.That(); - pretend.Setup(i => i.Foo(It.Is(static i => i == "1"))) + pretend.Setup(i => i.Foo(It.IsAny())) .Returns("2"); var simpleInterface = pretend.Create(); diff --git a/src/Pretender.SourceGenerator/CSharpSyntaxUtilities.cs b/src/Pretender.SourceGenerator/CSharpSyntaxUtilities.cs new file mode 100644 index 0000000..e26340f --- /dev/null +++ b/src/Pretender.SourceGenerator/CSharpSyntaxUtilities.cs @@ -0,0 +1,56 @@ +using System.Globalization; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; + +namespace Pretender.SourceGenerator +{ + internal static class CSharpSyntaxUtilities + { + // Standard format for double and single on non-inbox frameworks to ensure value is round-trippable. + public const string DoubleFormatString = "G17"; + public const string SingleFormatString = "G9"; + + // Format a literal in C# format -- works around https://github.com/dotnet/roslyn/issues/58705 + + public static string FormatLiteral(object? value, ITypeSymbol type) + { + if (value == null) + { + return $"default({type.ToFullDisplayString()})"; + } + + switch (value) + { + case string @string: + return SymbolDisplay.FormatLiteral(@string, quote: true); + case char @char: + return SymbolDisplay.FormatLiteral(@char, quote: true); + case double.NegativeInfinity: + return "double.NegativeInfinity"; + case double.PositiveInfinity: + return "double.PositiveInfinity"; + case double.NaN: + return "double.NaN"; + case double @double: + return $"{@double.ToString(DoubleFormatString, CultureInfo.InvariantCulture)}D"; + case float.NegativeInfinity: + return "float.NegativeInfinity"; + case float.PositiveInfinity: + return "float.PositiveInfinity"; + case float.NaN: + return "float.NaN"; + case float @float: + return $"{@float.ToString(SingleFormatString, CultureInfo.InvariantCulture)}F"; + case decimal @decimal: + // we do not need to specify a format string for decimal as it's default is round-trippable on all frameworks. + return $"{@decimal.ToString(CultureInfo.InvariantCulture)}M"; + case bool @bool: + return @bool ? "true" : "false"; + default: + // Assume this is a number. + return FormatNumber(); + } + string FormatNumber() => $"({type.ToFullDisplayString()})({Convert.ToString(value, CultureInfo.InvariantCulture)})"; + } + } +} diff --git a/src/Pretender.SourceGenerator/Emitter/CommonSyntax.cs b/src/Pretender.SourceGenerator/Emitter/CommonSyntax.cs index a8c355c..7866131 100644 --- a/src/Pretender.SourceGenerator/Emitter/CommonSyntax.cs +++ b/src/Pretender.SourceGenerator/Emitter/CommonSyntax.cs @@ -11,28 +11,11 @@ internal static class CommonSyntax public static PredefinedTypeSyntax VoidType { get; } = PredefinedType(Token(SyntaxKind.VoidKeyword)); public static TypeSyntax VarType { get; } = ParseTypeName("var"); public static GenericNameSyntax GenericPretendType { get; } = GenericName("Pretend"); - public static UsingDirectiveSyntax UsingSystem { get; } = UsingDirective(ParseName("System")); - public static UsingDirectiveSyntax UsingSystemThreadingTasks { get; } = UsingDirective(ParseName("System.Threading.Tasks")); // Verify public static SyntaxToken SetupIdentifier { get; } = Identifier("setup"); public static SyntaxToken CalledIdentifier { get; } = Identifier("called"); public static ParameterSyntax CalledParameter { get; } = Parameter(CalledIdentifier) .WithType(ParseTypeName("Called")); - - public static CompilationUnitSyntax CreateVerifyCompilationUnit(MethodDeclarationSyntax[] verifyMethods) - { - var classDeclaration = ClassDeclaration("VerifyInterceptors") - .AddModifiers(Token(SyntaxKind.FileKeyword), Token(SyntaxKind.StaticKeyword)) - .AddMembers(verifyMethods); - - var namespaceDeclaration = NamespaceDeclaration(ParseName("Pretender.SourceGeneration")) - .AddMembers(classDeclaration) - .AddUsings(UsingSystem, KnownBlocks.CompilerServicesUsing, UsingSystemThreadingTasks, KnownBlocks.PretenderUsing, KnownBlocks.PretenderInternalsUsing); - - return CompilationUnit() - .AddMembers(KnownBlocks.InterceptsLocationAttribute, namespaceDeclaration) - .NormalizeWhitespace(); - } } } diff --git a/src/Pretender.SourceGenerator/Emitter/CreateEmitter.cs b/src/Pretender.SourceGenerator/Emitter/CreateEmitter.cs index d183d1a..f645138 100644 --- a/src/Pretender.SourceGenerator/Emitter/CreateEmitter.cs +++ b/src/Pretender.SourceGenerator/Emitter/CreateEmitter.cs @@ -1,9 +1,9 @@ -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Operations; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; using System.Collections.Immutable; using Microsoft.CodeAnalysis; +using Pretender.SourceGenerator.Writing; namespace Pretender.SourceGenerator.Emitter { @@ -24,67 +24,62 @@ public CreateEmitter(IInvocationOperation originalOperation, ImmutableArray _originalOperation; - public MethodDeclarationSyntax Emit(CancellationToken cancellationToken) + public void Emit(IndentedTextWriter writer, CancellationToken cancellationToken) { var returnType = _originalOperation.TargetMethod.ReturnType; - var returnTypeSyntax = returnType.AsUnknownTypeSyntax(); + var returnTypeSyntax = returnType.ToUnknownTypeString(); - TypeParameterSyntax[] typeParameters; - ParameterSyntax[] methodParameters; - ArgumentSyntax[] constructorArguments; - - if (_typeArguments.HasValue) + foreach (var location in _locations) { - typeParameters = new TypeParameterSyntax[_typeArguments.Value.Length]; + writer.WriteLine(@$"[InterceptsLocation(@""{location.FilePath}"", {location.LineNumber}, {location.CharacterNumber})]"); + } + writer.Write($"internal static {returnType.ToUnknownTypeString()} Create{_index}"); - // We always take the Pretend argument first as a this parameter - methodParameters = new ParameterSyntax[_typeArguments.Value.Length + 1]; - constructorArguments = new ArgumentSyntax[_typeArguments.Value.Length + 1]; + if (_typeArguments is ImmutableArray typeArguments && typeArguments.Length > 0) + { + // (this Pretend pretend, T0 arg0, T1 arg1) + writer.Write("<"); for (var i = 0; i < _typeArguments.Value.Length; i++) { - var typeName = $"T{i}"; - var argName = $"arg{i}"; + writer.Write($"T{i}"); + } - typeParameters[i] = TypeParameter(typeName); - methodParameters[i + 1] = Parameter(Identifier(argName)) - .WithType(ParseTypeName(typeName)); - constructorArguments[i + 1] = Argument(IdentifierName(argName)); + writer.Write($">(this Pretend<{returnTypeSyntax}> pretend"); + + for (var i = 0; i < _typeArguments.Value.Length; i++) + { + writer.Write($", T{i} arg{i}"); } + + writer.WriteLine(")"); } else { - typeParameters = []; - methodParameters = new ParameterSyntax[1]; - constructorArguments = new ArgumentSyntax[1]; + // TODO: Handle the params overload + writer.WriteLine($"(this Pretend<{returnTypeSyntax}> pretend)"); } - methodParameters[0] = Parameter(Identifier("pretend")) - .WithType(GenericName("Pretend") - .AddTypeArgumentListArguments(returnTypeSyntax)) - .WithModifiers(TokenList(Token(SyntaxKind.ThisKeyword)) - ); - - constructorArguments[0] = Argument(IdentifierName("pretend")); - - var objectCreation = ObjectCreationExpression(ParseTypeName(returnType.ToPretendName())) - .WithArgumentList(ArgumentList(SeparatedList(constructorArguments))); - - var method = MethodDeclaration(returnTypeSyntax, $"Create{_index}") - .WithBody(Block(ReturnStatement(objectCreation))) - .WithParameterList(ParameterList(SeparatedList(methodParameters))) - .WithModifiers(TokenList(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.StaticKeyword))); + using (writer.WriteBlock()) + { + writer.Write($"return new {returnType.ToPretendName()}(pretend"); - method = method.WithAttributeLists(List(CreateInterceptsAttributes())); + if (_typeArguments.HasValue) + { + for (int i = 0; i < _typeArguments.Value.Length; i++) + { + writer.Write($", arg{i}"); + } - if (typeParameters.Length > 0) - { - return method - .WithTypeParameterList(TypeParameterList(SeparatedList(typeParameters))); + writer.WriteLine(");"); + } + else + { + // TODO: Handle params overload + writer.WriteLine(");"); + } } - - return method; } private ImmutableArray CreateInterceptsAttributes() diff --git a/src/Pretender.SourceGenerator/Emitter/GrandEmitter.cs b/src/Pretender.SourceGenerator/Emitter/GrandEmitter.cs index 4283b62..fe7b2a3 100644 --- a/src/Pretender.SourceGenerator/Emitter/GrandEmitter.cs +++ b/src/Pretender.SourceGenerator/Emitter/GrandEmitter.cs @@ -1,7 +1,9 @@ using System.Collections.Immutable; +using System.Security.Cryptography.X509Certificates; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; +using Pretender.SourceGenerator.Writing; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; namespace Pretender.SourceGenerator.Emitter @@ -25,78 +27,81 @@ public GrandEmitter( _createEmitters = createEmitters; } - public CompilationUnitSyntax Emit(CancellationToken cancellationToken) + public string Emit(CancellationToken cancellationToken) { - var namespaceDeclaration = KnownBlocks.OurNamespace - .AddUsings( - UsingDirective(ParseName("System")), - KnownBlocks.CompilerServicesUsing, - UsingDirective(ParseName("System.Threading.Tasks")), - KnownBlocks.PretenderUsing, - KnownBlocks.PretenderInternalsUsing - ); + var writer = new IndentedTextWriter(); - foreach (var pretendEmitter in _pretendEmitters) - { - cancellationToken.ThrowIfCancellationRequested(); - namespaceDeclaration = namespaceDeclaration - .AddMembers(pretendEmitter.Emit(cancellationToken)); - } + // InceptsLocationAttribute + writer.Write(KnownBlocks.InterceptsLocationAttribute, isMultiline: true); + writer.WriteLine(); + writer.WriteLine(); - var setupInterceptorsClass = ClassDeclaration("SetupInterceptors") - .WithModifiers(TokenList(Token(SyntaxKind.FileKeyword), Token(SyntaxKind.StaticKeyword))); - - cancellationToken.ThrowIfCancellationRequested(); - - int setupIndex = 0; - foreach (var setupEmitter in _setupEmitters) + writer.WriteLine("namespace Pretender.SourceGeneration"); + using (writer.WriteBlock()) { - cancellationToken.ThrowIfCancellationRequested(); - setupInterceptorsClass = setupInterceptorsClass - .AddMembers(setupEmitter.Emit(setupIndex, cancellationToken)); - setupIndex++; - } - - var verifyInterceptorsClass = ClassDeclaration("VerifyInterceptors") - .AddModifiers(Token(SyntaxKind.FileKeyword), Token(SyntaxKind.StaticKeyword)); + writer.WriteLine("using System;"); + writer.WriteLine("using System.Reflection;"); + writer.WriteLine("using System.Runtime.CompilerServices;"); + writer.WriteLine("using System.Threading.Tasks;"); + writer.WriteLine("using Pretender;"); + writer.WriteLine("using Pretender.Internals;"); + writer.WriteLine(); + + foreach (var pretendEmitter in _pretendEmitters) + { + cancellationToken.ThrowIfCancellationRequested(); + pretendEmitter.Emit(writer, cancellationToken); + } - cancellationToken.ThrowIfCancellationRequested(); - - int verifyIndex = 0; - foreach (var verifyEmitter in _verifyEmitters) - { cancellationToken.ThrowIfCancellationRequested(); - verifyInterceptorsClass = verifyInterceptorsClass - .AddMembers(verifyEmitter.Emit(verifyIndex, cancellationToken)); - verifyIndex++; - } + writer.WriteLine(); + writer.WriteLine("file static class SetupInterceptors"); + using (writer.WriteBlock()) + { + int setupIndex = 0; + foreach (var setupEmitter in _setupEmitters) + { + cancellationToken.ThrowIfCancellationRequested(); + setupEmitter.Emit(writer, setupIndex, cancellationToken); + setupIndex++; + } + } - var createInterceptorsClass = ClassDeclaration("CreateInterceptors") - .AddModifiers(Token(SyntaxKind.FileKeyword), Token(SyntaxKind.StaticKeyword)); + cancellationToken.ThrowIfCancellationRequested(); - cancellationToken.ThrowIfCancellationRequested(); + writer.WriteLine(); + writer.WriteLine("file static class VerifyInterceptors"); + using (writer.WriteBlock()) + { + int verifyIndex = 0; + foreach (var verifyEmitter in _verifyEmitters) + { + cancellationToken.ThrowIfCancellationRequested(); + verifyEmitter.Emit(writer, verifyIndex, cancellationToken); + verifyIndex++; + } + } - int createIndex = 0; - foreach (var createEmitter in _createEmitters) - { cancellationToken.ThrowIfCancellationRequested(); - createInterceptorsClass = createInterceptorsClass - .AddMembers(createEmitter.Emit(cancellationToken)); - createIndex++; + writer.WriteLine(); + writer.WriteLine("file static class CreateInterceptors"); + using (writer.WriteBlock()) + { + int createIndex = 0; + foreach (var createEmitter in _createEmitters) + { + cancellationToken.ThrowIfCancellationRequested(); + createEmitter.Emit(writer, cancellationToken); + createIndex++; + } + } } - namespaceDeclaration = namespaceDeclaration - .AddMembers(setupInterceptorsClass, verifyInterceptorsClass, createInterceptorsClass); - cancellationToken.ThrowIfCancellationRequested(); - return CompilationUnit() - .AddMembers( - KnownBlocks.InterceptsLocationAttribute, - namespaceDeclaration) - .NormalizeWhitespace(); + return writer.ToString(); } } } diff --git a/src/Pretender.SourceGenerator/Emitter/MatcherArgumentEmitter.cs b/src/Pretender.SourceGenerator/Emitter/MatcherArgumentEmitter.cs new file mode 100644 index 0000000..406f5f2 --- /dev/null +++ b/src/Pretender.SourceGenerator/Emitter/MatcherArgumentEmitter.cs @@ -0,0 +1,82 @@ +using Microsoft.CodeAnalysis; +using Pretender.SourceGenerator.SetupArguments; +using Pretender.SourceGenerator.Writing; + +namespace Pretender.SourceGenerator.Emitter +{ + internal class MatcherArgumentEmitter : SetupArgumentEmitter + { + private readonly INamedTypeSymbol _matcherType; + + // TODO: Also take args + public MatcherArgumentEmitter(INamedTypeSymbol matcherType, SetupArgumentSpec argumentSpec) + : base(argumentSpec) + { + _matcherType = matcherType; + } + + public override void EmitArgumentMatcher(IndentedTextWriter writer, CancellationToken cancellationToken) + { + //var arguments = new ArgumentSyntax[_invocationOperation.Arguments.Length]; + //bool allArgumentsSafe = true; + + //for (int i = 0; i < arguments.Length; i++) + //{ + // var arg = _invocationOperation.Arguments[i]; + // if (arg.Value is ILiteralOperation literalOperation) + // { + // arguments[i] = Argument(literalOperation.ToLiteralExpression()); + // } + // else if (arg.Value is IDelegateCreationOperation delegateCreation) + // { + + // if (delegateCreation.Target is IAnonymousFunctionOperation anonymousFunctionOperation) + // { + // if (anonymousFunctionOperation.Symbol.IsStatic) // This isn't enough either though, they could call a static method that only exists in their context + // { + // // If it's guaranteed to be static, we can just rewrite it in our code + // arguments[i] = Argument(ParseExpression(delegateCreation.Syntax.GetText().ToString())); + // } + // else if (false) // Is non-scope capturing + // { + // // This is a lot more work but also very powerful in terms of speed + // // We need to rewrite the delegate and replace all local references with our getter + // allArgumentsSafe = false; + // } + // else + // { + // // We need a static matcher + // allArgumentsSafe = false; + // } + // } + // else + // { + // allArgumentsSafe = false; + // } + // } + // else + // { + // allArgumentsSafe = false; + // } + //} + + //if (!allArgumentsSafe) + //{ + // createdMatchStatements = false; + // return; + //} + + EmitArgumentAccessor(writer); + + var matcherLocalName = $"{ArgumentSpec.Parameter.Name}_matcher"; + + // TODO: Get arguments + writer.WriteLine($"var {matcherLocalName} = new {_matcherType.ToFullDisplayString()}();"); + writer.WriteLine($"if (!{matcherLocalName}.Matches({ArgumentSpec.Parameter.Name}_arg))"); + using (writer.WriteBlock()) + { + writer.WriteLine("return false;"); + } + } + } +} diff --git a/src/Pretender.SourceGenerator/Emitter/NoopArgumentEmitter.cs b/src/Pretender.SourceGenerator/Emitter/NoopArgumentEmitter.cs new file mode 100644 index 0000000..b7da48f --- /dev/null +++ b/src/Pretender.SourceGenerator/Emitter/NoopArgumentEmitter.cs @@ -0,0 +1,19 @@ +using Pretender.SourceGenerator.SetupArguments; +using Pretender.SourceGenerator.Writing; + +namespace Pretender.SourceGenerator.Emitter +{ + internal class NoopArgumentEmitter : SetupArgumentEmitter + { + public NoopArgumentEmitter(SetupArgumentSpec argumentSpec) + : base(argumentSpec) + { } + + public override bool EmitsMatcher => false; + + public override void EmitArgumentMatcher(IndentedTextWriter writer, CancellationToken cancellationToken) + { + // Intentional no-op + } + } +} diff --git a/src/Pretender.SourceGenerator/Emitter/PretendEmitter.cs b/src/Pretender.SourceGenerator/Emitter/PretendEmitter.cs index fa44ccd..8b9f3d8 100644 --- a/src/Pretender.SourceGenerator/Emitter/PretendEmitter.cs +++ b/src/Pretender.SourceGenerator/Emitter/PretendEmitter.cs @@ -1,290 +1,181 @@ -using System.Collections.Immutable; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; +using Microsoft.CodeAnalysis; +using Pretender.SourceGenerator.Parser; +using Pretender.SourceGenerator.Writing; namespace Pretender.SourceGenerator.Emitter { internal class PretendEmitter { private readonly ITypeSymbol _pretendType; + private readonly IReadOnlyDictionary _methodStrategies; private readonly bool _fillExisting; - public PretendEmitter(ITypeSymbol pretendType, bool fillExisting) + public PretendEmitter(ITypeSymbol pretendType, IReadOnlyDictionary methodStrategies, bool fillExisting) { _pretendType = pretendType; + _methodStrategies = methodStrategies; _fillExisting = fillExisting; } - public ITypeSymbol PretendType => _pretendType; - - public TypeDeclarationSyntax Emit(CancellationToken token) + public void Emit(IndentedTextWriter writer, CancellationToken token) { - var pretendFieldAssignment = ExpressionStatement( - AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, - IdentifierName("_pretend"), - IdentifierName("pretend") - ) - ); - - var methodSymbols = new List<(IMethodSymbol Method, string Name)>(); - token.ThrowIfCancellationRequested(); - var typeMembers = _pretendType.GetMembers(); - - foreach (var member in typeMembers) + // TODO: Generate debugger display + if (_fillExisting) { - if (member is IMethodSymbol methodSymbol) - { - methodSymbols.Add((methodSymbol, methodSymbol.Name)); - } - else if (member is IPropertySymbol propertySymbol) - { - // Property symbol is taken care of through IMethodSymbol - } - else if (member is IFieldSymbol fieldSymbol) - { - // TODO: Do I need to do anything - // abstract fields? - } - else + writer.WriteLine($"public partial class {_pretendType.Name}"); + } + else + { + writer.WriteLine($"file class {_pretendType.ToPretendName()} : {_pretendType.ToFullDisplayString()}"); + } + using (writer.WriteBlock()) + { + // static fields + foreach (var strategyEntry in _methodStrategies) { - throw new NotImplementedException($"We don't support {member.Kind} quite yet, please file an issue."); + var strategy = strategyEntry.Value; + writer.Write($"public static readonly MethodInfo {strategy.UniqueName}_MethodInfo = typeof({_pretendType.ToPretendName()})"); + strategyEntry.Value.EmitMethodGetter(writer, token); + writer.WriteLine(skipIfPresent: true); } - } + writer.WriteLine(); - token.ThrowIfCancellationRequested(); + // instance fields + writer.WriteLine($"private readonly Pretend<{_pretendType.ToFullDisplayString()}> _pretend;"); + writer.WriteLine(); - var methodInfoFields = new List(); - - // Find the shortest path to uniquify all method info getters - var groupedMethodSymbols = methodSymbols - .Where(m => m.Method.MethodKind == MethodKind.Ordinary) - .GroupBy(m => m.Name); - - foreach (var groupedMethodSymbol in groupedMethodSymbols) - { - var methods = groupedMethodSymbol.ToArray(); - - if (methods.Length == 1) + // main constructor + writer.WriteLine($"public {_pretendType.ToPretendName()}(Pretend<{_pretendType.ToFullDisplayString()}> pretend)"); + using (writer.WriteBlock()) { - var (method, name) = methods[0]; - // No one else has this name - ExpressionSyntax expression = CreateSimpleMethodInfoGetter(name, "GetMethod"); - methodInfoFields.Add(CreateMethodInfoField(method, expression)); - continue; + writer.WriteLine("_pretend = pretend;"); } - // We have to do more work to fine the unique method, I also know it's not a property anymore - // because properties have a unique name - var groupedMethodParameterLengths = groupedMethodSymbol - .GroupBy(m => m.Method.Parameters.Length); + token.ThrowIfCancellationRequested(); + + // TODO: Stub other base type constructors - foreach (var groupedMethodParameterLength in groupedMethodParameterLengths) + // methods/properties + var allMembers = _pretendType.GetMembers(); + + foreach (var member in allMembers) { - methods = groupedMethodParameterLength.ToArray(); + token.ThrowIfCancellationRequested(); - if (methods.Length == 1) + if (member.IsStatic) { - var method = methods[0]; - // This method is unique from it's other matches via it's parameter length - // TODO: Do this + // TODO: I should probably stub out static abstracts continue; } - } - - // TODO: Match all type parameters - throw new NotImplementedException($"Could not find a unique way to identify method '{groupedMethodSymbol.Key}'"); - } - - var propertyMethodSymbols = methodSymbols - .Where(m => m.Method.MethodKind == MethodKind.PropertyGet - || m.Method.MethodKind == MethodKind.PropertySet); - foreach (var (method, name) in propertyMethodSymbols) - { - var methodName = method.MethodKind == MethodKind.PropertyGet - ? "GetMethod" - : "SetMethod"; - ExpressionSyntax expression = CreateSimplePropertyMethodInfoGetter(method.AssociatedSymbol!.Name, methodName); - methodInfoFields.Add(CreateMethodInfoField(method, expression)); + if (member is IMethodSymbol constructorSymbol && constructorSymbol.MethodKind == MethodKind.Constructor) + { + if (constructorSymbol.Parameters.Length != 0) + { + throw new NotImplementedException("We have not implemented constructors with parameters yet."); + } + } + else if (member is IMethodSymbol methodSymbol && methodSymbol.MethodKind == MethodKind.Ordinary) + { + // Emit Method body + writer.WriteLine(); + writer.Write($"public {methodSymbol.ReturnType.ToUnknownTypeString()} {methodSymbol.Name}"); + + var hasTypeParameters = methodSymbol.TypeParameters.Length > 0; + + if (hasTypeParameters) + { + writer.Write($"<{string.Join(", ", methodSymbol.TypeParameters.Select(t => t.Name))}>"); + } + + var parameters = methodSymbol.Parameters.Select(p => + { + string output = ""; + if (p.RefKind == RefKind.Out) + { + output += "out "; + } + else if (p.RefKind == RefKind.Ref) + { + output += "ref "; + } + else if (p.RefKind == RefKind.RefReadOnly) + { + output += "ref readonly "; + } + + output += $"{p.Type.ToUnknownTypeString()} {p.Name}"; + return output; + }); + writer.WriteLine($"({string.Join(", ", parameters)})"); + EmitMethodBody(writer, methodSymbol); + } + else if (member is IPropertySymbol propertySymbol) + { + // Emit property + writer.WriteLine(); + writer.WriteLine($"public {propertySymbol.Type.ToUnknownTypeString()} {propertySymbol.Name}"); + using (writer.WriteBlock()) + { + if (propertySymbol.GetMethod is not null) + { + writer.WriteLine("get"); + EmitMethodBody(writer, propertySymbol.GetMethod); + } + + if (propertySymbol.SetMethod is not null) + { + writer.WriteLine("set"); + EmitMethodBody(writer, propertySymbol.SetMethod); + } + } + } + } } - - var instanceField = FieldDeclaration(VariableDeclaration(GetGenericPretendType(), SingletonSeparatedList(VariableDeclarator(Identifier("_pretend"))))) - .WithModifiers(TokenList(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ReadOnlyKeyword))) - .WithTrailingTrivia(CarriageReturnLineFeed); - - methodInfoFields.Add(instanceField); - - var classDeclaration = _pretendType.ScaffoldImplementation(new ScaffoldTypeOptions - { - CustomFields = methodInfoFields.ToImmutableArray(), - AddMethodBody = CreateMethodBody, - CustomizeConstructor = () => (CreateConstructorParameter(), [pretendFieldAssignment]), - }); - - // TODO: Add properties - - // TODO: Generate debugger display - return classDeclaration - .WithModifiers(TokenList(Token(SyntaxKind.FileKeyword))); - } - - private static FieldDeclarationSyntax CreateMethodInfoField(IMethodSymbol method, ExpressionSyntax expressionSyntax) - { - // public static readonly MethodInfo MethodInfo_name_4B2 = !; - return FieldDeclaration(VariableDeclaration(ParseTypeName("MethodInfo"))) - .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.ReadOnlyKeyword)) - .AddDeclarationVariables(VariableDeclarator(Identifier(method.ToMethodInfoCacheName())) - .WithInitializer(EqualsValueClause( - PostfixUnaryExpression(SyntaxKind.SuppressNullableWarningExpression, expressionSyntax))) - ); - } - - private InvocationExpressionSyntax CreateSimpleMethodInfoGetter(string name, string afterTypeOfMethod) - { - - return InvocationExpression(MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - TypeOfExpression(ParseTypeName(_pretendType.ToFullDisplayString())), - IdentifierName(afterTypeOfMethod))) - .AddArgumentListArguments(Argument(NameOfExpression(name))); } - private MemberAccessExpressionSyntax CreateSimplePropertyMethodInfoGetter(string propertyName, string type) + private void EmitMethodBody(IndentedTextWriter writer, IMethodSymbol methodSymbol) { - return MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - CreateSimpleMethodInfoGetter(propertyName, "GetProperty"), - IdentifierName(type)); - } - - private static InvocationExpressionSyntax NameOfExpression(string identifier) - { - var text = SyntaxFacts.GetText(SyntaxKind.NameOfKeyword); - - var identifierSyntax = Identifier(default, - SyntaxKind.NameOfKeyword, - text, - text, - default); - - return InvocationExpression( - IdentifierName(identifierSyntax), - ArgumentList(SingletonSeparatedList(Argument(IdentifierName(identifier))))); - } + using (writer.WriteBlock()) + { + writer.WriteLine($"object?[] __arguments__ = [{string.Join(", ", methodSymbol.Parameters.Select(p => p.Name))}];"); + // TODO: Probably create an Argument object + writer.WriteLine($"var __callInfo__ = new CallInfo({_methodStrategies[methodSymbol].UniqueName}_MethodInfo, __arguments__);"); + writer.WriteLine("_pretend.Handle(__callInfo__);"); - private ParameterSyntax CreateConstructorParameter() - { - return Parameter(Identifier("pretend")) - .WithType(GetGenericPretendType()); - } + foreach (var parameter in methodSymbol.Parameters) + { + if (parameter.RefKind != RefKind.Ref && parameter.RefKind != RefKind.Out) + { + continue; + } - private TypeSyntax GetGenericPretendType() - { - return GenericName(Identifier("Pretend"), - TypeArgumentList(SingletonSeparatedList(ParseTypeName(_pretendType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))))); - } + writer.WriteLine($"{parameter.Name} = __arguments__[{parameter.Ordinal}];"); + } - private FieldDeclarationSyntax GetStaticMethodCacheField(IMethodSymbol method, int index) - { - // TODO: Get method info via argument types - return FieldDeclaration(VariableDeclaration(ParseTypeName("MethodInfo"))) - .AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.ReadOnlyKeyword)) - .AddDeclarationVariables(VariableDeclarator(Identifier($"__methodInfo_{method.Name}_{index}")) - .WithInitializer(EqualsValueClause( - InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, TypeOfExpression(ParseTypeName(_pretendType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))), IdentifierName("GetMethod"))) - .WithArgumentList(ArgumentList(SingletonSeparatedList(Argument(ParseExpression($"nameof({method.Name})")))))))); + if (methodSymbol.ReturnType.SpecialType != SpecialType.System_Void) + { + // TODO: What do I do about the nullability issues? + writer.WriteLine($"return ({methodSymbol.ReturnType.ToUnknownTypeString()})__callInfo__.ReturnValue;"); + } + } } - private BlockSyntax CreateMethodBody(IMethodSymbol method) + public class Comparer : IEqualityComparer { - var methodBodyStatements = new List(); - - // This is using the new collection expression syntax in C# 12 - // [arg1, arg2, arg3] - var collectionExpression = CollectionExpression() - .AddElements(method.Parameters.Select(p - => ExpressionElement(IdentifierName(p.Name))).ToArray()); - - // object?[] - var typeSyntax = ArrayType(NullableType(PredefinedType(Token(SyntaxKind.ObjectKeyword)))) - .WithRankSpecifiers(SingletonList(ArrayRankSpecifier())); - - var argumentsIdentifier = IdentifierName("__arguments"); - var callInfoIdentifier = IdentifierName("__callInfo"); - - // I'm not currently able to use Span because I have to store CallInfo for late Setup/Verify - // but I don't want to delete this code in case this becomes possible or I don't want to support that - // Span - //var typeSyntax = GenericName("Span").AddTypeArgumentListArguments(NullableType(PredefinedType(Token(SyntaxKind.ObjectKeyword)))); - - // object? [] arguments = [arg0, arg1]; - var argumentsDeclaration = LocalDeclarationStatement( - VariableDeclaration(typeSyntax) - .AddVariables(VariableDeclarator(argumentsIdentifier.Identifier) - .WithInitializer(EqualsValueClause(collectionExpression)) - ) - ); - - methodBodyStatements.Add(argumentsDeclaration); - - // var callInfo = new CallInfo(__methodInfo_MethodName_0, arguments); - var callInfoCreation = LocalDeclarationStatement( - VariableDeclaration(IdentifierName("var")) - .AddVariables(VariableDeclarator(callInfoIdentifier.Identifier) - .WithInitializer(EqualsValueClause(ObjectCreationExpression(ParseTypeName("CallInfo")) - .AddArgumentListArguments(Argument(IdentifierName(method.ToMethodInfoCacheName())), Argument(argumentsIdentifier)))))); + public static Comparer Default = new(); - methodBodyStatements.Add(callInfoCreation); - - // TODO: Call inner implementations when we support them - - // _pretend.Handle(callInfo); - var handleCall = ExpressionStatement(InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("_pretend"), - IdentifierName("Handle"))) - .WithArgumentList(ArgumentList(SingletonSeparatedList( - Argument(callInfoIdentifier))))); - - methodBodyStatements.Add(handleCall); - - // Set ref and out parameters - // TODO: Do I need to do refs? - var refAndOutParameters = method.Parameters - .Where(p => p.RefKind == RefKind.Ref || p.RefKind == RefKind.Out); - - - - foreach (var p in refAndOutParameters) + bool IEqualityComparer.Equals(PretendEmitter x, PretendEmitter y) { - // assign them to the values from arguments - var refOrOutAssignment = AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - IdentifierName(p.Name), - ElementAccessExpression( - argumentsIdentifier, - BracketedArgumentList(SingletonSeparatedList( - Argument(LiteralExpression( - SyntaxKind.NumericLiteralExpression, - Literal(p.Ordinal)))))) - ); - - methodBodyStatements.Add(ExpressionStatement(refOrOutAssignment)); + return SymbolEqualityComparer.Default.Equals(x._pretendType, y._pretendType); } - if (method.ReturnType.SpecialType != SpecialType.System_Void) + int IEqualityComparer.GetHashCode(PretendEmitter obj) { - var returnStatement = ReturnStatement(CastExpression( - method.ReturnType.AsUnknownTypeSyntax(), - MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, callInfoIdentifier, IdentifierName("ReturnValue")))); - - methodBodyStatements.Add(returnStatement); + return SymbolEqualityComparer.Default.GetHashCode(obj._pretendType); } - - return Block(methodBodyStatements); } } } diff --git a/src/Pretender.SourceGenerator/Emitter/SetupActionEmitter.cs b/src/Pretender.SourceGenerator/Emitter/SetupActionEmitter.cs index 4f8458a..3516cf7 100644 --- a/src/Pretender.SourceGenerator/Emitter/SetupActionEmitter.cs +++ b/src/Pretender.SourceGenerator/Emitter/SetupActionEmitter.cs @@ -1,221 +1,82 @@ using System.Collections.Immutable; using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; using Pretender.SourceGenerator.Parser; -using Pretender.SourceGenerator.SetupArguments; -using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; +using Pretender.SourceGenerator.Writing; namespace Pretender.SourceGenerator.Emitter { internal class SetupActionEmitter { - private readonly ImmutableArray _setupArgumentSpecs; + private readonly ImmutableArray _setupArgumentEmitters; private readonly KnownTypeSymbols _knownTypeSymbols; - public SetupActionEmitter(ITypeSymbol pretendType, IMethodSymbol setupMethod, ImmutableArray setupArgumentSpecs, KnownTypeSymbols knownTypeSymbols) + public SetupActionEmitter(ITypeSymbol pretendType, IMethodSymbol setupMethod, ImmutableArray setupArgumentEmitters, KnownTypeSymbols knownTypeSymbols) { PretendType = pretendType; SetupMethod = setupMethod; - _setupArgumentSpecs = setupArgumentSpecs; + _setupArgumentEmitters = setupArgumentEmitters; _knownTypeSymbols = knownTypeSymbols; } public ITypeSymbol PretendType { get; } public IMethodSymbol SetupMethod { get; } - public InvocationExpressionSyntax CreateSetupGetter(CancellationToken cancellationToken) + public void Emit(IndentedTextWriter writer, CancellationToken cancellationToken) { - var totalMatchStatements = _setupArgumentSpecs.Sum(sa => sa.NeededMatcherStatements); cancellationToken.ThrowIfCancellationRequested(); - var matchStatements = new StatementSyntax[totalMatchStatements]; - int addedStatements = 0; + writer.Write("pretend.GetOrCreateSetup"); - for (var i = 0; i < _setupArgumentSpecs.Length; i++) + var returnType = SetupMethod.ReturnType.SpecialType != SpecialType.System_Void + ? SetupMethod.ReturnType : null; + + if (returnType is not null) { - var argument = _setupArgumentSpecs[i]; + writer.Write($"<{returnType.ToUnknownTypeString()}>"); + } - var newStatements = argument.CreateMatcherStatements(cancellationToken); - cancellationToken.ThrowIfCancellationRequested(); + writer.WriteLine("(0, static (pretend, expr) =>"); + writer.WriteLine("{"); + writer.IncreaseIndent(); - newStatements.CopyTo(matchStatements, addedStatements); - addedStatements += newStatements.Length; - } + var anyEmitMatcherStatements = _setupArgumentEmitters.Any(e => e.EmitsMatcher); - ArgumentSyntax matcherArgument; - ImmutableArray statements; - if (matchStatements.Length == 0) + string matcherName; + if (anyEmitMatcherStatements) { - statements = ImmutableArray.Empty; + matcherName = "matchCall"; + writer.WriteLine("Matcher matchCall = (callInfo, target) =>"); + writer.WriteLine("{"); + writer.IncreaseIndent(); + + foreach (var argumentEmitter in _setupArgumentEmitters) + { + argumentEmitter.EmitArgumentMatcher(writer, cancellationToken); + } - // Nothing actually needs to match this will always return true, so we use a cached matcher that always returns true - matcherArgument = Argument(MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - ParseTypeName("Cache"), - IdentifierName("NoOpMatcher"))) - .WithNameColon(NameColon("matcher")); + writer.WriteLine("return true;"); + writer.DecreaseIndent(); + writer.WriteLine("};"); } else { - // Other match statements should have added all the ways the method could return false - // so if it gets through all those statements it should return true at the end. - var trueReturnStatement = ReturnStatement(LiteralExpression(SyntaxKind.TrueLiteralExpression)); - - /* - * Matcher matchCall = static (callInfo, target) => - * { - * ...matching calls... - * return true; - * } - */ - var matchCallIdentifier = Identifier("matchCall"); - - var matcherDelegate = ParenthesizedLambdaExpression( - ParameterList(SeparatedList([ - Parameter(Identifier("callInfo")), - Parameter(Identifier("target")) - ])), - Block(List([.. matchStatements, trueReturnStatement]))) - .WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword))); - - statements = ImmutableArray.Create(LocalDeclarationStatement(VariableDeclaration( - ParseTypeName("Matcher")) - .WithVariables(SingletonSeparatedList( - VariableDeclarator(matchCallIdentifier) - .WithInitializer(EqualsValueClause(matcherDelegate)))))); - - matcherArgument = Argument(IdentifierName(matchCallIdentifier)); + matcherName = "Cache.NoOpMatcher"; } - cancellationToken.ThrowIfCancellationRequested(); - - var objectCreationArguments = ArgumentList( - SeparatedList(new[] - { - Argument(IdentifierName("pretend")), - //Argument(IdentifierName("setupExpression")), - Argument(MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(PretendType.ToPretendName()), - IdentifierName(SetupMethod.ToMethodInfoCacheName()) - )), - matcherArgument, - Argument(MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("expr"), - IdentifierName("Target") - )), - })); - - cancellationToken.ThrowIfCancellationRequested(); - - GenericNameSyntax returnObjectName; - SimpleNameSyntax getOrCreateName; - if (SetupMethod.ReturnsVoid) + if (returnType is not null) { - // VoidCompiledSetup - returnObjectName = GenericName("VoidCompiledSetup") - .AddTypeArgumentListArguments(ParseTypeName(PretendType.ToFullDisplayString())); + var methodStrategy = _knownTypeSymbols.GetSingleMethodStrategy(SetupMethod); - getOrCreateName = IdentifierName("GetOrCreateSetup"); + // TODO: default value + writer.WriteLine($"return new ReturningCompiledSetup<{PretendType.ToFullDisplayString()}, {returnType.ToUnknownTypeString()}>(pretend, {PretendType.ToPretendName()}.{methodStrategy.UniqueName}_MethodInfo, {matcherName}, expr.Target, defaultValue: default);"); } else { - - // ReturningCompiledSetup - returnObjectName = GenericName("ReturningCompiledSetup") - .AddTypeArgumentListArguments( - ParseTypeName(PretendType.ToFullDisplayString()), - SetupMethod.ReturnType.AsUnknownTypeSyntax()); - - getOrCreateName = GenericName("GetOrCreateSetup") - .AddTypeArgumentListArguments(SetupMethod.ReturnType.AsUnknownTypeSyntax()); - - // TODO: Recursively mock? - ExpressionSyntax defaultValue; - - // TODO: Is this safe? - var namedType = (INamedTypeSymbol)SetupMethod.ReturnType; - - defaultValue = namedType.ToDefaultValueSyntax(_knownTypeSymbols); - - //if (SetupMethod.ReturnType.EqualsByName(["System", "Threading", "Tasks", "Task"])) - //{ - // if (SetupMethod.ReturnType is INamedTypeSymbol namedType && namedType.TypeArguments.Length == 1) - // { - // // Task.FromResult(default) - // defaultValue = KnownBlocks.TaskFromResult( - // namedType.TypeArguments[0].AsUnknownTypeSyntax(), - // LiteralExpression(SyntaxKind.DefaultLiteralExpression)); - // } - // else - // { - // // Task.CompletedTask - // defaultValue = KnownBlocks.TaskCompletedTask; - // } - //} - //else if (SetupMethod.ReturnType.EqualsByName(["System", "Threading", "Tasks", "ValueTask"])) - //{ - // if (SetupMethod.ReturnType is INamedTypeSymbol namedType && namedType.TypeArguments.Length == 1) - // { - // // ValueTask.FromResult(default) - // defaultValue = KnownBlocks.ValueTaskFromResult( - // namedType.TypeArguments[0].AsUnknownTypeSyntax(), - // LiteralExpression(SyntaxKind.DefaultLiteralExpression) - // ); - // } - // else - // { - // // ValueTask.CompletedTask - // defaultValue = KnownBlocks.ValueTaskCompletedTask; - // } - //} - //else - //{ - // // TODO: Support custom awaitable - // // default - // defaultValue = LiteralExpression(SyntaxKind.DefaultLiteralExpression); - //} - - cancellationToken.ThrowIfCancellationRequested(); - - objectCreationArguments = objectCreationArguments.AddArguments(Argument( - defaultValue).WithNameColon(NameColon("defaultValue"))); + writer.WriteLine($"return new VoidCompiledSetup<{PretendType.ToFullDisplayString()}>();"); } - cancellationToken.ThrowIfCancellationRequested(); - - var compiledSetupCreation = ObjectCreationExpression(returnObjectName) - .WithArgumentList(objectCreationArguments); - - // (pretend, expression) => - // { - // return new CompiledSetup(); - // } - var creator = ParenthesizedLambdaExpression() - .WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword))) - .AddParameterListParameters(Parameter(Identifier("pretend")), Parameter(Identifier("expr"))) - .AddBlockStatements([.. statements, ReturnStatement(compiledSetupCreation)]); - - // TODO: The hash code doesn't actually work, right now, this will create a new pretend every call. - // We likely need to create our own class that can calculate the hash code and place that number in here. - - cancellationToken.ThrowIfCancellationRequested(); - // TODO: Should I have a different seed? - //var badHashCode = _argumentSpecs.Aggregate(0, (agg, s) => HashCode.Combine(agg, s.GetHashCode())); - var badHashCode = 0; - - cancellationToken.ThrowIfCancellationRequested(); - - // pretend.GetOrCreateSetup() - return InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("pretend"), - getOrCreateName)) - .AddArgumentListArguments( - Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(badHashCode))), - Argument(creator), - Argument(IdentifierName("setupExpression"))); + writer.DecreaseIndent(); + writer.WriteLine("}, setupExpression);"); } } } diff --git a/src/Pretender.SourceGenerator/Emitter/SetupArgumentEmitter.cs b/src/Pretender.SourceGenerator/Emitter/SetupArgumentEmitter.cs new file mode 100644 index 0000000..ac3bcce --- /dev/null +++ b/src/Pretender.SourceGenerator/Emitter/SetupArgumentEmitter.cs @@ -0,0 +1,38 @@ +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; +using Pretender.SourceGenerator.SetupArguments; +using Pretender.SourceGenerator.Writing; + +namespace Pretender.SourceGenerator.Emitter +{ + internal abstract class SetupArgumentEmitter + { + protected SetupArgumentEmitter(SetupArgumentSpec argumentSpec) + { + ArgumentSpec = argumentSpec; + } + + protected SetupArgumentSpec ArgumentSpec { get; } + + public virtual bool EmitsMatcher => true; + public ImmutableArray NeededLocals { get; } + public bool NeedsCapturer { get; } + + public abstract void EmitArgumentMatcher(IndentedTextWriter writer, CancellationToken cancellationToken); + + protected void EmitArgumentAccessor(IndentedTextWriter writer) + { + // var name_arg = (string?)callInfo[0]; + writer.WriteLine($"var {ArgumentSpec.Parameter.Name}_arg = ({ArgumentSpec.Parameter.Type.ToUnknownTypeString()})callInfo.Arguments[{ArgumentSpec.Parameter.Ordinal}];"); + } + + protected void EmitIfReturnFalseCheck(IndentedTextWriter writer, string left, string right) + { + writer.WriteLine($"if ({left} != {right})"); + using (writer.WriteBlock()) + { + writer.WriteLine("return false;"); + } + } + } +} diff --git a/src/Pretender.SourceGenerator/Emitter/SetupEmitter.cs b/src/Pretender.SourceGenerator/Emitter/SetupEmitter.cs index e744dde..9b9d280 100644 --- a/src/Pretender.SourceGenerator/Emitter/SetupEmitter.cs +++ b/src/Pretender.SourceGenerator/Emitter/SetupEmitter.cs @@ -1,8 +1,5 @@ -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; -using Microsoft.CodeAnalysis.Operations; +using Microsoft.CodeAnalysis.Operations; +using Pretender.SourceGenerator.Writing; namespace Pretender.SourceGenerator.Emitter { @@ -18,37 +15,34 @@ public SetupEmitter(SetupActionEmitter setupActionEmitter, IInvocationOperation } // TODO: Run cancellationToken a lot more - public MemberDeclarationSyntax Emit(int index, CancellationToken cancellationToken) + public void Emit(IndentedTextWriter writer, int index, CancellationToken cancellationToken) { var setupMethod = _setupActionEmitter.SetupMethod; var pretendType = _setupActionEmitter.PretendType; - var interceptsLocation = new InterceptsLocationInfo(_setupInvocation); - - // TODO: This is wrong - var typeArguments = setupMethod.ReturnsVoid - ? TypeArgumentList(SingletonSeparatedList(ParseTypeName(pretendType.ToFullDisplayString()))) - : TypeArgumentList(SeparatedList([ParseTypeName(pretendType.ToFullDisplayString()), setupMethod.ReturnType.AsUnknownTypeSyntax()])); - - var returnType = GenericName("IPretendSetup") - .WithTypeArgumentList(typeArguments); - - var setupCreatorInvocation = _setupActionEmitter.CreateSetupGetter(cancellationToken); - - return MethodDeclaration(returnType, $"Setup{index}") - .WithBody(Block(ReturnStatement(setupCreatorInvocation))) - .WithParameterList(ParameterList(SeparatedList(new[] - { - Parameter(Identifier("pretend")) - .WithModifiers(TokenList(Token(SyntaxKind.ThisKeyword))) - .WithType(ParseTypeName($"Pretend<{pretendType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}>")), - - Parameter(Identifier("setupExpression")) - .WithType(GenericName(setupMethod.ReturnsVoid ? "Action" : "Func").WithTypeArgumentList(typeArguments)), - }))) - .WithModifiers(TokenList(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.StaticKeyword))) - .WithAttributeLists(SingletonList(AttributeList( - SingletonSeparatedList(interceptsLocation.ToAttributeSyntax())))); + var location = new InterceptsLocationInfo(_setupInvocation); + + string typeArgs; + string actionType; + if (setupMethod.ReturnsVoid) + { + typeArgs = $"<{pretendType.ToFullDisplayString()}>"; + actionType = "Action"; + } + else + { + typeArgs = $"<{pretendType.ToFullDisplayString()}, {setupMethod.ReturnType.ToUnknownTypeString()}>"; + actionType = "Func"; + } + + writer.WriteLine(@$"[InterceptsLocation(@""{location.FilePath}"", {location.LineNumber}, {location.CharacterNumber})]"); + writer.Write($"internal static IPretendSetup{typeArgs} Setup{index}"); + writer.WriteLine($"(this Pretend<{pretendType.ToUnknownTypeString()}> pretend, {actionType}{typeArgs} setupExpression)"); + using (writer.WriteBlock()) + { + writer.Write("return "); + _setupActionEmitter.Emit(writer, cancellationToken); + } } } } diff --git a/src/Pretender.SourceGenerator/Emitter/VerifyEmitter.cs b/src/Pretender.SourceGenerator/Emitter/VerifyEmitter.cs index da9c6c8..446fc18 100644 --- a/src/Pretender.SourceGenerator/Emitter/VerifyEmitter.cs +++ b/src/Pretender.SourceGenerator/Emitter/VerifyEmitter.cs @@ -1,8 +1,6 @@ using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Operations; -using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; +using Pretender.SourceGenerator.Writing; namespace Pretender.SourceGenerator.Emitter { @@ -21,47 +19,25 @@ public VerifyEmitter(ITypeSymbol pretendType, ITypeSymbol? returnType, SetupActi _invocationOperation = invocationOperation; } - public MethodDeclarationSyntax Emit(int index, CancellationToken cancellationToken) + public void Emit(IndentedTextWriter writer, int index, CancellationToken cancellationToken) { - var setupInvocation = _setupActionEmitter.CreateSetupGetter(cancellationToken); - - // var setup = pretend.GetOrCreateSetup(...); - var setupLocal = LocalDeclarationStatement(VariableDeclaration(CommonSyntax.VarType) - .WithVariables(SingletonSeparatedList(VariableDeclarator(CommonSyntax.SetupIdentifier) - .WithInitializer(EqualsValueClause(setupInvocation))))); - - TypeSyntax pretendType = ParseTypeName(_pretendType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); - - TypeSyntax setupExpressionType = _returnType == null - ? GenericName("Action").AddTypeArgumentListArguments(pretendType) - : GenericName("Func").AddTypeArgumentListArguments(pretendType, _returnType.AsUnknownTypeSyntax()); - - // setup.Verify(called); - var verifyInvocation = InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(CommonSyntax.SetupIdentifier), - IdentifierName("Verify") - ) - ) - .AddArgumentListArguments(Argument(IdentifierName(CommonSyntax.CalledIdentifier))); - - var interceptsInfo = new InterceptsLocationInfo(_invocationOperation); - - // public static void Verify0( - return MethodDeclaration(CommonSyntax.VoidType, Identifier($"Verify{index}")) - .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword)) - .AddParameterListParameters( - [ - Parameter(Identifier("pretend")) - .WithType(CommonSyntax.GenericPretendType.AddTypeArgumentListArguments(pretendType)) - .AddModifiers(Token(SyntaxKind.ThisKeyword)), - Parameter(Identifier("setupExpression")) - .WithType(setupExpressionType), - CommonSyntax.CalledParameter - ]) - .AddBodyStatements([setupLocal, ExpressionStatement(verifyInvocation)]) - .AddAttributeLists(AttributeList(SingletonSeparatedList(interceptsInfo.ToAttributeSyntax()))); + // TODO: Create method trivia + var location = new InterceptsLocationInfo(_invocationOperation); + writer.WriteLine(@$"[InterceptsLocation(@""{location.FilePath}"", {location.LineNumber}, {location.CharacterNumber})]"); + + // TODO: Get property setup expression type + var setupExpressionType = _returnType != null + ? $"Func<{_pretendType.ToFullDisplayString()}, {_returnType.ToUnknownTypeString()}>" + : $"Action<{_pretendType.ToFullDisplayString()}>"; + + + writer.WriteLine($"internal static void Verify{index}(this Pretend<{_pretendType.ToUnknownTypeString()}> pretend, {setupExpressionType} setupExpression, Called called)"); + using (writer.WriteBlock()) + { + writer.Write("var setup = "); + _setupActionEmitter.Emit(writer, cancellationToken); + writer.WriteLine("setup.Verify(called);"); + } } } } diff --git a/src/Pretender.SourceGenerator/Fakes/IKnownFake.cs b/src/Pretender.SourceGenerator/Fakes/IKnownFake.cs index 1b1a626..ea15cbb 100644 --- a/src/Pretender.SourceGenerator/Fakes/IKnownFake.cs +++ b/src/Pretender.SourceGenerator/Fakes/IKnownFake.cs @@ -5,6 +5,6 @@ namespace Pretender.SourceGenerator.Fakes { internal interface IKnownFake { - bool TryConstruct(INamedTypeSymbol typeSymbol, KnownTypeSymbols knownTypeSymbols, CancellationToken cancellationToken, out ITypeSymbol? fakeType); + bool TryConstruct(INamedTypeSymbol typeSymbol, KnownTypeSymbols knownTypeSymbols, CancellationToken cancellationToken, out INamedTypeSymbol? fakeType); } } diff --git a/src/Pretender.SourceGenerator/Fakes/ILoggerFake.cs b/src/Pretender.SourceGenerator/Fakes/ILoggerFake.cs index 7e8ff76..e16bff6 100644 --- a/src/Pretender.SourceGenerator/Fakes/ILoggerFake.cs +++ b/src/Pretender.SourceGenerator/Fakes/ILoggerFake.cs @@ -5,7 +5,7 @@ namespace Pretender.SourceGenerator.Fakes { internal class ILoggerFake : IKnownFake { - public bool TryConstruct(INamedTypeSymbol typeSymbol, KnownTypeSymbols knownTypeSymbols, CancellationToken cancellationToken, out ITypeSymbol? fakeType) + public bool TryConstruct(INamedTypeSymbol typeSymbol, KnownTypeSymbols knownTypeSymbols, CancellationToken cancellationToken, out INamedTypeSymbol? fakeType) { fakeType = null; if (SymbolEqualityComparer.Default.Equals(typeSymbol, knownTypeSymbols.MicrosoftExtensionsLoggingILogger)) diff --git a/src/Pretender.SourceGenerator/Invocation/PretendInvocation.cs b/src/Pretender.SourceGenerator/Invocation/PretendInvocation.cs index fefe32a..856b408 100644 --- a/src/Pretender.SourceGenerator/Invocation/PretendInvocation.cs +++ b/src/Pretender.SourceGenerator/Invocation/PretendInvocation.cs @@ -8,14 +8,14 @@ namespace Pretender.SourceGenerator.Invocation { internal class PretendInvocation { - public PretendInvocation(ITypeSymbol pretendType, Location location, bool fillExisting) + public PretendInvocation(INamedTypeSymbol pretendType, Location location, bool fillExisting) { PretendType = pretendType; Location = location; FillExisting = fillExisting; } - public ITypeSymbol PretendType { get; } + public INamedTypeSymbol PretendType { get; } public Location Location { get; } public bool FillExisting { get; } @@ -70,12 +70,12 @@ public static bool IsCandidateSyntaxNode(SyntaxNode node) } return CreateFromTypeSymbol( - operation.TargetMethod.TypeArguments[0], + (INamedTypeSymbol)operation.TargetMethod.TypeArguments[0], // This should be a totally safe cast operation.Syntax.GetLocation(), fillExisting: false); } - private static PretendInvocation? CreateFromTypeSymbol(ITypeSymbol typeSymbol, Location location, bool fillExisting) + private static PretendInvocation? CreateFromTypeSymbol(INamedTypeSymbol typeSymbol, Location location, bool fillExisting) { // TODO: Maybe check that ITypeSymbol is INamedTypeSymbol? return new PretendInvocation(typeSymbol, location, fillExisting); diff --git a/src/Pretender.SourceGenerator/KnownBlocks.cs b/src/Pretender.SourceGenerator/KnownBlocks.cs index df93fc4..6911176 100644 --- a/src/Pretender.SourceGenerator/KnownBlocks.cs +++ b/src/Pretender.SourceGenerator/KnownBlocks.cs @@ -9,7 +9,7 @@ internal static class KnownBlocks { private static readonly AssemblyName s_assemblyName = typeof(KnownBlocks).Assembly.GetName(); private static readonly string GeneratedCodeAnnotationString = $@"[GeneratedCode(""{s_assemblyName.Name}"", ""{s_assemblyName.Version}"")]"; - public static MemberDeclarationSyntax InterceptsLocationAttribute { get; } = ((CompilationUnitSyntax)ParseSyntaxTree($$""" + public static string InterceptsLocationAttribute { get; } = $$""" namespace System.Runtime.CompilerServices { using System; @@ -24,7 +24,7 @@ public InterceptsLocationAttribute(string filePath, int line, int column) } } } - """).GetRoot()).Members[0]; + """; public static NamespaceDeclarationSyntax OurNamespace { get; } = NamespaceDeclaration(IdentifierName("Pretender.SourceGeneration")); diff --git a/src/Pretender.SourceGenerator/Parser/KnownTypeSymbols.cs b/src/Pretender.SourceGenerator/Parser/KnownTypeSymbols.cs index e5e2a6a..9fe40b4 100644 --- a/src/Pretender.SourceGenerator/Parser/KnownTypeSymbols.cs +++ b/src/Pretender.SourceGenerator/Parser/KnownTypeSymbols.cs @@ -1,15 +1,19 @@ -using Microsoft.CodeAnalysis; +using System.Collections.Concurrent; +using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; namespace Pretender.SourceGenerator.Parser { internal sealed class KnownTypeSymbols { + private readonly ConcurrentDictionary> _cachedTypeMethodNames = new(SymbolEqualityComparer.Default); + public Compilation Compilation { get; } // INamedTypeSymbols public INamedTypeSymbol? Pretend { get; } public INamedTypeSymbol? Pretend_Unbound { get; } + public INamedTypeSymbol? AnyMatcher { get; } public INamedTypeSymbol String { get; } public INamedTypeSymbol? Task { get; } @@ -33,6 +37,8 @@ public KnownTypeSymbols(Compilation compilation) // TODO: Get known types Pretend = compilation.GetTypeByMetadataName("Pretender.Pretend`1"); Pretend_Unbound = Pretend?.ConstructUnboundGenericType(); + AnyMatcher = compilation.GetTypeByMetadataName("Pretender.Matchers.AnyMatcher"); + String = compilation.GetSpecialType(SpecialType.System_String); Task = compilation.GetTypeByMetadataName("System.Threading.Tasks.Task"); @@ -57,6 +63,38 @@ public KnownTypeSymbols(Compilation compilation) // TODO: FakeTimeProvider } + public MethodStrategy GetSingleMethodStrategy(IMethodSymbol method) + { + return GetTypesStrategies(method.ContainingType)[method]; + } + + public IReadOnlyDictionary GetTypesStrategies(INamedTypeSymbol type) + { + return _cachedTypeMethodNames.GetOrAdd( + type, + static (type) => + { + Dictionary methodDictionary = new(SymbolEqualityComparer.Default); + var groupedByNameMethods = type.GetApplicableMethods() + .GroupBy(m => m.Name); + + foreach (var groupedByNameMethod in groupedByNameMethods) + { + var methods = groupedByNameMethod.ToArray(); + if (methods.Length == 1) + { + methodDictionary.Add(methods[0], new ByNameMethodStrategy(methods[0])); + continue; + } + + // More than on method has this name, next try number of arguments + } + + return methodDictionary; + } + ); + } + public static bool IsPretend(INamedTypeSymbol type) { // This should be enough diff --git a/src/Pretender.SourceGenerator/Parser/MethodStrategy.cs b/src/Pretender.SourceGenerator/Parser/MethodStrategy.cs new file mode 100644 index 0000000..1c2db46 --- /dev/null +++ b/src/Pretender.SourceGenerator/Parser/MethodStrategy.cs @@ -0,0 +1,86 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Pretender.SourceGenerator.Writing; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace Pretender.SourceGenerator.Parser +{ + internal abstract class MethodStrategy + { + protected MethodStrategy(IMethodSymbol method) + { + Method = method; + } + + public IMethodSymbol Method { get; } + public abstract string UniqueName { get; } + + public abstract void EmitMethodGetter(IndentedTextWriter writer, CancellationToken cancellationToken); + + protected static InvocationExpressionSyntax NameOfExpression(string identifier) + { + var text = SyntaxFacts.GetText(SyntaxKind.NameOfKeyword); + + var identifierSyntax = Identifier(default, + SyntaxKind.NameOfKeyword, + text, + text, + default); + + return InvocationExpression( + IdentifierName(identifierSyntax), + ArgumentList(SingletonSeparatedList(Argument(IdentifierName(identifier))))); + } + } + + internal class ByNameMethodStrategy : MethodStrategy + { + public ByNameMethodStrategy(IMethodSymbol method) + : base(method) + { + + } + + public override string UniqueName => Method.Name; + + public override void EmitMethodGetter(IndentedTextWriter writer, CancellationToken cancellationToken) + { + if (Method.MethodKind == MethodKind.Ordinary) + { + writer.Write($".GetMethod(nameof({Method.Name}))!;"); + } + else if (Method.MethodKind == MethodKind.PropertyGet) + { + writer.Write($".GetProperty(nameof({Method.AssociatedSymbol!.Name}))!.GetMethod;"); + } + else if (Method.MethodKind == MethodKind.PropertySet) + { + writer.Write($".GetProperty(nameof({Method.AssociatedSymbol!.Name}))!.SetMethod;"); + } + else + { + throw new InvalidOperationException($"Did not expect {Method.MethodKind}"); + } + } + } + + internal class ByParameterCountMethodStrategy : MethodStrategy + { + private readonly int _parameterCount; + + public ByParameterCountMethodStrategy(IMethodSymbol method, int parameterCount) + : base(method) + { + _parameterCount = parameterCount; + } + + // TODO: Is this the naming we want? I think I just want an incrementing index + public override string UniqueName => $"{Method.Name}_{_parameterCount}"; + + public override void EmitMethodGetter(IndentedTextWriter writer, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } + } +} diff --git a/src/Pretender.SourceGenerator/Parser/PretendParser.cs b/src/Pretender.SourceGenerator/Parser/PretendParser.cs index b1ec1f8..502050f 100644 --- a/src/Pretender.SourceGenerator/Parser/PretendParser.cs +++ b/src/Pretender.SourceGenerator/Parser/PretendParser.cs @@ -9,7 +9,7 @@ namespace Pretender.SourceGenerator.Parser { internal class PretendParser { - private static readonly List _knownFakes = + private static readonly List s_knownFakes = [ new ILoggerFake(), ]; @@ -43,21 +43,25 @@ public PretendParser(PretendInvocation pretendInvocation, KnownTypeSymbols known if (_settings.Behavior == PretendBehavior.PreferFakes) { - foreach (var fake in _knownFakes) + foreach (var fake in s_knownFakes) { cancellationToken.ThrowIfCancellationRequested(); - if (fake.TryConstruct((INamedTypeSymbol)pretendType, _knownTypeSymbols, cancellationToken, out var fakeType)) + if (fake.TryConstruct(pretendType, _knownTypeSymbols, cancellationToken, out var fakeType)) { // TODO: Do something } } } + var methodStrategies = _knownTypeSymbols.GetTypesStrategies(pretendType); + + cancellationToken.ThrowIfCancellationRequested(); + // TODO: Do a larger amount of parsing cancellationToken.ThrowIfCancellationRequested(); - return (new PretendEmitter(PretendInvocation.PretendType, PretendInvocation.FillExisting), null); + return (new PretendEmitter(PretendInvocation.PretendType, methodStrategies, PretendInvocation.FillExisting), null); } } } diff --git a/src/Pretender.SourceGenerator/Parser/SetupActionParser.cs b/src/Pretender.SourceGenerator/Parser/SetupActionParser.cs index 168cea8..c152dcf 100644 --- a/src/Pretender.SourceGenerator/Parser/SetupActionParser.cs +++ b/src/Pretender.SourceGenerator/Parser/SetupActionParser.cs @@ -42,13 +42,37 @@ public SetupActionParser(IOperation setupActionArgument, ITypeSymbol pretendType var arguments = candidate.Arguments; - var builder = ImmutableArray.CreateBuilder(arguments.Length); + var builder = ImmutableArray.CreateBuilder(arguments.Length); + + var argumentDiagnostics = new List(); + for (var i = 0; i < arguments.Length; i++) { - builder.Add(SetupArgumentSpec.Create(arguments[i], i)); + var argumentSpec = new SetupArgumentSpec(arguments[i], _knownTypeSymbols); + var argumentParser = new SetupArgumentParser(argumentSpec); + + var (emitter, diagnostics) = argumentParser.Parse(cancellationToken); + + // If any emitter comes back null, return the diagnostics it came back with and all the ones we've collected for other parsing operations + if (emitter == null) + { + Debug.Assert(diagnostics.HasValue); + argumentDiagnostics.AddRange(diagnostics!.Value); + return (null, argumentDiagnostics.ToImmutableArray()); + } + + if (diagnostics is ImmutableArray parseDiagnostics) + { + argumentDiagnostics.AddRange(parseDiagnostics); + } + + builder.Add(emitter); } - return (new SetupActionEmitter(_pretendType, candidate.Method, builder.MoveToImmutable(), _knownTypeSymbols), null); + return ( + new SetupActionEmitter(_pretendType, candidate.Method, builder.MoveToImmutable(), _knownTypeSymbols), + argumentDiagnostics.Count != 0 ? argumentDiagnostics.ToImmutableArray() : null + ); } private ImmutableArray GetInvocationCandidates(CancellationToken cancellationToken) diff --git a/src/Pretender.SourceGenerator/Pretender.SourceGenerator.csproj b/src/Pretender.SourceGenerator/Pretender.SourceGenerator.csproj index f0dec8b..66d80c1 100644 --- a/src/Pretender.SourceGenerator/Pretender.SourceGenerator.csproj +++ b/src/Pretender.SourceGenerator/Pretender.SourceGenerator.csproj @@ -15,6 +15,10 @@ all + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + diff --git a/src/Pretender.SourceGenerator/PretenderSettings.cs b/src/Pretender.SourceGenerator/PretenderSettings.cs index e02acdb..49fedf8 100644 --- a/src/Pretender.SourceGenerator/PretenderSettings.cs +++ b/src/Pretender.SourceGenerator/PretenderSettings.cs @@ -3,6 +3,7 @@ namespace Pretender.SourceGenerator { + // The properties on this class need to match the property names in PretenderSettingsAttribute internal class PretenderSettings { public static PretenderSettings Default { get; } = new PretenderSettings( diff --git a/src/Pretender.SourceGenerator/PretenderSourceGenerator.cs b/src/Pretender.SourceGenerator/PretenderSourceGenerator.cs index 1663b8e..4b82dda 100644 --- a/src/Pretender.SourceGenerator/PretenderSourceGenerator.cs +++ b/src/Pretender.SourceGenerator/PretenderSourceGenerator.cs @@ -102,7 +102,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) var createEmitters = ReportDiagnostics(context, createEmittersWithDiagnostics); context.RegisterSourceOutput( - pretendEmitters.Collect() + pretendEmitters.GroupWith(e => e, PretendEmitter.Comparer.Default).Select((t, _) => t.Source).Collect() .Combine(setups.Collect()) .Combine(verifyEmitters.Collect()) .Combine(createEmitters.Collect()), (context, emitters) => @@ -113,11 +113,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context) var grandEmitter = new GrandEmitter(pretends, setups, verifies, creates); - var compilationUnit = grandEmitter.Emit(context.CancellationToken); + var sourceText = grandEmitter.Emit(context.CancellationToken); context.CancellationToken.ThrowIfCancellationRequested(); - context.AddSource("Pretender.g.cs", compilationUnit.GetText(Encoding.UTF8)); + context.AddSource("Pretender.g.cs", sourceText); }); } diff --git a/src/Pretender.SourceGenerator/SetupArguments/InvocationArgumentSpec.cs b/src/Pretender.SourceGenerator/SetupArguments/InvocationArgumentSpec.cs deleted file mode 100644 index 3f4fd1d..0000000 --- a/src/Pretender.SourceGenerator/SetupArguments/InvocationArgumentSpec.cs +++ /dev/null @@ -1,188 +0,0 @@ -using System.Collections.Immutable; -using System.Diagnostics; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Operations; -using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; - -namespace Pretender.SourceGenerator.SetupArguments -{ - // TODO: Should probably have Specs for "safe" invocations and capturing specs - internal class InvocationArgumentSpec : SetupArgumentSpec - { - private readonly IInvocationOperation _invocationOperation; - - public InvocationArgumentSpec( - IInvocationOperation invocationOperation, - IArgumentOperation originalArgument, - int argumentPlacement) - : base(originalArgument, argumentPlacement) - { - _invocationOperation = invocationOperation; - } - - private ImmutableArray? _cachedMatcherStatements; - - public override int NeededMatcherStatements - { - get - { - if (TryGetMatcherAttributeType(out var matcherType)) - { - // TODO: Match with KnownTypeSymbols - if (matcherType.EqualsByName(["Pretender", "Matchers", "AnyMatcher"])) - { - _cachedMatcherStatements = ImmutableArray.Empty; - return 0; - } - - var arguments = new ArgumentSyntax[_invocationOperation.Arguments.Length]; - bool allArgumentsSafe = true; - - for (int i = 0; i < arguments.Length; i++) - { - var arg = _invocationOperation.Arguments[i]; - if (arg.Value is ILiteralOperation literalOperation) - { - arguments[i] = Argument(literalOperation.ToLiteralExpression()); - } - else if (arg.Value is IDelegateCreationOperation delegateCreation) - { - - if (delegateCreation.Target is IAnonymousFunctionOperation anonymousFunctionOperation) - { - if (anonymousFunctionOperation.Symbol.IsStatic) // This isn't enough either though, they could call a static method that only exists in their context - { - // If it's guaranteed to be static, we can just rewrite it in our code - arguments[i] = Argument(ParseExpression(delegateCreation.Syntax.GetText().ToString())); - } - else if (false) // Is non-scope capturing - { - // This is a lot more work but also very powerful in terms of speed - // We need to rewrite the delegate and replace all local references with our getter - allArgumentsSafe = false; - } - else - { - // We need a static matcher - allArgumentsSafe = false; - } - } - else - { - allArgumentsSafe = false; - } - } - else - { - allArgumentsSafe = false; - } - } - - if (!allArgumentsSafe) - { - _cachedMatcherStatements = ImmutableArray.Empty; - return 0; - } - - var statements = new StatementSyntax[3]; - var (identifier, accessor) = CreateArgumentAccessor(); - statements[0] = accessor; - - var matcherLocalName = $"{Parameter.Name}_matcher"; - - // var name_matcher = new global::MyMatcher(arg0, arg1); - statements[1] = LocalDeclarationStatement( - VariableDeclaration(ParseTypeName("var")) - .AddVariables(VariableDeclarator(matcherLocalName) - .WithInitializer(EqualsValueClause(ObjectCreationExpression(ParseTypeName(matcherType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))) - .AddArgumentListArguments(arguments)) - ) - ) - ); - - statements[2] = CreateIfCheck( - PrefixUnaryExpression( - SyntaxKind.LogicalNotExpression, - InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(matcherLocalName), - IdentifierName("Matches") - ) - ) - .AddArgumentListArguments(Argument(IdentifierName(identifier))) - ) - ); - - _cachedMatcherStatements = ImmutableArray.Create(statements); - return 3; - } - else - { - _cachedMatcherStatements = ImmutableArray.Empty; - return 0; - } - } - } - - public override ImmutableArray CreateMatcherStatements(CancellationToken cancellationToken) - { - Debug.Assert(_cachedMatcherStatements.HasValue, "Should have called NeededStatements first."); - return _cachedMatcherStatements!.Value; - } - - private bool TryGetMatcherAttributeType(out INamedTypeSymbol matcherType) - { - var allAttributes = _invocationOperation.TargetMethod.GetAttributes(); - var matcherAttribute = allAttributes.Single(ad => ad.AttributeClass!.EqualsByName(["Pretender", "Matchers", "MatcherAttribute"])); - - matcherType = null!; - - if (matcherAttribute.AttributeClass!.IsGenericType) - { - // We are in the typed version, get the generic arg - matcherType = (INamedTypeSymbol)matcherAttribute.AttributeClass.TypeArguments[0]; - } - else - { - // We are in the base version, get the constructor arg - // TODO: Make this work - // matcherType = matcherAttribute.ConstructorArguments[0]; - var attributeType = matcherAttribute.ConstructorArguments[0]; - // TODO: When can Type be null? - if (!attributeType.Type!.EqualsByName(["System", "Type"])) - { - return false; - } - - if (attributeType.Value is null) - { - return false; - } - - matcherType = (INamedTypeSymbol)attributeType.Value!; - } - - // TODO: Write a lot more tests for this - if (matcherType.IsUnboundGenericType) - { - if (_invocationOperation.TargetMethod.TypeArguments.Length != matcherType.TypeArguments.Length) - { - return false; - } - - matcherType = matcherType.ConstructedFrom.Construct([.._invocationOperation.TargetMethod.TypeArguments]); - } - - return true; - } - - public override int GetHashCode() - { - // TODO: This is not enough for uniqueness - return SymbolEqualityComparer.Default.GetHashCode(_invocationOperation.TargetMethod); - } - } -} diff --git a/src/Pretender.SourceGenerator/SetupArguments/LiteralArgumentEmitter.cs b/src/Pretender.SourceGenerator/SetupArguments/LiteralArgumentEmitter.cs new file mode 100644 index 0000000..1c0958d --- /dev/null +++ b/src/Pretender.SourceGenerator/SetupArguments/LiteralArgumentEmitter.cs @@ -0,0 +1,32 @@ +using Microsoft.CodeAnalysis.Operations; +using Pretender.SourceGenerator.Emitter; +using Pretender.SourceGenerator.Writing; + +namespace Pretender.SourceGenerator.SetupArguments +{ + internal class LiteralArgumentEmitter : SetupArgumentEmitter + { + private readonly ILiteralOperation _literalOperation; + + public LiteralArgumentEmitter(ILiteralOperation literalOperation, SetupArgumentSpec argumentSpec) + : base(argumentSpec) + { + _literalOperation = literalOperation; + } + + public override void EmitArgumentMatcher(IndentedTextWriter writer, CancellationToken cancellationToken) + { + EmitArgumentAccessor(writer); + EmitIfReturnFalseCheck(writer, + $"{ArgumentSpec.Parameter.Name}_arg", + CSharpSyntaxUtilities.FormatLiteral(_literalOperation.ConstantValue.Value, ArgumentSpec.Parameter.Type)); + } + + public override int GetHashCode() + { + return _literalOperation.ConstantValue.HasValue + ? _literalOperation.ConstantValue.Value?.GetHashCode() ?? 41602 // TODO: Magic value? + : 1337; // TODO: Magic value? + } + } +} diff --git a/src/Pretender.SourceGenerator/SetupArguments/LiteralArgumentSpec.cs b/src/Pretender.SourceGenerator/SetupArguments/LiteralArgumentSpec.cs deleted file mode 100644 index e953713..0000000 --- a/src/Pretender.SourceGenerator/SetupArguments/LiteralArgumentSpec.cs +++ /dev/null @@ -1,34 +0,0 @@ -using System.Collections.Immutable; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Operations; -using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; - -namespace Pretender.SourceGenerator.SetupArguments -{ - internal class LiteralArgumentSpec : SetupArgumentSpec - { - private readonly ILiteralOperation _literalOperation; - - public LiteralArgumentSpec(ILiteralOperation literalOperation, IArgumentOperation originalArgument, int argumentPlacement) - : base(originalArgument, argumentPlacement) - { - _literalOperation = literalOperation; - } - - public override int NeededMatcherStatements => 2; - - public override ImmutableArray CreateMatcherStatements(CancellationToken cancellationToken) - { - var (argumentName, localDeclaration) = CreateArgumentAccessor(); - var ifCheck = CreateIfCheck(IdentifierName(argumentName), _literalOperation.ToLiteralExpression()); - return ImmutableArray.Create([localDeclaration, ifCheck]); - } - - public override int GetHashCode() - { - return _literalOperation.ConstantValue.HasValue - ? _literalOperation.ConstantValue.Value?.GetHashCode() ?? 41602 // TODO: Magic value? - : 1337; // TODO: Magic value? - } - } -} diff --git a/src/Pretender.SourceGenerator/SetupArguments/LocalReferenceArgumentEmitter.cs b/src/Pretender.SourceGenerator/SetupArguments/LocalReferenceArgumentEmitter.cs new file mode 100644 index 0000000..8a146c0 --- /dev/null +++ b/src/Pretender.SourceGenerator/SetupArguments/LocalReferenceArgumentEmitter.cs @@ -0,0 +1,30 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Operations; +using Pretender.SourceGenerator.Emitter; +using Pretender.SourceGenerator.Writing; + +namespace Pretender.SourceGenerator.SetupArguments +{ + internal class LocalReferenceArgumentEmitter : SetupArgumentEmitter + { + private readonly ILocalReferenceOperation _localReferenceOperation; + + public LocalReferenceArgumentEmitter(ILocalReferenceOperation localReferenceOperation, SetupArgumentSpec argumentSpec) : base(argumentSpec) + { + _localReferenceOperation = localReferenceOperation; + } + + public override void EmitArgumentMatcher(IndentedTextWriter writer, CancellationToken cancellationToken) + { + var localVariableName = $"{ArgumentSpec.Parameter.Name}_local"; + EmitArgumentAccessor(writer); + writer.WriteLine(@$"var {localVariableName} = target.GetType().GetField(""{_localReferenceOperation.Local.Name}"").GetValue(target);"); + EmitIfReturnFalseCheck(writer, $"{ArgumentSpec.Parameter.Name}_arg", localVariableName); + } + + public override int GetHashCode() + { + return SymbolEqualityComparer.Default.GetHashCode(_localReferenceOperation.Local); + } + } +} diff --git a/src/Pretender.SourceGenerator/SetupArguments/LocalReferenceArgumentSpec.cs b/src/Pretender.SourceGenerator/SetupArguments/LocalReferenceArgumentSpec.cs deleted file mode 100644 index 3060e11..0000000 --- a/src/Pretender.SourceGenerator/SetupArguments/LocalReferenceArgumentSpec.cs +++ /dev/null @@ -1,92 +0,0 @@ -using System; -using System.Collections.Immutable; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Operations; -using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; - -namespace Pretender.SourceGenerator.SetupArguments -{ - internal class LocalReferenceArgumentSpec : SetupArgumentSpec - { - private readonly ILocalReferenceOperation _localReferenceOperation; - - public LocalReferenceArgumentSpec( - ILocalReferenceOperation localReferenceOperation, - IArgumentOperation originalArgument, - int argumentPlacement) : base(originalArgument, argumentPlacement) - { - _localReferenceOperation = localReferenceOperation; - } - - public override int NeededMatcherStatements => 3; - - public override ImmutableArray CreateMatcherStatements(CancellationToken cancellationToken) - { - var variableName = $"{Parameter.Name}_local"; - var (identifier, accessor) = CreateArgumentAccessor(); - - // This is for calling the UnsafeAccessor method that doesn't seem to work for my needs - //statements[1] = LocalDeclarationStatement(VariableDeclaration(ParseTypeName("var")) - // .AddVariables(VariableDeclarator(variableName) - // .WithInitializer(EqualsValueClause(InvocationExpression( - // MemberAccessExpression( - // SyntaxKind.SimpleMemberAccessExpression, - // IdentifierName($"Setup{index}Accessor"), - // IdentifierName(((ILocalReferenceOperation)ArgumentOperation.Value).Local.Name) - // ) - // ) - // .AddArgumentListArguments(Argument(IdentifierName("target"))))))); - - - //statements[1] = LocalDeclarationStatement(VariableDeclaration(localOperation.Local.Type.AsUnknownTypeSyntax()) - // .AddVariables(VariableDeclarator(variableName) - // .WithInitializer(EqualsValueClause( - // MemberAccessExpression( - // SyntaxKind.SimpleMemberAccessExpression, - // ParenthesizedExpression(CastExpression(ParseTypeName("dynamic"), IdentifierName("target"))), - // IdentifierName(localOperation.Local.Name)))) - // ) - // ); - - // var arg_local = target.GetType().GetField("local").GetValue(target); - - // target.GetType() - var getTypeInvocation = InvocationExpression(MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("target"), - IdentifierName("GetType"))); - - // target.GetType().GetField("local")! - var getFieldInvocation = PostfixUnaryExpression(SyntaxKind.SuppressNullableWarningExpression, InvocationExpression(MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - getTypeInvocation, - IdentifierName("GetField"))) - .AddArgumentListArguments(Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(_localReferenceOperation.Local.Name))))); - - var getValueInvocation = CastExpression(_localReferenceOperation.Local.Type.AsUnknownTypeSyntax(), InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - getFieldInvocation, - IdentifierName("GetValue"))) - .AddArgumentListArguments(Argument(IdentifierName("target")))); - - // TODO: This really sucks, but neither other way works - var localDeclaration = LocalDeclarationStatement(VariableDeclaration(ParseTypeName("var"))) - .AddDeclarationVariables(VariableDeclarator(variableName) - .WithInitializer(EqualsValueClause(getValueInvocation))); - //statements[1] = ExpressionStatement( - // ParseExpression($"var {variableName} = target.GetType().GetField(\"{localOperation.Local.Name}\")!.GetValue(target)") - //); - - var ifCheck = CreateIfCheck(IdentifierName(identifier), IdentifierName(variableName)); - return ImmutableArray.Create([accessor, localDeclaration, ifCheck]); - } - - public override int GetHashCode() - { - return SymbolEqualityComparer.Default.GetHashCode(_localReferenceOperation.Local); - } - } -} diff --git a/src/Pretender.SourceGenerator/SetupArguments/SetupArgumentSpec.cs b/src/Pretender.SourceGenerator/SetupArguments/SetupArgumentParser.cs similarity index 58% rename from src/Pretender.SourceGenerator/SetupArguments/SetupArgumentSpec.cs rename to src/Pretender.SourceGenerator/SetupArguments/SetupArgumentParser.cs index d5abde0..68a5dc9 100644 --- a/src/Pretender.SourceGenerator/SetupArguments/SetupArgumentSpec.cs +++ b/src/Pretender.SourceGenerator/SetupArguments/SetupArgumentParser.cs @@ -1,76 +1,126 @@ using System.Collections.Immutable; using System.Diagnostics; -using System.Reflection; using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Operations; -using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; +using Pretender.SourceGenerator.Emitter; +using Pretender.SourceGenerator.Parser; namespace Pretender.SourceGenerator.SetupArguments { - internal abstract class SetupArgumentSpec + internal class SetupArgumentSpec { - private readonly List _diagnostics = []; - public SetupArgumentSpec(IArgumentOperation originalArgument, int argumentPlacement) + public SetupArgumentSpec(IArgumentOperation argument, KnownTypeSymbols knownTypeSymbols) { - OriginalArgument = originalArgument; - ArgumentPlacement = argumentPlacement; - - var tracker = new ArgumentTracker(); - Visit(originalArgument, tracker); - - NeedsCapturer = tracker.NeedsCapturer; - NeededLocals = tracker.NeededLocals; + Argument = argument; + KnownTypeSymbols = knownTypeSymbols; } - protected IArgumentOperation OriginalArgument { get; } - protected IParameterSymbol Parameter => OriginalArgument.Parameter!; - protected int ArgumentPlacement { get; } - protected void AddDiagnostic(Diagnostic diagnostic) - { - _diagnostics.Add(diagnostic); - } + public IArgumentOperation Argument { get; } - public IReadOnlyList Diagnostics => _diagnostics; - public bool NeedsCapturer { get; } - public ImmutableArray NeededLocals { get; } - public abstract int NeededMatcherStatements { get; } + // I don't think a SetupArgument will ever be __argList so I'm not worried about this null assurance + public IParameterSymbol Parameter => Argument.Parameter!; - public abstract ImmutableArray CreateMatcherStatements(CancellationToken cancellationToken); + public KnownTypeSymbols KnownTypeSymbols { get; } + } + + internal class SetupArgumentParser + { + private readonly SetupArgumentSpec _setupArgumentSpec; - protected (SyntaxToken Identifier, LocalDeclarationStatementSyntax Accessor) CreateArgumentAccessor() + public SetupArgumentParser(SetupArgumentSpec setupArgumentSpec) { - var argumentLocal = Identifier($"{Parameter.Name}_arg"); + _setupArgumentSpec = setupArgumentSpec; + } - // (string?)callInfo.Arguments[index]; - ExpressionSyntax argumentGetter = CastExpression( - Parameter.Type.AsUnknownTypeSyntax(), - ElementAccessExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName("callInfo"), IdentifierName("Arguments"))) - .AddArgumentListArguments(Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(ArgumentPlacement)))) - ); - var localAccessor = LocalDeclarationStatement(VariableDeclaration(ParseTypeName("var")) - .AddVariables(VariableDeclarator(argumentLocal) - .WithInitializer(EqualsValueClause(argumentGetter)))); + public (SetupArgumentEmitter? SetupArgumentEmitter, ImmutableArray? Diagnostics) Parse(CancellationToken cancellationToken) + { + var argumentValue = _setupArgumentSpec.Argument.Value; - return (argumentLocal, localAccessor); + return argumentValue.Kind switch + { + OperationKind.Literal => (new LiteralArgumentEmitter((ILiteralOperation)argumentValue, _setupArgumentSpec), null), + OperationKind.Invocation => ParseInvocation((IInvocationOperation)argumentValue, cancellationToken), + OperationKind.LocalReference => (new LocalReferenceArgumentEmitter((ILocalReferenceOperation)argumentValue, _setupArgumentSpec), null), + _ => throw new NotImplementedException($"{argumentValue.Kind} is not a supported operation in setup arguments."), + }; } - protected IfStatementSyntax CreateIfCheck(ExpressionSyntax left, ExpressionSyntax right) + private (SetupArgumentEmitter? Emitter, ImmutableArray? Diagnostics) ParseInvocation(IInvocationOperation invocation, CancellationToken cancellationToken) { - var binaryExpression = BinaryExpression( - SyntaxKind.NotEqualsExpression, - left, - right); + if (TryGetMatcherAttributeType(invocation, out var matcherType, cancellationToken)) + { + cancellationToken.ThrowIfCancellationRequested(); + + // Special case AnyMatcher + if (SymbolEqualityComparer.Default.Equals(matcherType, _setupArgumentSpec.KnownTypeSymbols.AnyMatcher)) + { + return (new NoopArgumentEmitter(_setupArgumentSpec), null); + } - return CreateIfCheck(binaryExpression); + // TODO: Parse args passed into the invocation + return (new MatcherArgumentEmitter(matcherType, _setupArgumentSpec), null); + } + else + { + // They likely invoked their own method, we will need to run and capture output for value/matcher + throw new NotImplementedException("We don't support user scoped invocations quite yet."); + } } - protected IfStatementSyntax CreateIfCheck(ExpressionSyntax condition) + private bool TryGetMatcherAttributeType(IInvocationOperation invocation, out INamedTypeSymbol matcherType, CancellationToken cancellationToken) { - return IfStatement(condition, Block( - ReturnStatement(LiteralExpression(SyntaxKind.FalseLiteralExpression)))); + var allAttributes = invocation.TargetMethod.GetAttributes(); + + // TODO: Use KnownTypeSymbols + var matcherAttribute = allAttributes.Single(ad => ad.AttributeClass!.EqualsByName(["Pretender", "Matchers", "MatcherAttribute"])); + + matcherType = null!; + + cancellationToken.ThrowIfCancellationRequested(); + + if (matcherAttribute.AttributeClass!.IsGenericType) + { + // We are in the typed version, get the generic arg + matcherType = (INamedTypeSymbol)matcherAttribute.AttributeClass.TypeArguments[0]; + } + else + { + // We are in the base version, get the constructor arg + // TODO: Make this work + // matcherType = matcherAttribute.ConstructorArguments[0]; + var attributeType = matcherAttribute.ConstructorArguments[0]; + // TODO: When can Type be null? + // TODO: Use KnownTypeSymbols + if (!attributeType.Type!.EqualsByName(["System", "Type"])) + { + return false; + } + + cancellationToken.ThrowIfCancellationRequested(); + + if (attributeType.Value is null) + { + return false; + } + + // Always an okay cast? + matcherType = (INamedTypeSymbol)attributeType.Value!; + } + + cancellationToken.ThrowIfCancellationRequested(); + // TODO: Write a lot more tests for this + if (matcherType.IsUnboundGenericType) + { + if (invocation.TargetMethod.TypeArguments.Length != matcherType.TypeArguments.Length) + { + return false; + } + + matcherType = matcherType.ConstructedFrom.Construct([.. invocation.TargetMethod.TypeArguments]); + } + + return true; } private static void Visit(IOperation? operation, ArgumentTracker tracker) @@ -167,32 +217,6 @@ private static void VisitMany(IEnumerable operations, ArgumentTracke } } - // Factory method for creating an ArgumentSpec based on the argument operation - public static SetupArgumentSpec Create(IArgumentOperation argumentOperation, int argumentPlacement) - { - var argumentOperationValue = argumentOperation.Value; - switch (argumentOperationValue.Kind) - { - case OperationKind.Literal: - return new LiteralArgumentSpec( - (ILiteralOperation)argumentOperationValue, - argumentOperation, - argumentPlacement); - case OperationKind.Invocation: - return new InvocationArgumentSpec( - (IInvocationOperation)argumentOperationValue, - argumentOperation, - argumentPlacement); - case OperationKind.LocalReference: - return new LocalReferenceArgumentSpec( - (ILocalReferenceOperation)argumentOperationValue, - argumentOperation, - argumentPlacement); - default: - throw new NotImplementedException(); - } - } - private class ArgumentTracker { private readonly List _neededLocals = new(); diff --git a/src/Pretender.SourceGenerator/SymbolExtensions.cs b/src/Pretender.SourceGenerator/SymbolExtensions.cs index 53964b7..2871210 100644 --- a/src/Pretender.SourceGenerator/SymbolExtensions.cs +++ b/src/Pretender.SourceGenerator/SymbolExtensions.cs @@ -1,8 +1,10 @@ -using System.Diagnostics; +using System.Collections.Immutable; +using System.Diagnostics; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using Pretender.SourceGenerator.Parser; +using Pretender.SourceGenerator.Writing; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; namespace Pretender.SourceGenerator @@ -50,6 +52,13 @@ public static TypeSyntax AsUnknownTypeSyntax(this ITypeSymbol type) return typeSyntax; } + public static string ToUnknownTypeString(this ITypeSymbol type) + { + return type.NullableAnnotation == NullableAnnotation.Annotated + ? $"{type.ToFullDisplayString()}?" + : type.ToFullDisplayString(); + } + public static ExpressionSyntax ToDefaultValueSyntax(this INamedTypeSymbol type, KnownTypeSymbols knownTypeSymbols) { // They have explicitly annotated this type as nullable, so return null @@ -131,16 +140,19 @@ public static string ToFullDisplayString(this ITypeSymbol type) return type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); } - public static ClassDeclarationSyntax ScaffoldImplementation(this ITypeSymbol type, ScaffoldTypeOptions options) + public static void ScaffoldImplementation(this ITypeSymbol type, IndentedTextWriter writer, ScaffoldTypeOptions options) { Debug.Assert(!type.IsSealed, "I can't scaffold an implementation of a sealed class"); // Add the base type PretendIMyType : IMyType - var classDeclaration = ClassDeclaration(type.ToPretendName()) - .AddBaseListTypes(SimpleBaseType(ParseTypeName(type.ToFullDisplayString()))); + writer.WriteLine($"file class {type.ToPretendName()} : {type.ToFullDisplayString()}"); + using (writer.WriteBlock()) + { + + } // Add fields first - classDeclaration = classDeclaration.AddMembers([.. options.CustomFields]); + //classDeclaration = classDeclaration.AddMembers([.. options.CustomFields]); // TODO: Only public and non-sealed? var typeMembers = type.GetMembers(); @@ -252,7 +264,18 @@ public static ClassDeclarationSyntax ScaffoldImplementation(this ITypeSymbol typ // TODO: Add GeneratedCodeAttribute // TODO: Add ExcludeFromCodeCoverageAttribute - return classDeclaration.AddMembers([.. members]); + //return classDeclaration.AddMembers([.. members]); + } + + public static ImmutableArray GetApplicableMethods(this INamedTypeSymbol type) + { + return type.GetMembers() + .Where(m => !m.IsStatic) + .OfType() + .Where(m => m.MethodKind == MethodKind.Ordinary + || m.MethodKind == MethodKind.PropertyGet + || m.MethodKind == MethodKind.PropertySet) + .ToImmutableArray(); } public static MethodDeclarationSyntax ToMethodSyntax(this IMethodSymbol method) diff --git a/src/Pretender.SourceGenerator/Writing/ImmutableArrayBuilder.cs b/src/Pretender.SourceGenerator/Writing/ImmutableArrayBuilder.cs new file mode 100644 index 0000000..52d8a1b --- /dev/null +++ b/src/Pretender.SourceGenerator/Writing/ImmutableArrayBuilder.cs @@ -0,0 +1,364 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Runtime.CompilerServices; + +namespace Pretender.SourceGenerator.Writing; + +/// +/// A helper type to build sequences of values with pooled buffers. +/// +/// The type of items to create sequences for. +internal struct ImmutableArrayBuilder : IDisposable +{ + /// + /// The shared instance to share objects. + /// + private static readonly ObjectPool SharedObjectPool = new(static () => new Writer()); + + /// + /// The rented instance to use. + /// + private Writer? writer; + + /// + /// Creates a new object. + /// + public ImmutableArrayBuilder() + { + this.writer = SharedObjectPool.Allocate(); + } + + /// + /// Gets the data written to the underlying buffer so far, as a . + /// + public readonly ReadOnlySpan WrittenSpan + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => this.writer!.WrittenSpan; + } + + /// + /// Gets the number of elements currently written in the current instance. + /// + public readonly int Count + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => this.writer!.Count; + } + + /// + /// Advances the current writer and gets a to the requested memory area. + /// + /// The requested size to advance by. + /// A to the requested memory area. + /// + /// No other data should be written to the builder while the returned + /// is in use, as it could invalidate the memory area wrapped by it, if resizing occurs. + /// + public readonly Span Advance(int requestedSize) + { + return this.writer!.Advance(requestedSize); + } + + /// + public readonly void Add(T item) + { + this.writer!.Add(item); + } + + /// + /// Adds the specified items to the end of the array. + /// + /// The items to add at the end of the array. + public readonly void AddRange(ReadOnlySpan items) + { + this.writer!.AddRange(items); + } + + /// + public readonly void Clear() + { + this.writer!.Clear(); + } + + /// + /// Inserts an item to the builder at the specified index. + /// + /// The zero-based index at which should be inserted. + /// The object to insert into the current instance. + public readonly void Insert(int index, T item) + { + this.writer!.Insert(index, item); + } + + /// + /// Gets an instance for the current builder. + /// + /// An instance for the current builder. + /// + /// The builder should not be mutated while an enumerator is in use. + /// + public readonly IEnumerable AsEnumerable() + { + return this.writer!; + } + + /// + public readonly ImmutableArray ToImmutable() + { + T[] array = this.writer!.WrittenSpan.ToArray(); + + return Unsafe.As>(ref array); + } + + /// + public readonly T[] ToArray() + { + return this.writer!.WrittenSpan.ToArray(); + } + + /// + public override readonly string ToString() + { + return this.writer!.WrittenSpan.ToString(); + } + + /// + public void Dispose() + { + Writer? writer = this.writer; + + this.writer = null; + + if (writer is not null) + { + writer.Clear(); + + SharedObjectPool.Free(writer); + } + } + + /// + /// A class handling the actual buffer writing. + /// + private sealed class Writer : IList, IReadOnlyList + { + /// + /// The underlying array. + /// + private T[] array; + + /// + /// The starting offset within . + /// + private int index; + + /// + /// Creates a new instance with the specified parameters. + /// + public Writer() + { + if (typeof(T) == typeof(char)) + { + this.array = new T[1024]; + } + else + { + this.array = new T[8]; + } + + this.index = 0; + } + + /// + public int Count => this.index; + + /// + public ReadOnlySpan WrittenSpan + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => new(this.array, 0, this.index); + } + + /// + bool ICollection.IsReadOnly => true; + + /// + T IReadOnlyList.this[int index] => WrittenSpan[index]; + + /// + T IList.this[int index] + { + get => WrittenSpan[index]; + set => throw new NotSupportedException(); + } + + /// + public Span Advance(int requestedSize) + { + EnsureCapacity(requestedSize); + + Span span = this.array.AsSpan(this.index, requestedSize); + + this.index += requestedSize; + + return span; + } + + /// + public void Add(T value) + { + EnsureCapacity(1); + + this.array[this.index++] = value; + } + + /// + public void AddRange(ReadOnlySpan items) + { + EnsureCapacity(items.Length); + + items.CopyTo(this.array.AsSpan(this.index)); + + this.index += items.Length; + } + + /// + public void Clear(ReadOnlySpan items) + { + this.index = 0; + } + + /// + public void Insert(int index, T item) + { + if (index < 0 || index > this.index) + { + ImmutableArrayBuilder.ThrowArgumentOutOfRangeExceptionForIndex(); + } + + EnsureCapacity(1); + + if (index < this.index) + { + Array.Copy(this.array, index, this.array, index + 1, this.index - index); + } + + this.array[index] = item; + this.index++; + } + + /// + /// Clears the items in the current writer. + /// + public void Clear() + { + if (typeof(T) != typeof(byte) && + typeof(T) != typeof(char) && + typeof(T) != typeof(int)) + { + this.array.AsSpan(0, this.index).Clear(); + } + + this.index = 0; + } + + /// + /// Ensures that has enough free space to contain a given number of new items. + /// + /// The minimum number of items to ensure space for in . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void EnsureCapacity(int requestedSize) + { + if (requestedSize > this.array.Length - this.index) + { + ResizeBuffer(requestedSize); + } + } + + /// + /// Resizes to ensure it can fit the specified number of new items. + /// + /// The minimum number of items to ensure space for in . + [MethodImpl(MethodImplOptions.NoInlining)] + private void ResizeBuffer(int sizeHint) + { + int minimumSize = this.index + sizeHint; + int requestedSize = Math.Max(this.array.Length * 2, minimumSize); + + T[] newArray = new T[requestedSize]; + + Array.Copy(this.array, newArray, this.index); + + this.array = newArray; + } + + /// + int IList.IndexOf(T item) + { + return Array.IndexOf(this.array, item, 0, this.index); + } + + /// + void IList.RemoveAt(int index) + { + throw new NotSupportedException(); + } + + /// + bool ICollection.Contains(T item) + { + return Array.IndexOf(this.array, item, 0, this.index) >= 0; + } + + /// + void ICollection.CopyTo(T[] array, int arrayIndex) + { + Array.Copy(this.array, 0, array, arrayIndex, this.index); + } + + /// + bool ICollection.Remove(T item) + { + throw new NotSupportedException(); + } + + /// + IEnumerator IEnumerable.GetEnumerator() + { + T?[] array = this.array!; + int length = this.index; + + for (int i = 0; i < length; i++) + { + yield return array[i]!; + } + } + + /// + IEnumerator IEnumerable.GetEnumerator() + { + return ((IEnumerable)this).GetEnumerator(); + } + } +} + +/// +/// Private helpers for the type. +/// +file static class ImmutableArrayBuilder +{ + /// + /// Throws an for "index". + /// + public static void ThrowArgumentOutOfRangeExceptionForIndex() + { + throw new ArgumentOutOfRangeException("index"); + } +} \ No newline at end of file diff --git a/src/Pretender.SourceGenerator/Writing/IndentedTextWriter.cs b/src/Pretender.SourceGenerator/Writing/IndentedTextWriter.cs new file mode 100644 index 0000000..12bb53d --- /dev/null +++ b/src/Pretender.SourceGenerator/Writing/IndentedTextWriter.cs @@ -0,0 +1,515 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.ComponentModel; +using System.Globalization; +using System.Runtime.CompilerServices; +using System.Text; + +#pragma warning disable IDE0290 + +namespace Pretender.SourceGenerator.Writing; + +/// +/// A helper type to build sequences of values with pooled buffers. +/// +internal sealed class IndentedTextWriter : IDisposable +{ + /// + /// The default indentation (4 spaces). + /// + private const string DefaultIndentation = " "; + + /// + /// The default new line ('\n'). + /// + private const char DefaultNewLine = '\n'; + + /// + /// The instance that text will be written to. + /// + private ImmutableArrayBuilder builder; + + /// + /// The current indentation level. + /// + private int currentIndentationLevel; + + /// + /// The current indentation, as text. + /// + private string currentIndentation = ""; + + /// + /// The cached array of available indentations, as text. + /// + private string[] availableIndentations; + + /// + /// Creates a new object. + /// + public IndentedTextWriter() + { + builder = new ImmutableArrayBuilder(); + currentIndentationLevel = 0; + currentIndentation = ""; + availableIndentations = new string[4]; + availableIndentations[0] = ""; + + for (int i = 1, n = availableIndentations.Length; i < n; i++) + { + availableIndentations[i] = availableIndentations[i - 1] + DefaultIndentation; + } + } + + /// + /// Advances the current writer and gets a to the requested memory area. + /// + /// The requested size to advance by. + /// A to the requested memory area. + /// + /// No other data should be written to the writer while the returned + /// is in use, as it could invalidate the memory area wrapped by it, if resizing occurs. + /// + public Span Advance(int requestedSize) + { + // Add the leading whitespace if needed (same as WriteRawText below) + if (builder.Count == 0 || builder.WrittenSpan[^1] == DefaultNewLine) + { + builder.AddRange(currentIndentation.AsSpan()); + } + + return builder.Advance(requestedSize); + } + + /// + /// Increases the current indentation level. + /// + public void IncreaseIndent() + { + currentIndentationLevel++; + + if (currentIndentationLevel == availableIndentations.Length) + { + Array.Resize(ref availableIndentations, availableIndentations.Length * 2); + } + + // Set both the current indentation and the current position in the indentations + // array to the expected indentation for the incremented level (ie. one level more). + currentIndentation = availableIndentations[currentIndentationLevel] + ??= availableIndentations[currentIndentationLevel - 1] + DefaultIndentation; + } + + /// + /// Decreases the current indentation level. + /// + public void DecreaseIndent() + { + currentIndentationLevel--; + currentIndentation = availableIndentations[currentIndentationLevel]; + } + + /// + /// Writes a block to the underlying buffer. + /// + /// A value to close the open block with. + public Block WriteBlock() + { + WriteLine("{"); + IncreaseIndent(); + + return new(this); + } + + /// + /// Writes content to the underlying buffer. + /// + /// The content to write. + /// Whether the input content is multiline. + public void Write(string content, bool isMultiline = false) + { + Write(content.AsSpan(), isMultiline); + } + + /// + /// Writes content to the underlying buffer. + /// + /// The content to write. + /// Whether the input content is multiline. + public void Write(ReadOnlySpan content, bool isMultiline = false) + { + if (isMultiline) + { + while (content.Length > 0) + { + int newLineIndex = content.IndexOf(DefaultNewLine); + + if (newLineIndex < 0) + { + // There are no new lines left, so the content can be written as a single line + WriteRawText(content); + + break; + } + else + { + ReadOnlySpan line = content[..newLineIndex]; + + // Write the current line (if it's empty, we can skip writing the text entirely). + // This ensures that raw multiline string literals with blank lines don't have + // extra whitespace at the start of those lines, which would otherwise happen. + WriteIf(!line.IsEmpty, line); + WriteLine(); + + // Move past the new line character (the result could be an empty span) + content = content[(newLineIndex + 1)..]; + } + } + } + else + { + WriteRawText(content); + } + } + + /// + /// Writes content to the underlying buffer. + /// + /// The interpolated string handler with content to write. + public void Write([InterpolatedStringHandlerArgument("")] ref WriteInterpolatedStringHandler handler) + { + _ = this; + } + + /// + /// Writes content to the underlying buffer depending on an input condition. + /// + /// The condition to use to decide whether or not to write content. + /// The content to write. + /// Whether the input content is multiline. + public void WriteIf(bool condition, string content, bool isMultiline = false) + { + if (condition) + { + Write(content.AsSpan(), isMultiline); + } + } + + /// + /// Writes content to the underlying buffer depending on an input condition. + /// + /// The condition to use to decide whether or not to write content. + /// The content to write. + /// Whether the input content is multiline. + public void WriteIf(bool condition, ReadOnlySpan content, bool isMultiline = false) + { + if (condition) + { + Write(content, isMultiline); + } + } + + /// + /// Writes content to the underlying buffer depending on an input condition. + /// + /// The condition to use to decide whether or not to write content. + /// The interpolated string handler with content to write. + public void WriteIf(bool condition, [InterpolatedStringHandlerArgument("", nameof(condition))] ref WriteIfInterpolatedStringHandler handler) + { + _ = this; + } + + /// + /// Writes a line to the underlying buffer. + /// + /// Indicates whether to skip adding the line if there already is one. + public void WriteLine(bool skipIfPresent = false) + { + if (skipIfPresent && builder.WrittenSpan is [.., '\n', '\n']) + { + return; + } + + builder.Add(DefaultNewLine); + } + + /// + /// Writes content to the underlying buffer and appends a trailing new line. + /// + /// The content to write. + /// Whether the input content is multiline. + public void WriteLine(string content, bool isMultiline = false) + { + WriteLine(content.AsSpan(), isMultiline); + } + + /// + /// Writes content to the underlying buffer and appends a trailing new line. + /// + /// The content to write. + /// Whether the input content is multiline. + public void WriteLine(ReadOnlySpan content, bool isMultiline = false) + { + Write(content, isMultiline); + WriteLine(); + } + + /// + /// Writes content to the underlying buffer and appends a trailing new line. + /// + /// The interpolated string handler with content to write. + public void WriteLine([InterpolatedStringHandlerArgument("")] ref WriteInterpolatedStringHandler handler) + { + WriteLine(); + } + + /// + /// Writes a line to the underlying buffer depending on an input condition. + /// + /// The condition to use to decide whether or not to write content. + /// Indicates whether to skip adding the line if there already is one. + public void WriteLineIf(bool condition, bool skipIfPresent = false) + { + if (condition) + { + WriteLine(skipIfPresent); + } + } + + /// + /// Writes content to the underlying buffer and appends a trailing new line depending on an input condition. + /// + /// The condition to use to decide whether or not to write content. + /// The content to write. + /// Whether the input content is multiline. + public void WriteLineIf(bool condition, string content, bool isMultiline = false) + { + if (condition) + { + WriteLine(content.AsSpan(), isMultiline); + } + } + + /// + /// Writes content to the underlying buffer and appends a trailing new line depending on an input condition. + /// + /// The condition to use to decide whether or not to write content. + /// The content to write. + /// Whether the input content is multiline. + public void WriteLineIf(bool condition, ReadOnlySpan content, bool isMultiline = false) + { + if (condition) + { + Write(content, isMultiline); + WriteLine(); + } + } + + /// + /// Writes content to the underlying buffer and appends a trailing new line depending on an input condition. + /// + /// The condition to use to decide whether or not to write content. + /// The interpolated string handler with content to write. + public void WriteLineIf(bool condition, [InterpolatedStringHandlerArgument("", nameof(condition))] ref WriteIfInterpolatedStringHandler handler) + { + if (condition) + { + WriteLine(); + } + } + + /// + public override string ToString() + { + return builder.WrittenSpan.Trim().ToString(); + } + + /// + public void Dispose() + { + builder.Dispose(); + } + + /// + /// Writes raw text to the underlying buffer, adding leading indentation if needed. + /// + /// The raw text to write. + private void WriteRawText(ReadOnlySpan content) + { + if (builder.Count == 0 || builder.WrittenSpan[^1] == DefaultNewLine) + { + builder.AddRange(currentIndentation.AsSpan()); + } + + builder.AddRange(content); + } + + /// + /// A delegate representing a callback to write data into an instance. + /// + /// The type of data to use. + /// The input data to use to write into . + /// The instance to write into. + public delegate void Callback(T value, IndentedTextWriter writer); + + /// + /// Represents an indented block that needs to be closed. + /// + /// The input instance to wrap. + public struct Block(IndentedTextWriter writer) : IDisposable + { + /// + /// The instance to write to. + /// + private IndentedTextWriter? writer = writer; + + /// + public void Dispose() + { + IndentedTextWriter? writer = this.writer; + + this.writer = null; + + if (writer is not null) + { + writer.DecreaseIndent(); + writer.WriteLine("}"); + } + } + } + + /// + /// Provides a handler used by the language compiler to append interpolated strings into instances. + /// + [EditorBrowsable(EditorBrowsableState.Never)] + [InterpolatedStringHandler] + public readonly ref struct WriteInterpolatedStringHandler + { + /// The associated to which to append. + private readonly IndentedTextWriter writer; + + /// Creates a handler used to append an interpolated string into a . + /// The number of constant characters outside of interpolation expressions in the interpolated string. + /// The number of interpolation expressions in the interpolated string. + /// The associated to which to append. + /// This is intended to be called only by compiler-generated code. Arguments are not validated as they'd otherwise be for members intended to be used directly. + public WriteInterpolatedStringHandler(int literalLength, int formattedCount, IndentedTextWriter writer) + { + this.writer = writer; + } + + /// Writes the specified string to the handler. + /// The string to write. + public void AppendLiteral(string value) + { + writer.Write(value); + } + + /// Writes the specified value to the handler. + /// The value to write. + public void AppendFormatted(string? value) + { + AppendFormatted(value); + } + + /// Writes the specified character span to the handler. + /// The span to write. + public void AppendFormatted(ReadOnlySpan value) + { + writer.Write(value); + } + + /// Writes the specified value to the handler. + /// The value to write. + /// The type of the value to write. + public void AppendFormatted(T value) + { + if (value is not null) + { + writer.Write(value.ToString()); + } + } + + /// Writes the specified value to the handler. + /// The value to write. + /// The format string. + /// The type of the value to write. + public void AppendFormatted(T value, string? format) + { + if (value is IFormattable) + { + writer.Write(((IFormattable)value).ToString(format, CultureInfo.InvariantCulture)); + } + else if (value is not null) + { + writer.Write(value.ToString()); + } + } + } + + /// + /// Provides a handler used by the language compiler to conditionally append interpolated strings into instances. + /// + [EditorBrowsable(EditorBrowsableState.Never)] + [InterpolatedStringHandler] + public readonly ref struct WriteIfInterpolatedStringHandler + { + /// The associated to use. + private readonly WriteInterpolatedStringHandler handler; + + /// Creates a handler used to append an interpolated string into a . + /// The number of constant characters outside of interpolation expressions in the interpolated string. + /// The number of interpolation expressions in the interpolated string. + /// The associated to which to append. + /// The condition to use to decide whether or not to write content. + /// A value indicating whether formatting should proceed. + /// This is intended to be called only by compiler-generated code. Arguments are not validated as they'd otherwise be for members intended to be used directly. + public WriteIfInterpolatedStringHandler(int literalLength, int formattedCount, IndentedTextWriter writer, bool condition, out bool shouldAppend) + { + if (condition) + { + handler = new WriteInterpolatedStringHandler(literalLength, formattedCount, writer); + + shouldAppend = true; + } + else + { + handler = default; + + shouldAppend = false; + } + } + + /// + public void AppendLiteral(string value) + { + handler.AppendLiteral(value); + } + + /// + public void AppendFormatted(string? value) + { + handler.AppendFormatted(value); + } + + /// + public void AppendFormatted(ReadOnlySpan value) + { + handler.AppendFormatted(value); + } + + /// + public void AppendFormatted(T value) + { + handler.AppendFormatted(value); + } + + /// + public void AppendFormatted(T value, string? format) + { + handler.AppendFormatted(value, format); + } + } +} \ No newline at end of file diff --git a/src/Pretender.SourceGenerator/Writing/ObjectPool{T}.cs b/src/Pretender.SourceGenerator/Writing/ObjectPool{T}.cs new file mode 100644 index 0000000..650491a --- /dev/null +++ b/src/Pretender.SourceGenerator/Writing/ObjectPool{T}.cs @@ -0,0 +1,154 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +// Ported from Roslyn, see: https://github.com/dotnet/roslyn/blob/main/src/Dependencies/PooledObjects/ObjectPool%601.cs. + +using System; +using System.Runtime.CompilerServices; +using System.Threading; + +#pragma warning disable RS1035 + +namespace Pretender.SourceGenerator.Writing; + +/// +/// +/// Generic implementation of object pooling pattern with predefined pool size limit. The main purpose +/// is that limited number of frequently used objects can be kept in the pool for further recycling. +/// +/// +/// Notes: +/// +/// +/// It is not the goal to keep all returned objects. Pool is not meant for storage. If there +/// is no space in the pool, extra returned objects will be dropped. +/// +/// +/// It is implied that if object was obtained from a pool, the caller will return it back in +/// a relatively short time. Keeping checked out objects for long durations is ok, but +/// reduces usefulness of pooling. Just new up your own. +/// +/// +/// +/// +/// Not returning objects to the pool in not detrimental to the pool's work, but is a bad practice. +/// Rationale: if there is no intent for reusing the object, do not use pool - just use "new". +/// +/// +/// The type of objects to pool. +/// The input factory to produce items. +/// +/// The factory is stored for the lifetime of the pool. We will call this only when pool needs to +/// expand. compared to "new T()", Func gives more flexibility to implementers and faster than "new T()". +/// +/// The pool size to use. +internal sealed class ObjectPool(Func factory, int size) + where T : class +{ + /// + /// The array of cached items. + /// + private readonly Element[] items = new Element[size - 1]; + + /// + /// Storage for the pool objects. The first item is stored in a dedicated field + /// because we expect to be able to satisfy most requests from it. + /// + private T? firstItem; + + /// + /// Creates a new instance with the specified parameters. + /// + /// The input factory to produce items. + public ObjectPool(Func factory) + : this(factory, Environment.ProcessorCount * 2) + { + } + + /// + /// Produces a instance. + /// + /// The returned item to use. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public T Allocate() + { + T? item = this.firstItem; + + if (item is null || item != Interlocked.CompareExchange(ref this.firstItem, null, item)) + { + item = AllocateSlow(); + } + + return item; + } + + /// + /// Returns a given instance to the pool. + /// + /// The instance to return. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Free(T obj) + { + if (this.firstItem is null) + { + this.firstItem = obj; + } + else + { + FreeSlow(obj); + } + } + + /// + /// Allocates a new item. + /// + /// The returned item to use. + [MethodImpl(MethodImplOptions.NoInlining)] + private T AllocateSlow() + { + foreach (ref Element element in this.items.AsSpan()) + { + T? instance = element.Value; + + if (instance is not null) + { + if (instance == Interlocked.CompareExchange(ref element.Value, null, instance)) + { + return instance; + } + } + } + + return factory(); + } + + /// + /// Frees a given item. + /// + /// The item to return to the pool. + [MethodImpl(MethodImplOptions.NoInlining)] + private void FreeSlow(T obj) + { + foreach (ref Element element in this.items.AsSpan()) + { + if (element.Value is null) + { + element.Value = obj; + + break; + } + } + } + + /// + /// A container for a produced item (using a wrapper to avoid covariance checks). + /// + private struct Element + { + /// + /// The value held at the current element. + /// + internal T? Value; + } +} \ No newline at end of file diff --git a/test/SourceGeneratorTests/MainTests.cs b/test/SourceGeneratorTests/MainTests.cs index bcd2462..1ce47f6 100644 --- a/test/SourceGeneratorTests/MainTests.cs +++ b/test/SourceGeneratorTests/MainTests.cs @@ -53,22 +53,36 @@ public TestClass() [Fact] - public async Task Test2() + public async Task AbstractClass() { - var (result, compilation) = await RunPartialGeneratorAsync($$""" - var pretendSimpleInterface = Pretend.That(); + var (result, c) = await RunPartialGeneratorAsync($$""" + #nullable enable + using System; + using System.Threading.Tasks; + using Pretender; + + namespace AbstractClass; - var simpleInterface = pretendSimpleInterface.Create(); + public abstract class MyAbstractClass + { + abstract Task MethodAsync(string str); + abstract string Name { get; set; } + } + + public class TestClass + { + public TestClass() + { + var pretend = Pretend.That(); + + pretend.Setup(c => c.MethodAsync("Hi")); + } + } """); - Assert.Equal(3, result.GeneratedSources.Length); + var source = Assert.Single(result.GeneratedSources); - var source1 = result.GeneratedSources[0]; - var text1 = source1.SourceText.ToString(); - var source2 = result.GeneratedSources[1]; - var text2 = source2.SourceText.ToString(); - var source3 = result.GeneratedSources[2]; - var text3 = source3.SourceText.ToString(); + var sourceText = source.SourceText.ToString(); } [Fact] @@ -82,15 +96,11 @@ public async Task Test3() .Returns("Hi"); var pretend = pretendSimpleInterface.Create(); - """); - Assert.Equal(3, result.GeneratedSources.Length); + pretendSimpleInterface.Verify(i => i.Foo("1", 1), 2); + """); - var source1 = result.GeneratedSources[0]; - var text1 = source1.SourceText.ToString(); - var source2 = result.GeneratedSources[1]; - var text2 = source2.SourceText.ToString(); - var source3 = result.GeneratedSources[2]; - var text3 = source3.SourceText.ToString(); + var source = Assert.Single(result.GeneratedSources); + var text = source.SourceText.ToString(); } } diff --git a/test/SourceGeneratorTests/TestBase.cs b/test/SourceGeneratorTests/TestBase.cs index d248bc0..4282500 100644 --- a/test/SourceGeneratorTests/TestBase.cs +++ b/test/SourceGeneratorTests/TestBase.cs @@ -3,7 +3,6 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.Text; -using Microsoft.VisualStudio.TestPlatform.PlatformAbstractions; using Pretender; using Pretender.SourceGenerator;