diff --git a/api/axfeat/Cargo.toml b/api/axfeat/Cargo.toml index 875d302d0a..08c4110df7 100644 --- a/api/axfeat/Cargo.toml +++ b/api/axfeat/Cargo.toml @@ -68,6 +68,7 @@ fs-times = ["fs", "axfs/times"] # Networking net = ["alloc", "paging", "axdriver/virtio-net", "dep:axnet", "axruntime/net"] vsock = ["net", "axdriver/virtio-socket", "axruntime/vsock", "axnet/vsock"] +netlink = ["net", "axnet/netlink"] # Display display = [ diff --git a/modules/axfs/src/highlevel/file.rs b/modules/axfs/src/highlevel/file.rs index 7cd360f9d5..0d92565569 100644 --- a/modules/axfs/src/highlevel/file.rs +++ b/modules/axfs/src/highlevel/file.rs @@ -219,16 +219,13 @@ impl OpenOptions { } pub fn open_loc(&self, loc: Location) -> VfsResult { - if !self.is_valid() { - return Err(VfsError::InvalidInput); - } + self.check_options()?; + self._open(loc) } pub fn open(&self, context: &FsContext, path: impl AsRef) -> VfsResult { - if !self.is_valid() { - return Err(VfsError::InvalidInput); - } + self.check_options()?; let loc = match context.resolve_parent(path.as_ref()) { Ok((parent, name)) => { @@ -273,24 +270,14 @@ impl OpenOptions { }) } - pub(crate) fn is_valid(&self) -> bool { + pub(crate) fn check_options(&self) -> VfsResult<()> { if !self.read && !self.write && !self.append { - return true; + return Err(VfsError::InvalidInput); } - match (self.write, self.append) { - (true, false) => {} - (false, false) => { - if self.truncate || self.create || self.create_new { - return false; - } - } - (_, true) => { - if self.truncate && !self.create_new { - return false; - } - } + if self.create_new && !self.create { + return Err(VfsError::AlreadyExists); } - true + Ok(()) } } diff --git a/modules/axnet/Cargo.toml b/modules/axnet/Cargo.toml index 3a3d1ec44b..7c5b18dc2f 100644 --- a/modules/axnet/Cargo.toml +++ b/modules/axnet/Cargo.toml @@ -11,6 +11,7 @@ documentation = "https://arceos-org.github.io/arceos/axnet/index.html" [features] vsock = ["axdriver/vsock"] +netlink = [] [dependencies] async-channel = { version = "2.5", default-features = false } @@ -25,13 +26,20 @@ axio = { workspace = true } axpoll = { workspace = true } axsync = { workspace = true } axtask = { workspace = true } -bitflags = "2.9.1" +bitflags = { version = "2.9.1", features = ["bytemuck"] } +bytemuck = { version = "1.23", features = ["derive"] } cfg-if = { workspace = true } enum_dispatch = { workspace = true } event-listener = { version = "5.4", default-features = false } hashbrown = "0.16" lazy_static = { workspace = true } log = { workspace = true } +memory_addr = { workspace = true } +num_enum = { version = "0.7", default-features = false } +rand = { version = "0.9", default-features = false, features = [ + "alloc", + "small_rng", +] } ringbuf = { version = "0.4.8", default-features = false, features = ["alloc"] } spin = { workspace = true } diff --git a/modules/axnet/src/consts.rs b/modules/axnet/src/consts.rs index d229c5eb8b..2d7a7a496a 100644 --- a/modules/axnet/src/consts.rs +++ b/modules/axnet/src/consts.rs @@ -21,3 +21,5 @@ pub const LISTEN_QUEUE_SIZE: usize = 512; pub const SOCKET_BUFFER_SIZE: usize = 64; pub const ETHERNET_MAX_PENDING_PACKETS: usize = 32; + +pub const NETLINK_DEFAULT_BUF_SIZE: usize = 65536; diff --git a/modules/axnet/src/device/ethernet.rs b/modules/axnet/src/device/ethernet.rs index 359d186bfa..338233cf04 100644 --- a/modules/axnet/src/device/ethernet.rs +++ b/modules/axnet/src/device/ethernet.rs @@ -9,13 +9,13 @@ use smoltcp::{ time::{Duration, Instant}, wire::{ ArpOperation, ArpPacket, ArpRepr, EthernetAddress, EthernetFrame, EthernetProtocol, - EthernetRepr, IpAddress, Ipv4Cidr, + EthernetRepr, IpAddress, Ipv4Address, Ipv4Cidr, }, }; use crate::{ consts::{ETHERNET_MAX_PENDING_PACKETS, STANDARD_MTU}, - device::Device, + device::{Device, DeviceFlags, DeviceType}, }; const EMPTY_MAC: EthernetAddress = EthernetAddress([0; 6]); @@ -26,6 +26,7 @@ struct Neighbor { } pub struct EthernetDevice { + index: u32, name: String, inner: AxNetDevice, neighbors: HashMap>, @@ -36,7 +37,7 @@ pub struct EthernetDevice { impl EthernetDevice { const NEIGHBOR_TTL: Duration = Duration::from_secs(60); - pub fn new(name: String, inner: AxNetDevice, ip: Ipv4Cidr) -> Self { + pub fn new(index: u32, name: String, inner: AxNetDevice, ip: Ipv4Cidr) -> Self { let pending_packets = PacketBuffer::new( vec![PacketMetadata::EMPTY; ETHERNET_MAX_PENDING_PACKETS], vec![ @@ -46,6 +47,7 @@ impl EthernetDevice { ], ); Self { + index, name, inner, neighbors: HashMap::new(), @@ -261,6 +263,30 @@ impl Device for EthernetDevice { &self.name } + fn get_type(&self) -> DeviceType { + DeviceType::ETHER + } + + fn get_flags(&self) -> DeviceFlags { + DeviceFlags::UP + | DeviceFlags::BROADCAST + | DeviceFlags::RUNNING + | DeviceFlags::LOWER_UP + | DeviceFlags::MULTICAST + } + + fn get_index(&self) -> u32 { + self.index + } + + fn ipv4_addr(&self) -> Option { + Some(self.ip.address()) + } + + fn prefix_len(&self) -> Option { + Some(self.ip.prefix_len()) + } + fn recv(&mut self, buffer: &mut PacketBuffer<()>, timestamp: Instant) -> bool { loop { let rx_buf = match self.inner.receive() { diff --git a/modules/axnet/src/device/loopback.rs b/modules/axnet/src/device/loopback.rs index af0eba22bf..1f8c8eaaaf 100644 --- a/modules/axnet/src/device/loopback.rs +++ b/modules/axnet/src/device/loopback.rs @@ -1,5 +1,5 @@ use alloc::vec; -use core::task::Waker; +use core::{net::Ipv4Addr, task::Waker}; use axpoll::PollSet; use smoltcp::{ @@ -10,20 +10,22 @@ use smoltcp::{ use crate::{ consts::{SOCKET_BUFFER_SIZE, STANDARD_MTU}, - device::Device, + device::{Device, DeviceFlags, DeviceType}, }; pub struct LoopbackDevice { + index: u32, buffer: PacketBuffer<'static, ()>, poll: PollSet, } impl LoopbackDevice { - pub fn new() -> Self { + pub fn new(index: u32) -> Self { let buffer = PacketBuffer::new( vec![PacketMetadata::EMPTY; SOCKET_BUFFER_SIZE], vec![0u8; STANDARD_MTU * SOCKET_BUFFER_SIZE], ); Self { + index, buffer, poll: PollSet::new(), } @@ -35,6 +37,26 @@ impl Device for LoopbackDevice { "lo" } + fn get_type(&self) -> DeviceType { + DeviceType::LOOPBACK + } + + fn get_flags(&self) -> DeviceFlags { + DeviceFlags::UP | DeviceFlags::LOOPBACK | DeviceFlags::RUNNING + } + + fn get_index(&self) -> u32 { + self.index + } + + fn ipv4_addr(&self) -> Option { + Some(Ipv4Addr::new(127, 0, 0, 1)) + } + + fn prefix_len(&self) -> Option { + Some(8) + } + fn recv(&mut self, buffer: &mut PacketBuffer<()>, _timestamp: Instant) -> bool { self.buffer.dequeue().ok().is_some_and(|(_, rx_buf)| { buffer diff --git a/modules/axnet/src/device/mod.rs b/modules/axnet/src/device/mod.rs index a28a59e628..884cc4314d 100644 --- a/modules/axnet/src/device/mod.rs +++ b/modules/axnet/src/device/mod.rs @@ -1,6 +1,12 @@ use core::task::Waker; -use smoltcp::{storage::PacketBuffer, time::Instant, wire::IpAddress}; +use bitflags::bitflags; +use num_enum::TryFromPrimitive; +use smoltcp::{ + storage::PacketBuffer, + time::Instant, + wire::{IpAddress, Ipv4Address}, +}; mod ethernet; mod loopback; @@ -15,6 +21,16 @@ pub use vsock::*; pub trait Device: Send + Sync { fn name(&self) -> &str; + fn get_type(&self) -> DeviceType; + + fn get_flags(&self) -> DeviceFlags; + + fn get_index(&self) -> u32; + + fn ipv4_addr(&self) -> Option; + + fn prefix_len(&self) -> Option; + fn recv(&mut self, buffer: &mut PacketBuffer<()>, timestamp: Instant) -> bool; /// Sends a packet to the next hop. /// @@ -25,3 +41,81 @@ pub trait Device: Send + Sync { fn register_waker(&self, waker: &Waker); } + +/// Device type. +/// +/// Reference: +#[repr(u16)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, TryFromPrimitive)] +#[allow(dead_code)] +pub enum DeviceType { + // Arp protocol hardware identifiers + /// from KA9Q: NET/ROM pseudo + NETROM = 0, + /// Ethernet 10Mbps + ETHER = 1, + /// Experimental Ethernet + EETHER = 2, + + // Dummy types for non ARP hardware + /// IPIP tunnel + TUNNEL = 768, + /// IP6IP6 tunnel + TUNNEL6 = 769, + /// Frame Relay Access Device + FRAD = 770, + /// SKIP vif + SKIP = 771, + /// Loopback device + LOOPBACK = 772, + /// Localtalk device + LOCALTALK = 773, + // TODO: This enum is not exhaustive +} + +bitflags! { + /// Device flags. + /// + /// Reference: + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + pub struct DeviceFlags: u32 { + /// Device is up + const UP = 1<<0; + /// Broadcast address valid + const BROADCAST = 1<<1; + /// Turn on debugging + const DEBUG = 1<<2; + /// Loopback net + const LOOPBACK = 1<<3; + /// Device is has p-p link + const POINTOPOINT = 1<<4; + /// Avoid use of trailers + const NOTRAILERS = 1<<5; + /// Device RFC2863 OPER_UP + const RUNNING = 1<<6; + /// No ARP protocol + const NOARP = 1<<7; + /// Receive all packets + const PROMISC = 1<<8; + /// Receive all multicast packets + const ALLMULTI = 1<<9; + /// Master of a load balancer + const MASTER = 1<<10; + /// Slave of a load balancer + const SLAVE = 1<<11; + /// Supports multicast + const MULTICAST = 1<<12; + /// Can set media type + const PORTSEL = 1<<13; + /// Auto media select active + const AUTOMEDIA = 1<<14; + /// Dialup device with changing addresses + const DYNAMIC = 1<<15; + /// Driver signals L1 up + const LOWER_UP = 1<<16; + /// Driver signals dormant + const DORMANT = 1<<17; + /// Echo sent packets + const ECHO = 1<<18; + } +} diff --git a/modules/axnet/src/lib.rs b/modules/axnet/src/lib.rs index b5de464a43..49c9f5a282 100644 --- a/modules/axnet/src/lib.rs +++ b/modules/axnet/src/lib.rs @@ -13,6 +13,7 @@ //! [smoltcp]: https://github.com/smoltcp-rs/smoltcp #![no_std] +#![feature(associated_type_defaults)] #[macro_use] extern crate log; @@ -32,9 +33,13 @@ pub mod udp; pub mod unix; #[cfg(feature = "vsock")] pub mod vsock; +#[cfg(feature = "netlink")] +pub mod netlink; + mod wrapper; use alloc::{borrow::ToOwned, boxed::Box}; +use core::sync::atomic::{AtomicU32, Ordering}; use axdriver::{AxDeviceContainer, prelude::*}; use axsync::Mutex; @@ -56,6 +61,8 @@ static SOCKET_SET: Lazy = Lazy::new(SocketSetWrapper::new); static SERVICE: Once> = Once::new(); +static DEVICE_INDEX_COUNTER: AtomicU32 = AtomicU32::new(1); + fn get_service() -> axsync::MutexGuard<'static, Service> { SERVICE .get() @@ -68,7 +75,8 @@ pub fn init_network(mut net_devs: AxDeviceContainer) { info!("Initialize network subsystem..."); let mut router = Router::new(); - let lo_dev = router.add_device(Box::new(LoopbackDevice::new())); + let index = DEVICE_INDEX_COUNTER.fetch_add(1, Ordering::Relaxed); + let lo_dev = router.add_device(Box::new(LoopbackDevice::new(index))); let lo_ip = Ipv4Cidr::new(Ipv4Address::new(127, 0, 0, 1), 8); router.add_rule(Rule::new( @@ -85,6 +93,7 @@ pub fn init_network(mut net_devs: AxDeviceContainer) { let eth0_ip = Ipv4Cidr::new(IP.parse().expect("Invalid IPv4 address"), IP_PREFIX); let eth0_dev = router.add_device(Box::new(EthernetDevice::new( + DEVICE_INDEX_COUNTER.fetch_add(1, Ordering::Relaxed), "eth0".to_owned(), dev, eth0_ip, diff --git a/modules/axnet/src/netlink/addr.rs b/modules/axnet/src/netlink/addr.rs new file mode 100644 index 0000000000..acdfbe4467 --- /dev/null +++ b/modules/axnet/src/netlink/addr.rs @@ -0,0 +1,126 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct NetlinkSocketAddr { + port: u32, + groups: GroupIdSet, +} + +impl NetlinkSocketAddr { + /// Creates a new netlink address. + pub const fn new(port: u32, groups: GroupIdSet) -> Self { + Self { port, groups } + } + + /// Creates a new unspecified address. + /// + /// Both the port ID and group numbers are left unspecified. + /// + /// Note that an unspecified address can also represent the kernel socket + /// address. + pub const fn new_unspecified() -> Self { + Self { + port: 0, + groups: GroupIdSet::new_empty(), + } + } + + pub fn is_unspecified(&self) -> bool { + self.port == 0 && self.groups.is_empty() + } + + /// Returns the port number. + pub const fn port(&self) -> u32 { + self.port + } + + pub fn set_port(&mut self, port: u32) { + self.port = port; + } + + /// Returns the group ID set. + pub const fn groups(&self) -> GroupIdSet { + self.groups + } + + /// Adds some new groups to the address. + pub fn add_groups(&mut self, groups: GroupIdSet) { + self.groups.add_groups(groups); + } +} + +/// A set of group IDs. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct GroupIdSet(u32); + +impl GroupIdSet { + /// Creates a new empty `GroupIdSet`. + pub const fn new_empty() -> Self { + Self(0) + } + + /// Creates a new `GroupIdSet` with multiple groups. + /// + /// Each 1 bit in `groups` represent a group. + pub const fn new(groups: u32) -> Self { + Self(groups) + } + + /// Creates an iterator over all group IDs. + pub const fn ids_iter(&self) -> GroupIdIter { + GroupIdIter::new(self) + } + + /// Adds some new groups. + pub fn add_groups(&mut self, groups: GroupIdSet) { + self.0 |= groups.0; + } + + /// Drops some groups. + pub fn drop_groups(&mut self, groups: GroupIdSet) { + self.0 &= !groups.0; + } + + /// Sets new groups. + pub fn set_groups(&mut self, new_groups: u32) { + self.0 = new_groups; + } + + /// Clears all groups. + pub fn clear(&mut self) { + self.0 = 0; + } + + /// Checks if the set of group IDs is empty. + pub fn is_empty(&self) -> bool { + self.0 == 0 + } + + /// Returns the group IDs as a u32. + pub fn as_u32(&self) -> u32 { + self.0 + } +} + +/// Iterator over a set of group IDs. +pub struct GroupIdIter { + groups: u32, +} + +impl GroupIdIter { + const fn new(groups: &GroupIdSet) -> Self { + Self { groups: groups.0 } + } +} + +impl Iterator for GroupIdIter { + type Item = u32; + + fn next(&mut self) -> Option { + if self.groups > 0 { + let group_id = self.groups.trailing_zeros(); + self.groups &= self.groups - 1; + return Some(group_id); + } + + None + } +} diff --git a/modules/axnet/src/netlink/message/attr/mod.rs b/modules/axnet/src/netlink/message/attr/mod.rs new file mode 100644 index 0000000000..7df433182f --- /dev/null +++ b/modules/axnet/src/netlink/message/attr/mod.rs @@ -0,0 +1,151 @@ +mod noattr; + +use alloc::vec::Vec; +use core::mem::size_of; + +use axerrno::{AxError, AxResult}; +use axio::{BufRead, Write}; +use bytemuck::{Pod, Zeroable, bytes_of}; +use memory_addr::align_up; +pub use noattr::NoAttr; + +use crate::netlink::message::{ContinueRead, NLMSG_ALIGN}; + +/// Netlink attribute header. +/// +/// Reference: . +#[repr(C)] +#[derive(Debug, Clone, Copy, Pod, Zeroable)] +pub struct AttrHeader { + len: u16, + type_: u16, +} + +impl AttrHeader { + /// Creates a new `AttrHeader` from the given type and payload length. + pub fn from_payload_len(type_: u16, payload_len: usize) -> Self { + let total_len = payload_len + size_of::(); + debug_assert!(total_len <= u16::MAX as usize); + + Self { + len: total_len as u16, + type_, + } + } + + /// Returns the attribute type, masking the nested and net byteorder flags. + pub fn type_(&self) -> u16 { + self.type_ & ATTRIBUTE_TYPE_MASK + } + + /// Returns the length of the attribute payload. + pub fn payload_len(&self) -> usize { + self.len as usize - size_of::() + } + + /// Returns the total length of the attribute, including header and payload. + pub fn total_len(&self) -> usize { + self.len as usize + } + + /// Returns the total length of the attribute, including padding. + pub fn total_len_with_padding(&self) -> usize { + align_up(self.len as usize, NLMSG_ALIGN) + } + + /// Returns the length of the padding after the attribute payload. + pub fn padding_len(&self) -> usize { + self.total_len_with_padding() - self.total_len() + } +} + +const IS_NESTED_MASK: u16 = 1u16 << 15; +const IS_NET_BYTEORDER_MASK: u16 = 1u16 << 14; +const ATTRIBUTE_TYPE_MASK: u16 = !(IS_NESTED_MASK | IS_NET_BYTEORDER_MASK); + +/// Trait for Netlink attributes. +pub trait Attribute: Send + Sync { + /// Returns the attribute type. + fn type_(&self) -> u16; + + /// Returns the attribute payload as bytes. + fn payload_as_bytes(&self) -> &[u8]; + + /// Returns the total length of the attribute, including padding. + fn total_len_with_padding(&self) -> usize { + const DUMMY_TYPE: u16 = 0; + + AttrHeader::from_payload_len(DUMMY_TYPE, self.payload_as_bytes().len()) + .total_len_with_padding() + } + + /// Reads the attribute from the given header and reader. + fn read_from(header: &AttrHeader, reader: &mut impl BufRead) -> AxResult> + where + Self: Sized; + + /// Reads all attributes from the `reader` until `total_len` bytes are read. + fn read_all_from( + reader: &mut impl BufRead, + mut total_len: usize, + ) -> AxResult>> + where + Self: Sized, + { + let mut res = Vec::new(); + + while total_len > 0 { + if total_len < size_of::() { + reader.consume(total_len); + return Ok(ContinueRead::SkippedErr(AxError::InvalidInput)); + } + + let mut buf = [0u8; size_of::()]; + reader.read_exact(&mut buf)?; + let header: AttrHeader = *bytemuck::from_bytes(&buf); + total_len -= size_of::(); + if header.total_len() < size_of::() { + reader.consume(total_len); + return Ok(ContinueRead::SkippedErr(AxError::InvalidInput)); + } + + if header.payload_len() > total_len { + reader.consume(total_len); + return Ok(ContinueRead::SkippedErr(AxError::InvalidInput)); + } + total_len -= header.payload_len(); + + match Self::read_from(&header, reader)? { + ContinueRead::Parsed(attr) => res.push(attr), + ContinueRead::Skipped => (), + ContinueRead::SkippedErr(err) => { + reader.consume(total_len); + return Ok(ContinueRead::SkippedErr(err)); + } + } + + let padding_len = total_len.min(header.padding_len()); + reader.consume(padding_len); + total_len -= padding_len; + } + + Ok(ContinueRead::Parsed(res)) + } + + /// Writes the attribute to the `writer`. + fn write_to(&self, writer: &mut impl Write) -> AxResult { + let type_ = self.type_(); + let payload = self.payload_as_bytes(); + + let header = AttrHeader::from_payload_len(type_, payload.len()); + writer.write_all(bytes_of(&header))?; + writer.write_all(payload)?; + + let padding_len = header.padding_len(); + if padding_len > 0 { + writer.write_all(&[0u8; 8][..padding_len])?; + } + + Ok(()) + } +} diff --git a/modules/axnet/src/netlink/message/attr/noattr.rs b/modules/axnet/src/netlink/message/attr/noattr.rs new file mode 100644 index 0000000000..10311b7ca5 --- /dev/null +++ b/modules/axnet/src/netlink/message/attr/noattr.rs @@ -0,0 +1,43 @@ +use alloc::vec::Vec; + +use axerrno::AxResult; +use axio::BufRead; + +use super::{Attribute, AttrHeader}; +use crate::netlink::message::ContinueRead; + +/// An attribute type that represents no attribute. +#[derive(Debug, Clone)] +pub enum NoAttr {} + +impl Attribute for NoAttr { + fn type_(&self) -> u16 { + match *self {} + } + + fn payload_as_bytes(&self) -> &[u8] { + match *self {} + } + + fn read_from(header: &AttrHeader, reader: &mut impl BufRead) -> AxResult> + where + Self: Sized, + { + let payload_len = header.payload_len(); + reader.consume(payload_len); + + Ok(ContinueRead::Skipped) + } + + fn read_all_from( + reader: &mut impl BufRead, + total_len: usize, + ) -> AxResult>> + where + Self: Sized, + { + reader.consume(total_len); + + Ok(ContinueRead::Skipped) + } +} diff --git a/modules/axnet/src/netlink/message/mod.rs b/modules/axnet/src/netlink/message/mod.rs new file mode 100644 index 0000000000..4a7e8a3834 --- /dev/null +++ b/modules/axnet/src/netlink/message/mod.rs @@ -0,0 +1,56 @@ +pub mod attr; +pub mod result; +pub mod segment; + +use alloc::vec::Vec; + +use axerrno::AxResult; +use axio::{BufRead, Write}; + +use self::{ + result::ContinueRead, + segment::{ErrorSegment, SegmentHeader}, +}; +use crate::netlink::receiver::QueueableMessage; + +/// Netlink message with protocol-specific segments. +#[derive(Debug, Clone)] +pub struct Message { + segments: Vec, +} + +impl Message { + /// Creates a new message with the given segments. + pub fn new(segments: Vec) -> Self { + Self { segments } + } + + /// Writes the message to the given writer. + pub fn write_to(&self, writer: &mut impl Write) -> AxResult<()> { + for segment in &self.segments { + segment.write_to(writer)?; + } + Ok(()) + } +} + +impl QueueableMessage for Message { + /// Returns the total length of the message, including all segments. + fn total_len(&self) -> usize { + self.segments + .iter() + .map(|segment| segment.header().len as usize) + .sum() + } +} + +/// Protocol-specific segment trait. +pub trait ProtocolSegment: Sized { + fn header(&self) -> &SegmentHeader; + fn header_mut(&mut self) -> &mut SegmentHeader; + fn read_from(reader: &mut impl BufRead) -> AxResult>; + fn write_to(&self, writer: &mut impl Write) -> AxResult<()>; +} + +/// Alignment for netlink messages. +pub(super) const NLMSG_ALIGN: usize = 4; diff --git a/modules/axnet/src/netlink/message/result.rs b/modules/axnet/src/netlink/message/result.rs new file mode 100644 index 0000000000..9795002470 --- /dev/null +++ b/modules/axnet/src/netlink/message/result.rs @@ -0,0 +1,43 @@ +use axerrno::{AxError, ax_err_type}; + +/// Result of reading a protocol-specific segment. +#[derive(Debug)] +pub enum ContinueRead { + Parsed(T), + Skipped, + SkippedErr(E), +} + +impl ContinueRead { + /// Creates a [`SkippedErr`] variant with the given error information. + /// + /// [`SkippedErr`]: Self::SkippedErr + pub fn skipped_with_error(errno: AxError, msg: &'static str) -> Self { + Self::SkippedErr(ax_err_type!(errno, msg)) + } +} + +impl ContinueRead { + /// Maps a `ContinueRead` to `ContinueRead` by applying a + /// function to a contained `Parsed` value. + pub fn map U>(self, f: F) -> ContinueRead { + match self { + ContinueRead::Parsed(t) => ContinueRead::Parsed(f(t)), + ContinueRead::Skipped => ContinueRead::Skipped, + ContinueRead::SkippedErr(e) => ContinueRead::SkippedErr(e), + } + } + + /// Maps a `ContinueRead` to `ContinueRead` by applying a + /// function to a contained `SkippedErr` value. + pub fn map_err(self, f: F) -> ContinueRead + where + F: FnOnce(E) -> U, + { + match self { + Self::Parsed(val) => ContinueRead::Parsed(val), + Self::Skipped => ContinueRead::Skipped, + Self::SkippedErr(err) => ContinueRead::SkippedErr(f(err)), + } + } +} diff --git a/modules/axnet/src/netlink/message/segment/ack.rs b/modules/axnet/src/netlink/message/segment/ack.rs new file mode 100644 index 0000000000..e53dafb502 --- /dev/null +++ b/modules/axnet/src/netlink/message/segment/ack.rs @@ -0,0 +1,81 @@ +use alloc::vec::Vec; + +use axerrno::AxError; +use bytemuck::{Pod, Zeroable}; + +use crate::netlink::message::{ + attr::NoAttr, + segment::{SegHdrCommonFlags, SegmentBody, SegmentCommon, SegmentHeader, SegmentType}, +}; + +/// Acknowledgment segment without attributes. +pub type DoneSegment = SegmentCommon; + +/// Body of a Done segment. +#[repr(C)] +#[derive(Debug, Clone, Copy, Pod, Zeroable)] +pub struct DoneSegmentBody { + error_code: i32, +} + +impl SegmentBody for DoneSegmentBody { + type CType = DoneSegmentBody; +} + +impl DoneSegment { + /// Creates a new Done segment from the given request header and optional error. + pub fn new_from_request(request_header: &SegmentHeader, error: Option) -> Self { + let header = SegmentHeader { + len: 0, + type_: SegmentType::DONE as _, + flags: SegHdrCommonFlags::empty().bits(), + seq: request_header.seq, + pid: request_header.pid, + }; + + let body = { + let error_code = error.map_or(0, |e| -(e.code() as i32)); + DoneSegmentBody { error_code } + }; + + Self::new(header, body, Vec::new()) + } +} + +/// Error segment without attributes. +pub type ErrorSegment = SegmentCommon; + +/// Body of an Error segment. +#[repr(C)] +#[derive(Debug, Pod, Clone, Copy, Zeroable)] +pub struct ErrorSegmentBody { + error_code: i32, + request_header: SegmentHeader, +} + +impl SegmentBody for ErrorSegmentBody { + type CType = ErrorSegmentBody; +} + +impl ErrorSegment { + /// Creates a new Error segment from the given request header and optional error. + pub fn new_from_request(request_header: &SegmentHeader, error: Option) -> Self { + let header = SegmentHeader { + len: 0, + type_: SegmentType::ERROR as _, + flags: SegHdrCommonFlags::empty().bits(), + seq: request_header.seq, + pid: request_header.pid, + }; + + let body = { + let error_code = error.map_or(0, |e| -(e.code() as i32)); + ErrorSegmentBody { + error_code, + request_header: *request_header, + } + }; + + Self::new(header, body, Vec::new()) + } +} diff --git a/modules/axnet/src/netlink/message/segment/common.rs b/modules/axnet/src/netlink/message/segment/common.rs new file mode 100644 index 0000000000..2975b0b24c --- /dev/null +++ b/modules/axnet/src/netlink/message/segment/common.rs @@ -0,0 +1,108 @@ +use alloc::vec::Vec; +use core::mem::size_of; + +use axerrno::AxResult; +use axio::{BufRead, Write}; +use bytemuck::bytes_of; + +use crate::netlink::message::{ + attr::Attribute, + result::ContinueRead, + segment::{SegmentBody, SegmentHeader}, +}; + +/// Common segment structure with body and attributes. +#[derive(Debug, Clone)] +pub struct SegmentCommon { + header: SegmentHeader, + body: Body, + attrs: Vec, +} + +impl SegmentCommon { + pub const HEADER_LEN: usize = size_of::(); + + /// Returns a reference to the segment header. + pub fn header(&self) -> &SegmentHeader { + &self.header + } + + /// Returns a mutable reference to the segment header. + pub fn header_mut(&mut self) -> &mut SegmentHeader { + &mut self.header + } + + /// Returns a reference to the segment body. + pub fn body(&self) -> &Body { + &self.body + } + + /// Returns a reference to the segment attributes. + pub fn attrs(&self) -> &[Attr] { + &self.attrs + } +} + +impl SegmentCommon { + pub const BODY_LEN: usize = size_of::(); + + /// Creates a new segment with the given header, body, and attributes. + pub fn new(header: SegmentHeader, body: Body, attrs: Vec) -> Self { + let mut res = Self { + header, + body, + attrs, + }; + res.header.len = res.total_len() as u32; + res + } + + /// Reads a segment from the given header and reader. + pub fn read_from( + header: &SegmentHeader, + reader: &mut impl BufRead, + ) -> AxResult> { + let (body, remain_len) = match Body::read_from(header, reader)? { + ContinueRead::Parsed(parsed) => parsed, + ContinueRead::Skipped => return Ok(ContinueRead::Skipped), + ContinueRead::SkippedErr(err) => return Ok(ContinueRead::SkippedErr(err)), + }; + + let attrs = match Attr::read_all_from(reader, remain_len)? { + ContinueRead::Parsed(attrs) => attrs, + ContinueRead::Skipped => Vec::new(), + ContinueRead::SkippedErr(err) => return Ok(ContinueRead::SkippedErr(err)), + }; + + Ok(ContinueRead::Parsed(Self { + header: *header, + body, + attrs, + })) + } + + /// Writes the segment to the given writer. + pub fn write_to(&self, writer: &mut impl Write) -> AxResult { + writer.write_all(bytes_of(&self.header))?; + self.body.write_to(writer)?; + for attr in &self.attrs { + attr.write_to(writer)?; + } + Ok(()) + } + + /// Returns the total length of the segment, including header, body, and attributes. + pub fn total_len(&self) -> usize { + Self::HEADER_LEN + Self::BODY_LEN + self.attrs_len() + } +} + +impl SegmentCommon { + /// Returns the total length of the segment attributes, including padding. + pub fn attrs_len(&self) -> usize { + self.attrs + .iter() + .map(|attr| attr.total_len_with_padding()) + .sum() + } +} diff --git a/modules/axnet/src/netlink/message/segment/header.rs b/modules/axnet/src/netlink/message/segment/header.rs new file mode 100644 index 0000000000..40e9314b61 --- /dev/null +++ b/modules/axnet/src/netlink/message/segment/header.rs @@ -0,0 +1,112 @@ +use core::mem::size_of; + +use axerrno::{AxResult, ax_err_type}; +use bitflags::bitflags; +use bytemuck::{Pod, Zeroable}; +use memory_addr::align_up; + +use crate::netlink::message::NLMSG_ALIGN; + +/// `nlmsghdr` in Linux. +/// +/// Reference: . +#[repr(C)] +#[derive(Debug, Clone, Copy, Pod, Zeroable)] +pub struct SegmentHeader { + /// Length of the message, including the header + pub len: u32, + /// Type of message content + pub type_: u16, + /// Additional flags + pub flags: u16, + /// Sequence number + pub seq: u32, + /// Sending process port ID + pub pid: u32, +} + +impl SegmentHeader { + /// Returns the payload length (including padding). + pub fn padded_payload_len(&self) -> AxResult { + // Validate `self.len`. + let payload_len = (self.len as usize) + .checked_sub(size_of::()) + .ok_or_else(|| ax_err_type!(InvalidInput, "the message length is too small"))?; + + Ok(align_up(payload_len, NLMSG_ALIGN)) + } +} + +bitflags! { + /// Common flags used in [`CMsgSegHdr`]. + /// + /// Reference: . + pub struct SegHdrCommonFlags: u16 { + /// Indicates a request message + const REQUEST = 0x01; + /// Multipart message, terminated by NLMSG_DONE + const MULTI = 0x02; + /// Reply with an acknowledgment, with zero or an error code + const ACK = 0x04; + /// Echo this request + const ECHO = 0x08; + /// Dump was inconsistent due to sequence change + const DUMP_INTR = 0x10; + /// Dump was filtered as requested + const DUMP_FILTERED = 0x20; + } +} + +bitflags! { + /// Modifiers for GET requests. + /// + /// Reference: . + pub struct GetRequestFlags: u16 { + /// Specify the tree root + const ROOT = 0x100; + /// Return all matching results + const MATCH = 0x200; + /// Atomic get request + const ATOMIC = 0x400; + /// Combination flag for root and match + const DUMP = Self::ROOT.bits() | Self::MATCH.bits(); + } +} + +bitflags! { + /// Modifiers for NEW requests. + /// + /// Reference: . + pub struct NewRequestFlags: u16 { + /// Override existing entries + const REPLACE = 0x100; + /// Do not modify if it exists + const EXCL = 0x200; + /// Create if it does not exist + const CREATE = 0x400; + /// Add to the end of the list + const APPEND = 0x800; + } +} + +bitflags! { + /// Modifiers for DELETE requests. + /// + /// Reference: . + pub struct DeleteRequestFlags: u16 { + /// Do not delete recursively + const NONREC = 0x100; + /// Delete multiple objects + const BULK = 0x200; + } +} + +bitflags! { + /// Flags for ACK messages. + /// + /// Reference: . + pub struct AckFlags: u16 { + const CAPPED = 0x100; + const ACK_TLVS = 0x100; + } +} diff --git a/modules/axnet/src/netlink/message/segment/mod.rs b/modules/axnet/src/netlink/message/segment/mod.rs new file mode 100644 index 0000000000..6f3ba1c58d --- /dev/null +++ b/modules/axnet/src/netlink/message/segment/mod.rs @@ -0,0 +1,110 @@ +mod ack; +mod common; +mod header; + +use core::mem::size_of; + +use axerrno::{AxError, AxResult}; +use axio::{BufRead, Write}; +use bytemuck::{Pod, bytes_of}; +use memory_addr::align_up; +use num_enum::TryFromPrimitive; + +pub use self::{ + ack::{DoneSegment, ErrorSegment}, + common::SegmentCommon, + header::*, +}; +use super::{ContinueRead, NLMSG_ALIGN}; +use crate::netlink::read_pod; + +pub trait SegmentBody: Sized + Copy + Clone { + // The actual message body should be `Self::CType`, + // but older versions of Linux use a legacy type (usually `CRtGenMsg` here). + // Reference: . + // FIXME: Verify whether the legacy type includes any types other than `CRtGenMsg`. + type CLegacyType: Pod = Self::CType; + type CType: Pod + TryInto + From + From; + + fn read_from( + header: &SegmentHeader, + reader: &mut impl BufRead, + ) -> AxResult> { + let mut remaining_len = header.padded_payload_len()?; + + let (c_type, padding_len) = if remaining_len >= size_of::() { + let c_type = read_pod::(reader)?; + remaining_len -= size_of::(); + + (c_type, Self::padding_len()) + } else if remaining_len >= size_of::() { + let legacy = read_pod::(reader)?; + remaining_len -= size_of::(); + + (Self::CType::from(legacy), Self::legacy_padding_len()) + } else { + reader.consume(remaining_len); + return Ok(ContinueRead::SkippedErr(AxError::InvalidInput)); + }; + + let padding_len = padding_len.min(remaining_len); + reader.consume(padding_len); + remaining_len -= padding_len; + + match c_type.try_into() { + Ok(body) => Ok(ContinueRead::Parsed((body, remaining_len))), + Err(_err) => { + reader.consume(remaining_len); + Ok(ContinueRead::SkippedErr(AxError::InvalidInput)) + } + } + } + + fn write_to(&self, writer: &mut impl Write) -> AxResult { + let c_body = Self::CType::from(*self); + writer.write_all(bytes_of(&c_body))?; + + let padding_len = Self::padding_len(); + if padding_len > 0 { + writer.write_all(&[0u8; 8][..padding_len])?; + } + + Ok(()) + } + + fn padding_len() -> usize { + let payload_len = size_of::(); + align_up(payload_len, NLMSG_ALIGN) - payload_len + } + + fn legacy_padding_len() -> usize { + let payload_len = size_of::(); + align_up(payload_len, NLMSG_ALIGN) - payload_len + } +} + +#[repr(u16)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, TryFromPrimitive)] +#[expect(clippy::upper_case_acronyms)] +pub enum SegmentType { + // Standard netlink message types + NOOP = 1, + ERROR = 2, + DONE = 3, + OVERRUN = 4, + + // protocol-level types + NEWLINK = 16, + DELLINK = 17, + GETLINK = 18, + SETLINK = 19, + + NEWADDR = 20, + DELADDR = 21, + GETADDR = 22, + + NEWROUTE = 24, + DELROUTE = 25, + GETROUTE = 26, + // TODO: The list is not exhaustive. +} diff --git a/modules/axnet/src/netlink/mod.rs b/modules/axnet/src/netlink/mod.rs new file mode 100644 index 0000000000..0e69180387 --- /dev/null +++ b/modules/axnet/src/netlink/mod.rs @@ -0,0 +1,201 @@ +mod addr; +mod message; +mod receiver; +mod route; +mod table; + +use core::{mem::size_of, task::Context}; + +use axerrno::{AxError, AxResult, LinuxError, ax_bail}; +use axio::prelude::*; +use axpoll::{IoEvents, Pollable}; +use bytemuck::Pod; +use enum_dispatch::enum_dispatch; +use spin::RwLock; + +pub use self::{ + addr::{GroupIdSet, NetlinkSocketAddr}, + route::RouteTransport, +}; +use crate::{ + RecvOptions, SendOptions, Shutdown, SocketAddrEx, SocketOps, + options::{Configurable, GetSocketOption, SetSocketOption}, +}; + +/// Reads a `Pod` value from a `BufRead` source. +pub(crate) fn read_pod(reader: &mut impl BufRead) -> AxResult { + let mut buf = alloc::vec![0u8; size_of::()]; + reader.read_exact(&mut buf)?; + Ok(bytemuck::pod_read_unaligned(&buf)) +} + +/// Trait for Netlink transport operations +#[enum_dispatch] +pub trait NetlinkTransportOps: Configurable + Pollable + Send + Sync { + fn bind(&self, local_addr: &mut NetlinkSocketAddr) -> AxResult; + fn send(&self, src: impl Read + IoBuf, port: u32, options: SendOptions) -> AxResult; + fn recv(&self, dst: impl Write + IoBufMut, options: RecvOptions<'_>) -> AxResult; + fn shutdown(&self, _how: Shutdown, _local_addr: Option<&NetlinkSocketAddr>) -> AxResult { + Ok(()) + } +} + +/// Enum for different Netlink transport implementations +#[enum_dispatch(Configurable, NetlinkTransportOps)] +pub enum NetlinkTransport { + Route(RouteTransport), + // TODO: more netlink transport support +} + +impl Pollable for NetlinkTransport { + fn poll(&self) -> IoEvents { + match self { + NetlinkTransport::Route(route) => route.poll(), + } + } + + fn register(&self, context: &mut Context<'_>, events: IoEvents) { + match self { + NetlinkTransport::Route(route) => route.register(context, events), + } + } +} + +/// Netlink socket implementation +pub struct NetlinkSocket { + transport: NetlinkTransport, + local_addr: RwLock>, + remote_addr: RwLock>, +} + +impl NetlinkSocket { + /// Creates a new Netlink socket with the specified transport. + pub fn new(transport: impl Into) -> Self { + Self { + transport: transport.into(), + local_addr: RwLock::new(None), + remote_addr: RwLock::new(Some(NetlinkSocketAddr::new_unspecified())), + } + } +} + +impl Configurable for NetlinkSocket { + fn get_option_inner(&self, opt: &mut GetSocketOption) -> AxResult { + self.transport.get_option_inner(opt) + } + + fn set_option_inner(&self, opt: SetSocketOption) -> AxResult { + self.transport.set_option_inner(opt) + } +} + +impl SocketOps for NetlinkSocket { + /// Binds the socket to a local Netlink address. + fn bind(&self, local_addr: SocketAddrEx) -> AxResult { + let mut local_addr = local_addr.into_netlink()?; + let mut guard = self.local_addr.write(); + if guard.is_some() { + ax_bail!(InvalidInput, "already bound"); + } + self.transport.bind(&mut local_addr)?; + *guard = Some(local_addr); + info!("Netlink socket bound to {:?}", local_addr); + Ok(()) + } + + /// Connects the socket to a remote Netlink address. + fn connect(&self, remote_addr: SocketAddrEx) -> AxResult { + // Ensures the socket is bound before connecting. + if self.local_addr.read().is_none() { + self.bind(SocketAddrEx::Netlink(NetlinkSocketAddr::new_unspecified()))?; + } + let remote_addr = remote_addr.into_netlink()?; + let mut guard = self.remote_addr.write(); + *guard = Some(remote_addr); + info!("Netlink socket connected to {:?}", remote_addr); + Ok(()) + } + + /// Sends data through the Netlink socket. + fn send(&self, src: impl Read + IoBuf, options: SendOptions) -> AxResult { + // Ensure the socket is bound before sending. + if self.local_addr.read().is_none() { + self.bind(SocketAddrEx::Netlink(NetlinkSocketAddr::new_unspecified()))?; + } + let remote_addr = options.to.clone().map_or_else( + || { + // If no remote address is specified, use the connected address. + self.remote_addr + .read() + .clone() + .ok_or_else(|| AxError::from(LinuxError::EDESTADDRREQ)) + }, + |addr| addr.into_netlink(), + )?; + if !remote_addr.is_unspecified() { + ax_bail!( + NotConnected, + "sending netlink route messages to user space is not supported" + ); + } + if !options.cmsg.is_empty() { + ax_bail!( + InvalidInput, + "control messages are not supported for netlink sockets" + ); + } + + self.transport.send( + src, + self.local_addr.read().as_ref().unwrap().port(), + options, + ) + } + + /// Receives data from the Netlink socket. + fn recv(&self, dst: impl Write + IoBufMut, options: RecvOptions<'_>) -> AxResult { + self.transport.recv(dst, options) + } + + /// Gets the local address of the Netlink socket. + fn local_addr(&self) -> AxResult { + match self.local_addr.try_read() { + Some(addr) => addr.map(SocketAddrEx::Netlink).ok_or(AxError::NotConnected), + None => Err(AxError::NotConnected), + } + } + + /// Gets the peer (remote) address of the Netlink socket. + fn peer_addr(&self) -> AxResult { + match self.remote_addr.try_read() { + Some(addr) => addr.map(SocketAddrEx::Netlink).ok_or(AxError::NotConnected), + None => Err(AxError::NotConnected), + } + } + + /// Shuts down the Netlink socket. + fn shutdown(&self, how: Shutdown) -> AxResult { + self.transport + .shutdown(how, self.local_addr.read().as_ref()) + } +} + +impl Pollable for NetlinkSocket { + fn poll(&self) -> IoEvents { + if self.local_addr.read().is_none() { + return IoEvents::empty(); + } + self.transport.poll() + } + + fn register(&self, context: &mut Context<'_>, events: IoEvents) { + self.transport.register(context, events) + } +} + +impl Drop for NetlinkSocket { + fn drop(&mut self) { + trace!("Dropping netlink socket"); + self.shutdown(Shutdown::Both).ok(); + } +} diff --git a/modules/axnet/src/netlink/receiver.rs b/modules/axnet/src/netlink/receiver.rs new file mode 100644 index 0000000000..efd05d59ca --- /dev/null +++ b/modules/axnet/src/netlink/receiver.rs @@ -0,0 +1,100 @@ +use alloc::{collections::VecDeque, sync::Arc}; + +use axerrno::{AxError, AxResult, LinuxError}; +use axpoll::PollSet; +use axsync::Mutex; + +/// Receiver for Netlink messages. +#[derive(Clone)] +pub struct MessageReceiver { + message_queue: Arc>>, + poller: Arc, +} + +/// Queue for Netlink messages. +pub(super) struct MessageQueue { + messages: VecDeque, + total_length: usize, + error: Option, +} + +impl MessageQueue { + /// Creates a new MessageQueue and its corresponding MessageReceiver. + pub(super) fn new_pair(poller: Arc) -> (Arc>, MessageReceiver) { + let queue = Arc::new(Mutex::new(Self { + messages: VecDeque::new(), + total_length: 0, + error: None, + })); + let receiver = MessageReceiver { + message_queue: queue.clone(), + poller: poller.clone(), + }; + (queue, receiver) + } + + /// Checks if the message queue is empty. + pub(super) fn is_empty(&self) -> bool { + self.messages.is_empty() + } +} + +/// Trait for messages that can be queued. +pub trait QueueableMessage { + fn total_len(&self) -> usize; +} + +impl MessageQueue { + /// Dequeues a message if the provided function indicates to do so. + pub(super) fn dequeue_if(&mut self, f: F) -> AxResult + where + F: FnOnce(&Message, usize) -> AxResult<(bool, R)>, + { + if let Some(error) = self.error.take() { + return Err(error); + } + + let Some(message) = self.messages.front() else { + debug!("No message to dequeue"); + return Err(AxError::WouldBlock); + }; + + let length = message.total_len(); + let (should_pop, result) = f(message, length)?; + if should_pop { + self.messages.pop_front().unwrap(); + self.total_length -= length; + } + + Ok(result) + } + + /// Enqueues a new message into the queue. + #[must_use] + fn enqueue(&mut self, message: Message) -> bool { + let length = message.total_len(); + + if self.total_length.saturating_add(length) > crate::consts::NETLINK_DEFAULT_BUF_SIZE { + self.error = Some(AxError::from(LinuxError::ENOBUFS)); + return false; + } + + self.messages.push_back(message); + self.total_length += length; + + true + } +} + +impl MessageReceiver { + /// Enqueues a message into the receiver's message queue. + pub(super) fn enqueue_message(&self, message: Message) { + let is_ok = self.message_queue.lock().enqueue(message); + if is_ok { + trace!("Message enqueued successfully"); + self.poller.wake(); + } else { + warn!("Failed to enqueue message"); + } + } +} diff --git a/modules/axnet/src/netlink/route/handle.rs b/modules/axnet/src/netlink/route/handle.rs new file mode 100644 index 0000000000..cde2dcdbf4 --- /dev/null +++ b/modules/axnet/src/netlink/route/handle.rs @@ -0,0 +1,224 @@ +use alloc::{boxed::Box, ffi::CString, vec, vec::Vec}; +use core::num::{NonZero, NonZeroU32}; + +use axerrno::{AxError, AxResult, ax_bail}; + +use crate::{ + device::Device, + get_service, + netlink::{ + message::{ProtocolSegment, segment::*}, + route::{RouteTransport, message::*}, + table::{NETLINK_BIND_TABLE, ProtocolBindTable}, + }, +}; + +/// An helper function to access the route protocol bind table. +pub(crate) fn with_route_table( + f: impl FnOnce(&mut ProtocolBindTable) -> AxResult, +) -> AxResult { + let mut table = NETLINK_BIND_TABLE.route.write(); + f(&mut table) +} + +impl RouteTransport { + /// Handles a GetLink request segment. + pub fn get_link(request_segment: &LinkSegment) -> AxResult> { + let filter_by = FilterBy::from_request(request_segment)?; + + // Generate response segments based on the filter. + let mut response_segments: Vec = get_service() + .iter_devices() + .filter(|dev| match &filter_by { + FilterBy::Dump => true, + FilterBy::Index(index) => dev.get_index() == *index, + FilterBy::Name(name) => dev.name() == *name, + }) + .map(|dev| dev_to_new_link(request_segment.header(), dev)) + .map(RouteSegment::NewLink) + .collect(); + + let dump_all = matches!(filter_by, FilterBy::Dump); + + if !dump_all && response_segments.is_empty() { + ax_bail!(NoSuchDevice, "no matching link found"); + } + + finish_response(request_segment.header(), dump_all, &mut response_segments); + Ok(response_segments) + } + + /// Handles a GetAddr request segment. + pub fn get_addr(request_segment: &AddrSegment) -> AxResult> { + let dump_all = { + let flags = GetRequestFlags::from_bits_truncate(request_segment.header().flags); + flags.contains(GetRequestFlags::DUMP) + }; + if !dump_all { + ax_bail!(Unsupported, "GETADDR only supports dump requests"); + } + + let mut response_segments: Vec = get_service() + .iter_devices() + // Get_addr only support dump, so no filtering needed + .filter_map(|iface| dev_to_new_addr(request_segment.header(), iface)) + .map(RouteSegment::NewAddr) + .collect(); + + finish_response(request_segment.header(), dump_all, &mut response_segments); + + Ok(response_segments) + } + + /// Handles a route request segment. + pub fn handle_request(&self, request: &RouteSegment, dst_port: u32) { + trace!("Handling request: {:?} for port {}", request, dst_port); + let response_segments = match request { + RouteSegment::GetLink(request_segment) => RouteTransport::get_link(request_segment), + RouteSegment::GetAddr(request_segment) => RouteTransport::get_addr(request_segment), + _ => Err(AxError::Unsupported), + }; + trace!("Response segments: {:?}", response_segments); + let response = match response_segments { + Ok(segments) => RouteMessage::new(segments), + Err(_) => { + // TODO: Future should build RouteSegment::Error(ErrorSegment::new_from_request(...)) and unicast it to `dst_port` instead of returning. + return; + } + }; + + if let Err(e) = with_route_table(|table| table.unicast(dst_port, response)) { + warn!( + "Failed to unicast netlink response to port {}: {:?}", + dst_port, e + ); + } + } +} + +/// Get a new link segment from a device. +fn dev_to_new_link(request_header: &SegmentHeader, dev: &Box) -> LinkSegment { + let header = SegmentHeader { + len: 0, + type_: SegmentType::NEWLINK as _, + flags: SegHdrCommonFlags::empty().bits(), + seq: request_header.seq, + pid: request_header.pid, + }; + + const AF_UNSPEC: u8 = 0; + let link_message = LinkSegmentBody { + family: AF_UNSPEC, + type_: dev.get_type(), + index: NonZero::new(dev.get_index()), + flags: dev.get_flags(), + }; + + let name = CString::new(dev.name()).unwrap_or_default(); + let attrs = vec![ + LinkAttr::Name(name), + LinkAttr::Mtu(crate::consts::STANDARD_MTU as u32), + ]; + + LinkSegment::new(header, link_message, attrs) +} + +/// Get a new address segment from a device. +fn dev_to_new_addr(request_header: &SegmentHeader, dev: &Box) -> Option { + let ipv4_addr = dev.ipv4_addr()?; + let prefix_len = dev.prefix_len()?; + + let header = SegmentHeader { + len: 0, + type_: SegmentType::NEWADDR as _, + flags: SegHdrCommonFlags::empty().bits(), + seq: request_header.seq, + pid: request_header.pid, + }; + + const AF_INET: u8 = 2; + let addr_message = AddrSegmentBody { + family: AF_INET as _, + prefix_len, + flags: AddrMessageFlags::PERMANENT, + scope: RtScope::HOST, + index: NonZeroU32::new(dev.get_index()), + }; + + let label = CString::new(dev.name()).unwrap_or_default(); + let attrs = vec![ + AddrAttr::Address(ipv4_addr.octets()), + AddrAttr::Label(label), + AddrAttr::Local(ipv4_addr.octets()), + ]; + + Some(AddrSegment::new(header, addr_message, attrs)) +} + +/// Finalizes the response segments. +pub fn finish_response( + request_header: &SegmentHeader, + dump_all: bool, + response_segments: &mut Vec, +) { + if !dump_all { + debug_assert_eq!( + response_segments.len(), + 1, + "non-dump response should have exactly one segment" + ); + return; + } + append_done_segment(request_header, response_segments); + add_multi_flag(response_segments); +} + +/// Appends a done segment as the last segment of the provided segments. +fn append_done_segment(request_header: &SegmentHeader, response_segments: &mut Vec) { + let done_segment = DoneSegment::new_from_request(request_header, None); + response_segments.push(RouteSegment::Done(done_segment)); +} + +/// Adds the `MULTI` flag to all segments in `segments`. +fn add_multi_flag(response_segments: &mut [RouteSegment]) { + for segment in response_segments.iter_mut() { + let header = segment.header_mut(); + let mut flags = SegHdrCommonFlags::from_bits_truncate(header.flags); + flags |= SegHdrCommonFlags::MULTI; + header.flags = flags.bits(); + } +} + +/// Filter criteria for GetLink requests. +enum FilterBy<'a> { + Index(u32), + Name(&'a str), + Dump, +} + +impl<'a> FilterBy<'a> { + /// Creates a FilterBy instance from a LinkSegment request. + fn from_request(segment: &'a LinkSegment) -> AxResult { + // Dump has the highest priority. + if GetRequestFlags::from_bits_truncate(segment.header().flags) + .contains(GetRequestFlags::DUMP) + { + return Ok(Self::Dump); + } + if let Some(required_index) = segment.body().index { + return Ok(Self::Index(required_index.get())); + } + let required_name = segment.attrs().iter().find_map(|attr| { + if let LinkAttr::Name(name) = attr { + name.to_str().ok() + } else { + None + } + }); + if let Some(required_name) = required_name { + return Ok(Self::Name(required_name)); + } + + Err(AxError::InvalidInput) + } +} diff --git a/modules/axnet/src/netlink/route/message.rs b/modules/axnet/src/netlink/route/message.rs new file mode 100644 index 0000000000..4246884b78 --- /dev/null +++ b/modules/axnet/src/netlink/route/message.rs @@ -0,0 +1,12 @@ +mod attr; +mod segment; + +pub use attr::{addr::AddrAttr, link::LinkAttr}; +pub use segment::{ + RouteSegment, + addr::{AddrMessageFlags, AddrSegment, AddrSegmentBody, RtScope}, + link::{LinkSegment, LinkSegmentBody}, +}; + +/// Route message type. +pub type RouteMessage = crate::netlink::message::Message; diff --git a/modules/axnet/src/netlink/route/message/attr/addr.rs b/modules/axnet/src/netlink/route/message/attr/addr.rs new file mode 100644 index 0000000000..a6b3d9b011 --- /dev/null +++ b/modules/axnet/src/netlink/route/message/attr/addr.rs @@ -0,0 +1,94 @@ +use alloc::{ffi::CString, vec::Vec}; + +use axerrno::AxResult; +use axio::BufRead; + +use crate::netlink::message::{ + attr::{Attribute, AttrHeader}, + result::ContinueRead, +}; + +/// Address-related attributes. +/// +/// Reference: . +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u16)] +#[expect(non_camel_case_types)] +#[expect(clippy::upper_case_acronyms)] +#[allow(dead_code)] +enum AddrAttrClass { + UNSPEC = 0, + ADDRESS = 1, + LOCAL = 2, + LABEL = 3, + BROADCAST = 4, + ANYCAST = 5, + CACHEINFO = 6, + MULTICAST = 7, + FLAGS = 8, + RT_PRIORITY = 9, + TARGET_NETNSID = 10, +} + +/// Address attribute. +#[derive(Debug, Clone)] +pub enum AddrAttr { + Address([u8; 4]), + Local([u8; 4]), + Label(CString), +} + +impl AddrAttr { + /// Returns the class of the address attribute. + fn class(&self) -> AddrAttrClass { + match self { + AddrAttr::Address(_) => AddrAttrClass::ADDRESS, + AddrAttr::Local(_) => AddrAttrClass::LOCAL, + AddrAttr::Label(_) => AddrAttrClass::LABEL, + } + } +} + +impl Attribute for AddrAttr { + fn type_(&self) -> u16 { + self.class() as u16 + } + + fn payload_as_bytes(&self) -> &[u8] { + match self { + AddrAttr::Address(address) => address, + AddrAttr::Local(local) => local, + AddrAttr::Label(label) => label.as_bytes_with_nul(), + } + } + + fn read_from(header: &AttrHeader, reader: &mut impl BufRead) -> AxResult> + where + Self: Sized, + { + let payload_len = header.payload_len(); + reader.consume(payload_len); + + // GETADDR only supports dump requests. These requests do not have any + // attributes. According to the Linux behavior, we should just ignore + // all the attributes. + + Ok(ContinueRead::Skipped) + } + + fn read_all_from( + reader: &mut impl BufRead, + total_len: usize, + ) -> AxResult>> + where + Self: Sized, + { + reader.consume(total_len); + + // GETADDR only supports dump requests. These requests do not have any + // attributes. According to the Linux behavior, we should just ignore + // all the attributes. + + Ok(ContinueRead::Skipped) + } +} diff --git a/modules/axnet/src/netlink/route/message/attr/link.rs b/modules/axnet/src/netlink/route/message/attr/link.rs new file mode 100644 index 0000000000..3fbe72b468 --- /dev/null +++ b/modules/axnet/src/netlink/route/message/attr/link.rs @@ -0,0 +1,209 @@ +use alloc::ffi::CString; +use core::mem::size_of; + +use axerrno::{AxError, AxResult}; +use axio::BufRead; +use bitflags::bitflags; +use bytemuck::{Pod, Zeroable, bytes_of}; +use num_enum::TryFromPrimitive; + +use super::IFNAME_SIZE; +use crate::netlink::{ + message::{ + attr::{AttrHeader, Attribute}, + result::ContinueRead, + }, + read_pod, +}; + +/// Link-level attributes. +/// +/// Reference: . +#[derive(Debug, Clone, Copy, PartialEq, Eq, TryFromPrimitive)] +#[repr(u16)] +#[expect(non_camel_case_types)] +#[expect(clippy::upper_case_acronyms)] +enum LinkAttrClass { + UNSPEC = 0, + ADDRESS = 1, + BROADCAST = 2, + IFNAME = 3, + MTU = 4, + LINK = 5, + QDISC = 6, + STATS = 7, + COST = 8, + PRIORITY = 9, + MASTER = 10, + /// Wireless Extension event + WIRELESS = 11, + /// Protocol specific information for a link + PROTINFO = 12, + TXQLEN = 13, + MAP = 14, + WEIGHT = 15, + OPERSTATE = 16, + LINKMODE = 17, + LINKINFO = 18, + NET_NS_PID = 19, + IFALIAS = 20, + /// Number of VFs if device is SR-IOV PF + NUM_VF = 21, + VFINFO_LIST = 22, + STATS64 = 23, + VF_PORTS = 24, + PORT_SELF = 25, + AF_SPEC = 26, + /// Group the device belongs to + GROUP = 27, + NET_NS_FD = 28, + /// Extended info mask, VFs, etc. + EXT_MASK = 29, + /// Promiscuity count: > 0 means acts PROMISC + PROMISCUITY = 30, + NUM_TX_QUEUES = 31, + NUM_RX_QUEUES = 32, + CARRIER = 33, + PHYS_PORT_ID = 34, + CARRIER_CHANGES = 35, + PHYS_SWITCH_ID = 36, + LINK_NETNSID = 37, + PHYS_PORT_NAME = 38, + PROTO_DOWN = 39, + GSO_MAX_SEGS = 40, + GSO_MAX_SIZE = 41, + PAD = 42, + XDP = 43, + EVENT = 44, + NEW_NETNSID = 45, + IF_NETNSID = 46, + CARRIER_UP_COUNT = 47, + CARRIER_DOWN_COUNT = 48, + NEW_IFINDEX = 49, + MIN_MTU = 50, + MAX_MTU = 51, + PROP_LIST = 52, + /// Alternative ifname + ALT_IFNAME = 53, + PERM_ADDRESS = 54, + PROTO_DOWN_REASON = 55, + PARENT_DEV_NAME = 56, + PARENT_DEV_BUS_NAME = 57, +} + +#[derive(Debug, Clone)] +pub enum LinkAttr { + Name(CString), + Mtu(u32), + TxqLen(u32), + LinkMode(u8), + ExtMask(RtExtFilter), +} + +impl LinkAttr { + fn class(&self) -> LinkAttrClass { + match self { + LinkAttr::Name(_) => LinkAttrClass::IFNAME, + LinkAttr::Mtu(_) => LinkAttrClass::MTU, + LinkAttr::TxqLen(_) => LinkAttrClass::TXQLEN, + LinkAttr::LinkMode(_) => LinkAttrClass::LINKMODE, + LinkAttr::ExtMask(_) => LinkAttrClass::EXT_MASK, + } + } +} + +impl Attribute for LinkAttr { + fn type_(&self) -> u16 { + self.class() as u16 + } + + fn payload_as_bytes(&self) -> &[u8] { + match self { + LinkAttr::Name(name) => name.as_bytes_with_nul(), + LinkAttr::Mtu(mtu) => bytes_of(mtu), + LinkAttr::TxqLen(txq_len) => bytes_of(txq_len), + LinkAttr::LinkMode(link_mode) => bytes_of(link_mode), + LinkAttr::ExtMask(ext_filter) => bytes_of(ext_filter), + } + } + + fn read_from(header: &AttrHeader, reader: &mut impl BufRead) -> AxResult> + where + Self: Sized, + { + let payload_len = header.payload_len(); + + // TODO: Currently, `IS_NET_BYTEORDER_MASK` and `IS_NESTED_MASK` are ignored. + let Ok(class) = LinkAttrClass::try_from(header.type_()) else { + // Unknown attributes should be ignored. + // Reference: . + reader.consume(payload_len); + return Ok(ContinueRead::Skipped); + }; + + let res = match (class, payload_len) { + (LinkAttrClass::IFNAME, 1..=IFNAME_SIZE) => { + let mut buf = alloc::vec![0u8; payload_len]; + reader.read_exact(&mut buf)?; + let nul_pos = buf.iter().position(|&b| b == 0).unwrap_or(buf.len()); + let name = CString::new(&buf[..nul_pos]).map_err(|_| AxError::InvalidInput)?; + if name.as_bytes().len() == IFNAME_SIZE { + return Ok(ContinueRead::SkippedErr(AxError::OutOfRange)); + } + Self::Name(name) + } + (LinkAttrClass::MTU, 4) => Self::Mtu(read_pod::(reader)?), + (LinkAttrClass::TXQLEN, 4) => Self::TxqLen(read_pod::(reader)?), + (LinkAttrClass::LINKMODE, 1) => Self::LinkMode(read_pod::(reader)?), + (LinkAttrClass::EXT_MASK, 4) => { + const { assert!(size_of::() == 4) }; + Self::ExtMask(read_pod::(reader)?) + } + + ( + LinkAttrClass::IFNAME + | LinkAttrClass::MTU + | LinkAttrClass::TXQLEN + | LinkAttrClass::LINKMODE + | LinkAttrClass::EXT_MASK, + _, + ) => { + warn!("link attribute `{:?}` contains invalid payload", class); + reader.consume(payload_len); + return Ok(ContinueRead::SkippedErr( + if class == LinkAttrClass::IFNAME { + AxError::OutOfRange + } else { + AxError::InvalidInput + }, + )); + } + + (..) => { + warn!("link attribute `{:?}` is not supported", class); + reader.consume(payload_len); + return Ok(ContinueRead::Skipped); + } + }; + + Ok(ContinueRead::Parsed(res)) + } +} + +bitflags! { + /// New extended info filters for [`NlLinkAttr::ExtMask`]. + /// + /// Reference: . + #[repr(C)] + #[derive(Clone, Copy, PartialEq, Eq, Debug, Pod, Zeroable)] + pub struct RtExtFilter: u32 { + const VF = 1 << 0; + const BRVLAN = 1 << 1; + const BRVLAN_COMPRESSED = 1 << 2; + const SKIP_STATS = 1 << 3; + const MRP = 1 << 4; + const CFM_CONFIG = 1 << 5; + const CFM_STATUS = 1 << 6; + const MST = 1 << 7; + } +} diff --git a/modules/axnet/src/netlink/route/message/attr/mod.rs b/modules/axnet/src/netlink/route/message/attr/mod.rs new file mode 100644 index 0000000000..7358c2a536 --- /dev/null +++ b/modules/axnet/src/netlink/route/message/attr/mod.rs @@ -0,0 +1,4 @@ +pub mod addr; +pub mod link; + +const IFNAME_SIZE: usize = 16; diff --git a/modules/axnet/src/netlink/route/message/segment/addr.rs b/modules/axnet/src/netlink/route/message/segment/addr.rs new file mode 100644 index 0000000000..1e0c2808d5 --- /dev/null +++ b/modules/axnet/src/netlink/route/message/segment/addr.rs @@ -0,0 +1,107 @@ +use core::num::NonZeroU32; + +use axerrno::AxError; +use bitflags::bitflags; +use bytemuck::{Pod, Zeroable}; +use num_enum::TryFromPrimitive; + +use super::legacy::CRtGenMsg; +use crate::netlink::{ + message::segment::{SegmentBody, SegmentCommon}, + route::message::AddrAttr, +}; + +/// Address segment type. +pub type AddrSegment = SegmentCommon; + +impl SegmentBody for AddrSegmentBody { + type CLegacyType = CRtGenMsg; + type CType = CIfaddrMsg; +} + +/// `ifaddrmsg` in Linux. +#[repr(C)] +#[derive(Debug, Clone, Copy, Pod, Zeroable)] +pub struct CIfaddrMsg { + pub family: u8, + pub prefix_len: u8, + pub flags: u8, + pub scope: u8, + pub index: u32, +} + +/// Address segment body. +#[derive(Debug, Clone, Copy)] +pub struct AddrSegmentBody { + pub family: i32, + pub prefix_len: u8, + pub flags: AddrMessageFlags, + pub scope: RtScope, + pub index: Option, +} + +impl TryFrom for AddrSegmentBody { + type Error = AxError; + + fn try_from(value: CIfaddrMsg) -> Result { + let flags = AddrMessageFlags::from_bits_truncate(value.flags as u32); + let scope = RtScope::try_from(value.scope).map_err(|_| AxError::InvalidInput)?; + let index = NonZeroU32::new(value.index); + + Ok(Self { + family: value.family as i32, + prefix_len: value.prefix_len, + flags, + scope, + index, + }) + } +} + +impl From for CIfaddrMsg { + fn from(value: AddrSegmentBody) -> Self { + let index = if let Some(index) = value.index { + index.get() + } else { + 0 + }; + CIfaddrMsg { + family: value.family as u8, + prefix_len: value.prefix_len, + flags: value.flags.bits() as u8, + scope: value.scope as _, + index, + } + } +} + +bitflags! { + /// Flags for address messages. + #[derive(Debug, Clone, Copy)] + pub struct AddrMessageFlags: u32 { + const SECONDARY = 0x01; + const NODAD = 0x02; + const OPTIMISTIC = 0x04; + const DADFAILED = 0x08; + const HOMEADDRESS = 0x10; + const DEPRECATED = 0x20; + const TENTATIVE = 0x40; + const PERMANENT = 0x80; + const MANAGETEMPADDR = 0x100; + const NOPREFIXROUTE = 0x200; + const MCAUTOJOIN = 0x400; + const STABLE_PRIVACY = 0x800; + } +} + +/// Route scope. +#[repr(u8)] +#[derive(Debug, Clone, Copy, TryFromPrimitive)] +#[allow(non_camel_case_types)] +pub enum RtScope { + UNIVERSE = 0, + SITE = 200, + LINK = 253, + HOST = 254, + NOWHERE = 255, +} diff --git a/modules/axnet/src/netlink/route/message/segment/legacy.rs b/modules/axnet/src/netlink/route/message/segment/legacy.rs new file mode 100644 index 0000000000..6a31f95abc --- /dev/null +++ b/modules/axnet/src/netlink/route/message/segment/legacy.rs @@ -0,0 +1,35 @@ +use bytemuck::{Pod, Zeroable}; + +use super::{addr::CIfaddrMsg, link::CIfinfoMsg}; + +/// `rtgenmsg` in Linux. +#[repr(C)] +#[derive(Debug, Clone, Copy, Pod, Zeroable)] +pub struct CRtGenMsg { + pub family: u8, +} + +impl From for CIfinfoMsg { + fn from(value: CRtGenMsg) -> Self { + Self { + family: value.family, + _pad: 0, + type_: 0, + index: 0, + flags: 0, + change: 0, + } + } +} + +impl From for CIfaddrMsg { + fn from(value: CRtGenMsg) -> Self { + Self { + family: value.family, + prefix_len: 0, + flags: 0, + scope: 0, + index: 0, + } + } +} diff --git a/modules/axnet/src/netlink/route/message/segment/link.rs b/modules/axnet/src/netlink/route/message/segment/link.rs new file mode 100644 index 0000000000..8264456bc0 --- /dev/null +++ b/modules/axnet/src/netlink/route/message/segment/link.rs @@ -0,0 +1,79 @@ +use core::num::NonZeroU32; + +use axerrno::AxError; +use bytemuck::{Pod, Zeroable}; + +use super::legacy::CRtGenMsg; +use crate::{ + device::{DeviceFlags, DeviceType}, + netlink::{ + message::segment::{SegmentBody, SegmentCommon}, + route::message::LinkAttr, + }, +}; + +pub type LinkSegment = SegmentCommon; + +impl SegmentBody for LinkSegmentBody { + type CLegacyType = CRtGenMsg; + type CType = CIfinfoMsg; +} +/// `ifinfomsg` in Linux. +/// +/// Reference: . +#[repr(C)] +#[derive(Debug, Clone, Copy, Pod, Zeroable)] +pub struct CIfinfoMsg { + /// AF_UNSPEC + pub family: u8, + /// Padding byte + pub _pad: u8, + /// Device type + pub type_: u16, + /// Interface index + pub index: u32, + /// Device flags + pub flags: u32, + /// Change mask + pub change: u32, +} + +/// Link segment body. +#[derive(Debug, Clone, Copy)] +pub struct LinkSegmentBody { + pub family: u8, + pub type_: DeviceType, + pub index: Option, + pub flags: DeviceFlags, +} + +impl TryFrom for LinkSegmentBody { + type Error = AxError; + + fn try_from(value: CIfinfoMsg) -> Result { + let family = value.family; + let type_ = DeviceType::try_from(value.type_).map_err(|_| AxError::InvalidInput)?; + let index = NonZeroU32::new(value.index); + let flags = DeviceFlags::from_bits_truncate(value.flags); + + Ok(Self { + family, + type_, + index, + flags, + }) + } +} + +impl From for CIfinfoMsg { + fn from(value: LinkSegmentBody) -> Self { + CIfinfoMsg { + family: value.family, + _pad: 0, + type_: value.type_ as _, + index: value.index.map(NonZeroU32::get).unwrap_or(0), + flags: value.flags.bits(), + change: 0, + } + } +} diff --git a/modules/axnet/src/netlink/route/message/segment/mod.rs b/modules/axnet/src/netlink/route/message/segment/mod.rs new file mode 100644 index 0000000000..1d3969918c --- /dev/null +++ b/modules/axnet/src/netlink/route/message/segment/mod.rs @@ -0,0 +1,82 @@ +pub mod addr; +mod legacy; +pub mod link; + +use axerrno::{AxError, AxResult}; +use axio::{BufRead, Write}; + +use self::{addr::AddrSegment, link::LinkSegment}; +use crate::netlink::{ + message::{ + ProtocolSegment, + result::ContinueRead, + segment::{DoneSegment, ErrorSegment, SegmentHeader, SegmentType}, + }, + read_pod, +}; + +/// Routing segment enumeration. +#[derive(Debug, Clone)] +pub enum RouteSegment { + NewLink(LinkSegment), + GetLink(LinkSegment), + NewAddr(AddrSegment), + GetAddr(AddrSegment), + Done(DoneSegment), + Error(ErrorSegment), +} + +impl ProtocolSegment for RouteSegment { + fn header(&self) -> &SegmentHeader { + match self { + RouteSegment::NewLink(s) | RouteSegment::GetLink(s) => s.header(), + RouteSegment::NewAddr(s) | RouteSegment::GetAddr(s) => s.header(), + RouteSegment::Done(s) => s.header(), + RouteSegment::Error(s) => s.header(), + } + } + + fn header_mut(&mut self) -> &mut SegmentHeader { + match self { + RouteSegment::NewLink(s) | RouteSegment::GetLink(s) => s.header_mut(), + RouteSegment::NewAddr(s) | RouteSegment::GetAddr(s) => s.header_mut(), + RouteSegment::Done(s) => s.header_mut(), + RouteSegment::Error(s) => s.header_mut(), + } + } + + fn read_from(reader: &mut impl BufRead) -> AxResult> { + let header = read_pod::(reader)?; + + let segment = match SegmentType::try_from(header.type_) { + Ok(SegmentType::GETLINK) => { + LinkSegment::read_from(&header, reader)?.map(RouteSegment::GetLink) + } + Ok(SegmentType::GETADDR) => { + AddrSegment::read_from(&header, reader)?.map(RouteSegment::GetAddr) + } + _ => { + let payload_len = header.padded_payload_len()?; + reader.consume(payload_len); + ContinueRead::skipped_with_error( + AxError::Unsupported, + "the segment type is not supported", + ) + } + }; + + Ok(segment.map_err(|error| ErrorSegment::new_from_request(&header, Some(error)))) + } + + fn write_to(&self, writer: &mut impl Write) -> AxResult { + match self { + RouteSegment::NewLink(s) => s.write_to(writer), + RouteSegment::NewAddr(s) => s.write_to(writer), + RouteSegment::Done(s) => s.write_to(writer), + RouteSegment::Error(s) => s.write_to(writer), + RouteSegment::GetAddr(_) | RouteSegment::GetLink(_) => { + unreachable!("kernel should not write get requests to user space"); + } + } + } +} diff --git a/modules/axnet/src/netlink/route/mod.rs b/modules/axnet/src/netlink/route/mod.rs new file mode 100644 index 0000000000..7e53f0451e --- /dev/null +++ b/modules/axnet/src/netlink/route/mod.rs @@ -0,0 +1,150 @@ +pub mod handle; +pub mod message; + +use alloc::sync::Arc; +use core::task::Context; + +use axerrno::{AxError, AxResult, LinuxError}; +use axio::{BufReader, prelude::*}; +use axpoll::{IoEvents, PollSet, Pollable}; +use axsync::Mutex; +use handle::with_route_table; +use message::{RouteMessage, RouteSegment}; + +use crate::{ + RecvOptions, SendOptions, + general::GeneralOptions, + netlink::{ + NetlinkTransportOps, + addr::{GroupIdSet, NetlinkSocketAddr}, + message::ProtocolSegment, + receiver::{MessageQueue, MessageReceiver}, + }, + options::{Configurable, GetSocketOption, SetSocketOption}, +}; + +/// Netlink transport implementation for routing messages +pub struct RouteTransport { + general: GeneralOptions, + poller: Arc, + message_queue: Arc>>, + receiver: MessageReceiver, + groups: GroupIdSet, +} + +impl RouteTransport { + /// Creates a new RouteTransport instance. + pub fn new() -> Self { + let poller = Arc::new(PollSet::new()); + let (message_queue, receiver) = MessageQueue::::new_pair(poller.clone()); + RouteTransport { + general: GeneralOptions::default(), + poller, + message_queue, + receiver, + groups: GroupIdSet::new_empty(), + } + } +} + +impl Default for RouteTransport { + fn default() -> Self { + Self::new() + } +} + +impl Configurable for RouteTransport { + fn get_option_inner(&self, opt: &mut GetSocketOption) -> AxResult { + self.general.get_option_inner(opt) + } + + fn set_option_inner(&self, opt: SetSocketOption) -> AxResult { + self.general.set_option_inner(opt) + } +} + +impl NetlinkTransportOps for RouteTransport { + /// Binds the RouteTransport to the specified local address. + fn bind(&self, local_addr: &mut NetlinkSocketAddr) -> AxResult { + local_addr.add_groups(self.groups); + with_route_table(|table| table.bind(local_addr, self.receiver.clone())) + } + + /// Sends a RouteMessage to the specified port. + fn send(&self, src: impl Read + IoBuf, port: u32, _options: SendOptions) -> AxResult { + use crate::netlink::message::result::ContinueRead; + + let initial_remaining = src.remaining(); + let mut reader = BufReader::new(src); + loop { + let mut segment = match RouteSegment::read_from(&mut reader)? { + ContinueRead::Parsed(seg) => seg, + ContinueRead::Skipped => continue, + ContinueRead::SkippedErr(_error_segment) => { + // TODO: Should unicast `_error_segment` as NLMSG_ERROR with original seq/pid to `port` instead of silently continuing. + continue; + } + }; + let header = segment.header_mut(); + // Set the pid to the sender's port if it's zero + if header.pid == 0 { + header.pid = port + } + self.handle_request(&segment, port); + return Ok(initial_remaining); + } + } + + /// Receives a RouteMessage from the message queue. + fn recv( + &self, + mut dst: impl Write + IoBufMut, + mut options: RecvOptions<'_>, + ) -> AxResult { + self.general.recv_poller(self, || { + let mut message_queue = self.message_queue.lock(); + message_queue.dequeue_if(|msg, len| { + if dst.remaining_mut() < len { + return Err(AxError::from(LinuxError::ENOBUFS)); + } + trace!("recv message {:?}", msg); + if let Some(from) = options.from.as_mut() { + **from = crate::SocketAddrEx::Netlink(NetlinkSocketAddr::new_unspecified()); + } + + msg.write_to(&mut dst)?; + Ok((true, len)) + }) + }) + } + + /// Shuts down the RouteTransport, removing any bindings. + fn shutdown(&self, _how: crate::Shutdown, local_addr: Option<&NetlinkSocketAddr>) -> AxResult { + with_route_table(|table| { + if let Some(addr) = local_addr { + table.unicast_sockets.remove(&addr.port()); + + for group_id in addr.groups().ids_iter() { + let group = &mut table.multicast_groups[group_id as usize]; + group.remove_member(addr.port()); + } + } + Ok(()) + }) + } +} + +impl Pollable for RouteTransport { + fn poll(&self) -> IoEvents { + let mut events = IoEvents::OUT; + let message_queue = self.message_queue.lock(); + events.set(IoEvents::IN, !message_queue.is_empty()); + events + } + + fn register(&self, context: &mut Context<'_>, events: IoEvents) { + if events.contains(IoEvents::IN) { + self.poller.register(context.waker()); + } + } +} diff --git a/modules/axnet/src/netlink/table.rs b/modules/axnet/src/netlink/table.rs new file mode 100644 index 0000000000..1d5838b45e --- /dev/null +++ b/modules/axnet/src/netlink/table.rs @@ -0,0 +1,202 @@ +use alloc::{ + boxed::Box, + collections::{BTreeMap, BTreeSet}, +}; + +use axerrno::{AxError, AxResult}; +use axsync::Mutex; +use axtask::current; +use lazy_static::lazy_static; +use rand::{RngCore, SeedableRng, rngs::SmallRng}; +use spin::RwLock; + +use crate::netlink::{ + addr::{GroupIdSet, NetlinkSocketAddr}, + receiver::{MessageReceiver, QueueableMessage}, + route::message::RouteMessage, +}; + +const MAX_GROUPS: u32 = 32; + +lazy_static! { + /// Global Netlink bind table. + pub static ref NETLINK_BIND_TABLE: NetlinkBindTable = NetlinkBindTable::new(); + /// Random number generator for assigning port numbers. + static ref RANDOM: Random = Random::new(); +} + +const RANDOM_SEED: &[u8; 32] = b"0123456789abcdef0123456789abcdef"; + +/// A simple random number generator wrapper. +struct Random { + rng: Mutex, +} + +impl Random { + pub fn new() -> Self { + Self { + rng: Mutex::new(SmallRng::from_seed(*RANDOM_SEED)), + } + } + + /// Generates a random u32 number. + pub fn gen_u32(&self) -> u32 { + let mut rng = self.rng.lock(); + rng.next_u32() + } +} + +/// Netlink protocol bind table. +pub struct NetlinkBindTable { + pub route: RwLock>, + // TODO: more protocol bind tables +} + +impl NetlinkBindTable { + pub fn new() -> Self { + Self { + route: RwLock::new(ProtocolBindTable::new()), + } + } +} + +impl Default for NetlinkBindTable { + fn default() -> Self { + Self::new() + } +} + +/// A protocol-specific bind table for Netlink sockets. +pub struct ProtocolBindTable { + pub unicast_sockets: BTreeMap>, + pub multicast_groups: Box<[MulticastGroup]>, +} + +impl ProtocolBindTable { + /// Creates a new protocol bind table. + pub fn new() -> Self { + let multicast_groups = (0u32..MAX_GROUPS).map(|_| MulticastGroup::new()).collect(); + Self { + unicast_sockets: BTreeMap::new(), + multicast_groups, + } + } +} + +impl Default for ProtocolBindTable { + fn default() -> Self { + Self::new() + } +} + +impl ProtocolBindTable { + + /// Binds a Netlink socket to the specified address. + pub fn bind( + &mut self, + addr: &mut NetlinkSocketAddr, + receiver: MessageReceiver, + ) -> AxResult { + let port = if addr.port() != 0 { + addr.port() + } else { + let mut random_port = current().id().as_u64() as u32; + while random_port == 0 || self.unicast_sockets.contains_key(&random_port) { + random_port = RANDOM.gen_u32(); + } + random_port + }; + addr.set_port(port); + + if self.unicast_sockets.contains_key(&port) { + return Err(AxError::AlreadyExists); + } + + info!("Binding netlink socket to port {}", port); + self.unicast_sockets.insert(port, receiver); + + for group_id in addr.groups().ids_iter() { + let group = &mut self.multicast_groups[group_id as usize]; + group.add_member(port); + } + Ok(()) + } + + /// Sends a Message to the specified port. + pub fn unicast(&self, dst_port: u32, message: Message) -> AxResult + where + Message: QueueableMessage, + { + let Some(receiver) = self.unicast_sockets.get(&dst_port) else { + return Ok(()); + }; + receiver.enqueue_message(message); + + Ok(()) + } + + /// TODO: support multicast sending + #[allow(dead_code)] + pub fn multicast(&self, dst_groups: GroupIdSet, message: Message) -> AxResult + where + Message: MulticastMessage, + { + for group in dst_groups.ids_iter() { + let Some(group) = self.multicast_groups.get(group as usize) else { + continue; + }; + + for port_num in group.members() { + let Some(receiver) = self.unicast_sockets.get(port_num) else { + continue; + }; + receiver.enqueue_message(message.clone()); + } + } + + Ok(()) + } +} + +/// A netlink multicast group. +/// +/// A group can contain multiple sockets, +/// each identified by its bound port number. +pub struct MulticastGroup { + members: BTreeSet, +} + +impl MulticastGroup { + /// Creates a new multicast group. + pub fn new() -> Self { + Self { + members: BTreeSet::new(), + } + } +} + +impl Default for MulticastGroup { + fn default() -> Self { + Self::new() + } +} + +impl MulticastGroup { + + /// Adds a new member to the multicast group. + pub fn add_member(&mut self, port_num: u32) { + self.members.insert(port_num); + } + + /// Removes a member from the multicast group. + pub fn remove_member(&mut self, port_num: u32) { + self.members.remove(&port_num); + } + + /// Returns an iterator over all member port numbers in this group. + pub fn members(&self) -> impl Iterator { + self.members.iter() + } +} + +pub trait MulticastMessage: QueueableMessage + Clone {} diff --git a/modules/axnet/src/service.rs b/modules/axnet/src/service.rs index 95df07c19d..7ee64838c9 100644 --- a/modules/axnet/src/service.rs +++ b/modules/axnet/src/service.rs @@ -12,7 +12,7 @@ use smoltcp::{ wire::{HardwareAddress, IpAddress, IpListenEndpoint}, }; -use crate::{SOCKET_SET, router::Router}; +use crate::{SOCKET_SET, device::Device, router::Router}; fn now() -> Instant { Instant::from_micros_const((wall_time_nanos() / NANOS_PER_MICROS) as i64) @@ -61,6 +61,11 @@ impl Service { } } + /// Iterate over all devices. + pub fn iter_devices(&self) -> impl Iterator> { + self.router.devices.iter() + } + pub fn register_waker(&mut self, mask: u32, waker: &Waker) { let next = self.iface.poll_at(now(), &SOCKET_SET.inner.lock()); diff --git a/modules/axnet/src/socket.rs b/modules/axnet/src/socket.rs index 2e104f477b..116780e61e 100644 --- a/modules/axnet/src/socket.rs +++ b/modules/axnet/src/socket.rs @@ -6,6 +6,8 @@ use core::{ task::Context, }; +#[cfg(feature = "netlink")] +use crate::netlink::{NetlinkSocket, NetlinkSocketAddr}; #[cfg(feature = "vsock")] use axdriver::prelude::VsockAddr; use axerrno::{AxError, AxResult, LinuxError}; @@ -29,6 +31,8 @@ pub enum SocketAddrEx { Unix(UnixSocketAddr), #[cfg(feature = "vsock")] Vsock(VsockAddr), + #[cfg(feature = "netlink")] + Netlink(NetlinkSocketAddr), } impl SocketAddrEx { @@ -38,6 +42,8 @@ impl SocketAddrEx { SocketAddrEx::Unix(_) => Err(AxError::from(LinuxError::EAFNOSUPPORT)), #[cfg(feature = "vsock")] SocketAddrEx::Vsock(_) => Err(AxError::from(LinuxError::EAFNOSUPPORT)), + #[cfg(feature = "netlink")] + SocketAddrEx::Netlink(_) => Err(AxError::from(LinuxError::EAFNOSUPPORT)), } } @@ -47,6 +53,8 @@ impl SocketAddrEx { SocketAddrEx::Ip(_) => Err(AxError::from(LinuxError::EAFNOSUPPORT)), #[cfg(feature = "vsock")] SocketAddrEx::Vsock(_) => Err(AxError::from(LinuxError::EAFNOSUPPORT)), + #[cfg(feature = "netlink")] + SocketAddrEx::Netlink(_) => Err(AxError::from(LinuxError::EAFNOSUPPORT)), } } @@ -56,6 +64,19 @@ impl SocketAddrEx { SocketAddrEx::Ip(_) => Err(AxError::from(LinuxError::EAFNOSUPPORT)), SocketAddrEx::Unix(_) => Err(AxError::from(LinuxError::EAFNOSUPPORT)), SocketAddrEx::Vsock(addr) => Ok(addr), + #[cfg(feature = "netlink")] + SocketAddrEx::Netlink(_) => Err(AxError::from(LinuxError::EAFNOSUPPORT)), + } + } + + #[cfg(feature = "netlink")] + pub fn into_netlink(self) -> AxResult { + match self { + SocketAddrEx::Ip(_) => Err(AxError::from(LinuxError::EAFNOSUPPORT)), + SocketAddrEx::Unix(_) => Err(AxError::from(LinuxError::EAFNOSUPPORT)), + #[cfg(feature = "vsock")] + SocketAddrEx::Vsock(_) => Err(AxError::from(LinuxError::EAFNOSUPPORT)), + SocketAddrEx::Netlink(addr) => Ok(addr), } } } @@ -170,6 +191,8 @@ pub enum Socket { Unix(UnixSocket), #[cfg(feature = "vsock")] Vsock(VsockSocket), + #[cfg(feature = "netlink")] + Netlink(NetlinkSocket), } impl Pollable for Socket { @@ -180,6 +203,8 @@ impl Pollable for Socket { Socket::Unix(unix) => unix.poll(), #[cfg(feature = "vsock")] Socket::Vsock(vsock) => vsock.poll(), + #[cfg(feature = "netlink")] + Socket::Netlink(netlink) => netlink.poll(), } } @@ -190,6 +215,8 @@ impl Pollable for Socket { Socket::Unix(unix) => unix.register(context, events), #[cfg(feature = "vsock")] Socket::Vsock(vsock) => vsock.register(context, events), + #[cfg(feature = "netlink")] + Socket::Netlink(netlink) => netlink.register(context, events), } } } diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 20da71dd9a..58ed10fc43 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,6 +1,6 @@ [toolchain] profile = "minimal" -channel = "nightly-2025-12-12" +channel = "nightly-2026-02-25" components = ["rust-src", "llvm-tools", "rustfmt", "clippy"] targets = [ "x86_64-unknown-none",