From 8d9f44b7838750b103ad2ad49055e8333d719e8a Mon Sep 17 00:00:00 2001 From: tomer doron Date: Thu, 18 Jan 2024 13:48:10 -0800 Subject: [PATCH] allow custom initialization of the HandlerType of the LambdaRuntime (#310) Motivation: Provide the flexibility for custom initialization of the HandlerType as this will often be required by higher level frameworks. Modifications: * Modify the LambdaRuntime type to accept a closure to provide the handler rather than requiring that it is provided by a static method on the Handler type * Update downstream code to use HandlerProvider * Update upstream code to support passing Handler Type of Handler Provider * Add and update tests Originally suggested and coded by @tachyonics in https://github.com/swift-server/swift-aws-lambda-runtime/pull/308 --- Sources/AWSLambdaRuntimeCore/Lambda.swift | 26 +++- .../AWSLambdaRuntimeCore/LambdaHandler.swift | 24 ++- .../AWSLambdaRuntimeCore/LambdaRunner.swift | 10 +- .../AWSLambdaRuntimeCore/LambdaRuntime.swift | 138 ++++++++++++++++-- .../LambdaRunnerTest.swift | 125 ++++++++++++++++ Tests/AWSLambdaRuntimeCoreTests/Utils.swift | 62 +++++++- 6 files changed, 359 insertions(+), 26 deletions(-) diff --git a/Sources/AWSLambdaRuntimeCore/Lambda.swift b/Sources/AWSLambdaRuntimeCore/Lambda.swift index ae225d01..f23dc57c 100644 --- a/Sources/AWSLambdaRuntimeCore/Lambda.swift +++ b/Sources/AWSLambdaRuntimeCore/Lambda.swift @@ -38,7 +38,7 @@ public enum Lambda { configuration: LambdaConfiguration = .init(), handlerType: Handler.Type ) -> Result { - Self.run(configuration: configuration, handlerType: CodableSimpleLambdaHandler.self) + Self.run(configuration: configuration, handlerProvider: CodableSimpleLambdaHandler.makeHandler(context:)) } /// Run a Lambda defined by implementing the ``LambdaHandler`` protocol. @@ -54,7 +54,7 @@ public enum Lambda { configuration: LambdaConfiguration = .init(), handlerType: Handler.Type ) -> Result { - Self.run(configuration: configuration, handlerType: CodableLambdaHandler.self) + Self.run(configuration: configuration, handlerProvider: CodableLambdaHandler.makeHandler(context:)) } /// Run a Lambda defined by implementing the ``EventLoopLambdaHandler`` protocol. @@ -70,7 +70,7 @@ public enum Lambda { configuration: LambdaConfiguration = .init(), handlerType: Handler.Type ) -> Result { - Self.run(configuration: configuration, handlerType: CodableEventLoopLambdaHandler.self) + Self.run(configuration: configuration, handlerProvider: CodableEventLoopLambdaHandler.makeHandler(context:)) } /// Run a Lambda defined by implementing the ``ByteBufferLambdaHandler`` protocol. @@ -85,6 +85,19 @@ public enum Lambda { internal static func run( configuration: LambdaConfiguration = .init(), handlerType: (some ByteBufferLambdaHandler).Type + ) -> Result { + Self.run(configuration: configuration, handlerProvider: handlerType.makeHandler(context:)) + } + + /// Run a Lambda defined by implementing the ``LambdaRuntimeHandler`` protocol. + /// - parameters: + /// - configuration: A Lambda runtime configuration object + /// - handlerProvider: A provider of the ``LambdaRuntimeHandler`` to invoke. + /// + /// - note: This is a blocking operation that will run forever, as its lifecycle is managed by the AWS Lambda Runtime Engine. + internal static func run( + configuration: LambdaConfiguration = .init(), + handlerProvider: @escaping (LambdaInitializationContext) -> EventLoopFuture ) -> Result { let _run = { (configuration: LambdaConfiguration) -> Result in #if swift(<5.9) @@ -95,7 +108,12 @@ public enum Lambda { var result: Result! MultiThreadedEventLoopGroup.withCurrentThreadAsEventLoop { eventLoop in - let runtime = LambdaRuntime(handlerType: handlerType, eventLoop: eventLoop, logger: logger, configuration: configuration) + let runtime = LambdaRuntime( + handlerProvider: handlerProvider, + eventLoop: eventLoop, + logger: logger, + configuration: configuration + ) #if DEBUG let signalSource = trap(signal: configuration.lifecycle.stopSignal) { signal in logger.info("intercepted signal: \(signal)") diff --git a/Sources/AWSLambdaRuntimeCore/LambdaHandler.swift b/Sources/AWSLambdaRuntimeCore/LambdaHandler.swift index fc3611ba..3a7e3c27 100644 --- a/Sources/AWSLambdaRuntimeCore/LambdaHandler.swift +++ b/Sources/AWSLambdaRuntimeCore/LambdaHandler.swift @@ -398,7 +398,7 @@ extension EventLoopLambdaHandler { /// - note: This is a low level protocol designed to power the higher level ``EventLoopLambdaHandler`` and /// ``LambdaHandler`` based APIs. /// Most users are not expected to use this protocol. -public protocol ByteBufferLambdaHandler { +public protocol ByteBufferLambdaHandler: LambdaRuntimeHandler { /// Create a Lambda handler for the runtime. /// /// Use this to initialize all your resources that you want to cache between invocations. This could be database @@ -433,6 +433,28 @@ extension ByteBufferLambdaHandler { } } +// MARK: - LambdaRuntimeHandler + +/// An `EventLoopFuture` based processing protocol for a Lambda that takes a `ByteBuffer` and returns +/// an optional `ByteBuffer` asynchronously. +/// +/// - note: This is a low level protocol designed to enable use cases where a frameworks initializes the +/// runtime with a handler outside the normal initialization of +/// ``ByteBufferLambdaHandler``, ``EventLoopLambdaHandler`` and ``LambdaHandler`` based APIs. +/// Most users are not expected to use this protocol. +public protocol LambdaRuntimeHandler { + /// The Lambda handling method. + /// Concrete Lambda handlers implement this method to provide the Lambda functionality. + /// + /// - parameters: + /// - context: Runtime ``LambdaContext``. + /// - event: The event or input payload encoded as `ByteBuffer`. + /// + /// - Returns: An `EventLoopFuture` to report the result of the Lambda back to the runtime engine. + /// The `EventLoopFuture` should be completed with either a response encoded as `ByteBuffer` or an `Error`. + func handle(_ buffer: ByteBuffer, context: LambdaContext) -> EventLoopFuture +} + // MARK: - Other @usableFromInline diff --git a/Sources/AWSLambdaRuntimeCore/LambdaRunner.swift b/Sources/AWSLambdaRuntimeCore/LambdaRunner.swift index 898fc3e3..9557f41f 100644 --- a/Sources/AWSLambdaRuntimeCore/LambdaRunner.swift +++ b/Sources/AWSLambdaRuntimeCore/LambdaRunner.swift @@ -33,7 +33,11 @@ internal final class LambdaRunner { /// Run the user provided initializer. This *must* only be called once. /// /// - Returns: An `EventLoopFuture` fulfilled with the outcome of the initialization. - func initialize(handlerType: Handler.Type, logger: Logger, terminator: LambdaTerminator) -> EventLoopFuture { + func initialize( + handlerProvider: @escaping (LambdaInitializationContext) -> EventLoopFuture, + logger: Logger, + terminator: LambdaTerminator + ) -> EventLoopFuture { logger.debug("initializing lambda") // 1. create the handler from the factory // 2. report initialization error if one occurred @@ -44,7 +48,7 @@ internal final class LambdaRunner { terminator: terminator ) - return handlerType.makeHandler(context: context) + return handlerProvider(context) // Hopping back to "our" EventLoop is important in case the factory returns a future // that originated from a foreign EventLoop/EventLoopGroup. // This can happen if the factory uses a library (let's say a database client) that manages its own threads/loops @@ -59,7 +63,7 @@ internal final class LambdaRunner { } } - func run(handler: some ByteBufferLambdaHandler, logger: Logger) -> EventLoopFuture { + func run(handler: some LambdaRuntimeHandler, logger: Logger) -> EventLoopFuture { logger.debug("lambda invocation sequence starting") // 1. request invocation from lambda runtime engine self.isGettingNextInvocation = true diff --git a/Sources/AWSLambdaRuntimeCore/LambdaRuntime.swift b/Sources/AWSLambdaRuntimeCore/LambdaRuntime.swift index 96b77489..c570a0b3 100644 --- a/Sources/AWSLambdaRuntimeCore/LambdaRuntime.swift +++ b/Sources/AWSLambdaRuntimeCore/LambdaRuntime.swift @@ -19,12 +19,14 @@ import NIOCore /// `LambdaRuntime` manages the Lambda process lifecycle. /// /// Use this API, if you build a higher level web framework which shall be able to run inside the Lambda environment. -public final class LambdaRuntime { +public final class LambdaRuntime { private let eventLoop: EventLoop private let shutdownPromise: EventLoopPromise private let logger: Logger private let configuration: LambdaConfiguration + private let handlerProvider: (LambdaInitializationContext) -> EventLoopFuture + private var state = State.idle { willSet { self.eventLoop.assertInEventLoop() @@ -35,18 +37,41 @@ public final class LambdaRuntime { /// Create a new `LambdaRuntime`. /// /// - parameters: - /// - handlerType: The ``ByteBufferLambdaHandler`` type the `LambdaRuntime` shall create and manage. + /// - handlerProvider: A provider of the ``Handler`` the `LambdaRuntime` will manage. /// - eventLoop: An `EventLoop` to run the Lambda on. /// - logger: A `Logger` to log the Lambda events. - public convenience init(_ handlerType: Handler.Type, eventLoop: EventLoop, logger: Logger) { - self.init(handlerType: handlerType, eventLoop: eventLoop, logger: logger, configuration: .init()) + @usableFromInline + convenience init( + handlerProvider: @escaping (LambdaInitializationContext) -> EventLoopFuture, + eventLoop: EventLoop, + logger: Logger + ) { + self.init( + handlerProvider: handlerProvider, + eventLoop: eventLoop, + logger: logger, + configuration: .init() + ) } - init(handlerType: Handler.Type, eventLoop: EventLoop, logger: Logger, configuration: LambdaConfiguration) { + /// Create a new `LambdaRuntime`. + /// + /// - parameters: + /// - handlerProvider: A provider of the ``Handler`` the `LambdaRuntime` will manage. + /// - eventLoop: An `EventLoop` to run the Lambda on. + /// - logger: A `Logger` to log the Lambda events. + init( + handlerProvider: @escaping (LambdaInitializationContext) -> EventLoopFuture, + eventLoop: EventLoop, + logger: Logger, + configuration: LambdaConfiguration + ) { self.eventLoop = eventLoop self.shutdownPromise = eventLoop.makePromise(of: Int.self) self.logger = logger self.configuration = configuration + + self.handlerProvider = handlerProvider } deinit { @@ -85,7 +110,7 @@ public final class LambdaRuntime { let terminator = LambdaTerminator() let runner = LambdaRunner(eventLoop: self.eventLoop, configuration: self.configuration) - let startupFuture = runner.initialize(handlerType: Handler.self, logger: logger, terminator: terminator) + let startupFuture = runner.initialize(handlerProvider: self.handlerProvider, logger: logger, terminator: terminator) startupFuture.flatMap { handler -> EventLoopFuture> in // after the startup future has succeeded, we have a handler that we can use // to `run` the lambda. @@ -175,7 +200,7 @@ public final class LambdaRuntime { private enum State { case idle case initializing - case active(LambdaRunner, any ByteBufferLambdaHandler) + case active(LambdaRunner, any LambdaRuntimeHandler) case shuttingdown case shutdown @@ -204,8 +229,16 @@ public enum LambdaRuntimeFactory { /// - eventLoop: An `EventLoop` to run the Lambda on. /// - logger: A `Logger` to log the Lambda events. @inlinable - public static func makeRuntime(_ handlerType: H.Type, eventLoop: any EventLoop, logger: Logger) -> LambdaRuntime { - LambdaRuntime>(CodableSimpleLambdaHandler.self, eventLoop: eventLoop, logger: logger) + public static func makeRuntime( + _ handlerType: Handler.Type, + eventLoop: any EventLoop, + logger: Logger + ) -> LambdaRuntime { + LambdaRuntime>( + handlerProvider: CodableSimpleLambdaHandler.makeHandler(context:), + eventLoop: eventLoop, + logger: logger + ) } /// Create a new `LambdaRuntime`. @@ -215,8 +248,16 @@ public enum LambdaRuntimeFactory { /// - eventLoop: An `EventLoop` to run the Lambda on. /// - logger: A `Logger` to log the Lambda events. @inlinable - public static func makeRuntime(_ handlerType: H.Type, eventLoop: any EventLoop, logger: Logger) -> LambdaRuntime { - LambdaRuntime>(CodableLambdaHandler.self, eventLoop: eventLoop, logger: logger) + public static func makeRuntime( + _ handlerType: Handler.Type, + eventLoop: any EventLoop, + logger: Logger + ) -> LambdaRuntime { + LambdaRuntime>( + handlerProvider: CodableLambdaHandler.makeHandler(context:), + eventLoop: eventLoop, + logger: logger + ) } /// Create a new `LambdaRuntime`. @@ -226,8 +267,79 @@ public enum LambdaRuntimeFactory { /// - eventLoop: An `EventLoop` to run the Lambda on. /// - logger: A `Logger` to log the Lambda events. @inlinable - public static func makeRuntime(_ handlerType: H.Type, eventLoop: any EventLoop, logger: Logger) -> LambdaRuntime { - LambdaRuntime>(CodableEventLoopLambdaHandler.self, eventLoop: eventLoop, logger: logger) + public static func makeRuntime( + _ handlerType: Handler.Type, + eventLoop: any EventLoop, + logger: Logger + ) -> LambdaRuntime { + LambdaRuntime>( + handlerProvider: CodableEventLoopLambdaHandler.makeHandler(context:), + eventLoop: eventLoop, + logger: logger + ) + } + + /// Create a new `LambdaRuntime`. + /// + /// - parameters: + /// - handlerType: The ``ByteBufferLambdaHandler`` type the `LambdaRuntime` shall create and manage. + /// - eventLoop: An `EventLoop` to run the Lambda on. + /// - logger: A `Logger` to log the Lambda events. + @inlinable + public static func makeRuntime( + _ handlerType: Handler.Type, + eventLoop: any EventLoop, + logger: Logger + ) -> LambdaRuntime { + LambdaRuntime( + handlerProvider: Handler.makeHandler(context:), + eventLoop: eventLoop, + logger: logger + ) + } + + /// Create a new `LambdaRuntime`. + /// + /// - parameters: + /// - handlerProvider: A provider of the ``Handler`` the `LambdaRuntime` will manage. + /// - eventLoop: An `EventLoop` to run the Lambda on. + /// - logger: A `Logger` to log the Lambda events. + @inlinable + public static func makeRuntime( + handlerProvider: @escaping (LambdaInitializationContext) -> EventLoopFuture, + eventLoop: any EventLoop, + logger: Logger + ) -> LambdaRuntime { + LambdaRuntime( + handlerProvider: handlerProvider, + eventLoop: eventLoop, + logger: logger + ) + } + + /// Create a new `LambdaRuntime`. + /// + /// - parameters: + /// - handlerProvider: A provider of the ``Handler`` the `LambdaRuntime` will manage. + /// - eventLoop: An `EventLoop` to run the Lambda on. + /// - logger: A `Logger` to log the Lambda events. + @inlinable + public static func makeRuntime( + handlerProvider: @escaping (LambdaInitializationContext) async throws -> Handler, + eventLoop: any EventLoop, + logger: Logger + ) -> LambdaRuntime { + LambdaRuntime( + handlerProvider: { context in + let promise = eventLoop.makePromise(of: Handler.self) + promise.completeWithTask { + try await handlerProvider(context) + } + return promise.futureResult + }, + eventLoop: eventLoop, + logger: logger + ) } } diff --git a/Tests/AWSLambdaRuntimeCoreTests/LambdaRunnerTest.swift b/Tests/AWSLambdaRuntimeCoreTests/LambdaRunnerTest.swift index a561dbf8..6fd91aec 100644 --- a/Tests/AWSLambdaRuntimeCoreTests/LambdaRunnerTest.swift +++ b/Tests/AWSLambdaRuntimeCoreTests/LambdaRunnerTest.swift @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// @testable import AWSLambdaRuntimeCore +import NIOCore import XCTest class LambdaRunnerTest: XCTestCase { @@ -68,4 +69,128 @@ class LambdaRunnerTest: XCTestCase { } XCTAssertNoThrow(try runLambda(behavior: Behavior(), handlerType: RuntimeErrorHandler.self)) } + + func testCustomProviderSuccess() { + struct Behavior: LambdaServerBehavior { + let requestId = UUID().uuidString + let event = "hello" + func getInvocation() -> GetInvocationResult { + .success((self.requestId, self.event)) + } + + func processResponse(requestId: String, response: String?) -> Result { + XCTAssertEqual(self.requestId, requestId, "expecting requestId to match") + XCTAssertEqual(self.event, response, "expecting response to match") + return .success(()) + } + + func processError(requestId: String, error: ErrorResponse) -> Result { + XCTFail("should not report error") + return .failure(.internalServerError) + } + + func processInitError(error: ErrorResponse) -> Result { + XCTFail("should not report init error") + return .failure(.internalServerError) + } + } + XCTAssertNoThrow(try runLambda(behavior: Behavior(), handlerProvider: { context in + context.eventLoop.makeSucceededFuture(EchoHandler()) + })) + } + + func testCustomProviderFailure() { + struct Behavior: LambdaServerBehavior { + let requestId = UUID().uuidString + let event = "hello" + func getInvocation() -> GetInvocationResult { + .success((self.requestId, self.event)) + } + + func processResponse(requestId: String, response: String?) -> Result { + XCTFail("should not report processing") + return .failure(.internalServerError) + } + + func processError(requestId: String, error: ErrorResponse) -> Result { + XCTFail("should not report error") + return .failure(.internalServerError) + } + + func processInitError(error: ErrorResponse) -> Result { + XCTAssertEqual(String(describing: CustomError()), error.errorMessage, "expecting error to match") + return .success(()) + } + } + + struct CustomError: Error {} + + XCTAssertThrowsError(try runLambda(behavior: Behavior(), handlerProvider: { context -> EventLoopFuture in + context.eventLoop.makeFailedFuture(CustomError()) + })) { error in + XCTAssertNotNil(error as? CustomError, "expecting error to match") + } + } + + func testCustomAsyncProviderSuccess() { + struct Behavior: LambdaServerBehavior { + let requestId = UUID().uuidString + let event = "hello" + func getInvocation() -> GetInvocationResult { + .success((self.requestId, self.event)) + } + + func processResponse(requestId: String, response: String?) -> Result { + XCTAssertEqual(self.requestId, requestId, "expecting requestId to match") + XCTAssertEqual(self.event, response, "expecting response to match") + return .success(()) + } + + func processError(requestId: String, error: ErrorResponse) -> Result { + XCTFail("should not report error") + return .failure(.internalServerError) + } + + func processInitError(error: ErrorResponse) -> Result { + XCTFail("should not report init error") + return .failure(.internalServerError) + } + } + XCTAssertNoThrow(try runLambda(behavior: Behavior(), handlerProvider: { _ async throws -> EchoHandler in + EchoHandler() + })) + } + + func testCustomAsyncProviderFailure() { + struct Behavior: LambdaServerBehavior { + let requestId = UUID().uuidString + let event = "hello" + func getInvocation() -> GetInvocationResult { + .success((self.requestId, self.event)) + } + + func processResponse(requestId: String, response: String?) -> Result { + XCTFail("should not report processing") + return .failure(.internalServerError) + } + + func processError(requestId: String, error: ErrorResponse) -> Result { + XCTFail("should not report error") + return .failure(.internalServerError) + } + + func processInitError(error: ErrorResponse) -> Result { + XCTAssertEqual(String(describing: CustomError()), error.errorMessage, "expecting error to match") + return .success(()) + } + } + + struct CustomError: Error {} + + XCTAssertThrowsError(try runLambda(behavior: Behavior(), handlerProvider: { _ async throws -> EchoHandler in + throw CustomError() + })) { error in + XCTAssertNotNil(error as? CustomError, "expecting error to match") + } + } } diff --git a/Tests/AWSLambdaRuntimeCoreTests/Utils.swift b/Tests/AWSLambdaRuntimeCoreTests/Utils.swift index aecd3186..8e2d3c38 100644 --- a/Tests/AWSLambdaRuntimeCoreTests/Utils.swift +++ b/Tests/AWSLambdaRuntimeCoreTests/Utils.swift @@ -19,18 +19,59 @@ import NIOPosix import XCTest func runLambda(behavior: LambdaServerBehavior, handlerType: Handler.Type) throws { - try runLambda(behavior: behavior, handlerType: CodableSimpleLambdaHandler.self) + try runLambda(behavior: behavior, handlerProvider: CodableSimpleLambdaHandler.makeHandler(context:)) } func runLambda(behavior: LambdaServerBehavior, handlerType: Handler.Type) throws { - try runLambda(behavior: behavior, handlerType: CodableLambdaHandler.self) + try runLambda(behavior: behavior, handlerProvider: CodableLambdaHandler.makeHandler(context:)) } func runLambda(behavior: LambdaServerBehavior, handlerType: Handler.Type) throws { - try runLambda(behavior: behavior, handlerType: CodableEventLoopLambdaHandler.self) + try runLambda(behavior: behavior, handlerProvider: CodableEventLoopLambdaHandler.makeHandler(context:)) } -func runLambda(behavior: LambdaServerBehavior, handlerType: (some ByteBufferLambdaHandler).Type) throws { +func runLambda( + behavior: LambdaServerBehavior, + handlerProvider: @escaping (LambdaInitializationContext) -> EventLoopFuture +) throws { + try runLambda(behavior: behavior, handlerProvider: { context in + handlerProvider(context).map { + CodableEventLoopLambdaHandler(handler: $0, allocator: context.allocator) + } + }) +} + +func runLambda( + behavior: LambdaServerBehavior, + handlerProvider: @escaping (LambdaInitializationContext) async throws -> Handler +) throws { + try runLambda(behavior: behavior, handlerProvider: { context in + let handler = try await handlerProvider(context) + return CodableEventLoopLambdaHandler(handler: handler, allocator: context.allocator) + }) +} + +func runLambda( + behavior: LambdaServerBehavior, + handlerProvider: @escaping (LambdaInitializationContext) async throws -> Handler +) throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + try runLambda( + behavior: behavior, + handlerProvider: { context in + let promise = eventLoopGroup.next().makePromise(of: Handler.self) + promise.completeWithTask { + try await handlerProvider(context) + } + return promise.futureResult + } + ) +} + +func runLambda( + behavior: LambdaServerBehavior, + handlerProvider: @escaping (LambdaInitializationContext) -> EventLoopFuture +) throws { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let logger = Logger(label: "TestLogger") @@ -39,7 +80,7 @@ func runLambda(behavior: LambdaServerBehavior, handlerType: (some ByteBufferLamb let runner = LambdaRunner(eventLoop: eventLoopGroup.next(), configuration: configuration) let server = try MockLambdaServer(behavior: behavior).start().wait() defer { XCTAssertNoThrow(try server.stop().wait()) } - try runner.initialize(handlerType: handlerType, logger: logger, terminator: terminator).flatMap { handler in + try runner.initialize(handlerProvider: handlerProvider, logger: logger, terminator: terminator).flatMap { handler in runner.run(handler: handler, logger: logger) }.wait() } @@ -89,3 +130,14 @@ extension LambdaTerminator.TerminationError: Equatable { return String(describing: lhs) == String(describing: rhs) } } + +// for backward compatibility in tests +extension LambdaRunner { + func initialize( + handlerType: Handler.Type, + logger: Logger, + terminator: LambdaTerminator + ) -> EventLoopFuture { + self.initialize(handlerProvider: handlerType.makeHandler(context:), logger: logger, terminator: terminator) + } +}