Skip to content

Commit

Permalink
Properly fulfill write promises
Browse files Browse the repository at this point in the history
  • Loading branch information
MahdiBM committed Aug 8, 2024
1 parent d18b137 commit 8b944b2
Show file tree
Hide file tree
Showing 12 changed files with 194 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ struct CloseStateMachine {
}

enum Action {
case sendCloseSync(CloseTarget)
case sendCloseSync(CloseTarget, promise: EventLoopPromise<Void>?)
case succeedClose(CloseCommandContext)
case failClose(CloseCommandContext, with: PSQLError)

Expand All @@ -24,14 +24,14 @@ struct CloseStateMachine {
self.state = .initialized(closeContext)
}

mutating func start() -> Action {
mutating func start(_ promise: EventLoopPromise<Void>?) -> Action {
guard case .initialized(let closeContext) = self.state else {
preconditionFailure("Start should only be called, if the query has been initialized")
}

self.state = .closeSyncSent(closeContext)

return .sendCloseSync(closeContext.target)
return .sendCloseSync(closeContext.target, promise: promise)
}

mutating func closeCompletedReceived() -> Action {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ struct ConnectionStateMachine {
// Connection Actions

// --- general actions
case sendParseDescribeBindExecuteSync(PostgresQuery)
case sendBindExecuteSync(PSQLExecuteStatement)
case sendParseDescribeBindExecuteSync(PostgresQuery, promise: EventLoopPromise<Void>?)
case sendBindExecuteSync(PSQLExecuteStatement, promise: EventLoopPromise<Void>?)
case failQuery(EventLoopPromise<PSQLRowStream>, with: PSQLError, cleanupContext: CleanUpContext?)
case succeedQuery(EventLoopPromise<PSQLRowStream>, with: QueryResult)

Expand All @@ -97,12 +97,12 @@ struct ConnectionStateMachine {
case forwardStreamError(PSQLError, read: Bool, cleanupContext: CleanUpContext?)

// Prepare statement actions
case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType])
case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType], promise: EventLoopPromise<Void>?)
case succeedPreparedStatementCreation(EventLoopPromise<RowDescription?>, with: RowDescription?)
case failPreparedStatementCreation(EventLoopPromise<RowDescription?>, with: PSQLError, cleanupContext: CleanUpContext?)

// Close actions
case sendCloseSync(CloseTarget)
case sendCloseSync(CloseTarget, promise: EventLoopPromise<Void>?)
case succeedClose(CloseCommandContext)
case failClose(CloseCommandContext, with: PSQLError, cleanupContext: CleanUpContext?)
}
Expand Down Expand Up @@ -234,7 +234,7 @@ struct ConnectionStateMachine {
}
self.state = .sslNegotiated
return .establishSSLConnection

case .initialized,
.sslNegotiated,
.sslHandlerAdded,
Expand All @@ -247,7 +247,7 @@ struct ConnectionStateMachine {
.closing,
.closed:
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.sslSupported))

case .modifying:
preconditionFailure("Invalid state: \(self.state)")
}
Expand Down Expand Up @@ -583,14 +583,16 @@ struct ConnectionStateMachine {
}

switch task {
case .extendedQuery(let queryContext):
case .extendedQuery(let queryContext, let writePromise):
writePromise?.fail(psqlErrror) /// Use `cleanupContext` or not?
switch queryContext.query {
case .executeStatement(_, let promise), .unnamed(_, let promise):
return .failQuery(promise, with: psqlErrror, cleanupContext: nil)
case .prepareStatement(_, _, _, let promise):
return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil)
}
case .closeCommand(let closeContext):
case .closeCommand(let closeContext, let writePromise):
writePromise?.fail(psqlErrror) /// Use `cleanupContext` or not?
return .failClose(closeContext, with: psqlErrror, cleanupContext: nil)
}
}
Expand Down Expand Up @@ -934,17 +936,17 @@ struct ConnectionStateMachine {
}

