diff --git a/src/Pretender.SourceGenerator/Emitter/GrandEmitter.cs b/src/Pretender.SourceGenerator/Emitter/GrandEmitter.cs index 4d40b30..0d1cf88 100644 --- a/src/Pretender.SourceGenerator/Emitter/GrandEmitter.cs +++ b/src/Pretender.SourceGenerator/Emitter/GrandEmitter.cs @@ -1,20 +1,88 @@ using System.Collections.Immutable; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; namespace Pretender.SourceGenerator.Emitter { internal class GrandEmitter { private readonly ImmutableArray _pretendEmitters; + private readonly ImmutableArray _setupEmitters; + private readonly ImmutableArray _verifyEmitters; + private readonly ImmutableArray _createEmitters; - public GrandEmitter(ImmutableArray pretendEmitters) + public GrandEmitter( + ImmutableArray pretendEmitters, + ImmutableArray setupEmitters, + ImmutableArray verifyEmitters, + ImmutableArray createEmitters) { _pretendEmitters = pretendEmitters; + _setupEmitters = setupEmitters; + _verifyEmitters = verifyEmitters; + _createEmitters = createEmitters; } public CompilationUnitSyntax Emit(CancellationToken cancellationToken) { - throw new NotImplementedException(); + var namespaceDeclaration = KnownBlocks.OurNamespace + .AddUsings( + UsingDirective(ParseName("System")), + KnownBlocks.CompilerServicesUsing, + UsingDirective(ParseName("System.Threading.Tasks")), + KnownBlocks.PretenderUsing, + KnownBlocks.PretenderInternalsUsing + ); + + foreach (var pretendEmitter in _pretendEmitters) + { + namespaceDeclaration = namespaceDeclaration + .AddMembers(pretendEmitter.Emit(cancellationToken)); + } + + var setupInterceptorsClass = ClassDeclaration("SetupInterceptors") + .WithModifiers(TokenList(Token(SyntaxKind.FileKeyword), Token(SyntaxKind.StaticKeyword))); + + int setupIndex = 0; + foreach (var setupEmitter in _setupEmitters) + { + setupInterceptorsClass = setupInterceptorsClass + .AddMembers(setupEmitter.Emit(setupIndex, cancellationToken)); + setupIndex++; + } + + var verifyInterceptorsClass = ClassDeclaration("VerifyInterceptors") + .AddModifiers(Token(SyntaxKind.FileKeyword), Token(SyntaxKind.StaticKeyword)); + + int verifyIndex = 0; + foreach (var verifyEmitter in _verifyEmitters) + { + verifyInterceptorsClass = verifyInterceptorsClass + .AddMembers(verifyEmitter.Emit(verifyIndex, cancellationToken)); + verifyIndex++; + } + + var createInterceptorsClass = ClassDeclaration("CreateInterceptors") + .AddModifiers(Token(SyntaxKind.FileKeyword), Token(SyntaxKind.StaticKeyword)); + + int createIndex = 0; + foreach (var createEmitter in _createEmitters) + { + createInterceptorsClass = createInterceptorsClass + .AddMembers(createEmitter.Emit(cancellationToken)); + createIndex++; + } + + namespaceDeclaration = namespaceDeclaration + .AddMembers(setupInterceptorsClass, verifyInterceptorsClass, createInterceptorsClass); + + return CompilationUnit() + .AddMembers( + KnownBlocks.InterceptsLocationAttribute, + namespaceDeclaration) + .NormalizeWhitespace(); } } } diff --git a/src/Pretender.SourceGenerator/Emitter/PretendEmitter.cs b/src/Pretender.SourceGenerator/Emitter/PretendEmitter.cs index dfaf851..fa44ccd 100644 --- a/src/Pretender.SourceGenerator/Emitter/PretendEmitter.cs +++ b/src/Pretender.SourceGenerator/Emitter/PretendEmitter.cs @@ -19,7 +19,7 @@ public PretendEmitter(ITypeSymbol pretendType, bool fillExisting) public ITypeSymbol PretendType => _pretendType; - public CompilationUnitSyntax Emit(CancellationToken token) + public TypeDeclarationSyntax Emit(CancellationToken token) { var pretendFieldAssignment = ExpressionStatement( AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, @@ -128,25 +128,8 @@ public CompilationUnitSyntax Emit(CancellationToken token) // TODO: Add properties // TODO: Generate debugger display - classDeclaration = classDeclaration - .WithModifiers(TokenList(Token(SyntaxKind.InternalKeyword))); - - SyntaxTriviaList leadingTrivia = TriviaList( - Comment("// "), - Trivia(NullableDirectiveTrivia(Token(SyntaxKind.EnableKeyword), true)), - Comment("/// ")); - - var sourceGenerationNamespace = KnownBlocks.OurNamespace - .AddMembers(classDeclaration.WithInheritDoc()) - .AddUsings( - UsingDirective(IdentifierName("System.Reflection")), - KnownBlocks.PretenderUsing - ); - - return CompilationUnit() - .AddMembers(sourceGenerationNamespace) - .WithLeadingTrivia(leadingTrivia) - .NormalizeWhitespace(); + return classDeclaration + .WithModifiers(TokenList(Token(SyntaxKind.FileKeyword))); } private static FieldDeclarationSyntax CreateMethodInfoField(IMethodSymbol method, ExpressionSyntax expressionSyntax) diff --git a/src/Pretender.SourceGenerator/Emitter/SetupEmitter.cs b/src/Pretender.SourceGenerator/Emitter/SetupEmitter.cs index 1052421..e744dde 100644 --- a/src/Pretender.SourceGenerator/Emitter/SetupEmitter.cs +++ b/src/Pretender.SourceGenerator/Emitter/SetupEmitter.cs @@ -18,13 +18,11 @@ public SetupEmitter(SetupActionEmitter setupActionEmitter, IInvocationOperation } // TODO: Run cancellationToken a lot more - public MemberDeclarationSyntax[] Emit(int index, CancellationToken cancellationToken) + public MemberDeclarationSyntax Emit(int index, CancellationToken cancellationToken) { var setupMethod = _setupActionEmitter.SetupMethod; var pretendType = _setupActionEmitter.PretendType; - var allMembers = new List(); - var interceptsLocation = new InterceptsLocationInfo(_setupInvocation); // TODO: This is wrong @@ -37,7 +35,7 @@ public MemberDeclarationSyntax[] Emit(int index, CancellationToken cancellationT var setupCreatorInvocation = _setupActionEmitter.CreateSetupGetter(cancellationToken); - var fullSetupMethod = MethodDeclaration(returnType, $"Setup{index}") + return MethodDeclaration(returnType, $"Setup{index}") .WithBody(Block(ReturnStatement(setupCreatorInvocation))) .WithParameterList(ParameterList(SeparatedList(new[] { @@ -51,9 +49,6 @@ public MemberDeclarationSyntax[] Emit(int index, CancellationToken cancellationT .WithModifiers(TokenList(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.StaticKeyword))) .WithAttributeLists(SingletonList(AttributeList( SingletonSeparatedList(interceptsLocation.ToAttributeSyntax())))); - - allMembers.Add(fullSetupMethod); - return [.. allMembers]; } } } diff --git a/src/Pretender.SourceGenerator/PretenderSourceGenerator.cs b/src/Pretender.SourceGenerator/PretenderSourceGenerator.cs index e60adab..a4af867 100644 --- a/src/Pretender.SourceGenerator/PretenderSourceGenerator.cs +++ b/src/Pretender.SourceGenerator/PretenderSourceGenerator.cs @@ -41,7 +41,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) }); #region Pretend - IncrementalValuesProvider<(PretendEmitter? Emitter, ImmutableArray? Diagnostics)> pretendsWithDiagnostics = + IncrementalValuesProvider<(PretendEmitter? Emitter, ImmutableArray? Diagnostics)> pretendEmittersWithDiagnostics = context.SyntaxProvider.CreateSyntaxProvider( predicate: (node, _) => PretendInvocation.IsCandidateSyntaxNode(node), transform: PretendInvocation.Create) @@ -56,17 +56,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context) }) .WithTrackingName("Pretend"); - var pretends = ReportDiagnostics(context, pretendsWithDiagnostics); - - context.RegisterSourceOutput(pretends, static (context, emitter) => - { - var compilationUnit = emitter.Emit(context.CancellationToken); - context.AddSource($"Pretender.Type.{emitter.PretendType.ToPretendName()}.g.cs", compilationUnit.GetText(Encoding.UTF8)); - }); + var pretendEmitters = ReportDiagnostics(context, pretendEmittersWithDiagnostics); #endregion #region Setup - IncrementalValuesProvider<(SetupEmitter? Emitter, ImmutableArray? Diagnostics)> setups = + IncrementalValuesProvider<(SetupEmitter? Emitter, ImmutableArray? Diagnostics)> setupEmittersWithDiagnostics = context.SyntaxProvider.CreateSyntaxProvider( predicate: static (node, _) => SetupInvocation.IsCandidateSyntaxNode(node), transform: SetupInvocation.Create) @@ -80,56 +74,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context) }) .WithTrackingName("Setup"); - context.RegisterSourceOutput(setups.Collect(), static (context, setups) => - { - var members = new List(); - for (var i = 0; i < setups.Length; i++) - { - var setup = setups[i]; - - if (setup.Diagnostics is ImmutableArray diagnostics) - { - foreach (var diagnostic in diagnostics) - { - context.ReportDiagnostic(diagnostic); - } - } - - if (setup.Emitter is SetupEmitter emitter) - { - members.AddRange(emitter.Emit(i, context.CancellationToken)); - } - } - - var classDeclaration = SyntaxFactory.ClassDeclaration("SetupInterceptors") - .WithModifiers(SyntaxFactory.TokenList( - SyntaxFactory.Token(SyntaxKind.FileKeyword), - SyntaxFactory.Token(SyntaxKind.StaticKeyword))) - .AddMembers([.. members]); - - var namespaceDeclaration = SyntaxFactory.NamespaceDeclaration(SyntaxFactory.ParseName("Pretender.SourceGeneration")) - .AddMembers(classDeclaration) - .AddUsings( - SyntaxFactory.UsingDirective(SyntaxFactory.ParseName("System")), - KnownBlocks.CompilerServicesUsing, - SyntaxFactory.UsingDirective(SyntaxFactory.ParseName("System.Linq.Expressions")), - SyntaxFactory.UsingDirective(SyntaxFactory.ParseName("System.Threading.Tasks")), - KnownBlocks.PretenderUsing, - KnownBlocks.PretenderInternalsUsing - ); - - var il = KnownBlocks.InterceptsLocationAttribute; - - var compilationUnit = SyntaxFactory.CompilationUnit() - .AddMembers(il, namespaceDeclaration) - .NormalizeWhitespace(); - - context.AddSource("Pretender.Setups.g.cs", compilationUnit.GetText(Encoding.UTF8)); - }); + var setups = ReportDiagnostics(context, setupEmittersWithDiagnostics); #endregion #region Verify - IncrementalValuesProvider<(VerifyEmitter? Emitter, ImmutableArray? Diagnostics)> verifyCallsWithDiagnostics = + IncrementalValuesProvider<(VerifyEmitter? Emitter, ImmutableArray? Diagnostics)> verifyEmittersWithDiagnostics = context.SyntaxProvider.CreateSyntaxProvider( predicate: (node, _) => VerifyInvocation.IsCandidateSyntaxNode(node), transform: VerifyInvocation.Create) @@ -144,30 +93,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context) }) .WithTrackingName("Verify"); - var verifyEmitters = ReportDiagnostics(context, verifyCallsWithDiagnostics); - - context.RegisterSourceOutput(verifyEmitters.Collect(), (context, inputs) => - { - var methods = new List(); - for (var i = 0; i < inputs.Length; i++) - { - var input = inputs[i]; - - var method = input.Emit(0, context.CancellationToken); - methods.Add(method); - } - - if (methods.Count > 0) - { - // Emit all methods - var compilationUnit = CommonSyntax.CreateVerifyCompilationUnit([.. methods]); - context.AddSource("Pretender.Verifies.g.cs", compilationUnit.GetText(Encoding.UTF8)); - } - }); + var verifyEmitters = ReportDiagnostics(context, verifyEmittersWithDiagnostics); #endregion #region Create - var createCalls = context.SyntaxProvider.CreateSyntaxProvider( + var createEmittersWithDiagnostics = context.SyntaxProvider.CreateSyntaxProvider( predicate: (node, _) => CreateInvocation.IsCandidateSyntaxNode(node), transform: CreateInvocation.Create) .Where(i => i is not null)! @@ -180,36 +110,22 @@ public void Initialize(IncrementalGeneratorInitializationContext context) }) .WithTrackingName("Create"); - var createEmitters = ReportDiagnostics(context, createCalls); + var createEmitters = ReportDiagnostics(context, createEmittersWithDiagnostics); + #endregion - context.RegisterSourceOutput(createEmitters, static (context, emitter) => + context.RegisterSourceOutput( + pretendEmitters.Collect() + .Combine(setups.Collect()) + .Combine(verifyEmitters.Collect()) + .Combine(createEmitters.Collect()), (context, emitters) => { - // TODO: Don't actually need a list here - var members = new List(); + var (((pretends, setups), verifies), creates) = emitters; + var grandEmitter = new GrandEmitter(pretends, setups, verifies, creates); - string? pretendName = null; + var compilationUnit = grandEmitter.Emit(context.CancellationToken); - pretendName ??= emitter.Operation.TargetMethod.ReturnType.ToPretendName(); - members.Add(emitter.Emit(context.CancellationToken)); - - if (members.Any()) - { - var createClass = SyntaxFactory.ClassDeclaration("CreateInterceptors") - .AddModifiers(SyntaxFactory.Token(SyntaxKind.FileKeyword), SyntaxFactory.Token(SyntaxKind.StaticKeyword)) - .WithMembers(SyntaxFactory.List(members)); - - var createNamespace = KnownBlocks.OurNamespace - .AddMembers(createClass) - .AddUsings(KnownBlocks.CompilerServicesUsing, KnownBlocks.PretenderUsing); - - var cu = SyntaxFactory.CompilationUnit() - .AddMembers(KnownBlocks.InterceptsLocationAttribute, createNamespace) - .NormalizeWhitespace(); - - context.AddSource($"Pretender.Creates.{pretendName}.g.cs", cu.GetText(Encoding.UTF8)); - } + context.AddSource("Pretender.g.cs", compilationUnit.GetText(Encoding.UTF8)); }); - #endregion } private static IncrementalValuesProvider ReportDiagnostics(IncrementalGeneratorInitializationContext context, IncrementalValuesProvider<(T? Emitter, ImmutableArray? Diagnostics)> source) diff --git a/test/SourceGeneratorTests/MainTests.cs b/test/SourceGeneratorTests/MainTests.cs index 06ad198..f04e2fa 100644 --- a/test/SourceGeneratorTests/MainTests.cs +++ b/test/SourceGeneratorTests/MainTests.cs @@ -45,10 +45,10 @@ public TestClass() } """); - Assert.Equal(2, result.GeneratedSources.Length); + Assert.Equal(1, result.GeneratedSources.Length); var text1 = result.GeneratedSources[0].SourceText.ToString(); - var text2 = result.GeneratedSources[1].SourceText.ToString(); + //var text2 = result.GeneratedSources[1].SourceText.ToString(); }