diff --git a/Sources/PostgresNIO/New/Connection State Machine/CloseStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/CloseStateMachine.swift index 791cebdd..e55f6fae 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/CloseStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/CloseStateMachine.swift @@ -10,7 +10,7 @@ struct CloseStateMachine { } enum Action { - case sendCloseSync(CloseTarget) + case sendCloseSync(CloseTarget, promise: EventLoopPromise?) case succeedClose(CloseCommandContext) case failClose(CloseCommandContext, with: PSQLError) @@ -24,14 +24,14 @@ struct CloseStateMachine { self.state = .initialized(closeContext) } - mutating func start() -> Action { + mutating func start(_ promise: EventLoopPromise?) -> Action { guard case .initialized(let closeContext) = self.state else { preconditionFailure("Start should only be called, if the query has been initialized") } self.state = .closeSyncSent(closeContext) - return .sendCloseSync(closeContext.target) + return .sendCloseSync(closeContext.target, promise: promise) } mutating func closeCompletedReceived() -> Action { diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 9d264bcc..17039df2 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -85,8 +85,8 @@ struct ConnectionStateMachine { // Connection Actions // --- general actions - case sendParseDescribeBindExecuteSync(PostgresQuery) - case sendBindExecuteSync(PSQLExecuteStatement) + case sendParseDescribeBindExecuteSync(PostgresQuery, promise: EventLoopPromise?) + case sendBindExecuteSync(PSQLExecuteStatement, promise: EventLoopPromise?) case failQuery(EventLoopPromise, with: PSQLError, cleanupContext: CleanUpContext?) case succeedQuery(EventLoopPromise, with: QueryResult) @@ -97,12 +97,12 @@ struct ConnectionStateMachine { case forwardStreamError(PSQLError, read: Bool, cleanupContext: CleanUpContext?) // Prepare statement actions - case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType]) + case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType], promise: EventLoopPromise?) case succeedPreparedStatementCreation(EventLoopPromise, with: RowDescription?) case failPreparedStatementCreation(EventLoopPromise, with: PSQLError, cleanupContext: CleanUpContext?) // Close actions - case sendCloseSync(CloseTarget) + case sendCloseSync(CloseTarget, promise: EventLoopPromise?) case succeedClose(CloseCommandContext) case failClose(CloseCommandContext, with: PSQLError, cleanupContext: CleanUpContext?) } @@ -234,7 +234,7 @@ struct ConnectionStateMachine { } self.state = .sslNegotiated return .establishSSLConnection - + case .initialized, .sslNegotiated, .sslHandlerAdded, @@ -247,7 +247,7 @@ struct ConnectionStateMachine { .closing, .closed: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.sslSupported)) - + case .modifying: preconditionFailure("Invalid state: \(self.state)") } @@ -583,14 +583,16 @@ struct ConnectionStateMachine { } switch task { - case .extendedQuery(let queryContext): + case .extendedQuery(let queryContext, let writePromise): + writePromise?.fail(psqlErrror) /// Use `cleanupContext` or not? switch queryContext.query { case .executeStatement(_, let promise), .unnamed(_, let promise): return .failQuery(promise, with: psqlErrror, cleanupContext: nil) case .prepareStatement(_, _, _, let promise): return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil) } - case .closeCommand(let closeContext): + case .closeCommand(let closeContext, let writePromise): + writePromise?.fail(psqlErrror) /// Use `cleanupContext` or not? return .failClose(closeContext, with: psqlErrror, cleanupContext: nil) } } @@ -934,17 +936,17 @@ struct ConnectionStateMachine { } switch task { - case .extendedQuery(let queryContext): + case .extendedQuery(let queryContext, let promise): self.state = .modifying // avoid CoW var extendedQuery = ExtendedQueryStateMachine(queryContext: queryContext) - let action = extendedQuery.start() + let action = extendedQuery.start(promise) self.state = .extendedQuery(extendedQuery, connectionContext) return self.modify(with: action) - case .closeCommand(let closeContext): + case .closeCommand(let closeContext, let promise): self.state = .modifying // avoid CoW var closeStateMachine = CloseStateMachine(closeContext: closeContext) - let action = closeStateMachine.start() + let action = closeStateMachine.start(promise) self.state = .closeCommand(closeStateMachine, connectionContext) return self.modify(with: action) } @@ -1031,10 +1033,10 @@ extension ConnectionStateMachine { extension ConnectionStateMachine { mutating func modify(with action: ExtendedQueryStateMachine.Action) -> ConnectionStateMachine.ConnectionAction { switch action { - case .sendParseDescribeBindExecuteSync(let query): - return .sendParseDescribeBindExecuteSync(query) - case .sendBindExecuteSync(let executeStatement): - return .sendBindExecuteSync(executeStatement) + case .sendParseDescribeBindExecuteSync(let query, let promise): + return .sendParseDescribeBindExecuteSync(query, promise: promise) + case .sendBindExecuteSync(let executeStatement, let promise): + return .sendBindExecuteSync(executeStatement, promise: promise) case .failQuery(let requestContext, with: let error): let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) return .failQuery(requestContext, with: error, cleanupContext: cleanupContext) @@ -1057,8 +1059,8 @@ extension ConnectionStateMachine { return .read case .wait: return .wait - case .sendParseDescribeSync(name: let name, query: let query, bindingDataTypes: let bindingDataTypes): - return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes) + case .sendParseDescribeSync(name: let name, query: let query, bindingDataTypes: let bindingDataTypes, let promise): + return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes, promise: promise) case .succeedPreparedStatementCreation(let promise, with: let rowDescription): return .succeedPreparedStatementCreation(promise, with: rowDescription) case .failPreparedStatementCreation(let promise, with: let error): @@ -1094,8 +1096,8 @@ extension ConnectionStateMachine { extension ConnectionStateMachine { mutating func modify(with action: CloseStateMachine.Action) -> ConnectionStateMachine.ConnectionAction { switch action { - case .sendCloseSync(let sendClose): - return .sendCloseSync(sendClose) + case .sendCloseSync(let sendClose, let promise): + return .sendCloseSync(sendClose, promise: promise) case .succeedClose(let closeContext): return .succeedClose(closeContext) case .failClose(let closeContext, with: let error): diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 78f0d202..40940ce4 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -25,10 +25,10 @@ struct ExtendedQueryStateMachine { } enum Action { - case sendParseDescribeBindExecuteSync(PostgresQuery) - case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType]) - case sendBindExecuteSync(PSQLExecuteStatement) - + case sendParseDescribeBindExecuteSync(PostgresQuery, promise: EventLoopPromise?) + case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType], promise: EventLoopPromise?) + case sendBindExecuteSync(PSQLExecuteStatement, promise: EventLoopPromise?) + // --- general actions case failQuery(EventLoopPromise, with: PSQLError) case succeedQuery(EventLoopPromise, with: QueryResult) @@ -56,7 +56,7 @@ struct ExtendedQueryStateMachine { self.state = .initialized(queryContext) } - mutating func start() -> Action { + mutating func start(_ promise: EventLoopPromise?) -> Action { guard case .initialized(let queryContext) = self.state else { preconditionFailure("Start should only be called, if the query has been initialized") } @@ -65,7 +65,7 @@ struct ExtendedQueryStateMachine { case .unnamed(let query, _): return self.avoidingStateMachineCoW { state -> Action in state = .messagesSent(queryContext) - return .sendParseDescribeBindExecuteSync(query) + return .sendParseDescribeBindExecuteSync(query, promise: promise) } case .executeStatement(let prepared, _): @@ -76,13 +76,14 @@ struct ExtendedQueryStateMachine { case .none: state = .noDataMessageReceived(queryContext) } - return .sendBindExecuteSync(prepared) + return .sendBindExecuteSync(prepared, promise: promise) } + /// Not my code, but this is ignoring the last argument which is a promise? is that fine? case .prepareStatement(let name, let query, let bindingDataTypes, _): return self.avoidingStateMachineCoW { state -> Action in state = .messagesSent(queryContext) - return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes) + return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes, promise: promise) } } } diff --git a/Sources/PostgresNIO/New/NotificationListener.swift b/Sources/PostgresNIO/New/NotificationListener.swift index 4982b8ad..f69a4a55 100644 --- a/Sources/PostgresNIO/New/NotificationListener.swift +++ b/Sources/PostgresNIO/New/NotificationListener.swift @@ -42,7 +42,10 @@ final class NotificationListener: @unchecked Sendable { self.state = .closure(context, closure) } - func startListeningSucceeded(handler: PostgresChannelHandler) { + func startListeningSucceeded( + handler: PostgresChannelHandler, + writePromise: EventLoopPromise? + ) { self.eventLoop.preconditionInEventLoop() let handlerLoopBound = NIOLoopBound(handler, eventLoop: self.eventLoop) @@ -56,7 +59,7 @@ final class NotificationListener: @unchecked Sendable { switch reason { case .cancelled: eventLoop.execute { - handlerLoopBound.value.cancelNotificationListener(channel: channel, id: listenerID) + handlerLoopBound.value.cancelNotificationListener(channel: channel, id: listenerID, writePromise: nil) } case .finished: @@ -70,12 +73,14 @@ final class NotificationListener: @unchecked Sendable { let notificationSequence = PostgresNotificationSequence(base: stream) checkedContinuation.resume(returning: notificationSequence) + writePromise?.succeed(()) case .streamListening, .done: fatalError("Invalid state: \(self.state)") case .closure: - break // ignore + writePromise?.succeed(()) + // ignore } } diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index 363f9394..a7dbb8db 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -10,12 +10,12 @@ enum HandlerTask { } enum PSQLTask { - case extendedQuery(ExtendedQueryContext) - case closeCommand(CloseCommandContext) + case extendedQuery(ExtendedQueryContext, writePromise: EventLoopPromise?) + case closeCommand(CloseCommandContext, writePromise: EventLoopPromise?) func failWithError(_ error: PSQLError) { switch self { - case .extendedQuery(let extendedQueryContext): + case .extendedQuery(let extendedQueryContext, let writePromise): switch extendedQueryContext.query { case .unnamed(_, let eventLoopPromise): eventLoopPromise.fail(error) @@ -24,9 +24,11 @@ enum PSQLTask { case .prepareStatement(_, _, _, let eventLoopPromise): eventLoopPromise.fail(error) } + writePromise?.fail(error) - case .closeCommand(let closeCommandContext): + case .closeCommand(let closeCommandContext, let writePromise): closeCommandContext.promise.fail(error) + writePromise?.fail(error) } } } diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index a3190aa7..1f02a768 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -73,7 +73,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } // MARK: Channel handler incoming - + func channelActive(context: ChannelHandlerContext) { // `fireChannelActive` needs to be called BEFORE we set the state machine to connected, // since we want to make sure that upstream handlers know about the active connection before @@ -202,34 +202,37 @@ final class PostgresChannelHandler: ChannelDuplexHandler { switch handlerTask { case .closeCommand(let command): - psqlTask = .closeCommand(command) + psqlTask = .closeCommand(command, writePromise: promise) case .extendedQuery(let query): - psqlTask = .extendedQuery(query) + psqlTask = .extendedQuery(query, writePromise: promise) case .startListening(let listener): switch self.listenState.startListening(listener) { case .startListening(let channel): - psqlTask = self.makeStartListeningQuery(channel: channel, context: context) + psqlTask = self.makeStartListeningQuery(channel: channel, context: context, writePromise: promise) case .none: + promise?.succeed(()) return case .succeedListenStart(let listener): - listener.startListeningSucceeded(handler: self) + listener.startListeningSucceeded(handler: self, writePromise: promise) return } case .cancelListening(let channel, let id): switch self.listenState.cancelNotificationListener(channel: channel, id: id) { case .none: + promise?.succeed(()) return case .stopListening(let channel, let listener): - psqlTask = self.makeUnlistenQuery(channel: channel, context: context) + psqlTask = self.makeUnlistenQuery(channel: channel, context: context, writePromise: promise) listener.failed(CancellationError()) case .cancelListener(let listener): listener.failed(CancellationError()) + promise?.fail(CancellationError()) return } case .executePreparedStatement(let preparedStatement): @@ -240,19 +243,23 @@ final class PostgresChannelHandler: ChannelDuplexHandler { case .prepareStatement: psqlTask = self.makePrepareStatementTask( preparedStatement: preparedStatement, - context: context + context: context, + writePromise: promise ) case .waitForAlreadyInFlightPreparation: // The state machine already keeps track of this // and will execute the statement as soon as it's prepared + promise?.succeed(()) return case .executeStatement(let rowDescription): psqlTask = self.makeExecutePreparedStatementTask( preparedStatement: preparedStatement, - rowDescription: rowDescription + rowDescription: rowDescription, + writePromise: promise ) case .returnError(let error): preparedStatement.promise.fail(error) + promise?.fail(error) return } } @@ -280,6 +287,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { case PSQLOutgoingEvent.authenticate(let authContext): let action = self.state.provideAuthenticationContext(authContext) self.run(action, with: context) + promise?.succeed(()) case PSQLOutgoingEvent.gracefulShutdown: let action = self.state.gracefulClose(promise) @@ -292,12 +300,13 @@ final class PostgresChannelHandler: ChannelDuplexHandler { // MARK: Listening - func cancelNotificationListener(channel: String, id: Int) { + func cancelNotificationListener(channel: String, id: Int, writePromise: EventLoopPromise?) { self.eventLoop.preconditionInEventLoop() switch self.listenState.cancelNotificationListener(channel: channel, id: id) { case .cancelListener(let listener): listener.cancelled() + writePromise?.succeed(()) case .stopListening(let channel, cancelListener: let listener): listener.cancelled() @@ -306,18 +315,21 @@ final class PostgresChannelHandler: ChannelDuplexHandler { return } - let query = self.makeUnlistenQuery(channel: channel, context: context) + let query = self.makeUnlistenQuery(channel: channel, context: context, writePromise: writePromise) let action = self.state.enqueue(task: query) self.run(action, with: context) case .none: - break + writePromise?.succeed(()) } } // MARK: Channel handler actions - private func run(_ action: ConnectionStateMachine.ConnectionAction, with context: ChannelHandlerContext) { + private func run( + _ action: ConnectionStateMachine.ConnectionAction, + with context: ChannelHandlerContext + ) { self.logger.trace("Run action", metadata: [.connectionAction: "\(action)"]) switch action { @@ -345,12 +357,18 @@ final class PostgresChannelHandler: ChannelDuplexHandler { self.closeConnectionAndCleanup(cleanupContext, context: context) case .fireChannelInactive: context.fireChannelInactive() - case .sendParseDescribeSync(let name, let query, let bindingDataTypes): - self.sendParseDescribeAndSyncMessage(statementName: name, query: query, bindingDataTypes: bindingDataTypes, context: context) - case .sendBindExecuteSync(let executeStatement): - self.sendBindExecuteAndSyncMessage(executeStatement: executeStatement, context: context) - case .sendParseDescribeBindExecuteSync(let query): - self.sendParseDescribeBindExecuteAndSyncMessage(query: query, context: context) + case .sendParseDescribeSync(let name, let query, let bindingDataTypes, let promise): + self.sendParseDescribeAndSyncMessage( + statementName: name, + query: query, + bindingDataTypes: bindingDataTypes, + context: context, + promise: promise + ) + case .sendBindExecuteSync(let executeStatement, let promise): + self.sendBindExecuteAndSyncMessage(executeStatement: executeStatement, context: context, promise: promise) + case .sendParseDescribeBindExecuteSync(let query, let promise): + self.sendParseDescribeBindExecuteAndSyncMessage(query: query, context: context, promise: promise) case .succeedQuery(let promise, with: let result): self.succeedQuery(promise, result: result, context: context) case .failQuery(let promise, with: let error, let cleanupContext): @@ -361,7 +379,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { case .forwardRows(let rows): self.rowStream!.receive(rows) - + case .forwardStreamComplete(let buffer, let commandTag): guard let rowStream = self.rowStream else { // if the stream was cancelled we don't have it here anymore. @@ -372,8 +390,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { rowStream.receive(buffer) } rowStream.receive(completion: .success(commandTag)) - - + case .forwardStreamError(let error, let read, let cleanupContext): self.rowStream!.receive(completion: .failure(error)) self.rowStream = nil @@ -382,7 +399,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } else if read { context.read() } - + case .provideAuthenticationContext: context.fireUserInboundEventTriggered(PSQLEvent.readyForStartup) @@ -414,10 +431,10 @@ final class PostgresChannelHandler: ChannelDuplexHandler { if let cleanupContext = cleanupContext { self.closeConnectionAndCleanup(cleanupContext, context: context) } - case .sendCloseSync(let sendClose): - self.sendCloseAndSyncMessage(sendClose, context: context) + case .sendCloseSync(let sendClose, let promise): + self.sendCloseAndSyncMessage(sendClose, context: context, promise: promise) case .succeedClose(let closeContext): - closeContext.promise.succeed(Void()) + closeContext.promise.succeed(()) case .failClose(let closeContext, with: let error, let cleanupContext): closeContext.promise.fail(error) if let cleanupContext = cleanupContext { @@ -476,17 +493,21 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } } - private func sendCloseAndSyncMessage(_ sendClose: CloseTarget, context: ChannelHandlerContext) { + private func sendCloseAndSyncMessage( + _ sendClose: CloseTarget, + context: ChannelHandlerContext, + promise: EventLoopPromise? + ) { switch sendClose { case .preparedStatement(let name): self.encoder.closePreparedStatement(name) self.encoder.sync() - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) - + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: promise) + case .portal(let name): self.encoder.closePortal(name) self.encoder.sync() - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: promise) } } @@ -494,18 +515,20 @@ final class PostgresChannelHandler: ChannelDuplexHandler { statementName: String, query: String, bindingDataTypes: [PostgresDataType], - context: ChannelHandlerContext + context: ChannelHandlerContext, + promise: EventLoopPromise? ) { precondition(self.rowStream == nil, "Expected to not have an open stream at this point") self.encoder.parse(preparedStatementName: statementName, query: query, parameters: bindingDataTypes) self.encoder.describePreparedStatement(statementName) self.encoder.sync() - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: promise) } private func sendBindExecuteAndSyncMessage( executeStatement: PSQLExecuteStatement, - context: ChannelHandlerContext + context: ChannelHandlerContext, + promise: EventLoopPromise? ) { self.encoder.bind( portalName: "", @@ -514,12 +537,13 @@ final class PostgresChannelHandler: ChannelDuplexHandler { ) self.encoder.execute(portalName: "") self.encoder.sync() - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: promise) } private func sendParseDescribeBindExecuteAndSyncMessage( query: PostgresQuery, - context: ChannelHandlerContext + context: ChannelHandlerContext, + promise: EventLoopPromise? ) { precondition(self.rowStream == nil, "Expected to not have an open stream at this point") let unnamedStatementName = "" @@ -532,7 +556,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { self.encoder.bind(portalName: "", preparedStatementName: unnamedStatementName, bind: query.binds) self.encoder.execute(portalName: "") self.encoder.sync() - context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: promise) } private func succeedQuery( @@ -592,7 +616,11 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } } - private func makeStartListeningQuery(channel: String, context: ChannelHandlerContext) -> PSQLTask { + private func makeStartListeningQuery( + channel: String, + context: ChannelHandlerContext, + writePromise: EventLoopPromise? + ) -> PSQLTask { let promise = context.eventLoop.makePromise(of: PSQLRowStream.self) let query = ExtendedQueryContext( query: PostgresQuery(unsafeSQL: #"LISTEN "\#(channel)";"#), @@ -602,23 +630,28 @@ final class PostgresChannelHandler: ChannelDuplexHandler { let loopBound = NIOLoopBound((self, context), eventLoop: self.eventLoop) promise.futureResult.whenComplete { result in let (selfTransferred, context) = loopBound.value - selfTransferred.startListenCompleted(result, for: channel, context: context) + selfTransferred.startListenCompleted(result, for: channel, context: context, writePromise: nil) } - return .extendedQuery(query) + return .extendedQuery(query, writePromise: writePromise) } - private func startListenCompleted(_ result: Result, for channel: String, context: ChannelHandlerContext) { + private func startListenCompleted( + _ result: Result, + for channel: String, + context: ChannelHandlerContext, + writePromise: EventLoopPromise? + ) { switch result { case .success: switch self.listenState.startListeningSucceeded(channel: channel) { case .activateListeners(let listeners): for list in listeners { - list.startListeningSucceeded(handler: self) + list.startListeningSucceeded(handler: self, writePromise: nil) } - + writePromise?.succeed(()) /// Should we instead do smth like "whenAllSucceed"? case .stopListening: - let task = self.makeUnlistenQuery(channel: channel, context: context) + let task = self.makeUnlistenQuery(channel: channel, context: context, writePromise: writePromise) let action = self.state.enqueue(task: task) self.run(action, with: context) } @@ -637,10 +670,15 @@ final class PostgresChannelHandler: ChannelDuplexHandler { for list in listeners { list.failed(finalError) } + writePromise?.fail(finalError) } } - private func makeUnlistenQuery(channel: String, context: ChannelHandlerContext) -> PSQLTask { + private func makeUnlistenQuery( + channel: String, + context: ChannelHandlerContext, + writePromise: EventLoopPromise? + ) -> PSQLTask { let promise = context.eventLoop.makePromise(of: PSQLRowStream.self) let query = ExtendedQueryContext( query: PostgresQuery(unsafeSQL: #"UNLISTEN "\#(channel)";"#), @@ -650,25 +688,26 @@ final class PostgresChannelHandler: ChannelDuplexHandler { let loopBound = NIOLoopBound((self, context), eventLoop: self.eventLoop) promise.futureResult.whenComplete { result in let (selfTransferred, context) = loopBound.value - selfTransferred.stopListenCompleted(result, for: channel, context: context) + selfTransferred.stopListenCompleted(result, for: channel, context: context, writePromise: nil) } - return .extendedQuery(query) + return .extendedQuery(query, writePromise: writePromise) } private func stopListenCompleted( _ result: Result, for channel: String, - context: ChannelHandlerContext + context: ChannelHandlerContext, + writePromise: EventLoopPromise? ) { switch result { case .success: switch self.listenState.stopListeningSucceeded(channel: channel) { case .none: - break + writePromise?.succeed() case .startListening: - let task = self.makeStartListeningQuery(channel: channel, context: context) + let task = self.makeStartListeningQuery(channel: channel, context: context, writePromise: writePromise) let action = self.state.enqueue(task: task) self.run(action, with: context) } @@ -676,6 +715,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { case .failure(let error): let action = self.state.errorHappened(.unlistenError(underlying: error)) self.run(action, with: context) + writePromise?.fail(error) /// Should I pass the promise to the action? seemed troublesome } } @@ -696,7 +736,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private func makePrepareStatementTask( preparedStatement: PreparedStatementContext, - context: ChannelHandlerContext + context: ChannelHandlerContext, + writePromise: EventLoopPromise? ) -> PSQLTask { let promise = self.eventLoop.makePromise(of: RowDescription?.self) let loopBound = NIOLoopBound((self, context), eventLoop: self.eventLoop) @@ -723,28 +764,35 @@ final class PostgresChannelHandler: ChannelDuplexHandler { ) } } - return .extendedQuery(.init( - name: preparedStatement.name, - query: preparedStatement.sql, - bindingDataTypes: preparedStatement.bindingDataTypes, - logger: preparedStatement.logger, - promise: promise - )) + return .extendedQuery( + .init( + name: preparedStatement.name, + query: preparedStatement.sql, + bindingDataTypes: preparedStatement.bindingDataTypes, + logger: preparedStatement.logger, + promise: promise + ), + writePromise: writePromise + ) } private func makeExecutePreparedStatementTask( preparedStatement: PreparedStatementContext, - rowDescription: RowDescription? + rowDescription: RowDescription?, + writePromise: EventLoopPromise? ) -> PSQLTask { - return .extendedQuery(.init( - executeStatement: .init( - name: preparedStatement.name, - binds: preparedStatement.bindings, - rowDescription: rowDescription + return .extendedQuery( + .init( + executeStatement: .init( + name: preparedStatement.name, + binds: preparedStatement.bindings, + rowDescription: rowDescription + ), + logger: preparedStatement.logger, + promise: preparedStatement.promise ), - logger: preparedStatement.logger, - promise: preparedStatement.promise - )) + writePromise: writePromise + ) } private func prepareStatementComplete( @@ -757,16 +805,18 @@ final class PostgresChannelHandler: ChannelDuplexHandler { rowDescription: rowDescription ) for preparedStatement in action.statements { - let action = self.state.enqueue(task: .extendedQuery(.init( - executeStatement: .init( - name: preparedStatement.name, - binds: preparedStatement.bindings, - rowDescription: action.rowDescription + let action = self.state.enqueue(task: .extendedQuery( + .init( + executeStatement: .init( + name: preparedStatement.name, + binds: preparedStatement.bindings, + rowDescription: action.rowDescription + ), + logger: preparedStatement.logger, + promise: preparedStatement.promise ), - logger: preparedStatement.logger, - promise: preparedStatement.promise + writePromise: nil // Ignore )) - ) self.run(action, with: context) } } diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift index df881f90..67134329 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/AuthenticationStateMachineTests.swift @@ -9,7 +9,7 @@ class AuthenticationStateMachineTests: XCTestCase { var state = ConnectionStateMachine(requireBackendKeyData: true) XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) - + XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) XCTAssertEqual(state.authenticationMessageReceived(.plaintext), .sendPasswordMessage(.cleartext, authContext)) XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait) @@ -50,7 +50,7 @@ class AuthenticationStateMachineTests: XCTestCase { var state = ConnectionStateMachine(requireBackendKeyData: true) XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) - + let saslResponse = state.authenticationMessageReceived(.sasl(names: ["SCRAM-SHA-256"])) guard case .sendSaslInitialResponse(name: let name, initialResponse: let responseData) = saslResponse else { return XCTFail("\(saslResponse) is not .sendSaslInitialResponse") diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift index f3d72a5e..876c6b54 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ConnectionStateMachineTests.swift @@ -164,7 +164,7 @@ class ConnectionStateMachineTests: XCTestCase { logger: .psqlTest, promise: queryPromise) - XCTAssertEqual(state.enqueue(task: .extendedQuery(extendedQueryContext)), .wait) + XCTAssertEqual(state.enqueue(task: .extendedQuery(extendedQueryContext, writePromise: nil)), .wait) XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext) XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext)) XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext)) @@ -178,7 +178,7 @@ class ConnectionStateMachineTests: XCTestCase { .file: "auth.c" ] XCTAssertEqual(state.errorReceived(.init(fields: fields)), - .closeConnectionAndCleanup(.init(action: .close, tasks: [.extendedQuery(extendedQueryContext)], error: .server(.init(fields: fields)), closePromise: nil))) + .closeConnectionAndCleanup(.init(action: .close, tasks: [.extendedQuery(extendedQueryContext, writePromise: nil)], error: .server(.init(fields: fields)), closePromise: nil))) XCTAssertNil(queryPromise.futureResult._value) diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index 40e32468..283db3d0 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -15,7 +15,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { let query: PostgresQuery = "DELETE FROM table WHERE id=\(1)" let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) - XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext, writePromise: nil)), .sendParseDescribeBindExecuteSync(query, promise: nil)) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) XCTAssertEqual(state.noDataReceived(), .wait) @@ -33,7 +33,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { let query: PostgresQuery = "SELECT version()" let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) - XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext, writePromise: nil)), .sendParseDescribeBindExecuteSync(query, promise: nil)) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) @@ -87,7 +87,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { let query: PostgresQuery = "DELETE FROM table WHERE id=\(1)" let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) - XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext, writePromise: nil)), .sendParseDescribeBindExecuteSync(query, promise: nil)) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) @@ -105,7 +105,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { let query: PostgresQuery = "SELECT version()" let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) - XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext, writePromise: nil)), .sendParseDescribeBindExecuteSync(query, promise: nil)) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) @@ -149,7 +149,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { let query: PostgresQuery = "SELECT version()" let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) - XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext, writePromise: nil)), .sendParseDescribeBindExecuteSync(query, promise: nil)) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) @@ -191,7 +191,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { let query: PostgresQuery = "SELECT version()" let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) - XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext, writePromise: nil)), .sendParseDescribeBindExecuteSync(query, promise: nil)) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) @@ -245,7 +245,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { let query: PostgresQuery = "SELECT version()" let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) - XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext, writePromise: nil)), .sendParseDescribeBindExecuteSync(query, promise: nil)) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) @@ -266,7 +266,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { let query: PostgresQuery = "SELECT version()" let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) - XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext, writePromise: nil)), .sendParseDescribeBindExecuteSync(query, promise: nil)) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) XCTAssertEqual(state.cancelQueryStream(), .failQuery(promise, with: .queryCancelled, cleanupContext: .none)) diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift index 547f5cdf..3957f880 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift @@ -15,8 +15,8 @@ class PrepareStatementStateMachineTests: XCTestCase { name: name, query: query, bindingDataTypes: [], logger: .psqlTest, promise: promise ) - XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)), - .sendParseDescribeSync(name: name, query: query, bindingDataTypes: [])) + XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext, writePromise: nil)), + .sendParseDescribeSync(name: name, query: query, bindingDataTypes: [], promise: nil)) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) @@ -41,8 +41,8 @@ class PrepareStatementStateMachineTests: XCTestCase { name: name, query: query, bindingDataTypes: [], logger: .psqlTest, promise: promise ) - XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)), - .sendParseDescribeSync(name: name, query: query, bindingDataTypes: [])) + XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext, writePromise: nil)), + .sendParseDescribeSync(name: name, query: query, bindingDataTypes: [], promise: nil)) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) @@ -63,8 +63,8 @@ class PrepareStatementStateMachineTests: XCTestCase { name: name, query: query, bindingDataTypes: [], logger: .psqlTest, promise: promise ) - XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)), - .sendParseDescribeSync(name: name, query: query, bindingDataTypes: [])) + XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext, writePromise: nil)), + .sendParseDescribeSync(name: name, query: query, bindingDataTypes: [], promise: nil)) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) diff --git a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift index 769bde4b..3f406598 100644 --- a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift @@ -14,7 +14,7 @@ class Date_PSQLCodableTests: XCTestCase { var result: Date? XCTAssertNoThrow(result = try Date(from: &buffer, type: .timestamptz, format: .binary, context: .default)) - XCTAssertEqual(value, result) + XCTAssertEqual(value.timeIntervalSince1970, result?.timeIntervalSince1970 ?? 0, accuracy: 0.001) } func testDecodeRandomDate() { diff --git a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift index 9a1224d8..d4007600 100644 --- a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift @@ -22,7 +22,7 @@ extension PostgresNIO.ConnectionStateMachine.ConnectionAction: Swift.Equatable { return lhs == rhs case (.sendPasswordMessage(let lhsMethod, let lhsAuthContext), sendPasswordMessage(let rhsMethod, let rhsAuthContext)): return lhsMethod == rhsMethod && lhsAuthContext == rhsAuthContext - case (.sendParseDescribeBindExecuteSync(let lquery), sendParseDescribeBindExecuteSync(let rquery)): + case (.sendParseDescribeBindExecuteSync(let lquery, _), sendParseDescribeBindExecuteSync(let rquery, _)): return lquery == rquery case (.fireEventReadyForQuery, .fireEventReadyForQuery): return true @@ -36,7 +36,7 @@ extension PostgresNIO.ConnectionStateMachine.ConnectionAction: Swift.Equatable { return lhsBuffer == rhsBuffer && lhsCommandTag == rhsCommandTag case (.forwardStreamError(let lhsError, let lhsRead, let lhsCleanupContext), .forwardStreamError(let rhsError , let rhsRead, let rhsCleanupContext)): return lhsError == rhsError && lhsRead == rhsRead && lhsCleanupContext == rhsCleanupContext - case (.sendParseDescribeSync(let lhsName, let lhsQuery, let lhsDataTypes), .sendParseDescribeSync(let rhsName, let rhsQuery, let rhsDataTypes)): + case (.sendParseDescribeSync(let lhsName, let lhsQuery, let lhsDataTypes, _), .sendParseDescribeSync(let rhsName, let rhsQuery, let rhsDataTypes, _)): return lhsName == rhsName && lhsQuery == rhsQuery && lhsDataTypes == rhsDataTypes case (.succeedPreparedStatementCreation(let lhsPromise, let lhsRowDescription), .succeedPreparedStatementCreation(let rhsPromise, let rhsRowDescription)): return lhsPromise.futureResult === rhsPromise.futureResult && lhsRowDescription == rhsRowDescription @@ -109,9 +109,9 @@ extension PostgresNIO.PSQLError: Swift.Equatable { extension PostgresNIO.PSQLTask: Swift.Equatable { public static func == (lhs: PSQLTask, rhs: PSQLTask) -> Bool { switch (lhs, rhs) { - case (.extendedQuery(let lhs), .extendedQuery(let rhs)): + case (.extendedQuery(let lhs, _), .extendedQuery(let rhs, _)): return lhs === rhs - case (.closeCommand(let lhs), .closeCommand(let rhs)): + case (.closeCommand(let lhs, _), .closeCommand(let rhs, _)): return lhs === rhs default: return false