diff --git a/Sources/NIOCore/AddressedEnvelope.swift b/Sources/NIOCore/AddressedEnvelope.swift index d705009a0e..857e186f14 100644 --- a/Sources/NIOCore/AddressedEnvelope.swift +++ b/Sources/NIOCore/AddressedEnvelope.swift @@ -27,22 +27,40 @@ public struct AddressedEnvelope { self.remoteAddress = remoteAddress self.data = data } - + public init(remoteAddress: SocketAddress, data: DataType, metadata: Metadata?) { self.remoteAddress = remoteAddress self.data = data self.metadata = metadata } - + /// Any metadata associated with an `AddressedEnvelope` public struct Metadata: Hashable, Sendable { /// Details of any congestion state. public var ecnState: NIOExplicitCongestionNotificationState public var packetInfo: NIOPacketInfo? - + + /// The size of data segments. + /// + /// For outbound messages setting this option informs the kernel to split the data from the + /// addressed envelope into segments of this size. Note that not all platforms support + /// this option and support should be checked with ``System/supportsUDPSegmentationOffload``. + /// + /// For inbound messages this value may be set with the segment size set by the sender if + /// the ``ChannelOptions/Types/DatagramReceiveOffload`` option is set. Support for that + /// option should be checked with ``System/supportsUDPReceiveOffload``. + public var segmentSize: Int? + + public init() { + self.ecnState = .transportNotCapable + self.packetInfo = nil + self.segmentSize = nil + } + public init(ecnState: NIOExplicitCongestionNotificationState) { self.ecnState = ecnState self.packetInfo = nil + self.segmentSize = nil } public init(ecnState: NIOExplicitCongestionNotificationState, packetInfo: NIOPacketInfo?) { diff --git a/Sources/NIOPosix/ControlMessage.swift b/Sources/NIOPosix/ControlMessage.swift index 4c87483a48..aefd5505c9 100644 --- a/Sources/NIOPosix/ControlMessage.swift +++ b/Sources/NIOPosix/ControlMessage.swift @@ -89,10 +89,10 @@ struct UnsafeControlMessageCollection { // Add the `Collection` functionality to UnsafeControlMessageCollection. extension UnsafeControlMessageCollection: Collection { typealias Element = UnsafeControlMessage - + struct Index: Equatable, Comparable { fileprivate var cmsgPointer: UnsafeMutablePointer? - + static func < (lhs: UnsafeControlMessageCollection.Index, rhs: UnsafeControlMessageCollection.Index) -> Bool { // nil is high, as that's the end of the collection. @@ -105,12 +105,12 @@ extension UnsafeControlMessageCollection: Collection { return false } } - + fileprivate init(cmsgPointer: UnsafeMutablePointer?) { self.cmsgPointer = cmsgPointer } } - + var startIndex: Index { var messageHeader = self.messageHeader return withUnsafePointer(to: &messageHeader) { messageHeaderPtr in @@ -118,9 +118,9 @@ extension UnsafeControlMessageCollection: Collection { return Index(cmsgPointer: firstCMsg) } } - + var endIndex: Index { return Index(cmsgPointer: nil) } - + func index(after: Index) -> Index { var msgHdr = messageHeader return withUnsafeMutablePointer(to: &msgHdr) { messageHeaderPtr in @@ -128,7 +128,7 @@ extension UnsafeControlMessageCollection: Collection { after: after.cmsgPointer!)) } } - + public subscript(position: Index) -> Element { let cmsg = position.cmsgPointer! return UnsafeControlMessage(level: cmsg.pointee.cmsg_level, @@ -152,13 +152,14 @@ struct UnsafeReceivedControlBytes { struct ControlMessageParser { var ecnValue: NIOExplicitCongestionNotificationState = .transportNotCapable // Default var packetInfo: NIOPacketInfo? = nil + var segmentSize: Int? = nil init(parsing controlMessagesReceived: UnsafeControlMessageCollection) { for controlMessage in controlMessagesReceived { self.receiveMessage(controlMessage) } } - + #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) private static let ipv4TosType = IP_RECVTOS #else @@ -174,12 +175,14 @@ struct ControlMessageParser { } return readValue } - + private mutating func receiveMessage(_ controlMessage: UnsafeControlMessage) { if controlMessage.level == IPPROTO_IP { self.receiveIPv4Message(controlMessage) } else if controlMessage.level == IPPROTO_IPV6 { self.receiveIPv6Message(controlMessage) + } else if controlMessage.level == IPPROTO_UDP { + self.receiveUDPMessage(controlMessage) } } @@ -225,6 +228,17 @@ struct ControlMessageParser { } } } + + private mutating func receiveUDPMessage(_ controlMessage: UnsafeControlMessage) { + #if os(Linux) + if controlMessage.type == NIOBSDSocket.Option.udp_gro.rawValue { + if let data = controlMessage.data { + let readValue = ControlMessageParser._readCInt(data: data) + self.segmentSize = Int(readValue) + } + } + #endif + } } extension NIOExplicitCongestionNotificationState { @@ -262,7 +276,7 @@ extension CInt { struct UnsafeOutboundControlBytes { private var controlBytes: UnsafeMutableRawBufferPointer private var writePosition: UnsafeMutableRawBufferPointer.Index - + /// This structure must not outlive `controlBytes` init(controlBytes: UnsafeMutableRawBufferPointer) { self.controlBytes = controlBytes @@ -279,24 +293,24 @@ struct UnsafeOutboundControlBytes { type: CInt, payload: PayloadType) { let writableBuffer = UnsafeMutableRawBufferPointer(fastRebase: self.controlBytes[writePosition...]) - + let requiredSize = NIOBSDSocketControlMessage.space(payloadSize: MemoryLayout.stride(ofValue: payload)) precondition(writableBuffer.count >= requiredSize, "Insufficient size for cmsghdr and data") - + let bufferBase = writableBuffer.baseAddress! // Binding to cmsghdr is safe here as this is the only place where we bind to non-Raw. let cmsghdrPtr = bufferBase.bindMemory(to: cmsghdr.self, capacity: 1) cmsghdrPtr.pointee.cmsg_level = level cmsghdrPtr.pointee.cmsg_type = type cmsghdrPtr.pointee.cmsg_len = .init(NIOBSDSocketControlMessage.length(payloadSize: MemoryLayout.size(ofValue: payload))) - + let dataPointer = NIOBSDSocketControlMessage.data(for: cmsghdrPtr)! precondition(dataPointer.count >= MemoryLayout.stride) dataPointer.storeBytes(of: payload, as: PayloadType.self) - + self.writePosition += requiredSize } - + /// The result is only valid while this is valid. var validControlBytes: UnsafeMutableRawBufferPointer { if writePosition == 0 { @@ -304,7 +318,7 @@ struct UnsafeOutboundControlBytes { } return UnsafeMutableRawBufferPointer(fastRebase: self.controlBytes[0 ..< self.writePosition]) } - + } extension UnsafeOutboundControlBytes { @@ -330,12 +344,24 @@ extension UnsafeOutboundControlBytes { break } } + + internal mutating func appendSegmentSize(metadata: AddressedEnvelope.Metadata?) { + #if os(Linux) + guard let segmentSize = metadata?.segmentSize else { return } + self.appendGenericControlMessage(level: .init(IPPROTO_UDP), + type: NIOBSDSocket.Option.udp_segment.rawValue, + payload: UInt16(segmentSize)) + #endif + } } extension AddressedEnvelope.Metadata { /// It's assumed the caller has checked that congestion information is required before calling. internal init(from controlMessagesReceived: UnsafeControlMessageCollection) { let controlMessageReceiver = ControlMessageParser(parsing: controlMessagesReceived) - self.init(ecnState: controlMessageReceiver.ecnValue, packetInfo: controlMessageReceiver.packetInfo) + self.init() + self.ecnState = controlMessageReceiver.ecnValue + self.packetInfo = controlMessageReceiver.packetInfo + self.segmentSize = controlMessageReceiver.segmentSize } } diff --git a/Sources/NIOPosix/DatagramVectorReadManager.swift b/Sources/NIOPosix/DatagramVectorReadManager.swift index 014dc6c6fa..0a0cac8108 100644 --- a/Sources/NIOPosix/DatagramVectorReadManager.swift +++ b/Sources/NIOPosix/DatagramVectorReadManager.swift @@ -102,7 +102,7 @@ struct DatagramVectorReadManager { // First we set up the iovec and save it off. self.ioVector[i] = IOVector(iov_base: bufferPointer.baseAddress! + (i * messageSize), iov_len: numericCast(messageSize)) - + let controlBytes: UnsafeMutableRawBufferPointer if parseControlMessages { // This will be used in buildMessages below but should not be used beyond return of this function. @@ -178,12 +178,11 @@ struct DatagramVectorReadManager { precondition(self.messageVector[i].msg_hdr.msg_namelen != 0, "Unexpected zero length peer name") #endif let address: SocketAddress = self.sockaddrVector[i].convert() - + // Extract congestion information if requested. let metadata: AddressedEnvelope.Metadata? if parseControlMessages { - let controlMessagesReceived = - UnsafeControlMessageCollection(messageHeader: self.messageVector[i].msg_hdr) + let controlMessagesReceived = UnsafeControlMessageCollection(messageHeader: self.messageVector[i].msg_hdr) metadata = .init(from: controlMessagesReceived) } else { metadata = nil diff --git a/Sources/NIOPosix/PendingDatagramWritesManager.swift b/Sources/NIOPosix/PendingDatagramWritesManager.swift index 315ec03327..de14bfa935 100644 --- a/Sources/NIOPosix/PendingDatagramWritesManager.swift +++ b/Sources/NIOPosix/PendingDatagramWritesManager.swift @@ -133,6 +133,7 @@ private func doPendingDatagramWriteVectorOperation(pending: PendingDatagramWrite var controlBytes = UnsafeOutboundControlBytes(controlBytes: controlMessageStorage[c]) controlBytes.appendExplicitCongestionState(metadata: p.metadata, protocolFamily: protocolFamily) + controlBytes.appendSegmentSize(metadata: p.metadata) let controlMessageBytePointer = controlBytes.validControlBytes let msg = msghdr(msg_name: address, diff --git a/Sources/NIOPosix/SocketChannel.swift b/Sources/NIOPosix/SocketChannel.swift index 6a0f3201c5..0796a82378 100644 --- a/Sources/NIOPosix/SocketChannel.swift +++ b/Sources/NIOPosix/SocketChannel.swift @@ -370,6 +370,11 @@ final class ServerSocketChannel: BaseSocketChannel { final class DatagramChannel: BaseSocketChannel { private var reportExplicitCongestionNotifications = false private var receivePacketInfo = false + private var receiveSegmentSize = false + + private var parseControlMessages: Bool { + return self.reportExplicitCongestionNotifications || self.receivePacketInfo || self.receiveSegmentSize + } // Guard against re-entrance of flushNow() method. private let pendingWrites: PendingDatagramWritesManager @@ -517,6 +522,7 @@ final class DatagramChannel: BaseSocketChannel { throw ChannelError.operationUnsupported } let enable = value as! ChannelOptions.Types.DatagramReceiveOffload.Value + self.receiveSegmentSize = enable try self.socket.setUDPReceiveOffload(enable) default: try super.setOption0(option, value: value) @@ -619,7 +625,7 @@ final class DatagramChannel: BaseSocketChannel { // These control bytes must not escape the current call stack let controlBytesBuffer: UnsafeMutableRawBufferPointer - if self.reportExplicitCongestionNotifications || self.receivePacketInfo { + if self.parseControlMessages { controlBytesBuffer = self.selectableEventLoop.controlMessageStorage[0] } else { controlBytesBuffer = UnsafeMutableRawBufferPointer(start: nil, count: 0) @@ -648,8 +654,7 @@ final class DatagramChannel: BaseSocketChannel { readPending = false let metadata: AddressedEnvelope.Metadata? - if self.reportExplicitCongestionNotifications || self.receivePacketInfo, - let controlMessagesReceived = controlBytes.receivedControlMessages { + if self.parseControlMessages, let controlMessagesReceived = controlBytes.receivedControlMessages { metadata = .init(from: controlMessagesReceived) } else { metadata = nil @@ -688,7 +693,7 @@ final class DatagramChannel: BaseSocketChannel { try vectorReadManager.readFromSocket( socket: self.socket, buffer: &buffer, - parseControlMessages: self.reportExplicitCongestionNotifications || self.receivePacketInfo) + parseControlMessages: self.parseControlMessages) } switch result { @@ -810,6 +815,7 @@ final class DatagramChannel: BaseSocketChannel { controlBytes: self.selectableEventLoop.controlMessageStorage[0]) controlBytes.appendExplicitCongestionState(metadata: metadata, protocolFamily: self.localAddress?.protocol) + controlBytes.appendSegmentSize(metadata: metadata) return try self.socket.sendmsg(pointer: ptr, destinationPtr: destinationPtr, destinationSize: destinationSize, diff --git a/Tests/NIOPosixTests/DatagramChannelTests+XCTest.swift b/Tests/NIOPosixTests/DatagramChannelTests+XCTest.swift index 21a00c2aaa..2464448f99 100644 --- a/Tests/NIOPosixTests/DatagramChannelTests+XCTest.swift +++ b/Tests/NIOPosixTests/DatagramChannelTests+XCTest.swift @@ -89,6 +89,9 @@ extension DatagramChannelTests { ("testChannelCanReceiveLargeBufferWithGROUsingVectorReads", testChannelCanReceiveLargeBufferWithGROUsingVectorReads), ("testChannelCanReceiveMultipleLargeBuffersWithGROUsingScalarReads", testChannelCanReceiveMultipleLargeBuffersWithGROUsingScalarReads), ("testChannelCanReceiveMultipleLargeBuffersWithGROUsingVectorReads", testChannelCanReceiveMultipleLargeBuffersWithGROUsingVectorReads), + ("testSegmentSizeSetViaMetadataOnUnsupportedPlatform", testSegmentSizeSetViaMetadataOnUnsupportedPlatform), + ("testSegmentSizeSetViaMetadata", testSegmentSizeSetViaMetadata), + ("testSegmentSizeFromMetadataTakesPrecedence", testSegmentSizeFromMetadataTakesPrecedence), ] } } diff --git a/Tests/NIOPosixTests/DatagramChannelTests.swift b/Tests/NIOPosixTests/DatagramChannelTests.swift index 110034164e..df4b1bbecd 100644 --- a/Tests/NIOPosixTests/DatagramChannelTests.swift +++ b/Tests/NIOPosixTests/DatagramChannelTests.swift @@ -1380,6 +1380,7 @@ class DatagramChannelTests: XCTestCase { let datagrams = try self.secondChannel.waitForDatagrams(count: writes) for datagram in datagrams { XCTAssertEqual(datagram.data.readableBytes, segments * segmentSize) + XCTAssertEqual(datagram.metadata?.segmentSize, segmentSize) } } @@ -1394,6 +1395,7 @@ class DatagramChannelTests: XCTestCase { let datagrams = try self.thirdChannel.waitForDatagrams(count: writes * segments) for datagram in datagrams { XCTAssertEqual(datagram.data.readableBytes, segmentSize) + XCTAssertNil(datagram.metadata?.segmentSize) } } } @@ -1414,6 +1416,90 @@ class DatagramChannelTests: XCTestCase { try self.testReceiveLargeBufferWithGRO(segments: 10, segmentSize: 1000, writes: 4, vectorReads: 4) } + func testSegmentSizeSetViaMetadataOnUnsupportedPlatform() throws { + try XCTSkipIf(System.supportsUDPSegmentationOffload, "UDP_SEGMENT (GSO) is supported on this platform") + + let buffer = ByteBuffer(repeating: 1, count: 10_000) + var metadata = AddressedEnvelope.Metadata() + metadata.segmentSize = 1000 + let writeData = AddressedEnvelope(remoteAddress: self.secondChannel.localAddress!, data: buffer, metadata: metadata) + + XCTAssertThrowsError(try self.firstChannel.writeAndFlush(NIOAny(writeData)).wait()) + } + + func testSegmentSizeSetViaMetadata() throws { + try XCTSkipUnless(System.supportsUDPSegmentationOffload, "UDP_SEGMENT (GSO) is not supported on this platform") + try XCTSkipUnless(System.supportsUDPReceiveOffload, "UDP_GRO is not supported on this platform") + try XCTSkipUnless(try self.hasGoodGROSupport()) + + // GSO can be enabled ad-hoc by setting the segment size in the metadata. This will also + // populate the metadata on the receive side. + + // Set GRO and a larger receive allocator on the receiver. + XCTAssertNoThrow(try self.secondChannel.setOption(ChannelOptions.datagramReceiveOffload, value: true).wait()) + let fixed = FixedSizeRecvByteBufferAllocator(capacity: 1 << 16) + XCTAssertNoThrow(try self.secondChannel.setOption(ChannelOptions.recvAllocator, value: fixed).wait()) + + let buffer = ByteBuffer(repeating: 1, count: 10_000) + var writeData = AddressedEnvelope(remoteAddress: self.secondChannel.localAddress!, data: buffer) + var metadata = AddressedEnvelope.Metadata() + + let segmentSizes = [1000, 500, 250] + + for segmentSize in segmentSizes { + metadata.segmentSize = segmentSize + writeData.metadata = metadata + XCTAssertNoThrow(try self.firstChannel.writeAndFlush(NIOAny(writeData)).wait()) + } + + let datagrams = try self.secondChannel.waitForDatagrams(count: segmentSizes.count) + for (datagram, segmentSize) in zip(datagrams, segmentSizes) { + XCTAssertEqual(datagram.data.readableBytes, 10_000) + XCTAssertEqual(datagram.metadata?.segmentSize, segmentSize) + } + } + + func testSegmentSizeFromMetadataTakesPrecedence() throws { + try XCTSkipUnless(System.supportsUDPSegmentationOffload, "UDP_SEGMENT (GSO) is not supported on this platform") + try XCTSkipUnless(System.supportsUDPReceiveOffload, "UDP_GRO is not supported on this platform") + try XCTSkipUnless(try self.hasGoodGROSupport()) + + // Set GSO on the socket. + let socketOptionSegmentSize = 500 + XCTAssertNoThrow(try self.firstChannel.setOption(ChannelOptions.datagramSegmentSize, value: CInt(socketOptionSegmentSize)).wait()) + // Set GRO and a larger receive allocator on the receiver. + XCTAssertNoThrow(try self.secondChannel.setOption(ChannelOptions.datagramReceiveOffload, value: true).wait()) + let fixed = FixedSizeRecvByteBufferAllocator(capacity: 1 << 16) + XCTAssertNoThrow(try self.secondChannel.setOption(ChannelOptions.recvAllocator, value: fixed).wait()) + + let buffer = ByteBuffer(repeating: 1, count: 10_000) + var writeData = AddressedEnvelope(remoteAddress: self.secondChannel.localAddress!, data: buffer) + var metadata = AddressedEnvelope.Metadata() + + // nil means defer to the size set on the socket. + let segmentSizes = [1000, nil, 250, nil] + + for segmentSize in segmentSizes { + if let segmentSize = segmentSize { + metadata.segmentSize = segmentSize + writeData.metadata = metadata + } else { + writeData.metadata = nil + } + XCTAssertNoThrow(try self.firstChannel.writeAndFlush(NIOAny(writeData)).wait()) + } + + let datagrams = try self.secondChannel.waitForDatagrams(count: segmentSizes.count) + for (datagram, segmentSize) in zip(datagrams, segmentSizes) { + XCTAssertEqual(datagram.data.readableBytes, 10_000) + if let segmentSize = segmentSize { + XCTAssertEqual(datagram.metadata?.segmentSize, segmentSize) + } else { + XCTAssertEqual(datagram.metadata?.segmentSize, socketOptionSegmentSize) + } + } + } + private func hasGoodGROSupport() throws -> Bool { // Source code for UDP_GRO was added in Linux 5.0. However, this support is somewhat limited // and some sources indicate support was actually added in 5.10 (perhaps more widely