diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index e267d8f9..e1a9da7e 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -438,6 +438,47 @@ extension PostgresConnection { } } + /// Run a simple text-only query on the Postgres server the connection is connected to. + /// WARNING: This function is not yet API and is incomplete. + /// The return type will change to another stream. + /// + /// - Parameters: + /// - query: The simple query to run + /// - logger: The `Logger` to log into for the query + /// - file: The file, the query was started in. Used for better error reporting. + /// - line: The line, the query was started in. Used for better error reporting. + /// - Returns: A ``PostgresRowSequence`` containing the rows the server sent as the query result. + /// The sequence be discarded. + @discardableResult + public func __simpleQuery( + _ query: String, + logger: Logger, + file: String = #fileID, + line: Int = #line + ) async throws -> PostgresRowSequence { + var logger = logger + logger[postgresMetadataKey: .connectionID] = "\(self.id)" + + let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let context = SimpleQueryContext( + query: query, + logger: logger, + promise: promise + ) + + self.channel.write(HandlerTask.simpleQuery(context), promise: nil) + + do { + return try await promise.futureResult.map({ $0.asyncSequence() }).get() + } catch var error as PSQLError { + error.file = file + error.line = line + // FIXME: just pass the string as a simple query, instead of acting like this is a PostgresQuery. + error.query = PostgresQuery(unsafeSQL: query) + throw error // rethrow with more metadata + } + } + /// Start listening for a channel public func listen(_ channel: String) async throws -> PostgresNotificationSequence { let id = self.internalListenID.loadThenWrappingIncrement(ordering: .relaxed) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 9d264bcc..4faf57a7 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -31,6 +31,7 @@ struct ConnectionStateMachine { case readyForQuery(ConnectionContext) case extendedQuery(ExtendedQueryStateMachine, ConnectionContext) + case simpleQuery(SimpleQueryStateMachine, ConnectionContext) case closeCommand(CloseStateMachine, ConnectionContext) case closing(PSQLError?) @@ -87,6 +88,7 @@ struct ConnectionStateMachine { // --- general actions case sendParseDescribeBindExecuteSync(PostgresQuery) case sendBindExecuteSync(PSQLExecuteStatement) + case sendQuery(String) case failQuery(EventLoopPromise, with: PSQLError, cleanupContext: CleanUpContext?) case succeedQuery(EventLoopPromise, with: QueryResult) @@ -156,6 +158,7 @@ struct ConnectionStateMachine { .authenticated, .readyForQuery, .extendedQuery, + .simpleQuery, .closeCommand, .closing, .closed, @@ -213,6 +216,7 @@ struct ConnectionStateMachine { .authenticating, .readyForQuery, .extendedQuery, + .simpleQuery, .closeCommand: return self.errorHappened(.serverClosedConnection(underlying: nil)) @@ -243,6 +247,7 @@ struct ConnectionStateMachine { .authenticated, .readyForQuery, .extendedQuery, + .simpleQuery, .closeCommand, .closing, .closed: @@ -270,6 +275,7 @@ struct ConnectionStateMachine { .authenticated, .readyForQuery, .extendedQuery, + .simpleQuery, .closeCommand, .closing, .closed: @@ -290,6 +296,7 @@ struct ConnectionStateMachine { .authenticated, .readyForQuery, .extendedQuery, + .simpleQuery, .closeCommand, .closing, .closed: @@ -314,6 +321,7 @@ struct ConnectionStateMachine { .authenticated, .readyForQuery, .extendedQuery, + .simpleQuery, .closeCommand, .closing, .closed: @@ -379,6 +387,12 @@ struct ConnectionStateMachine { self.state = .extendedQuery(query, connectionContext) return .wait + case .simpleQuery(let query, var connectionContext): + self.state = .modifying // avoid CoW + connectionContext.parameters[status.parameter] = status.value + self.state = .simpleQuery(query, connectionContext) + return .wait + case .closeCommand(let closeState, var connectionContext): self.state = .modifying // avoid CoW connectionContext.parameters[status.parameter] = status.value @@ -429,6 +443,15 @@ struct ConnectionStateMachine { self.state = .extendedQuery(extendedQueryState, connectionContext) return self.modify(with: action) + case .simpleQuery(var simpleQueryState, let connectionContext): + if simpleQueryState.isComplete { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.error(errorMessage))) + } + self.state = .modifying // avoid CoW + let action = simpleQueryState.errorReceived(errorMessage) + self.state = .simpleQuery(simpleQueryState, connectionContext) + return self.modify(with: action) + case .closing: // If the state machine is in state `.closing`, the connection shutdown was initiated // by the client. This means a `TERMINATE` message has already been sent and the @@ -463,6 +486,13 @@ struct ConnectionStateMachine { let action = queryState.errorHappened(error) return self.modify(with: action) } + case .simpleQuery(var queryState, _): + if queryState.isComplete { + return self.closeConnectionAndCleanup(error) + } else { + let action = queryState.errorHappened(error) + return self.modify(with: action) + } case .closeCommand(var closeState, _): if closeState.isComplete { return self.closeConnectionAndCleanup(error) @@ -496,6 +526,12 @@ struct ConnectionStateMachine { self.state = .extendedQuery(extendedQuery, connectionContext) return self.modify(with: action) + case .simpleQuery(var simpleQuery, let connectionContext): + self.state = .modifying // avoid CoW + let action = simpleQuery.noticeReceived(notice) + self.state = .simpleQuery(simpleQuery, connectionContext) + return self.modify(with: action) + default: return .wait } @@ -526,6 +562,15 @@ struct ConnectionStateMachine { connectionContext.transactionState = transactionState + self.state = .readyForQuery(connectionContext) + return self.executeNextQueryFromQueue() + case .simpleQuery(let simpleQuery, var connectionContext): + guard simpleQuery.isComplete else { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.readyForQuery(transactionState))) + } + + connectionContext.transactionState = transactionState + self.state = .readyForQuery(connectionContext) return self.executeNextQueryFromQueue() case .closeCommand(let closeStateMachine, var connectionContext): @@ -537,7 +582,7 @@ struct ConnectionStateMachine { self.state = .readyForQuery(connectionContext) return self.executeNextQueryFromQueue() - + default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.readyForQuery(transactionState))) } @@ -558,6 +603,7 @@ struct ConnectionStateMachine { .authenticating, .closeCommand, .extendedQuery, + .simpleQuery, .sslNegotiated, .sslHandlerAdded, .sslRequestSent, @@ -590,6 +636,8 @@ struct ConnectionStateMachine { case .prepareStatement(_, _, _, let promise): return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil) } + case .simpleQuery(let queryContext): + return .failQuery(queryContext.promise, with: psqlErrror, cleanupContext: nil) case .closeCommand(let closeContext): return .failClose(closeContext, with: psqlErrror, cleanupContext: nil) } @@ -615,7 +663,13 @@ struct ConnectionStateMachine { let action = extendedQuery.channelReadComplete() self.state = .extendedQuery(extendedQuery, connectionContext) return self.modify(with: action) - + + case .simpleQuery(var simpleQuery, let connectionContext): + self.state = .modifying // avoid CoW + let action = simpleQuery.channelReadComplete() + self.state = .simpleQuery(simpleQuery, connectionContext) + return self.modify(with: action) + case .modifying: preconditionFailure("Invalid state") } @@ -643,6 +697,12 @@ struct ConnectionStateMachine { self.state = .extendedQuery(extendedQuery, connectionContext) return self.modify(with: action) + case .simpleQuery(var simpleQuery, let connectionContext): + self.state = .modifying // avoid CoW + let action = simpleQuery.readEventCaught() + self.state = .simpleQuery(simpleQuery, connectionContext) + return self.modify(with: action) + case .closeCommand(var closeState, let connectionContext): self.state = .modifying // avoid CoW let action = closeState.readEventCaught() @@ -708,6 +768,12 @@ struct ConnectionStateMachine { self.state = .extendedQuery(queryState, connectionContext) return self.modify(with: action) + case .simpleQuery(var queryState, let connectionContext) where !queryState.isComplete: + self.state = .modifying // avoid CoW + let action = queryState.rowDescriptionReceived(description) + self.state = .simpleQuery(queryState, connectionContext) + return self.modify(with: action) + default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.rowDescription(description))) } @@ -742,60 +808,100 @@ struct ConnectionStateMachine { } mutating func commandCompletedReceived(_ commandTag: String) -> ConnectionAction { - guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { + switch self.state { + case .extendedQuery(var queryState, let connectionContext) where !queryState.isComplete: + self.state = .modifying // avoid CoW + let action = queryState.commandCompletedReceived(commandTag) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + + case .simpleQuery(var queryState, let connectionContext) where !queryState.isComplete: + self.state = .modifying // avoid CoW + let action = queryState.commandCompletedReceived(commandTag) + self.state = .simpleQuery(queryState, connectionContext) + return self.modify(with: action) + + default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.commandComplete(commandTag))) } - - self.state = .modifying // avoid CoW - let action = queryState.commandCompletedReceived(commandTag) - self.state = .extendedQuery(queryState, connectionContext) - return self.modify(with: action) } mutating func emptyQueryResponseReceived() -> ConnectionAction { - guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { + switch self.state { + case .extendedQuery(var queryState, let connectionContext) where !queryState.isComplete: + self.state = .modifying // avoid CoW + let action = queryState.emptyQueryResponseReceived() + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + + case .simpleQuery(var queryState, let connectionContext) where !queryState.isComplete: + self.state = .modifying // avoid CoW + let action = queryState.emptyQueryResponseReceived() + self.state = .simpleQuery(queryState, connectionContext) + return self.modify(with: action) + + default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.emptyQueryResponse)) } - - self.state = .modifying // avoid CoW - let action = queryState.emptyQueryResponseReceived() - self.state = .extendedQuery(queryState, connectionContext) - return self.modify(with: action) } mutating func dataRowReceived(_ dataRow: DataRow) -> ConnectionAction { - guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { + switch self.state { + case .extendedQuery(var queryState, let connectionContext) where !queryState.isComplete: + self.state = .modifying // avoid CoW + let action = queryState.dataRowReceived(dataRow) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + + case .simpleQuery(var queryState, let connectionContext) where !queryState.isComplete: + self.state = .modifying // avoid CoW + let action = queryState.dataRowReceived(dataRow) + self.state = .simpleQuery(queryState, connectionContext) + return self.modify(with: action) + + default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.dataRow(dataRow))) } - - self.state = .modifying // avoid CoW - let action = queryState.dataRowReceived(dataRow) - self.state = .extendedQuery(queryState, connectionContext) - return self.modify(with: action) } // MARK: Consumer mutating func cancelQueryStream() -> ConnectionAction { - guard case .extendedQuery(var queryState, let connectionContext) = self.state else { + switch self.state { + case .extendedQuery(var queryState, let connectionContext): + self.state = .modifying // avoid CoW + let action = queryState.cancel() + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + + case .simpleQuery(var queryState, let connectionContext): + self.state = .modifying // avoid CoW + let action = queryState.cancel() + self.state = .simpleQuery(queryState, connectionContext) + return self.modify(with: action) + + default: preconditionFailure("Tried to cancel stream without active query") } - - self.state = .modifying // avoid CoW - let action = queryState.cancel() - self.state = .extendedQuery(queryState, connectionContext) - return self.modify(with: action) } mutating func requestQueryRows() -> ConnectionAction { - guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { + switch self.state { + case .extendedQuery(var queryState, let connectionContext) where !queryState.isComplete: + self.state = .modifying // avoid CoW + let action = queryState.requestQueryRows() + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + + case .simpleQuery(var queryState, let connectionContext) where !queryState.isComplete: + self.state = .modifying // avoid CoW + let action = queryState.requestQueryRows() + self.state = .simpleQuery(queryState, connectionContext) + return self.modify(with: action) + + default: preconditionFailure("Tried to consume next row, without active query") } - - self.state = .modifying // avoid CoW - let action = queryState.requestQueryRows() - self.state = .extendedQuery(queryState, connectionContext) - return self.modify(with: action) } // MARK: - Private Methods - @@ -876,6 +982,36 @@ struct ConnectionStateMachine { return .failPreparedStatementCreation(promise, with: error, cleanupContext: cleanupContext) } + case .simpleQuery(var queryStateMachine, _): + let cleanupContext = self.setErrorAndCreateCleanupContext(error, closePromise: closePromise) + + if queryStateMachine.isComplete { + // in case the query state machine is complete all necessary actions have already + // been forwarded to the consumer. We can close and cleanup without caring about the + // substate machine. + return .closeConnectionAndCleanup(cleanupContext) + } + + let action = queryStateMachine.errorHappened(error) + switch action { + case .sendQuery, + .succeedQuery, + .forwardRows, + .forwardStreamComplete, + .wait, + .read: + preconditionFailure("Invalid query state machine action in state: \(self.state), action: \(action)") + + case .evaluateErrorAtConnectionLevel: + return .closeConnectionAndCleanup(cleanupContext) + + case .failQuery(let queryContext, with: let error): + return .failQuery(queryContext, with: error, cleanupContext: cleanupContext) + + case .forwardStreamError(let error, let read): + return .forwardStreamError(error, read: read, cleanupContext: cleanupContext) + } + case .closeCommand(var closeStateMachine, _): let cleanupContext = self.setErrorAndCreateCleanupContext(error, closePromise: closePromise) @@ -941,6 +1077,13 @@ struct ConnectionStateMachine { self.state = .extendedQuery(extendedQuery, connectionContext) return self.modify(with: action) + case .simpleQuery(let queryContext): + self.state = .modifying // avoid CoW + var simpleQuery = SimpleQueryStateMachine(queryContext: queryContext) + let action = simpleQuery.start() + self.state = .simpleQuery(simpleQuery, connectionContext) + return self.modify(with: action) + case .closeCommand(let closeContext): self.state = .modifying // avoid CoW var closeStateMachine = CloseStateMachine(closeContext: closeContext) @@ -1068,6 +1211,37 @@ extension ConnectionStateMachine { } } +extension ConnectionStateMachine { + mutating func modify(with action: SimpleQueryStateMachine.Action) -> ConnectionStateMachine.ConnectionAction { + switch action { + case .sendQuery(let query): + return .sendQuery(query) + case .failQuery(let requestContext, with: let error): + let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) + return .failQuery(requestContext, with: error, cleanupContext: cleanupContext) + case .succeedQuery(let requestContext, with: let result): + return .succeedQuery(requestContext, with: result) + case .forwardRows(let buffer): + return .forwardRows(buffer) + case .forwardStreamComplete(let buffer, let commandTag): + return .forwardStreamComplete(buffer, commandTag: commandTag) + case .forwardStreamError(let error, let read): + let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) + return .forwardStreamError(error, read: read, cleanupContext: cleanupContext) + + case .evaluateErrorAtConnectionLevel(let error): + if let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) { + return .closeConnectionAndCleanup(cleanupContext) + } + return .wait + case .read: + return .read + case .wait: + return .wait + } + } +} + extension ConnectionStateMachine { mutating func modify(with action: AuthenticationStateMachine.Action) -> ConnectionStateMachine.ConnectionAction { switch action { @@ -1178,6 +1352,8 @@ extension ConnectionStateMachine.State: CustomDebugStringConvertible { return ".readyForQuery(connectionContext: \(String(reflecting: connectionContext)))" case .extendedQuery(let subStateMachine, let connectionContext): return ".extendedQuery(\(String(reflecting: subStateMachine)), connectionContext: \(String(reflecting: connectionContext)))" + case .simpleQuery(let subStateMachine, let connectionContext): + return ".simpleQuery(\(String(reflecting: subStateMachine)), connectionContext: \(String(reflecting: connectionContext)))" case .closeCommand(let subStateMachine, let connectionContext): return ".closeCommand(\(String(reflecting: subStateMachine)), connectionContext: \(String(reflecting: connectionContext)))" case .closing: diff --git a/Sources/PostgresNIO/New/Connection State Machine/SimpleQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/SimpleQueryStateMachine.swift new file mode 100644 index 00000000..15b4dc6b --- /dev/null +++ b/Sources/PostgresNIO/New/Connection State Machine/SimpleQueryStateMachine.swift @@ -0,0 +1,432 @@ +import NIOCore + +struct SimpleQueryStateMachine { + + private enum State { + case initialized(SimpleQueryContext) + case messagesSent(SimpleQueryContext) + + case rowDescriptionReceived(SimpleQueryContext, [RowDescription.Column]) + case emptyQueryResponseReceived + + case streaming([RowDescription.Column], RowStreamStateMachine) + /// Indicates that the current query was cancelled and we want to drain rows from the connection ASAP + case drain([RowDescription.Column]) + + case commandComplete(commandTag: String) + case error(PSQLError) + + case modifying + } + + enum Action { + case sendQuery(String) + + // --- general actions + case failQuery(EventLoopPromise, with: PSQLError) + case succeedQuery(EventLoopPromise, with: QueryResult) + + case evaluateErrorAtConnectionLevel(PSQLError) + + // --- streaming actions + // actions if query has requested next row but we are waiting for backend + case forwardRows([DataRow]) + case forwardStreamComplete([DataRow], commandTag: String) + case forwardStreamError(PSQLError, read: Bool) + + case read + case wait + } + + private var state: State + private var isCancelled: Bool + + init(queryContext: SimpleQueryContext) { + self.isCancelled = false + self.state = .initialized(queryContext) + } + + mutating func start() -> Action { + guard case .initialized(let queryContext) = self.state else { + preconditionFailure("Start should only be called, if the query has been initialized") + } + + return self.avoidingStateMachineCoW { state -> Action in + state = .messagesSent(queryContext) + return .sendQuery(queryContext.query) + } + } + + mutating func cancel() -> Action { + switch self.state { + case .initialized: + preconditionFailure("Start must be called immediatly after the query was created") + + case .messagesSent(let queryContext): + guard !self.isCancelled else { + return .wait + } + + self.isCancelled = true + return .failQuery(queryContext.promise, with: .queryCancelled) + + case .rowDescriptionReceived(let queryContext, let columns): + guard !self.isCancelled else { + return .wait + } + + self.isCancelled = true + self.state = .drain(columns) + return .failQuery(queryContext.promise, with: .queryCancelled) + + case .streaming(let columns, var streamStateMachine): + precondition(!self.isCancelled) + self.isCancelled = true + self.state = .drain(columns) + switch streamStateMachine.fail() { + case .wait: + return .forwardStreamError(.queryCancelled, read: false) + case .read: + return .forwardStreamError(.queryCancelled, read: true) + } + + case .commandComplete, .emptyQueryResponseReceived, .error, .drain: + // the stream has already finished. + return .wait + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + + mutating func rowDescriptionReceived(_ rowDescription: RowDescription) -> Action { + let queryContext: SimpleQueryContext + switch self.state { + case .messagesSent(let simpleQueryContext): + queryContext = simpleQueryContext + default: + return self.setAndFireError(.unexpectedBackendMessage(.rowDescription(rowDescription))) + } + + guard !self.isCancelled else { + self.state = .drain(rowDescription.columns) + return .failQuery(queryContext.promise, with: .queryCancelled) + } + + self.avoidingStateMachineCoW { state in + // In a simple query almost all responses/columns will be in text format. + state = .rowDescriptionReceived(queryContext, rowDescription.columns) + } + + return .wait + } + + mutating func dataRowReceived(_ dataRow: DataRow) -> Action { + switch self.state { + case .rowDescriptionReceived(let queryContext, let columns): + // When receiving a data row, we must ensure that the data row column count + // matches the previously received row description column count. + guard dataRow.columnCount == columns.count else { + return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow))) + } + + return self.avoidingStateMachineCoW { state -> Action in + var demandStateMachine = RowStreamStateMachine() + demandStateMachine.receivedRow(dataRow) + state = .streaming(columns, demandStateMachine) + let result = QueryResult(value: .rowDescription(columns), logger: queryContext.logger) + return .succeedQuery(queryContext.promise, with: result) + } + + case .streaming(let columns, var demandStateMachine): + // When receiving a data row, we must ensure that the data row column count + // matches the previously received row description column count. + guard dataRow.columnCount == columns.count else { + return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow))) + } + + return self.avoidingStateMachineCoW { state -> Action in + demandStateMachine.receivedRow(dataRow) + state = .streaming(columns, demandStateMachine) + return .wait + } + + case .drain(let columns): + guard dataRow.columnCount == columns.count else { + return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow))) + } + // we ignore all rows and wait for readyForQuery + return .wait + + case .initialized, + .messagesSent, + .emptyQueryResponseReceived, + .commandComplete, + .error: + return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow))) + case .modifying: + preconditionFailure("Invalid state") + } + } + + mutating func commandCompletedReceived(_ commandTag: String) -> Action { + switch self.state { + case .messagesSent(let context): + return self.avoidingStateMachineCoW { state -> Action in + state = .commandComplete(commandTag: commandTag) + let result = QueryResult(value: .noRows(.tag(commandTag)), logger: context.logger) + return .succeedQuery(context.promise, with: result) + } + + case .rowDescriptionReceived(let context, _): + return self.avoidingStateMachineCoW { state -> Action in + state = .commandComplete(commandTag: commandTag) + let result = QueryResult(value: .noRows(.tag(commandTag)), logger: context.logger) + return .succeedQuery(context.promise, with: result) + } + + case .streaming(_, var demandStateMachine): + return self.avoidingStateMachineCoW { state -> Action in + state = .commandComplete(commandTag: commandTag) + return .forwardStreamComplete(demandStateMachine.end(), commandTag: commandTag) + } + + case .drain: + precondition(self.isCancelled) + self.state = .commandComplete(commandTag: commandTag) + return .wait + + case .initialized, + .emptyQueryResponseReceived, + .commandComplete, + .error: + return self.setAndFireError(.unexpectedBackendMessage(.commandComplete(commandTag))) + case .modifying: + preconditionFailure("Invalid state") + } + } + + mutating func emptyQueryResponseReceived() -> Action { + switch self.state { + case .messagesSent(let queryContext): + return self.avoidingStateMachineCoW { state -> Action in + state = .emptyQueryResponseReceived + let result = QueryResult(value: .noRows(.emptyResponse), logger: queryContext.logger) + return .succeedQuery(queryContext.promise, with: result) + } + + default: + return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) + } + } + + mutating func errorReceived(_ errorMessage: PostgresBackendMessage.ErrorResponse) -> Action { + let error = PSQLError.server(errorMessage) + switch self.state { + case .initialized: + return self.setAndFireError(.unexpectedBackendMessage(.error(errorMessage))) + case .messagesSent: + return self.setAndFireError(error) + case .rowDescriptionReceived: + return self.setAndFireError(error) + case .streaming, .drain: + return self.setAndFireError(error) + case .commandComplete, .emptyQueryResponseReceived: + return self.setAndFireError(.unexpectedBackendMessage(.error(errorMessage))) + case .error: + preconditionFailure(""" + This state must not be reached. If the query `.isComplete`, the + ConnectionStateMachine must not send any further events to the substate machine. + """) + + case .modifying: + preconditionFailure("Invalid state") + } + } + + mutating func noticeReceived(_ notice: PostgresBackendMessage.NoticeResponse) -> Action { + //self.queryObject.noticeReceived(notice) + return .wait + } + + mutating func errorHappened(_ error: PSQLError) -> Action { + return self.setAndFireError(error) + } + + // MARK: Customer Actions + + mutating func requestQueryRows() -> Action { + switch self.state { + case .streaming(let columns, var demandStateMachine): + return self.avoidingStateMachineCoW { state -> Action in + let action = demandStateMachine.demandMoreResponseBodyParts() + state = .streaming(columns, demandStateMachine) + switch action { + case .read: + return .read + case .wait: + return .wait + } + } + + case .drain: + return .wait + + case .initialized, + .messagesSent, + .emptyQueryResponseReceived, + .rowDescriptionReceived: + preconditionFailure("Requested to consume next row without anything going on.") + + case .commandComplete, .error: + preconditionFailure("The stream is already closed or in a failure state; rows can not be consumed at this time.") + case .modifying: + preconditionFailure("Invalid state") + } + } + + // MARK: Channel actions + + mutating func channelReadComplete() -> Action { + switch self.state { + case .initialized, + .commandComplete, + .drain, + .error, + .messagesSent, + .emptyQueryResponseReceived, + .rowDescriptionReceived: + return .wait + + case .streaming(let columns, var demandStateMachine): + return self.avoidingStateMachineCoW { state -> Action in + let rows = demandStateMachine.channelReadComplete() + state = .streaming(columns, demandStateMachine) + switch rows { + case .some(let rows): + return .forwardRows(rows) + case .none: + return .wait + } + } + + case .modifying: + preconditionFailure("Invalid state") + } + } + + mutating func readEventCaught() -> Action { + switch self.state { + case .messagesSent, + .rowDescriptionReceived: + return .read + case .streaming(let columns, var demandStateMachine): + precondition(!self.isCancelled) + return self.avoidingStateMachineCoW { state -> Action in + let action = demandStateMachine.read() + state = .streaming(columns, demandStateMachine) + switch action { + case .wait: + return .wait + case .read: + return .read + } + } + case .initialized, + .commandComplete, + .emptyQueryResponseReceived, + .drain, + .error: + // we already have the complete stream received, now we are waiting for a + // `readyForQuery` package. To receive this we need to read! + return .read + case .modifying: + preconditionFailure("Invalid state") + } + } + + // MARK: Private Methods + + private mutating func setAndFireError(_ error: PSQLError) -> Action { + switch self.state { + case .initialized(let context), + .messagesSent(let context), + .rowDescriptionReceived(let context, _): + self.state = .error(error) + if self.isCancelled { + return .evaluateErrorAtConnectionLevel(error) + } else { + return .failQuery(context.promise, with: error) + } + + case .drain: + self.state = .error(error) + return .evaluateErrorAtConnectionLevel(error) + + case .streaming(_, var streamStateMachine): + self.state = .error(error) + switch streamStateMachine.fail() { + case .wait: + return .forwardStreamError(error, read: false) + case .read: + return .forwardStreamError(error, read: true) + } + + case .commandComplete, .emptyQueryResponseReceived, .error: + preconditionFailure(""" + This state must not be reached. If the query `.isComplete`, the + ConnectionStateMachine must not send any further events to the substate machine. + """) + case .modifying: + preconditionFailure("Invalid state") + } + } + + var isComplete: Bool { + switch self.state { + case .commandComplete, .emptyQueryResponseReceived, .error: + return true + + case .rowDescriptionReceived, .initialized, .messagesSent, .streaming, .drain: + return false + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } +} + +extension SimpleQueryStateMachine { + /// So, uh...this function needs some explaining. + /// + /// While the state machine logic above is great, there is a downside to having all of the state machine data in + /// associated data on enumerations: any modification of that data will trigger copy on write for heap-allocated + /// data. That means that for _every operation on the state machine_ we will CoW our underlying state, which is + /// not good. + /// + /// The way we can avoid this is by using this helper function. It will temporarily set state to a value with no + /// associated data, before attempting the body of the function. It will also verify that the state machine never + /// remains in this bad state. + /// + /// A key note here is that all callers must ensure that they return to a good state before they exit. + /// + /// Sadly, because it's generic and has a closure, we need to force it to be inlined at all call sites, which is + /// not ideal. + @inline(__always) + private mutating func avoidingStateMachineCoW(_ body: (inout State) -> ReturnType) -> ReturnType { + self.state = .modifying + defer { + assert(!self.isModifying) + } + + return body(&self.state) + } + + private var isModifying: Bool { + if case .modifying = self.state { + return true + } else { + return false + } + } +} diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index 6106fd21..0e8879f8 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -3,6 +3,7 @@ import NIOCore enum HandlerTask: Sendable { case extendedQuery(ExtendedQueryContext) + case simpleQuery(SimpleQueryContext) case closeCommand(CloseCommandContext) case startListening(NotificationListener) case cancelListening(String, Int) @@ -11,6 +12,7 @@ enum HandlerTask: Sendable { enum PSQLTask { case extendedQuery(ExtendedQueryContext) + case simpleQuery(SimpleQueryContext) case closeCommand(CloseCommandContext) func failWithError(_ error: PSQLError) { @@ -25,6 +27,9 @@ enum PSQLTask { eventLoopPromise.fail(error) } + case .simpleQuery(let simpleQueryContext): + simpleQueryContext.promise.fail(error) + case .closeCommand(let closeCommandContext): closeCommandContext.promise.fail(error) } @@ -40,7 +45,7 @@ final class ExtendedQueryContext: Sendable { let query: Query let logger: Logger - + init( query: PostgresQuery, logger: Logger, @@ -100,6 +105,22 @@ final class PreparedStatementContext: Sendable { } } +final class SimpleQueryContext { + let query: String + let logger: Logger + let promise: EventLoopPromise + + init( + query: String, + logger: Logger, + promise: EventLoopPromise + ) { + self.query = query + self.logger = logger + self.promise = promise + } +} + final class CloseCommandContext: Sendable { let target: CloseTarget let logger: Logger diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 0a14849a..bf036737 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -205,6 +205,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { psqlTask = .closeCommand(command) case .extendedQuery(let query): psqlTask = .extendedQuery(query) + case .simpleQuery(let query): + psqlTask = .simpleQuery(query) case .startListening(let listener): switch self.listenState.startListening(listener) { @@ -319,7 +321,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private func run(_ action: ConnectionStateMachine.ConnectionAction, with context: ChannelHandlerContext) { self.logger.trace("Run action", metadata: [.connectionAction: "\(action)"]) - + switch action { case .establishSSLConnection: self.establishSSLConnection(context: context) @@ -351,6 +353,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { self.sendBindExecuteAndSyncMessage(executeStatement: executeStatement, context: context) case .sendParseDescribeBindExecuteSync(let query): self.sendParseDescribeBindExecuteAndSyncMessage(query: query, context: context) + case .sendQuery(let query): + self.sendQuery(query: query, context: context) case .succeedQuery(let promise, with: let result): self.succeedQuery(promise, result: result, context: context) case .failQuery(let promise, with: let error, let cleanupContext): @@ -534,7 +538,16 @@ final class PostgresChannelHandler: ChannelDuplexHandler { self.encoder.sync() context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) } - + + private func sendQuery( + query: String, + context: ChannelHandlerContext + ) { + precondition(self.rowStream == nil, "Expected to not have an open stream at this point") + self.encoder.query(query) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + } + private func succeedQuery( _ promise: EventLoopPromise, result: QueryResult, diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift index 97805418..fe9d46a3 100644 --- a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift +++ b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift @@ -117,6 +117,12 @@ struct PostgresFrontendMessageEncoder { self.buffer.writeInteger(maxNumberOfRows) } + mutating func query(_ query: String) { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers(id: .query, length: UInt32(1 + query.utf8.count)) + self.buffer.writeNullTerminatedString(query) + } + mutating func parse(preparedStatementName: String, query: String, parameters: Parameters) where Parameters.Element == PostgresDataType { self.clearIfNeeded() self.buffer.psqlWriteMultipleIntegers( @@ -202,6 +208,7 @@ private enum FrontendMessageID: UInt8, Hashable, Sendable { case flush = 72 // H case parse = 80 // P case password = 112 // p - also both sasl values + case query = 81 // Q case sync = 83 // S case terminate = 88 // X } diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index b4c8e93f..f1ab8c0f 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -25,6 +25,39 @@ final class AsyncPostgresConnectionTests: XCTestCase { } } + func test1kRoundTripsSimpleQuery() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + try await withTestConnection(on: eventLoop) { connection in + for _ in 0..<1_000 { + let rows = try await connection.__simpleQuery("SELECT version()", logger: .psqlTest) + var iterator = rows.makeAsyncIterator() + let firstRow = try await iterator.next() + XCTAssertEqual(try firstRow?.decode(String.self, context: .default).contains("PostgreSQL"), true) + let done = try await iterator.next() + XCTAssertNil(done) + } + } + } + + func test1kRoundTripsSimpleQueryNoRows() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + try await withTestConnection(on: eventLoop) { connection in + for _ in 0..<1_000 { + let nonExistentOID = 1928394819 + let rows = try await connection.__simpleQuery("SELECT * FROM pg_class WHERE oid = \(nonExistentOID)", logger: .psqlTest) + var iterator = rows.makeAsyncIterator() + let first = try await iterator.next() + XCTAssertEqual(first, nil) + } + } + } + func testSelect10kRows() async throws { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } @@ -46,6 +79,27 @@ final class AsyncPostgresConnectionTests: XCTestCase { } } + func testSelect10kRowsSimpleQuery() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let start = 1 + let end = 10000 + + try await withTestConnection(on: eventLoop) { connection in + let rows = try await connection.__simpleQuery("SELECT generate_series(\(start), \(end));", logger: .psqlTest) + var counter = 0 + for try await row in rows { + let element = try row.decode(Int.self) + XCTAssertEqual(element, counter + 1) + counter += 1 + } + + XCTAssertEqual(counter, end) + } + } + func testSelectActiveConnection() async throws { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/SimpleQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/SimpleQueryStateMachineTests.swift new file mode 100644 index 00000000..16441dfc --- /dev/null +++ b/Tests/PostgresNIOTests/New/Connection State Machine/SimpleQueryStateMachineTests.swift @@ -0,0 +1,310 @@ +import XCTest +import NIOCore +import NIOEmbedded +import Logging +@testable import PostgresNIO + +class SimpleQueryStateMachineTests: XCTestCase { + + func testQueryWithSimpleQueryWithoutDataRowsHappyPath() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query = "DELETE FROM table WHERE id=1" + let queryContext = SimpleQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .simpleQuery(queryContext)), .sendQuery(query)) + XCTAssertEqual(state.commandCompletedReceived("DELETE 1"), .succeedQuery(promise, with: .init(value: .noRows(.tag("DELETE 1")), logger: logger))) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testQueryWithSimpleQueryWithRowDescriptionWithoutDataRowsHappyPath() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let nonExistentOID = 371280378 + let query = "SELECT * FROM pg_class WHERE oid = \(nonExistentOID)" + let queryContext = SimpleQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .simpleQuery(queryContext)), .sendQuery(query)) + + let input: [RowDescription.Column] = [ + .init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, format: .text) + ] + XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) + XCTAssertEqual(state.commandCompletedReceived("DELETE 1"), .succeedQuery(promise, with: .init(value: .noRows(.tag("DELETE 1")), logger: logger))) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testQueryWithSimpleQueryWithDataRowsHappyPath() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query = "SELECT version()" + let queryContext = SimpleQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .simpleQuery(queryContext)), .sendQuery(query)) + + // We need to ensure that even though the row description from the wire says that we + // will receive data in `.text` format, we will actually receive it in binary format, + // since we requested it in binary with our bind message. + let input: [RowDescription.Column] = [ + .init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, format: .text) + ] + + XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) + let row1: DataRow = [ByteBuffer(string: "test1")] + let result = QueryResult(value: .rowDescription(input), logger: queryContext.logger) + XCTAssertEqual(state.dataRowReceived(row1), .succeedQuery(promise, with: result)) + XCTAssertEqual(state.channelReadComplete(), .forwardRows([row1])) + XCTAssertEqual(state.readEventCaught(), .wait) + XCTAssertEqual(state.requestQueryRows(), .read) + + let row2: DataRow = [ByteBuffer(string: "test2")] + let row3: DataRow = [ByteBuffer(string: "test3")] + let row4: DataRow = [ByteBuffer(string: "test4")] + XCTAssertEqual(state.dataRowReceived(row2), .wait) + XCTAssertEqual(state.dataRowReceived(row3), .wait) + XCTAssertEqual(state.dataRowReceived(row4), .wait) + XCTAssertEqual(state.channelReadComplete(), .forwardRows([row2, row3, row4])) + XCTAssertEqual(state.requestQueryRows(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + let row5: DataRow = [ByteBuffer(string: "test5")] + let row6: DataRow = [ByteBuffer(string: "test6")] + XCTAssertEqual(state.dataRowReceived(row5), .wait) + XCTAssertEqual(state.dataRowReceived(row6), .wait) + + XCTAssertEqual(state.commandCompletedReceived("SELECT 2"), .forwardStreamComplete([row5, row6], commandTag: "SELECT 2")) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testQueryWithSimpleQueryWithNoQuery() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query = "-- some comments" + let queryContext = SimpleQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .simpleQuery(queryContext)), .sendQuery(query)) + XCTAssertEqual(state.emptyQueryResponseReceived(), .succeedQuery(promise, with: .init(value: .noRows(.emptyResponse), logger: logger))) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testReceiveTotallyUnexpectedMessageInQuery() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let nonExistentOID = 371280378 + let query = "SELECT * FROM pg_class WHERE oid = \(nonExistentOID)" + let queryContext = SimpleQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .simpleQuery(queryContext)), .sendQuery(query)) + + let input: [RowDescription.Column] = [ + .init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, format: .text) + ] + XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) + + let psqlError = PSQLError.unexpectedBackendMessage(.authentication(.ok)) + XCTAssertEqual(state.authenticationMessageReceived(.ok), + .failQuery(promise, with: psqlError, cleanupContext: .init(action: .close, tasks: [], error: psqlError, closePromise: nil))) + } + + func testQueryIsCancelledImmediately() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query = "SELECT version()" + let queryContext = SimpleQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .simpleQuery(queryContext)), .sendQuery(query)) + + XCTAssertEqual(state.cancelQueryStream(), .failQuery(promise, with: .queryCancelled, cleanupContext: nil)) + + // The query was cancelled but it also ended anyways, so we accept that the query has succeeded + XCTAssertEqual(state.commandCompletedReceived("SELECT 2"), .succeedQuery(promise, with: .init(value: .noRows(.tag("SELECT 2")), logger: logger))) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testQueryIsCancelledWithReadPending() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query = "SELECT version()" + let queryContext = SimpleQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .simpleQuery(queryContext)), .sendQuery(query)) + + // We need to ensure that even though the row description from the wire says that we + // will receive data in `.text` format, we will actually receive it in binary format, + // since we requested it in binary with our bind message. + let input: [RowDescription.Column] = [ + .init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, format: .text) + ] + + XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) + XCTAssertEqual(state.cancelQueryStream(), .failQuery(promise, with: .queryCancelled, cleanupContext: nil)) + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test1")]), .wait) + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test2")]), .wait) + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test3")]), .wait) + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test4")]), .wait) + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + XCTAssertEqual(state.commandCompletedReceived("SELECT 2"), .wait) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testQueryIsCancelledWithReadPendingWhileStreaming() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query = "SELECT version()" + let queryContext = SimpleQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .simpleQuery(queryContext)), .sendQuery(query)) + + // We need to ensure that even though the row description from the wire says that we + // will receive data in `.text` format, we will actually receive it in binary format, + // since we requested it in binary with our bind message. + let input: [RowDescription.Column] = [ + .init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, format: .text) + ] + + XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) + let row1: DataRow = [ByteBuffer(string: "test1")] + let result = QueryResult(value: .rowDescription(input), logger: queryContext.logger) + XCTAssertEqual(state.dataRowReceived(row1), .succeedQuery(promise, with: result)) + XCTAssertEqual(state.cancelQueryStream(), .forwardStreamError(.queryCancelled, read: false, cleanupContext: nil)) + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test2")]), .wait) + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test3")]), .wait) + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test4")]), .wait) + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + XCTAssertEqual(state.commandCompletedReceived("SELECT 2"), .wait) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + + func testCancelQueryAfterServerError() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query = "SELECT version()" + let queryContext = SimpleQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .simpleQuery(queryContext)), .sendQuery(query)) + + // We need to ensure that even though the row description from the wire says that we + // will receive data in `.text` format, we will actually receive it in binary format, + // since we requested it in binary with our bind message. + let input: [RowDescription.Column] = [ + .init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, format: .text) + ] + + XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) + let result = QueryResult(value: .rowDescription(input), logger: queryContext.logger) + let row1: DataRow = [ByteBuffer(string: "test1")] + XCTAssertEqual(state.dataRowReceived(row1), .succeedQuery(promise, with: result)) + + let dataRows2: [DataRow] = [ + [ByteBuffer(string: "test2")], + [ByteBuffer(string: "test3")], + [ByteBuffer(string: "test4")] + ] + for row in dataRows2 { + XCTAssertEqual(state.dataRowReceived(row), .wait) + } + + XCTAssertEqual(state.channelReadComplete(), .forwardRows([row1] + dataRows2)) + XCTAssertEqual(state.readEventCaught(), .wait) + XCTAssertEqual(state.requestQueryRows(), .read) + let dataRows3: [DataRow] = [ + [ByteBuffer(string: "test5")], + [ByteBuffer(string: "test6")], + [ByteBuffer(string: "test7")] + ] + for row in dataRows3 { + XCTAssertEqual(state.dataRowReceived(row), .wait) + } + let serverError = PostgresBackendMessage.ErrorResponse(fields: [.severity: "Error", .sqlState: "123"]) + XCTAssertEqual(state.errorReceived(serverError), .forwardStreamError(.server(serverError), read: false, cleanupContext: .none)) + + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testQueryErrorDoesNotKillConnection() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query = "SELECT version()" + let queryContext = SimpleQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .simpleQuery(queryContext)), .sendQuery(query)) + + let serverError = PostgresBackendMessage.ErrorResponse(fields: [.severity: "Error", .sqlState: "123"]) + XCTAssertEqual( + state.errorReceived(serverError), .failQuery(promise, with: .server(serverError), cleanupContext: .none) + ) + + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testQueryErrorAfterCancelDoesNotKillConnection() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query = "SELECT version()" + let queryContext = SimpleQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .simpleQuery(queryContext)), .sendQuery(query)) + XCTAssertEqual(state.cancelQueryStream(), .failQuery(promise, with: .queryCancelled, cleanupContext: .none)) + + let serverError = PostgresBackendMessage.ErrorResponse(fields: [.severity: "Error", .sqlState: "123"]) + XCTAssertEqual(state.errorReceived(serverError), .wait) + + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } +} diff --git a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift index 9a1224d8..2cf7f11a 100644 --- a/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift +++ b/Tests/PostgresNIOTests/New/Extensions/ConnectionAction+TestUtils.swift @@ -24,6 +24,8 @@ extension PostgresNIO.ConnectionStateMachine.ConnectionAction: Swift.Equatable { return lhsMethod == rhsMethod && lhsAuthContext == rhsAuthContext case (.sendParseDescribeBindExecuteSync(let lquery), sendParseDescribeBindExecuteSync(let rquery)): return lquery == rquery + case (.sendQuery(let lquery), sendQuery(let rquery)): + return lquery == rquery case (.fireEventReadyForQuery, .fireEventReadyForQuery): return true case (.succeedQuery(let lhsPromise, let lhsResult), .succeedQuery(let rhsPromise, let rhsResult)): diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index 55ccd0a9..a4e0992e 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -217,6 +217,11 @@ extension PostgresFrontendMessage { preconditionFailure("TODO: Unimplemented") case .saslResponse: preconditionFailure("TODO: Unimplemented") + case .query: + guard let query = buffer.readNullTerminatedString() else { + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) + } + return .query(.init(query: query)) case .sync: return .sync case .terminate: diff --git a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift index 2532959a..a7564d33 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift @@ -59,6 +59,11 @@ enum PostgresFrontendMessage: Equatable { } } + struct Query: Hashable { + /// The query string. + let query: String + } + struct Parse: Hashable { /// The name of the destination prepared statement (an empty string selects the unnamed prepared statement). let preparedStatementName: String @@ -179,6 +184,7 @@ enum PostgresFrontendMessage: Equatable { case saslInitialResponse(SASLInitialResponse) case saslResponse(SASLResponse) case sslRequest + case query(Query) case sync case startup(Startup) case terminate @@ -194,6 +200,7 @@ enum PostgresFrontendMessage: Equatable { case password case saslInitialResponse case saslResponse + case query case sync case terminate @@ -217,6 +224,8 @@ enum PostgresFrontendMessage: Equatable { self = .saslInitialResponse case UInt8(ascii: "p"): self = .saslResponse + case UInt8(ascii: "Q"): + self = .query case UInt8(ascii: "S"): self = .sync case UInt8(ascii: "X"): @@ -246,6 +255,8 @@ enum PostgresFrontendMessage: Equatable { return UInt8(ascii: "p") case .saslResponse: return UInt8(ascii: "p") + case .query: + return UInt8(ascii: "Q") case .sync: return UInt8(ascii: "S") case .terminate: @@ -283,6 +294,8 @@ extension PostgresFrontendMessage { preconditionFailure("SSL requests don't have an identifier") case .startup: preconditionFailure("Startup messages don't have an identifier") + case .query: + return .query case .sync: return .sync case .terminate: diff --git a/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift b/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift index 33afbe0d..c68ea540 100644 --- a/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLFrontendMessageTests.swift @@ -31,7 +31,19 @@ class PSQLFrontendMessageTests: XCTestCase { XCTAssertEqual(PostgresFrontendMessage.ID.flush.rawValue, byteBuffer.readInteger(as: UInt8.self)) XCTAssertEqual(4, byteBuffer.readInteger(as: Int32.self)) // payload length } - + + func testEncodeQuery() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + let query = "SELECT * FROM foo" + encoder.query(query) + var byteBuffer = encoder.flushBuffer() + + XCTAssertEqual(byteBuffer.readableBytes, 23) + XCTAssertEqual(PostgresFrontendMessage.ID.query.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(22, byteBuffer.readInteger(as: Int32.self)) // payload length + XCTAssertEqual([UInt8](query.utf8), byteBuffer.readBytes(length: 17)) + } + func testEncodeSync() { var encoder = PostgresFrontendMessageEncoder(buffer: .init()) encoder.sync() diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index d0f8e2b0..6d615aeb 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -317,6 +317,40 @@ class PostgresConnectionTests: XCTestCase { } } + func testCloseImmediatelyWithSimpleQuery() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in + for _ in 1...2 { + taskGroup.addTask { + try await connection.__simpleQuery("SELECT 1;", logger: logger) + } + } + + let query = try await channel.waitForSimpleQueryRequest() + XCTAssertEqual(query.query, "SELECT 1;") + + async let close: () = connection.close() + + try await channel.closeFuture.get() + XCTAssertEqual(channel.isActive, false) + + try await close + + while let taskResult = await taskGroup.nextResult() { + switch taskResult { + case .success: + XCTFail("Expected queries to fail") + case .failure(let failure): + guard let error = failure as? PSQLError else { + return XCTFail("Unexpected error type: \(failure)") + } + XCTAssertEqual(error.code, .clientClosedConnection) + } + } + } + } + func testIfServerJustClosesTheErrorReflectsThat() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() let logger = self.logger @@ -346,6 +380,35 @@ class PostgresConnectionTests: XCTestCase { } } + func testIfServerJustClosesTheErrorReflectsThatInSimpleQuery() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + let logger = self.logger + + async let response = try await connection.__simpleQuery("SELECT 1;", logger: logger) + + let query = try await channel.waitForSimpleQueryRequest() + XCTAssertEqual(query.query, "SELECT 1;") + + try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelInactive() } + try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelUnregistered() } + + do { + _ = try await response + XCTFail("Expected to throw") + } catch { + XCTAssertEqual((error as? PSQLError)?.code, .serverClosedConnection) + } + + // retry on same connection + + do { + _ = try await connection.__simpleQuery("SELECT 1;", logger: self.logger) + XCTFail("Expected to throw") + } catch { + XCTAssertEqual((error as? PSQLError)?.code, .serverClosedConnection) + } + } + struct TestPrepareStatement: PostgresPreparedStatement { static let sql = "SELECT datname FROM pg_stat_activity WHERE state = $1" typealias Row = String @@ -692,6 +755,14 @@ extension NIOAsyncTestingChannel { return UnpreparedRequest(parse: parse, describe: describe, bind: bind, execute: execute) } + func waitForSimpleQueryRequest() async throws -> PostgresFrontendMessage.Query { + let query = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + guard case .query(let query) = query else { + fatalError() + } + return query + } + func waitForPrepareRequest() async throws -> PrepareRequest { let parse = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) let describe = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self)