Skip to content

Commit e2122d7

Browse files
Add support for Void-returning database functions (#155)
* Allow for void returning database functions. * wip * wip * fix --------- Co-authored-by: Brandon Williams <[email protected]>
1 parent 14f79c6 commit e2122d7

File tree

4 files changed

+131
-77
lines changed

4 files changed

+131
-77
lines changed

Sources/StructuredQueriesSQLite/Macros.swift

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,22 @@ public macro DatabaseFunction<each T: QueryBindable, R: QueryBindable>(
3535
module: "StructuredQueriesSQLiteMacros",
3636
type: "DatabaseFunctionMacro"
3737
)
38+
39+
/// Defines and implements a conformance to the ``/StructuredQueriesSQLiteCore/DatabaseFunction``
40+
/// protocol.
41+
///
42+
/// - Parameters
43+
/// - name: The function's name. Defaults to the name of the function the macro is applied to.
44+
/// - representableFunctionType: The function as represented in a query.
45+
/// - isDeterministic: Whether or not the function is deterministic (or "pure" or "referentially
46+
/// transparent"), _i.e._ given an input it will always return the same output.
47+
@attached(peer, names: overloaded, prefixed(`$`))
48+
public macro DatabaseFunction<each T: QueryBindable>(
49+
_ name: String = "",
50+
as representableFunctionType: ((repeat each T) -> Void).Type,
51+
isDeterministic: Bool = false
52+
) =
53+
#externalMacro(
54+
module: "StructuredQueriesSQLiteMacros",
55+
type: "DatabaseFunctionMacro"
56+
)

Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift

Lines changed: 25 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -25,31 +25,11 @@ extension DatabaseFunctionMacro: PeerMacro {
2525
return []
2626
}
2727

