diff --git a/README.md b/README.md index f654241..7e1ed2d 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,8 @@ It's rather simple. `EnumeratorMacro` will: * Remove empty lines from the final generated code, to get rid of possible excessive empty lines. -* +* Remove last trailing comma in a case switch, which is an error. For easier templating. + * Search for the `testRemovesLastErroneousCommaInCaseSwitch()` test for an example. diff --git a/Sources/EnumeratorMacroImpl/EnumeratorMacroType.swift b/Sources/EnumeratorMacroImpl/EnumeratorMacroType.swift index ad7a85a..714f0c1 100644 --- a/Sources/EnumeratorMacroImpl/EnumeratorMacroType.swift +++ b/Sources/EnumeratorMacroImpl/EnumeratorMacroType.swift @@ -121,14 +121,39 @@ extension EnumeratorMacroType: MemberMacro { let decls = SourceFileSyntax.parse( from: &parser ).statements.compactMap { statement -> DeclSyntax? in - let diagnostics = ParseDiagnosticsGenerator.diagnostics(for: statement) - let hasError = diagnostics.contains(where: { $0.diagMessage.severity == .error }) - if hasError { - context.diagnose(.init( - node: codeSyntax, - message: MacroError.renderedSyntaxContainsErrors(statement.description) - )) + var statement = statement + + var diagnostics = ParseDiagnosticsGenerator.diagnostics(for: statement) + if diagnostics.containsError { + /// Try to recover from errors: + let switchRewriter = SwitchErrorsRewriter() + let fixedStatement = switchRewriter.rewrite(statement) + let newDiagnostics = ParseDiagnosticsGenerator.diagnostics(for: fixedStatement) + if !newDiagnostics.containsError { + switch CodeBlockItemSyntax(fixedStatement) { + case let .some(fixedStatement): + statement = fixedStatement + diagnostics = newDiagnostics + case .none: + context.diagnose( + Diagnostic( + node: codeSyntax, + message: MacroError.internalError( + "Could not convert a Syntax to a CodeBlockItemSyntax" + ) + ) + ) + return nil + } + } else { + /// If not recovered, throw a diagnostic error. + context.diagnose(.init( + node: codeSyntax, + message: MacroError.renderedSyntaxContainsErrors(statement.description) + )) + } } + for diagnostic in diagnostics { if diagnostic.diagMessage.severity == .error { context.diagnose(.init( @@ -143,7 +168,7 @@ extension EnumeratorMacroType: MemberMacro { /// TODO: Apply the fixit } } - if hasError { + if diagnostics.containsError { return nil } switch DeclSyntax(statement.item) { @@ -174,7 +199,7 @@ extension EnumeratorMacroType: MemberMacro { let excessiveTriviaRemover = ExcessiveTriviaRemover() processedSyntax = excessiveTriviaRemover.rewrite(processedSyntax) - let switchRewriter = SwitchRewriter() + let switchRewriter = SwitchWarningsRewriter() processedSyntax = switchRewriter.rewrite(processedSyntax) guard let declSyntax = DeclSyntax(processedSyntax) else { @@ -194,3 +219,9 @@ extension EnumeratorMacroType: MemberMacro { return postProcessedSyntaxes } } + +private extension [Diagnostic] { + var containsError: Bool { + self.contains(where: { $0.diagMessage.severity == .error }) + } +} diff --git a/Sources/EnumeratorMacroImpl/SwitchErrorsRewriter.swift b/Sources/EnumeratorMacroImpl/SwitchErrorsRewriter.swift new file mode 100644 index 0000000..b14eceb --- /dev/null +++ b/Sources/EnumeratorMacroImpl/SwitchErrorsRewriter.swift @@ -0,0 +1,47 @@ +import SwiftSyntax + +final class SwitchErrorsRewriter: SyntaxRewriter { + override func visit(_ node: SwitchCaseSyntax) -> SwitchCaseSyntax { + guard let label = node.label.as(SwitchCaseLabelSyntax.self) else { + return node + } + var items = label.caseItems + guard items.count > 1 else { + return node + } + + let lastIndex = items.lastIndex(where: { _ in true })! + var lastIsAMissingExpr: Bool { + if let pattern = items[lastIndex].pattern.as(ExpressionPatternSyntax.self), + pattern.expression.is(MissingExprSyntax.self) { + return true + } else { + return false + } + } + + var oneToLastContainsTrailingComma: Bool { + items[items.index(before: lastIndex)].trailingComma != nil + } + + if lastIsAMissingExpr, + oneToLastContainsTrailingComma { + + items[items.index(before: lastIndex)].trailingComma = nil + items.remove(at: lastIndex) + + let node = node.with( + \.label, + SwitchCaseSyntax.Label( + label.with( + \.caseItems, + items + ) + ) + ) + return node + } else { + return node + } + } +} diff --git a/Sources/EnumeratorMacroImpl/SwitchRewriter.swift b/Sources/EnumeratorMacroImpl/SwitchWarningsRewriter.swift similarity index 99% rename from Sources/EnumeratorMacroImpl/SwitchRewriter.swift rename to Sources/EnumeratorMacroImpl/SwitchWarningsRewriter.swift index ea1dbb7..4f34ec0 100644 --- a/Sources/EnumeratorMacroImpl/SwitchRewriter.swift +++ b/Sources/EnumeratorMacroImpl/SwitchWarningsRewriter.swift @@ -1,6 +1,6 @@ import SwiftSyntax -final class SwitchRewriter: SyntaxRewriter { +final class SwitchWarningsRewriter: SyntaxRewriter { override func visit(_ node: SwitchCaseSyntax) -> SwitchCaseSyntax { self.removeUnusedLet( self.removeUnusedArguments( diff --git a/Tests/EnumeratorMacroTests/EnumeratorMacroTests.swift b/Tests/EnumeratorMacroTests/EnumeratorMacroTests.swift index 76e3a30..92c5596 100644 --- a/Tests/EnumeratorMacroTests/EnumeratorMacroTests.swift +++ b/Tests/EnumeratorMacroTests/EnumeratorMacroTests.swift @@ -295,7 +295,7 @@ final class EnumeratorMacroTests: XCTestCase { ) } - func removesExcessiveTrivia() { + func testRemovesExcessiveTrivia() { assertMacroExpansion( #""" @Enumerator( @@ -337,6 +337,44 @@ final class EnumeratorMacroTests: XCTestCase { ) } + func testRemovesLastErroneousCommaInCaseSwitch() { + assertMacroExpansion( + #""" + @Enumerator( + """ + public var constant: String { + switch self { + case {{#cases}}.{{name}}, {{/cases}}: + "some constant" + } + } + """ + ) + enum TestEnum { + case a + case b + } + """#, + /// It usually contain `case .a, .b,:` which is an error + /// because `.b` has a trailing comma after it. + /// But the macro should recover from this situation: + expandedSource: #""" + enum TestEnum { + case a + case b + + public var constant: String { + switch self { + case .a, .b: + "some constant" + } + } + } + """#, + macros: EnumeratorMacroEntryPoint.macros + ) + } + func testDiagnosesNotAnEnum() { assertMacroExpansion(