From 7c3c1d19789cf4e2456e1ac6a74e1f5fe4e07816 Mon Sep 17 00:00:00 2001 From: Arnaud Date: Mon, 23 Dec 2024 18:37:29 +0100 Subject: [PATCH 1/9] Making BHTTPSerializer compliant to known-length messages Motivation: As defined in [RFC9292](https://www.rfc-editor.org/rfc/rfc9292), you can choose to use either the known-length format or an indeterminate-length format. The purpose of this PR is to give the possibility for users to use the known-length format. Modifications: - Finite State Machine (FSM): Added an FSM to the serializer to define allowed state transitions. This prevents invalid sequences, such as having a Body HTTP Part followed by a Header HTTP Part. - BHTTPSerializer: Introduced a flag within BHTTPSerializer to determine whether the known-length or indeterminate-length format will be used. - Serialization of Known-Length Sections: Added methods to BHTTPSerializer specifically for serializing known-length sections. - Unit-tests: I updated unit tests to cover scenarios where we have out-of-order HTTP parts but also to verify the known-length implementation. Design Choices: - Buffer Reference: The output buffer is kept as a reference to optimize performance. Copying the entire buffer at the end of the serialization would be inefficient. - Inlining serializeContentChunk: I inlined the serializeContentChunk function to minimize overhead from function prologues and epilogues, as it is expected to be used frequently. - Buffers for Known-Length Format: For known-length formats, two buffers were introduced: chunkBuffer and fieldSectionBuffer. These buffers are used when the request/response consists of multiple parts. I did some researchs to find what could be the optimal initial capacities for those 2 buffers. I found out that I could set them to 500 and 700, respectively, based on data from the following sources: - SPDY Whitepaper for header field sizes: https://www.chromium.org/spdy/spdy-whitepaper - Arxiv paper on HTTP body size for body content size: https://arxiv.org/pdf/1405.2330 (If needed, I can provide more details on how these values were derived.) However, as underlined in the ByteBuffer implementation: https://github.com/apple/swift-nio/blob/main/Sources/NIOCore/ByteBuffer-core.swift, the initial capacity is set to: let newCapacity = capacity == 0 ? 0 : capacity.nextPowerOf2ClampedToMax() . Therefore, I decided to not use those values as it would very likely to be way too much. Given this, I decided not to use the predefined values, as they would likely be too large. Instead, I opted to rely on the initial size encountered when initializing the two buffers. - FSM Transition Definitions: The state transition definitions in the FSM are declared as static, with the intention that these definitions will be placed in the data section of the compiled binary. This should save memory but as I only used Swift for a couple of days now, I prefer to take distance. I tried to verify this by reviewing the assembly output here: https://godbolt.org/z/s88Kaxqn8 but I wasn't able to confirm with certainty due to the noise in the output. --- README.md | 3 +- Sources/ObliviousHTTP/BHTTPSerializer.swift | 200 +++++++++++--- Sources/ObliviousHTTP/Errors.swift | 21 +- .../BHTTPSerializerTests.swift | 255 ++++++++++++++++-- 4 files changed, 427 insertions(+), 52 deletions(-) diff --git a/README.md b/README.md index bbedd73..85b10e2 100644 --- a/README.md +++ b/README.md @@ -47,8 +47,9 @@ dependencies: [ ### Binary HTTP Encoding To serialise binary HTTP messages use `BHTTPSerializer.serialize(message, buffer)`. +As defined in [RFC9292](https://www.rfc-editor.org/rfc/rfc9292), you can choose to use either the known-length format or an indeterminate-length format. This choice can be configured during the initialization of the serializer by passing the desired type: `BHTTPSerializer(type: .knownLength)`. -To deserialise binary HTTP messages use `BHTTPParser`, adding received data with `append()`, then calling `completeBodyRecieved()`. The read the message parts received call `nextMessage()`. +To deserialise binary HTTP messages use `BHTTPParser`, adding received data with `append()`, then calling `completeBodyReceived()`. The read the message parts received call `nextMessage()`. ### Oblivious Encapsulation diff --git a/Sources/ObliviousHTTP/BHTTPSerializer.swift b/Sources/ObliviousHTTP/BHTTPSerializer.swift index c4f8821..3bee8bb 100644 --- a/Sources/ObliviousHTTP/BHTTPSerializer.swift +++ b/Sources/ObliviousHTTP/BHTTPSerializer.swift @@ -14,45 +14,86 @@ import NIOCore import NIOHTTP1 -// For now this type is entirely stateless, which is achieved by using the indefinite-length encoding. -// It also means it does not enforce correctness, and so can produce invalid encodings if a user holds -// it wrong. -// -// Later optimizations can be made by adding more state into this type. /// Binary HTTP serialiser as described in [RFC9292](https://www.rfc-editor.org/rfc/rfc9292). -/// Currently only indeterminate-length encoding is supported. public struct BHTTPSerializer { + + private var fsm: BHTTPSerializerFSM + private var type: BHTTPSerializerType + private var chunkBuffer: ByteBuffer + private var fieldSectionBuffer: ByteBuffer + /// Initialise a Binary HTTP Serialiser. - public init() {} + public init( + type: BHTTPSerializerType = .indeterminateLength, + allocator: ByteBufferAllocator = ByteBufferAllocator() + ) { + self.type = type + self.chunkBuffer = allocator.buffer(capacity: 0) + self.fieldSectionBuffer = allocator.buffer(capacity: 0) + self.fsm = BHTTPSerializerFSM(initialState: BHTTPSerializerState.HEADER) + } /// Serialise a message into a buffer using binary HTTP encoding. /// - Parameters: - /// - message: The message to serialise. File regions are currently not supported. + /// - message: The message to serialise. File regions are currently not supported. /// - buffer: Destination buffer to serialise into. - public func serialize(_ message: Message, into buffer: inout ByteBuffer) { + public mutating func serialize(_ message: Message, into buffer: inout ByteBuffer) throws { switch message { case .request(.head(let requestHead)): - Self.serializeRequestHead(requestHead, into: &buffer) + try self.fsm.ensureState([BHTTPSerializerState.HEADER]) + self.serializeRequestHead(requestHead, into: &buffer) + try self.fsm.transition(to: BHTTPSerializerState.CHUNK) + case .response(.head(let responseHead)): - Self.serializeResponseHead(responseHead, into: &buffer) + try self.fsm.ensureState([BHTTPSerializerState.HEADER]) + self.serializeResponseHead(responseHead, into: &buffer) + try self.fsm.transition(to: BHTTPSerializerState.CHUNK) + case .request(.body(.byteBuffer(let body))), .response(.body(.byteBuffer(let body))): - Self.serializeContentChunk(body, into: &buffer) + try self.fsm.ensureState([BHTTPSerializerState.CHUNK]) + switch self.type { + case .indeterminateLength: + Self.serializeContentChunk(body, into: &buffer) + case .knownLength: + self.stackContentChunk(body) + } + case .request(.body(.fileRegion)), .response(.body(.fileRegion)): - fatalError("fileregion unsupported") + throw ObliviousHTTPError.unsupportedOption(reason: "fileregion unsupported") + case .request(.end(.some(let trailers))), .response(.end(.some(let trailers))): - // Send a 0 to terminate the body, then a field section. - buffer.writeInteger(UInt8(0)) - Self.serializeIndeterminateLengthFieldSection(trailers, into: &buffer) + try self.fsm.ensureState([BHTTPSerializerState.CHUNK, BHTTPSerializerState.HEADER]) + switch self.type { + case .indeterminateLength: + // Send a 0 to terminate the body, then a field section. + buffer.writeInteger(UInt8(0)) + Self.serializeIndeterminateLengthFieldSection(trailers, into: &buffer) + case .knownLength: + self.serializeContent(into: &buffer) + self.stackKnownLengthFieldSection(trailers) + } + try self.fsm.transition(to: BHTTPSerializerState.TRAILERS) + case .request(.end(.none)), .response(.end(.none)): - // We can omit the trailers in this context, but we will always send a zero - // byte, either to communicate no trailers or no body. - buffer.writeInteger(UInt8(0)) + try self.fsm.ensureState([BHTTPSerializerState.CHUNK, BHTTPSerializerState.TRAILERS]) + switch self.type { + case .indeterminateLength: + buffer.writeInteger(UInt8(0)) + case .knownLength: + self.serializeContent(into: &buffer) + self.serializeKnownLengthFieldSection(into: &buffer) + } + try self.fsm.transition(to: BHTTPSerializerState.END) } } - private static func serializeRequestHead(_ head: HTTPRequestHead, into buffer: inout ByteBuffer) { - // First, the framing indicator. 2 for indeterminate length request. - buffer.writeVarint(2) + private mutating func serializeRequestHead(_ head: HTTPRequestHead, into buffer: inout ByteBuffer) { + // First, the framing indicator + buffer.writeVarint( + self.type == .indeterminateLength + ? BHTTPFramingIndicator.requestIndeterminateLength.rawValue + : BHTTPFramingIndicator.requestKnownLength.rawValue + ) let method = head.method let scheme = "https" // Hardcoded for now, but not really the right option. @@ -64,22 +105,50 @@ public struct BHTTPSerializer { buffer.writeVarintPrefixedString(authority) buffer.writeVarintPrefixedString(path) - Self.serializeIndeterminateLengthFieldSection(head.headers, into: &buffer) + switch self.type { + case .indeterminateLength: + Self.serializeIndeterminateLengthFieldSection(head.headers, into: &buffer) + case .knownLength: + self.stackKnownLengthFieldSection(head.headers) + self.serializeKnownLengthFieldSection(into: &buffer) + } } - private static func serializeResponseHead(_ head: HTTPResponseHead, into buffer: inout ByteBuffer) { - // First, the framing indicator. 3 for indeterminate length response. - buffer.writeVarint(3) + private mutating func serializeResponseHead(_ head: HTTPResponseHead, into buffer: inout ByteBuffer) { + // First, the framing indicator + buffer.writeVarint( + self.type == .indeterminateLength + ? BHTTPFramingIndicator.responseInderterminateLength.rawValue + : BHTTPFramingIndicator.responseKnownLength.rawValue + ) + buffer.writeVarint(Int(head.status.code)) - Self.serializeIndeterminateLengthFieldSection(head.headers, into: &buffer) + + switch self.type { + case .indeterminateLength: + Self.serializeIndeterminateLengthFieldSection(head.headers, into: &buffer) + case .knownLength: + self.stackKnownLengthFieldSection(head.headers) + self.serializeKnownLengthFieldSection(into: &buffer) + } } + @inline(__always) private static func serializeContentChunk(_ chunk: ByteBuffer, into buffer: inout ByteBuffer) { - // Omit zero-length chunks. if chunk.readableBytes == 0 { return } buffer.writeVarintPrefixedImmutableBuffer(chunk) } + private mutating func serializeContent(into buffer: inout ByteBuffer) { + if self.chunkBuffer.readableBytes == 0 { return } + buffer.writeVarintPrefixedImmutableBuffer(self.chunkBuffer) + self.chunkBuffer.clear() + } + + private mutating func stackContentChunk(_ chunk: ByteBuffer) { + self.chunkBuffer.writeImmutableBuffer(chunk) + } + private static func serializeIndeterminateLengthFieldSection( _ fields: HTTPHeaders, into buffer: inout ByteBuffer @@ -88,18 +157,83 @@ public struct BHTTPSerializer { buffer.writeVarintPrefixedString(name) buffer.writeVarintPrefixedString(value) } - // This is technically a varint but we can skip the check there because we know it can always encode in one byte. - buffer.writeInteger(UInt8(0)) + buffer.writeInteger(UInt8(0)) // End of field section } + private mutating func serializeKnownLengthFieldSection(into buffer: inout ByteBuffer) { + buffer.writeVarintPrefixedImmutableBuffer(self.fieldSectionBuffer) + self.fieldSectionBuffer.clear() + } + + private mutating func stackKnownLengthFieldSection(_ fields: HTTPHeaders) { + for (name, value) in fields { + self.fieldSectionBuffer.writeVarintPrefixedString(name) + self.fieldSectionBuffer.writeVarintPrefixedString(value) + } + } } +// Enum definitions for message, states, and types. extension BHTTPSerializer { - /// Types of message for binary http serilaisation + // Finite State Machine for managing transitions in BHTTPSerializer. + private class BHTTPSerializerFSM { + var currentState: BHTTPSerializerState + + init(initialState: BHTTPSerializerState) { + self.currentState = initialState + } + + // Transition to a new state, respecting the state machine constraints. + func transition(to state: BHTTPSerializerState) throws { + guard let allowedTransitions = Self.validTransitions[currentState] else { + throw ObliviousHTTPError.unexpectedHTTPMessageSection(state: currentState.rawValue) + } + + guard allowedTransitions.contains(state) else { + throw ObliviousHTTPError.unexpectedHTTPMessageSection(state: currentState.rawValue) + } + + currentState = state + } + + // Define a dictionary to map current states to allowed transitions + private static let validTransitions: [BHTTPSerializerState: Set] = [ + .HEADER: [.CHUNK, .TRAILERS], + .CHUNK: [.TRAILERS, .END], + .TRAILERS: [.END], + .END: [] + ] + + + // Ensure that the current state is one of the allowed states. + func ensureState(_ allowedStates: [BHTTPSerializerState]) throws { + if !allowedStates.contains(self.currentState) { + throw ObliviousHTTPError.unexpectedHTTPMessageSection(state: self.currentState.rawValue) + } + } + } + public enum Message { - /// Part of an HTTP request. case request(HTTPClientRequestPart) - /// Part of an HTTP response. case response(HTTPServerResponsePart) } + + public enum BHTTPSerializerType { + case knownLength + case indeterminateLength + } + + public enum BHTTPFramingIndicator: Int { + case requestKnownLength = 0 + case responseKnownLength = 1 + case requestIndeterminateLength = 2 + case responseInderterminateLength = 3 + } + + public enum BHTTPSerializerState: String { + case HEADER = "Header" + case CHUNK = "Chunk" + case TRAILERS = "Trailers" + case END = "End" + } } diff --git a/Sources/ObliviousHTTP/Errors.swift b/Sources/ObliviousHTTP/Errors.swift index 62910d5..bcfddb9 100644 --- a/Sources/ObliviousHTTP/Errors.swift +++ b/Sources/ObliviousHTTP/Errors.swift @@ -44,13 +44,30 @@ public struct ObliviousHTTPError: Error, Hashable { Self.init(backing: .truncatedEncoding(reason: reason)) } - /// Create an error indicating that parsing faileud due to an unexpected HTTP status code. + /// Create an error indicating that parsing failed due to an unexpected HTTP status code. /// - Parameter status: The status code encountered. /// - Returns: An Error representing this failure. @inline(never) public static func invalidStatus(status: Int) -> ObliviousHTTPError { Self.init(backing: .invalidStatus(status: status)) } + + /// Create an error indicating that serializing failed due to an unexpected HTTP section. + /// - Parameter status: The state encountered. + /// - Returns: An Error representing this failure. + @inline(never) + public static func unexpectedHTTPMessageSection(state: String) -> ObliviousHTTPError { + Self.init(backing: .unexpectedHTTPMessageSection(state: "\(state) section was not expected.")) + } + + /// Create an error indicating that serializing failed due to an unsupported option. + /// - Parameter status: The unsupported option details. + /// - Returns: An Error representing this failure. + @inline(never) + public static func unsupportedOption(reason: String) -> ObliviousHTTPError { + Self.init(backing: .unsupportedOption(reason: reason)) + } + } extension ObliviousHTTPError { @@ -59,5 +76,7 @@ extension ObliviousHTTPError { case invalidFieldSection(reason: String) case truncatedEncoding(reason: String) case invalidStatus(status: Int) + case unexpectedHTTPMessageSection(state: String) + case unsupportedOption(reason: String) } } diff --git a/Tests/ObliviousHTTPTests/BHTTPSerializerTests.swift b/Tests/ObliviousHTTPTests/BHTTPSerializerTests.swift index b090c90..a61f6cc 100644 --- a/Tests/ObliviousHTTPTests/BHTTPSerializerTests.swift +++ b/Tests/ObliviousHTTPTests/BHTTPSerializerTests.swift @@ -30,12 +30,46 @@ final class BHTTPSerializerTests: XCTestCase { .end(nil), ] - let serializer = BHTTPSerializer() + var buffer = ByteBuffer() + var serializer = BHTTPSerializer() var parser = BHTTPParser(role: .server) + + for message in request { + try serializer.serialize(.request(message), into: &buffer) + } + + parser.append(buffer) + parser.completeBodyReceived() + var received: [HTTPServerRequestPart] = [] + + while let next = try parser.nextMessage(), case .request(let request) = next { + received.append(request) + } + + let expectedRequest: [HTTPServerRequestPart] = [ + .head(.init(version: .http1_1, method: .GET, uri: "/example", headers: expectedHeaders)), + .end(nil), + ] + XCTAssertEqual(expectedRequest, received) + } + + func testSimpleGetRequestRoundTripsWithKnownLengthSerializer() throws { + let expectedHeaders = HTTPHeaders([ + ("user-agent", "curl/7.16.3 libcurl/7.16.3 OpenSSL/0.9.7l zlib/1.2.3"), + ("host", "www.example.com"), + ("accept-language", "en, mi"), + ]) + let request: [HTTPClientRequestPart] = [ + .head(.init(version: .http1_1, method: .GET, uri: "/example", headers: expectedHeaders)), + .end(nil), + ] + var buffer = ByteBuffer() + var serializer = BHTTPSerializer(type: .knownLength) + var parser = BHTTPParser(role: .server) for message in request { - serializer.serialize(.request(message), into: &buffer) + try serializer.serialize(.request(message), into: &buffer) } parser.append(buffer) @@ -53,6 +87,71 @@ final class BHTTPSerializerTests: XCTestCase { XCTAssertEqual(expectedRequest, received) } + func testSerializerThrowsForMultipleHeadPartsInRequest() throws { + let expectedHeaders = HTTPHeaders([ + ("user-agent", "curl/7.16.3 libcurl/7.16.3 OpenSSL/0.9.7l zlib/1.2.3"), + ("host", "www.example.com"), + ("accept-language", "en, mi"), + ]) + let request: [HTTPClientRequestPart] = [ + .head(.init(version: .http1_1, method: .GET, uri: "/example", headers: expectedHeaders)), + .head(.init(version: .http1_1, method: .GET, uri: "/example", headers: expectedHeaders)), + .end(nil), + ] + + var buffer = ByteBuffer() + var serializer = BHTTPSerializer() + + var didThrowError = false + + for message in request { + do { + try serializer.serialize(.request(message), into: &buffer) + } catch { + didThrowError = true + XCTAssertNotNil(error, "Chunk section was not expected.") + break + } + } + + XCTAssertTrue(didThrowError, "Expected serializer to throw an error because the request has 2 heads part.") + } + + func testSerializerThrowsForOutOfOrderRequestParts() throws { + let expectedHeaders = HTTPHeaders([ + ("user-agent", "curl/7.16.3 libcurl/7.16.3 OpenSSL/0.9.7l zlib/1.2.3"), + ("host", "www.example.com"), + ("accept-language", "en, mi"), + ("content-length", "5"), + ]) + let request: [HTTPClientRequestPart] = [ + .body(.byteBuffer(.init(string: "he"))), + .head(.init(version: .http1_1, method: .POST, uri: "/example", headers: expectedHeaders)), + .body(.byteBuffer(.init(string: "llo"))), + .end(nil), + ] + + var buffer = ByteBuffer() + var serializer = BHTTPSerializer() + + var didThrowError = false + + for message in request { + do { + try serializer.serialize(.request(message), into: &buffer) + } catch { + didThrowError = true + XCTAssertNotNil(error, "Header section was not expected.") + break + } + } + + XCTAssertTrue( + didThrowError, + "Expected serializer to throw an error because the request has out of order parts." + ) + } + func testSimplePOSTRequestRoundTrips() throws { let expectedHeaders = HTTPHeaders([ ("user-agent", "curl/7.16.3 libcurl/7.16.3 OpenSSL/0.9.7l zlib/1.2.3"), @@ -67,12 +166,12 @@ final class BHTTPSerializerTests: XCTestCase { .end(nil), ] - let serializer = BHTTPSerializer() - var parser = BHTTPParser(role: .server) var buffer = ByteBuffer() + var serializer = BHTTPSerializer() + var parser = BHTTPParser(role: .server) for message in request { - serializer.serialize(.request(message), into: &buffer) + try serializer.serialize(.request(message), into: &buffer) } parser.append(buffer) @@ -92,6 +191,44 @@ final class BHTTPSerializerTests: XCTestCase { XCTAssertEqual(expectedRequest, received) } + func testSimplePOSTRequestRoundTripsWithKnownLengthSerialiser() throws { + let expectedHeaders = HTTPHeaders([ + ("user-agent", "curl/7.16.3 libcurl/7.16.3 OpenSSL/0.9.7l zlib/1.2.3"), + ("host", "www.example.com"), + ("accept-language", "en, mi"), + ("content-length", "5"), + ]) + let request: [HTTPClientRequestPart] = [ + .head(.init(version: .http1_1, method: .POST, uri: "/example", headers: expectedHeaders)), + .body(.byteBuffer(.init(string: "he"))), + .body(.byteBuffer(.init(string: "llo"))), + .end(nil), + ] + + var buffer = ByteBuffer() + var serializer = BHTTPSerializer(type: .knownLength) + var parser = BHTTPParser(role: .server) + + for message in request { + try serializer.serialize(.request(message), into: &buffer) + } + + parser.append(buffer) + parser.completeBodyReceived() + var received: [HTTPServerRequestPart] = [] + + while let next = try parser.nextMessage(), case .request(let request) = next { + received.append(request) + } + + let expectedRequest: [HTTPServerRequestPart] = [ + .head(.init(version: .http1_1, method: .POST, uri: "/example", headers: expectedHeaders)), + .body(.init(string: "hello")), + .end(nil), + ] + XCTAssertEqual(expectedRequest, received) + } + func testSimplePOSTRequestWithTrailersRoundTrips() throws { let expectedHeaders = HTTPHeaders([ ("user-agent", "curl/7.16.3 libcurl/7.16.3 OpenSSL/0.9.7l zlib/1.2.3"), @@ -110,12 +247,12 @@ final class BHTTPSerializerTests: XCTestCase { .end(expectedTrailers), ] - let serializer = BHTTPSerializer() - var parser = BHTTPParser(role: .server) var buffer = ByteBuffer() + var serializer = BHTTPSerializer() + var parser = BHTTPParser(role: .server) for message in request { - serializer.serialize(.request(message), into: &buffer) + try serializer.serialize(.request(message), into: &buffer) } parser.append(buffer) @@ -135,6 +272,49 @@ final class BHTTPSerializerTests: XCTestCase { XCTAssertEqual(expectedRequest, received) } + func testSimplePOSTRequestWithTrailersRoundTripsAndKnownLengthSerializer() throws { + let expectedHeaders = HTTPHeaders([ + ("user-agent", "curl/7.16.3 libcurl/7.16.3 OpenSSL/0.9.7l zlib/1.2.3"), + ("host", "www.example.com"), + ("accept-language", "en, mi"), + ("content-length", "5"), + ]) + let expectedTrailers = HTTPHeaders([ + ("foo", "bar"), + ("froo", "brar"), + ]) + let request: [HTTPClientRequestPart] = [ + .head(.init(version: .http1_1, method: .POST, uri: "/example", headers: expectedHeaders)), + .body(.byteBuffer(.init(string: "he"))), + .body(.byteBuffer(.init(string: "llo"))), + .end(expectedTrailers), + .end(nil), + ] + + var buffer = ByteBuffer() + var serializer = BHTTPSerializer(type: .knownLength) + var parser = BHTTPParser(role: .server) + + for message in request { + try serializer.serialize(.request(message), into: &buffer) + } + + parser.append(buffer) + parser.completeBodyReceived() + var received: [HTTPServerRequestPart] = [] + + while let next = try parser.nextMessage(), case .request(let request) = next { + received.append(request) + } + + let expectedRequest: [HTTPServerRequestPart] = [ + .head(.init(version: .http1_1, method: .POST, uri: "/example", headers: expectedHeaders)), + .body(.init(string: "hello")), + .end(expectedTrailers) + ] + XCTAssertEqual(expectedRequest, received) + } + func testSimple201ResponseRoundTrips() throws { let expectedHeaders = HTTPHeaders([ ("server", "apache"), @@ -145,12 +325,12 @@ final class BHTTPSerializerTests: XCTestCase { .end(nil), ] - let serializer = BHTTPSerializer() - var parser = BHTTPParser(role: .client) var buffer = ByteBuffer() + var serializer = BHTTPSerializer() + var parser = BHTTPParser(role: .client) for message in response { - serializer.serialize(.response(message), into: &buffer) + try serializer.serialize(.response(message), into: &buffer) } parser.append(buffer) @@ -179,13 +359,12 @@ final class BHTTPSerializerTests: XCTestCase { .body(.byteBuffer(ByteBuffer(string: "hello"))), .end(nil), ] - - let serializer = BHTTPSerializer() - var parser = BHTTPParser(role: .client) var buffer = ByteBuffer() + var serializer = BHTTPSerializer() + var parser = BHTTPParser(role: .client) for message in response { - serializer.serialize(.response(message), into: &buffer) + try serializer.serialize(.response(message), into: &buffer) } parser.append(buffer) @@ -220,12 +399,54 @@ final class BHTTPSerializerTests: XCTestCase { .end(expectedTrailers), ] - let serializer = BHTTPSerializer() + var buffer = ByteBuffer() + var serializer = BHTTPSerializer() var parser = BHTTPParser(role: .client) + + for message in response { + try serializer.serialize(.response(message), into: &buffer) + } + + parser.append(buffer) + parser.completeBodyReceived() + var received: [HTTPClientResponsePart] = [] + + while let next = try parser.nextMessage(), case .response(let response) = next { + received.append(response) + } + + let expectedResponse: [HTTPClientResponsePart] = [ + .head(.init(version: .http1_1, status: .noContent, headers: expectedHeaders)), + .body(ByteBuffer(string: "hello")), + .end(expectedTrailers), + ] + XCTAssertEqual(expectedResponse, received) + } + + + func testSimple200ResponseWithBodyAndTrailersRoundTripsAndKnownLengthSerializer() throws { + let expectedHeaders = HTTPHeaders([ + ("server", "apache"), + ("other-header", "its value"), + ("content-length", "5"), + ]) + let expectedTrailers = HTTPHeaders([ + ("foo", "bar"), + ("froo", "brar"), + ]) + let response: [HTTPServerResponsePart] = [ + .head(.init(version: .http1_1, status: .noContent, headers: expectedHeaders)), + .body(.byteBuffer(ByteBuffer(string: "hello"))), + .end(expectedTrailers), + .end(nil) + ] + var buffer = ByteBuffer() + var serializer = BHTTPSerializer(type: .knownLength) + var parser = BHTTPParser(role: .client) for message in response { - serializer.serialize(.response(message), into: &buffer) + try serializer.serialize(.response(message), into: &buffer) } parser.append(buffer) From 0d675264b05d2a2fd909cf0582c0848aceb8bdcf Mon Sep 17 00:00:00 2001 From: Arnaud Date: Tue, 24 Dec 2024 17:56:54 +0100 Subject: [PATCH 2/9] Adding more documentation for the BHTTPSerializer init --- Sources/ObliviousHTTP/BHTTPSerializer.swift | 6 ++++-- Sources/ObliviousHTTP/Errors.swift | 4 ++-- Tests/ObliviousHTTPTests/BHTTPSerializerTests.swift | 6 ++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/Sources/ObliviousHTTP/BHTTPSerializer.swift b/Sources/ObliviousHTTP/BHTTPSerializer.swift index 3bee8bb..7c7ee4c 100644 --- a/Sources/ObliviousHTTP/BHTTPSerializer.swift +++ b/Sources/ObliviousHTTP/BHTTPSerializer.swift @@ -23,6 +23,9 @@ public struct BHTTPSerializer { private var fieldSectionBuffer: ByteBuffer /// Initialise a Binary HTTP Serialiser. + /// - Parameters: + /// - type: The type of BHTTPSerializer you want: either known or indeterminate length. + /// - allocator: Byte buffer allocator used. public init( type: BHTTPSerializerType = .indeterminateLength, allocator: ByteBufferAllocator = ByteBufferAllocator() @@ -201,10 +204,9 @@ extension BHTTPSerializer { .HEADER: [.CHUNK, .TRAILERS], .CHUNK: [.TRAILERS, .END], .TRAILERS: [.END], - .END: [] + .END: [], ] - // Ensure that the current state is one of the allowed states. func ensureState(_ allowedStates: [BHTTPSerializerState]) throws { if !allowedStates.contains(self.currentState) { diff --git a/Sources/ObliviousHTTP/Errors.swift b/Sources/ObliviousHTTP/Errors.swift index bcfddb9..e0f67db 100644 --- a/Sources/ObliviousHTTP/Errors.swift +++ b/Sources/ObliviousHTTP/Errors.swift @@ -53,7 +53,7 @@ public struct ObliviousHTTPError: Error, Hashable { } /// Create an error indicating that serializing failed due to an unexpected HTTP section. - /// - Parameter status: The state encountered. + /// - Parameter state: The state encountered. /// - Returns: An Error representing this failure. @inline(never) public static func unexpectedHTTPMessageSection(state: String) -> ObliviousHTTPError { @@ -61,7 +61,7 @@ public struct ObliviousHTTPError: Error, Hashable { } /// Create an error indicating that serializing failed due to an unsupported option. - /// - Parameter status: The unsupported option details. + /// - Parameter reason: The unsupported option details. /// - Returns: An Error representing this failure. @inline(never) public static func unsupportedOption(reason: String) -> ObliviousHTTPError { diff --git a/Tests/ObliviousHTTPTests/BHTTPSerializerTests.swift b/Tests/ObliviousHTTPTests/BHTTPSerializerTests.swift index a61f6cc..fae3b6f 100644 --- a/Tests/ObliviousHTTPTests/BHTTPSerializerTests.swift +++ b/Tests/ObliviousHTTPTests/BHTTPSerializerTests.swift @@ -29,7 +29,6 @@ final class BHTTPSerializerTests: XCTestCase { .head(.init(version: .http1_1, method: .GET, uri: "/example", headers: expectedHeaders)), .end(nil), ] - var buffer = ByteBuffer() var serializer = BHTTPSerializer() var parser = BHTTPParser(role: .server) @@ -310,7 +309,7 @@ final class BHTTPSerializerTests: XCTestCase { let expectedRequest: [HTTPServerRequestPart] = [ .head(.init(version: .http1_1, method: .POST, uri: "/example", headers: expectedHeaders)), .body(.init(string: "hello")), - .end(expectedTrailers) + .end(expectedTrailers), ] XCTAssertEqual(expectedRequest, received) } @@ -423,7 +422,6 @@ final class BHTTPSerializerTests: XCTestCase { XCTAssertEqual(expectedResponse, received) } - func testSimple200ResponseWithBodyAndTrailersRoundTripsAndKnownLengthSerializer() throws { let expectedHeaders = HTTPHeaders([ ("server", "apache"), @@ -438,7 +436,7 @@ final class BHTTPSerializerTests: XCTestCase { .head(.init(version: .http1_1, status: .noContent, headers: expectedHeaders)), .body(.byteBuffer(ByteBuffer(string: "hello"))), .end(expectedTrailers), - .end(nil) + .end(nil), ] var buffer = ByteBuffer() From 8e9351f7c528987e27c0f5de0dca17f5207a5732 Mon Sep 17 00:00:00 2001 From: Arnaud Date: Thu, 2 Jan 2025 16:47:31 +0100 Subject: [PATCH 3/9] Switch enum BHTTPSerializerType to a struct called SerializerType --- Sources/ObliviousHTTP/BHTTPSerializer.swift | 54 ++++++++++++++------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/Sources/ObliviousHTTP/BHTTPSerializer.swift b/Sources/ObliviousHTTP/BHTTPSerializer.swift index 7c7ee4c..6d83881 100644 --- a/Sources/ObliviousHTTP/BHTTPSerializer.swift +++ b/Sources/ObliviousHTTP/BHTTPSerializer.swift @@ -18,7 +18,7 @@ import NIOHTTP1 public struct BHTTPSerializer { private var fsm: BHTTPSerializerFSM - private var type: BHTTPSerializerType + private var type: SerializerType private var chunkBuffer: ByteBuffer private var fieldSectionBuffer: ByteBuffer @@ -27,7 +27,7 @@ public struct BHTTPSerializer { /// - type: The type of BHTTPSerializer you want: either known or indeterminate length. /// - allocator: Byte buffer allocator used. public init( - type: BHTTPSerializerType = .indeterminateLength, + type: SerializerType = .indeterminateLength, allocator: ByteBufferAllocator = ByteBufferAllocator() ) { self.type = type @@ -55,10 +55,11 @@ public struct BHTTPSerializer { case .request(.body(.byteBuffer(let body))), .response(.body(.byteBuffer(let body))): try self.fsm.ensureState([BHTTPSerializerState.CHUNK]) switch self.type { - case .indeterminateLength: - Self.serializeContentChunk(body, into: &buffer) case .knownLength: self.stackContentChunk(body) + break + default: + Self.serializeContentChunk(body, into: &buffer) } case .request(.body(.fileRegion)), .response(.body(.fileRegion)): @@ -67,24 +68,26 @@ public struct BHTTPSerializer { case .request(.end(.some(let trailers))), .response(.end(.some(let trailers))): try self.fsm.ensureState([BHTTPSerializerState.CHUNK, BHTTPSerializerState.HEADER]) switch self.type { - case .indeterminateLength: - // Send a 0 to terminate the body, then a field section. - buffer.writeInteger(UInt8(0)) - Self.serializeIndeterminateLengthFieldSection(trailers, into: &buffer) case .knownLength: self.serializeContent(into: &buffer) self.stackKnownLengthFieldSection(trailers) + break + default: + // Send a 0 to terminate the body, then a field section. + buffer.writeInteger(UInt8(0)) + Self.serializeIndeterminateLengthFieldSection(trailers, into: &buffer) } try self.fsm.transition(to: BHTTPSerializerState.TRAILERS) case .request(.end(.none)), .response(.end(.none)): try self.fsm.ensureState([BHTTPSerializerState.CHUNK, BHTTPSerializerState.TRAILERS]) switch self.type { - case .indeterminateLength: - buffer.writeInteger(UInt8(0)) case .knownLength: self.serializeContent(into: &buffer) self.serializeKnownLengthFieldSection(into: &buffer) + break + default: + buffer.writeInteger(UInt8(0)) } try self.fsm.transition(to: BHTTPSerializerState.END) } @@ -109,11 +112,12 @@ public struct BHTTPSerializer { buffer.writeVarintPrefixedString(path) switch self.type { - case .indeterminateLength: - Self.serializeIndeterminateLengthFieldSection(head.headers, into: &buffer) case .knownLength: self.stackKnownLengthFieldSection(head.headers) self.serializeKnownLengthFieldSection(into: &buffer) + break + default: + Self.serializeIndeterminateLengthFieldSection(head.headers, into: &buffer) } } @@ -128,11 +132,12 @@ public struct BHTTPSerializer { buffer.writeVarint(Int(head.status.code)) switch self.type { - case .indeterminateLength: - Self.serializeIndeterminateLengthFieldSection(head.headers, into: &buffer) case .knownLength: self.stackKnownLengthFieldSection(head.headers) self.serializeKnownLengthFieldSection(into: &buffer) + break + default: + Self.serializeIndeterminateLengthFieldSection(head.headers, into: &buffer) } } @@ -220,9 +225,24 @@ extension BHTTPSerializer { case response(HTTPServerResponsePart) } - public enum BHTTPSerializerType { - case knownLength - case indeterminateLength + public struct SerializerType: Equatable { + private enum InternalType: Equatable { + case knownLength + case indeterminateLength + } + + private let type: InternalType + + public static let knownLength = SerializerType(type: .knownLength) + public static let indeterminateLength = SerializerType(type: .indeterminateLength) + + private init(type: InternalType) { + self.type = type + } + + public static func == (lop: SerializerType, rop: SerializerType) -> Bool { + lop.type == rop.type + } } public enum BHTTPFramingIndicator: Int { From c1524f327c7b327afa19d3aca8fe1339f17bb7af Mon Sep 17 00:00:00 2001 From: Arnaud Date: Sun, 5 Jan 2025 15:45:49 +0100 Subject: [PATCH 4/9] Refactor BHTTPSerializer: removing the ensureState function on the FSM and better framing indicator computation --- Sources/ObliviousHTTP/BHTTPSerializer.swift | 109 ++++++++++---------- Sources/ObliviousHTTP/Errors.swift | 4 +- 2 files changed, 56 insertions(+), 57 deletions(-) diff --git a/Sources/ObliviousHTTP/BHTTPSerializer.swift b/Sources/ObliviousHTTP/BHTTPSerializer.swift index 6d83881..1aad885 100644 --- a/Sources/ObliviousHTTP/BHTTPSerializer.swift +++ b/Sources/ObliviousHTTP/BHTTPSerializer.swift @@ -33,7 +33,25 @@ public struct BHTTPSerializer { self.type = type self.chunkBuffer = allocator.buffer(capacity: 0) self.fieldSectionBuffer = allocator.buffer(capacity: 0) - self.fsm = BHTTPSerializerFSM(initialState: BHTTPSerializerState.HEADER) + self.fsm = BHTTPSerializerFSM(initialState: BHTTPSerializerState.start) + } + + private var requestFramingIndicator: Int { + switch self.type { + case .knownLength: + return FramingIndicator.requestKnownLength + default: + return FramingIndicator.requestIndeterminateLength + } + } + + private var responseFramingIndicator: Int { + switch self.type { + case .knownLength: + return FramingIndicator.responseKnownLength + default: + return FramingIndicator.responseIndeterminateLength + } } /// Serialise a message into a buffer using binary HTTP encoding. @@ -43,17 +61,14 @@ public struct BHTTPSerializer { public mutating func serialize(_ message: Message, into buffer: inout ByteBuffer) throws { switch message { case .request(.head(let requestHead)): - try self.fsm.ensureState([BHTTPSerializerState.HEADER]) self.serializeRequestHead(requestHead, into: &buffer) - try self.fsm.transition(to: BHTTPSerializerState.CHUNK) + try self.fsm.transition(to: BHTTPSerializerState.header) case .response(.head(let responseHead)): - try self.fsm.ensureState([BHTTPSerializerState.HEADER]) self.serializeResponseHead(responseHead, into: &buffer) - try self.fsm.transition(to: BHTTPSerializerState.CHUNK) + try self.fsm.transition(to: BHTTPSerializerState.header) case .request(.body(.byteBuffer(let body))), .response(.body(.byteBuffer(let body))): - try self.fsm.ensureState([BHTTPSerializerState.CHUNK]) switch self.type { case .knownLength: self.stackContentChunk(body) @@ -61,12 +76,12 @@ public struct BHTTPSerializer { default: Self.serializeContentChunk(body, into: &buffer) } + try self.fsm.transition(to: BHTTPSerializerState.chunk) case .request(.body(.fileRegion)), .response(.body(.fileRegion)): throw ObliviousHTTPError.unsupportedOption(reason: "fileregion unsupported") case .request(.end(.some(let trailers))), .response(.end(.some(let trailers))): - try self.fsm.ensureState([BHTTPSerializerState.CHUNK, BHTTPSerializerState.HEADER]) switch self.type { case .knownLength: self.serializeContent(into: &buffer) @@ -77,10 +92,10 @@ public struct BHTTPSerializer { buffer.writeInteger(UInt8(0)) Self.serializeIndeterminateLengthFieldSection(trailers, into: &buffer) } - try self.fsm.transition(to: BHTTPSerializerState.TRAILERS) + try self.fsm.transition(to: BHTTPSerializerState.trailers) + case .request(.end(.none)), .response(.end(.none)): - try self.fsm.ensureState([BHTTPSerializerState.CHUNK, BHTTPSerializerState.TRAILERS]) switch self.type { case .knownLength: self.serializeContent(into: &buffer) @@ -89,17 +104,13 @@ public struct BHTTPSerializer { default: buffer.writeInteger(UInt8(0)) } - try self.fsm.transition(to: BHTTPSerializerState.END) + try self.fsm.transition(to: BHTTPSerializerState.end) } } private mutating func serializeRequestHead(_ head: HTTPRequestHead, into buffer: inout ByteBuffer) { // First, the framing indicator - buffer.writeVarint( - self.type == .indeterminateLength - ? BHTTPFramingIndicator.requestIndeterminateLength.rawValue - : BHTTPFramingIndicator.requestKnownLength.rawValue - ) + buffer.writeVarint(requestFramingIndicator) let method = head.method let scheme = "https" // Hardcoded for now, but not really the right option. @@ -123,11 +134,7 @@ public struct BHTTPSerializer { private mutating func serializeResponseHead(_ head: HTTPResponseHead, into buffer: inout ByteBuffer) { // First, the framing indicator - buffer.writeVarint( - self.type == .indeterminateLength - ? BHTTPFramingIndicator.responseInderterminateLength.rawValue - : BHTTPFramingIndicator.responseKnownLength.rawValue - ) + buffer.writeVarint(responseFramingIndicator) buffer.writeVarint(Int(head.status.code)) @@ -141,7 +148,6 @@ public struct BHTTPSerializer { } } - @inline(__always) private static func serializeContentChunk(_ chunk: ByteBuffer, into buffer: inout ByteBuffer) { if chunk.readableBytes == 0 { return } buffer.writeVarintPrefixedImmutableBuffer(chunk) @@ -184,40 +190,32 @@ public struct BHTTPSerializer { // Enum definitions for message, states, and types. extension BHTTPSerializer { // Finite State Machine for managing transitions in BHTTPSerializer. - private class BHTTPSerializerFSM { + public struct BHTTPSerializerFSM { var currentState: BHTTPSerializerState init(initialState: BHTTPSerializerState) { self.currentState = initialState } - // Transition to a new state, respecting the state machine constraints. - func transition(to state: BHTTPSerializerState) throws { - guard let allowedTransitions = Self.validTransitions[currentState] else { - throw ObliviousHTTPError.unexpectedHTTPMessageSection(state: currentState.rawValue) + mutating func transition(to state: BHTTPSerializerState) throws { + let allowedNextStates: Set + switch currentState { + case .start: + allowedNextStates = [.header] + case .header: + allowedNextStates = [.chunk, .trailers, .end] + case .chunk: + allowedNextStates = [.trailers, .end, .chunk] + case .trailers: + allowedNextStates = [.trailers, .end] + case .end: + allowedNextStates = [] } - - guard allowedTransitions.contains(state) else { - throw ObliviousHTTPError.unexpectedHTTPMessageSection(state: currentState.rawValue) + guard allowedNextStates.contains(state) else { + throw ObliviousHTTPError.unexpectedHTTPMessageSection() } - currentState = state } - - // Define a dictionary to map current states to allowed transitions - private static let validTransitions: [BHTTPSerializerState: Set] = [ - .HEADER: [.CHUNK, .TRAILERS], - .CHUNK: [.TRAILERS, .END], - .TRAILERS: [.END], - .END: [], - ] - - // Ensure that the current state is one of the allowed states. - func ensureState(_ allowedStates: [BHTTPSerializerState]) throws { - if !allowedStates.contains(self.currentState) { - throw ObliviousHTTPError.unexpectedHTTPMessageSection(state: self.currentState.rawValue) - } - } } public enum Message { @@ -245,17 +243,18 @@ extension BHTTPSerializer { } } - public enum BHTTPFramingIndicator: Int { - case requestKnownLength = 0 - case responseKnownLength = 1 - case requestIndeterminateLength = 2 - case responseInderterminateLength = 3 + internal struct FramingIndicator { + static var requestKnownLength: Int { 0 } + static var responseKnownLength: Int { 1 } + static var requestIndeterminateLength: Int { 2 } + static var responseIndeterminateLength: Int { 3 } } - public enum BHTTPSerializerState: String { - case HEADER = "Header" - case CHUNK = "Chunk" - case TRAILERS = "Trailers" - case END = "End" + public enum BHTTPSerializerState { + case start + case header + case chunk + case trailers + case end } } diff --git a/Sources/ObliviousHTTP/Errors.swift b/Sources/ObliviousHTTP/Errors.swift index e0f67db..cb56f53 100644 --- a/Sources/ObliviousHTTP/Errors.swift +++ b/Sources/ObliviousHTTP/Errors.swift @@ -56,8 +56,8 @@ public struct ObliviousHTTPError: Error, Hashable { /// - Parameter state: The state encountered. /// - Returns: An Error representing this failure. @inline(never) - public static func unexpectedHTTPMessageSection(state: String) -> ObliviousHTTPError { - Self.init(backing: .unexpectedHTTPMessageSection(state: "\(state) section was not expected.")) + public static func unexpectedHTTPMessageSection() -> ObliviousHTTPError { + Self.init(backing: .unexpectedHTTPMessageSection(state: "An unexpected HTTP message section was encountered.")) } /// Create an error indicating that serializing failed due to an unsupported option. From 1fa44b7453e5c8f164fd2a60e1117ce4f8158de9 Mon Sep 17 00:00:00 2001 From: Arnaud Date: Sat, 11 Jan 2025 13:36:29 +0100 Subject: [PATCH 5/9] Merge serializations and transitions logic into the serializer FSM. --- Sources/ObliviousHTTP/BHTTPSerializer.swift | 124 ++++++++++++++------ 1 file changed, 87 insertions(+), 37 deletions(-) diff --git a/Sources/ObliviousHTTP/BHTTPSerializer.swift b/Sources/ObliviousHTTP/BHTTPSerializer.swift index 1aad885..e2bea8c 100644 --- a/Sources/ObliviousHTTP/BHTTPSerializer.swift +++ b/Sources/ObliviousHTTP/BHTTPSerializer.swift @@ -18,7 +18,7 @@ import NIOHTTP1 public struct BHTTPSerializer { private var fsm: BHTTPSerializerFSM - private var type: SerializerType + public var type: SerializerType private var chunkBuffer: ByteBuffer private var fieldSectionBuffer: ByteBuffer @@ -61,50 +61,22 @@ public struct BHTTPSerializer { public mutating func serialize(_ message: Message, into buffer: inout ByteBuffer) throws { switch message { case .request(.head(let requestHead)): - self.serializeRequestHead(requestHead, into: &buffer) - try self.fsm.transition(to: BHTTPSerializerState.header) + try self.fsm.writeRequestHead(requestHead, into: &buffer, using: &self) case .response(.head(let responseHead)): - self.serializeResponseHead(responseHead, into: &buffer) - try self.fsm.transition(to: BHTTPSerializerState.header) + try self.fsm.writeResponseHead(responseHead, into: &buffer, using: &self) case .request(.body(.byteBuffer(let body))), .response(.body(.byteBuffer(let body))): - switch self.type { - case .knownLength: - self.stackContentChunk(body) - break - default: - Self.serializeContentChunk(body, into: &buffer) - } - try self.fsm.transition(to: BHTTPSerializerState.chunk) + try self.fsm.writeBodyChunk(body, into: &buffer, using: &self) case .request(.body(.fileRegion)), .response(.body(.fileRegion)): throw ObliviousHTTPError.unsupportedOption(reason: "fileregion unsupported") case .request(.end(.some(let trailers))), .response(.end(.some(let trailers))): - switch self.type { - case .knownLength: - self.serializeContent(into: &buffer) - self.stackKnownLengthFieldSection(trailers) - break - default: - // Send a 0 to terminate the body, then a field section. - buffer.writeInteger(UInt8(0)) - Self.serializeIndeterminateLengthFieldSection(trailers, into: &buffer) - } - try self.fsm.transition(to: BHTTPSerializerState.trailers) - + try self.fsm.writeTrailers(trailers, into: &buffer, using: &self) case .request(.end(.none)), .response(.end(.none)): - switch self.type { - case .knownLength: - self.serializeContent(into: &buffer) - self.serializeKnownLengthFieldSection(into: &buffer) - break - default: - buffer.writeInteger(UInt8(0)) - } - try self.fsm.transition(to: BHTTPSerializerState.end) + try self.fsm.writeRequestEnd(into: &buffer, using: &self) } } @@ -148,6 +120,16 @@ public struct BHTTPSerializer { } } + private mutating func serializeChunk(_ chunk: ByteBuffer, into buffer: inout ByteBuffer) { + switch self.type { + case .knownLength: + self.stackContentChunk(chunk) + break + default: + Self.serializeContentChunk(chunk, into: &buffer) + } + } + private static func serializeContentChunk(_ chunk: ByteBuffer, into buffer: inout ByteBuffer) { if chunk.readableBytes == 0 { return } buffer.writeVarintPrefixedImmutableBuffer(chunk) @@ -174,6 +156,19 @@ public struct BHTTPSerializer { buffer.writeInteger(UInt8(0)) // End of field section } + private mutating func serializeTrailers(_ trailers: HTTPHeaders, into buffer: inout ByteBuffer) { + switch self.type { + case .knownLength: + self.serializeContent(into: &buffer) + self.stackKnownLengthFieldSection(trailers) + break + default: + // Send a 0 to terminate the body, then a field section. + buffer.writeInteger(UInt8(0)) + Self.serializeIndeterminateLengthFieldSection(trailers, into: &buffer) + } + } + private mutating func serializeKnownLengthFieldSection(into buffer: inout ByteBuffer) { buffer.writeVarintPrefixedImmutableBuffer(self.fieldSectionBuffer) self.fieldSectionBuffer.clear() @@ -185,19 +180,74 @@ public struct BHTTPSerializer { self.fieldSectionBuffer.writeVarintPrefixedString(value) } } + + private mutating func endRequest(into buffer: inout ByteBuffer) { + switch self.type { + case .knownLength: + self.serializeContent(into: &buffer) + self.serializeKnownLengthFieldSection(into: &buffer) + break + default: + buffer.writeInteger(UInt8(0)) + } + } } // Enum definitions for message, states, and types. extension BHTTPSerializer { // Finite State Machine for managing transitions in BHTTPSerializer. - public struct BHTTPSerializerFSM { - var currentState: BHTTPSerializerState + public class BHTTPSerializerFSM { + private(set) var currentState: BHTTPSerializerState init(initialState: BHTTPSerializerState) { self.currentState = initialState } - mutating func transition(to state: BHTTPSerializerState) throws { + func writeRequestHead( + _ requestHead: HTTPRequestHead, + into buffer: inout ByteBuffer, + using serializer: inout BHTTPSerializer + ) throws { + try self.transition(to: .header) + serializer.serializeRequestHead(requestHead, into: &buffer) + } + + func writeResponseHead( + _ responseHead: HTTPResponseHead, + into buffer: inout ByteBuffer, + using serializer: inout BHTTPSerializer + ) throws { + try self.transition(to: .header) + serializer.serializeResponseHead(responseHead, into: &buffer) + } + + + func writeRequestEnd( + into buffer: inout ByteBuffer, + using serializer: inout BHTTPSerializer + ) throws { + serializer.endRequest(into: &buffer) + try self.transition(to: .end) + } + func writeBodyChunk( + _ body: ByteBuffer, + into buffer: inout ByteBuffer, + using serializer: inout BHTTPSerializer + ) throws { + serializer.serializeChunk(body, into: &buffer) + try self.transition(to: .chunk) + } + + func writeTrailers( + _ trailers: HTTPHeaders, + into buffer: inout ByteBuffer, + using serializer: inout BHTTPSerializer + ) throws { + serializer.serializeTrailers(trailers, into: &buffer) + try self.transition(to: .trailers) + } + + func transition(to state: BHTTPSerializerState) throws { let allowedNextStates: Set switch currentState { case .start: From 878902296be7344df9ddd81c0a13250dc10a4c35 Mon Sep 17 00:00:00 2001 From: Arnaud Date: Fri, 17 Jan 2025 11:12:16 +0100 Subject: [PATCH 6/9] Switch BHTTP Serializer type and state to private --- Sources/ObliviousHTTP/BHTTPSerializer.swift | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Sources/ObliviousHTTP/BHTTPSerializer.swift b/Sources/ObliviousHTTP/BHTTPSerializer.swift index e2bea8c..42eb8f1 100644 --- a/Sources/ObliviousHTTP/BHTTPSerializer.swift +++ b/Sources/ObliviousHTTP/BHTTPSerializer.swift @@ -18,7 +18,7 @@ import NIOHTTP1 public struct BHTTPSerializer { private var fsm: BHTTPSerializerFSM - public var type: SerializerType + private var type: SerializerType private var chunkBuffer: ByteBuffer private var fieldSectionBuffer: ByteBuffer @@ -197,7 +197,7 @@ public struct BHTTPSerializer { extension BHTTPSerializer { // Finite State Machine for managing transitions in BHTTPSerializer. public class BHTTPSerializerFSM { - private(set) var currentState: BHTTPSerializerState + private var currentState: BHTTPSerializerState init(initialState: BHTTPSerializerState) { self.currentState = initialState @@ -247,7 +247,7 @@ extension BHTTPSerializer { try self.transition(to: .trailers) } - func transition(to state: BHTTPSerializerState) throws { + private func transition(to state: BHTTPSerializerState) throws { let allowedNextStates: Set switch currentState { case .start: @@ -300,7 +300,7 @@ extension BHTTPSerializer { static var responseIndeterminateLength: Int { 3 } } - public enum BHTTPSerializerState { + internal enum BHTTPSerializerState { case start case header case chunk From fc9daf6ffd2207bce8cafe919ae6760d5648eae6 Mon Sep 17 00:00:00 2001 From: Arnaud Date: Fri, 17 Jan 2025 19:46:12 +0100 Subject: [PATCH 7/9] Add few stylistic fixs for BHTTPSerialzier --- Sources/ObliviousHTTP/BHTTPSerializer.swift | 27 ++++++++++----------- Sources/ObliviousHTTP/Errors.swift | 1 - 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/Sources/ObliviousHTTP/BHTTPSerializer.swift b/Sources/ObliviousHTTP/BHTTPSerializer.swift index 42eb8f1..cb166f0 100644 --- a/Sources/ObliviousHTTP/BHTTPSerializer.swift +++ b/Sources/ObliviousHTTP/BHTTPSerializer.swift @@ -98,9 +98,9 @@ public struct BHTTPSerializer { case .knownLength: self.stackKnownLengthFieldSection(head.headers) self.serializeKnownLengthFieldSection(into: &buffer) - break - default: + case .indeterminateLength: Self.serializeIndeterminateLengthFieldSection(head.headers, into: &buffer) + default: break } } @@ -114,9 +114,9 @@ public struct BHTTPSerializer { case .knownLength: self.stackKnownLengthFieldSection(head.headers) self.serializeKnownLengthFieldSection(into: &buffer) - break - default: + case .indeterminateLength: Self.serializeIndeterminateLengthFieldSection(head.headers, into: &buffer) + default: break } } @@ -124,9 +124,9 @@ public struct BHTTPSerializer { switch self.type { case .knownLength: self.stackContentChunk(chunk) - break - default: + case .indeterminateLength: Self.serializeContentChunk(chunk, into: &buffer) + default: break } } @@ -135,7 +135,7 @@ public struct BHTTPSerializer { buffer.writeVarintPrefixedImmutableBuffer(chunk) } - private mutating func serializeContent(into buffer: inout ByteBuffer) { + private mutating func serializeStackedContent(into buffer: inout ByteBuffer) { if self.chunkBuffer.readableBytes == 0 { return } buffer.writeVarintPrefixedImmutableBuffer(self.chunkBuffer) self.chunkBuffer.clear() @@ -159,13 +159,13 @@ public struct BHTTPSerializer { private mutating func serializeTrailers(_ trailers: HTTPHeaders, into buffer: inout ByteBuffer) { switch self.type { case .knownLength: - self.serializeContent(into: &buffer) + self.serializeStackedContent(into: &buffer) self.stackKnownLengthFieldSection(trailers) - break - default: + case .indeterminateLength: // Send a 0 to terminate the body, then a field section. buffer.writeInteger(UInt8(0)) Self.serializeIndeterminateLengthFieldSection(trailers, into: &buffer) + default: break } } @@ -184,11 +184,11 @@ public struct BHTTPSerializer { private mutating func endRequest(into buffer: inout ByteBuffer) { switch self.type { case .knownLength: - self.serializeContent(into: &buffer) + self.serializeStackedContent(into: &buffer) self.serializeKnownLengthFieldSection(into: &buffer) - break - default: + case .indeterminateLength: buffer.writeInteger(UInt8(0)) + default: break } } } @@ -221,7 +221,6 @@ extension BHTTPSerializer { serializer.serializeResponseHead(responseHead, into: &buffer) } - func writeRequestEnd( into buffer: inout ByteBuffer, using serializer: inout BHTTPSerializer diff --git a/Sources/ObliviousHTTP/Errors.swift b/Sources/ObliviousHTTP/Errors.swift index cb56f53..4794990 100644 --- a/Sources/ObliviousHTTP/Errors.swift +++ b/Sources/ObliviousHTTP/Errors.swift @@ -53,7 +53,6 @@ public struct ObliviousHTTPError: Error, Hashable { } /// Create an error indicating that serializing failed due to an unexpected HTTP section. - /// - Parameter state: The state encountered. /// - Returns: An Error representing this failure. @inline(never) public static func unexpectedHTTPMessageSection() -> ObliviousHTTPError { From 8dfb9b833c015cae955ffc9ef4c3998e51850874 Mon Sep 17 00:00:00 2001 From: Arnaud Date: Tue, 21 Jan 2025 17:47:30 +0100 Subject: [PATCH 8/9] Change SerializerType protocol from Equatable to Hashble, Sendable --- Sources/ObliviousHTTP/BHTTPSerializer.swift | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/Sources/ObliviousHTTP/BHTTPSerializer.swift b/Sources/ObliviousHTTP/BHTTPSerializer.swift index cb166f0..98523ac 100644 --- a/Sources/ObliviousHTTP/BHTTPSerializer.swift +++ b/Sources/ObliviousHTTP/BHTTPSerializer.swift @@ -196,7 +196,7 @@ public struct BHTTPSerializer { // Enum definitions for message, states, and types. extension BHTTPSerializer { // Finite State Machine for managing transitions in BHTTPSerializer. - public class BHTTPSerializerFSM { + private class BHTTPSerializerFSM { private var currentState: BHTTPSerializerState init(initialState: BHTTPSerializerState) { @@ -272,8 +272,8 @@ extension BHTTPSerializer { case response(HTTPServerResponsePart) } - public struct SerializerType: Equatable { - private enum InternalType: Equatable { + public struct SerializerType: Hashable, Sendable { + private enum InternalType: Hashable { case knownLength case indeterminateLength } @@ -287,9 +287,6 @@ extension BHTTPSerializer { self.type = type } - public static func == (lop: SerializerType, rop: SerializerType) -> Bool { - lop.type == rop.type - } } internal struct FramingIndicator { From 2010ef9f1a5aa3f7bcfdaff73d86da07e47bc56c Mon Sep 17 00:00:00 2001 From: Arnaud Lecomte Date: Wed, 22 Jan 2025 11:47:50 +0100 Subject: [PATCH 9/9] Move serialization process from BHTTPSerializer to a new struct to ensure exclusive access on buffers --- Sources/ObliviousHTTP/BHTTPSerializer.swift | 301 +++++++++++--------- 1 file changed, 159 insertions(+), 142 deletions(-) diff --git a/Sources/ObliviousHTTP/BHTTPSerializer.swift b/Sources/ObliviousHTTP/BHTTPSerializer.swift index 98523ac..06941a9 100644 --- a/Sources/ObliviousHTTP/BHTTPSerializer.swift +++ b/Sources/ObliviousHTTP/BHTTPSerializer.swift @@ -18,9 +18,7 @@ import NIOHTTP1 public struct BHTTPSerializer { private var fsm: BHTTPSerializerFSM - private var type: SerializerType - private var chunkBuffer: ByteBuffer - private var fieldSectionBuffer: ByteBuffer + private var context: SerializerContext /// Initialise a Binary HTTP Serialiser. /// - Parameters: @@ -30,28 +28,8 @@ public struct BHTTPSerializer { type: SerializerType = .indeterminateLength, allocator: ByteBufferAllocator = ByteBufferAllocator() ) { - self.type = type - self.chunkBuffer = allocator.buffer(capacity: 0) - self.fieldSectionBuffer = allocator.buffer(capacity: 0) self.fsm = BHTTPSerializerFSM(initialState: BHTTPSerializerState.start) - } - - private var requestFramingIndicator: Int { - switch self.type { - case .knownLength: - return FramingIndicator.requestKnownLength - default: - return FramingIndicator.requestIndeterminateLength - } - } - - private var responseFramingIndicator: Int { - switch self.type { - case .knownLength: - return FramingIndicator.responseKnownLength - default: - return FramingIndicator.responseIndeterminateLength - } + self.context = SerializerContext(type: type, allocator: allocator) } /// Serialise a message into a buffer using binary HTTP encoding. @@ -61,192 +39,231 @@ public struct BHTTPSerializer { public mutating func serialize(_ message: Message, into buffer: inout ByteBuffer) throws { switch message { case .request(.head(let requestHead)): - try self.fsm.writeRequestHead(requestHead, into: &buffer, using: &self) + try self.fsm.writeRequestHead(requestHead, into: &buffer, using: &self.context) case .response(.head(let responseHead)): - try self.fsm.writeResponseHead(responseHead, into: &buffer, using: &self) + try self.fsm.writeResponseHead(responseHead, into: &buffer, using: &self.context) case .request(.body(.byteBuffer(let body))), .response(.body(.byteBuffer(let body))): - try self.fsm.writeBodyChunk(body, into: &buffer, using: &self) + try self.fsm.writeBodyChunk(body, into: &buffer, using: &self.context) case .request(.body(.fileRegion)), .response(.body(.fileRegion)): throw ObliviousHTTPError.unsupportedOption(reason: "fileregion unsupported") case .request(.end(.some(let trailers))), .response(.end(.some(let trailers))): - try self.fsm.writeTrailers(trailers, into: &buffer, using: &self) + try self.fsm.writeTrailers(trailers, into: &buffer, using: &self.context) case .request(.end(.none)), .response(.end(.none)): - try self.fsm.writeRequestEnd(into: &buffer, using: &self) + try self.fsm.writeRequestEnd(into: &buffer, using: &self.context) } } - private mutating func serializeRequestHead(_ head: HTTPRequestHead, into buffer: inout ByteBuffer) { - // First, the framing indicator - buffer.writeVarint(requestFramingIndicator) - - let method = head.method - let scheme = "https" // Hardcoded for now, but not really the right option. - let path = head.uri - let authority = head.headers["Host"].first ?? "" - - buffer.writeVarintPrefixedString(method.rawValue) - buffer.writeVarintPrefixedString(scheme) - buffer.writeVarintPrefixedString(authority) - buffer.writeVarintPrefixedString(path) - - switch self.type { - case .knownLength: - self.stackKnownLengthFieldSection(head.headers) - self.serializeKnownLengthFieldSection(into: &buffer) - case .indeterminateLength: - Self.serializeIndeterminateLengthFieldSection(head.headers, into: &buffer) - default: break + +} + +// Enum definitions for message, states, and types. +extension BHTTPSerializer { + + private struct SerializerContext { + private var chunkBuffer: ByteBuffer + private var fieldSectionBuffer: ByteBuffer + private var type: SerializerType + + + public init( + type: SerializerType, + allocator: ByteBufferAllocator = ByteBufferAllocator() + ) { + self.chunkBuffer = allocator.buffer(capacity: 0) + self.fieldSectionBuffer = allocator.buffer(capacity: 0) + self.type = type + } + private var requestFramingIndicator: Int { + switch self.type { + case .knownLength: + return FramingIndicator.requestKnownLength + default: + return FramingIndicator.requestIndeterminateLength + } } - } - private mutating func serializeResponseHead(_ head: HTTPResponseHead, into buffer: inout ByteBuffer) { - // First, the framing indicator - buffer.writeVarint(responseFramingIndicator) + private var responseFramingIndicator: Int { + switch self.type { + case .knownLength: + return FramingIndicator.responseKnownLength + default: + return FramingIndicator.responseIndeterminateLength + } + } + + + mutating func serializeRequestHead(_ head: HTTPRequestHead, into buffer: inout ByteBuffer) { + // First, the framing indicator + buffer.writeVarint(requestFramingIndicator) + + let method = head.method + let scheme = "https" // Hardcoded for now, but not really the right option. + let path = head.uri + let authority = head.headers["Host"].first ?? "" + + buffer.writeVarintPrefixedString(method.rawValue) + buffer.writeVarintPrefixedString(scheme) + buffer.writeVarintPrefixedString(authority) + buffer.writeVarintPrefixedString(path) + + switch self.type { + case .knownLength: + self.stackKnownLengthFieldSection(head.headers) + self.serializeKnownLengthFieldSection(into: &buffer) + case .indeterminateLength: + Self.serializeIndeterminateLengthFieldSection(head.headers, into: &buffer) + default: break + } + } + + mutating func serializeResponseHead(_ head: HTTPResponseHead, into buffer: inout ByteBuffer) { + // First, the framing indicator + buffer.writeVarint(responseFramingIndicator) - buffer.writeVarint(Int(head.status.code)) + buffer.writeVarint(Int(head.status.code)) - switch self.type { - case .knownLength: - self.stackKnownLengthFieldSection(head.headers) - self.serializeKnownLengthFieldSection(into: &buffer) - case .indeterminateLength: - Self.serializeIndeterminateLengthFieldSection(head.headers, into: &buffer) - default: break + switch self.type { + case .knownLength: + self.stackKnownLengthFieldSection(head.headers) + self.serializeKnownLengthFieldSection(into: &buffer) + case .indeterminateLength: + Self.serializeIndeterminateLengthFieldSection(head.headers, into: &buffer) + default: break + } } - } - private mutating func serializeChunk(_ chunk: ByteBuffer, into buffer: inout ByteBuffer) { - switch self.type { - case .knownLength: - self.stackContentChunk(chunk) - case .indeterminateLength: - Self.serializeContentChunk(chunk, into: &buffer) - default: break + mutating func serializeChunk(_ chunk: ByteBuffer, into buffer: inout ByteBuffer) { + switch self.type { + case .knownLength: + self.stackContentChunk(chunk) + case .indeterminateLength: + Self.serializeContentChunk(chunk, into: &buffer) + default: break + } } - } - private static func serializeContentChunk(_ chunk: ByteBuffer, into buffer: inout ByteBuffer) { - if chunk.readableBytes == 0 { return } - buffer.writeVarintPrefixedImmutableBuffer(chunk) - } + static func serializeContentChunk(_ chunk: ByteBuffer, into buffer: inout ByteBuffer) { + if chunk.readableBytes == 0 { return } + buffer.writeVarintPrefixedImmutableBuffer(chunk) + } - private mutating func serializeStackedContent(into buffer: inout ByteBuffer) { - if self.chunkBuffer.readableBytes == 0 { return } - buffer.writeVarintPrefixedImmutableBuffer(self.chunkBuffer) - self.chunkBuffer.clear() - } + mutating func serializeStackedContent(into buffer: inout ByteBuffer) { + if self.chunkBuffer.readableBytes == 0 { return } + buffer.writeVarintPrefixedImmutableBuffer(self.chunkBuffer) + self.chunkBuffer.clear() + } - private mutating func stackContentChunk(_ chunk: ByteBuffer) { - self.chunkBuffer.writeImmutableBuffer(chunk) - } + mutating func stackContentChunk(_ chunk: ByteBuffer) { + self.chunkBuffer.writeImmutableBuffer(chunk) + } - private static func serializeIndeterminateLengthFieldSection( - _ fields: HTTPHeaders, - into buffer: inout ByteBuffer - ) { - for (name, value) in fields { - buffer.writeVarintPrefixedString(name) - buffer.writeVarintPrefixedString(value) + static func serializeIndeterminateLengthFieldSection( + _ fields: HTTPHeaders, + into buffer: inout ByteBuffer + ) { + for (name, value) in fields { + buffer.writeVarintPrefixedString(name) + buffer.writeVarintPrefixedString(value) + } + buffer.writeInteger(UInt8(0)) // End of field section } - buffer.writeInteger(UInt8(0)) // End of field section - } - private mutating func serializeTrailers(_ trailers: HTTPHeaders, into buffer: inout ByteBuffer) { - switch self.type { - case .knownLength: - self.serializeStackedContent(into: &buffer) - self.stackKnownLengthFieldSection(trailers) - case .indeterminateLength: - // Send a 0 to terminate the body, then a field section. - buffer.writeInteger(UInt8(0)) - Self.serializeIndeterminateLengthFieldSection(trailers, into: &buffer) - default: break + mutating func serializeTrailers(_ trailers: HTTPHeaders, into buffer: inout ByteBuffer) { + switch self.type { + case .knownLength: + self.serializeStackedContent(into: &buffer) + self.stackKnownLengthFieldSection(trailers) + case .indeterminateLength: + // Send a 0 to terminate the body, then a field section. + buffer.writeInteger(UInt8(0)) + Self.serializeIndeterminateLengthFieldSection(trailers, into: &buffer) + default: break + } } - } - private mutating func serializeKnownLengthFieldSection(into buffer: inout ByteBuffer) { - buffer.writeVarintPrefixedImmutableBuffer(self.fieldSectionBuffer) - self.fieldSectionBuffer.clear() - } - private mutating func stackKnownLengthFieldSection(_ fields: HTTPHeaders) { - for (name, value) in fields { - self.fieldSectionBuffer.writeVarintPrefixedString(name) - self.fieldSectionBuffer.writeVarintPrefixedString(value) + mutating func endRequest(into buffer: inout ByteBuffer) { + switch self.type { + case .knownLength: + self.serializeStackedContent(into: &buffer) + self.serializeKnownLengthFieldSection(into: &buffer) + case .indeterminateLength: + buffer.writeInteger(UInt8(0)) + default: break + } } - } - - private mutating func endRequest(into buffer: inout ByteBuffer) { - switch self.type { - case .knownLength: - self.serializeStackedContent(into: &buffer) - self.serializeKnownLengthFieldSection(into: &buffer) - case .indeterminateLength: - buffer.writeInteger(UInt8(0)) - default: break + + mutating func stackKnownLengthFieldSection(_ fields: HTTPHeaders) { + for (name, value) in fields { + self.fieldSectionBuffer.writeVarintPrefixedString(name) + self.fieldSectionBuffer.writeVarintPrefixedString(value) + } + } + + + mutating func serializeKnownLengthFieldSection(into buffer: inout ByteBuffer) { + buffer.writeVarintPrefixedImmutableBuffer(self.fieldSectionBuffer) + self.fieldSectionBuffer.clear() } } -} - -// Enum definitions for message, states, and types. -extension BHTTPSerializer { + // Finite State Machine for managing transitions in BHTTPSerializer. - private class BHTTPSerializerFSM { + private struct BHTTPSerializerFSM { private var currentState: BHTTPSerializerState init(initialState: BHTTPSerializerState) { self.currentState = initialState } - func writeRequestHead( + mutating func writeRequestHead( _ requestHead: HTTPRequestHead, into buffer: inout ByteBuffer, - using serializer: inout BHTTPSerializer + using context: inout SerializerContext ) throws { try self.transition(to: .header) - serializer.serializeRequestHead(requestHead, into: &buffer) + context.serializeRequestHead(requestHead, into: &buffer) } - func writeResponseHead( + mutating func writeResponseHead( _ responseHead: HTTPResponseHead, into buffer: inout ByteBuffer, - using serializer: inout BHTTPSerializer + using context: inout SerializerContext ) throws { try self.transition(to: .header) - serializer.serializeResponseHead(responseHead, into: &buffer) + context.serializeResponseHead(responseHead, into: &buffer) } - func writeRequestEnd( + mutating func writeRequestEnd( into buffer: inout ByteBuffer, - using serializer: inout BHTTPSerializer + using context: inout SerializerContext ) throws { - serializer.endRequest(into: &buffer) + context.endRequest(into: &buffer) try self.transition(to: .end) } - func writeBodyChunk( + mutating func writeBodyChunk( _ body: ByteBuffer, into buffer: inout ByteBuffer, - using serializer: inout BHTTPSerializer + using context: inout SerializerContext ) throws { - serializer.serializeChunk(body, into: &buffer) + context.serializeChunk(body, into: &buffer) try self.transition(to: .chunk) } - func writeTrailers( + mutating func writeTrailers( _ trailers: HTTPHeaders, into buffer: inout ByteBuffer, - using serializer: inout BHTTPSerializer + using context: inout SerializerContext ) throws { - serializer.serializeTrailers(trailers, into: &buffer) + context.serializeTrailers(trailers, into: &buffer) try self.transition(to: .trailers) } - private func transition(to state: BHTTPSerializerState) throws { + private mutating func transition(to state: BHTTPSerializerState) throws { let allowedNextStates: Set switch currentState { case .start: