Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix leakage of promises #497

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ struct CloseStateMachine {
}

enum Action {
case sendCloseSync(CloseTarget)
case sendCloseSync(CloseTarget, promise: EventLoopPromise<Void>?)
case succeedClose(CloseCommandContext)
case failClose(CloseCommandContext, with: PSQLError)

Expand All @@ -24,14 +24,14 @@ struct CloseStateMachine {
self.state = .initialized(closeContext)
}

mutating func start() -> Action {
mutating func start(_ promise: EventLoopPromise<Void>?) -> Action {
guard case .initialized(let closeContext) = self.state else {
preconditionFailure("Start should only be called, if the query has been initialized")
}

self.state = .closeSyncSent(closeContext)

return .sendCloseSync(closeContext.target)
return .sendCloseSync(closeContext.target, promise: promise)
}

mutating func closeCompletedReceived() -> Action {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ struct ConnectionStateMachine {
// Connection Actions

// --- general actions
case sendParseDescribeBindExecuteSync(PostgresQuery)
case sendBindExecuteSync(PSQLExecuteStatement)
case sendParseDescribeBindExecuteSync(PostgresQuery, promise: EventLoopPromise<Void>?)
case sendBindExecuteSync(PSQLExecuteStatement, promise: EventLoopPromise<Void>?)
case failQuery(EventLoopPromise<PSQLRowStream>, with: PSQLError, cleanupContext: CleanUpContext?)
case succeedQuery(EventLoopPromise<PSQLRowStream>, with: QueryResult)

Expand All @@ -97,12 +97,12 @@ struct ConnectionStateMachine {
case forwardStreamError(PSQLError, read: Bool, cleanupContext: CleanUpContext?)

// Prepare statement actions
case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType])
case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType], promise: EventLoopPromise<Void>?)
case succeedPreparedStatementCreation(EventLoopPromise<RowDescription?>, with: RowDescription?)
case failPreparedStatementCreation(EventLoopPromise<RowDescription?>, with: PSQLError, cleanupContext: CleanUpContext?)

// Close actions
case sendCloseSync(CloseTarget)
case sendCloseSync(CloseTarget, promise: EventLoopPromise<Void>?)
case succeedClose(CloseCommandContext)
case failClose(CloseCommandContext, with: PSQLError, cleanupContext: CleanUpContext?)
}
Expand Down Expand Up @@ -234,7 +234,7 @@ struct ConnectionStateMachine {
}
self.state = .sslNegotiated
return .establishSSLConnection

case .initialized,
.sslNegotiated,
.sslHandlerAdded,
Expand All @@ -247,7 +247,7 @@ struct ConnectionStateMachine {
.closing,
.closed:
return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.sslSupported))

case .modifying:
preconditionFailure("Invalid state: \(self.state)")
}
Expand Down Expand Up @@ -583,14 +583,16 @@ struct ConnectionStateMachine {
}

switch task {
case .extendedQuery(let queryContext):
case .extendedQuery(let queryContext, let writePromise):
writePromise?.fail(psqlErrror) /// Use `cleanupContext` or not?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment

switch queryContext.query {
case .executeStatement(_, let promise), .unnamed(_, let promise):
return .failQuery(promise, with: psqlErrror, cleanupContext: nil)
case .prepareStatement(_, _, _, let promise):
return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil)
}
case .closeCommand(let closeContext):
case .closeCommand(let closeContext, let writePromise):
writePromise?.fail(psqlErrror) /// Use `cleanupContext` or not?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment

return .failClose(closeContext, with: psqlErrror, cleanupContext: nil)
}
}
Expand Down Expand Up @@ -934,17 +936,17 @@ struct ConnectionStateMachine {
}