28-
guard declaration.signature.returnClause != nil else {
29-
context.diagnose(
30-
Diagnostic(
31-
node: declaration.signature,
32-
position: declaration.signature.endPositionBeforeTrailingTrivia,
33-
message: MacroExpansionErrorMessage(
34-
"Missing required return type"
35-
),
36-
fixIt: .replace(
37-
message: MacroExpansionFixItMessage("Insert '-> <#QueryBindable#>'"),
38-
oldNode: declaration.signature,
39-
newNode: declaration.signature.with(
40-
\.returnClause,
41-
ReturnClauseSyntax(
42-
type: IdentifierTypeSyntax(name: "<#QueryBindable#>")
43-
.with(\.leadingTrivia, .space)
44-
.with(\.trailingTrivia, .space)
45-
)
46-
)
47-
)
48-
)
28+
let returnClause =
29+
declaration.signature.returnClause
30+
?? ReturnClauseSyntax(
31+
type: "Swift.Void" as TypeSyntax
4932
)
50-
return []
51-
}
52-
5333
let declarationName = declaration.name.trimmedDescription.trimmingBackticks()
5434
var functionName = declarationName
5535
var functionRepresentation: FunctionTypeSyntax?
@@ -158,43 +138,45 @@ extension DatabaseFunctionMacro: PeerMacro {
158138
argumentBindings.append((parameterName, "\(type)(queryBinding: arguments[\(offset)])"))
159139
}
160140
var inputType = bodyArguments.joined(separator: ", ")
161-
let bodyReturnClause: String
162-
let outputType: TypeSyntax
163-
if let returnClause = signature.returnClause {
164-
outputType = returnClause.type.trimmed
165-
signature.returnClause?.type = (functionRepresentation?.returnClause ?? returnClause).type
166-
.asQueryExpression()
167-
bodyReturnClause = " \(returnClause.trimmedDescription)"
168-
} else {
169-
outputType = "Void"
170-
bodyReturnClause = " -> Void"
171-
}
141+
let isVoidReturning = signature.returnClause == nil
142+
let outputType = returnClause.type.trimmed
143+
signature.returnClause = returnClause
144+
signature.returnClause?.type = (functionRepresentation?.returnClause ?? returnClause).type
145+
.asQueryExpression()
146+
let bodyReturnClause = " \(returnClause.trimmedDescription)"
172147
let bodyType = """
173148
(\(inputType))\
174149
\(declaration.signature.effectSpecifiers?.trimmedDescription ?? "")\
175150
\(bodyReturnClause)
176151
"""
152+
let bodyInvocation = """
153+
\(declaration.signature.effectSpecifiers?.throwsClause != nil ? "try " : "")self.body(\
154+
\(argumentBindings.map { name, _ in "\(name).queryOutput" }.joined(separator: ", "))\
155+
)
156+
"""
177157
// TODO: Diagnose 'asyncClause'?
178158
signature.effectSpecifiers?.throwsClause = nil
179159

180-
var invocationBody = """
181-
\(functionRepresentation?.returnClause.type ?? outputType)(
182-
queryOutput: self.body(\
183-
\(argumentBindings.map { name, _ in "\(name).queryOutput" }.joined(separator: ", "))\
184-
)
160+
var invocationBody =
161+
isVoidReturning
162+
? """
163+
\(bodyInvocation)
164+
return .null
165+
"""
166+
: """
167+
return \(functionRepresentation?.returnClause.type ?? outputType)(
168+
queryOutput: \(bodyInvocation)
185169
)
186170
.queryBinding
187171
"""
188172
if declaration.signature.effectSpecifiers?.throwsClause != nil {
189173
invocationBody = """
190174
do {
191-
return try \(invocationBody)
175+
\(invocationBody)
192176
} catch {
193177
return .invalid(error)
194178
}
195179
"""
196-
} else {
197-
invocationBody = "return \(invocationBody)"
198180
}
199181

200182
var attributes = declaration.attributes

Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -572,8 +572,8 @@ extension SnapshotTests {
572572
return .invalid(InvalidInvocation())
573573
}
574574
do {
575-
return try Date(
576-
queryOutput: self.body()
575+
return Date(
576+
queryOutput: try self.body()
577577
)
578578
.queryBinding
579579
} catch {
@@ -627,8 +627,8 @@ extension SnapshotTests {
627627
return .invalid(InvalidInvocation())
628628
}
629629
do {
630-
return try Date(
631-
queryOutput: self.body()
630+
return Date(
631+
queryOutput: try self.body()
632632
)
633633
.queryBinding
634634
} catch {
@@ -869,35 +869,64 @@ extension SnapshotTests {
869869
}
870870
}
871871

872-
@Test func returnTypeDiagnostic() {
872+
@Test func voidReturnType() {
873873
assertMacro {
874874
"""
875875
@DatabaseFunction
876876
public func void() {
877877
print("...")
878878
}
879879
"""
880-
} diagnostics: {
881-
"""
882-
@DatabaseFunction
880+
} expansion: {
881+
#"""
883882
public func void() {
884-
──┬
885-
╰─ 🛑 Missing required return type
886-
✏️ Insert '-> <#QueryBindable#>'
887883
print("...")
888884
}
889-
"""
890-
} fixes: {
885+
886+
public var $void: __macro_local_4voidfMu_ {
887+
__macro_local_4voidfMu_(void)
888+
}
889+
890+
public struct __macro_local_4voidfMu_: StructuredQueriesSQLiteCore.ScalarDatabaseFunction {
891+
public typealias Input = ()
892+
public typealias Output = Swift.Void
893+
public let name = "void"
894+
public let argumentCount: Int? = 0
895+
public let isDeterministic = false
896+
public let body: () -> Swift.Void
897+
public init(_ body: @escaping () -> Swift.Void) {
898+
self.body = body
899+
}
900+
public func callAsFunction() -> some StructuredQueriesCore.QueryExpression<Swift.Void> {
901+
StructuredQueriesCore.SQLQueryExpression(
902+
"\(quote: self.name)()"
903+
)
904+
}
905+
public func invoke(
906+
_ arguments: [StructuredQueriesCore.QueryBinding]
907+
) -> StructuredQueriesCore.QueryBinding {
908+
guard self.argumentCount == nil || self.argumentCount == arguments.count else {
909+
return .invalid(InvalidInvocation())
910+
}
911+
self.body()
912+
return .null
913+
}
914+
private struct InvalidInvocation: Error {
915+
}
916+
}
917+
"""#
918+
}
919+
assertMacro {
891920
"""
892921
@DatabaseFunction
893-
public func void() -> <#QueryBindable#> {
894-
print("...")
922+
public func void() throws {
923+
throw Failure()
895924
}
896925
"""
897926
} expansion: {
898927
#"""
899-
public func void() -> <#QueryBindable#> {
900-
print("...")
928+
public func void() throws {
929+
throw Failure()
901930
}
902931
903932
public var $void: __macro_local_4voidfMu_ {
@@ -906,15 +935,15 @@ extension SnapshotTests {
906935
907936
public struct __macro_local_4voidfMu_: StructuredQueriesSQLiteCore.ScalarDatabaseFunction {
908937
public typealias Input = ()
909-
public typealias Output = <#QueryBindable#>
938+
public typealias Output = Swift.Void
910939
public let name = "void"
911940
public let argumentCount: Int? = 0
912941
public let isDeterministic = false
913-
public let body: () -> <#QueryBindable#>
914-
public init(_ body: @escaping () -> <#QueryBindable#>) {
942+
public let body: () throws -> Swift.Void
943+
public init(_ body: @escaping () throws -> Swift.Void) {
915944
self.body = body
916945
}
917-
public func callAsFunction() -> some StructuredQueriesCore.QueryExpression<<#QueryBindable#>> {
946+
public func callAsFunction() -> some StructuredQueriesCore.QueryExpression<Swift.Void> {
918947
StructuredQueriesCore.SQLQueryExpression(
919948
"\(quote: self.name)()"
920949
)
@@ -925,10 +954,12 @@ extension SnapshotTests {
925954
guard self.argumentCount == nil || self.argumentCount == arguments.count else {
926955
return .invalid(InvalidInvocation())
927956
}
928-
return <#QueryBindable#>(
929-
queryOutput: self.body()
930-
)
931-
.queryBinding
957+
do {
958+
try self.body()
959+
return .null
960+
} catch {
961+
return .invalid(error)
962+
}
932963
}
933964
private struct InvalidInvocation: Error {
934965
}

Tests/StructuredQueriesTests/DatabaseFunctionTests.swift

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@ import _StructuredQueriesSQLite
1010

1111
extension SnapshotTests {
1212
@Suite struct DatabaseFunctionTests {
13+
@Dependency(\.defaultDatabase) var database
14+
1315
@DatabaseFunction
1416
func isEnabled() -> Bool {
1517
true
1618
}
1719
@Test func customIsEnabled() {
18-
@Dependency(\.defaultDatabase) var database
1920
$isEnabled.install(database.handle)
2021
assertQuery(
2122
Values($isEnabled())
@@ -37,7 +38,6 @@ extension SnapshotTests {
3738
Date(timeIntervalSince1970: 0)
3839
}
3940
@Test func customDateTime() {
40-
@Dependency(\.defaultDatabase) var database
4141
$dateTime.install(database.handle)
4242
assertQuery(
4343
Values($dateTime())
@@ -59,7 +59,6 @@ extension SnapshotTests {
5959
first + second
6060
}
6161
@Test func customConcat() {
62-
@Dependency(\.defaultDatabase) var database
6362
$concat.install(database.handle)
6463
assertQuery(
6564
Values($concat(first: "foo", second: "bar"))
@@ -77,7 +76,6 @@ extension SnapshotTests {
7776
}
7877

7978
@Test func erasedConcat() {
80-
@Dependency(\.defaultDatabase) var database
8179
$concat.install(database.handle)
8280
assertQuery(
8381
Values($concat("foo", "bar"))
@@ -104,7 +102,6 @@ extension SnapshotTests {
104102
throw Failure()
105103
}
106104
@Test func customThrowing() {
107-
@Dependency(\.defaultDatabase) var database
108105
$throwing.install(database.handle)
109106
assertQuery(
110107
Values($throwing())
@@ -132,7 +129,6 @@ extension SnapshotTests {
132129
completion == .incomplete ? .completing : .incomplete
133130
}
134131
@Test func customToggle() {
135-
@Dependency(\.defaultDatabase) var database
136132
$toggle.install(database.handle)
137133
assertQuery(
138134
Values($toggle(Completion.incomplete))
@@ -155,7 +151,6 @@ extension SnapshotTests {
155151
}
156152

157153
@Test func customRepresentation() {
158-
@Dependency(\.defaultDatabase) var database
159154
$jsonCapitalize.install(database.handle)
160155
assertQuery(
161156
Values($jsonCapitalize(#bind(["hello", "world"])))
@@ -184,7 +179,6 @@ extension SnapshotTests {
184179
}
185180

186181
@Test func customMixedRepresentation() {
187-
@Dependency(\.defaultDatabase) var database
188182
$jsonDropFirst.install(database.handle)
189183
assertQuery(
190184
Values($jsonDropFirst(#bind(["hello", "world", "goodnight", "moon"]), 2))
@@ -215,7 +209,6 @@ extension SnapshotTests {
215209
}
216210

217211
@Test func customNilRepresentation() {
218-
@Dependency(\.defaultDatabase) var database
219212
$jsonCount.install(database.handle)
220213
assertQuery(
221214
Values($jsonCount(#bind(["hello", "world", "goodnight", "moon"])))
@@ -249,5 +242,34 @@ extension SnapshotTests {
249242
"""
250243
}
251244
}
245+
246+
final class Logger {
247+
var messages: [String] = []
248+
249+
@DatabaseFunction
250+
func log(_ message: String) {
251+
messages.append(message)
252+
}
253+
}
254+
255+
@Test func voidState() {
256+
let logger = Logger()
257+
logger.$log.install(database.handle)
258+
259+
assertQuery(
260+
Values(logger.$log("Hello, world!"))
261+
) {
262+
"""
263+
SELECT "log"('Hello, world!')
264+
"""
265+
} results: {
266+
"""
267+
┌──┐
268+
└──┘
269+
"""
270+
}
271+
272+
#expect(logger.messages == ["Hello, world!"])
273+
}
252274
}
253275
}

0 commit comments

Comments
 (0)