diff --git a/src/Pretender.SourceGenerator/Emitter/GrandEmitter.cs b/src/Pretender.SourceGenerator/Emitter/GrandEmitter.cs index 0d1cf88..4283b62 100644 --- a/src/Pretender.SourceGenerator/Emitter/GrandEmitter.cs +++ b/src/Pretender.SourceGenerator/Emitter/GrandEmitter.cs @@ -38,6 +38,7 @@ public CompilationUnitSyntax Emit(CancellationToken cancellationToken) foreach (var pretendEmitter in _pretendEmitters) { + cancellationToken.ThrowIfCancellationRequested(); namespaceDeclaration = namespaceDeclaration .AddMembers(pretendEmitter.Emit(cancellationToken)); } @@ -45,9 +46,12 @@ public CompilationUnitSyntax Emit(CancellationToken cancellationToken) var setupInterceptorsClass = ClassDeclaration("SetupInterceptors") .WithModifiers(TokenList(Token(SyntaxKind.FileKeyword), Token(SyntaxKind.StaticKeyword))); + cancellationToken.ThrowIfCancellationRequested(); + int setupIndex = 0; foreach (var setupEmitter in _setupEmitters) { + cancellationToken.ThrowIfCancellationRequested(); setupInterceptorsClass = setupInterceptorsClass .AddMembers(setupEmitter.Emit(setupIndex, cancellationToken)); setupIndex++; @@ -56,9 +60,13 @@ public CompilationUnitSyntax Emit(CancellationToken cancellationToken) var verifyInterceptorsClass = ClassDeclaration("VerifyInterceptors") .AddModifiers(Token(SyntaxKind.FileKeyword), Token(SyntaxKind.StaticKeyword)); + cancellationToken.ThrowIfCancellationRequested(); + int verifyIndex = 0; foreach (var verifyEmitter in _verifyEmitters) { + cancellationToken.ThrowIfCancellationRequested(); + verifyInterceptorsClass = verifyInterceptorsClass .AddMembers(verifyEmitter.Emit(verifyIndex, cancellationToken)); verifyIndex++; @@ -67,9 +75,13 @@ public CompilationUnitSyntax Emit(CancellationToken cancellationToken) var createInterceptorsClass = ClassDeclaration("CreateInterceptors") .AddModifiers(Token(SyntaxKind.FileKeyword), Token(SyntaxKind.StaticKeyword)); + cancellationToken.ThrowIfCancellationRequested(); + int createIndex = 0; foreach (var createEmitter in _createEmitters) { + cancellationToken.ThrowIfCancellationRequested(); + createInterceptorsClass = createInterceptorsClass .AddMembers(createEmitter.Emit(cancellationToken)); createIndex++; @@ -78,6 +90,8 @@ public CompilationUnitSyntax Emit(CancellationToken cancellationToken) namespaceDeclaration = namespaceDeclaration .AddMembers(setupInterceptorsClass, verifyInterceptorsClass, createInterceptorsClass); + cancellationToken.ThrowIfCancellationRequested(); + return CompilationUnit() .AddMembers( KnownBlocks.InterceptsLocationAttribute, diff --git a/src/Pretender.SourceGenerator/PretenderSourceGenerator.cs b/src/Pretender.SourceGenerator/PretenderSourceGenerator.cs index a4af867..1663b8e 100644 --- a/src/Pretender.SourceGenerator/PretenderSourceGenerator.cs +++ b/src/Pretender.SourceGenerator/PretenderSourceGenerator.cs @@ -1,8 +1,6 @@ using System.Collections.Immutable; using System.Text; using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; -using Microsoft.CodeAnalysis.CSharp.Syntax; using Pretender.SourceGenerator.Emitter; using Pretender.SourceGenerator.Invocation; using Pretender.SourceGenerator.Parser; @@ -14,12 +12,10 @@ public class PretenderSourceGenerator : IIncrementalGenerator { public void Initialize(IncrementalGeneratorInitializationContext context) { - // TODO: Refactor our region use IncrementalValueProvider knownTypeSymbols = context.CompilationProvider .Select((compilation, _) => new KnownTypeSymbols(compilation)); - // TODO: Read settings off of IncrementalValueProvider settings = context.SyntaxProvider.ForAttributeWithMetadataName( "Pretender.PretenderSettingsAttribute", predicate: static (node, _) => true, @@ -40,7 +36,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context) return PretenderSettings.FromAttribute(settings[0]); }); - #region Pretend IncrementalValuesProvider<(PretendEmitter? Emitter, ImmutableArray? Diagnostics)> pretendEmittersWithDiagnostics = context.SyntaxProvider.CreateSyntaxProvider( predicate: (node, _) => PretendInvocation.IsCandidateSyntaxNode(node), @@ -57,9 +52,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .WithTrackingName("Pretend"); var pretendEmitters = ReportDiagnostics(context, pretendEmittersWithDiagnostics); - #endregion - #region Setup IncrementalValuesProvider<(SetupEmitter? Emitter, ImmutableArray? Diagnostics)> setupEmittersWithDiagnostics = context.SyntaxProvider.CreateSyntaxProvider( predicate: static (node, _) => SetupInvocation.IsCandidateSyntaxNode(node), @@ -75,9 +68,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .WithTrackingName("Setup"); var setups = ReportDiagnostics(context, setupEmittersWithDiagnostics); - #endregion - #region Verify IncrementalValuesProvider<(VerifyEmitter? Emitter, ImmutableArray? Diagnostics)> verifyEmittersWithDiagnostics = context.SyntaxProvider.CreateSyntaxProvider( predicate: (node, _) => VerifyInvocation.IsCandidateSyntaxNode(node), @@ -94,9 +85,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .WithTrackingName("Verify"); var verifyEmitters = ReportDiagnostics(context, verifyEmittersWithDiagnostics); - #endregion - #region Create var createEmittersWithDiagnostics = context.SyntaxProvider.CreateSyntaxProvider( predicate: (node, _) => CreateInvocation.IsCandidateSyntaxNode(node), transform: CreateInvocation.Create) @@ -111,7 +100,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .WithTrackingName("Create"); var createEmitters = ReportDiagnostics(context, createEmittersWithDiagnostics); - #endregion context.RegisterSourceOutput( pretendEmitters.Collect() @@ -120,10 +108,15 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .Combine(createEmitters.Collect()), (context, emitters) => { var (((pretends, setups), verifies), creates) = emitters; + + context.CancellationToken.ThrowIfCancellationRequested(); + var grandEmitter = new GrandEmitter(pretends, setups, verifies, creates); var compilationUnit = grandEmitter.Emit(context.CancellationToken); + context.CancellationToken.ThrowIfCancellationRequested(); + context.AddSource("Pretender.g.cs", compilationUnit.GetText(Encoding.UTF8)); }); } diff --git a/test/SourceGeneratorTests/MainTests.cs b/test/SourceGeneratorTests/MainTests.cs index f04e2fa..bcd2462 100644 --- a/test/SourceGeneratorTests/MainTests.cs +++ b/test/SourceGeneratorTests/MainTests.cs @@ -45,9 +45,9 @@ public TestClass() } """); - Assert.Equal(1, result.GeneratedSources.Length); + var source = Assert.Single(result.GeneratedSources); - var text1 = result.GeneratedSources[0].SourceText.ToString(); + var text1 = source.SourceText.ToString(); //var text2 = result.GeneratedSources[1].SourceText.ToString(); }