diff --git a/src/network.zig b/src/network.zig index 5301739316..0ea5a65a32 100644 --- a/src/network.zig +++ b/src/network.zig @@ -47,10 +47,10 @@ const Socket = struct { posix.close(self.socketID); } - fn send(self: Socket, data: []const u8, destination: Address) void { + fn send(self: Socket, data: []const u8, destination: SocketAddress) void { const addr = posix.sockaddr.in{ .port = @byteSwap(destination.port), - .addr = destination.ip, + .addr = destination.ip.address, }; if(builtin.os.tag == .windows) { // TODO: Upstream error, fix after next Zig update after #24466 is merged const sendto = struct { @@ -70,7 +70,7 @@ const Socket = struct { } } - fn receive(self: Socket, buffer: []u8, timeout: i32, resultAddress: *Address) ![]u8 { + fn receive(self: Socket, buffer: []u8, timeout: i32, resultAddress: *SocketAddress) ![]u8 { if(builtin.os.tag == .windows) { // Of course Windows always has it's own special thing. var pfd = [1]posix.pollfd{ .{.fd = self.socketID, .events = std.c.POLL.RDNORM | std.c.POLL.RDBAND, .revents = undefined}, @@ -98,17 +98,11 @@ const Socket = struct { var addr: posix.sockaddr.in = undefined; var addrLen: posix.socklen_t = @sizeOf(posix.sockaddr.in); const length = try posix.recvfrom(self.socketID, buffer, 0, @ptrCast(&addr), &addrLen); - resultAddress.ip = addr.addr; + resultAddress.ip = .{.address = addr.addr}; resultAddress.port = @byteSwap(addr.port); return buffer[0..length]; } - fn resolveIP(addr: []const u8) !u32 { - const list = try std.net.getAddressList(main.stackAllocator.allocator, addr, settings.defaultPort); - defer list.deinit(); - return list.addrs[0].in.sa.addr; - } - fn getPort(self: Socket) !u16 { var addr: posix.sockaddr.in = undefined; var addrLen: posix.socklen_t = @sizeOf(posix.sockaddr.in); @@ -122,24 +116,88 @@ pub fn init() void { protocols.init(); } -pub const Address = struct { - ip: u32, +pub const IpAddress = struct { + address: u32, + + pub const localhost = parse("127.0.0.1") catch unreachable; + + pub fn format(self: IpAddress, writer: anytype) !void { + try writer.print("{}.{}.{}.{}", .{self.address & 255, self.address >> 8 & 255, self.address >> 16 & 255, self.address >> 24}); + } + + fn resolve(addr: []const u8, port: ?u16) !IpAddress { + const allocator = if(builtin.is_test) main.heap.testingAllocator else main.stackAllocator; + const list = try std.net.getAddressList(allocator.allocator, addr, port orelse settings.defaultPort); + defer list.deinit(); + return .{.address = list.addrs[0].in.sa.addr}; + } + + pub fn parse(addr: []const u8) !IpAddress { + var parts = std.mem.splitScalar(u8, addr, '.'); + var address: u32 = 0; + while(parts.next()) |part| { + const octet = try std.fmt.parseInt(u8, part, 10); + address >>= 8; + if(address >> 24 > 0) return error.TooManyOctets; + address |= @as(u32, octet) << 24; + } + return .{.address = address}; + } +}; + +pub const SocketAddress = struct { + ip: IpAddress, port: u16, isSymmetricNAT: bool = false, - pub const localHost = 0x0100007f; - - pub fn format(self: Address, writer: anytype) !void { + pub fn format(self: SocketAddress, writer: anytype) !void { if(self.isSymmetricNAT) { - try writer.print("{}.{}.{}.{}:?{}", .{self.ip & 255, self.ip >> 8 & 255, self.ip >> 16 & 255, self.ip >> 24, self.port}); + try writer.print("{f}:?{}", .{self.ip, self.port}); } else { - try writer.print("{}.{}.{}.{}:{}", .{self.ip & 255, self.ip >> 8 & 255, self.ip >> 16 & 255, self.ip >> 24, self.port}); + try writer.print("{f}:{}", .{self.ip, self.port}); } } + + fn parseInner(string: []const u8, defaultPort: ?u16) !struct {ip: []const u8, isSymmetricNAT: bool, port: u16} { + var parts = std.mem.splitScalar(u8, string, ':'); + const ip = parts.first(); + var portString = parts.next() orelse return error.MissingColon; + if(parts.next() != null) return error.MultipleColons; + var isSymmetricNAT = false; + if(portString.len == 0) return error.EmptyPort; + if(portString[0] == '?') { + isSymmetricNAT = true; + portString = portString[1..]; + } + const port = try (std.fmt.parseInt(u16, portString, 10) catch |err| (defaultPort orelse err)); + return .{ + .ip = ip, + .isSymmetricNAT = isSymmetricNAT, + .port = port, + }; + } + + pub fn parse(string: []const u8, defaultPort: ?u16) !SocketAddress { + const innerResult = try parseInner(string, defaultPort); + return .{ + .ip = try IpAddress.parse(innerResult.ip), + .isSymmetricNAT = innerResult.isSymmetricNAT, + .port = innerResult.port, + }; + } + + pub fn resolve(string: []const u8, defaultPort: ?u16) !SocketAddress { + const innerResult = try parseInner(string, defaultPort); + return .{ + .ip = try IpAddress.resolve(innerResult.ip, innerResult.port), + .isSymmetricNAT = innerResult.isSymmetricNAT, + .port = innerResult.port, + }; + } }; const Request = struct { - address: Address, + address: SocketAddress, data: []const u8, requestNotifier: std.Thread.Condition = std.Thread.Condition{}, }; @@ -249,8 +307,8 @@ const stun = struct { // MARK: stun const XOR_MAPPED_ADDRESS: u16 = 0x0020; const MAGIC_COOKIE = [_]u8{0x21, 0x12, 0xA4, 0x42}; - fn requestAddress(connection: *ConnectionManager) Address { - var oldAddress: ?Address = null; + fn requestAddress(connection: *ConnectionManager) SocketAddress { + var oldAddress: ?SocketAddress = null; var seed: [std.Random.DefaultCsprng.secret_seed_length]u8 = @splat(0); std.mem.writeInt(i128, seed[0..16], std.time.nanoTimestamp(), builtin.cpu.arch.endian()); // Not the best seed, but it's not that important. var random = std.Random.DefaultCsprng.init(seed); @@ -265,14 +323,9 @@ const stun = struct { // MARK: stun }; random.fill(data[8..]); // Fill the transaction ID. - var splitter = std.mem.splitScalar(u8, server, ':'); - const ip = splitter.first(); - const serverAddress = Address{ - .ip = Socket.resolveIP(ip) catch |err| { - std.log.warn("Cannot resolve STUN server address: {s}, error: {s}", .{ip, @errorName(err)}); - continue; - }, - .port = std.fmt.parseUnsigned(u16, splitter.rest(), 10) catch 3478, + const serverAddress = SocketAddress.resolve(server, null) catch |err| { + std.log.warn("Cannot resolve STUN server address: {s}, error: {s}", .{server, @errorName(err)}); + continue; }; if(connection.sendRequest(main.globalAllocator, &data, serverAddress, 500*1000000)) |answer| { defer main.globalAllocator.free(answer); @@ -286,7 +339,7 @@ const stun = struct { // MARK: stun }; if(oldAddress) |other| { std.log.info("{f}", .{result}); - if(other.ip == result.ip and other.port == result.port) { + if(other.ip.address == result.ip.address and other.port == result.port) { return result; } else { result.isSymmetricNAT = true; @@ -299,10 +352,10 @@ const stun = struct { // MARK: stun std.log.warn("Couldn't reach STUN server: {s}", .{server}); } } - return Address{.ip = Socket.resolveIP("127.0.0.1") catch unreachable, .port = settings.defaultPort}; // TODO: Return ip address in LAN. + return SocketAddress{.ip = .localhost, .port = settings.defaultPort}; // TODO: Return ip address in LAN. } - fn findIPPort(_data: []const u8) !Address { + fn findIPPort(_data: []const u8) !SocketAddress { var data = _data[20..]; // Skip the header. while(data.len > 0) { const typ = std.mem.readInt(u16, data[0..2], .big); @@ -322,9 +375,9 @@ const stun = struct { // MARK: stun addressData[4] ^= MAGIC_COOKIE[2]; addressData[5] ^= MAGIC_COOKIE[3]; } - return Address{ + return SocketAddress{ .port = std.mem.readInt(u16, addressData[0..2], .big), - .ip = std.mem.readInt(u32, addressData[2..6], builtin.cpu.arch.endian()), // Needs to stay in big endian → native. + .ip = .{.address = std.mem.readInt(u32, addressData[2..6], builtin.cpu.arch.endian())}, // Needs to stay in big endian → native. }; } else if(data[1] == 0x02) { data = data[(len + 3) & ~@as(usize, 3) ..]; // Pad to 32 Bit. @@ -357,7 +410,7 @@ pub const ConnectionManager = struct { // MARK: ConnectionManager socket: Socket = undefined, thread: std.Thread = undefined, threadId: std.Thread.Id = undefined, - externalAddress: Address = undefined, + externalAddress: SocketAddress = undefined, online: Atomic(bool) = .init(false), running: Atomic(bool) = .init(true), @@ -378,7 +431,7 @@ pub const ConnectionManager = struct { // MARK: ConnectionManager const PacketSendRequest = struct { data: []const u8, - target: Address, + target: SocketAddress, time: i64, fn compare(_: void, a: PacketSendRequest, b: PacketSendRequest) std.math.Order { @@ -444,7 +497,7 @@ pub const ConnectionManager = struct { // MARK: ConnectionManager } } - pub fn send(self: *ConnectionManager, data: []const u8, target: Address, nanoTime: ?i64) void { + pub fn send(self: *ConnectionManager, data: []const u8, target: SocketAddress, nanoTime: ?i64) void { if(nanoTime) |time| { self.mutex.lock(); defer self.mutex.unlock(); @@ -458,7 +511,7 @@ pub const ConnectionManager = struct { // MARK: ConnectionManager } } - pub fn sendRequest(self: *ConnectionManager, allocator: NeverFailingAllocator, data: []const u8, target: Address, timeout_ns: u64) ?[]const u8 { + pub fn sendRequest(self: *ConnectionManager, allocator: NeverFailingAllocator, data: []const u8, target: SocketAddress, timeout_ns: u64) ?[]const u8 { self.socket.send(data, target); var request = Request{.address = target, .data = data}; { @@ -494,7 +547,7 @@ pub const ConnectionManager = struct { // MARK: ConnectionManager self.mutex.lock(); defer self.mutex.unlock(); for(self.connections.items) |other| { - if(other.remoteAddress.ip == conn.remoteAddress.ip and other.remoteAddress.port == conn.remoteAddress.port) return error.AlreadyConnected; + if(other.remoteAddress.ip.address == conn.remoteAddress.ip.address and other.remoteAddress.port == conn.remoteAddress.port) return error.AlreadyConnected; } self.connections.append(conn); } @@ -518,12 +571,12 @@ pub const ConnectionManager = struct { // MARK: ConnectionManager } } - fn onReceive(self: *ConnectionManager, data: []const u8, source: Address) void { + fn onReceive(self: *ConnectionManager, data: []const u8, source: SocketAddress) void { std.debug.assert(self.threadId == std.Thread.getCurrentId()); self.mutex.lock(); for(self.connections.items) |conn| { - if(conn.remoteAddress.ip == source.ip) { + if(conn.remoteAddress.ip.address == source.ip.address) { if(conn.bruteforcingPort) { conn.remoteAddress.port = source.port; conn.bruteforcingPort = false; @@ -539,15 +592,15 @@ pub const ConnectionManager = struct { // MARK: ConnectionManager defer self.mutex.unlock(); // Check if it's part of an active request: for(self.requests.items) |request| { - if(request.address.ip == source.ip and request.address.port == source.port) { + if(request.address.ip.address == source.ip.address and request.address.port == source.port) { request.data = main.globalAllocator.dupe(u8, data); request.requestNotifier.signal(); return; } } - if(self.online.load(.acquire) and source.ip == self.externalAddress.ip and source.port == self.externalAddress.port) return; + if(self.online.load(.acquire) and source.ip.address == self.externalAddress.ip.address and source.port == self.externalAddress.port) return; } - if(self.allowNewConnections.load(.monotonic) or source.ip == Address.localHost) { + if(self.allowNewConnections.load(.monotonic) or source.ip.address == IpAddress.localhost.address) { if(data.len != 0 and data[0] == @intFromEnum(Connection.ChannelId.init)) { const ip = std.fmt.allocPrint(main.stackAllocator.allocator, "{f}", .{source}) catch unreachable; defer main.stackAllocator.free(ip); @@ -573,7 +626,7 @@ pub const ConnectionManager = struct { // MARK: ConnectionManager while(self.running.load(.monotonic)) { main.heap.GarbageCollection.syncPoint(); self.waitingToFinishReceive.broadcast(); - var source: Address = undefined; + var source: SocketAddress = undefined; if(self.socket.receive(&self.receiveBuffer, 1, &source)) |data| { self.onReceive(data, source); } else |err| { @@ -1099,7 +1152,7 @@ pub const Connection = struct { // MARK: Connection manager: *ConnectionManager, user: ?*main.server.User, - remoteAddress: Address, + remoteAddress: SocketAddress, bruteforcingPort: bool = false, bruteForcedPortRange: u16 = 0, @@ -1158,20 +1211,8 @@ pub const Connection = struct { // MARK: Connection result.queuedConfirmations.deinit(); } if(result.connectionIdentifier == 0) result.connectionIdentifier = 1; - - var splitter = std.mem.splitScalar(u8, ipPort, ':'); - const ip = splitter.first(); - result.remoteAddress.ip = try Socket.resolveIP(ip); - var port = splitter.rest(); - if(port.len != 0 and port[0] == '?') { - result.remoteAddress.isSymmetricNAT = true; - result.bruteforcingPort = true; - port = port[1..]; - } - result.remoteAddress.port = std.fmt.parseUnsigned(u16, port, 10) catch blk: { - if(ip.len != ipPort.len) std.log.err("Could not parse port \"{s}\". Using default port instead.", .{port}); - break :blk settings.defaultPort; - }; + result.remoteAddress = try SocketAddress.resolve(ipPort, settings.defaultPort); + result.bruteforcingPort = result.remoteAddress.isSymmetricNAT; try result.manager.addConnection(result); return result; @@ -1505,3 +1546,110 @@ pub const Connection = struct { // MARK: Connection std.log.info("Disconnected", .{}); } }; + +const ProtocolTask = struct { + conn: *Connection, + protocol: u8, + data: []const u8, + + const vtable = utils.ThreadPool.VTable{ + .getPriority = main.utils.castFunctionSelfToAnyopaque(getPriority), + .isStillNeeded = main.utils.castFunctionSelfToAnyopaque(isStillNeeded), + .run = main.utils.castFunctionSelfToAnyopaque(run), + .clean = main.utils.castFunctionSelfToAnyopaque(clean), + .taskType = .misc, + }; + + pub fn schedule(conn: *Connection, protocol: u8, data: []const u8) void { + const task = main.globalAllocator.create(ProtocolTask); + task.* = ProtocolTask{ + .conn = conn, + .protocol = protocol, + .data = main.globalAllocator.dupe(u8, data), + }; + main.threadPool.addTask(task, &vtable); + } + + pub fn getPriority(_: *ProtocolTask) f32 { + return std.math.floatMax(f32); + } + + pub fn isStillNeeded(_: *ProtocolTask) bool { + return true; + } + + pub fn run(self: *ProtocolTask) void { + defer self.clean(); + var reader = utils.BinaryReader.init(self.data); + protocols.list[self.protocol].?(self.conn, &reader) catch |err| { + std.log.err("Got error {s} while executing protocol {} with data {any}", .{@errorName(err), self.protocol, self.data}); // TODO: Maybe disconnect on error + }; + } + + pub fn clean(self: *ProtocolTask) void { + main.globalAllocator.free(self.data); + main.globalAllocator.destroy(self); + } +}; + +test "Parse address" { + const localhost: u32 = 0x0100007f; + try std.testing.expectEqual(localhost, IpAddress.localhost.address); + const socketAddress = try SocketAddress.parse("127.0.0.1:1234", null); + try std.testing.expectEqual(SocketAddress{.ip = .{.address = localhost}, .port = 1234}, socketAddress); + + const symmetricSocketAddress = try SocketAddress.parse("127.0.0.1:?1234", null); + try std.testing.expectEqual(SocketAddress{.ip = .{.address = localhost}, .isSymmetricNAT = true, .port = 1234}, symmetricSocketAddress); +} + +test "Resolve address" { + const addresses: [4][]const u8 = .{ + "127.0.0.1", + "123.1.111.222", + "1.1.1.1", + "0.0.0.0", + }; + for(addresses) |addressStr| { + const parsedAddress = try IpAddress.parse(addressStr); + const resolvedAddress = try IpAddress.resolve(addressStr, null); + try std.testing.expectEqualDeep(parsedAddress, resolvedAddress); + } + const socketAddresses: [4][]const u8 = .{ + "127.0.0.1:1234", + "123.1.111.222:?11111", + "1.1.1.1:255", + "0.0.0.0:?3333", + }; + for(socketAddresses) |addressStr| { + const parsedAddress = try SocketAddress.parse(addressStr, null); + const resolvedAddress = try SocketAddress.resolve(addressStr, null); + try std.testing.expectEqualDeep(parsedAddress, resolvedAddress); + } +} + +test "Format address" { + const addresses: [4][]const u8 = .{ + "127.0.0.1", + "123.1.111.222", + "1.1.1.1", + "0.0.0.0", + }; + for(addresses) |addressStr| { + const address = try IpAddress.parse(addressStr); + const reformattedAddress = std.fmt.allocPrint(main.heap.testingAllocator.allocator, "{f}", .{address}) catch unreachable; + defer main.heap.testingAllocator.free(reformattedAddress); + try std.testing.expectEqualStrings(addressStr, reformattedAddress); + } + const socketAddresses: [4][]const u8 = .{ + "127.0.0.1:1234", + "123.1.111.222:?11111", + "1.1.1.1:255", + "0.0.0.0:?3333", + }; + for(socketAddresses) |addressStr| { + const address = try SocketAddress.parse(addressStr, null); + const reformattedAddress = std.fmt.allocPrint(main.heap.testingAllocator.allocator, "{f}", .{address}) catch unreachable; + defer main.heap.testingAllocator.free(reformattedAddress); + try std.testing.expectEqualStrings(addressStr, reformattedAddress); + } +}