Skip to content

Commit

Permalink
Add support for UNION queries (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
finestructure authored Dec 6, 2021
1 parent bfcaa63 commit d2027b4
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 0 deletions.
60 changes: 60 additions & 0 deletions Sources/SQLKit/Builders/SQLUnionBuilder.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
public final class SQLUnionBuilder: SQLQueryBuilder {
public var query: SQLExpression { self.union }

public var union: SQLUnion
public var database: SQLDatabase

public init(on database: SQLDatabase, initialQuery: SQLSelect) {
self.union = .init(initialQuery: initialQuery)
self.database = database
}

public func union(distinct query: SQLSelect) -> Self {
self.union.add(query, all: false)
return self
}

public func union(all query: SQLSelect) -> Self {
self.union.add(query, all: true)
return self
}
}

extension SQLDatabase {
public func union(_ predicate: (SQLSelectBuilder) -> SQLSelectBuilder) -> SQLUnionBuilder {
return SQLUnionBuilder(on: self, initialQuery: predicate(.init(on: self)).select)
}
}

extension SQLUnionBuilder {
public func union(distinct predicate: (SQLSelectBuilder) -> SQLSelectBuilder) -> Self {
return self.union(distinct: predicate(.init(on: self.database)).select)
}

/// Alias the `distinct` variant so it acts as the "default".
public func union(_ predicate: (SQLSelectBuilder) -> SQLSelectBuilder) -> Self {
return self.union(distinct: predicate)
}

public func union(all predicate: (SQLSelectBuilder) -> SQLSelectBuilder) -> Self {
return self.union(all: predicate(.init(on: self.database)).select)
}
}

extension SQLSelectBuilder {
public func union(distinct predicate: (SQLSelectBuilder) -> SQLSelectBuilder) -> SQLUnionBuilder {
return SQLUnionBuilder(on: self.database, initialQuery: self.select)
.union(distinct: predicate(.init(on: self.database)).select)
}

/// Alias the `distinct` variant so it acts as the "default".
public func union(_ predicate: (SQLSelectBuilder) -> SQLSelectBuilder) -> SQLUnionBuilder {
return SQLUnionBuilder(on: self.database, initialQuery: self.select)
.union(distinct: predicate(.init(on: self.database)).select)
}

public func union(all predicate: (SQLSelectBuilder) -> SQLSelectBuilder) -> SQLUnionBuilder {
return SQLUnionBuilder(on: self.database, initialQuery: self.select)
.union(all: predicate(.init(on: self.database)).select)
}
}
36 changes: 36 additions & 0 deletions Sources/SQLKit/Query/SQLUnion.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
public struct SQLUnion: SQLExpression {
public var initialQuery: SQLSelect
public var unions: [(SQLUnionJoiner, SQLSelect)]

public init(initialQuery: SQLSelect, unions: [(SQLUnionJoiner, SQLSelect)] = []) {
self.initialQuery = initialQuery
self.unions = unions
}

public mutating func add(_ query: SQLSelect, all: Bool) {
self.unions.append((.init(all: all), query))
}

public func serialize(to serializer: inout SQLSerializer) {
assert(!self.unions.isEmpty, "Serializing a union with only one query is invalid.")
SQLGroupExpression(self.initialQuery).serialize(to: &serializer)
self.unions
.forEach { (joiner, select) in
joiner.serialize(to: &serializer)
SQLGroupExpression(select).serialize(to: &serializer)
}
}
}

public struct SQLUnionJoiner: SQLExpression {
public var all: Bool

public init(all: Bool) {
self.all = all
}

public func serialize(to serializer: inout SQLSerializer) {
serializer.write(" UNION\(self.all ? " ALL" : "") ")
}
}

36 changes: 36 additions & 0 deletions Tests/SQLKitTests/SQLKitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -829,4 +829,40 @@ CREATE TABLE `planets`(`id` BIGINT, `name` TEXT, `diameter` INTEGER, `galaxy_nam
XCTFail("Could not decode row with keyDecodingStrategy \(error)")
}
}

func testUnion() throws {
try db.select()
.column("id")
.from("t1")
.where("f1", .equal, "foo")
.limit(1)
.union({
$0.column("id")
.from("t2")
.where("f2", .equal, "bar")
.limit(2)
}).union({
$0.column("id")
.from("t3")
.where("f3", .equal, "baz")
.limit(3)
})
.run().wait()

XCTAssertEqual(db.results[0], "(SELECT `id` FROM `t1` WHERE `f1` = ? LIMIT 1) UNION (SELECT `id` FROM `t2` WHERE `f2` = ? LIMIT 2) UNION (SELECT `id` FROM `t3` WHERE `f3` = ? LIMIT 3)")
}

func testUnionAll() throws {
try db.select()
.column("id")
.from("t1")
.union(all: {
$0.column("id")
.from("t2")
})
.run().wait()

XCTAssertEqual(db.results[0], "(SELECT `id` FROM `t1`) UNION ALL (SELECT `id` FROM `t2`)")
}

}

0 comments on commit d2027b4

Please sign in to comment.