Skip to content

Commit

Permalink
Enable generic types to be instantiated (#88)
Browse files Browse the repository at this point in the history
  • Loading branch information
dfed authored Jun 23, 2024
1 parent 4f4ea97 commit 2c3af48
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 54 deletions.
2 changes: 1 addition & 1 deletion Sources/SafeDI/PropertyDecoration/Instantiable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
/// }
///
/// An extension declaration decorated with `@Instantiable` makes the extended type capable of having properties of other `@Instantiable`-decorated types injected into it. Decorating extensions with `@Instantiable` enables third-party types to be instantiated by the SafeDI system.
/// Usage of this macro requires the extension to implement a method `public static instantiate() -> ExtendedType` that defines the instantiation logic for the externally defined type.
/// Usage of this macro requires the extension to implement a method `public static func instantiate() -> ExtendedType` that defines the instantiation logic for the externally defined type.
///
/// Example:
///
Expand Down
4 changes: 2 additions & 2 deletions Sources/SafeDICore/Errors/FixableInstantiableError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public enum FixableInstantiableError: DiagnosticError {
case .disallowedEffectSpecifiers:
"@\(InstantiableVisitor.macroName)-decorated extension’s `instantiate()` method must not throw or be async"
case .incorrectReturnType:
"@\(InstantiableVisitor.macroName)-decorated extension’s `instantiate()` method must return the same type as the extended type"
"@\(InstantiableVisitor.macroName)-decorated extension’s `instantiate()` method must return the same base type as the extended type"
case .disallowedGenericWhereClause:
"@\(InstantiableVisitor.macroName)-decorated extension must not have a generic `where` clause"
case .dependencyHasTooManyAttributes:
Expand Down Expand Up @@ -119,7 +119,7 @@ public enum FixableInstantiableError: DiagnosticError {
case .disallowedEffectSpecifiers:
"Remove effect specifiers"
case .incorrectReturnType:
"Make `instantiate()`’s return type the same as the extended type"
"Make `instantiate()`’s return type the same base type as the extended type"
case .disallowedGenericWhereClause:
"Remove generic `where` clause"
case .dependencyHasTooManyAttributes:
Expand Down
11 changes: 11 additions & 0 deletions Sources/SafeDICore/Models/TypeDescription.swift
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,17 @@ public enum TypeDescription: Codable, Hashable, Comparable, Sendable {
}
}

var strippingGenerics: TypeDescription {
switch self {
case let .simple(name, _):
.simple(name: name, generics: [])
case let .nested(name, parentType, _):
.nested(name: name, parentType: parentType, generics: [])
case .array, .attributed, .any, .closure, .composition, .dictionary, .implicitlyUnwrappedOptional, .metatype, .optional, .some, .tuple, .unknown, .void:
self
}
}

/// The receiver as an `@Instantiable` type.
var asInstantiatedType: TypeDescription {
switch self {
Expand Down
4 changes: 2 additions & 2 deletions Sources/SafeDICore/Visitors/FileVisitor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public final class FileVisitor: SyntaxVisitor {
public override func visit(_ node: ExtensionDeclSyntax) -> SyntaxVisitorContinueKind {
let instantiableVisitor = InstantiableVisitor(declarationType: .extensionDecl)
instantiableVisitor.walk(node)
if let instantiable = instantiableVisitor.instantiable {
for instantiable in instantiableVisitor.instantiables {
instantiables.append(instantiable)
}

Expand Down Expand Up @@ -124,7 +124,7 @@ public final class FileVisitor: SyntaxVisitor {

let instantiableVisitor = InstantiableVisitor(declarationType: .concreteDecl)
instantiableVisitor.walk(node)
if let instantiable = instantiableVisitor.instantiable {
for instantiable in instantiableVisitor.instantiables {
instantiables.append(instantiable)
}

Expand Down
73 changes: 35 additions & 38 deletions Sources/SafeDICore/Visitors/InstantiableVisitor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,10 @@ public final class InstantiableVisitor: SyntaxVisitor {
}

public override func visit(_ node: ExtensionDeclSyntax) -> SyntaxVisitorContinueKind {
guard declarationType.isExtension else {
return .skipChildren
}
guard let instantiableMacro = node.attributes.instantiableMacro else {
// Not an instantiable type. We do not care.
return .skipChildren
}

instantiableType = node.extendedType.typeDescription
processAttributes(node.attributes, on: instantiableMacro)
if let instantiableMacro = node.attributes.instantiableMacro {
processAttributes(node.attributes, on: instantiableMacro)
}

return .visitChildren
}
Expand All @@ -163,7 +157,7 @@ public final class InstantiableVisitor: SyntaxVisitor {

if
let returnClause = node.signature.returnClause,
returnClause.type.typeDescription != instantiableType,
returnClause.type.typeDescription.strippingGenerics != instantiableType?.strippingGenerics,
let instantiableType
{
var modifiedSignature = node.signature
Expand All @@ -174,7 +168,7 @@ public final class InstantiableVisitor: SyntaxVisitor {
),
type: IdentifierTypeSyntax(
leadingTrivia: node.signature.leadingTrivia,
name: .identifier(instantiableType.asSource),
name: .identifier(instantiableType.strippingGenerics.asSource),
trailingTrivia: node.signature.trailingTrivia
)
)
Expand All @@ -191,13 +185,19 @@ public final class InstantiableVisitor: SyntaxVisitor {
}

let initializer = Initializer(node)
initializers.append(initializer)
// We should only have a single `instantiate` method, so we set rather than append to dependencies.
dependencies = initializer.arguments.map {
Dependency(
property: $0.asProperty,
source: .received
)
if let instantiableType = node.signature.returnClause?.type.typeDescription {
extensionInstantiables.append(.init(
instantiableType: instantiableType,
initializer: initializer,
additionalInstantiables: additionalInstantiables,
dependencies: initializer.arguments.map {
Dependency(
property: $0.asProperty,
source: .received
)
},
declarationType: .extensionType
))
}

if !initializer.isPublicOrOpen || node.modifiers.staticModifier == nil {
Expand Down Expand Up @@ -317,29 +317,24 @@ public final class InstantiableVisitor: SyntaxVisitor {
}
}

// MARK: Internal

var instantiable: Instantiable? {
guard let instantiableType else { return nil }
public var instantiables: [Instantiable] {
switch declarationType {
case .concreteDecl:
guard let topLevelDeclarationType else { return nil }
return Instantiable(
instantiableType: instantiableType,
initializer: initializers.first(where: { $0.isValid(forFulfilling: dependencies) }) ?? initializerToGenerate(),
additionalInstantiables: additionalInstantiables,
dependencies: dependencies,
declarationType: topLevelDeclarationType.asDeclarationType
)
if let instantiableType, let topLevelDeclarationType {
[
Instantiable(
instantiableType: instantiableType,
initializer: initializers.first(where: { $0.isValid(forFulfilling: dependencies) }) ?? initializerToGenerate(),
additionalInstantiables: additionalInstantiables,
dependencies: dependencies,
declarationType: topLevelDeclarationType.asDeclarationType
),
]
} else {
[]
}
case .extensionDecl:
return Instantiable(
instantiableType: instantiableType,
// If we have more than one initializer this isn't a valid extension.
initializer: initializers.count > 1 ? nil : initializers.first,
additionalInstantiables: additionalInstantiables,
dependencies: dependencies,
declarationType: .extensionType
)
extensionInstantiables
}
}

Expand All @@ -348,6 +343,8 @@ public final class InstantiableVisitor: SyntaxVisitor {
private var isInTopLevelDeclaration = false
private var topLevelDeclarationType: ConcreteDeclType?

private var extensionInstantiables = [Instantiable]()

private let declarationType: DeclarationType

private func visitDecl(_ node: some ConcreteDeclSyntaxProtocol) -> SyntaxVisitorContinueKind {
Expand Down
8 changes: 4 additions & 4 deletions Sources/SafeDIMacros/Macros/InstantiableMacro.swift
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,10 @@ public struct InstantiableMacro: MemberMacro {
context.diagnose(diagnostic)
}

let initializersCount = visitor.initializers.count
if initializersCount > 1 {
let instantiableTypeCount = visitor.instantiables.map(\.instantiableTypes).count
if instantiableTypeCount > 1 {
throw InstantiableError.tooManyInstantiateMethods
} else if initializersCount == 0 {
} else if instantiableTypeCount == 0 {
let extendedTypeName = extensionDeclaration.extendedType.typeDescription.asSource
var membersWithInitializer = declaration.memberBlock.members
membersWithInitializer.insert(
Expand Down Expand Up @@ -348,7 +348,7 @@ public struct InstantiableMacro: MemberMacro {
case .fulfillingAdditionalTypesArgumentInvalid:
"The argument `fulfillingAdditionalTypes` must be an inlined array"
case .tooManyInstantiateMethods:
"@\(InstantiableVisitor.macroName)-decorated extension must have a single `instantiate()` method"
"@\(InstantiableVisitor.macroName)-decorated extension must have a single `instantiate()` method per return type"
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions Tests/SafeDIMacrosTests/InstantiableMacroTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ import SafeDICore
}
}

func test_extension_throwsErrorWhenMoreThanOneInstantiateMethod() {
func test_extension_throwsErrorWhenMoreThanOneInstantiateMethodForSameType() {
assertMacro {
"""
@Instantiable
Expand All @@ -910,7 +910,7 @@ import SafeDICore
"""
@Instantiable
┬────────────
╰─ 🛑 @Instantiable-decorated extension must have a single `instantiate()` method
╰─ 🛑 @Instantiable-decorated extension must have a single `instantiate()` method per return type
extension ExampleService: Instantiable {
public static func instantiate() -> ExampleService { fatalError() }
public static func instantiate(user: User) -> ExampleService { fatalError() }
Expand Down Expand Up @@ -1883,8 +1883,8 @@ import SafeDICore
extension ExampleService: Instantiable {
public static func instantiate() -> OtherExampleService { fatalError() }
┬───────────────────────────────────────────────────────────────────────
╰─ 🛑 @Instantiable-decorated extension’s `instantiate()` method must return the same type as the extended type
✏️ Make `instantiate()`’s return type the same as the extended type
╰─ 🛑 @Instantiable-decorated extension’s `instantiate()` method must return the same base type as the extended type
✏️ Make `instantiate()`’s return type the same base type as the extended type
}
"""
} fixes: {
Expand Down
118 changes: 115 additions & 3 deletions Tests/SafeDIToolTests/SafeDIToolTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4139,6 +4139,118 @@ final class SafeDIToolTests: XCTestCase {
)
}

func test_run_writesConvenienceExtensionOnRootOfTree_whenAGenericTypeIsAnExtendedInstantiableWithMultipleGenericReturnTypes() async throws {
let output = try await executeSystemUnderTest(
swiftFileContent: [
"""
@Instantiable
public struct Root {
@Instantiated let stringContainer: Container<String>
@Instantiated let intContainer: Container<Int>
@Instantiated let floatContainer: Container<Float>
@Instantiated let voidContainer: Container<Void>
}
""",
"""
public struct Container<T> {
let value: T
}
@Instantiable
extension Container<T>: Instantiable {
public static func instantiate() -> Container<String> {
.init(value: "")
}
public static func instantiate() -> Container<Int> {
.init(value: 0)
}
public static func instantiate() -> Container<Float> {
.init(value: 0)
}
public static func instantiate() -> Container<Void> {
.init(value: ())
}
}
""",
],
buildDependencyTreeOutput: true
)

XCTAssertEqual(
try XCTUnwrap(output.dependencyTree),
"""
// This file was generated by the SafeDIGenerateDependencyTree build tool plugin.
// Any modifications made to this file will be overwritten on subsequent builds.
// Please refrain from editing this file directly.
extension Root {
public init() {
let stringContainer = Container<String>.instantiate()
let intContainer = Container<Int>.instantiate()
let floatContainer = Container<Float>.instantiate()
let voidContainer = Container<Void>.instantiate()
self.init(stringContainer: stringContainer, intContainer: intContainer, floatContainer: floatContainer, voidContainer: voidContainer)
}
}
"""
)
}

func test_run_writesConvenienceExtensionOnRootOfTree_whenAGenericTypeIsAnExtendedInstantiableWithMultipleGenericFullyQualifiedReturnTypes() async throws {
let output = try await executeSystemUnderTest(
swiftFileContent: [
"""
@Instantiable
public struct Root {
@Instantiated let stringContainer: MyModule.Container<String>
@Instantiated let intContainer: MyModule.Container<Int>
@Instantiated let floatContainer: MyModule.Container<Float>
@Instantiated let voidContainer: MyModule.Container<Void>
}
""",
"""
public struct Container<T> {
let value: T
}
@Instantiable
extension MyModule.Container<T>: Instantiable {
public static func instantiate() -> MyModule.Container<String> {
.init(value: "")
}
public static func instantiate() -> MyModule.Container<Int> {
.init(value: 0)
}
public static func instantiate() -> MyModule.Container<Float> {
.init(value: 0)
}
public static func instantiate() -> MyModule.Container<Void> {
.init(value: ())
}
}
""",
],
buildDependencyTreeOutput: true
)

XCTAssertEqual(
try XCTUnwrap(output.dependencyTree),
"""
// This file was generated by the SafeDIGenerateDependencyTree build tool plugin.
// Any modifications made to this file will be overwritten on subsequent builds.
// Please refrain from editing this file directly.
extension Root {
public init() {
let stringContainer = MyModule.Container<String>.instantiate()
let intContainer = MyModule.Container<Int>.instantiate()
let floatContainer = MyModule.Container<Float>.instantiate()
let voidContainer = MyModule.Container<Void>.instantiate()
self.init(stringContainer: stringContainer, intContainer: intContainer, floatContainer: floatContainer, voidContainer: voidContainer)
}
}
"""
)
}

// MARK: Error Tests

func test_run_onCodeWithPropertyWithUnknownFulfilledType_throwsError() async {
Expand Down Expand Up @@ -4659,7 +4771,7 @@ final class SafeDIToolTests: XCTestCase {
@Instantiable
extension RootViewController: Instantiable {
public static instantiate() {
public static func instantiate() -> RootViewController {
RootViewController()
}
}
Expand All @@ -4683,7 +4795,7 @@ final class SafeDIToolTests: XCTestCase {
@Instantiable
extension UserDefaults: Instantiable {
public static instantiate() {
public static func instantiate() -> UserDefaults {
.standard
}
}
Expand All @@ -4693,7 +4805,7 @@ final class SafeDIToolTests: XCTestCase {
@Instantiable
extension UserDefaults: Instantiable {
public static instantiate(suiteName: String) {
public static func instantiate(suiteName: String) -> UserDefaults {
UserDefaults(suiteName: suiteName)
}
}
Expand Down

0 comments on commit 2c3af48

Please sign in to comment.