diff --git a/FlyingSocks/Sources/Socket.swift b/FlyingSocks/Sources/Socket.swift index 61fabbeb..8625e098 100644 --- a/FlyingSocks/Sources/Socket.swift +++ b/FlyingSocks/Sources/Socket.swift @@ -41,7 +41,7 @@ public enum SocketType: Sendable { case datagram } -extension SocketType { +package extension SocketType { var rawValue: Int32 { switch self { case .stream: @@ -50,6 +50,17 @@ extension SocketType { Socket.datagram } } + + init(rawValue: Int32) throws { + switch rawValue { + case Socket.stream: + self = .stream + case Socket.datagram: + self = .datagram + default: + throw SocketError.makeFailed("Invalid SocketType") + } + } } public struct Socket: Sendable, Hashable { @@ -74,20 +85,20 @@ public struct Socket: Sendable, Hashable { @available(*, deprecated, message: "type is now SocketType") public init(domain: Int32, type: Int32) throws { - let descriptor = FileDescriptor(rawValue: Socket.socket(domain, type, 0)) + try self.init(domain: domain, type: SocketType(rawValue: type)) + } + + public init(domain: Int32, type: SocketType) throws { + let descriptor = FileDescriptor(rawValue: Socket.socket(domain, type.rawValue, 0)) guard descriptor != .invalid else { throw SocketError.makeFailed("CreateSocket") } self.file = descriptor - if type == SocketType.datagram.rawValue { + if type == .datagram { try setPktInfo(domain: domain) } } - public init(domain: Int32, type: SocketType) throws { - try self.init(domain: domain, type: type.rawValue) - } - public var flags: Flags { get throws { let flags = Socket.fcntl(file.rawValue, F_GETFL) diff --git a/FlyingSocks/Tests/SocketTests.swift b/FlyingSocks/Tests/SocketTests.swift index 68b6dc01..511c7117 100644 --- a/FlyingSocks/Tests/SocketTests.swift +++ b/FlyingSocks/Tests/SocketTests.swift @@ -35,6 +35,21 @@ import Testing struct SocketTests { + @Test + func socketType_init() throws { + #expect(try SocketType(rawValue: Socket.stream) == .stream) + #expect(try SocketType(rawValue: Socket.datagram) == .datagram) + #expect(throws: (any Error).self) { + try SocketType(rawValue: -1) + } + } + + @Test + func socketType_rawValue() { + #expect(SocketType.stream.rawValue == Socket.stream) + #expect(SocketType.datagram.rawValue == Socket.datagram) + } + @Test func socketEvents() { let events: Set = [.read, .write] diff --git a/FlyingSocks/XCTests/SocketTests.swift b/FlyingSocks/XCTests/SocketTests.swift index b34063df..e555e459 100644 --- a/FlyingSocks/XCTests/SocketTests.swift +++ b/FlyingSocks/XCTests/SocketTests.swift @@ -34,6 +34,17 @@ import XCTest final class SocketTests: XCTestCase { + func testSocketType_init() { + XCTAssertEqual(try SocketType(rawValue: Socket.stream), .stream) + XCTAssertEqual(try SocketType(rawValue: Socket.datagram), .datagram) + XCTAssertThrowsError(try SocketType(rawValue: -1)) + } + + func funcSocketType_rawValue() { + XCTAssertEqual(SocketType.stream.rawValue, Socket.stream) + XCTAssertEqual(SocketType.datagram.rawValue, Socket.datagram) + } + func testSocketEvents() { let events: Set = [.read, .write]