diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index 7de87f55..00367a14 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -439,7 +439,7 @@ extension PostgresConnection { } /// Run a simple text-only query on the Postgres server the connection is connected to. - /// WARNING: This functions is not yet API and is incomplete. + /// WARNING: This function is not yet API and is incomplete. /// The return type will change to another stream. /// /// - Parameters: @@ -460,13 +460,13 @@ extension PostgresConnection { logger[postgresMetadataKey: .connectionID] = "\(self.id)" let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) - let context = ExtendedQueryContext( - simpleQuery: query, + let context = SimpleQueryContext( + query: query, logger: logger, promise: promise ) - self.channel.write(HandlerTask.extendedQuery(context), promise: nil) + self.channel.write(HandlerTask.simpleQuery(context), promise: nil) do { return try await promise.futureResult.map({ $0.asyncSequence() }).get() diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 03e9c7eb..4faf57a7 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -31,6 +31,7 @@ struct ConnectionStateMachine { case readyForQuery(ConnectionContext) case extendedQuery(ExtendedQueryStateMachine, ConnectionContext) + case simpleQuery(SimpleQueryStateMachine, ConnectionContext) case closeCommand(CloseStateMachine, ConnectionContext) case closing(PSQLError?) @@ -157,6 +158,7 @@ struct ConnectionStateMachine { .authenticated, .readyForQuery, .extendedQuery, + .simpleQuery, .closeCommand, .closing, .closed, @@ -214,6 +216,7 @@ struct ConnectionStateMachine { .authenticating, .readyForQuery, .extendedQuery, + .simpleQuery, .closeCommand: return self.errorHappened(.serverClosedConnection(underlying: nil)) @@ -244,6 +247,7 @@ struct ConnectionStateMachine { .authenticated, .readyForQuery, .extendedQuery, + .simpleQuery, .closeCommand, .closing, .closed: @@ -271,6 +275,7 @@ struct ConnectionStateMachine { .authenticated, .readyForQuery, .extendedQuery, + .simpleQuery, .closeCommand, .closing, .closed: @@ -291,6 +296,7 @@ struct ConnectionStateMachine { .authenticated, .readyForQuery, .extendedQuery, + .simpleQuery, .closeCommand, .closing, .closed: @@ -315,6 +321,7 @@ struct ConnectionStateMachine { .authenticated, .readyForQuery, .extendedQuery, + .simpleQuery, .closeCommand, .closing, .closed: @@ -380,6 +387,12 @@ struct ConnectionStateMachine { self.state = .extendedQuery(query, connectionContext) return .wait + case .simpleQuery(let query, var connectionContext): + self.state = .modifying // avoid CoW + connectionContext.parameters[status.parameter] = status.value + self.state = .simpleQuery(query, connectionContext) + return .wait + case .closeCommand(let closeState, var connectionContext): self.state = .modifying // avoid CoW connectionContext.parameters[status.parameter] = status.value @@ -430,6 +443,15 @@ struct ConnectionStateMachine { self.state = .extendedQuery(extendedQueryState, connectionContext) return self.modify(with: action) + case .simpleQuery(var simpleQueryState, let connectionContext): + if simpleQueryState.isComplete { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.error(errorMessage))) + } + self.state = .modifying // avoid CoW + let action = simpleQueryState.errorReceived(errorMessage) + self.state = .simpleQuery(simpleQueryState, connectionContext) + return self.modify(with: action) + case .closing: // If the state machine is in state `.closing`, the connection shutdown was initiated // by the client. This means a `TERMINATE` message has already been sent and the @@ -464,6 +486,13 @@ struct ConnectionStateMachine { let action = queryState.errorHappened(error) return self.modify(with: action) } + case .simpleQuery(var queryState, _): + if queryState.isComplete { + return self.closeConnectionAndCleanup(error) + } else { + let action = queryState.errorHappened(error) + return self.modify(with: action) + } case .closeCommand(var closeState, _): if closeState.isComplete { return self.closeConnectionAndCleanup(error) @@ -497,6 +526,12 @@ struct ConnectionStateMachine { self.state = .extendedQuery(extendedQuery, connectionContext) return self.modify(with: action) + case .simpleQuery(var simpleQuery, let connectionContext): + self.state = .modifying // avoid CoW + let action = simpleQuery.noticeReceived(notice) + self.state = .simpleQuery(simpleQuery, connectionContext) + return self.modify(with: action) + default: return .wait } @@ -527,6 +562,15 @@ struct ConnectionStateMachine { connectionContext.transactionState = transactionState + self.state = .readyForQuery(connectionContext) + return self.executeNextQueryFromQueue() + case .simpleQuery(let simpleQuery, var connectionContext): + guard simpleQuery.isComplete else { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.readyForQuery(transactionState))) + } + + connectionContext.transactionState = transactionState + self.state = .readyForQuery(connectionContext) return self.executeNextQueryFromQueue() case .closeCommand(let closeStateMachine, var connectionContext): @@ -559,6 +603,7 @@ struct ConnectionStateMachine { .authenticating, .closeCommand, .extendedQuery, + .simpleQuery, .sslNegotiated, .sslHandlerAdded, .sslRequestSent, @@ -586,11 +631,13 @@ struct ConnectionStateMachine { switch task { case .extendedQuery(let queryContext): switch queryContext.query { - case .executeStatement(_, let promise), .unnamed(_, let promise), .simpleQuery(_, let promise): + case .executeStatement(_, let promise), .unnamed(_, let promise): return .failQuery(promise, with: psqlErrror, cleanupContext: nil) case .prepareStatement(_, _, _, let promise): return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil) } + case .simpleQuery(let queryContext): + return .failQuery(queryContext.promise, with: psqlErrror, cleanupContext: nil) case .closeCommand(let closeContext): return .failClose(closeContext, with: psqlErrror, cleanupContext: nil) } @@ -616,7 +663,13 @@ struct ConnectionStateMachine { let action = extendedQuery.channelReadComplete() self.state = .extendedQuery(extendedQuery, connectionContext) return self.modify(with: action) - + + case .simpleQuery(var simpleQuery, let connectionContext): + self.state = .modifying // avoid CoW + let action = simpleQuery.channelReadComplete() + self.state = .simpleQuery(simpleQuery, connectionContext) + return self.modify(with: action) + case .modifying: preconditionFailure("Invalid state") } @@ -644,6 +697,12 @@ struct ConnectionStateMachine { self.state = .extendedQuery(extendedQuery, connectionContext) return self.modify(with: action) + case .simpleQuery(var simpleQuery, let connectionContext): + self.state = .modifying // avoid CoW + let action = simpleQuery.readEventCaught() + self.state = .simpleQuery(simpleQuery, connectionContext) + return self.modify(with: action) + case .closeCommand(var closeState, let connectionContext): self.state = .modifying // avoid CoW let action = closeState.readEventCaught() @@ -709,6 +768,12 @@ struct ConnectionStateMachine { self.state = .extendedQuery(queryState, connectionContext) return self.modify(with: action) + case .simpleQuery(var queryState, let connectionContext) where !queryState.isComplete: + self.state = .modifying // avoid CoW + let action = queryState.rowDescriptionReceived(description) + self.state = .simpleQuery(queryState, connectionContext) + return self.modify(with: action) + default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.rowDescription(description))) } @@ -743,60 +808,100 @@ struct ConnectionStateMachine { } mutating func commandCompletedReceived(_ commandTag: String) -> ConnectionAction { - guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { + switch self.state { + case .extendedQuery(var queryState, let connectionContext) where !queryState.isComplete: + self.state = .modifying // avoid CoW + let action = queryState.commandCompletedReceived(commandTag) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + + case .simpleQuery(var queryState, let connectionContext) where !queryState.isComplete: + self.state = .modifying // avoid CoW + let action = queryState.commandCompletedReceived(commandTag) + self.state = .simpleQuery(queryState, connectionContext) + return self.modify(with: action) + + default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.commandComplete(commandTag))) } - - self.state = .modifying // avoid CoW - let action = queryState.commandCompletedReceived(commandTag) - self.state = .extendedQuery(queryState, connectionContext) - return self.modify(with: action) } mutating func emptyQueryResponseReceived() -> ConnectionAction { - guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { + switch self.state { + case .extendedQuery(var queryState, let connectionContext) where !queryState.isComplete: + self.state = .modifying // avoid CoW + let action = queryState.emptyQueryResponseReceived() + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + + case .simpleQuery(var queryState, let connectionContext) where !queryState.isComplete: + self.state = .modifying // avoid CoW + let action = queryState.emptyQueryResponseReceived() + self.state = .simpleQuery(queryState, connectionContext) + return self.modify(with: action) + + default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.emptyQueryResponse)) } - - self.state = .modifying // avoid CoW - let action = queryState.emptyQueryResponseReceived() - self.state = .extendedQuery(queryState, connectionContext) - return self.modify(with: action) } mutating func dataRowReceived(_ dataRow: DataRow) -> ConnectionAction { - guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { + switch self.state { + case .extendedQuery(var queryState, let connectionContext) where !queryState.isComplete: + self.state = .modifying // avoid CoW + let action = queryState.dataRowReceived(dataRow) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + + case .simpleQuery(var queryState, let connectionContext) where !queryState.isComplete: + self.state = .modifying // avoid CoW + let action = queryState.dataRowReceived(dataRow) + self.state = .simpleQuery(queryState, connectionContext) + return self.modify(with: action) + + default: return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.dataRow(dataRow))) } - - self.state = .modifying // avoid CoW - let action = queryState.dataRowReceived(dataRow) - self.state = .extendedQuery(queryState, connectionContext) - return self.modify(with: action) } // MARK: Consumer mutating func cancelQueryStream() -> ConnectionAction { - guard case .extendedQuery(var queryState, let connectionContext) = self.state else { + switch self.state { + case .extendedQuery(var queryState, let connectionContext): + self.state = .modifying // avoid CoW + let action = queryState.cancel() + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + + case .simpleQuery(var queryState, let connectionContext): + self.state = .modifying // avoid CoW + let action = queryState.cancel() + self.state = .simpleQuery(queryState, connectionContext) + return self.modify(with: action) + + default: preconditionFailure("Tried to cancel stream without active query") } - - self.state = .modifying // avoid CoW - let action = queryState.cancel() - self.state = .extendedQuery(queryState, connectionContext) - return self.modify(with: action) } mutating func requestQueryRows() -> ConnectionAction { - guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { + switch self.state { + case .extendedQuery(var queryState, let connectionContext) where !queryState.isComplete: + self.state = .modifying // avoid CoW + let action = queryState.requestQueryRows() + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + + case .simpleQuery(var queryState, let connectionContext) where !queryState.isComplete: + self.state = .modifying // avoid CoW + let action = queryState.requestQueryRows() + self.state = .simpleQuery(queryState, connectionContext) + return self.modify(with: action) + + default: preconditionFailure("Tried to consume next row, without active query") } - - self.state = .modifying // avoid CoW - let action = queryState.requestQueryRows() - self.state = .extendedQuery(queryState, connectionContext) - return self.modify(with: action) } // MARK: - Private Methods - @@ -856,7 +961,6 @@ struct ConnectionStateMachine { case .sendParseDescribeBindExecuteSync, .sendParseDescribeSync, .sendBindExecuteSync, - .sendQuery, .succeedQuery, .succeedPreparedStatementCreation, .forwardRows, @@ -878,6 +982,36 @@ struct ConnectionStateMachine { return .failPreparedStatementCreation(promise, with: error, cleanupContext: cleanupContext) } + case .simpleQuery(var queryStateMachine, _): + let cleanupContext = self.setErrorAndCreateCleanupContext(error, closePromise: closePromise) + + if queryStateMachine.isComplete { + // in case the query state machine is complete all necessary actions have already + // been forwarded to the consumer. We can close and cleanup without caring about the + // substate machine. + return .closeConnectionAndCleanup(cleanupContext) + } + + let action = queryStateMachine.errorHappened(error) + switch action { + case .sendQuery, + .succeedQuery, + .forwardRows, + .forwardStreamComplete, + .wait, + .read: + preconditionFailure("Invalid query state machine action in state: \(self.state), action: \(action)") + + case .evaluateErrorAtConnectionLevel: + return .closeConnectionAndCleanup(cleanupContext) + + case .failQuery(let queryContext, with: let error): + return .failQuery(queryContext, with: error, cleanupContext: cleanupContext) + + case .forwardStreamError(let error, let read): + return .forwardStreamError(error, read: read, cleanupContext: cleanupContext) + } + case .closeCommand(var closeStateMachine, _): let cleanupContext = self.setErrorAndCreateCleanupContext(error, closePromise: closePromise) @@ -943,6 +1077,13 @@ struct ConnectionStateMachine { self.state = .extendedQuery(extendedQuery, connectionContext) return self.modify(with: action) + case .simpleQuery(let queryContext): + self.state = .modifying // avoid CoW + var simpleQuery = SimpleQueryStateMachine(queryContext: queryContext) + let action = simpleQuery.start() + self.state = .simpleQuery(simpleQuery, connectionContext) + return self.modify(with: action) + case .closeCommand(let closeContext): self.state = .modifying // avoid CoW var closeStateMachine = CloseStateMachine(closeContext: closeContext) @@ -1037,8 +1178,6 @@ extension ConnectionStateMachine { return .sendParseDescribeBindExecuteSync(query) case .sendBindExecuteSync(let executeStatement): return .sendBindExecuteSync(executeStatement) - case .sendQuery(let query): - return .sendQuery(query) case .failQuery(let requestContext, with: let error): let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) return .failQuery(requestContext, with: error, cleanupContext: cleanupContext) @@ -1072,6 +1211,37 @@ extension ConnectionStateMachine { } } +extension ConnectionStateMachine { + mutating func modify(with action: SimpleQueryStateMachine.Action) -> ConnectionStateMachine.ConnectionAction { + switch action { + case .sendQuery(let query): + return .sendQuery(query) + case .failQuery(let requestContext, with: let error): + let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) + return .failQuery(requestContext, with: error, cleanupContext: cleanupContext) + case .succeedQuery(let requestContext, with: let result): + return .succeedQuery(requestContext, with: result) + case .forwardRows(let buffer): + return .forwardRows(buffer) + case .forwardStreamComplete(let buffer, let commandTag): + return .forwardStreamComplete(buffer, commandTag: commandTag) + case .forwardStreamError(let error, let read): + let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) + return .forwardStreamError(error, read: read, cleanupContext: cleanupContext) + + case .evaluateErrorAtConnectionLevel(let error): + if let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) { + return .closeConnectionAndCleanup(cleanupContext) + } + return .wait + case .read: + return .read + case .wait: + return .wait + } + } +} + extension ConnectionStateMachine { mutating func modify(with action: AuthenticationStateMachine.Action) -> ConnectionStateMachine.ConnectionAction { switch action { @@ -1182,6 +1352,8 @@ extension ConnectionStateMachine.State: CustomDebugStringConvertible { return ".readyForQuery(connectionContext: \(String(reflecting: connectionContext)))" case .extendedQuery(let subStateMachine, let connectionContext): return ".extendedQuery(\(String(reflecting: subStateMachine)), connectionContext: \(String(reflecting: connectionContext)))" + case .simpleQuery(let subStateMachine, let connectionContext): + return ".simpleQuery(\(String(reflecting: subStateMachine)), connectionContext: \(String(reflecting: connectionContext)))" case .closeCommand(let subStateMachine, let connectionContext): return ".closeCommand(\(String(reflecting: subStateMachine)), connectionContext: \(String(reflecting: connectionContext)))" case .closing: diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 853d3abc..087a6c24 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -29,8 +29,7 @@ struct ExtendedQueryStateMachine { case sendParseDescribeBindExecuteSync(PostgresQuery) case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType]) case sendBindExecuteSync(PSQLExecuteStatement) - case sendQuery(String) - + // --- general actions case failQuery(EventLoopPromise, with: PSQLError) case succeedQuery(EventLoopPromise, with: QueryResult) @@ -86,12 +85,6 @@ struct ExtendedQueryStateMachine { state = .messagesSent(queryContext) return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes) } - - case .simpleQuery(let query, _): - return self.avoidingStateMachineCoW { state -> Action in - state = .messagesSent(queryContext) - return .sendQuery(query) - } } } @@ -112,7 +105,7 @@ struct ExtendedQueryStateMachine { self.isCancelled = true switch queryContext.query { - case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise), .simpleQuery(_, let eventLoopPromise): + case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): return .failQuery(eventLoopPromise, with: .queryCancelled) case .prepareStatement(_, _, _, let eventLoopPromise): @@ -178,19 +171,11 @@ struct ExtendedQueryStateMachine { state = .noDataMessageReceived(queryContext) return .succeedPreparedStatementCreation(promise, with: nil) } - - case .simpleQuery: - return self.setAndFireError(.unexpectedBackendMessage(.noData)) } } mutating func rowDescriptionReceived(_ rowDescription: RowDescription) -> Action { - let queryContext: ExtendedQueryContext - switch self.state { - case .messagesSent(let extendedQueryContext), - .parameterDescriptionReceived(let extendedQueryContext): - queryContext = extendedQueryContext - default: + guard case .parameterDescriptionReceived(let queryContext) = self.state else { return self.setAndFireError(.unexpectedBackendMessage(.rowDescription(rowDescription))) } @@ -213,7 +198,7 @@ struct ExtendedQueryStateMachine { } switch queryContext.query { - case .unnamed, .executeStatement, .simpleQuery: + case .unnamed, .executeStatement: return .wait case .prepareStatement(_, _, _, let eventLoopPromise): @@ -234,9 +219,6 @@ struct ExtendedQueryStateMachine { case .prepareStatement: return .evaluateErrorAtConnectionLevel(.unexpectedBackendMessage(.bindComplete)) - - case .simpleQuery: - return self.setAndFireError(.unexpectedBackendMessage(.bindComplete)) } case .noDataMessageReceived(let queryContext): @@ -276,40 +258,20 @@ struct ExtendedQueryStateMachine { return .wait } - case .rowDescriptionReceived(let queryContext, let columns): - switch queryContext.query { - case .simpleQuery(_, let eventLoopPromise): - // When receiving a data row, we must ensure that the data row column count - // matches the previously received row description column count. - guard dataRow.columnCount == columns.count else { - return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow))) - } - - return self.avoidingStateMachineCoW { state -> Action in - var demandStateMachine = RowStreamStateMachine() - demandStateMachine.receivedRow(dataRow) - state = .streaming(columns, demandStateMachine) - let result = QueryResult(value: .rowDescription(columns), logger: queryContext.logger) - return .succeedQuery(eventLoopPromise, with: result) - } - - case .unnamed, .executeStatement, .prepareStatement: - return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow))) - } - case .drain(let columns): guard dataRow.columnCount == columns.count else { return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow))) } // we ignore all rows and wait for readyForQuery return .wait - + case .initialized, .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, .emptyQueryResponseReceived, + .rowDescriptionReceived, .bindCompleteReceived, .commandComplete, .error: @@ -330,36 +292,10 @@ struct ExtendedQueryStateMachine { return .succeedQuery(eventLoopPromise, with: result) } - case .prepareStatement, .simpleQuery: + case .prepareStatement: preconditionFailure("Invalid state: \(self.state)") } - - case .messagesSent(let context): - switch context.query { - case .simpleQuery(_, let eventLoopGroup): - return self.avoidingStateMachineCoW { state -> Action in - state = .commandComplete(commandTag: commandTag) - let result = QueryResult(value: .noRows(.tag(commandTag)), logger: context.logger) - return .succeedQuery(eventLoopGroup, with: result) - } - - case .unnamed, .executeStatement, .prepareStatement: - return self.setAndFireError(.unexpectedBackendMessage(.commandComplete(commandTag))) - } - - case .rowDescriptionReceived(let context, _): - switch context.query { - case .simpleQuery(_, let eventLoopPromise): - return self.avoidingStateMachineCoW { state -> Action in - state = .commandComplete(commandTag: commandTag) - let result = QueryResult(value: .noRows(.tag(commandTag)), logger: context.logger) - return .succeedQuery(eventLoopPromise, with: result) - } - - case .unnamed, .executeStatement, .prepareStatement: - return self.setAndFireError(.unexpectedBackendMessage(.commandComplete(commandTag))) - } - + case .streaming(_, var demandStateMachine): return self.avoidingStateMachineCoW { state -> Action in state = .commandComplete(commandTag: commandTag) @@ -370,12 +306,14 @@ struct ExtendedQueryStateMachine { precondition(self.isCancelled) self.state = .commandComplete(commandTag: commandTag) return .wait - + case .initialized, + .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, .emptyQueryResponseReceived, + .rowDescriptionReceived, .commandComplete, .error: return self.setAndFireError(.unexpectedBackendMessage(.commandComplete(commandTag))) @@ -385,32 +323,20 @@ struct ExtendedQueryStateMachine { } mutating func emptyQueryResponseReceived() -> Action { - switch self.state { - case .bindCompleteReceived(let queryContext): - switch queryContext.query { - case .unnamed(_, let eventLoopPromise), - .executeStatement(_, let eventLoopPromise): - return self.avoidingStateMachineCoW { state -> Action in - state = .emptyQueryResponseReceived - let result = QueryResult(value: .noRows(.emptyResponse), logger: queryContext.logger) - return .succeedQuery(eventLoopPromise, with: result) - } + guard case .bindCompleteReceived(let queryContext) = self.state else { + return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) + } - case .prepareStatement, .simpleQuery: - return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) - } - case .messagesSent(let queryContext): - switch queryContext.query { - case .simpleQuery(_, let eventLoopPromise): - return self.avoidingStateMachineCoW { state -> Action in - state = .emptyQueryResponseReceived - let result = QueryResult(value: .noRows(.emptyResponse), logger: queryContext.logger) - return .succeedQuery(eventLoopPromise, with: result) - } - case .unnamed, .executeStatement, .prepareStatement: - return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) + switch queryContext.query { + case .unnamed(_, let eventLoopPromise), + .executeStatement(_, let eventLoopPromise): + return self.avoidingStateMachineCoW { state -> Action in + state = .emptyQueryResponseReceived + let result = QueryResult(value: .noRows(.emptyResponse), logger: queryContext.logger) + return .succeedQuery(eventLoopPromise, with: result) } - default: + + case .prepareStatement(_, _, _, _): return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) } } @@ -571,7 +497,7 @@ struct ExtendedQueryStateMachine { return .evaluateErrorAtConnectionLevel(error) } else { switch context.query { - case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise), .simpleQuery(_, let eventLoopPromise): + case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): return .failQuery(eventLoopPromise, with: error) case .prepareStatement(_, _, _, let eventLoopPromise): return .failPreparedStatementCreation(eventLoopPromise, with: error) @@ -610,7 +536,7 @@ struct ExtendedQueryStateMachine { switch context.query { case .prepareStatement: return true - case .unnamed, .executeStatement, .simpleQuery: + case .unnamed, .executeStatement: return false } diff --git a/Sources/PostgresNIO/New/Connection State Machine/SimpleQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/SimpleQueryStateMachine.swift new file mode 100644 index 00000000..f0340679 --- /dev/null +++ b/Sources/PostgresNIO/New/Connection State Machine/SimpleQueryStateMachine.swift @@ -0,0 +1,441 @@ +import NIOCore + +struct SimpleQueryStateMachine { + + private enum State { + case initialized(SimpleQueryContext) + case messagesSent(SimpleQueryContext) + + case rowDescriptionReceived(SimpleQueryContext, [RowDescription.Column]) + case emptyQueryResponseReceived + + case streaming([RowDescription.Column], RowStreamStateMachine) + /// Indicates that the current query was cancelled and we want to drain rows from the connection ASAP + case drain([RowDescription.Column]) + + case commandComplete(commandTag: String) + case error(PSQLError) + + case modifying + } + + enum Action { + case sendQuery(String) + + // --- general actions + case failQuery(EventLoopPromise, with: PSQLError) + case succeedQuery(EventLoopPromise, with: QueryResult) + + case evaluateErrorAtConnectionLevel(PSQLError) + + // --- streaming actions + // actions if query has requested next row but we are waiting for backend + case forwardRows([DataRow]) + case forwardStreamComplete([DataRow], commandTag: String) + case forwardStreamError(PSQLError, read: Bool) + + case read + case wait + } + + private var state: State + private var isCancelled: Bool + + init(queryContext: SimpleQueryContext) { + self.isCancelled = false + self.state = .initialized(queryContext) + } + + mutating func start() -> Action { + guard case .initialized(let queryContext) = self.state else { + preconditionFailure("Start should only be called, if the query has been initialized") + } + + return self.avoidingStateMachineCoW { state -> Action in + state = .messagesSent(queryContext) + return .sendQuery(queryContext.query) + } + } + + mutating func cancel() -> Action { + switch self.state { + case .initialized: + preconditionFailure("Start must be called immediatly after the query was created") + + case .messagesSent(let queryContext): + guard !self.isCancelled else { + return .wait + } + + self.isCancelled = true + return .failQuery(queryContext.promise, with: .queryCancelled) + + case .rowDescriptionReceived(let queryContext, let columns): + guard !self.isCancelled else { + return .wait + } + + self.isCancelled = true + self.state = .drain(columns) + return .failQuery(queryContext.promise, with: .queryCancelled) + + case .streaming(let columns, var streamStateMachine): + precondition(!self.isCancelled) + self.isCancelled = true + self.state = .drain(columns) + switch streamStateMachine.fail() { + case .wait: + return .forwardStreamError(.queryCancelled, read: false) + case .read: + return .forwardStreamError(.queryCancelled, read: true) + } + + case .commandComplete, .emptyQueryResponseReceived, .error, .drain: + // the stream has already finished. + return .wait + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } + + mutating func rowDescriptionReceived(_ rowDescription: RowDescription) -> Action { + let queryContext: SimpleQueryContext + switch self.state { + case .messagesSent(let simpleQueryContext): + queryContext = simpleQueryContext + default: + return self.setAndFireError(.unexpectedBackendMessage(.rowDescription(rowDescription))) + } + + // In Postgres extended queries we always request the response rows to be returned in + // `.binary` format. + // However, this is a simple query and almost all responses will be in text format anyway. + let columns = rowDescription.columns.map { column -> RowDescription.Column in + var column = column + // FIXME: .binary is not valid in a simple-query + column.format = .binary + return column + } + + guard !self.isCancelled else { + self.state = .drain(rowDescription.columns) + return .failQuery(queryContext.promise, with: .queryCancelled) + } + + self.avoidingStateMachineCoW { state in + state = .rowDescriptionReceived(queryContext, columns) + } + + return .wait + } + + mutating func dataRowReceived(_ dataRow: DataRow) -> Action { + switch self.state { + case .rowDescriptionReceived(let queryContext, let columns): + // When receiving a data row, we must ensure that the data row column count + // matches the previously received row description column count. + guard dataRow.columnCount == columns.count else { + return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow))) + } + + return self.avoidingStateMachineCoW { state -> Action in + var demandStateMachine = RowStreamStateMachine() + demandStateMachine.receivedRow(dataRow) + state = .streaming(columns, demandStateMachine) + let result = QueryResult(value: .rowDescription(columns), logger: queryContext.logger) + return .succeedQuery(queryContext.promise, with: result) + } + + case .streaming(let columns, var demandStateMachine): + // When receiving a data row, we must ensure that the data row column count + // matches the previously received row description column count. + guard dataRow.columnCount == columns.count else { + return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow))) + } + + return self.avoidingStateMachineCoW { state -> Action in + demandStateMachine.receivedRow(dataRow) + state = .streaming(columns, demandStateMachine) + return .wait + } + + case .drain(let columns): + guard dataRow.columnCount == columns.count else { + return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow))) + } + // we ignore all rows and wait for readyForQuery + return .wait + + case .initialized, + .messagesSent, + .emptyQueryResponseReceived, + .commandComplete, + .error: + return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow))) + case .modifying: + preconditionFailure("Invalid state") + } + } + + mutating func commandCompletedReceived(_ commandTag: String) -> Action { + switch self.state { + case .messagesSent(let context): + return self.avoidingStateMachineCoW { state -> Action in + state = .commandComplete(commandTag: commandTag) + let result = QueryResult(value: .noRows(.tag(commandTag)), logger: context.logger) + return .succeedQuery(context.promise, with: result) + } + + case .rowDescriptionReceived(let context, _): + return self.avoidingStateMachineCoW { state -> Action in + state = .commandComplete(commandTag: commandTag) + let result = QueryResult(value: .noRows(.tag(commandTag)), logger: context.logger) + return .succeedQuery(context.promise, with: result) + } + + case .streaming(_, var demandStateMachine): + return self.avoidingStateMachineCoW { state -> Action in + state = .commandComplete(commandTag: commandTag) + return .forwardStreamComplete(demandStateMachine.end(), commandTag: commandTag) + } + + case .drain: + precondition(self.isCancelled) + self.state = .commandComplete(commandTag: commandTag) + return .wait + + case .initialized, + .emptyQueryResponseReceived, + .commandComplete, + .error: + return self.setAndFireError(.unexpectedBackendMessage(.commandComplete(commandTag))) + case .modifying: + preconditionFailure("Invalid state") + } + } + + mutating func emptyQueryResponseReceived() -> Action { + switch self.state { + case .messagesSent(let queryContext): + return self.avoidingStateMachineCoW { state -> Action in + state = .emptyQueryResponseReceived + let result = QueryResult(value: .noRows(.emptyResponse), logger: queryContext.logger) + return .succeedQuery(queryContext.promise, with: result) + } + + default: + return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) + } + } + + mutating func errorReceived(_ errorMessage: PostgresBackendMessage.ErrorResponse) -> Action { + let error = PSQLError.server(errorMessage) + switch self.state { + case .initialized: + return self.setAndFireError(.unexpectedBackendMessage(.error(errorMessage))) + case .messagesSent: + return self.setAndFireError(error) + case .rowDescriptionReceived: + return self.setAndFireError(error) + case .streaming, .drain: + return self.setAndFireError(error) + case .commandComplete, .emptyQueryResponseReceived: + return self.setAndFireError(.unexpectedBackendMessage(.error(errorMessage))) + case .error: + preconditionFailure(""" + This state must not be reached. If the query `.isComplete`, the + ConnectionStateMachine must not send any further events to the substate machine. + """) + + case .modifying: + preconditionFailure("Invalid state") + } + } + + mutating func noticeReceived(_ notice: PostgresBackendMessage.NoticeResponse) -> Action { + //self.queryObject.noticeReceived(notice) + return .wait + } + + mutating func errorHappened(_ error: PSQLError) -> Action { + return self.setAndFireError(error) + } + + // MARK: Customer Actions + + mutating func requestQueryRows() -> Action { + switch self.state { + case .streaming(let columns, var demandStateMachine): + return self.avoidingStateMachineCoW { state -> Action in + let action = demandStateMachine.demandMoreResponseBodyParts() + state = .streaming(columns, demandStateMachine) + switch action { + case .read: + return .read + case .wait: + return .wait + } + } + + case .drain: + return .wait + + case .initialized, + .messagesSent, + .emptyQueryResponseReceived, + .rowDescriptionReceived: + preconditionFailure("Requested to consume next row without anything going on.") + + case .commandComplete, .error: + preconditionFailure("The stream is already closed or in a failure state; rows can not be consumed at this time.") + case .modifying: + preconditionFailure("Invalid state") + } + } + + // MARK: Channel actions + + mutating func channelReadComplete() -> Action { + switch self.state { + case .initialized, + .commandComplete, + .drain, + .error, + .messagesSent, + .emptyQueryResponseReceived, + .rowDescriptionReceived: + return .wait + + case .streaming(let columns, var demandStateMachine): + return self.avoidingStateMachineCoW { state -> Action in + let rows = demandStateMachine.channelReadComplete() + state = .streaming(columns, demandStateMachine) + switch rows { + case .some(let rows): + return .forwardRows(rows) + case .none: + return .wait + } + } + + case .modifying: + preconditionFailure("Invalid state") + } + } + + mutating func readEventCaught() -> Action { + switch self.state { + case .messagesSent, + .rowDescriptionReceived: + return .read + case .streaming(let columns, var demandStateMachine): + precondition(!self.isCancelled) + return self.avoidingStateMachineCoW { state -> Action in + let action = demandStateMachine.read() + state = .streaming(columns, demandStateMachine) + switch action { + case .wait: + return .wait + case .read: + return .read + } + } + case .initialized, + .commandComplete, + .emptyQueryResponseReceived, + .drain, + .error: + // we already have the complete stream received, now we are waiting for a + // `readyForQuery` package. To receive this we need to read! + return .read + case .modifying: + preconditionFailure("Invalid state") + } + } + + // MARK: Private Methods + + private mutating func setAndFireError(_ error: PSQLError) -> Action { + switch self.state { + case .initialized(let context), + .messagesSent(let context), + .rowDescriptionReceived(let context, _): + self.state = .error(error) + if self.isCancelled { + return .evaluateErrorAtConnectionLevel(error) + } else { + return .failQuery(context.promise, with: error) + } + + case .drain: + self.state = .error(error) + return .evaluateErrorAtConnectionLevel(error) + + case .streaming(_, var streamStateMachine): + self.state = .error(error) + switch streamStateMachine.fail() { + case .wait: + return .forwardStreamError(error, read: false) + case .read: + return .forwardStreamError(error, read: true) + } + + case .commandComplete, .emptyQueryResponseReceived, .error: + preconditionFailure(""" + This state must not be reached. If the query `.isComplete`, the + ConnectionStateMachine must not send any further events to the substate machine. + """) + case .modifying: + preconditionFailure("Invalid state") + } + } + + var isComplete: Bool { + switch self.state { + case .commandComplete, .emptyQueryResponseReceived, .error: + return true + + case .rowDescriptionReceived, .initialized, .messagesSent, .streaming, .drain: + return false + + case .modifying: + preconditionFailure("Invalid state: \(self.state)") + } + } +} + +extension SimpleQueryStateMachine { + /// So, uh...this function needs some explaining. + /// + /// While the state machine logic above is great, there is a downside to having all of the state machine data in + /// associated data on enumerations: any modification of that data will trigger copy on write for heap-allocated + /// data. That means that for _every operation on the state machine_ we will CoW our underlying state, which is + /// not good. + /// + /// The way we can avoid this is by using this helper function. It will temporarily set state to a value with no + /// associated data, before attempting the body of the function. It will also verify that the state machine never + /// remains in this bad state. + /// + /// A key note here is that all callers must ensure that they return to a good state before they exit. + /// + /// Sadly, because it's generic and has a closure, we need to force it to be inlined at all call sites, which is + /// not ideal. + @inline(__always) + private mutating func avoidingStateMachineCoW(_ body: (inout State) -> ReturnType) -> ReturnType { + self.state = .modifying + defer { + assert(!self.isModifying) + } + + return body(&self.state) + } + + private var isModifying: Bool { + if case .modifying = self.state { + return true + } else { + return false + } + } +} diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index c0923785..62028d6b 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -3,6 +3,7 @@ import NIOCore enum HandlerTask { case extendedQuery(ExtendedQueryContext) + case simpleQuery(SimpleQueryContext) case closeCommand(CloseCommandContext) case startListening(NotificationListener) case cancelListening(String, Int) @@ -11,6 +12,7 @@ enum HandlerTask { enum PSQLTask { case extendedQuery(ExtendedQueryContext) + case simpleQuery(SimpleQueryContext) case closeCommand(CloseCommandContext) func failWithError(_ error: PSQLError) { @@ -23,23 +25,22 @@ enum PSQLTask { eventLoopPromise.fail(error) case .prepareStatement(_, _, _, let eventLoopPromise): eventLoopPromise.fail(error) - case .simpleQuery(_, let eventLoopPromise): - eventLoopPromise.fail(error) } + case .simpleQuery(let simpleQueryContext): + simpleQueryContext.promise.fail(error) + case .closeCommand(let closeCommandContext): closeCommandContext.promise.fail(error) } } } -// FIXME: Either rename all these `ExtendedQuery`s to just like `Query` or pull out `simpleQuery` final class ExtendedQueryContext { enum Query { case unnamed(PostgresQuery, EventLoopPromise) case executeStatement(PSQLExecuteStatement, EventLoopPromise) case prepareStatement(name: String, query: String, bindingDataTypes: [PostgresDataType], EventLoopPromise) - case simpleQuery(String, EventLoopPromise) } let query: Query @@ -73,15 +74,6 @@ final class ExtendedQueryContext { self.query = .prepareStatement(name: name, query: query, bindingDataTypes: bindingDataTypes, promise) self.logger = logger } - - init( - simpleQuery: String, - logger: Logger, - promise: EventLoopPromise - ) { - self.query = .simpleQuery(simpleQuery, promise) - self.logger = logger - } } final class PreparedStatementContext: Sendable { @@ -113,6 +105,22 @@ final class PreparedStatementContext: Sendable { } } +final class SimpleQueryContext { + let query: String + let logger: Logger + let promise: EventLoopPromise + + init( + query: String, + logger: Logger, + promise: EventLoopPromise + ) { + self.query = query + self.logger = logger + self.promise = promise + } +} + final class CloseCommandContext { let target: CloseTarget let logger: Logger diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 7e2203f9..8d31c852 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -205,6 +205,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { psqlTask = .closeCommand(command) case .extendedQuery(let query): psqlTask = .extendedQuery(query) + case .simpleQuery(let query): + psqlTask = .simpleQuery(query) case .startListening(let listener): switch self.listenState.startListening(listener) { diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index 0da20ed2..ae484acc 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -78,106 +78,6 @@ class ExtendedQueryStateMachineTests: XCTestCase { XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) } - func testExtendedQueryWithSimpleQueryWithoutDataRowsHappyPath() { - var state = ConnectionStateMachine.readyForQuery() - - let logger = Logger.psqlTest - let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) - promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. - let query = "DELETE FROM table WHERE id=1" - let queryContext = ExtendedQueryContext(simpleQuery: query, logger: logger, promise: promise) - - XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendQuery(query)) - XCTAssertEqual(state.commandCompletedReceived("DELETE 1"), .succeedQuery(promise, with: .init(value: .noRows(.tag("DELETE 1")), logger: logger))) - XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) - } - - func testExtendedQueryWithSimpleQueryWithRowDescriptionWithoutDataRowsHappyPath() { - var state = ConnectionStateMachine.readyForQuery() - - let logger = Logger.psqlTest - let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) - promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. - let nonExistentOID = 371280378 - let query = "SELECT * FROM pg_class WHERE oid = \(nonExistentOID)" - let queryContext = ExtendedQueryContext(simpleQuery: query, logger: logger, promise: promise) - - XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendQuery(query)) - - let input: [RowDescription.Column] = [ - .init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, format: .text) - ] - XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) - XCTAssertEqual(state.commandCompletedReceived("DELETE 1"), .succeedQuery(promise, with: .init(value: .noRows(.tag("DELETE 1")), logger: logger))) - XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) - } - - func testExtendedQueryWithSimpleQueryWithDataRowsHappyPath() { - var state = ConnectionStateMachine.readyForQuery() - - let logger = Logger.psqlTest - let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) - promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. - let query = "SELECT version()" - let queryContext = ExtendedQueryContext(simpleQuery: query, logger: logger, promise: promise) - - XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendQuery(query)) - - // We need to ensure that even though the row description from the wire says that we - // will receive data in `.text` format, we will actually receive it in binary format, - // since we requested it in binary with our bind message. - let input: [RowDescription.Column] = [ - .init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, format: .text) - ] - let expected: [RowDescription.Column] = input.map { - .init(name: $0.name, tableOID: $0.tableOID, columnAttributeNumber: $0.columnAttributeNumber, dataType: $0.dataType, - dataTypeSize: $0.dataTypeSize, dataTypeModifier: $0.dataTypeModifier, format: .binary) - } - - XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) - let row1: DataRow = [ByteBuffer(string: "test1")] - let result = QueryResult(value: .rowDescription(expected), logger: queryContext.logger) - XCTAssertEqual(state.dataRowReceived(row1), .succeedQuery(promise, with: result)) - XCTAssertEqual(state.channelReadComplete(), .forwardRows([row1])) - XCTAssertEqual(state.readEventCaught(), .wait) - XCTAssertEqual(state.requestQueryRows(), .read) - - let row2: DataRow = [ByteBuffer(string: "test2")] - let row3: DataRow = [ByteBuffer(string: "test3")] - let row4: DataRow = [ByteBuffer(string: "test4")] - XCTAssertEqual(state.dataRowReceived(row2), .wait) - XCTAssertEqual(state.dataRowReceived(row3), .wait) - XCTAssertEqual(state.dataRowReceived(row4), .wait) - XCTAssertEqual(state.channelReadComplete(), .forwardRows([row2, row3, row4])) - XCTAssertEqual(state.requestQueryRows(), .wait) - XCTAssertEqual(state.readEventCaught(), .read) - - XCTAssertEqual(state.channelReadComplete(), .wait) - XCTAssertEqual(state.readEventCaught(), .read) - - let row5: DataRow = [ByteBuffer(string: "test5")] - let row6: DataRow = [ByteBuffer(string: "test6")] - XCTAssertEqual(state.dataRowReceived(row5), .wait) - XCTAssertEqual(state.dataRowReceived(row6), .wait) - - XCTAssertEqual(state.commandCompletedReceived("SELECT 2"), .forwardStreamComplete([row5, row6], commandTag: "SELECT 2")) - XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) - } - - func testExtendedQueryWithSimpleQueryWithNoQuery() { - var state = ConnectionStateMachine.readyForQuery() - - let logger = Logger.psqlTest - let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) - promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. - let query = "-- some comments" - let queryContext = ExtendedQueryContext(simpleQuery: query, logger: logger, promise: promise) - - XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendQuery(query)) - XCTAssertEqual(state.emptyQueryResponseReceived(), .succeedQuery(promise, with: .init(value: .noRows(.emptyResponse), logger: logger))) - XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) - } - func testExtendedQueryWithNoQuery() { var state = ConnectionStateMachine.readyForQuery() diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/SimpleQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/SimpleQueryStateMachineTests.swift new file mode 100644 index 00000000..0f51afd9 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Connection State Machine/SimpleQueryStateMachineTests.swift @@ -0,0 +1,278 @@ +import XCTest +import NIOCore +import NIOEmbedded +import Logging +@testable import PostgresNIO + +class SimpleQueryStateMachineTests: XCTestCase { + + func testQueryWithSimpleQueryWithoutDataRowsHappyPath() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query = "DELETE FROM table WHERE id=1" + let queryContext = SimpleQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .simpleQuery(queryContext)), .sendQuery(query)) + XCTAssertEqual(state.commandCompletedReceived("DELETE 1"), .succeedQuery(promise, with: .init(value: .noRows(.tag("DELETE 1")), logger: logger))) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testQueryWithSimpleQueryWithRowDescriptionWithoutDataRowsHappyPath() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let nonExistentOID = 371280378 + let query = "SELECT * FROM pg_class WHERE oid = \(nonExistentOID)" + let queryContext = SimpleQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .simpleQuery(queryContext)), .sendQuery(query)) + + let input: [RowDescription.Column] = [ + .init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, format: .text) + ] + XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) + XCTAssertEqual(state.commandCompletedReceived("DELETE 1"), .succeedQuery(promise, with: .init(value: .noRows(.tag("DELETE 1")), logger: logger))) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testQueryWithSimpleQueryWithDataRowsHappyPath() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query = "SELECT version()" + let queryContext = SimpleQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .simpleQuery(queryContext)), .sendQuery(query)) + + // We need to ensure that even though the row description from the wire says that we + // will receive data in `.text` format, we will actually receive it in binary format, + // since we requested it in binary with our bind message. + let input: [RowDescription.Column] = [ + .init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, format: .text) + ] + let expected: [RowDescription.Column] = input.map { + .init(name: $0.name, tableOID: $0.tableOID, columnAttributeNumber: $0.columnAttributeNumber, dataType: $0.dataType, + dataTypeSize: $0.dataTypeSize, dataTypeModifier: $0.dataTypeModifier, format: .binary) + } + + XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) + let row1: DataRow = [ByteBuffer(string: "test1")] + let result = QueryResult(value: .rowDescription(expected), logger: queryContext.logger) + XCTAssertEqual(state.dataRowReceived(row1), .succeedQuery(promise, with: result)) + XCTAssertEqual(state.channelReadComplete(), .forwardRows([row1])) + XCTAssertEqual(state.readEventCaught(), .wait) + XCTAssertEqual(state.requestQueryRows(), .read) + + let row2: DataRow = [ByteBuffer(string: "test2")] + let row3: DataRow = [ByteBuffer(string: "test3")] + let row4: DataRow = [ByteBuffer(string: "test4")] + XCTAssertEqual(state.dataRowReceived(row2), .wait) + XCTAssertEqual(state.dataRowReceived(row3), .wait) + XCTAssertEqual(state.dataRowReceived(row4), .wait) + XCTAssertEqual(state.channelReadComplete(), .forwardRows([row2, row3, row4])) + XCTAssertEqual(state.requestQueryRows(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + let row5: DataRow = [ByteBuffer(string: "test5")] + let row6: DataRow = [ByteBuffer(string: "test6")] + XCTAssertEqual(state.dataRowReceived(row5), .wait) + XCTAssertEqual(state.dataRowReceived(row6), .wait) + + XCTAssertEqual(state.commandCompletedReceived("SELECT 2"), .forwardStreamComplete([row5, row6], commandTag: "SELECT 2")) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testQueryWithSimpleQueryWithNoQuery() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query = "-- some comments" + let queryContext = SimpleQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .simpleQuery(queryContext)), .sendQuery(query)) + XCTAssertEqual(state.emptyQueryResponseReceived(), .succeedQuery(promise, with: .init(value: .noRows(.emptyResponse), logger: logger))) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testReceiveTotallyUnexpectedMessageInQuery() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let nonExistentOID = 371280378 + let query = "SELECT * FROM pg_class WHERE oid = \(nonExistentOID)" + let queryContext = SimpleQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .simpleQuery(queryContext)), .sendQuery(query)) + + let input: [RowDescription.Column] = [ + .init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, format: .text) + ] + XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) + + let psqlError = PSQLError.unexpectedBackendMessage(.authentication(.ok)) + XCTAssertEqual(state.authenticationMessageReceived(.ok), + .failQuery(promise, with: psqlError, cleanupContext: .init(action: .close, tasks: [], error: psqlError, closePromise: nil))) + } + + func testQueryIsCancelledImmediatly() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query = "SELECT version()" + let queryContext = SimpleQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .simpleQuery(queryContext)), .sendQuery(query)) + + XCTAssertEqual(state.cancelQueryStream(), .failQuery(promise, with: .queryCancelled, cleanupContext: nil)) + + // The query was cancelled but it also ended anyways, so we accept that the query has succeeded + XCTAssertEqual(state.commandCompletedReceived("SELECT 2"), .succeedQuery(promise, with: .init(value: .noRows(.tag("SELECT 2")), logger: logger))) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testQueryIsCancelledWithReadPending() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query = "SELECT version()" + let queryContext = SimpleQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .simpleQuery(queryContext)), .sendQuery(query)) + + // We need to ensure that even though the row description from the wire says that we + // will receive data in `.text` format, we will actually receive it in binary format, + // since we requested it in binary with our bind message. + let input: [RowDescription.Column] = [ + .init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, format: .text) + ] + + XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) + XCTAssertEqual(state.cancelQueryStream(), .failQuery(promise, with: .queryCancelled, cleanupContext: nil)) + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test1")]), .wait) + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test2")]), .wait) + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test3")]), .wait) + XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test4")]), .wait) + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + XCTAssertEqual(state.commandCompletedReceived("SELECT 2"), .wait) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testCancelQueryAfterServerError() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query = "SELECT version()" + let queryContext = SimpleQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .simpleQuery(queryContext)), .sendQuery(query)) + + // We need to ensure that even though the row description from the wire says that we + // will receive data in `.text` format, we will actually receive it in binary format, + // since we requested it in binary with our bind message. + let input: [RowDescription.Column] = [ + .init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, format: .text) + ] + let expected: [RowDescription.Column] = input.map { + .init(name: $0.name, tableOID: $0.tableOID, columnAttributeNumber: $0.columnAttributeNumber, dataType: $0.dataType, + dataTypeSize: $0.dataTypeSize, dataTypeModifier: $0.dataTypeModifier, format: .binary) + } + + XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) + let result = QueryResult(value: .rowDescription(expected), logger: queryContext.logger) + let row1: DataRow = [ByteBuffer(string: "test1")] + XCTAssertEqual(state.dataRowReceived(row1), .succeedQuery(promise, with: result)) + + let dataRows2: [DataRow] = [ + [ByteBuffer(string: "test2")], + [ByteBuffer(string: "test3")], + [ByteBuffer(string: "test4")] + ] + for row in dataRows2 { + XCTAssertEqual(state.dataRowReceived(row), .wait) + } + + XCTAssertEqual(state.channelReadComplete(), .forwardRows([row1] + dataRows2)) + XCTAssertEqual(state.readEventCaught(), .wait) + XCTAssertEqual(state.requestQueryRows(), .read) + let dataRows3: [DataRow] = [ + [ByteBuffer(string: "test5")], + [ByteBuffer(string: "test6")], + [ByteBuffer(string: "test7")] + ] + for row in dataRows3 { + XCTAssertEqual(state.dataRowReceived(row), .wait) + } + let serverError = PostgresBackendMessage.ErrorResponse(fields: [.severity: "Error", .sqlState: "123"]) + XCTAssertEqual(state.errorReceived(serverError), .forwardStreamError(.server(serverError), read: false, cleanupContext: .none)) + + XCTAssertEqual(state.channelReadComplete(), .wait) + XCTAssertEqual(state.readEventCaught(), .read) + + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testQueryErrorDoesNotKillConnection() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query = "SELECT version()" + let queryContext = SimpleQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .simpleQuery(queryContext)), .sendQuery(query)) + + let serverError = PostgresBackendMessage.ErrorResponse(fields: [.severity: "Error", .sqlState: "123"]) + XCTAssertEqual( + state.errorReceived(serverError), .failQuery(promise, with: .server(serverError), cleanupContext: .none) + ) + + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + + func testQueryErrorAfterCancelDoesNotKillConnection() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query = "SELECT version()" + let queryContext = SimpleQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .simpleQuery(queryContext)), .sendQuery(query)) + XCTAssertEqual(state.cancelQueryStream(), .failQuery(promise, with: .queryCancelled, cleanupContext: .none)) + + let serverError = PostgresBackendMessage.ErrorResponse(fields: [.severity: "Error", .sqlState: "123"]) + XCTAssertEqual(state.errorReceived(serverError), .wait) + + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } +}