Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 112 additions & 73 deletions Sources/StructuredQueriesCore/Operators.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ extension QueryExpression where QueryValue: QueryBindable {
/// - rhs: Another expression to compare.
/// - Returns: A predicate expression.
public static func == (
lhs: Self, rhs: some QueryExpression<QueryValue>
lhs: Self,
rhs: some QueryExpression<QueryValue>
) -> some QueryExpression<Bool> {
lhs.eq(rhs)
}
Expand All @@ -34,27 +35,12 @@ extension QueryExpression where QueryValue: QueryBindable {
/// - rhs: Another expression to compare.
/// - Returns: A predicate expression.
public static func != (
lhs: Self, rhs: some QueryExpression<QueryValue>
lhs: Self,
rhs: some QueryExpression<QueryValue>
) -> some QueryExpression<Bool> {
lhs.neq(rhs)
}

@_disfavoredOverload
@_documentation(visibility: private)
public static func == (
lhs: Self, rhs: some QueryExpression<QueryValue?>
) -> some QueryExpression<Bool> {
BinaryOperator(lhs: lhs, operator: isNull(rhs) ? "IS" : "=", rhs: rhs)
}

@_disfavoredOverload
@_documentation(visibility: private)
public static func != (
lhs: Self, rhs: some QueryExpression<QueryValue?>
) -> some QueryExpression<Bool> {
BinaryOperator(lhs: lhs, operator: isNull(rhs) ? "IS NOT" : "<>", rhs: rhs)
}

/// Returns a predicate expression indicating whether two query expressions are equal.
///
/// ```swift
Expand Down Expand Up @@ -119,34 +105,6 @@ private func isNull<Value>(_ expression: some QueryExpression<Value>) -> Bool {
}

extension QueryExpression where QueryValue: QueryBindable & _OptionalProtocol {
@_documentation(visibility: private)
public static func == (
lhs: Self, rhs: some QueryExpression<QueryValue.Wrapped>
) -> some QueryExpression<Bool> {
BinaryOperator(lhs: lhs, operator: isNull(lhs) ? "IS" : "=", rhs: rhs)
}

@_documentation(visibility: private)
public static func != (
lhs: Self, rhs: some QueryExpression<QueryValue.Wrapped>
) -> some QueryExpression<Bool> {
BinaryOperator(lhs: lhs, operator: isNull(lhs) ? "IS NOT" : "<>", rhs: rhs)
}

@_documentation(visibility: private)
public static func == (
lhs: Self, rhs: some QueryExpression<QueryValue>
) -> some QueryExpression<Bool> {
BinaryOperator(lhs: lhs, operator: isNull(lhs) || isNull(rhs) ? "IS" : "=", rhs: rhs)
}

@_documentation(visibility: private)
public static func != (
lhs: Self, rhs: some QueryExpression<QueryValue>
) -> some QueryExpression<Bool> {
BinaryOperator(lhs: lhs, operator: isNull(lhs) || isNull(rhs) ? "IS NOT" : "<>", rhs: rhs)
}

@_documentation(visibility: private)
public func eq(_ other: some QueryExpression<QueryValue.Wrapped>) -> some QueryExpression<Bool> {
BinaryOperator(lhs: self, operator: "=", rhs: other)
Expand Down Expand Up @@ -183,16 +141,6 @@ extension QueryExpression where QueryValue: QueryBindable & _OptionalProtocol {
}

extension QueryExpression where QueryValue: QueryBindable {
@_documentation(visibility: private)
public static func == (lhs: Self, rhs: _Null<QueryValue>) -> some QueryExpression<Bool> {
lhs.is(rhs)
}

@_documentation(visibility: private)
public static func != (lhs: Self, rhs: _Null<QueryValue>) -> some QueryExpression<Bool> {
lhs.isNot(rhs)
}

@_documentation(visibility: private)
public func `is`(
_ other: _Null<QueryValue>
Expand All @@ -217,6 +165,80 @@ extension _Null: ExpressibleByNilLiteral {
public init(nilLiteral: ()) {}
}

// NB: This overload is required due to an overload resolution bug of 'Updates[dynamicMember:]'.
@_disfavoredOverload
@_documentation(visibility: private)
public func == <QueryValue>(
lhs: any QueryExpression<QueryValue>,
rhs: some QueryExpression<QueryValue?>
) -> some QueryExpression<Bool> {
BinaryOperator(lhs: lhs, operator: isNull(rhs) ? "IS" : "=", rhs: rhs)
}

// NB: This overload is required due to an overload resolution bug of 'Updates[dynamicMember:]'.
@_disfavoredOverload
@_documentation(visibility: private)
public func != <QueryValue>(
lhs: any QueryExpression<QueryValue>,
rhs: some QueryExpression<QueryValue?>
) -> some QueryExpression<Bool> {
BinaryOperator(lhs: lhs, operator: isNull(rhs) ? "IS NOT" : "<>", rhs: rhs)
}

// NB: This overload is required due to an overload resolution bug of 'Updates[dynamicMember:]'.
@_documentation(visibility: private)
public func == <QueryValue: _OptionalProtocol>(
lhs: any QueryExpression<QueryValue>,
rhs: some QueryExpression<QueryValue.Wrapped>
) -> some QueryExpression<Bool> {
BinaryOperator(lhs: lhs, operator: isNull(lhs) ? "IS" : "=", rhs: rhs)
}

// NB: This overload is required due to an overload resolution bug of 'Updates[dynamicMember:]'.
@_documentation(visibility: private)
public func != <QueryValue: _OptionalProtocol>(
lhs: any QueryExpression<QueryValue>,
rhs: some QueryExpression<QueryValue.Wrapped>
) -> some QueryExpression<Bool> {
BinaryOperator(lhs: lhs, operator: isNull(lhs) ? "IS NOT" : "<>", rhs: rhs)
}

// NB: This overload is required due to an overload resolution bug of 'Updates[dynamicMember:]'.
@_documentation(visibility: private)
public func == <QueryValue: _OptionalProtocol>(
lhs: any QueryExpression<QueryValue>,
rhs: some QueryExpression<QueryValue>
) -> some QueryExpression<Bool> {
BinaryOperator(lhs: lhs, operator: isNull(lhs) || isNull(rhs) ? "IS" : "=", rhs: rhs)
}

// NB: This overload is required due to an overload resolution bug of 'Updates[dynamicMember:]'.
@_documentation(visibility: private)
public func != <QueryValue: _OptionalProtocol>(
lhs: any QueryExpression<QueryValue>,
rhs: some QueryExpression<QueryValue>
) -> some QueryExpression<Bool> {
BinaryOperator(lhs: lhs, operator: isNull(lhs) || isNull(rhs) ? "IS NOT" : "<>", rhs: rhs)
}

// NB: This overload is required due to an overload resolution bug of 'Updates[dynamicMember:]'.
@_documentation(visibility: private)
public func == <QueryValue: QueryBindable>(
lhs: any QueryExpression<QueryValue>,
rhs: _Null<QueryValue>
) -> some QueryExpression<Bool> {
SQLQueryExpression(lhs).is(rhs)
}

// NB: This overload is required due to an overload resolution bug of 'Updates[dynamicMember:]'.
@_documentation(visibility: private)
public func != <QueryValue: QueryBindable>(
lhs: any QueryExpression<QueryValue>,
rhs: _Null<QueryValue>
) -> some QueryExpression<Bool> {
SQLQueryExpression(lhs).isNot(rhs)
}

extension QueryExpression where QueryValue: QueryBindable {
/// Returns a predicate expression indicating whether the value of the first expression is less
/// than that of the second expression.
Expand All @@ -229,7 +251,8 @@ extension QueryExpression where QueryValue: QueryBindable {
/// - rhs: Another expression to compare.
/// - Returns: A predicate expression.
public static func < (
lhs: Self, rhs: some QueryExpression<QueryValue>
lhs: Self,
rhs: some QueryExpression<QueryValue>
) -> some QueryExpression<Bool> {
lhs.lt(rhs)
}
Expand All @@ -245,7 +268,8 @@ extension QueryExpression where QueryValue: QueryBindable {
/// - rhs: Another expression to compare.
/// - Returns: A predicate expression.
public static func > (
lhs: Self, rhs: some QueryExpression<QueryValue>
lhs: Self,
rhs: some QueryExpression<QueryValue>
) -> some QueryExpression<Bool> {
lhs.gt(rhs)
}
Expand All @@ -261,7 +285,8 @@ extension QueryExpression where QueryValue: QueryBindable {
/// - rhs: Another expression to compare.
/// - Returns: A predicate expression.
public static func <= (
lhs: Self, rhs: some QueryExpression<QueryValue>
lhs: Self,
rhs: some QueryExpression<QueryValue>
) -> some QueryExpression<Bool> {
lhs.lte(rhs)
}
Expand All @@ -277,7 +302,8 @@ extension QueryExpression where QueryValue: QueryBindable {
/// - rhs: Another expression to compare.
/// - Returns: A predicate expression.
public static func >= (
lhs: Self, rhs: some QueryExpression<QueryValue>
lhs: Self,
rhs: some QueryExpression<QueryValue>
) -> some QueryExpression<Bool> {
lhs.gte(rhs)
}
Expand Down Expand Up @@ -338,7 +364,8 @@ extension QueryExpression where QueryValue == Bool {
/// - rhs: The right-hand side of the operation.
/// - Returns: A predicate expression.
public static func && (
lhs: Self, rhs: some QueryExpression<QueryValue>
lhs: Self,
rhs: some QueryExpression<QueryValue>
) -> some QueryExpression<QueryValue> {
lhs.and(rhs)
}
Expand All @@ -353,7 +380,8 @@ extension QueryExpression where QueryValue == Bool {
/// - rhs: The right-hand side of the operation.
/// - Returns: A predicate expression.
public static func || (
lhs: Self, rhs: some QueryExpression<QueryValue>
lhs: Self,
rhs: some QueryExpression<QueryValue>
) -> some QueryExpression<QueryValue> {
lhs.or(rhs)
}
Expand Down Expand Up @@ -415,7 +443,8 @@ extension QueryExpression where QueryValue: Numeric {
/// - rhs: The second expression to add.
/// - Returns: A sum expression.
public static func + (
lhs: Self, rhs: some QueryExpression<QueryValue>
lhs: Self,
rhs: some QueryExpression<QueryValue>
) -> some QueryExpression<QueryValue> {
BinaryOperator(lhs: lhs, operator: "+", rhs: rhs)
}
Expand All @@ -427,7 +456,8 @@ extension QueryExpression where QueryValue: Numeric {
/// - rhs: The second expression to subtract.
/// - Returns: A difference expression.
public static func - (
lhs: Self, rhs: some QueryExpression<QueryValue>
lhs: Self,
rhs: some QueryExpression<QueryValue>
) -> some QueryExpression<QueryValue> {
BinaryOperator(lhs: lhs, operator: "-", rhs: rhs)
}
Expand All @@ -439,7 +469,8 @@ extension QueryExpression where QueryValue: Numeric {
/// - rhs: The second expression to multiply.
/// - Returns: A product expression.
public static func * (
lhs: Self, rhs: some QueryExpression<QueryValue>
lhs: Self,
rhs: some QueryExpression<QueryValue>
) -> some QueryExpression<QueryValue> {
BinaryOperator(lhs: lhs, operator: "*", rhs: rhs)
}
Expand All @@ -451,7 +482,8 @@ extension QueryExpression where QueryValue: Numeric {
/// - rhs: The second expression to divide.
/// - Returns: A quotient expression.
public static func / (
lhs: Self, rhs: some QueryExpression<QueryValue>
lhs: Self,
rhs: some QueryExpression<QueryValue>
) -> some QueryExpression<QueryValue> {
BinaryOperator(lhs: lhs, operator: "/", rhs: rhs)
}
Expand Down Expand Up @@ -546,7 +578,8 @@ extension QueryExpression where QueryValue: BinaryInteger {
/// - rhs: The value to divide `lhs` by.
/// - Returns: An expression representing the remainder, or `NULL` if `rhs` is zero.
public static func % (
lhs: Self, rhs: some QueryExpression<QueryValue>
lhs: Self,
rhs: some QueryExpression<QueryValue>
) -> some QueryExpression<QueryValue?> {
BinaryOperator(lhs: lhs, operator: "%", rhs: rhs)
}
Expand All @@ -558,7 +591,8 @@ extension QueryExpression where QueryValue: BinaryInteger {
/// - rhs: Another integer expression.
/// - Returns: An expression representing a bitwise AND operation on the two given expressions.
public static func & (
lhs: Self, rhs: some QueryExpression<QueryValue>
lhs: Self,
rhs: some QueryExpression<QueryValue>
) -> some QueryExpression<QueryValue> {
BinaryOperator(lhs: lhs, operator: "&", rhs: rhs)
}
Expand All @@ -570,7 +604,8 @@ extension QueryExpression where QueryValue: BinaryInteger {
/// - rhs: Another integer expression.
/// - Returns: An expression representing a bitwise OR operation on the two given expressions.
public static func | (
lhs: Self, rhs: some QueryExpression<QueryValue>
lhs: Self,
rhs: some QueryExpression<QueryValue>
) -> some QueryExpression<QueryValue> {
BinaryOperator(lhs: lhs, operator: "|", rhs: rhs)
}
Expand All @@ -583,7 +618,8 @@ extension QueryExpression where QueryValue: BinaryInteger {
/// - rhs: Another integer expression.
/// - Returns: An expression representing a left bitshift operation on the two given expressions.
public static func << (
lhs: Self, rhs: some QueryExpression<QueryValue>
lhs: Self,
rhs: some QueryExpression<QueryValue>
) -> some QueryExpression<QueryValue> {
BinaryOperator(lhs: lhs, operator: "<<", rhs: rhs)
}
Expand All @@ -596,7 +632,8 @@ extension QueryExpression where QueryValue: BinaryInteger {
/// - rhs: Another integer expression.
/// - Returns: An expression representing a right bitshift operation on the two given expressions.
public static func >> (
lhs: Self, rhs: some QueryExpression<QueryValue>
lhs: Self,
rhs: some QueryExpression<QueryValue>
) -> some QueryExpression<QueryValue> {
BinaryOperator(lhs: lhs, operator: ">>", rhs: rhs)
}
Expand Down Expand Up @@ -647,7 +684,8 @@ extension QueryExpression where QueryValue == String {
/// - rhs: The second string expression.
/// - Returns: An expression concatenating the first expression with the second.
public static func + (
lhs: Self, rhs: some QueryExpression<QueryValue>
lhs: Self,
rhs: some QueryExpression<QueryValue>
) -> some QueryExpression<QueryValue> {
BinaryOperator(lhs: lhs, operator: "||", rhs: rhs)
}
Expand Down Expand Up @@ -762,7 +800,8 @@ extension SQLQueryExpression<String> {
/// - lhs: The column to append.
/// - rhs: The appended text.
public static func += (
lhs: inout Self, rhs: some QueryExpression<QueryValue>
lhs: inout Self,
rhs: some QueryExpression<QueryValue>
) {
lhs = Self(lhs + rhs)
}
Expand Down
12 changes: 0 additions & 12 deletions Tests/StructuredQueriesTests/SQLMacroTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -173,18 +173,6 @@ extension SnapshotTests {
"""
}
}

func foo() {
let searchText = "get"
#sql(
"""
SELECT \(Reminder.columns)
FROM \(Reminder.self)
WHERE \(Reminder.title) COLLATE NOCASE LIKE \(searchText)
""",
as: Reminder.self
)
}
}
}

Expand Down
11 changes: 5 additions & 6 deletions Tests/StructuredQueriesTests/UpdateTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,7 @@ extension SnapshotTests {
.find(1)
.update {
$0.dueDate = Case()
.when($0.dueDate.is(nil), then: #sql("'2018-01-29 00:08:00.000'"))
.else(#sql("NULL"))
.when($0.dueDate == nil, then: #sql("'2018-01-29 00:08:00.000'"))
}

assertQuery(
Expand All @@ -362,11 +361,11 @@ extension SnapshotTests {
) {
"""
UPDATE "reminders"
SET "dueDate" = CASE WHEN ("reminders"."dueDate" IS NULL) THEN '2018-01-29 00:08:00.000' ELSE NULL END
SET "dueDate" = CASE WHEN ("reminders"."dueDate" IS NULL) THEN '2018-01-29 00:08:00.000' END
WHERE ("reminders"."id" = 1)
RETURNING "dueDate"
"""
} results: {
}results: {
"""
┌─────┐
│ nil │
Expand All @@ -380,11 +379,11 @@ extension SnapshotTests {
) {
"""
UPDATE "reminders"
SET "dueDate" = CASE WHEN ("reminders"."dueDate" IS NULL) THEN '2018-01-29 00:08:00.000' ELSE NULL END
SET "dueDate" = CASE WHEN ("reminders"."dueDate" IS NULL) THEN '2018-01-29 00:08:00.000' END
WHERE ("reminders"."id" = 1)
RETURNING "dueDate"
"""
} results: {
}results: {
"""
┌────────────────────────────────┐
│ Date(2018-01-29T00:08:00.000Z) │
Expand Down