Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add segment size to addressed envelope metadata #2390

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions Sources/NIOCore/AddressedEnvelope.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,40 @@ public struct AddressedEnvelope<DataType> {
self.remoteAddress = remoteAddress
self.data = data
}

public init(remoteAddress: SocketAddress, data: DataType, metadata: Metadata?) {
self.remoteAddress = remoteAddress
self.data = data
self.metadata = metadata
}

/// Any metadata associated with an `AddressedEnvelope`
public struct Metadata: Hashable, Sendable {
/// Details of any congestion state.
public var ecnState: NIOExplicitCongestionNotificationState
public var packetInfo: NIOPacketInfo?


/// The size of data segments.
///
/// For outbound messages setting this option informs the kernel to split the data from the
/// addressed envelope into segments of this size. Note that not all platforms support
/// this option and support should be checked with ``System/supportsUDPSegmentationOffload``.
///
/// For inbound messages this value may be set with the segment size set by the sender if
/// the ``ChannelOptions/Types/DatagramReceiveOffload`` option is set. Support for that
/// option should be checked with ``System/supportsUDPReceiveOffload``.
public var segmentSize: Int?

public init() {
self.ecnState = .transportNotCapable
self.packetInfo = nil
self.segmentSize = nil
}

public init(ecnState: NIOExplicitCongestionNotificationState) {
self.ecnState = ecnState
self.packetInfo = nil
self.segmentSize = nil
}

public init(ecnState: NIOExplicitCongestionNotificationState, packetInfo: NIOPacketInfo?) {
Expand Down
60 changes: 43 additions & 17 deletions Sources/NIOPosix/ControlMessage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ struct UnsafeControlMessageCollection {
// Add the `Collection` functionality to UnsafeControlMessageCollection.
extension UnsafeControlMessageCollection: Collection {
typealias Element = UnsafeControlMessage

struct Index: Equatable, Comparable {
fileprivate var cmsgPointer: UnsafeMutablePointer<cmsghdr>?

static func < (lhs: UnsafeControlMessageCollection.Index,
rhs: UnsafeControlMessageCollection.Index) -> Bool {
// nil is high, as that's the end of the collection.
Expand All @@ -105,30 +105,30 @@ extension UnsafeControlMessageCollection: Collection {
return false
}
}

fileprivate init(cmsgPointer: UnsafeMutablePointer<cmsghdr>?) {
self.cmsgPointer = cmsgPointer
}
}

var startIndex: Index {
var messageHeader = self.messageHeader
return withUnsafePointer(to: &messageHeader) { messageHeaderPtr in
let firstCMsg = NIOBSDSocketControlMessage.firstHeader(inside: messageHeaderPtr)
return Index(cmsgPointer: firstCMsg)
}
}

var endIndex: Index { return Index(cmsgPointer: nil) }

func index(after: Index) -> Index {
var msgHdr = messageHeader
return withUnsafeMutablePointer(to: &msgHdr) { messageHeaderPtr in
return Index(cmsgPointer: NIOBSDSocketControlMessage.nextHeader(inside: messageHeaderPtr,
after: after.cmsgPointer!))
}
}

public subscript(position: Index) -> Element {
let cmsg = position.cmsgPointer!
return UnsafeControlMessage(level: cmsg.pointee.cmsg_level,
Expand All @@ -152,13 +152,14 @@ struct UnsafeReceivedControlBytes {
struct ControlMessageParser {
var ecnValue: NIOExplicitCongestionNotificationState = .transportNotCapable // Default
var packetInfo: NIOPacketInfo? = nil
var segmentSize: Int? = nil

init(parsing controlMessagesReceived: UnsafeControlMessageCollection) {
for controlMessage in controlMessagesReceived {
self.receiveMessage(controlMessage)
}
}

#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS)
private static let ipv4TosType = IP_RECVTOS
#else
Expand All @@ -174,12 +175,14 @@ struct ControlMessageParser {
}
return readValue
}

private mutating func receiveMessage(_ controlMessage: UnsafeControlMessage) {
if controlMessage.level == IPPROTO_IP {
self.receiveIPv4Message(controlMessage)
} else if controlMessage.level == IPPROTO_IPV6 {
self.receiveIPv6Message(controlMessage)
} else if controlMessage.level == IPPROTO_UDP {
self.receiveUDPMessage(controlMessage)
}
}

Expand Down Expand Up @@ -225,6 +228,17 @@ struct ControlMessageParser {
}
}
}

private mutating func receiveUDPMessage(_ controlMessage: UnsafeControlMessage) {
#if os(Linux)
if controlMessage.type == NIOBSDSocket.Option.udp_gro.rawValue {
if let data = controlMessage.data {
let readValue = ControlMessageParser._readCInt(data: data)
self.segmentSize = Int(readValue)
}
}
#endif
}
}

extension NIOExplicitCongestionNotificationState {
Expand Down Expand Up @@ -262,7 +276,7 @@ extension CInt {
struct UnsafeOutboundControlBytes {
private var controlBytes: UnsafeMutableRawBufferPointer
private var writePosition: UnsafeMutableRawBufferPointer.Index

/// This structure must not outlive `controlBytes`
init(controlBytes: UnsafeMutableRawBufferPointer) {
self.controlBytes = controlBytes
Expand All @@ -279,32 +293,32 @@ struct UnsafeOutboundControlBytes {
type: CInt,
payload: PayloadType) {
let writableBuffer = UnsafeMutableRawBufferPointer(fastRebase: self.controlBytes[writePosition...])

let requiredSize = NIOBSDSocketControlMessage.space(payloadSize: MemoryLayout.stride(ofValue: payload))
precondition(writableBuffer.count >= requiredSize, "Insufficient size for cmsghdr and data")

let bufferBase = writableBuffer.baseAddress!
// Binding to cmsghdr is safe here as this is the only place where we bind to non-Raw.
let cmsghdrPtr = bufferBase.bindMemory(to: cmsghdr.self, capacity: 1)
cmsghdrPtr.pointee.cmsg_level = level
cmsghdrPtr.pointee.cmsg_type = type
cmsghdrPtr.pointee.cmsg_len = .init(NIOBSDSocketControlMessage.length(payloadSize: MemoryLayout.size(ofValue: payload)))

let dataPointer = NIOBSDSocketControlMessage.data(for: cmsghdrPtr)!
precondition(dataPointer.count >= MemoryLayout<PayloadType>.stride)
dataPointer.storeBytes(of: payload, as: PayloadType.self)

self.writePosition += requiredSize
}

/// The result is only valid while this is valid.
var validControlBytes: UnsafeMutableRawBufferPointer {
if writePosition == 0 {
return UnsafeMutableRawBufferPointer(start: nil, count: 0)
}
return UnsafeMutableRawBufferPointer(fastRebase: self.controlBytes[0 ..< self.writePosition])
}

}

extension UnsafeOutboundControlBytes {
Expand All @@ -330,12 +344,24 @@ extension UnsafeOutboundControlBytes {
break
}
}

internal mutating func appendSegmentSize(metadata: AddressedEnvelope<ByteBuffer>.Metadata?) {
#if os(Linux)
guard let segmentSize = metadata?.segmentSize, System.supportsUDPSegmentationOffload else { return }
self.appendGenericControlMessage(level: .init(IPPROTO_UDP),
type: NIOBSDSocket.Option.udp_segment.rawValue,
payload: UInt16(segmentSize))
#endif
}
}

extension AddressedEnvelope.Metadata {
/// It's assumed the caller has checked that congestion information is required before calling.
internal init(from controlMessagesReceived: UnsafeControlMessageCollection) {
let controlMessageReceiver = ControlMessageParser(parsing: controlMessagesReceived)
self.init(ecnState: controlMessageReceiver.ecnValue, packetInfo: controlMessageReceiver.packetInfo)
self.init()
self.ecnState = controlMessageReceiver.ecnValue
self.packetInfo = controlMessageReceiver.packetInfo
self.segmentSize = controlMessageReceiver.segmentSize
}
}
7 changes: 3 additions & 4 deletions Sources/NIOPosix/DatagramVectorReadManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ struct DatagramVectorReadManager {

// First we set up the iovec and save it off.
self.ioVector[i] = IOVector(iov_base: bufferPointer.baseAddress! + (i * messageSize), iov_len: numericCast(messageSize))

let controlBytes: UnsafeMutableRawBufferPointer
if parseControlMessages {
// This will be used in buildMessages below but should not be used beyond return of this function.
Expand Down Expand Up @@ -178,12 +178,11 @@ struct DatagramVectorReadManager {
precondition(self.messageVector[i].msg_hdr.msg_namelen != 0, "Unexpected zero length peer name")
#endif
let address: SocketAddress = self.sockaddrVector[i].convert()

// Extract congestion information if requested.
let metadata: AddressedEnvelope<ByteBuffer>.Metadata?
if parseControlMessages {
let controlMessagesReceived =
UnsafeControlMessageCollection(messageHeader: self.messageVector[i].msg_hdr)
let controlMessagesReceived = UnsafeControlMessageCollection(messageHeader: self.messageVector[i].msg_hdr)
metadata = .init(from: controlMessagesReceived)
} else {
metadata = nil
Expand Down
1 change: 1 addition & 0 deletions Sources/NIOPosix/PendingDatagramWritesManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ private func doPendingDatagramWriteVectorOperation(pending: PendingDatagramWrite

var controlBytes = UnsafeOutboundControlBytes(controlBytes: controlMessageStorage[c])
controlBytes.appendExplicitCongestionState(metadata: p.metadata, protocolFamily: protocolFamily)
controlBytes.appendSegmentSize(metadata: p.metadata)
let controlMessageBytePointer = controlBytes.validControlBytes

let msg = msghdr(msg_name: address,
Expand Down
14 changes: 10 additions & 4 deletions Sources/NIOPosix/SocketChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,11 @@ final class ServerSocketChannel: BaseSocketChannel<ServerSocket> {
final class DatagramChannel: BaseSocketChannel<Socket> {
private var reportExplicitCongestionNotifications = false
private var receivePacketInfo = false
private var receiveSegmentSize = false

private var parseControlMessages: Bool {
return self.reportExplicitCongestionNotifications || self.receivePacketInfo || self.receiveSegmentSize
}

// Guard against re-entrance of flushNow() method.
private let pendingWrites: PendingDatagramWritesManager
Expand Down Expand Up @@ -517,6 +522,7 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
throw ChannelError.operationUnsupported
}
let enable = value as! ChannelOptions.Types.DatagramReceiveOffload.Value
self.receiveSegmentSize = enable
try self.socket.setUDPReceiveOffload(enable)
default:
try super.setOption0(option, value: value)
Expand Down Expand Up @@ -619,7 +625,7 @@ final class DatagramChannel: BaseSocketChannel<Socket> {

// These control bytes must not escape the current call stack
let controlBytesBuffer: UnsafeMutableRawBufferPointer
if self.reportExplicitCongestionNotifications || self.receivePacketInfo {
if self.parseControlMessages {
controlBytesBuffer = self.selectableEventLoop.controlMessageStorage[0]
} else {
controlBytesBuffer = UnsafeMutableRawBufferPointer(start: nil, count: 0)
Expand Down Expand Up @@ -648,8 +654,7 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
readPending = false

let metadata: AddressedEnvelope<ByteBuffer>.Metadata?
if self.reportExplicitCongestionNotifications || self.receivePacketInfo,
let controlMessagesReceived = controlBytes.receivedControlMessages {
if self.parseControlMessages, let controlMessagesReceived = controlBytes.receivedControlMessages {
metadata = .init(from: controlMessagesReceived)
} else {
metadata = nil
Expand Down Expand Up @@ -688,7 +693,7 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
try vectorReadManager.readFromSocket(
socket: self.socket,
buffer: &buffer,
parseControlMessages: self.reportExplicitCongestionNotifications || self.receivePacketInfo)
parseControlMessages: self.parseControlMessages)
}

switch result {
Expand Down Expand Up @@ -810,6 +815,7 @@ final class DatagramChannel: BaseSocketChannel<Socket> {
controlBytes: self.selectableEventLoop.controlMessageStorage[0])
controlBytes.appendExplicitCongestionState(metadata: metadata,
protocolFamily: self.localAddress?.protocol)
controlBytes.appendSegmentSize(metadata: metadata)
return try self.socket.sendmsg(pointer: ptr,
destinationPtr: destinationPtr,
destinationSize: destinationSize,
Expand Down
3 changes: 3 additions & 0 deletions Tests/NIOPosixTests/DatagramChannelTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ extension DatagramChannelTests {
("testChannelCanReceiveLargeBufferWithGROUsingVectorReads", testChannelCanReceiveLargeBufferWithGROUsingVectorReads),
("testChannelCanReceiveMultipleLargeBuffersWithGROUsingScalarReads", testChannelCanReceiveMultipleLargeBuffersWithGROUsingScalarReads),
("testChannelCanReceiveMultipleLargeBuffersWithGROUsingVectorReads", testChannelCanReceiveMultipleLargeBuffersWithGROUsingVectorReads),
("testSegmentSizeSetViaMetadataOnUnsupportedPlatform", testSegmentSizeSetViaMetadataOnUnsupportedPlatform),
("testSegmentSizeSetViaMetadata", testSegmentSizeSetViaMetadata),
("testSegmentSizeFromMetadataTakesPrecedence", testSegmentSizeFromMetadataTakesPrecedence),
]
}
}
Expand Down
Loading