switch task {
case .extendedQuery(let queryContext):
case .extendedQuery(let queryContext, let promise):
self.state = .modifying // avoid CoW
var extendedQuery = ExtendedQueryStateMachine(queryContext: queryContext)
let action = extendedQuery.start()
let action = extendedQuery.start(promise)
self.state = .extendedQuery(extendedQuery, connectionContext)
return self.modify(with: action)

case .closeCommand(let closeContext):
case .closeCommand(let closeContext, let promise):
self.state = .modifying // avoid CoW
var closeStateMachine = CloseStateMachine(closeContext: closeContext)
let action = closeStateMachine.start()
let action = closeStateMachine.start(promise)
self.state = .closeCommand(closeStateMachine, connectionContext)
return self.modify(with: action)
}
Expand Down Expand Up @@ -1031,10 +1033,10 @@ extension ConnectionStateMachine {
extension ConnectionStateMachine {
mutating func modify(with action: ExtendedQueryStateMachine.Action) -> ConnectionStateMachine.ConnectionAction {
switch action {
case .sendParseDescribeBindExecuteSync(let query):
return .sendParseDescribeBindExecuteSync(query)
case .sendBindExecuteSync(let executeStatement):
return .sendBindExecuteSync(executeStatement)
case .sendParseDescribeBindExecuteSync(let query, let promise):
return .sendParseDescribeBindExecuteSync(query, promise: promise)
case .sendBindExecuteSync(let executeStatement, let promise):
return .sendBindExecuteSync(executeStatement, promise: promise)
case .failQuery(let requestContext, with: let error):
let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error)
return .failQuery(requestContext, with: error, cleanupContext: cleanupContext)
Expand All @@ -1057,8 +1059,8 @@ extension ConnectionStateMachine {
return .read
case .wait:
return .wait
case .sendParseDescribeSync(name: let name, query: let query, bindingDataTypes: let bindingDataTypes):
return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes)
case .sendParseDescribeSync(name: let name, query: let query, bindingDataTypes: let bindingDataTypes, let promise):
return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes, promise: promise)
case .succeedPreparedStatementCreation(let promise, with: let rowDescription):
return .succeedPreparedStatementCreation(promise, with: rowDescription)
case .failPreparedStatementCreation(let promise, with: let error):
Expand Down Expand Up @@ -1094,8 +1096,8 @@ extension ConnectionStateMachine {
extension ConnectionStateMachine {
mutating func modify(with action: CloseStateMachine.Action) -> ConnectionStateMachine.ConnectionAction {
switch action {
case .sendCloseSync(let sendClose):
return .sendCloseSync(sendClose)
case .sendCloseSync(let sendClose, let promise):
return .sendCloseSync(sendClose, promise: promise)
case .succeedClose(let closeContext):
return .succeedClose(closeContext)
case .failClose(let closeContext, with: let error):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ struct ExtendedQueryStateMachine {
}

enum Action {
case sendParseDescribeBindExecuteSync(PostgresQuery)
case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType])
case sendBindExecuteSync(PSQLExecuteStatement)
case sendParseDescribeBindExecuteSync(PostgresQuery, promise: EventLoopPromise<Void>?)
case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType], promise: EventLoopPromise<Void>?)
case sendBindExecuteSync(PSQLExecuteStatement, promise: EventLoopPromise<Void>?)

// --- general actions
case failQuery(EventLoopPromise<PSQLRowStream>, with: PSQLError)
case succeedQuery(EventLoopPromise<PSQLRowStream>, with: QueryResult)
Expand Down Expand Up @@ -57,7 +57,7 @@ struct ExtendedQueryStateMachine {
self.state = .initialized(queryContext)
}

mutating func start() -> Action {
mutating func start(_ promise: EventLoopPromise<Void>?) -> Action {
guard case .initialized(let queryContext) = self.state else {
preconditionFailure("Start should only be called, if the query has been initialized")
}
Expand All @@ -66,7 +66,7 @@ struct ExtendedQueryStateMachine {
case .unnamed(let query, _):
return self.avoidingStateMachineCoW { state -> Action in
state = .messagesSent(queryContext)
return .sendParseDescribeBindExecuteSync(query)
return .sendParseDescribeBindExecuteSync(query, promise: promise)
}

case .executeStatement(let prepared, _):
Expand All @@ -77,13 +77,14 @@ struct ExtendedQueryStateMachine {
case .none:
state = .noDataMessageReceived(queryContext)
}
return .sendBindExecuteSync(prepared)
return .sendBindExecuteSync(prepared, promise: promise)
}

/// Not my code, but this is ignoring the last argument which is a promise? is that fine?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment

case .prepareStatement(let name, let query, let bindingDataTypes, _):
return self.avoidingStateMachineCoW { state -> Action in
state = .messagesSent(queryContext)
return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes)
return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes, promise: promise)
}
}
}
Expand Down
11 changes: 8 additions & 3 deletions Sources/PostgresNIO/New/NotificationListener.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ final class NotificationListener: @unchecked Sendable {
self.state = .closure(context, closure)
}

func startListeningSucceeded(handler: PostgresChannelHandler) {
func startListeningSucceeded(
handler: PostgresChannelHandler,
writePromise: EventLoopPromise<Void>?
) {
self.eventLoop.preconditionInEventLoop()
let handlerLoopBound = NIOLoopBound(handler, eventLoop: self.eventLoop)

Expand All @@ -56,7 +59,7 @@ final class NotificationListener: @unchecked Sendable {
switch reason {
case .cancelled:
eventLoop.execute {
handlerLoopBound.value.cancelNotificationListener(channel: channel, id: listenerID)
handlerLoopBound.value.cancelNotificationListener(channel: channel, id: listenerID, writePromise: nil)
}

case .finished:
Expand All @@ -70,12 +73,14 @@ final class NotificationListener: @unchecked Sendable {

let notificationSequence = PostgresNotificationSequence(base: stream)
checkedContinuation.resume(returning: notificationSequence)
writePromise?.succeed(())

case .streamListening, .done:
fatalError("Invalid state: \(self.state)")

case .closure:
break // ignore
writePromise?.succeed(())
// ignore
}
}

Expand Down
10 changes: 6 additions & 4 deletions Sources/PostgresNIO/New/PSQLTask.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ enum HandlerTask {
}

enum PSQLTask {
case extendedQuery(ExtendedQueryContext)
case closeCommand(CloseCommandContext)
case extendedQuery(ExtendedQueryContext, writePromise: EventLoopPromise<Void>?)
case closeCommand(CloseCommandContext, writePromise: EventLoopPromise<Void>?)

func failWithError(_ error: PSQLError) {
switch self {
case .extendedQuery(let extendedQueryContext):
case .extendedQuery(let extendedQueryContext, let writePromise):
switch extendedQueryContext.query {
case .unnamed(_, let eventLoopPromise):
eventLoopPromise.fail(error)
Expand All @@ -24,9 +24,11 @@ enum PSQLTask {
case .prepareStatement(_, _, _, let eventLoopPromise):
eventLoopPromise.fail(error)
}
writePromise?.fail(error)

case .closeCommand(let closeCommandContext):
case .closeCommand(let closeCommandContext, let writePromise):
closeCommandContext.promise.fail(error)
writePromise?.fail(error)
}
}
}
Expand Down
Loading
Loading