diff --git a/src/Pretender.SourceGenerator/Emitter/SetupEmitter.cs b/src/Pretender.SourceGenerator/Emitter/SetupEmitter.cs new file mode 100644 index 0000000..1052421 --- /dev/null +++ b/src/Pretender.SourceGenerator/Emitter/SetupEmitter.cs @@ -0,0 +1,59 @@ +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; +using Microsoft.CodeAnalysis.Operations; + +namespace Pretender.SourceGenerator.Emitter +{ + internal class SetupEmitter + { + private readonly SetupActionEmitter _setupActionEmitter; + private readonly IInvocationOperation _setupInvocation; + + public SetupEmitter(SetupActionEmitter setupActionEmitter, IInvocationOperation setupInvocation) + { + _setupActionEmitter = setupActionEmitter; + _setupInvocation = setupInvocation; + } + + // TODO: Run cancellationToken a lot more + public MemberDeclarationSyntax[] Emit(int index, CancellationToken cancellationToken) + { + var setupMethod = _setupActionEmitter.SetupMethod; + var pretendType = _setupActionEmitter.PretendType; + + var allMembers = new List(); + + var interceptsLocation = new InterceptsLocationInfo(_setupInvocation); + + // TODO: This is wrong + var typeArguments = setupMethod.ReturnsVoid + ? TypeArgumentList(SingletonSeparatedList(ParseTypeName(pretendType.ToFullDisplayString()))) + : TypeArgumentList(SeparatedList([ParseTypeName(pretendType.ToFullDisplayString()), setupMethod.ReturnType.AsUnknownTypeSyntax()])); + + var returnType = GenericName("IPretendSetup") + .WithTypeArgumentList(typeArguments); + + var setupCreatorInvocation = _setupActionEmitter.CreateSetupGetter(cancellationToken); + + var fullSetupMethod = MethodDeclaration(returnType, $"Setup{index}") + .WithBody(Block(ReturnStatement(setupCreatorInvocation))) + .WithParameterList(ParameterList(SeparatedList(new[] + { + Parameter(Identifier("pretend")) + .WithModifiers(TokenList(Token(SyntaxKind.ThisKeyword))) + .WithType(ParseTypeName($"Pretend<{pretendType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}>")), + + Parameter(Identifier("setupExpression")) + .WithType(GenericName(setupMethod.ReturnsVoid ? "Action" : "Func").WithTypeArgumentList(typeArguments)), + }))) + .WithModifiers(TokenList(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.StaticKeyword))) + .WithAttributeLists(SingletonList(AttributeList( + SingletonSeparatedList(interceptsLocation.ToAttributeSyntax())))); + + allMembers.Add(fullSetupMethod); + return [.. allMembers]; + } + } +} diff --git a/src/Pretender.SourceGenerator/Emitter/VerifyEmitter.cs b/src/Pretender.SourceGenerator/Emitter/VerifyEmitter.cs index 41fc2b2..da9c6c8 100644 --- a/src/Pretender.SourceGenerator/Emitter/VerifyEmitter.cs +++ b/src/Pretender.SourceGenerator/Emitter/VerifyEmitter.cs @@ -10,25 +10,25 @@ internal class VerifyEmitter { private readonly ITypeSymbol _pretendType; private readonly ITypeSymbol? _returnType; - private readonly SetupCreationSpec _setupCreationSpec; + private readonly SetupActionEmitter _setupActionEmitter; private readonly IInvocationOperation _invocationOperation; - public VerifyEmitter(ITypeSymbol pretendType, ITypeSymbol? returnType, SetupCreationSpec setupCreationSpec, IInvocationOperation invocationOperation) + public VerifyEmitter(ITypeSymbol pretendType, ITypeSymbol? returnType, SetupActionEmitter setupActionEmitter, IInvocationOperation invocationOperation) { _pretendType = pretendType; _returnType = returnType; - _setupCreationSpec = setupCreationSpec; + _setupActionEmitter = setupActionEmitter; _invocationOperation = invocationOperation; } - public MethodDeclarationSyntax EmitVerifyMethod(int index, CancellationToken cancellationToken) + public MethodDeclarationSyntax Emit(int index, CancellationToken cancellationToken) { - var setupGetter = _setupCreationSpec.CreateSetupGetter(cancellationToken); + var setupInvocation = _setupActionEmitter.CreateSetupGetter(cancellationToken); // var setup = pretend.GetOrCreateSetup(...); var setupLocal = LocalDeclarationStatement(VariableDeclaration(CommonSyntax.VarType) .WithVariables(SingletonSeparatedList(VariableDeclarator(CommonSyntax.SetupIdentifier) - .WithInitializer(EqualsValueClause(setupGetter))))); + .WithInitializer(EqualsValueClause(setupInvocation))))); TypeSyntax pretendType = ParseTypeName(_pretendType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); diff --git a/src/Pretender.SourceGenerator/InvocationOperationExtensions.cs b/src/Pretender.SourceGenerator/InvocationOperationExtensions.cs index eb83d8e..3dd2163 100644 --- a/src/Pretender.SourceGenerator/InvocationOperationExtensions.cs +++ b/src/Pretender.SourceGenerator/InvocationOperationExtensions.cs @@ -1,7 +1,5 @@ using System.Collections.Immutable; - using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Operations; namespace Pretender.SourceGenerator @@ -20,39 +18,6 @@ public static bool IsInvocationOperation(this IOperation? operation, out IInvoca return false; } - public static bool IsSetupCall(this SyntaxNode node) - { - return node is InvocationExpressionSyntax - { - Expression: MemberAccessExpressionSyntax - { - // pretend.Setup(i => i.Something()); - Name.Identifier.ValueText: "Setup" or "SetupSet", - }, - ArgumentList.Arguments.Count: 1 - }; - } - - public static bool IsValidSetupOperation(this IOperation operation, Compilation compilation, out IInvocationOperation? invocation) - { - var pretendType = compilation.GetTypeByMetadataName("Pretender.Pretend`1"); - invocation = null; - - // TODO: Probably need to check a few more things - // Someone could make a Setup extension method, that doesn't look - // like I think it should, I need to check the return type and first arg type - // a lot more closely. - if (operation is IInvocationOperation targetOperation - && targetOperation.Instance is not null - && SymbolEqualityComparer.Default.Equals(targetOperation.Instance.Type!.OriginalDefinition, pretendType)) - { - invocation = targetOperation; - return true; - } - - return false; - } - public static bool IsValidCreateOperation(this IOperation? operation, Compilation compilation, out IInvocationOperation invocationOperation, out ImmutableArray? typeArguments) { var pretendGeneric = compilation.GetTypeByMetadataName("Pretender.Pretend`1"); diff --git a/src/Pretender.SourceGenerator/Parser/KnownTypeSymbols.cs b/src/Pretender.SourceGenerator/Parser/KnownTypeSymbols.cs index c41b05d..25f928e 100644 --- a/src/Pretender.SourceGenerator/Parser/KnownTypeSymbols.cs +++ b/src/Pretender.SourceGenerator/Parser/KnownTypeSymbols.cs @@ -11,6 +11,13 @@ internal sealed class KnownTypeSymbols public INamedTypeSymbol? Pretend { get; } public INamedTypeSymbol? Pretend_Unbound { get; } + public INamedTypeSymbol? Task { get; } + public INamedTypeSymbol? TaskOfT { get; } + public INamedTypeSymbol? ValueTask { get; } + public INamedTypeSymbol? ValueTaskOfT { get; } + + + public KnownTypeSymbols(CSharpCompilation compilation) { Compilation = compilation; @@ -18,6 +25,13 @@ public KnownTypeSymbols(CSharpCompilation compilation) // TODO: Get known types Pretend = compilation.GetTypeByMetadataName("Pretender.Pretend`1"); Pretend_Unbound = Pretend?.ConstructUnboundGenericType(); + + Task = compilation.GetTypeByMetadataName("System.Threading.Tasks.Task"); + // TODO: Create unbounded? + TaskOfT = compilation.GetTypeByMetadataName("System.Threading.Tasks.Task`1"); + ValueTask = compilation.GetTypeByMetadataName("System.Threading.Tasks.ValueTask"); + // TODO: Create unbounded? + ValueTaskOfT = compilation.GetTypeByMetadataName("System.Threading.Tasks.ValueTask`1"); } public static bool IsPretend(INamedTypeSymbol type) diff --git a/src/Pretender.SourceGenerator/Parser/SetupActionParser.cs b/src/Pretender.SourceGenerator/Parser/SetupActionParser.cs new file mode 100644 index 0000000..023ab86 --- /dev/null +++ b/src/Pretender.SourceGenerator/Parser/SetupActionParser.cs @@ -0,0 +1,177 @@ +using System.Collections.Immutable; +using System.Diagnostics; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Operations; +using Pretender.SourceGenerator.SetupArguments; + +namespace Pretender.SourceGenerator.Parser +{ + internal class SetupActionParser + { + private readonly IOperation _setupActionArgument; + private readonly ITypeSymbol _pretendType; + private readonly bool _forcePropertySetter; + + // TODO: Should I have a higher IOperation kind here? Like InvocationOperation? + public SetupActionParser(IOperation setupActionArgument, ITypeSymbol pretendType, bool forcePropertySetter) + { + _setupActionArgument = setupActionArgument; + _pretendType = pretendType; + _forcePropertySetter = forcePropertySetter; + } + + public (SetupActionEmitter? Emitter, ImmutableArray? Diagnostics) Parse(CancellationToken cancellationToken) + { + var candidates = GetInvocationCandidates(cancellationToken); + + if (candidates.Length == 0) + { + // TODO: Create error diagnostic + return (null, null); + } + else if (candidates.Length != 1) + { + // TODO: Create error diagnostic + return (null, null); + } + + var candidate = candidates[0]; + + var arguments = candidate.Arguments; + + var builder = ImmutableArray.CreateBuilder(arguments.Length); + for (var i = 0; i < arguments.Length; i++) + { + builder.Add(SetupArgumentSpec.Create(arguments[i], i)); + } + + return (new SetupActionEmitter(_pretendType, candidate.Method, builder.MoveToImmutable()), null); + } + + private ImmutableArray GetInvocationCandidates(CancellationToken cancellationToken) + { + var builder = ImmutableArray.CreateBuilder(); + TraverseOperation(_setupActionArgument, builder, cancellationToken); + return builder.ToImmutable(); + } + + private void TraverseOperation(IOperation operation, ImmutableArray.Builder invocationCandidates, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + + switch (operation.Kind) + { + case OperationKind.Block: + var blockOperation = (IBlockOperation)operation; + TraverseOperationList(blockOperation.Operations, invocationCandidates, cancellationToken); + break; + case OperationKind.Return: + var returnOperation = (IReturnOperation)operation; + if (returnOperation.ReturnedValue != null) + { + TraverseOperation(returnOperation.ReturnedValue, invocationCandidates, cancellationToken); + } + break; + case OperationKind.ExpressionStatement: + var expressionStatement = (IExpressionStatementOperation)operation; + TraverseOperation(expressionStatement.Operation, invocationCandidates, cancellationToken); + break; + case OperationKind.Conversion: + var conversionOperation = (IConversionOperation)operation; + TraverseOperation(conversionOperation.Operand, invocationCandidates, cancellationToken); + break; + case OperationKind.Invocation: + var invocationOperation = (IInvocationOperation)operation; + TryMatchInvocationOperation(invocationOperation, invocationCandidates); + break; + case OperationKind.PropertyReference: + var propertyReferenceOperation = (IPropertyReferenceOperation)operation; + TryMatchPropertyReference(propertyReferenceOperation, invocationCandidates); + break; + case OperationKind.AnonymousFunction: + var anonymousFunctionOperation = (IAnonymousFunctionOperation)operation; + TraverseOperation(anonymousFunctionOperation.Body, invocationCandidates, cancellationToken); + break; + case OperationKind.DelegateCreation: + var delegateCreationOperation = (IDelegateCreationOperation)operation; + TraverseOperation(delegateCreationOperation.Target, invocationCandidates, cancellationToken); + break; + default: +#if DEBUG + // TODO: Figure out what operation caused this, it's not ideal to "randomly" support operations + Debugger.Launch(); +#endif + // Absolute fallback, most of our operations can be supported this way but it's nicer to be explicit + TraverseOperationList(operation.ChildOperations, invocationCandidates, cancellationToken); + break; + } + } + + private void TraverseOperationList(IEnumerable operations, ImmutableArray.Builder invocationCandidates, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + foreach (var operation in operations) + { + cancellationToken.ThrowIfCancellationRequested(); + TraverseOperation(operation, invocationCandidates, cancellationToken); + } + } + + private void TryMatchPropertyReference(IPropertyReferenceOperation propertyReference, ImmutableArray.Builder invocationCandidates) + { + if (propertyReference.Instance is not IParameterReferenceOperation parameterReference) + { + return; + } + + if (!SymbolEqualityComparer.Default.Equals(parameterReference.Type, _pretendType)) + { + return; + } + + var method = _forcePropertySetter + ? propertyReference.Property.SetMethod + : propertyReference.Property.GetMethod; + + if (method == null) + { + return; + } + + invocationCandidates.Add(new InvocationCandidate(method, ImmutableArray.Empty)); + } + + private void TryMatchInvocationOperation(IInvocationOperation invocation, ImmutableArray.Builder invocationCandidates) + { + if (_forcePropertySetter) + { + return; + } + + if (invocation.Instance is not IParameterReferenceOperation parameterReference) + { + return; + } + + if (!SymbolEqualityComparer.Default.Equals(parameterReference.Type, _pretendType)) + { + return; + } + + // TODO: Any more validation? + invocationCandidates.Add(new InvocationCandidate(invocation.TargetMethod, invocation.Arguments)); + } + + private class InvocationCandidate + { + public InvocationCandidate(IMethodSymbol methodSymbol, ImmutableArray argumentOperations) + { + Method = methodSymbol; + Arguments = argumentOperations; + } + + public IMethodSymbol Method { get; } + public ImmutableArray Arguments { get; } + } + } +} diff --git a/src/Pretender.SourceGenerator/Parser/SetupParser.cs b/src/Pretender.SourceGenerator/Parser/SetupParser.cs new file mode 100644 index 0000000..a6d888d --- /dev/null +++ b/src/Pretender.SourceGenerator/Parser/SetupParser.cs @@ -0,0 +1,56 @@ +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; +using Pretender.SourceGenerator.Emitter; +using static Pretender.SourceGenerator.PretenderSourceGenerator; + +namespace Pretender.SourceGenerator.Parser +{ + internal class SetupParser + { + private readonly SetupInvocation _setupInvocation; + private readonly bool _isLanguageVersionSupported; + private readonly KnownTypeSymbols _knownTypeSymbols; + + public SetupParser(SetupInvocation setupInvocation, CompilationData compilationData) + { + _setupInvocation = setupInvocation; + _isLanguageVersionSupported = compilationData.LanguageVersionIsSupported; + _knownTypeSymbols = compilationData.TypeSymbols!; + } + + public (SetupEmitter? Emitter, ImmutableArray? Diagnostics) Parse(CancellationToken cancellationToken) + { + if (!_isLanguageVersionSupported) + { + // TODO: Create error diagnostic + return (null, null); + } + + var operation = _setupInvocation.Operation; + + // Setup calls are expected to have a single argument, being the setup action argument + var setupArgument = operation.Arguments[0]; + + cancellationToken.ThrowIfCancellationRequested(); + + // Setup calls are expected to be called from Pretend so the type argument gives us the type we are pretending + // TODO: Assert the containing type maybe? + var pretendType = operation.TargetMethod.ContainingType.TypeArguments[0]; + + var useSetMethod = operation.TargetMethod.Name == "SetupSet"; + + var parser = new SetupActionParser(setupArgument.Value, pretendType, useSetMethod); + + var (setupActionEmitter, setupActionDiagnostics) = parser.Parse(cancellationToken); + + cancellationToken.ThrowIfCancellationRequested(); + + if (setupActionEmitter == null) + { + return (null, setupActionDiagnostics); + } + + return (new SetupEmitter(setupActionEmitter, operation), setupActionDiagnostics); + } + } +} diff --git a/src/Pretender.SourceGenerator/Parser/VerifyParser.cs b/src/Pretender.SourceGenerator/Parser/VerifyParser.cs index 96561f9..6dcb265 100644 --- a/src/Pretender.SourceGenerator/Parser/VerifyParser.cs +++ b/src/Pretender.SourceGenerator/Parser/VerifyParser.cs @@ -7,39 +7,60 @@ namespace Pretender.SourceGenerator.Parser { internal class VerifyParser { - private readonly VerifyInvocation _verifyInvocation; + private readonly bool _isLanguageVersionSupported; private readonly KnownTypeSymbols _knownTypeSymbols; + private readonly VerifyInvocation _verifyInvocation; public VerifyParser(VerifyInvocation verifyInvocation, CompilationData compilationData) { + _isLanguageVersionSupported = compilationData.LanguageVersionIsSupported; _knownTypeSymbols = compilationData.TypeSymbols!; _verifyInvocation = verifyInvocation; } - public (VerifyEmitter? VerifyEmitter, ImmutableArray? Diagnostics) GetVerifyEmitter(CancellationToken cancellationToken) + public (VerifyEmitter? VerifyEmitter, ImmutableArray? Diagnostics) Parse(CancellationToken cancellationToken) { + if (!_isLanguageVersionSupported) + { + // TODO: Create error diagnostic + return (null, null); + } + + cancellationToken.ThrowIfCancellationRequested(); + var operation = _verifyInvocation.Operation; // Verify calls are expected to have 2 arguments, the first being the setup expression var setupArgument = operation.Arguments[0]; + cancellationToken.ThrowIfCancellationRequested(); + // Verify calls are expected to be called from Pretend so the type argument gives us the type we are pretending var pretendType = operation.TargetMethod.ContainingType.TypeArguments[0]; // TODO: This doesn't exist yet var useSetMethod = operation.TargetMethod.Name == "VerifySet"; - // TODO: This should be done in a Parser type class as well - var setupCreationSpec = new SetupCreationSpec(setupArgument, pretendType, useSetMethod); + var parser = new SetupActionParser(setupArgument.Value, pretendType, useSetMethod); + + var (setupActionEmitter, setupActionDiagnostics) = parser.Parse(cancellationToken); + + cancellationToken.ThrowIfCancellationRequested(); + + if (setupActionEmitter == null) + { + return (null, setupActionDiagnostics); + } var returnType = setupArgument.Parameter!.Type.Name == "Func" ? ((INamedTypeSymbol)setupArgument.Parameter.Type).TypeArguments[1] // The Func variant is expected to have the return type in the second type argument : null; - var emitter = new VerifyEmitter(pretendType, returnType, setupCreationSpec, _verifyInvocation.Operation); + cancellationToken.ThrowIfCancellationRequested(); + + var emitter = new VerifyEmitter(pretendType, returnType, setupActionEmitter, _verifyInvocation.Operation); - // TODO: Get diagnostics from elsewhere - return (emitter, null); + return (emitter, setupActionDiagnostics); } } } diff --git a/src/Pretender.SourceGenerator/PretenderSourceGenerator.cs b/src/Pretender.SourceGenerator/PretenderSourceGenerator.cs index 50020f5..bd2751e 100644 --- a/src/Pretender.SourceGenerator/PretenderSourceGenerator.cs +++ b/src/Pretender.SourceGenerator/PretenderSourceGenerator.cs @@ -63,41 +63,44 @@ public void Initialize(IncrementalGeneratorInitializationContext context) #endregion #region Setup - var setupCallsWithDiagnostics = + IncrementalValuesProvider<(SetupEmitter? Emitter, ImmutableArray? Diagnostics)> setups = context.SyntaxProvider.CreateSyntaxProvider( - predicate: static (node, _) => node.IsSetupCall(), - transform: static (context, token) => + predicate: static (node, _) => SetupInvocation.IsCandidateSyntaxNode(node), + transform: SetupInvocation.Create) + .Where(i => i is not null) + .Combine(compilationData) + .Select(static (tuple, token) => + { + if (tuple.Right is not CompilationData compilationData) { - // All of this should be asserted in the predicate - var operation = context.SemanticModel.GetOperation(context.Node, token); - if (operation!.IsValidSetupOperation(context.SemanticModel.Compilation, out var invocation)) - { - return new SetupEntrypoint(invocation!); - } - return null; - }) - .Where(i => i is not null); + return (null, null); + } - context.RegisterSourceOutput(setupCallsWithDiagnostics, static (context, setup) => - { - foreach (var diagnostic in setup!.Diagnostics) - { - context.ReportDiagnostic(diagnostic); - } - }); + var parser = new SetupParser(tuple.Left!, compilationData); - var setups = setupCallsWithDiagnostics - .Where(s => s!.Diagnostics.Count == 0); + return parser.Parse(token); + }) + .WithTrackingName("Setup"); context.RegisterSourceOutput(setups.Collect(), static (context, setups) => { - var members = new List(); - for (var i = 0; i < setups.Length; i++) { var setup = setups[i]; - members.AddRange(setup!.GetMembers(i)); + + if (setup.Diagnostics is ImmutableArray diagnostics) + { + foreach (var diagnostic in diagnostics) + { + context.ReportDiagnostic(diagnostic); + } + } + + if (setup.Emitter is SetupEmitter emitter) + { + members.AddRange(emitter.Emit(i, context.CancellationToken)); + } } var classDeclaration = SyntaxFactory.ClassDeclaration("SetupInterceptors") @@ -144,15 +147,14 @@ public void Initialize(IncrementalGeneratorInitializationContext context) // Create new VerifySpec var parser = new VerifyParser(tuple.Left!, compilationData); - return parser.GetVerifyEmitter(cancellationToken); + return parser.Parse(cancellationToken); }) .WithTrackingName("Verify"); - // TODO: Register diagnostics context.RegisterSourceOutput(verifyCallsWithDiagnostics.Collect(), (context, inputs) => { var methods = new List(); - for ( var i = 0; i < inputs.Length; i++) + for (var i = 0; i < inputs.Length; i++) { var input = inputs[i]; if (input.Diagnostics is ImmutableArray diagnostics) @@ -166,7 +168,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) if (input.Emitter is VerifyEmitter emitter) { // TODO: Emit VerifyMethod - var method = emitter.EmitVerifyMethod(0, context.CancellationToken); + var method = emitter.Emit(0, context.CancellationToken); methods.Add(method); } } diff --git a/src/Pretender.SourceGenerator/SetupActionEmitter.cs b/src/Pretender.SourceGenerator/SetupActionEmitter.cs new file mode 100644 index 0000000..9922b7b --- /dev/null +++ b/src/Pretender.SourceGenerator/SetupActionEmitter.cs @@ -0,0 +1,212 @@ +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Pretender.SourceGenerator.SetupArguments; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace Pretender.SourceGenerator +{ + internal class SetupActionEmitter + { + private readonly ImmutableArray _setupArgumentSpecs; + + public SetupActionEmitter(ITypeSymbol pretendType, IMethodSymbol setupMethod, ImmutableArray setupArgumentSpecs) + { + PretendType = pretendType; + SetupMethod = setupMethod; + _setupArgumentSpecs = setupArgumentSpecs; + } + + public ITypeSymbol PretendType { get; } + public IMethodSymbol SetupMethod { get; } + + public InvocationExpressionSyntax CreateSetupGetter(CancellationToken cancellationToken) + { + var totalMatchStatements = _setupArgumentSpecs.Sum(sa => sa.NeededMatcherStatements); + cancellationToken.ThrowIfCancellationRequested(); + + var matchStatements = new StatementSyntax[totalMatchStatements]; + int addedStatements = 0; + + for (var i = 0; i < _setupArgumentSpecs.Length; i++) + { + var argument = _setupArgumentSpecs[i]; + + var newStatements = argument.CreateMatcherStatements(cancellationToken); + cancellationToken.ThrowIfCancellationRequested(); + + newStatements.CopyTo(matchStatements, addedStatements); + addedStatements += newStatements.Length; + } + + ArgumentSyntax matcherArgument; + ImmutableArray statements; + if (matchStatements.Length == 0) + { + statements = ImmutableArray.Empty; + + // Nothing actually needs to match this will always return true, so we use a cached matcher that always returns true + matcherArgument = Argument(MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + ParseTypeName("Cache"), + IdentifierName("NoOpMatcher"))) + .WithNameColon(NameColon("matcher")); + } + else + { + // Other match statements should have added all the ways the method could return false + // so if it gets through all those statements it should return true at the end. + var trueReturnStatement = ReturnStatement(LiteralExpression(SyntaxKind.TrueLiteralExpression)); + + /* + * Matcher matchCall = static (callInfo, target) => + * { + * ...matching calls... + * return true; + * } + */ + var matchCallIdentifier = Identifier("matchCall"); + + var matcherDelegate = ParenthesizedLambdaExpression( + ParameterList(SeparatedList([ + Parameter(Identifier("callInfo")), + Parameter(Identifier("target")) + ])), + Block(List([.. matchStatements, trueReturnStatement]))) + .WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword))); + + statements = ImmutableArray.Create(LocalDeclarationStatement(VariableDeclaration( + ParseTypeName("Matcher")) + .WithVariables(SingletonSeparatedList( + VariableDeclarator(matchCallIdentifier) + .WithInitializer(EqualsValueClause(matcherDelegate)))))); + + matcherArgument = Argument(IdentifierName(matchCallIdentifier)); + } + + cancellationToken.ThrowIfCancellationRequested(); + + var objectCreationArguments = ArgumentList( + SeparatedList(new[] + { + Argument(IdentifierName("pretend")), + //Argument(IdentifierName("setupExpression")), + Argument(MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName(PretendType.ToPretendName()), + IdentifierName(SetupMethod.ToMethodInfoCacheName()) + )), + matcherArgument, + Argument(MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("expr"), + IdentifierName("Target") + )), + })); + + cancellationToken.ThrowIfCancellationRequested(); + + GenericNameSyntax returnObjectName; + SimpleNameSyntax getOrCreateName; + if (SetupMethod.ReturnsVoid) + { + // VoidCompiledSetup + returnObjectName = GenericName("VoidCompiledSetup") + .AddTypeArgumentListArguments(ParseTypeName(PretendType.ToFullDisplayString())); + + getOrCreateName = IdentifierName("GetOrCreateSetup"); + } + else + { + // ReturningCompiledSetup + returnObjectName = GenericName("ReturningCompiledSetup") + .AddTypeArgumentListArguments( + ParseTypeName(PretendType.ToFullDisplayString()), + SetupMethod.ReturnType.AsUnknownTypeSyntax()); + + getOrCreateName = GenericName("GetOrCreateSetup") + .AddTypeArgumentListArguments(SetupMethod.ReturnType.AsUnknownTypeSyntax()); + + // TODO: Recursively mock? + ExpressionSyntax defaultValue; + + if (SetupMethod.ReturnType.EqualsByName(["System", "Threading", "Tasks", "Task"])) + { + if (SetupMethod.ReturnType is INamedTypeSymbol namedType && namedType.TypeArguments.Length == 1) + { + // Task.FromResult(default) + defaultValue = KnownBlocks.TaskFromResult( + namedType.TypeArguments[0].AsUnknownTypeSyntax(), + LiteralExpression(SyntaxKind.DefaultLiteralExpression)); + } + else + { + // Task.CompletedTask + defaultValue = KnownBlocks.TaskCompletedTask; + } + } + else if (SetupMethod.ReturnType.EqualsByName(["System", "Threading", "Tasks", "ValueTask"])) + { + if (SetupMethod.ReturnType is INamedTypeSymbol namedType && namedType.TypeArguments.Length == 1) + { + // ValueTask.FromResult(default) + defaultValue = KnownBlocks.ValueTaskFromResult( + namedType.TypeArguments[0].AsUnknownTypeSyntax(), + LiteralExpression(SyntaxKind.DefaultLiteralExpression) + ); + } + else + { + // ValueTask.CompletedTask + defaultValue = KnownBlocks.ValueTaskCompletedTask; + } + } + else + { + // TODO: Support custom awaitable + // default + defaultValue = LiteralExpression(SyntaxKind.DefaultLiteralExpression); + } + + cancellationToken.ThrowIfCancellationRequested(); + + objectCreationArguments = objectCreationArguments.AddArguments(Argument( + defaultValue).WithNameColon(NameColon("defaultValue"))); + } + + cancellationToken.ThrowIfCancellationRequested(); + + var compiledSetupCreation = ObjectCreationExpression(returnObjectName) + .WithArgumentList(objectCreationArguments); + + // (pretend, expression) => + // { + // return new CompiledSetup(); + // } + var creator = ParenthesizedLambdaExpression() + .WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword))) + .AddParameterListParameters(Parameter(Identifier("pretend")), Parameter(Identifier("expr"))) + .AddBlockStatements([.. statements, ReturnStatement(compiledSetupCreation)]); + + // TODO: The hash code doesn't actually work, right now, this will create a new pretend every call. + // We likely need to create our own class that can calculate the hash code and place that number in here. + + cancellationToken.ThrowIfCancellationRequested(); + // TODO: Should I have a different seed? + //var badHashCode = _argumentSpecs.Aggregate(0, (agg, s) => HashCode.Combine(agg, s.GetHashCode())); + var badHashCode = 0; + + cancellationToken.ThrowIfCancellationRequested(); + + // pretend.GetOrCreateSetup() + return InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("pretend"), + getOrCreateName)) + .AddArgumentListArguments( + Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(badHashCode))), + Argument(creator), + Argument(IdentifierName("setupExpression"))); + } + } +} diff --git a/src/Pretender.SourceGenerator/SetupArguments/InvocationArgumentSpec.cs b/src/Pretender.SourceGenerator/SetupArguments/InvocationArgumentSpec.cs index cc3e3fc..3f4fd1d 100644 --- a/src/Pretender.SourceGenerator/SetupArguments/InvocationArgumentSpec.cs +++ b/src/Pretender.SourceGenerator/SetupArguments/InvocationArgumentSpec.cs @@ -8,6 +8,7 @@ namespace Pretender.SourceGenerator.SetupArguments { + // TODO: Should probably have Specs for "safe" invocations and capturing specs internal class InvocationArgumentSpec : SetupArgumentSpec { private readonly IInvocationOperation _invocationOperation; @@ -29,6 +30,7 @@ public override int NeededMatcherStatements { if (TryGetMatcherAttributeType(out var matcherType)) { + // TODO: Match with KnownTypeSymbols if (matcherType.EqualsByName(["Pretender", "Matchers", "AnyMatcher"])) { _cachedMatcherStatements = ImmutableArray.Empty; @@ -179,7 +181,7 @@ private bool TryGetMatcherAttributeType(out INamedTypeSymbol matcherType) public override int GetHashCode() { - // TODO: This is not enought for uniqueness + // TODO: This is not enough for uniqueness return SymbolEqualityComparer.Default.GetHashCode(_invocationOperation.TargetMethod); } } diff --git a/src/Pretender.SourceGenerator/SetupCreationSpec.cs b/src/Pretender.SourceGenerator/SetupCreationSpec.cs deleted file mode 100644 index 4ab3778..0000000 --- a/src/Pretender.SourceGenerator/SetupCreationSpec.cs +++ /dev/null @@ -1,362 +0,0 @@ -using System.Collections.Immutable; -using System.Diagnostics; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Operations; -using Pretender.SourceGenerator.SetupArguments; -using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; - -namespace Pretender.SourceGenerator -{ - internal class SetupCreationSpec - { - private readonly IArgumentOperation _setupArgument; - private readonly ITypeSymbol _pretendType; - private readonly bool _useSetMethod; - - private readonly IMethodSymbol? _setupMethod; - private readonly ImmutableArray _argumentSpecs; - - public SetupCreationSpec(IArgumentOperation setupArgument, ITypeSymbol pretendType, bool useSetMethod) - { - _setupArgument = setupArgument; - _pretendType = pretendType; - _useSetMethod = useSetMethod; - - var candidates = GetInvocationCandidates(); - - if (candidates.Length == 0) - { - // TODO: Add diagnostic - return; - } - else if (candidates.Length != 1) - { - // TODO: Add diagnostic - return; - } - - var candidate = candidates[0]; - _setupMethod = candidate.Method; - - var builder = ImmutableArray.CreateBuilder(candidate.Arguments.Length); - for (var i = 0; i < candidate.Arguments.Length; i++) - { - builder.Add(SetupArgumentSpec.Create(candidate.Arguments[i], i)); - } - - _argumentSpecs = builder.MoveToImmutable(); - - // TODO: Get argument specs diagnostics and make them my own - } - - private ImmutableArray GetInvocationCandidates() - { - var builder = ImmutableArray.CreateBuilder(); - TraverseOperation(_setupArgument.Value, builder); - return builder.ToImmutable(); - } - - private void TraverseOperation(IOperation operation, ImmutableArray.Builder invocationCandidates) - { - switch (operation.Kind) - { - case OperationKind.Block: - var blockOperation = (IBlockOperation)operation; - TraverseOperationList(blockOperation.Operations, invocationCandidates); - break; - case OperationKind.Return: - var returnOperation = (IReturnOperation)operation; - if (returnOperation.ReturnedValue != null) - { - TraverseOperation(returnOperation.ReturnedValue, invocationCandidates); - } - break; - case OperationKind.ExpressionStatement: - var expressionStatement = (IExpressionStatementOperation)operation; - TraverseOperation(expressionStatement.Operation, invocationCandidates); - break; - case OperationKind.Conversion: - var conversionOperation = (IConversionOperation)operation; - TraverseOperation(conversionOperation.Operand, invocationCandidates); - break; - case OperationKind.Invocation: - var invocationOperation = (IInvocationOperation)operation; - TryMatchInvocationOperation(invocationOperation, invocationCandidates); - break; - case OperationKind.PropertyReference: - var propertyReferenceOperation = (IPropertyReferenceOperation)operation; - TryMatchPropertyReference(propertyReferenceOperation, invocationCandidates); - break; - case OperationKind.AnonymousFunction: - var anonymousFunctionOperation = (IAnonymousFunctionOperation)operation; - TraverseOperation(anonymousFunctionOperation.Body, invocationCandidates); - break; - case OperationKind.DelegateCreation: - var delegateCreationOperation = (IDelegateCreationOperation)operation; - TraverseOperation(delegateCreationOperation.Target, invocationCandidates); - break; - default: -#if DEBUG - // TODO: Figure out what operation caused this, it's not ideal to "randomly" support operations - Debugger.Launch(); -#endif - // Absolute fallback, most of our operations can be supported this way but it's nicer to be explicit - TraverseOperationList(operation.ChildOperations, invocationCandidates); - break; - } - } - - private void TraverseOperationList(IEnumerable operations, ImmutableArray.Builder invocationCandidates) - { - foreach (var operation in operations) - { - TraverseOperation(operation, invocationCandidates); - } - } - - private void TryMatchPropertyReference(IPropertyReferenceOperation propertyReference, ImmutableArray.Builder invocationCandidates) - { - if (propertyReference.Instance is not IParameterReferenceOperation parameterReference) - { - return; - } - - if (!SymbolEqualityComparer.Default.Equals(parameterReference.Type, _pretendType)) - { - return; - } - - var method = _useSetMethod - ? propertyReference.Property.SetMethod - : propertyReference.Property.GetMethod; - - if (method == null) - { - return; - } - - invocationCandidates.Add(new InvocationCandidate(method, ImmutableArray.Empty)); - } - - private void TryMatchInvocationOperation(IInvocationOperation invocation, ImmutableArray.Builder invocationCandidates) - { - if (invocation.Instance is not IParameterReferenceOperation parameterReference) - { - return; - } - - if (!SymbolEqualityComparer.Default.Equals(parameterReference.Type, _pretendType)) - { - return; - } - - // TODO: Any more validation? - - invocationCandidates.Add(new InvocationCandidate(invocation.TargetMethod, invocation.Arguments)); - } - - public InvocationExpressionSyntax CreateSetupGetter(CancellationToken cancellationToken) - { - Debug.Assert(_setupMethod is not null, "A setup method could not be found, which means there should have been error diagnostics and this method should not have ran."); - - var totalMatchStatements = _argumentSpecs.Sum(sa => sa.NeededMatcherStatements); - cancellationToken.ThrowIfCancellationRequested(); - - var matchStatements = new StatementSyntax[totalMatchStatements]; - int addedStatements = 0; - - for (var i = 0; i < _argumentSpecs.Length; i++) - { - var argument = _argumentSpecs[i]; - - var newStatements = argument.CreateMatcherStatements(cancellationToken); - cancellationToken.ThrowIfCancellationRequested(); - - newStatements.CopyTo(matchStatements, addedStatements); - addedStatements += newStatements.Length; - } - - ArgumentSyntax matcherArgument; - ImmutableArray statements; - if (matchStatements.Length == 0) - { - statements = ImmutableArray.Empty; - - // Nothing actually needs to match this will always return true, so we use a cached matcher that always returns true - matcherArgument = Argument(MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - ParseTypeName("Cache"), - IdentifierName("NoOpMatcher"))) - .WithNameColon(NameColon("matcher")); - } - else - { - // Other match statements should have added all the ways the method could return false - // so if it gets through all those statements it should return true at the end. - var trueReturnStatement = ReturnStatement(LiteralExpression(SyntaxKind.TrueLiteralExpression)); - - /* - * Matcher matchCall = static (callInfo, target) => - * { - * ...matching calls... - * return true; - * } - */ - var matchCallIdentifier = Identifier("matchCall"); - - var matcherDelegate = ParenthesizedLambdaExpression( - ParameterList(SeparatedList([ - Parameter(Identifier("callInfo")), - Parameter(Identifier("target")) - ])), - Block(List([.. matchStatements, trueReturnStatement]))) - .WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword))); - - statements = ImmutableArray.Create(LocalDeclarationStatement(VariableDeclaration( - ParseTypeName("Matcher")) - .WithVariables(SingletonSeparatedList( - VariableDeclarator(matchCallIdentifier) - .WithInitializer(EqualsValueClause(matcherDelegate)))))); - - matcherArgument = Argument(IdentifierName(matchCallIdentifier)); - } - - cancellationToken.ThrowIfCancellationRequested(); - - var objectCreationArguments = ArgumentList( - SeparatedList(new[] - { - Argument(IdentifierName("pretend")), - //Argument(IdentifierName("setupExpression")), - Argument(MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName(_pretendType.ToPretendName()), - IdentifierName(_setupMethod!.ToMethodInfoCacheName()) - )), - matcherArgument, - Argument(MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("expr"), - IdentifierName("Target") - )), - })); - - cancellationToken.ThrowIfCancellationRequested(); - - GenericNameSyntax returnObjectName; - SimpleNameSyntax getOrCreateName; - if (_setupMethod!.ReturnsVoid) - { - // VoidCompiledSetup - returnObjectName = GenericName("VoidCompiledSetup") - .AddTypeArgumentListArguments(ParseTypeName(_pretendType.ToFullDisplayString())); - - getOrCreateName = IdentifierName("GetOrCreateSetup"); - } - else - { - // ReturningCompiledSetup - returnObjectName = GenericName("ReturningCompiledSetup") - .AddTypeArgumentListArguments( - ParseTypeName(_pretendType.ToFullDisplayString()), - _setupMethod.ReturnType.AsUnknownTypeSyntax()); - - getOrCreateName = GenericName("GetOrCreateSetup") - .AddTypeArgumentListArguments(_setupMethod.ReturnType.AsUnknownTypeSyntax()); - - // TODO: Recursively mock? - ExpressionSyntax defaultValue; - - if (_setupMethod.ReturnType.EqualsByName(["System", "Threading", "Tasks", "Task"])) - { - if (_setupMethod.ReturnType is INamedTypeSymbol namedType && namedType.TypeArguments.Length == 1) - { - // Task.FromResult(default) - defaultValue = KnownBlocks.TaskFromResult( - namedType.TypeArguments[0].AsUnknownTypeSyntax(), - LiteralExpression(SyntaxKind.DefaultLiteralExpression)); - } - else - { - // Task.CompletedTask - defaultValue = KnownBlocks.TaskCompletedTask; - } - } - else if (_setupMethod.ReturnType.EqualsByName(["System", "Threading", "Tasks", "ValueTask"])) - { - if (_setupMethod.ReturnType is INamedTypeSymbol namedType && namedType.TypeArguments.Length == 1) - { - // ValueTask.FromResult(default) - defaultValue = KnownBlocks.ValueTaskFromResult( - namedType.TypeArguments[0].AsUnknownTypeSyntax(), - LiteralExpression(SyntaxKind.DefaultLiteralExpression) - ); - } - else - { - // ValueTask.CompletedTask - defaultValue = KnownBlocks.ValueTaskCompletedTask; - } - } - else - { - // TODO: Support custom awaitable - // default - defaultValue = LiteralExpression(SyntaxKind.DefaultLiteralExpression); - } - - cancellationToken.ThrowIfCancellationRequested(); - - objectCreationArguments = objectCreationArguments.AddArguments(Argument( - defaultValue).WithNameColon(NameColon("defaultValue"))); - } - - cancellationToken.ThrowIfCancellationRequested(); - - var compiledSetupCreation = ObjectCreationExpression(returnObjectName) - .WithArgumentList(objectCreationArguments); - - // (pretend, expression) => - // { - // return new CompiledSetup(); - // } - var creator = ParenthesizedLambdaExpression() - .WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword))) - .AddParameterListParameters(Parameter(Identifier("pretend")), Parameter(Identifier("expr"))) - .AddBlockStatements([.. statements, ReturnStatement(compiledSetupCreation)]); - - // TODO: The hash code doesn't actually work, right now, this will create a new pretend every call. - // We likely need to create our own class that can calculate the hash code and place that number in here. - - cancellationToken.ThrowIfCancellationRequested(); - // TODO: Should I have a different seed? - //var badHashCode = _argumentSpecs.Aggregate(0, (agg, s) => HashCode.Combine(agg, s.GetHashCode())); - var badHashCode = 0; - - cancellationToken.ThrowIfCancellationRequested(); - - // pretend.GetOrCreateSetup() - return InvocationExpression(MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, - IdentifierName("pretend"), - getOrCreateName)) - .AddArgumentListArguments( - Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(badHashCode))), - Argument(creator), - Argument(IdentifierName("setupExpression"))); - } - - private class InvocationCandidate - { - public InvocationCandidate(IMethodSymbol methodSymbol, ImmutableArray argumentOperations) - { - Method = methodSymbol; - Arguments = argumentOperations; - } - - public IMethodSymbol Method { get; } - public ImmutableArray Arguments { get; } - } - } -} diff --git a/src/Pretender.SourceGenerator/SetupEntrypoint.cs b/src/Pretender.SourceGenerator/SetupEntrypoint.cs index 71a2155..d2244f0 100644 --- a/src/Pretender.SourceGenerator/SetupEntrypoint.cs +++ b/src/Pretender.SourceGenerator/SetupEntrypoint.cs @@ -4,6 +4,7 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Operations; +using Pretender.SourceGenerator.Parser; using Pretender.SourceGenerator.SetupArguments; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; @@ -24,7 +25,13 @@ public SetupEntrypoint(IInvocationOperation invocationOperation) PretendType = pretendType; // TODO: Use correct useSetup value - SetupCreation = new SetupCreationSpec(setupExpressionArg, pretendType, false); + var parser = new SetupActionParser(setupExpressionArg.Value, pretendType, false); + + // TODO: Use the parser properly + var (emitter, diagnostics) = parser.Parse(default); + + // TODO: Don't override null + SetupCreation = emitter!; // TODO: Consume diagnostics var setupMethod = SimplifyOperation(setupExpressionArg.Value); @@ -54,6 +61,7 @@ public SetupEntrypoint(IInvocationOperation invocationOperation) } Arguments = setupArguments.ToImmutableArray(); + // TODO: Don't do this Diagnostics.AddRange(Arguments.SelectMany(s => s.Diagnostics)); } @@ -62,7 +70,7 @@ public SetupEntrypoint(IInvocationOperation invocationOperation) public ITypeSymbol PretendType { get; } public List Diagnostics { get; } = new List(); public IMethodSymbol SetupMethod { get; } = null!; - public SetupCreationSpec SetupCreation { get; } + public SetupActionEmitter SetupCreation { get; } public ImmutableArray Arguments { get; } public MemberDeclarationSyntax[] GetMembers(int index) diff --git a/src/Pretender.SourceGenerator/SetupInvocation.cs b/src/Pretender.SourceGenerator/SetupInvocation.cs new file mode 100644 index 0000000..a7213c2 --- /dev/null +++ b/src/Pretender.SourceGenerator/SetupInvocation.cs @@ -0,0 +1,60 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Operations; +using Pretender.SourceGenerator.Parser; + +namespace Pretender.SourceGenerator +{ + internal class SetupInvocation + { + public SetupInvocation(IInvocationOperation operation, Location location) + { + Operation = operation; + Location = location; + } + + public IInvocationOperation Operation { get; } + public Location Location { get; } + + public static bool IsCandidateSyntaxNode(SyntaxNode node) + { + return node is InvocationExpressionSyntax + { + Expression: MemberAccessExpressionSyntax + { + Name.Identifier.ValueText: "Setup" or "SetupSet", + }, + ArgumentList.Arguments.Count: 1, + }; + } + + public static SetupInvocation? Create(GeneratorSyntaxContext context, CancellationToken cancellationToken) + { + Debug.Assert(IsCandidateSyntaxNode(context.Node)); + var invocationSyntax = (InvocationExpressionSyntax)context.Node; + + return context.SemanticModel.GetOperation(invocationSyntax, cancellationToken) is IInvocationOperation operation + && IsSetupOperation(operation) + ? new SetupInvocation(operation, invocationSyntax.GetLocation()) + : null; + } + + private static bool IsSetupOperation(IInvocationOperation operation) + { + if (operation.TargetMethod is not IMethodSymbol + { + Name: "Setup" or "SetupSet", + ContainingType: INamedTypeSymbol namedTypeSymbol + } || !KnownTypeSymbols.IsPretend(namedTypeSymbol)) + { + return false; + } + + return true; + } + } +} diff --git a/src/Pretender.SourceGenerator/SymbolExtensions.cs b/src/Pretender.SourceGenerator/SymbolExtensions.cs index 0d81068..7d936e1 100644 --- a/src/Pretender.SourceGenerator/SymbolExtensions.cs +++ b/src/Pretender.SourceGenerator/SymbolExtensions.cs @@ -53,6 +53,19 @@ public static TypeSyntax AsUnknownTypeSyntax(this ITypeSymbol type) return typeSyntax; } + public static ExpressionSyntax ToDefaultValueSyntax(this ITypeSymbol type) + { + // They have explicitly annotated this type as nullable, so return null + if (type.NullableAnnotation == NullableAnnotation.Annotated) + { + return LiteralExpression(SyntaxKind.DefaultLiteralExpression); + } + + + + throw new NotImplementedException(); + } + public static string ToPretendName(this ITypeSymbol symbol) { return $"Pretend{symbol.Name}{SymbolEqualityComparer.Default.GetHashCode(symbol):X}"; diff --git a/src/Pretender.SourceGenerator/VerifyInvocation.cs b/src/Pretender.SourceGenerator/VerifyInvocation.cs index 8ac5768..4c9015c 100644 --- a/src/Pretender.SourceGenerator/VerifyInvocation.cs +++ b/src/Pretender.SourceGenerator/VerifyInvocation.cs @@ -49,6 +49,7 @@ private static bool IsVerifyOperation(IInvocationOperation operation) // but we should do it all with string comparisons if (operation.TargetMethod is not IMethodSymbol { + // TODO: The name has already been asserted, do I need to do this again? Name: "Verify", // TODO: or VerifySet, ContainingType: INamedTypeSymbol namedTypeSymbol } || !KnownTypeSymbols.IsPretend(namedTypeSymbol)) diff --git a/src/Pretender/Called.cs b/src/Pretender/Called.cs index 1472a7b..35dc57a 100644 --- a/src/Pretender/Called.cs +++ b/src/Pretender/Called.cs @@ -25,10 +25,12 @@ enum CalledKind public static Called Exactly(int expectedCalls) => new(expectedCalls, expectedCalls, CalledKind.Exact); - public static Called AtLeastOnce() - => new(1, int.MaxValue, CalledKind.AtLeast); + public static Called AtLeastOnce() => AtLeast(1); - public static implicit operator Called(Range range) + public static Called AtLeast(int minimumCalls) + => new(minimumCalls, int.MaxValue, CalledKind.AtLeast); + + public static Called Range(Range range) { if (range.Start.IsFromEnd || range.End.IsFromEnd) { @@ -38,8 +40,9 @@ public static implicit operator Called(Range range) return new(range.Start.Value, range.End.Value, CalledKind.Range); } - public static implicit operator Called(int expectedCalls) - => new(expectedCalls, expectedCalls, CalledKind.Exact); + public static implicit operator Called(Range range) => Range(range); + + public static implicit operator Called(int expectedCalls) => Exactly(expectedCalls); [StackTraceHidden] public void Validate(int callCount) diff --git a/src/Pretender/Matchers/MatcherListener.cs b/src/Pretender/Matchers/MatcherListener.cs new file mode 100644 index 0000000..103fde1 --- /dev/null +++ b/src/Pretender/Matchers/MatcherListener.cs @@ -0,0 +1,68 @@ +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; + +namespace Pretender.Matchers +{ + public sealed class MatcherListener : IDisposable + { + [ThreadStatic] + private static Stack? s_listeners; + + public static MatcherListener StartListening() + { + var listener = new MatcherListener(); + var listeners = s_listeners; + if (listeners == null) + { + s_listeners = listeners = new Stack(); + } + + listeners.Push(listener); + + return listener; + } + + public static bool IsListening([MaybeNullWhen(false)] out MatcherListener listener) + { + var listeners = s_listeners; + + if (listeners != null && listeners.Count > 0) + { + listener = listeners.Peek(); + return true; + } + + listener = null; + return false; + } + + private List _matchers; + + public void OnMatch(IMatcher matcher) + { + if (_matchers == null) + { + _matchers = []; + } + + _matchers.Add(matcher); + } + + public IEnumerable GetMatchers() + { + if (_matchers == null) + { + return []; + } + + return _matchers; + } + + public void Dispose() + { + var listeners = s_listeners; + Debug.Assert(listeners != null && listeners.Count > 0); + listeners.Pop(); + } + } +} diff --git a/test/Pretender.Tests/CalledTests.cs b/test/Pretender.Tests/CalledTests.cs index 42f5893..abb543c 100644 --- a/test/Pretender.Tests/CalledTests.cs +++ b/test/Pretender.Tests/CalledTests.cs @@ -12,6 +12,9 @@ public static IEnumerable Validate_DoesNotThrowData() yield return Data(Called.AtLeastOnce(), 1); yield return Data(Called.AtLeastOnce(), 2); + yield return Data(Called.AtLeast(3), 3); + yield return Data(Called.AtLeast(3), 10); + // Range yield return Data(1..4, 1); yield return Data(1..4, 2); @@ -38,6 +41,8 @@ public static IEnumerable Validate_ThrowsData() // AtLeast yield return Data(Called.AtLeastOnce(), 0); + yield return Data(Called.AtLeast(5), 0); + yield return Data(Called.AtLeast(5), 4); // Range yield return Data(2..5, 0); diff --git a/test/Pretender.Tests/Matchers/MatcherListenerTests.cs b/test/Pretender.Tests/Matchers/MatcherListenerTests.cs new file mode 100644 index 0000000..edf7687 --- /dev/null +++ b/test/Pretender.Tests/Matchers/MatcherListenerTests.cs @@ -0,0 +1,22 @@ +using Pretender.Matchers; + +namespace Pretender.Tests.Matchers +{ + public class MatcherListenerTests + { + [Fact] + public void StartListening_ReturnsSameListenerAsIsListening() + { + using var startedListener = MatcherListener.StartListening(); + Assert.True(MatcherListener.IsListening(out var listener)); + Assert.Equal(startedListener, listener); + } + + [Fact] + public void IsListening_ReturnsFalse_WhenNotStarted() + { + Assert.False(MatcherListener.IsListening(out var listener)); + Assert.Null(listener); + } + } +} diff --git a/test/SourceGeneratorTests/TestBase.cs b/test/SourceGeneratorTests/TestBase.cs index 0023b8b..163eb7a 100644 --- a/test/SourceGeneratorTests/TestBase.cs +++ b/test/SourceGeneratorTests/TestBase.cs @@ -98,8 +98,8 @@ private void CompareAgainstBaseline(GeneratedSourceResult result, string testMet #if !GENERATE_SOURCE var resultFileName = result.HintName.Replace('.', '_'); var baseLineName = $"{GetType().Name}.{testMethodName}.{normalizedName}"; - var resourceName = typeof(TestBase).Assembly.GetManifestResourceNames() - .Single(r => r.EndsWith(baseLineName)); + var resourceName = Assert.Single(typeof(TestBase).Assembly.GetManifestResourceNames() + .Where(r => r.EndsWith(baseLineName))); using var stream = typeof(TestBase).Assembly.GetManifestResourceStream(resourceName)!; using var reader = new StreamReader(stream);