diff --git a/Sources/NIOCore/ChannelHandler.swift b/Sources/NIOCore/ChannelHandler.swift index 93f11f537b..c48b06faff 100644 --- a/Sources/NIOCore/ChannelHandler.swift +++ b/Sources/NIOCore/ChannelHandler.swift @@ -343,3 +343,19 @@ extension RemovableChannelHandler { context.leavePipeline(removalToken: removalToken) } } + +/// A `NIOOutboundByteBufferingChannelHandler` is a `ChannelHandler` that +/// reports the number of bytes buffered for outbound direction. +public protocol NIOOutboundByteBufferingChannelHandler { + /// The number of bytes buffered in the channel handler, which are queued to be sent to + /// the next outbound channel handler. + var outboundBufferedBytes: Int { get } +} + +/// A `NIOInboundByteBufferingChannelHandler` is a `ChannelHandler` that +/// reports the number of bytes buffered for inbound direction. +public protocol NIOInboundByteBufferingChannelHandler { + /// The number of bytes buffered in the channel handler, which are queued to be sent to + /// the next inbound channel handler. + var inboundBufferedBytes: Int { get } +} diff --git a/Sources/NIOCore/ChannelPipeline.swift b/Sources/NIOCore/ChannelPipeline.swift index b2c90c142d..5bad399396 100644 --- a/Sources/NIOCore/ChannelPipeline.swift +++ b/Sources/NIOCore/ChannelPipeline.swift @@ -2089,3 +2089,126 @@ extension ChannelPipeline: CustomDebugStringConvertible { return handlers } } + +extension ChannelPipeline { + private enum BufferingDirection: Equatable { + case inbound + case outbound + } + + /// Retrieve the total number of bytes buffered for outbound. + public func outboundBufferedBytes() -> EventLoopFuture { + let future: EventLoopFuture + + if self.eventLoop.inEventLoop { + future = self.eventLoop.makeSucceededFuture(countAllBufferedBytes(direction: .outbound)) + } else { + future = self.eventLoop.submit { + self.countAllBufferedBytes(direction: .outbound) + } + } + + return future + } + + /// Retrieve the total number of bytes buffered for inbound. + public func inboundBufferedBytes() -> EventLoopFuture { + let future: EventLoopFuture + + if self.eventLoop.inEventLoop { + future = self.eventLoop.makeSucceededFuture(countAllBufferedBytes(direction: .inbound)) + } else { + future = self.eventLoop.submit { + self.countAllBufferedBytes(direction: .inbound) + } + } + + return future + } + + private static func countBufferedBytes(context: ChannelHandlerContext, direction: BufferingDirection) -> Int? { + switch direction { + case .inbound: + guard let handler = context.handler as? NIOInboundByteBufferingChannelHandler else { + return nil + } + return handler.inboundBufferedBytes + case .outbound: + guard let handler = context.handler as? NIOOutboundByteBufferingChannelHandler else { + return nil + } + return handler.outboundBufferedBytes + } + + } + + private func countAllBufferedBytes(direction: BufferingDirection) -> Int { + self.eventLoop.assertInEventLoop() + var total = 0 + var current = self.head?.next + switch direction { + case .inbound: + while let c = current, c !== self.tail { + if let inboundHandler = c.handler as? NIOInboundByteBufferingChannelHandler { + total += inboundHandler.inboundBufferedBytes + } + current = current?.next + } + case .outbound: + while let c = current, c !== self.tail { + if let outboundHandler = c.handler as? NIOOutboundByteBufferingChannelHandler { + total += outboundHandler.outboundBufferedBytes + } + current = current?.next + } + } + + return total + } +} + +extension ChannelPipeline.SynchronousOperations { + /// Retrieve the total number of bytes buffered for outbound. + /// + /// - Important: This *must* be called on the event loop. + public func outboundBufferedBytes() -> Int { + self.eventLoop.assertInEventLoop() + return self._pipeline.countAllBufferedBytes(direction: .outbound) + } + + /// Retrieve the number of outbound bytes buffered in the `ChannelHandler` associated with the given`ChannelHandlerContext`. + /// + /// - Parameters: + /// - in: the `ChannelHandlerContext` from which the outbound buffered bytes of the `ChannelHandler` will be retrieved. + /// - Important: This *must* be called on the event loop. + /// + /// - Returns: The number of bytes currently buffered in the `ChannelHandler` referenced by the `ChannelHandlerContext` parameter `in`. + /// If the `ChannelHandler` in the given `ChannelHandlerContext` does not conform to + /// `NIOOutboundByteBufferingChannelHandler`, this method will return `nil`. + public func outboundBufferedBytes(in context: ChannelHandlerContext) -> Int? { + self.eventLoop.assertInEventLoop() + return ChannelPipeline.countBufferedBytes(context: context, direction: .outbound) + } + + /// Retrieve total number of bytes buffered for inbound. + /// + /// - Important: This *must* be called on the event loop. + public func inboundBufferedBytes() -> Int { + self.eventLoop.assertInEventLoop() + return self._pipeline.countAllBufferedBytes(direction: .inbound) + } + + /// Retrieve the number of inbound bytes buffered in the `ChannelHandler` associated with the given `ChannelHandlerContext`. + /// + /// - Parameters: + /// - in: the `ChannelHandlerContext` from which the inbound buffered bytes of the `handler` will be retrieved. + /// - Important: This *must* be called on the event loop. + /// + /// - Returns: The number of bytes currently buffered in the `ChannelHandler` referenced by the `ChannelHandlerContext` parameter `in`. + /// If the `ChannelHandler` in the given `ChannelHandlerContext` does not conform to + /// `NIOInboundByteBufferingChannelHandler`, this method will return `nil`. + public func inboundBufferedBytes(in context: ChannelHandlerContext) -> Int? { + self.eventLoop.assertInEventLoop() + return ChannelPipeline.countBufferedBytes(context: context, direction: .inbound) + } +} diff --git a/Tests/NIOPosixTests/ChannelPipelineTest.swift b/Tests/NIOPosixTests/ChannelPipelineTest.swift index b84baf2e07..328d390ae7 100644 --- a/Tests/NIOPosixTests/ChannelPipelineTest.swift +++ b/Tests/NIOPosixTests/ChannelPipelineTest.swift @@ -1600,6 +1600,1085 @@ class ChannelPipelineTest: XCTestCase { XCTAssertEqual(eventCounter.userInboundEventTriggeredCalls, 1) XCTAssertEqual(eventCounter.writeCalls, 2) // write, and writeAndFlush } + + func testRetrieveInboundBufferedBytesFromChannelWithZeroHandler() throws { + let channel = EmbeddedChannel() + + let data = ByteBuffer(string: "1234") + for _ in 1...5 { + try channel.writeInbound(data) + let bufferedBytes = try channel.pipeline.inboundBufferedBytes().wait() + XCTAssertEqual(bufferedBytes, 0) + } + + for _ in 1...5 { + _ = try channel.readInbound(as: ByteBuffer.self) + } + + XCTAssertTrue(try channel.finish().isClean) + } + + func testRetrieveOutboundBufferedBytesFromChannelWithZeroHandler() throws { + let channel = EmbeddedChannel() + + let data = ByteBuffer(string: "1234") + for _ in 1...5 { + try channel.writeOutbound(data) + let bufferedBytes = try channel.pipeline.outboundBufferedBytes().wait() + XCTAssertEqual(bufferedBytes, 0) + } + + for _ in 1...5 { + _ = try channel.readOutbound(as: ByteBuffer.self) + } + + XCTAssertTrue(try channel.finish().isClean) + } + + func testRetrieveInboundBufferedBytesFromChannelWithOneHandler() throws { + class InboundBufferHandler: ChannelInboundHandler, NIOInboundByteBufferingChannelHandler { + typealias InboundIn = ByteBuffer + typealias InboundOut = ByteBuffer + + private var buffer = ByteBuffer() + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + buffer.writeImmutableBuffer(self.unwrapInboundIn(data)) + } + + var inboundBufferedBytes: Int { + self.buffer.readableBytes + } + } + + let channel = EmbeddedChannel() + try channel.pipeline.syncOperations.addHandlers([InboundBufferHandler()]) + + let data = ByteBuffer(string: "1234") + for cnt in 1...5 { + try channel.writeInbound(data) + let bufferedBytes = try channel.pipeline.inboundBufferedBytes().wait() + XCTAssertEqual(bufferedBytes, cnt * data.readableBytes) + } + + for _ in 1...5 { + _ = try channel.readInbound(as: ByteBuffer.self) + } + + XCTAssertTrue(try channel.finish().isClean) + } + + func testRetrieveOutboundBufferedBytesFromChannelWithOneHandler() throws { + class OutboundBufferHandler: ChannelOutboundHandler, NIOOutboundByteBufferingChannelHandler { + typealias OutboundIn = ByteBuffer + typealias OutboundOut = ByteBuffer + + private var buffer = ByteBuffer() + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + buffer.writeImmutableBuffer(self.unwrapOutboundIn(data)) + promise?.succeed() + } + + var outboundBufferedBytes: Int { + self.buffer.readableBytes + } + } + + let channel = EmbeddedChannel() + try channel.pipeline.syncOperations.addHandlers([OutboundBufferHandler()]) + + let data = ByteBuffer(string: "1234") + for cnt in 1...5 { + try channel.writeOutbound(data) + let bufferedBytes = try channel.pipeline.outboundBufferedBytes().wait() + XCTAssertEqual(bufferedBytes, cnt * data.readableBytes) + } + + for _ in 1...5 { + _ = try channel.readOutbound(as: ByteBuffer.self) + } + + XCTAssertTrue(try channel.finish().isClean) + } + + func testRetrieveInboundBufferedBytesFromChannelWithEmptyBuffer() throws { + class InboundBufferHandler: ChannelInboundHandler, NIOInboundByteBufferingChannelHandler { + typealias InboundIn = ByteBuffer + typealias InboundOut = ByteBuffer + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + context.fireChannelRead(data) + } + + var inboundBufferedBytes: Int { 0 } + } + + let channel = EmbeddedChannel() + try channel.pipeline.syncOperations.addHandlers([InboundBufferHandler(), InboundBufferHandler()]) + + let data = ByteBuffer(string: "1234") + for _ in 1...5 { + try channel.writeInbound(data) + let bufferedBytes = try channel.pipeline.inboundBufferedBytes().wait() + XCTAssertEqual(bufferedBytes, 0) + } + + for _ in 1...5 { + _ = try channel.readInbound(as: ByteBuffer.self) + } + + XCTAssertTrue(try channel.finish().isClean) + } + + func testRetrieveOutboundBufferedBytesFromChannelWithEmptyBuffer() throws { + class OutboundBufferHandler: ChannelOutboundHandler, NIOOutboundByteBufferingChannelHandler { + typealias OutboundIn = ByteBuffer + typealias OutboundOut = ByteBuffer + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + context.write(data, promise: promise) + } + + var outboundBufferedBytes: Int { 0 } + } + + let channel = EmbeddedChannel() + try channel.pipeline.syncOperations.addHandlers([OutboundBufferHandler(), OutboundBufferHandler()]) + + let data = ByteBuffer(string: "1234") + for _ in 1...5 { + try channel.writeOutbound(data) + let bufferedBytes = try channel.pipeline.outboundBufferedBytes().wait() + XCTAssertEqual(bufferedBytes, 0) + } + + for _ in 1...5 { + _ = try channel.readOutbound(as: ByteBuffer.self) + } + + XCTAssertTrue(try channel.finish().isClean) + } + + func testRetrieveInboundBufferedBytesFromChannelWithMultipleHandlers() throws { + class InboundBufferHandler: ChannelInboundHandler, NIOInboundByteBufferingChannelHandler { + typealias InboundIn = ByteBuffer + typealias InboundOut = ByteBuffer + + private var buffer = ByteBuffer() + private let expectedBufferCount: Int + + init(expectedBufferCount: Int) { + self.expectedBufferCount = expectedBufferCount + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + var buf = self.unwrapInboundIn(data) + let readSize = min(expectedBufferCount, buf.readableBytes) + if let b = buf.readSlice(length: readSize) { + buffer.writeImmutableBuffer(b) + } + context.fireChannelRead(self.wrapInboundOut(buf)) + } + + var inboundBufferedBytes: Int { + self.buffer.readableBytes + } + } + + let handlers = (0..<5).map { InboundBufferHandler(expectedBufferCount: $0) } + let channel = EmbeddedChannel() + try channel.pipeline.syncOperations.addHandlers(handlers) + + let data = ByteBuffer(string: "1234") + try channel.writeInbound(data) + let bufferedBytes = try channel.pipeline.inboundBufferedBytes().wait() + XCTAssertEqual(bufferedBytes, data.readableBytes) + + _ = try channel.readInbound(as: ByteBuffer.self) + XCTAssertTrue(try channel.finish().isClean) + } + + func testRetrieveOutboundBufferedBytesFromChannelWithMultipleHandlers() throws { + class OutboundBufferHandler: ChannelOutboundHandler, NIOOutboundByteBufferingChannelHandler { + + typealias OutboundIn = ByteBuffer + typealias OutboundOut = ByteBuffer + + private var buffer = ByteBuffer() + private let expectedBufferCount: Int + + init(expectedBufferCount: Int) { + self.expectedBufferCount = expectedBufferCount + } + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + var buf = self.unwrapOutboundIn(data) + let readSize = min(expectedBufferCount, buf.readableBytes) + if let b = buf.readSlice(length: readSize) { + buffer.writeImmutableBuffer(b) + } + + context.write(self.wrapOutboundOut(buf), promise: promise) + } + + var outboundBufferedBytes: Int { + self.buffer.readableBytes + } + } + + let handlers = (0..<5).map { OutboundBufferHandler(expectedBufferCount: $0) } + let channel = EmbeddedChannel() + try channel.pipeline.syncOperations.addHandlers(handlers) + + let data = ByteBuffer(string: "1234") + try channel.writeOutbound(data) + let bufferedBytes = try channel.pipeline.outboundBufferedBytes().wait() + XCTAssertEqual(bufferedBytes, data.readableBytes) + + _ = try channel.readOutbound(as: ByteBuffer.self) + XCTAssertTrue(try channel.finish().isClean) + } + + func testRetrieveInboundBufferedBytesFromChannelWithHandlersRemoved() throws { + class InboundBufferHandler: ChannelInboundHandler, NIOInboundByteBufferingChannelHandler, + RemovableChannelHandler + { + typealias InboundIn = ByteBuffer + typealias InboundOut = ByteBuffer + + private var buffer = ByteBuffer() + let expectedBufferCount: Int + + init(expectedBufferCount: Int) { + self.expectedBufferCount = expectedBufferCount + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + var buf = self.unwrapInboundIn(data) + let readSize = min(expectedBufferCount, buf.readableBytes) + if let b = buf.readBytes(length: readSize) { + buffer.writeBytes(b) + context.fireChannelRead(self.wrapInboundOut(buf)) + } + } + + var inboundBufferedBytes: Int { + self.buffer.readableBytes + } + } + + let handlers = (0..<5).map { InboundBufferHandler(expectedBufferCount: $0) } + + let channel = EmbeddedChannel() + for handler in handlers { + try channel.pipeline.syncOperations.addHandler(handler, position: .last) + } + + let data = ByteBuffer(string: "1234") + try channel.writeInbound(data) + var total = try channel.pipeline.inboundBufferedBytes().wait() + XCTAssertEqual(total, data.readableBytes) + let expectedBufferedBytes = handlers.map { $0.inboundBufferedBytes } + print(expectedBufferedBytes) + + for (expectedBufferedByte, handler) in zip(expectedBufferedBytes, handlers) { + let expectedRemaining = total - expectedBufferedByte + channel.pipeline.removeHandler(handler).flatMap { _ in + channel.pipeline.inboundBufferedBytes() + }.and(value: expectedRemaining).whenSuccess { (remaining, expectedRemaining) in + XCTAssertEqual(remaining, expectedRemaining) + } + total -= expectedBufferedByte + } + + _ = try channel.readInbound(as: ByteBuffer.self) + XCTAssertTrue(try channel.finish().isClean) + } + + func testRetrieveOutboundBufferedBytesFromChannelWithHandlersRemoved() throws { + class OutboundBufferHandler: ChannelOutboundHandler, NIOOutboundByteBufferingChannelHandler, + RemovableChannelHandler + { + + typealias OutboundIn = ByteBuffer + typealias OutboundOut = ByteBuffer + + private var buffer = ByteBuffer() + let expectedBufferCount: Int + + init(expectedBufferCount: Int) { + self.expectedBufferCount = expectedBufferCount + } + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + var buf = self.unwrapOutboundIn(data) + let readSize = min(expectedBufferCount, buf.readableBytes) + if let b = buf.readBytes(length: readSize) { + buffer.writeBytes(b) + context.write(self.wrapOutboundOut(buf), promise: promise) + } + } + + var outboundBufferedBytes: Int { + self.buffer.readableBytes + } + } + + let handlers = (0..<5).map { OutboundBufferHandler(expectedBufferCount: $0) } + + let channel = EmbeddedChannel() + for handler in handlers { + try channel.pipeline.syncOperations.addHandler(handler, position: .first) + } + + let data = ByteBuffer(string: "1234") + try channel.writeOutbound(data) + var total = try channel.pipeline.outboundBufferedBytes().wait() + XCTAssertEqual(total, data.readableBytes) + let expectedBufferedBytes = handlers.map { $0.outboundBufferedBytes } + + for (expectedBufferedByte, handler) in zip(expectedBufferedBytes, handlers) { + let expectedRemaining = total - expectedBufferedByte + channel.pipeline.removeHandler(handler).flatMap { _ in + channel.pipeline.outboundBufferedBytes() + }.and(value: expectedRemaining).whenSuccess { (remaining, expectedRemaining) in + XCTAssertEqual(remaining, expectedRemaining) + } + total -= expectedBufferedByte + } + + _ = try channel.readOutbound(as: ByteBuffer.self) + XCTAssertTrue(try channel.finish().isClean) + } + + func testRetrieveBufferedBytesFromChannelWithMixedHandlers() throws { + // A inbound channel handler that buffers incoming byte buffer when the total number of + // calls to the channelRead() is even. + class InboundBufferHandler: ChannelInboundHandler, NIOInboundByteBufferingChannelHandler { + typealias InboundIn = ByteBuffer + typealias InboundOut = ByteBuffer + var count: Int + var bb: ByteBuffer + + init() { + self.count = 0 + self.bb = ByteBuffer() + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + var d = unwrapInboundIn(data) + self.bb.writeBuffer(&d) + + if count % 2 == 1 { + context.fireChannelRead(self.wrapInboundOut(self.bb)) + self.bb.moveReaderIndex(forwardBy: self.bb.readableBytes) + } + + count += 1 + } + + var inboundBufferedBytes: Int { + bb.readableBytes + } + } + + // A outbound channel handler that buffers incoming byte buffer when the total number of + // calls to the write() is odd. + class OutboundBufferedHandler: ChannelOutboundHandler, NIOOutboundByteBufferingChannelHandler { + typealias OutboundIn = ByteBuffer + typealias OutboundOut = ByteBuffer + var count: Int + var bb: ByteBuffer + + init() { + self.count = 0 + self.bb = ByteBuffer() + } + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + var d = unwrapOutboundIn(data) + self.bb.writeBuffer(&d) + if count % 2 == 0 { + promise?.succeed() + } else { + context.write(self.wrapOutboundOut(self.bb), promise: promise) + self.bb.moveWriterIndex(forwardBy: self.bb.writableBytes) + } + count += 1 + } + + var outboundBufferedBytes: Int { + bb.writableBytes + } + } + + let channel = EmbeddedChannel(handlers: [InboundBufferHandler(), OutboundBufferedHandler()]) + + let data = ByteBuffer(string: "123") + try channel.writeAndFlush(data).wait() + + channel.pipeline.outboundBufferedBytes().whenSuccess { result in + XCTAssertEqual(result, data.writableBytes) + } + _ = try channel.readOutbound(as: ByteBuffer.self) + + try channel.writeAndFlush(data).wait() + + channel.pipeline.outboundBufferedBytes().whenSuccess { result in + XCTAssertEqual(result, 0) + } + + _ = try channel.readOutbound(as: ByteBuffer.self) + + try channel.writeInbound(data) + + channel.pipeline.inboundBufferedBytes().whenSuccess { result in + XCTAssertEqual(result, data.readableBytes) + } + + _ = try channel.readInbound(as: ByteBuffer.self) + + try channel.writeInbound(data) + + channel.pipeline.inboundBufferedBytes().whenSuccess { result in + XCTAssertEqual(result, 0) + } + + _ = try channel.readInbound(as: ByteBuffer.self) + + XCTAssertTrue(try channel.finish().isClean) + } + + func testSynchronouslyRetrieveInboundBufferedBytesWhenChannelHandlerNotConformToProtocol() throws { + class InboundBufferHandler: ChannelInboundHandler { + typealias InboundIn = ByteBuffer + typealias InboundOut = ByteBuffer + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + context.fireChannelRead(data) + } + } + + let channel = EmbeddedChannel() + let inboundChannelHandlerName = "InboundBufferHandler" + try channel.pipeline.syncOperations.addHandler(InboundBufferHandler(), name: inboundChannelHandlerName) + let context = try channel.pipeline.syncOperations.context(name: inboundChannelHandlerName) + let bufferedBytes = channel.pipeline.syncOperations.inboundBufferedBytes(in: context) + + XCTAssertNil(bufferedBytes) + XCTAssertTrue(try channel.finish().isClean) + } + + func testSynchronouslyRetrieveOutboundBufferedBytesWhenChannelHandlerNotConformToProtocol() throws { + class OutboundBufferHandler: ChannelOutboundHandler { + typealias OutboundIn = ByteBuffer + typealias OutboundOut = ByteBuffer + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + context.write(data, promise: promise) + } + } + + let channel = EmbeddedChannel() + let outboundChannelHandlerName = "outboundBufferHandler" + try channel.pipeline.syncOperations.addHandler(OutboundBufferHandler(), name: outboundChannelHandlerName) + let context = try channel.pipeline.syncOperations.context(name: outboundChannelHandlerName) + let bufferedBytes = channel.pipeline.syncOperations.outboundBufferedBytes(in: context) + + XCTAssertNil(bufferedBytes) + XCTAssertTrue(try channel.finish().isClean) + } + + func testSynchronouslyRetrieveInboundBufferedBytesFromOneHandler() throws { + class InboundBufferHandler: ChannelInboundHandler, NIOInboundByteBufferingChannelHandler { + typealias InboundIn = ByteBuffer + typealias InboundOut = ByteBuffer + + private var buffer = ByteBuffer() + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + buffer.writeImmutableBuffer(self.unwrapInboundIn(data)) + } + + var inboundBufferedBytes: Int { + self.buffer.readableBytes + } + } + + let channel = EmbeddedChannel() + let inboundChannelHandlerName = "InboundBufferHandler" + try channel.pipeline.syncOperations.addHandler(InboundBufferHandler(), name: inboundChannelHandlerName) + + let data = ByteBuffer(string: "1234") + for cnt in 1...5 { + try channel.writeInbound(data) + let context = try channel.pipeline.syncOperations.context(name: inboundChannelHandlerName) + let bufferedBytes = channel.pipeline.syncOperations.inboundBufferedBytes(in: context) + XCTAssertNotNil(bufferedBytes) + XCTAssertEqual(bufferedBytes, data.readableBytes * cnt) + } + + for _ in 1...5 { + _ = try channel.readInbound(as: ByteBuffer.self) + } + + XCTAssertTrue(try channel.finish().isClean) + } + + func testSynchronouslyRetrieveOutboundBufferedBytesFromOneHandler() throws { + class OutboundBufferHandler: ChannelOutboundHandler, NIOOutboundByteBufferingChannelHandler { + typealias OutboundIn = ByteBuffer + typealias OutboundOut = ByteBuffer + + private var buffer = ByteBuffer() + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + buffer.writeImmutableBuffer(self.unwrapOutboundIn(data)) + promise?.succeed() + } + + var outboundBufferedBytes: Int { + self.buffer.readableBytes + } + } + + let channel = EmbeddedChannel() + let outboundChannelHandlerName = "outboundBufferHandler" + try channel.pipeline.syncOperations.addHandler(OutboundBufferHandler(), name: outboundChannelHandlerName) + + let data = ByteBuffer(string: "1234") + for cnt in 1...5 { + try channel.writeOutbound(data) + let context = try channel.pipeline.syncOperations.context(name: outboundChannelHandlerName) + let bufferedBytes = channel.pipeline.syncOperations.outboundBufferedBytes(in: context) + + XCTAssertNotNil(bufferedBytes) + XCTAssertEqual(bufferedBytes, data.readableBytes * cnt) + } + + for _ in 1...5 { + _ = try channel.readOutbound(as: ByteBuffer.self) + } + + XCTAssertTrue(try channel.finish().isClean) + } + + func testSynchronouslyRetrieveEmptyInboundBufferedBytes() throws { + class InboundBufferHandler: ChannelInboundHandler, NIOInboundByteBufferingChannelHandler { + typealias InboundIn = ByteBuffer + typealias InboundOut = ByteBuffer + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + context.fireChannelRead(data) + } + + var inboundBufferedBytes: Int { 0 } + } + + let channel = EmbeddedChannel() + let inboundChannelHandlerName = "InboundBufferHandler" + try channel.pipeline.syncOperations.addHandler(InboundBufferHandler(), name: inboundChannelHandlerName) + + let data = ByteBuffer(string: "1234") + for _ in 1...5 { + try channel.writeInbound(data) + let context = try channel.pipeline.syncOperations.context(name: inboundChannelHandlerName) + let bufferedBytes = channel.pipeline.syncOperations.inboundBufferedBytes(in: context) + + XCTAssertNotNil(bufferedBytes) + XCTAssertEqual(bufferedBytes, 0) + } + + for _ in 1...5 { + _ = try channel.readInbound(as: ByteBuffer.self) + } + + XCTAssertTrue(try channel.finish().isClean) + } + + func testSynchronouslyRetrieveEmptyOutboundBufferedBytes() throws { + class OutboundBufferHandler: ChannelOutboundHandler, NIOOutboundByteBufferingChannelHandler { + typealias OutboundIn = ByteBuffer + typealias OutboundOut = ByteBuffer + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + context.write(data, promise: promise) + } + + var outboundBufferedBytes: Int { 0 } + } + + let channel = EmbeddedChannel() + let outboundChannelHandlerName = "outboundBufferHandler" + try channel.pipeline.syncOperations.addHandler(OutboundBufferHandler(), name: outboundChannelHandlerName) + + let data = ByteBuffer(string: "1234") + for _ in 1...5 { + try channel.writeOutbound(data) + let context = try channel.pipeline.syncOperations.context(name: outboundChannelHandlerName) + let bufferedBytes = channel.pipeline.syncOperations.outboundBufferedBytes(in: context) + + XCTAssertNotNil(bufferedBytes) + XCTAssertEqual(bufferedBytes, 0) + } + + for _ in 1...5 { + _ = try channel.readOutbound(as: ByteBuffer.self) + } + + XCTAssertTrue(try channel.finish().isClean) + } + + func testSynchronouslyRetrieveInboundBufferedBytesFromChannelWithZeroHandler() throws { + let channel = EmbeddedChannel() + + let data = ByteBuffer(string: "1234") + for _ in 1...5 { + try channel.writeInbound(data) + let bufferedBytes = channel.pipeline.syncOperations.inboundBufferedBytes() + XCTAssertEqual(bufferedBytes, 0) + } + + for _ in 1...5 { + _ = try channel.readInbound(as: ByteBuffer.self) + } + + XCTAssertTrue(try channel.finish().isClean) + } + + func testSynchronouslyRetrieveOutboundBufferedBytesFromChannelWithZeroHandler() throws { + let channel = EmbeddedChannel() + + let data = ByteBuffer(string: "1234") + for _ in 1...5 { + try channel.writeOutbound(data) + let bufferedBytes = channel.pipeline.syncOperations.outboundBufferedBytes() + XCTAssertEqual(bufferedBytes, 0) + } + + for _ in 1...5 { + _ = try channel.readOutbound(as: ByteBuffer.self) + } + + XCTAssertTrue(try channel.finish().isClean) + } + + func testSynchronouslyRetrieveInboundBufferedBytesFromChannelWithOneHandler() throws { + class InboundBufferHandler: ChannelInboundHandler, NIOInboundByteBufferingChannelHandler { + typealias InboundIn = ByteBuffer + typealias InboundOut = ByteBuffer + + private var buffer = ByteBuffer() + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + buffer.writeImmutableBuffer(self.unwrapInboundIn(data)) + } + + var inboundBufferedBytes: Int { + self.buffer.readableBytes + } + } + + let channel = EmbeddedChannel() + try channel.pipeline.syncOperations.addHandlers([InboundBufferHandler()]) + + let data = ByteBuffer(string: "1234") + for cnt in 1...5 { + try channel.writeInbound(data) + let bufferedBytes = channel.pipeline.syncOperations.inboundBufferedBytes() + XCTAssertEqual(bufferedBytes, cnt * data.readableBytes) + } + + for _ in 1...5 { + _ = try channel.readInbound(as: ByteBuffer.self) + } + + XCTAssertTrue(try channel.finish().isClean) + } + + func testSynchronouslyRetrieveOutboundBufferedBytesFromChannelWithOneHandler() throws { + class OutboundBufferHandler: ChannelOutboundHandler, NIOOutboundByteBufferingChannelHandler { + typealias OutboundIn = ByteBuffer + typealias OutboundOut = ByteBuffer + + private var buffer = ByteBuffer() + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + buffer.writeImmutableBuffer(self.unwrapOutboundIn(data)) + promise?.succeed() + } + + var outboundBufferedBytes: Int { + self.buffer.readableBytes + } + } + + let channel = EmbeddedChannel() + try channel.pipeline.syncOperations.addHandlers([OutboundBufferHandler()]) + + let data = ByteBuffer(string: "1234") + for cnt in 1...5 { + try channel.writeOutbound(data) + let bufferedBytes = channel.pipeline.syncOperations.outboundBufferedBytes() + XCTAssertEqual(bufferedBytes, cnt * data.readableBytes) + } + + for _ in 1...5 { + _ = try channel.readOutbound(as: ByteBuffer.self) + } + + XCTAssertTrue(try channel.finish().isClean) + } + + func testSynchronouslyRetrieveInboundBufferedBytesFromChannelWithEmptyBuffer() throws { + class InboundBufferHandler: ChannelInboundHandler, NIOInboundByteBufferingChannelHandler { + typealias InboundIn = ByteBuffer + typealias InboundOut = ByteBuffer + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + context.fireChannelRead(data) + } + + var inboundBufferedBytes: Int { 0 } + } + + let channel = EmbeddedChannel() + try channel.pipeline.syncOperations.addHandlers([InboundBufferHandler(), InboundBufferHandler()]) + + let data = ByteBuffer(string: "1234") + for _ in 1...5 { + try channel.writeInbound(data) + let bufferedBytes = channel.pipeline.syncOperations.inboundBufferedBytes() + XCTAssertEqual(bufferedBytes, 0) + } + + for _ in 1...5 { + _ = try channel.readInbound(as: ByteBuffer.self) + } + + XCTAssertTrue(try channel.finish().isClean) + } + + func testSynchronouslyRetrieveOutboundBufferedBytesFromChannelWithEmptyBuffer() throws { + class OutboundBufferHandler: ChannelOutboundHandler, NIOOutboundByteBufferingChannelHandler { + typealias OutboundIn = ByteBuffer + typealias OutboundOut = ByteBuffer + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + context.write(data, promise: promise) + } + + var outboundBufferedBytes: Int { 0 } + } + + let channel = EmbeddedChannel() + try channel.pipeline.syncOperations.addHandlers([OutboundBufferHandler(), OutboundBufferHandler()]) + + let data = ByteBuffer(string: "1234") + for _ in 1...5 { + try channel.writeOutbound(data) + let bufferedBytes = channel.pipeline.syncOperations.outboundBufferedBytes() + XCTAssertEqual(bufferedBytes, 0) + } + + for _ in 1...5 { + _ = try channel.readOutbound(as: ByteBuffer.self) + } + + XCTAssertTrue(try channel.finish().isClean) + } + + func testSynchronouslyRetrieveInboundBufferedBytesFromChannelWithMultipleHandlers() throws { + class InboundBufferHandler: ChannelInboundHandler, NIOInboundByteBufferingChannelHandler { + typealias InboundIn = ByteBuffer + typealias InboundOut = ByteBuffer + + private var buffer = ByteBuffer() + private let expectedBufferCount: Int + + init(expectedBufferCount: Int) { + self.expectedBufferCount = expectedBufferCount + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + var buf = self.unwrapInboundIn(data) + let readSize = min(expectedBufferCount, buf.readableBytes) + if let b = buf.readSlice(length: readSize) { + buffer.writeImmutableBuffer(b) + } + context.fireChannelRead(self.wrapInboundOut(buf)) + } + + var inboundBufferedBytes: Int { + self.buffer.readableBytes + } + } + + let handlers = (0..<5).map { InboundBufferHandler(expectedBufferCount: $0) } + let channel = EmbeddedChannel() + try channel.pipeline.syncOperations.addHandlers(handlers) + + let data = ByteBuffer(string: "1234") + try channel.writeInbound(data) + let bufferedBytes = channel.pipeline.syncOperations.inboundBufferedBytes() + XCTAssertEqual(bufferedBytes, data.readableBytes) + + _ = try channel.readInbound(as: ByteBuffer.self) + XCTAssertTrue(try channel.finish().isClean) + } + + func testSynchronouslyRetrieveOutboundBufferedBytesFromChannelWithMultipleHandlers() throws { + class OutboundBufferHandler: ChannelOutboundHandler, NIOOutboundByteBufferingChannelHandler { + + typealias OutboundIn = ByteBuffer + typealias OutboundOut = ByteBuffer + + private var buffer = ByteBuffer() + private let expectedBufferCount: Int + + init(expectedBufferCount: Int) { + self.expectedBufferCount = expectedBufferCount + } + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + var buf = self.unwrapOutboundIn(data) + let readSize = min(expectedBufferCount, buf.readableBytes) + if let b = buf.readSlice(length: readSize) { + buffer.writeImmutableBuffer(b) + } + + context.write(self.wrapOutboundOut(buf), promise: promise) + } + + var outboundBufferedBytes: Int { + self.buffer.readableBytes + } + } + + let handlers = (0..<5).map { OutboundBufferHandler(expectedBufferCount: $0) } + let channel = EmbeddedChannel() + try channel.pipeline.syncOperations.addHandlers(handlers) + + let data = ByteBuffer(string: "1234") + try channel.writeOutbound(data) + let bufferedBytes = channel.pipeline.syncOperations.outboundBufferedBytes() + XCTAssertEqual(bufferedBytes, data.readableBytes) + + _ = try channel.readOutbound(as: ByteBuffer.self) + XCTAssertTrue(try channel.finish().isClean) + } + + func testSynchronouslyRetrieveInboundBufferedBytesFromChannelWithHandlersRemoved() throws { + class InboundBufferHandler: ChannelInboundHandler, NIOInboundByteBufferingChannelHandler, + RemovableChannelHandler + { + typealias InboundIn = ByteBuffer + typealias InboundOut = ByteBuffer + + private var buffer = ByteBuffer() + let expectedBufferCount: Int + + init(expectedBufferCount: Int) { + self.expectedBufferCount = expectedBufferCount + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + var buf = self.unwrapInboundIn(data) + let readSize = min(expectedBufferCount, buf.readableBytes) + if let b = buf.readBytes(length: readSize) { + buffer.writeBytes(b) + context.fireChannelRead(self.wrapInboundOut(buf)) + } + } + + var inboundBufferedBytes: Int { + self.buffer.readableBytes + } + } + + let handlers = (0..<5).map { InboundBufferHandler(expectedBufferCount: $0) } + + let channel = EmbeddedChannel() + for handler in handlers { + try channel.pipeline.syncOperations.addHandler(handler, position: .last) + } + + let data = ByteBuffer(string: "1234") + try channel.writeInbound(data) + var total = channel.pipeline.syncOperations.inboundBufferedBytes() + XCTAssertEqual(total, data.readableBytes) + let expectedBufferedBytes = handlers.map { $0.inboundBufferedBytes } + print(expectedBufferedBytes) + + for (expectedBufferedByte, handler) in zip(expectedBufferedBytes, handlers) { + let expectedRemaining = total - expectedBufferedByte + channel.pipeline.syncOperations + .removeHandler(handler) + .and(value: expectedRemaining) + .whenSuccess { (_, expectedRemaining) in + let remaining = channel.pipeline.syncOperations.inboundBufferedBytes() + XCTAssertEqual(remaining, expectedRemaining) + } + total -= expectedBufferedByte + } + + _ = try channel.readInbound(as: ByteBuffer.self) + XCTAssertTrue(try channel.finish().isClean) + } + + func testSynchronouslyRetrieveOutboundBufferedBytesFromChannelWithHandlersRemoved() throws { + class OutboundBufferHandler: ChannelOutboundHandler, NIOOutboundByteBufferingChannelHandler, + RemovableChannelHandler + { + + typealias OutboundIn = ByteBuffer + typealias OutboundOut = ByteBuffer + + private var buffer = ByteBuffer() + let expectedBufferCount: Int + + init(expectedBufferCount: Int) { + self.expectedBufferCount = expectedBufferCount + } + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + var buf = self.unwrapOutboundIn(data) + let readSize = min(expectedBufferCount, buf.readableBytes) + if let b = buf.readBytes(length: readSize) { + buffer.writeBytes(b) + context.write(self.wrapOutboundOut(buf), promise: promise) + } + } + + var outboundBufferedBytes: Int { + self.buffer.readableBytes + } + } + + let handlers = (0..<5).map { OutboundBufferHandler(expectedBufferCount: $0) } + + let channel = EmbeddedChannel() + for handler in handlers { + try channel.pipeline.syncOperations.addHandler(handler, position: .first) + } + + let data = ByteBuffer(string: "1234") + try channel.writeOutbound(data) + var total = channel.pipeline.syncOperations.outboundBufferedBytes() + XCTAssertEqual(total, data.readableBytes) + let expectedBufferedBytes = handlers.map { $0.outboundBufferedBytes } + + for (expectedBufferedByte, handler) in zip(expectedBufferedBytes, handlers) { + let expectedRemaining = total - expectedBufferedByte + channel.pipeline.syncOperations + .removeHandler(handler) + .and(value: expectedRemaining) + .whenSuccess { (_, expectedRemaining) in + let remaining = channel.pipeline.syncOperations.outboundBufferedBytes() + XCTAssertEqual(remaining, expectedRemaining) + } + total -= expectedBufferedByte + } + + _ = try channel.readOutbound(as: ByteBuffer.self) + XCTAssertTrue(try channel.finish().isClean) + } + + func testSynchronouslyRetrieveBufferedBytesFromChannelWithMixedHandlers() throws { + // A inbound channel handler that buffers incoming byte buffer when the total number of + // calls to the channelRead() is even. + class InboundBufferHandler: ChannelInboundHandler, NIOInboundByteBufferingChannelHandler { + typealias InboundIn = ByteBuffer + typealias InboundOut = ByteBuffer + var count: Int + var bb: ByteBuffer + + init() { + self.count = 0 + self.bb = ByteBuffer() + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + var d = unwrapInboundIn(data) + self.bb.writeBuffer(&d) + + if count % 2 == 1 { + context.fireChannelRead(self.wrapInboundOut(self.bb)) + self.bb.moveReaderIndex(forwardBy: self.bb.readableBytes) + } + + count += 1 + } + + var inboundBufferedBytes: Int { + bb.readableBytes + } + } + + // A outbound channel handler that buffers incoming byte buffer when the total number of + // calls to the write() is odd. + class OutboundBufferedHandler: ChannelOutboundHandler, NIOOutboundByteBufferingChannelHandler { + typealias OutboundIn = ByteBuffer + typealias OutboundOut = ByteBuffer + var count: Int + var bb: ByteBuffer + + init() { + self.count = 0 + self.bb = ByteBuffer() + } + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + var d = unwrapOutboundIn(data) + self.bb.writeBuffer(&d) + if count % 2 == 0 { + promise?.succeed() + } else { + context.write(self.wrapOutboundOut(self.bb), promise: promise) + self.bb.moveWriterIndex(forwardBy: self.bb.writableBytes) + } + count += 1 + } + + var outboundBufferedBytes: Int { + bb.writableBytes + } + } + + let channel = EmbeddedChannel(handlers: [InboundBufferHandler(), OutboundBufferedHandler()]) + + let data = ByteBuffer(string: "123") + try channel.writeAndFlush(data).wait() + + var result = channel.pipeline.syncOperations.outboundBufferedBytes() + XCTAssertEqual(result, data.writableBytes) + + _ = try channel.readOutbound(as: ByteBuffer.self) + + try channel.writeAndFlush(data).wait() + + result = channel.pipeline.syncOperations.outboundBufferedBytes() + XCTAssertEqual(result, 0) + + _ = try channel.readOutbound(as: ByteBuffer.self) + + try channel.writeInbound(data) + + result = channel.pipeline.syncOperations.inboundBufferedBytes() + XCTAssertEqual(result, data.readableBytes) + + _ = try channel.readInbound(as: ByteBuffer.self) + + try channel.writeInbound(data) + + result = channel.pipeline.syncOperations.inboundBufferedBytes() + XCTAssertEqual(result, 0) + + _ = try channel.readInbound(as: ByteBuffer.self) + + XCTAssertTrue(try channel.finish().isClean) + } } // this should be within `testAddMultipleHandlers` but https://bugs.swift.org/browse/SR-9956