From d09b5527fb0c341f6993e4f5233fadbb7dee1c76 Mon Sep 17 00:00:00 2001 From: Tanner Date: Fri, 13 Dec 2019 17:01:47 -0500 Subject: [PATCH] model coding (#79) * add SQLRow decoding support * change spelling to model:, add update builder helper * row coder updates * fix tests --- Sources/SQLKit/Builders/SQLQueryFetcher.swift | 36 ++++++++ .../SQLKit/Builders/SQLUpdateBuilder.swift | 12 ++- Sources/SQLKit/SQLRow.swift | 11 +++ Sources/SQLKit/SQLRowDecoder.swift | 92 +++++++++++++++++++ Tests/SQLKitTests/SQLKitTests.swift | 83 +++++++++++++++++ 5 files changed, 231 insertions(+), 3 deletions(-) create mode 100644 Sources/SQLKit/SQLRowDecoder.swift diff --git a/Sources/SQLKit/Builders/SQLQueryFetcher.swift b/Sources/SQLKit/Builders/SQLQueryFetcher.swift index 97deac3f..9674d845 100644 --- a/Sources/SQLKit/Builders/SQLQueryFetcher.swift +++ b/Sources/SQLKit/Builders/SQLQueryFetcher.swift @@ -8,6 +8,18 @@ public protocol SQLQueryFetcher: SQLQueryBuilder { } extension SQLQueryFetcher { // MARK: First + + + public func first(decoding: D.Type) -> EventLoopFuture + where D: Decodable + { + self.first().flatMapThrowing { + guard let row = $0 else { + return nil + } + return try row.decode(model: D.self) + } + } /// Collects the first raw output and returns it. /// @@ -18,6 +30,17 @@ extension SQLQueryFetcher { } // MARK: All + + + public func all(decoding: D.Type) -> EventLoopFuture<[D]> + where D: Decodable + { + self.all().flatMapThrowing { + try $0.map { + try $0.decode(model: D.self) + } + } + } /// Collects all raw output into an array and returns it. /// @@ -31,6 +54,19 @@ extension SQLQueryFetcher { } // MARK: Run + + + public func run(decoding: D.Type, _ handler: @escaping (Result) -> ()) -> EventLoopFuture + where D: Decodable + { + self.run { + do { + try handler(.success($0.decode(model: D.self))) + } catch { + handler(.failure(error)) + } + } + } /// Runs the query, passing output to the supplied closure as it is recieved. diff --git a/Sources/SQLKit/Builders/SQLUpdateBuilder.swift b/Sources/SQLKit/Builders/SQLUpdateBuilder.swift index 72446e90..43aef3e3 100644 --- a/Sources/SQLKit/Builders/SQLUpdateBuilder.swift +++ b/Sources/SQLKit/Builders/SQLUpdateBuilder.swift @@ -29,11 +29,17 @@ public final class SQLUpdateBuilder: SQLQueryBuilder, SQLPredicateBuilder { self.update = update self.database = database } + + public func set(model: E) throws -> Self where E: Encodable { + let row = try SQLQueryEncoder().encode(model) + row.forEach { column, value in + _ = self.set(SQLColumn(column), to: value) + } + return self + } /// Sets a column (specified by an identifier) to an expression. - public func set(_ column: String, to bind: T) -> Self - where T: Encodable - { + public func set(_ column: String, to bind: Encodable) -> Self { return self.set(SQLIdentifier(column), to: SQLBind(bind)) } diff --git a/Sources/SQLKit/SQLRow.swift b/Sources/SQLKit/SQLRow.swift index 23bac3ab..cf770091 100644 --- a/Sources/SQLKit/SQLRow.swift +++ b/Sources/SQLKit/SQLRow.swift @@ -1,4 +1,15 @@ public protocol SQLRow { + var allColumns: [String] { get } + func contains(column: String) -> Bool + func decodeNil(column: String) throws -> Bool func decode(column: String, as type: D.Type) throws -> D where D: Decodable } + +extension SQLRow { + public func decode(model type: D.Type, prefix: String? = nil) throws -> D + where D: Decodable + { + try SQLRowDecoder().decode(D.self, from: self, prefix: prefix) + } +} diff --git a/Sources/SQLKit/SQLRowDecoder.swift b/Sources/SQLKit/SQLRowDecoder.swift new file mode 100644 index 00000000..8793e4d3 --- /dev/null +++ b/Sources/SQLKit/SQLRowDecoder.swift @@ -0,0 +1,92 @@ +struct SQLRowDecoder { + func decode(_ type: T.Type, from row: SQLRow, prefix: String? = nil) throws -> T + where T: Decodable + { + return try T.init(from: _Decoder(prefix: prefix, row: row)) + } + + enum _Error: Error { + case nesting + case unkeyedContainer + case singleValueContainer + } + + struct _Decoder: Decoder { + let prefix: String? + let row: SQLRow + var codingPath: [CodingKey] = [] + var userInfo: [CodingUserInfoKey : Any] { + [:] + } + + func container(keyedBy type: Key.Type) throws -> KeyedDecodingContainer + where Key: CodingKey + { + .init(_KeyedDecoder(prefix: self.prefix, row: self.row, codingPath: self.codingPath)) + } + + func unkeyedContainer() throws -> UnkeyedDecodingContainer { + throw _Error.unkeyedContainer + } + + func singleValueContainer() throws -> SingleValueDecodingContainer { + throw _Error.singleValueContainer + } + } + + struct _KeyedDecoder: KeyedDecodingContainerProtocol + where Key: CodingKey + { + let prefix: String? + let row: SQLRow + var codingPath: [CodingKey] = [] + var allKeys: [Key] { + self.row.allColumns.compactMap { + Key.init(stringValue: $0) + } + } + + func column(for key: Key) -> String { + if let prefix = self.prefix { + return prefix + key.stringValue + } else { + return key.stringValue + } + } + + func contains(_ key: Key) -> Bool { + self.row.contains(column: self.column(for: key)) + } + + func decodeNil(forKey key: Key) throws -> Bool { + try self.row.decodeNil(column: self.column(for: key)) + } + + func decode(_ type: T.Type, forKey key: Key) throws -> T + where T : Decodable + { + try self.row.decode(column: self.column(for: key), as: T.self) + } + + func nestedContainer( + keyedBy type: NestedKey.Type, + forKey key: Key + ) throws -> KeyedDecodingContainer + where NestedKey : CodingKey + { + throw _Error.nesting + } + + func nestedUnkeyedContainer(forKey key: Key) throws -> UnkeyedDecodingContainer { + throw _Error.nesting + } + + func superDecoder() throws -> Decoder { + _Decoder(prefix: self.prefix, row: self.row, codingPath: self.codingPath) + } + + func superDecoder(forKey key: Key) throws -> Decoder { + throw _Error.nesting + } + } +} diff --git a/Tests/SQLKitTests/SQLKitTests.swift b/Tests/SQLKitTests/SQLKitTests.swift index 189e3f81..a5250b0e 100644 --- a/Tests/SQLKitTests/SQLKitTests.swift +++ b/Tests/SQLKitTests/SQLKitTests.swift @@ -247,4 +247,87 @@ CREATE TABLE `planets`(`id` BIGINT, `name` TEXT, `diameter` INTEGER, `galaxy_nam XCTAssertEqual(db.results[2], "CREATE TABLE `planets3`(`galaxy_id` BIGINT, FOREIGN KEY (`galaxy_id`) REFERENCES `galaxies` (`id`) ON DELETE RESTRICT ON UPDATE CASCADE)") } + + func testSQLRowDecoder() throws { + struct Foo: Codable { + let id: UUID + let foo: Int + let bar: Double? + let baz: String + } + + do { + let row = TestRow(data: [ + "id": UUID(), + "foo": 42, + "bar": Double?.none as Any, + "baz": "vapor" + ]) + + let foo = try row.decode(model: Foo.self) + XCTAssertEqual(foo.foo, 42) + XCTAssertEqual(foo.bar, nil) + XCTAssertEqual(foo.baz, "vapor") + } + do { + let row = TestRow(data: [ + "foos_id": UUID(), + "foos_foo": 42, + "foos_bar": Double?.none as Any, + "foos_baz": "vapor" + ]) + + let foo = try row.decode(model: Foo.self, prefix: "foos_") + XCTAssertEqual(foo.foo, 42) + XCTAssertEqual(foo.bar, nil) + XCTAssertEqual(foo.baz, "vapor") + } + } +} + +struct TestRow: SQLRow { + var data: [String: Any] + + enum _Error: Error { + case missingColumn(String) + case typeMismatch(Any, Any.Type) + } + + var allColumns: [String] { + .init(self.data.keys) + } + + func contains(column: String) -> Bool { + self.data.keys.contains(column) + } + + func decodeNil(column: String) throws -> Bool { + if let value = self.data[column], let optional = value as? OptionalType { + return optional.isNil + } else { + return false + } + } + + func decode(column: String, as type: D.Type) throws -> D + where D : Decodable + { + guard let value = self.data[column] else { + throw _Error.missingColumn(column) + } + guard let cast = value as? D else { + throw _Error.typeMismatch(value, D.self) + } + return cast + } +} + +protocol OptionalType { + var isNil: Bool { get } +} + +extension Optional: OptionalType { + var isNil: Bool { + self == nil + } }