diff --git a/Sources/AWSLambdaRuntime/HTTPClient/LambdaRuntimeClient.swift b/Sources/AWSLambdaRuntime/HTTPClient/LambdaRuntimeClient.swift index 2a29a7821..f0be97413 100644 --- a/Sources/AWSLambdaRuntime/HTTPClient/LambdaRuntimeClient.swift +++ b/Sources/AWSLambdaRuntime/HTTPClient/LambdaRuntimeClient.swift @@ -265,9 +265,55 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol { } private func channelClosed(_ channel: any Channel) { + // Check if this is an old channel that we're already tracking as closed + // This handles the race condition where: + // 1. connectionWillClose() is called, adding the channel to closingConnections + // 2. A new connection is established (connectionState = .connected with new channel) + // 3. The old channel's closeFuture fires (closingState might be .closed) + // 4. We receive channelClosed() for the OLD channel while NEW channel is connected + if self.closingConnections.contains(where: { $0 === channel }) { + // If this channel is still the currently connected channel, let the main + // state-handling logic below run instead of treating it as an old channel. + if case .connected(let stateChannel, _) = self.connectionState, channel === stateChannel { + // Remove from tracking and fall through to the main switch statement + if let index = self.closingConnections.firstIndex(where: { $0 === channel }) { + self.closingConnections.remove(at: index) + } + } else { + // This is an old channel that's finishing its close operation + if let index = self.closingConnections.firstIndex(where: { $0 === channel }) { + self.closingConnections.remove(at: index) + } + + // If we're in closing state and all connections are now closed, complete the close + if case .closing(let continuation) = self.closingState, + self.closingConnections.isEmpty + { + self.closingState = .closed + continuation.resume() + } + + self.logger.trace( + "Old channel closed after new connection established", + metadata: ["channel": "\(channel)"] + ) + return + } + } + switch (self.connectionState, self.closingState) { case (_, .closed): - fatalError("Invalid state: \(self.connectionState), \(self.closingState)") + // This should not happen, but if it does, it means we're receiving a close + // notification for a channel after the runtime client has fully closed. + // Log it but don't crash - this could be a legitimate race condition. + self.logger.warning( + "Received channelClosed after closingState is .closed", + metadata: [ + "channel": "\(channel)", + "connectionState": "\(self.connectionState)", + ] + ) + return case (.disconnected, .notClosing): if let index = self.closingConnections.firstIndex(where: { $0 === channel }) { @@ -298,11 +344,35 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol { continuation.resume() } - case (.connected, .notClosing): - self.connectionState = .disconnected + case (.connected(let currentChannel, _), .notClosing): + // Only transition to disconnected if this is the CURRENT channel closing + if currentChannel === channel { + self.connectionState = .disconnected + } else { + // This is an old channel closing - remove from tracking + if let index = self.closingConnections.firstIndex(where: { $0 === channel }) { + self.closingConnections.remove(at: index) + } - case (.connected, .closing(let continuation)): - self.connectionState = .disconnected + self.logger.trace( + "Old channel closing while new connection is active", + metadata: [ + "closingChannel": "\(channel)", + "currentChannel": "\(currentChannel)", + ] + ) + } + + case (.connected(let currentChannel, _), .closing(let continuation)): + // Only transition to disconnected if this is the CURRENT channel closing + if currentChannel === channel { + self.connectionState = .disconnected + } else { + // This is an old channel closing - remove from tracking + if let index = self.closingConnections.firstIndex(where: { $0 === channel }) { + self.closingConnections.remove(at: index) + } + } if self.closingConnections.isEmpty { self.closingState = .closed @@ -369,7 +439,9 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol { self.assumeIsolated { runtimeClient in // close the channel runtimeClient.channelClosed(channel) - runtimeClient.connectionState = .disconnected + // Note: Do NOT set connectionState = .disconnected here! + // The channelClosed() method handles state transitions properly, + // checking if this is the current channel or an old one. } } @@ -431,7 +503,15 @@ extension LambdaRuntimeClient: LambdaChannelHandlerDelegate { case .connected(let stateChannel, _): guard channel === stateChannel else { + // This is an old channel closing - add to tracking isolated.closingConnections.append(channel) + isolated.logger.trace( + "Old channel will close while new connection is active", + metadata: [ + "closingChannel": "\(channel)", + "currentChannel": "\(stateChannel)", + ] + ) return } diff --git a/Tests/AWSLambdaRuntimeTests/LambdaRuntimeClientChannelLifecycleTests.swift b/Tests/AWSLambdaRuntimeTests/LambdaRuntimeClientChannelLifecycleTests.swift new file mode 100644 index 000000000..d838577c3 --- /dev/null +++ b/Tests/AWSLambdaRuntimeTests/LambdaRuntimeClientChannelLifecycleTests.swift @@ -0,0 +1,484 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftAWSLambdaRuntime open source project +// +// Copyright SwiftAWSLambdaRuntime project authors +// Copyright (c) Amazon.com, Inc. or its affiliates. +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIOCore +import NIOHTTP1 +import NIOPosix +import Testing + +import struct Foundation.UUID + +@testable import AWSLambdaRuntime + +/// Tests for channel lifecycle race condition fixes (Bug #624) +/// These tests verify: +/// 1. Old channels are properly removed from closingConnections +/// 2. Shutdown completes even with old channels closing +/// 3. Current channel always goes through proper state transitions +@Suite +struct LambdaRuntimeClientChannelLifecycleTests { + + let logger = { + var logger = Logger(label: "ChannelLifecycleTest") + // Uncomment the line below to enable trace-level logging for debugging purposes. + // logger.logLevel = .trace + return logger + }() + + // MARK: - Test Fix #1: Old channels removed from closingConnections in .connected + .notClosing + + @Test("Old channel closing in connected state is properly cleaned up") + @available(LambdaSwift 2.0, *) + func testOldChannelCleanupInConnectedNotClosing() async throws { + // This test simulates the scenario where: + // 1. A connection is established + // 2. Server closes the connection (old channel) + // 3. Client reconnects (new channel) + // 4. Old channel's closeFuture fires while new connection is active + // Expected: Old channel should be removed from closingConnections + + struct ReconnectBehavior: LambdaServerBehavior { + let requestId = UUID().uuidString + var invocationCount = 0 + + func getInvocation() -> GetInvocationResult { + // First invocation succeeds, then trigger disconnect + if invocationCount == 0 { + return .success((requestId, "first")) + } else { + return .success(("disconnect", "0")) + } + } + + func processResponse(requestId: String, response: String?) -> Result { + // After first response, signal delayed disconnect + if response == "first" { + return .success("delayed-disconnect") + } + return .success(nil) + } + + func processError(requestId: String, error: ErrorResponse) -> Result { + Issue.record("should not report error") + return .failure(.internalServerError) + } + + func processInitError(error: ErrorResponse) -> Result { + Issue.record("should not report init error") + return .failure(.internalServerError) + } + } + + try await withMockServer(behaviour: ReconnectBehavior()) { port in + let configuration = LambdaRuntimeClient.Configuration(ip: "127.0.0.1", port: port) + + try await LambdaRuntimeClient.withRuntimeClient( + configuration: configuration, + eventLoop: NIOSingletons.posixEventLoopGroup.next(), + logger: self.logger + ) { runtimeClient in + do { + // First invocation - will succeed + let (invocation, writer) = try await runtimeClient.nextInvocation() + #expect(invocation.event == ByteBuffer(string: "first")) + try await writer.writeAndFinish(ByteBuffer(string: "first")) + + // Give time for server to close connection + try await Task.sleep(for: .milliseconds(100)) + + // Second invocation - should reconnect and then get disconnect signal + // This tests that old channel cleanup doesn't interfere with new connection + let _ = try? await runtimeClient.nextInvocation() + + // If we reach here without freezing, the test passes + // The old channel was properly cleaned up + } catch { + // Expected to fail with connection error, which is fine + // The important part is that we don't freeze + } + } + } + } + + // MARK: - Test Fix #2: Old channels removed from closingConnections in .connected + .closing + + @Test("Old channel closing during shutdown doesn't prevent completion") + @available(LambdaSwift 2.0, *) + func testOldChannelCleanupDuringShutdown() async throws { + // This test simulates the scenario where: + // 1. A connection is established + // 2. Server closes connection (old channel starts closing) + // 3. Client reconnects (new channel) + // 4. Client initiates shutdown + // 5. Old channel's closeFuture fires during shutdown + // Expected: Shutdown should complete without freezeing + + struct ShutdownWithOldChannelBehavior: LambdaServerBehavior { + let requestId = UUID().uuidString + var responseCount = 0 + + func getInvocation() -> GetInvocationResult { + .success((requestId, "event")) + } + + func processResponse(requestId: String, response: String?) -> Result { + // First response triggers delayed disconnect + if responseCount == 0 { + return .success("delayed-disconnect") + } + return .success(nil) + } + + func processError(requestId: String, error: ErrorResponse) -> Result { + Issue.record("should not report error") + return .failure(.internalServerError) + } + + func processInitError(error: ErrorResponse) -> Result { + Issue.record("should not report init error") + return .failure(.internalServerError) + } + } + + // Use a timeout to ensure shutdown completes + try await withTimeout(deadline: .seconds(5)) { + try await withMockServer(behaviour: ShutdownWithOldChannelBehavior()) { port in + let configuration = LambdaRuntimeClient.Configuration(ip: "127.0.0.1", port: port) + + try await LambdaRuntimeClient.withRuntimeClient( + configuration: configuration, + eventLoop: NIOSingletons.posixEventLoopGroup.next(), + logger: self.logger + ) { runtimeClient in + // First invocation + let (_, writer) = try await runtimeClient.nextInvocation() + try await writer.writeAndFinish(ByteBuffer(string: "response")) + + // Give time for server to close connection + try await Task.sleep(for: .milliseconds(100)) + + // Try to get next invocation (will reconnect) + let _ = try? await runtimeClient.nextInvocation() + + // Shutdown is triggered automatically when exiting withRuntimeClient + // If old channel cleanup is working, this should complete quickly + } + } + } + // If we reach here without timeout, the test passes + } + + @Test("Shutdown completes when multiple old channels are closing") + @available(LambdaSwift 2.0, *) + func testShutdownWithMultipleOldChannels() async throws { + // This test simulates multiple reconnections followed by shutdown + // to ensure all old channels are properly tracked and cleaned up + + struct MultipleReconnectBehavior: LambdaServerBehavior { + var invocationCount = 0 + + func getInvocation() -> GetInvocationResult { + let requestId = UUID().uuidString + return .success((requestId, "event-\(invocationCount)")) + } + + func processResponse(requestId: String, response: String?) -> Result { + // Trigger disconnect after each response to force reconnection + .success("delayed-disconnect") + } + + func processError(requestId: String, error: ErrorResponse) -> Result { + Issue.record("should not report error") + return .failure(.internalServerError) + } + + func processInitError(error: ErrorResponse) -> Result { + Issue.record("should not report init error") + return .failure(.internalServerError) + } + } + + // Use a timeout to ensure shutdown completes + try await withTimeout(deadline: .seconds(5)) { + try await withMockServer(behaviour: MultipleReconnectBehavior()) { port in + let configuration = LambdaRuntimeClient.Configuration(ip: "127.0.0.1", port: port) + + try await LambdaRuntimeClient.withRuntimeClient( + configuration: configuration, + eventLoop: NIOSingletons.posixEventLoopGroup.next(), + logger: self.logger + ) { runtimeClient in + // Perform multiple invocations with reconnections + for _ in 0..<3 { + do { + let (_, writer) = try await runtimeClient.nextInvocation() + try await writer.writeAndFinish(ByteBuffer(string: "response")) + // Give time for disconnect + try await Task.sleep(for: .milliseconds(50)) + } catch { + // Connection errors are expected + break + } + } + + // Shutdown happens automatically + // All old channels should be cleaned up properly + } + } + } + // If we reach here without timeout, the test passes + } + + // MARK: - Test Fix #3: Current channel goes through proper state transitions + + @Test("Current channel closing triggers proper state transition") + @available(LambdaSwift 2.0, *) + func testCurrentChannelStateTransition() async throws { + // This test verifies that when the current channel closes, + // it goes through the main switch statement and properly + // transitions to disconnected state (not taking early return) + + struct CurrentChannelCloseBehavior: LambdaServerBehavior { + let requestId = UUID().uuidString + + func getInvocation() -> GetInvocationResult { + .success((requestId, "event")) + } + + func processResponse(requestId: String, response: String?) -> Result { + // Close connection after response + .success("delayed-disconnect") + } + + func processError(requestId: String, error: ErrorResponse) -> Result { + Issue.record("should not report error") + return .failure(.internalServerError) + } + + func processInitError(error: ErrorResponse) -> Result { + Issue.record("should not report init error") + return .failure(.internalServerError) + } + } + + try await withMockServer(behaviour: CurrentChannelCloseBehavior()) { port in + let configuration = LambdaRuntimeClient.Configuration(ip: "127.0.0.1", port: port) + + try await LambdaRuntimeClient.withRuntimeClient( + configuration: configuration, + eventLoop: NIOSingletons.posixEventLoopGroup.next(), + logger: self.logger + ) { runtimeClient in + // First invocation + let (_, writer) = try await runtimeClient.nextInvocation() + try await writer.writeAndFinish(ByteBuffer(string: "response")) + + // Give time for connection to close + try await Task.sleep(for: .milliseconds(100)) + + // Try next invocation - should reconnect successfully + // This verifies that the state transition happened correctly + do { + let (invocation, writer2) = try await runtimeClient.nextInvocation() + #expect(invocation.event == ByteBuffer(string: "event")) + try await writer2.writeAndFinish(ByteBuffer(string: "response")) + } catch { + // Connection error is acceptable here + // The important part is that we didn't freeze or crash + } + } + } + } + + @Test("Current channel in channelsBeingClosed doesn't take early return") + @available(LambdaSwift 2.0, *) + func testCurrentChannelNotTreatedAsOld() async throws { + // This test specifically verifies the fix for Comment #3: + // When a channel is in channelsBeingClosed but is still the current channel, + // it should NOT take the early return path + + struct ImmediateCloseBehavior: LambdaServerBehavior { + var invocationCount = 0 + + func getInvocation() -> GetInvocationResult { + let requestId = UUID().uuidString + // Return disconnect on second call to close current channel + if invocationCount > 0 { + return .success(("disconnect", "0")) + } + return .success((requestId, "event")) + } + + func processResponse(requestId: String, response: String?) -> Result { + .success(nil) + } + + func processError(requestId: String, error: ErrorResponse) -> Result { + Issue.record("should not report error") + return .failure(.internalServerError) + } + + func processInitError(error: ErrorResponse) -> Result { + Issue.record("should not report init error") + return .failure(.internalServerError) + } + } + + try await withMockServer(behaviour: ImmediateCloseBehavior()) { port in + let configuration = LambdaRuntimeClient.Configuration(ip: "127.0.0.1", port: port) + + try await LambdaRuntimeClient.withRuntimeClient( + configuration: configuration, + eventLoop: NIOSingletons.posixEventLoopGroup.next(), + logger: self.logger + ) { runtimeClient in + // First invocation succeeds + let (_, writer) = try await runtimeClient.nextInvocation() + try await writer.writeAndFinish(ByteBuffer(string: "response")) + + // Second invocation will trigger disconnect + // The current channel will be added to channelsBeingClosed + // but should still go through proper state transition + do { + let _ = try await runtimeClient.nextInvocation() + // Note: The disconnect might not happen immediately, + // so we may get a successful invocation here. + // The important part is that we don't freeze or crash. + } catch let error as LambdaRuntimeError { + // Expected error - connection was closed + #expect(error.code == .connectionToControlPlaneLost) + } catch { + // Other errors are also acceptable (ChannelError, IOError, etc.) + } + + // If we reach here without freezeing, the state transition worked correctly + } + } + } + + // MARK: - Integration Tests + + @Test("Rapid reconnections with old channels closing don't cause issues") + @available(LambdaSwift 2.0, *) + func testRapidReconnections() async throws { + // This integration test simulates the exact race condition from Bug #624: + // Rapid connection recycling where old channels close while new ones connect + + struct RapidReconnectBehavior: LambdaServerBehavior { + var count = 0 + + func getInvocation() -> GetInvocationResult { + let requestId = UUID().uuidString + return .success((requestId, "event")) + } + + func processResponse(requestId: String, response: String?) -> Result { + // Alternate between normal and delayed disconnect + // to create timing variations + if count % 2 == 0 { + return .success("delayed-disconnect") + } + return .success(nil) + } + + func processError(requestId: String, error: ErrorResponse) -> Result { + Issue.record("should not report error") + return .failure(.internalServerError) + } + + func processInitError(error: ErrorResponse) -> Result { + Issue.record("should not report init error") + return .failure(.internalServerError) + } + } + + try await withTimeout(deadline: .seconds(10)) { + try await withMockServer(behaviour: RapidReconnectBehavior()) { port in + let configuration = LambdaRuntimeClient.Configuration(ip: "127.0.0.1", port: port) + + try await LambdaRuntimeClient.withRuntimeClient( + configuration: configuration, + eventLoop: NIOSingletons.posixEventLoopGroup.next(), + logger: self.logger + ) { runtimeClient in + // Perform many rapid invocations + for i in 0..<10 { + do { + let (_, writer) = try await runtimeClient.nextInvocation() + try await writer.writeAndFinish(ByteBuffer(string: "response-\(i)")) + + // Very short delay to create race conditions + try await Task.sleep(for: .milliseconds(10)) + } catch { + // Connection errors are expected and acceptable + // The important part is we don't crash or freeze + break + } + } + } + } + } + // If we complete without timeout or crash, all fixes are working + } + + @Test("Concurrent operations during channel lifecycle transitions") + @available(LambdaSwift 2.0, *) + func testConcurrentOperationsDuringTransitions() async throws { + // This test verifies that concurrent operations don't interfere + // with channel lifecycle management + + struct ConcurrentBehavior: LambdaServerBehavior { + func getInvocation() -> GetInvocationResult { + let requestId = UUID().uuidString + return .success((requestId, "event")) + } + + func processResponse(requestId: String, response: String?) -> Result { + .success(nil) + } + + func processError(requestId: String, error: ErrorResponse) -> Result { + .success(()) + } + + func processInitError(error: ErrorResponse) -> Result { + Issue.record("should not report init error") + return .failure(.internalServerError) + } + } + + try await withTimeout(deadline: .seconds(5)) { + try await withMockServer(behaviour: ConcurrentBehavior()) { port in + let configuration = LambdaRuntimeClient.Configuration(ip: "127.0.0.1", port: port) + + try await LambdaRuntimeClient.withRuntimeClient( + configuration: configuration, + eventLoop: NIOSingletons.posixEventLoopGroup.next(), + logger: self.logger + ) { runtimeClient in + // Perform a few invocations to establish connection + for _ in 0..<3 { + let (_, writer) = try await runtimeClient.nextInvocation() + try await writer.writeAndFinish(ByteBuffer(string: "response")) + } + + // Shutdown happens automatically + // All channels should be properly cleaned up + } + } + } + } +}