Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for Websocket client upgrade failure #2659

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion Sources/NIOHTTP1/NIOTypedHTTPClientUpgradeHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,19 @@ public final class NIOTypedHTTPClientUpgradeHandler<UpgradeResult: Sendable>: Ch
public func handlerRemoved(context: ChannelHandlerContext) {
switch self.stateMachine.handlerRemoved() {
case .failUpgradePromise:
self.upgradeResultPromise.fail(ChannelError.inappropriateOperationForState)
// Make sure the completion handler is called on the failed upgrade path
self.notUpgradingCompletionHandler(context.channel)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This We cannot run the notUpgradingCompletionHandler in all cases where the state machine returns failUpgradePromise right now. What we should rather do is add a case to the HandlerRemovedAction from the state machine and return that from the states where we know that running the notUpgradingCompletionHandler is correct. In detail, running the notUpgradingCompletionHandler while we are in the state upgrading or unbuffering doesn't seem correct to me.

.hop(to: context.eventLoop)
.whenComplete { result in
switch result {
case .success(let value):
// Expected upgrade failure without error
self.upgradeResultPromise.succeed(value)
case .failure(let error):
// Unexpected upgrade failure with error
self.upgradeResultPromise.fail(error)
}
}
case .none:
break
}
Expand Down
96 changes: 95 additions & 1 deletion Tests/NIOWebSocketTests/WebSocketClientEndToEndTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,26 @@ extension ChannelPipeline {
}
}

private func interactInMemory(_ first: EmbeddedChannel,
_ second: EmbeddedChannel,
eventLoop: EmbeddedEventLoop) throws {
var operated: Bool

repeat {
eventLoop.run()
operated = false

if let data = try first.readOutbound(as: ByteBuffer.self) {
operated = true
try second.writeInbound(data)
}
if let data = try second.readOutbound(as: ByteBuffer.self) {
operated = true
try first.writeInbound(data)
}
} while operated
}

private func setUpClientChannel(clientHTTPHandler: RemovableChannelHandler,
clientUpgraders: [NIOHTTPClientProtocolUpgrader],
_ upgradeCompletionHandler: @escaping (ChannelHandlerContext) -> Void) throws -> EmbeddedChannel {
Expand Down Expand Up @@ -591,7 +611,7 @@ final class TypedWebSocketClientEndToEndTests: WebSocketClientEndToEndTests {
requestKey: "OfS0wDaT5NoxF2gqm7Zj2YtetzM=",
upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in
channel.pipeline.addHandler(handler)
})
})

// The process should kick-off independently by sending the upgrade request to the server.
let (clientChannel, upgradeResult) = try setUpClientChannel(
Expand All @@ -615,5 +635,79 @@ final class TypedWebSocketClientEndToEndTests: WebSocketClientEndToEndTests {

return (clientChannel, handler)
}

func testSimpleUpgradeRejectedWhenServerSendsUpgradeNil() throws {

enum TestUpgradeResult: Int {
case successfulUpgrade
case notUpgraded
}

let serverRecorder = WebSocketRecorderHandler()
let clientRecorder = WebSocketRecorderHandler()
let loop = EmbeddedEventLoop()
let serverChannel = EmbeddedChannel(loop: loop)
let clientChannel = EmbeddedChannel(loop: loop)

let serverUpgrader = NIOTypedWebSocketServerUpgrader(
shouldUpgrade: { (channel, head) in
channel.eventLoop.makeSucceededFuture(nil)
},
upgradePipelineHandler: { (channel, req) in
channel.pipeline.addHandler(serverRecorder)
}
)

XCTAssertNoThrow(try serverChannel.pipeline.syncOperations.configureUpgradableHTTPServerPipeline(
configuration: .init(
upgradeConfiguration: NIOTypedHTTPServerUpgradeConfiguration<Void>(
upgraders: [serverUpgrader],
notUpgradingCompletionHandler: { $0.eventLoop.makeSucceededVoidFuture() }
)
)
))

let basicClientUpgrader = NIOTypedWebSocketClientUpgrader<TestUpgradeResult>(
upgradePipelineHandler: { (channel: Channel, _: HTTPResponseHead) in
channel.eventLoop.makeCompletedFuture {
return TestUpgradeResult.successfulUpgrade
}
})

var headers = HTTPHeaders()
headers.add(name: "Content-Type", value: "text/plain; charset=utf-8")
headers.add(name: "Content-Length", value: "\(0)")
let requestHead = HTTPRequestHead(
version: .http1_1,
method: .GET,
uri: "/",
headers: headers
)
let config = NIOTypedHTTPClientUpgradeConfiguration<TestUpgradeResult>(
upgradeRequestHead: requestHead,
upgraders: [basicClientUpgrader],
notUpgradingCompletionHandler: { channel in
channel.eventLoop.makeCompletedFuture {
return TestUpgradeResult.notUpgraded
}
}
)
let updgradeResult = try clientChannel.pipeline.syncOperations.configureUpgradableHTTPClientPipeline(configuration: .init(upgradeConfiguration: config))

XCTAssertNoThrow(try interactInMemory(clientChannel, serverChannel, eventLoop: loop))
XCTAssertNoThrow(try clientChannel.finish())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this test hits handlerRemoved it hits it for the wrong reason IMO. When calling clientChannel.finish we are closing the channel which will lead to handlerRemoved being called. What we should do instead is construct a test case where the server responds that it is not upgrading by returning a normal HTTP response and then making sure that the notUpgradingHandler is run.

updgradeResult.whenComplete { result in
switch result {
case .success(let value):
XCTAssertTrue(value == .notUpgraded)
case .failure(let error):
XCTFail("There should be no failure here \(error)")
}
}
XCTAssertNoThrow(try serverChannel.finishAcceptingAlreadyClosed())
XCTAssertEqual(clientRecorder.errors.count, 0)
XCTAssertEqual(serverRecorder.errors.count, 0)
XCTAssertNoThrow(try loop.syncShutdownGracefully())
}
}
#endif