Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: HTTPDecoder: Reenable Parsing After Failed Upgrade #2570

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions Sources/NIOHTTP1/HTTPDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ public typealias HTTPResponseDecoder = HTTPDecoder<HTTPClientResponsePart, HTTPC
///
/// While the `HTTPRequestDecoder` does not currently have a specific ordering requirement in the
/// `ChannelPipeline` (unlike `HTTPResponseDecoder`), it is possible that it will develop one. For
/// that reason, applications should try to ensure that the `HTTPRequestDecoder` *later* in the
/// that reason, applications should try to ensure that the `HTTPRequestDecoder` is *later* in the
/// `ChannelPipeline` than the `HTTPResponseEncoder`.
///
/// Rather than set this up manually, consider using `ChannelPipeline.configureHTTPServerPipeline`.
Expand All @@ -484,12 +484,16 @@ public enum HTTPDecoderKind: Sendable {
case response
}

extension HTTPDecoder: WriteObservingByteToMessageDecoder where In == HTTPClientResponsePart, Out == HTTPClientRequestPart {
extension HTTPDecoder: WriteObservingByteToMessageDecoder {
public typealias OutboundIn = Out

public func write(data: HTTPClientRequestPart) {
if case .head(let head) = data {
public func write(data: Out) {
if kind == .response, case .head(let head) = data as? HTTPClientRequestPart {
self.parser.requestHeads.append(head)
} else if kind == .request, case let .head(head) = data as? HTTPServerResponsePart {
if head.isKeepAlive, head.status != .switchingProtocols, self.stopParsing {
self.stopParsing = false
}
}
}
}
Expand Down Expand Up @@ -567,7 +571,6 @@ public final class HTTPDecoder<In, Out>: ByteToMessageDecoder, HTTPDecoderDelega
case .response:
self.context!.fireChannelRead(NIOAny(HTTPClientResponsePart.body(self.buffer!.readSlice(length: bytes.count)!)))
}

}

func didReceiveHeaderName(_ bytes: UnsafeRawBufferPointer) throws {
Expand Down
59 changes: 59 additions & 0 deletions Tests/NIOHTTP1Tests/HTTPDecoderTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,65 @@ class HTTPDecoderTest: XCTestCase {
status: .ok)), try channel.readInbound()))
}

func testHTTPRequestParsingReenabledAfterFailedUpgrade() {
let channel = EmbeddedChannel()
defer {
XCTAssertNoThrow(try channel.finish())
}
var inBuffer = channel.allocator.buffer(capacity: 128)
inBuffer.writeStaticString("GET / HTTP/1.1\r\nHost: localhost\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n")

XCTAssertNoThrow(try channel.pipeline.addHandler(ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .fireError))).wait())
// Server receives an upgrade request.
XCTAssertNoThrow(try channel.writeInbound(inBuffer))
XCTAssertNoThrow(XCTAssertEqual(HTTPServerRequestPart.head(.init(version: .http1_1,
method: .GET, uri: "/",
headers: .init([("Host", "localhost"),
("Connection", "Upgrade"),
("Upgrade", "websocket")]))),
try channel.readInbound()))
XCTAssertNoThrow(XCTAssertEqual(HTTPServerRequestPart.end(nil), try channel.readInbound()))
// Server sees a non upgrade response come through.
XCTAssertNoThrow(try channel.writeOutbound(HTTPServerResponsePart.head(.init(version: .http1_1, status: .internalServerError))))
inBuffer.clear()
// Server receives another request.
inBuffer.writeStaticString("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n")
XCTAssertNoThrow(try channel.writeInbound(inBuffer))
// Server should properly parse that request.
XCTAssertNoThrow(XCTAssertEqual(HTTPServerRequestPart.head(.init(version: .http1_1,
method: .GET, uri: "/",
headers: .init([("Host", "localhost")]))),
try channel.readInbound()))
}

func testHTTPRequestParsingStopsAfterSuccessfulUpgrade() {
let channel = EmbeddedChannel()
defer {
XCTAssertNoThrow(try channel.finish())
}
var inBuffer = channel.allocator.buffer(capacity: 128)
inBuffer.writeStaticString("GET / HTTP/1.1\r\nHost: localhost\r\nConnection: Upgrade\r\nUpgrade: websocket\r\n\r\n")

XCTAssertNoThrow(try channel.pipeline.addHandler(ByteToMessageHandler(HTTPRequestDecoder(leftOverBytesStrategy: .fireError))).wait())
// Server receives an upgrade request.
XCTAssertNoThrow(try channel.writeInbound(inBuffer))
XCTAssertNoThrow(XCTAssertEqual(HTTPServerRequestPart.head(.init(version: .http1_1,
method: .GET, uri: "/",
headers: .init([("Host", "localhost"),
("Connection", "Upgrade"),
("Upgrade", "websocket")]))),
try channel.readInbound()))
XCTAssertNoThrow(XCTAssertEqual(HTTPServerRequestPart.end(nil), try channel.readInbound()))
// Server sees successful upgrade response come through.
XCTAssertNoThrow(try channel.writeOutbound(HTTPServerResponsePart.head(.init(version: .http1_1, status: .switchingProtocols))))
inBuffer.clear()
// Server receives another request.
inBuffer.writeStaticString("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n")
XCTAssertNoThrow(try channel.writeInbound(inBuffer))
// Server should not parse that request.
XCTAssertNoThrow(XCTAssertEqual(nil, try channel.readInbound(as: HTTPServerRequestPart.self)))
}

func testBasicVerifications() {
let byteBufferContainingJustAnX = ByteBuffer(string: "X")
let expectedInOuts: [(String, [HTTPServerRequestPart])] = [
Expand Down