Skip to content

Commit

Permalink
Automatically recover from trailing-comma-in-switch-case errors
Browse files Browse the repository at this point in the history
  • Loading branch information
MahdiBM committed Jul 18, 2024
1 parent e046d96 commit 9de9afa
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 12 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ It's rather simple.

`EnumeratorMacro` will:
* Remove empty lines from the final generated code, to get rid of possible excessive empty lines.
*
* Remove last trailing comma in a case switch, which is an error. For easier templating.
* Search for the `testRemovesLastErroneousCommaInCaseSwitch()` test for an example.

</details>

Expand Down
49 changes: 40 additions & 9 deletions Sources/EnumeratorMacroImpl/EnumeratorMacroType.swift
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,39 @@ extension EnumeratorMacroType: MemberMacro {
let decls = SourceFileSyntax.parse(
from: &parser
).statements.compactMap { statement -> DeclSyntax? in
let diagnostics = ParseDiagnosticsGenerator.diagnostics(for: statement)
let hasError = diagnostics.contains(where: { $0.diagMessage.severity == .error })
if hasError {
context.diagnose(.init(
node: codeSyntax,
message: MacroError.renderedSyntaxContainsErrors(statement.description)
))
var statement = statement

var diagnostics = ParseDiagnosticsGenerator.diagnostics(for: statement)
if diagnostics.containsError {
/// Try to recover from errors:
let switchRewriter = SwitchErrorsRewriter()
let fixedStatement = switchRewriter.rewrite(statement)
let newDiagnostics = ParseDiagnosticsGenerator.diagnostics(for: fixedStatement)
if !newDiagnostics.containsError {
switch CodeBlockItemSyntax(fixedStatement) {
case let .some(fixedStatement):
statement = fixedStatement
diagnostics = newDiagnostics
case .none:
context.diagnose(
Diagnostic(
node: codeSyntax,
message: MacroError.internalError(
"Could not convert a Syntax to a CodeBlockItemSyntax"
)
)
)
return nil
}
} else {
/// If not recovered, throw a diagnostic error.
context.diagnose(.init(
node: codeSyntax,
message: MacroError.renderedSyntaxContainsErrors(statement.description)
))
}
}

for diagnostic in diagnostics {
if diagnostic.diagMessage.severity == .error {
context.diagnose(.init(
Expand All @@ -143,7 +168,7 @@ extension EnumeratorMacroType: MemberMacro {
/// TODO: Apply the fixit
}
}
if hasError {
if diagnostics.containsError {
return nil
}
switch DeclSyntax(statement.item) {
Expand Down Expand Up @@ -174,7 +199,7 @@ extension EnumeratorMacroType: MemberMacro {
let excessiveTriviaRemover = ExcessiveTriviaRemover()
processedSyntax = excessiveTriviaRemover.rewrite(processedSyntax)

let switchRewriter = SwitchRewriter()
let switchRewriter = SwitchWarningsRewriter()
processedSyntax = switchRewriter.rewrite(processedSyntax)

guard let declSyntax = DeclSyntax(processedSyntax) else {
Expand All @@ -194,3 +219,9 @@ extension EnumeratorMacroType: MemberMacro {
return postProcessedSyntaxes
}
}

private extension [Diagnostic] {
var containsError: Bool {
self.contains(where: { $0.diagMessage.severity == .error })
}
}
47 changes: 47 additions & 0 deletions Sources/EnumeratorMacroImpl/SwitchErrorsRewriter.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import SwiftSyntax

final class SwitchErrorsRewriter: SyntaxRewriter {
override func visit(_ node: SwitchCaseSyntax) -> SwitchCaseSyntax {
guard let label = node.label.as(SwitchCaseLabelSyntax.self) else {
return node
}
var items = label.caseItems
guard items.count > 1 else {
return node
}

let lastIndex = items.lastIndex(where: { _ in true })!
var lastIsAMissingExpr: Bool {
if let pattern = items[lastIndex].pattern.as(ExpressionPatternSyntax.self),
pattern.expression.is(MissingExprSyntax.self) {
return true
} else {
return false
}
}

var oneToLastContainsTrailingComma: Bool {
items[items.index(before: lastIndex)].trailingComma != nil
}

if lastIsAMissingExpr,
oneToLastContainsTrailingComma {

items[items.index(before: lastIndex)].trailingComma = nil
items.remove(at: lastIndex)

let node = node.with(
\.label,
SwitchCaseSyntax.Label(
label.with(
\.caseItems,
items
)
)
)
return node
} else {
return node
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import SwiftSyntax

final class SwitchRewriter: SyntaxRewriter {
final class SwitchWarningsRewriter: SyntaxRewriter {
override func visit(_ node: SwitchCaseSyntax) -> SwitchCaseSyntax {
self.removeUnusedLet(
self.removeUnusedArguments(
Expand Down
40 changes: 39 additions & 1 deletion Tests/EnumeratorMacroTests/EnumeratorMacroTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ final class EnumeratorMacroTests: XCTestCase {
)
}

func removesExcessiveTrivia() {
func testRemovesExcessiveTrivia() {
assertMacroExpansion(
#"""
@Enumerator(
Expand Down Expand Up @@ -337,6 +337,44 @@ final class EnumeratorMacroTests: XCTestCase {
)
}

func testRemovesLastErroneousCommaInCaseSwitch() {
assertMacroExpansion(
#"""
@Enumerator(
"""
public var constant: String {
switch self {
case {{#cases}}.{{name}}, {{/cases}}:
"some constant"
}
}
"""
)
enum TestEnum {
case a
case b
}
"""#,
/// It usually contain `case .a, .b,:` which is an error
/// because `.b` has a trailing comma after it.
/// But the macro should recover from this situation:
expandedSource: #"""
enum TestEnum {
case a
case b
public var constant: String {
switch self {
case .a, .b:
"some constant"
}
}
}
"""#,
macros: EnumeratorMacroEntryPoint.macros
)
}


func testDiagnosesNotAnEnum() {
assertMacroExpansion(
Expand Down

0 comments on commit 9de9afa

Please sign in to comment.