switch task {
case .extendedQuery(let queryContext):
case .extendedQuery(let queryContext, let promise):
self.state = .modifying // avoid CoW
var extendedQuery = ExtendedQueryStateMachine(queryContext: queryContext)
let action = extendedQuery.start()
let action = extendedQuery.start(promise)
self.state = .extendedQuery(extendedQuery, connectionContext)
return self.modify(with: action)

case .closeCommand(let closeContext):
case .closeCommand(let closeContext, let promise):
self.state = .modifying // avoid CoW
var closeStateMachine = CloseStateMachine(closeContext: closeContext)
let action = closeStateMachine.start()
let action = closeStateMachine.start(promise)
self.state = .closeCommand(closeStateMachine, connectionContext)
return self.modify(with: action)
}
Expand Down Expand Up @@ -1031,10 +1033,10 @@ extension ConnectionStateMachine {
extension ConnectionStateMachine {
mutating func modify(with action: ExtendedQueryStateMachine.Action) -> ConnectionStateMachine.ConnectionAction {
switch action {
case .sendParseDescribeBindExecuteSync(let query):
return .sendParseDescribeBindExecuteSync(query)
case .sendBindExecuteSync(let executeStatement):
return .sendBindExecuteSync(executeStatement)
case .sendParseDescribeBindExecuteSync(let query, let promise):
return .sendParseDescribeBindExecuteSync(query, promise: promise)
case .sendBindExecuteSync(let executeStatement, let promise):
return .sendBindExecuteSync(executeStatement, promise: promise)
case .failQuery(let requestContext, with: let error):
let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error)
return .failQuery(requestContext, with: error, cleanupContext: cleanupContext)
Expand All @@ -1057,8 +1059,8 @@ extension ConnectionStateMachine {
return .read
case .wait:
return .wait
case .sendParseDescribeSync(name: let name, query: let query, bindingDataTypes: let bindingDataTypes):
return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes)
case .sendParseDescribeSync(name: let name, query: let query, bindingDataTypes: let bindingDataTypes, let promise):
return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes, promise: promise)
case .succeedPreparedStatementCreation(let promise, with: let rowDescription):
return .succeedPreparedStatementCreation(promise, with: rowDescription)
case .failPreparedStatementCreation(let promise, with: let error):
Expand Down Expand Up @@ -1094,8 +1096,8 @@ extension ConnectionStateMachine {
extension ConnectionStateMachine {
mutating func modify(with action: CloseStateMachine.Action) -> ConnectionStateMachine.ConnectionAction {
switch action {
case .sendCloseSync(let sendClose):
return .sendCloseSync(sendClose)
case .sendCloseSync(let sendClose, let promise):
return .sendCloseSync(sendClose, promise: promise)
case .succeedClose(let closeContext):
return .succeedClose(closeContext)
case .failClose(let closeContext, with: let error):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ struct ExtendedQueryStateMachine {
}

enum Action {
case sendParseDescribeBindExecuteSync(PostgresQuery)
case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType])
case sendBindExecuteSync(PSQLExecuteStatement)
case sendParseDescribeBindExecuteSync(PostgresQuery, promise: EventLoopPromise<Void>?)
case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType], promise: EventLoopPromise<Void>?)
case sendBindExecuteSync(PSQLExecuteStatement, promise: EventLoopPromise<Void>?)

// --- general actions
case failQuery(EventLoopPromise<PSQLRowStream>, with: PSQLError)
case succeedQuery(EventLoopPromise<PSQLRowStream>, with: QueryResult)
Expand Down Expand Up @@ -56,7 +56,7 @@ struct ExtendedQueryStateMachine {
self.state = .initialized(queryContext)
}

mutating func start() -> Action {
mutating func start(_ promise: EventLoopPromise<Void>?) -> Action {
guard case .initialized(let queryContext) = self.state else {
preconditionFailure("Start should only be called, if the query has been initialized")
}
Expand All @@ -65,7 +65,7 @@ struct ExtendedQueryStateMachine {
case .unnamed(let query, _):
return self.avoidingStateMachineCoW { state -> Action in
state = .messagesSent(queryContext)
return .sendParseDescribeBindExecuteSync(query)
return .sendParseDescribeBindExecuteSync(query, promise: promise)
}

case .executeStatement(let prepared, _):
Expand All @@ -76,13 +76,14 @@ struct ExtendedQueryStateMachine {
case .none:
state = .noDataMessageReceived(queryContext)
}
return .sendBindExecuteSync(prepared)
return .sendBindExecuteSync(prepared, promise: promise)
}

/// Not my code, but this is ignoring the last argument which is a promise? is that fine?
case .prepareStatement(let name, let query, let bindingDataTypes, _):
return self.avoidingStateMachineCoW { state -> Action in
state = .messagesSent(queryContext)
return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes)
return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes, promise: promise)
}
}
}
Expand Down
11 changes: 8 additions & 3 deletions Sources/PostgresNIO/New/NotificationListener.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ final class NotificationListener: @unchecked Sendable {
self.state = .closure(context, closure)
}

func startListeningSucceeded(handler: PostgresChannelHandler) {
func startListeningSucceeded(
handler: PostgresChannelHandler,
writePromise: EventLoopPromise<Void>?
) {
self.eventLoop.preconditionInEventLoop()
let handlerLoopBound = NIOLoopBound(handler, eventLoop: self.eventLoop)

Expand All @@ -56,7 +59,7 @@ final class NotificationListener: @unchecked Sendable {
switch reason {
case .cancelled:
eventLoop.execute {
handlerLoopBound.value.cancelNotificationListener(channel: channel, id: listenerID)
handlerLoopBound.value.cancelNotificationListener(channel: channel, id: listenerID, writePromise: nil)
}

case .finished:
Expand All @@ -70,12 +73,14 @@ final class NotificationListener: @unchecked Sendable {

let notificationSequence = PostgresNotificationSequence(base: stream)
checkedContinuation.resume(returning: notificationSequence)
writePromise?.succeed(())

case .streamListening, .done:
fatalError("Invalid state: \(self.state)")

case .closure:
break // ignore
writePromise?.succeed(())
// ignore
}
}

Expand Down
10 changes: 6 additions & 4 deletions Sources/PostgresNIO/New/PSQLTask.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ enum HandlerTask {
}

enum PSQLTask {
case extendedQuery(ExtendedQueryContext)
case closeCommand(CloseCommandContext)
case extendedQuery(ExtendedQueryContext, writePromise: EventLoopPromise<Void>?)
case closeCommand(CloseCommandContext, writePromise: EventLoopPromise<Void>?)

func failWithError(_ error: PSQLError) {
switch self {
case .extendedQuery(let extendedQueryContext):
case .extendedQuery(let extendedQueryContext, let writePromise):
switch extendedQueryContext.query {
case .unnamed(_, let eventLoopPromise):
eventLoopPromise.fail(error)
Expand All @@ -24,9 +24,11 @@ enum PSQLTask {
case .prepareStatement(_, _, _, let eventLoopPromise):
eventLoopPromise.fail(error)
}
writePromise?.fail(error)

case .closeCommand(let closeCommandContext):
case .closeCommand(let closeCommandContext, let writePromise):
closeCommandContext.promise.fail(error)
writePromise?.fail(error)
}
}
}
Expand Down
Loading

0 comments on commit 8b944b2

Please sign in to comment.