Skip to content

Commit

Permalink
Backport: BaseSocketChannel flushNow IONotificationState changes (#2954
Browse files Browse the repository at this point in the history
…) (#2981)

### Motivation:

We could previously hit an assert do to a re-entrancy issue where
channel and buffered write state could change during a call-out leading
to invalid state.

### Modifications:

The decision on whether or not we should be registered for future writes
is now taken after the call outs to `fireChannelWritabilityChanged` and
`fireChannelReadComplete`. The new registration state is set to
`.unregister` if the channel is not open or if there are now flushed
pending writes.

### Result:

Scope for re-entrancy crashes is reduced.

Co-authored-by: Cory Benfield <[email protected]>
(cherry picked from commit fdc3a31)

Co-authored-by: Rick Newton-Rogers <[email protected]>
  • Loading branch information
Lukasa and rnro authored Nov 21, 2024
1 parent 02906a6 commit 7bb2d55
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 27 deletions.
60 changes: 33 additions & 27 deletions Sources/NIOPosix/BaseSocketChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -561,42 +561,48 @@ class BaseSocketChannel<SocketType: BaseSocketProtocol>: SelectableChannel, Chan
}

var newWriteRegistrationState: IONotificationState = .unregister
do {
while newWriteRegistrationState == .unregister && self.hasFlushedPendingWrites() && self.isOpen {
while newWriteRegistrationState == .unregister && self.hasFlushedPendingWrites() && self.isOpen {
let writeResult: OverallWriteResult
do {
assert(self.lifecycleManager.isActive)
let writeResult = try self.writeToSocket()
switch writeResult.writeResult {
case .couldNotWriteEverything:
newWriteRegistrationState = .register
case .writtenCompletely:
newWriteRegistrationState = .unregister
}

writeResult = try self.writeToSocket()
if writeResult.writabilityChange {
// We went from not writable to writable.
self.pipeline.syncOperations.fireChannelWritabilityChanged()
}
}
} catch let err {
// If there is a write error we should try drain the inbound before closing the socket as there may be some data pending.
// We ignore any error that is thrown as we will use the original err to close the channel and notify the user.
if self.readIfNeeded0() {
assert(self.lifecycleManager.isActive)

// We need to continue reading until there is nothing more to be read from the socket as we will not have another chance to drain it.
var readAtLeastOnce = false
while let read = try? self.readFromSocket(), read == .some {
readAtLeastOnce = true
}
if readAtLeastOnce && self.lifecycleManager.isActive {
self.pipeline.fireChannelReadComplete()
} catch let err {
// If there is a write error we should try drain the inbound before closing the socket as there may be some data pending.
// We ignore any error that is thrown as we will use the original err to close the channel and notify the user.
if self.readIfNeeded0() {
assert(self.lifecycleManager.isActive)

// We need to continue reading until there is nothing more to be read from the socket as we will not have another chance to drain it.
var readAtLeastOnce = false
while let read = try? self.readFromSocket(), read == .some {
readAtLeastOnce = true
}
if readAtLeastOnce && self.lifecycleManager.isActive {
self.pipeline.fireChannelReadComplete()
}
}

self.close0(error: err, mode: .all, promise: nil)

// we handled all writes
return .unregister
}

self.close0(error: err, mode: .all, promise: nil)
switch writeResult.writeResult {
case .couldNotWriteEverything:
newWriteRegistrationState = .register
case .writtenCompletely:
newWriteRegistrationState = .unregister
}

// we handled all writes
return .unregister
if !self.isOpen || !self.hasFlushedPendingWrites() {
// No further writes, unregister. We won't re-enter the loop as both of these would have to be true.
newWriteRegistrationState = .unregister
}
}

assert(
Expand Down
115 changes: 115 additions & 0 deletions Tests/NIOPosixTests/SALChannelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1022,4 +1022,119 @@ final class SALChannelTest: XCTestCase, SALTest {
}
)
}

func testBaseSocketChannelFlushNowReentrancyCrash() {
final class TestHandler: ChannelInboundHandler {
typealias InboundIn = Any
typealias OutboundOut = ByteBuffer

private let buffer: ByteBuffer

init(_ buffer: ByteBuffer) {
self.buffer = buffer
}

func channelActive(context: ChannelHandlerContext) {
context.write(self.wrapOutboundOut(buffer), promise: nil)
context.write(self.wrapOutboundOut(buffer), promise: nil)
context.flush()
context.fireChannelActive()
}

func channelWritabilityChanged(context: ChannelHandlerContext) {
if context.channel.isWritable {
context.close(promise: nil)
}
context.fireChannelWritabilityChanged()
}
}
guard let channel = try? self.makeSocketChannel() else {
XCTFail("couldn't make a channel")
return
}
let localAddress = try! SocketAddress(ipAddress: "1.2.3.4", port: 5)
let serverAddress = try! SocketAddress(ipAddress: "9.8.7.6", port: 5)
let buffer = ByteBuffer(repeating: 0, count: 1024)

XCTAssertNoThrow(
try channel.eventLoop.runSAL(syscallAssertions: {
try self.assertSetOption(expectedLevel: .tcp, expectedOption: .tcp_nodelay) { value in
(value as? SocketOptionValue) == 1
}
try self.assertConnect(expectedAddress: serverAddress, result: false)
try self.assertLocalAddress(address: localAddress)
try self.assertRegister { selectable, event, Registration in
XCTAssertEqual([.reset], event)
return true
}
try self.assertReregister { selectable, event in
XCTAssertEqual([.reset, .write], event)
return true
}

let writeEvent = SelectorEvent(
io: [.write],
registration: NIORegistration(
channel: .socketChannel(channel),
interested: [.reset, .write],
registrationID: .initialRegistrationID
)
)
try self.assertWaitingForNotification(result: writeEvent)
try self.assertGetOption(expectedLevel: .socket, expectedOption: .so_error, value: CInt(0))
try self.assertRemoteAddress(address: serverAddress)

try self.assertReregister { selectable, event in
XCTAssertEqual([.reset, .readEOF, .write], event)
return true
}
try self.assertWritev(
expectedFD: .max,
expectedBytes: [buffer, buffer],
return: .wouldBlock(0)
)
try self.assertWritev(
expectedFD: .max,
expectedBytes: [buffer, buffer],
return: .wouldBlock(0)
)

let canWriteEvent = SelectorEvent(
io: [.write],
registration: NIORegistration(
channel: .socketChannel(channel),
interested: [.reset, .readEOF, .write],
registrationID: .initialRegistrationID
)
)
try self.assertWaitingForNotification(result: canWriteEvent)
try self.assertWritev(
expectedFD: .max,
expectedBytes: [buffer, buffer],
return: .processed(buffer.readableBytes)
)

try self.assertDeregister { selectable in
true
}
try self.assertClose(expectedFD: .max)
}) {
ClientBootstrap(group: channel.eventLoop)
.channelOption(.autoRead, value: false)
.channelOption(.writeSpin, value: 0)
.channelOption(
.writeBufferWaterMark,
value: .init(low: buffer.readableBytes + 1, high: buffer.readableBytes + 1)
)
.channelInitializer { channel in
try! channel.pipeline.syncOperations.addHandler(TestHandler(buffer))
return channel.eventLoop.makeSucceededVoidFuture()
}
.testOnly_connect(injectedChannel: channel, to: serverAddress)
.flatMap {
$0.closeFuture
}
}.salWait()
)
}
}

0 comments on commit 7bb2d55

Please sign in to comment.