Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 86 additions & 9 deletions Sources/AWSLambdaRuntime/HTTPClient/LambdaRuntimeClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
// being fully closed before we can return from it.
private var closingConnections: [any Channel] = []

// Track channels that are in the process of closing to handle race conditions
// where an old channel's closeFuture fires after a new connection is established
private var channelsBeingClosed: Set<ObjectIdentifier> = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between closingConnections and channelsBeingClosed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmmh, in my view, channelsBeingClosed provides early detection at the top of channelClosed() to identify old channels and handle them gracefully without touching the main state machine. closingConnections ensures we wait for all channels (both current and old) to fully close before completing shutdown.

But you're question makes me think more about this (thank you) and maybe, these can be merged to keep only one source of truth.

If we keep just closingConnections, we will write

// Instead of:
if channelsBeingClosed.contains(channelID) { ... }

// We could do:
if self.closingConnections.contains(where: { $0 === channel }) { ... }

Let me test that

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adam-fowler Thank you for the question. I simplified and now use only closingConnections to keep track of closingc onnections. One single source of truth and a simpler mental model


@inlinable
static func withRuntimeClient<Result>(
configuration: Configuration,
Expand Down Expand Up @@ -265,9 +269,52 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
}

private func channelClosed(_ channel: any Channel) {
let channelID = ObjectIdentifier(channel)

// Check if this is an old channel that we're already tracking as closed
// This handles the race condition where:
// 1. connectionWillClose() is called, marking the channel as closing
// 2. A new connection is established (connectionState = .connected with new channel)
// 3. The old channel's closeFuture fires (closingState might be .closed)
// 4. We receive channelClosed() for the OLD channel while NEW channel is connected
if channelsBeingClosed.contains(channelID) {
// This is an old channel that's finishing its close operation
channelsBeingClosed.remove(channelID)

// Also remove from closingConnections if present
if let index = self.closingConnections.firstIndex(where: { $0 === channel }) {
self.closingConnections.remove(at: index)
}

// If we're in closing state and all connections are now closed, complete the close
if case .closing(let continuation) = self.closingState,
self.closingConnections.isEmpty,
channelsBeingClosed.isEmpty
{
self.closingState = .closed
continuation.resume()
}

self.logger.trace(
"Old channel closed after new connection established",
metadata: ["channel": "\(channel)"]
)
return
}

switch (self.connectionState, self.closingState) {
case (_, .closed):
fatalError("Invalid state: \(self.connectionState), \(self.closingState)")
// This should not happen, but if it does, it means we're receiving a close
// notification for a channel after the runtime client has fully closed.
// Log it but don't crash - this could be a legitimate race condition.
self.logger.warning(
"Received channelClosed after closingState is .closed",
metadata: [
"channel": "\(channel)",
"connectionState": "\(self.connectionState)",
]
)
return

case (.disconnected, .notClosing):
if let index = self.closingConnections.firstIndex(where: { $0 === channel }) {
Expand All @@ -279,7 +326,7 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
self.closingConnections.remove(at: index)
}

if self.closingConnections.isEmpty {
if self.closingConnections.isEmpty && channelsBeingClosed.isEmpty {
self.closingState = .closed
continuation.resume()
}
Expand All @@ -293,18 +340,33 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
case (.connecting(let array), .closing(let continuation)):
self.connectionState = .disconnected
precondition(array.isEmpty, "If we are closing we should have failed all connection attempts already")
if self.closingConnections.isEmpty {
if self.closingConnections.isEmpty && channelsBeingClosed.isEmpty {
self.closingState = .closed
continuation.resume()
}

case (.connected, .notClosing):
self.connectionState = .disconnected
case (.connected(let currentChannel, _), .notClosing):
// Only transition to disconnected if this is the CURRENT channel closing
if currentChannel === channel {
self.connectionState = .disconnected
} else {
// This is an old channel closing - just track it
self.logger.trace(
"Old channel closing while new connection is active",
metadata: [
"closingChannel": "\(channel)",
"currentChannel": "\(currentChannel)",
]
)
}

case (.connected, .closing(let continuation)):
self.connectionState = .disconnected
case (.connected(let currentChannel, _), .closing(let continuation)):
// Only transition to disconnected if this is the CURRENT channel closing
if currentChannel === channel {
self.connectionState = .disconnected
}

if self.closingConnections.isEmpty {
if self.closingConnections.isEmpty && channelsBeingClosed.isEmpty {
self.closingState = .closed
continuation.resume()
}
Expand Down Expand Up @@ -369,7 +431,9 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
self.assumeIsolated { runtimeClient in
// close the channel
runtimeClient.channelClosed(channel)
runtimeClient.connectionState = .disconnected
// Note: Do NOT set connectionState = .disconnected here!
// The channelClosed() method handles state transitions properly,
// checking if this is the current channel or an old one.
}
}

Expand Down Expand Up @@ -412,6 +476,11 @@ extension LambdaRuntimeClient: LambdaChannelHandlerDelegate {

nonisolated func connectionWillClose(channel: any Channel) {
self.assumeIsolated { isolated in
let channelID = ObjectIdentifier(channel)

// Mark this channel as being closed to track it through the close lifecycle
isolated.channelsBeingClosed.insert(channelID)

switch isolated.connectionState {
case .disconnected:
// this case should never happen. But whatever
Expand All @@ -431,7 +500,15 @@ extension LambdaRuntimeClient: LambdaChannelHandlerDelegate {

case .connected(let stateChannel, _):
guard channel === stateChannel else {
// This is an old channel closing - add to tracking
isolated.closingConnections.append(channel)
isolated.logger.trace(
"Old channel will close while new connection is active",
metadata: [
"closingChannel": "\(channel)",
"currentChannel": "\(stateChannel)",
]
)
return
}

Expand Down