diff --git a/Cargo.lock b/Cargo.lock index eaaf391e9..99bfc69a9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -187,17 +187,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "async-recursion" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.79", -] - [[package]] name = "async-signal" version = "0.2.8" @@ -2309,7 +2298,6 @@ dependencies = [ "async-io", "async-lock", "async-process", - "async-recursion", "async-task", "async-trait", "blocking", diff --git a/zbus/Cargo.toml b/zbus/Cargo.toml index f6760d275..01fee9ca7 100644 --- a/zbus/Cargo.toml +++ b/zbus/Cargo.toml @@ -113,9 +113,6 @@ nix = { version = "0.29", default-features = false, features = [ # Cargo doesn't provide a way to do that for only specific target OS: https://github.com/rust-lang/cargo/issues/1197. async-process = "2.2.2" -[target.'cfg(any(target_os = "macos", windows))'.dependencies] -async-recursion = "1.1.1" - [dev-dependencies] zbus_xml = { path = "../zbus_xml", version = "4.0.0" } doc-comment = "0.3.3" diff --git a/zbus/src/address/address_list.rs b/zbus/src/address/address_list.rs new file mode 100644 index 000000000..8d7d0f8f0 --- /dev/null +++ b/zbus/src/address/address_list.rs @@ -0,0 +1,81 @@ +use std::{borrow::Cow, fmt}; + +use super::{Address, Error, Result, ToAddresses}; + +/// A bus address list. +/// +/// D-Bus addresses are `;`-separated. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct AddressList<'a> { + addr: Cow<'a, str>, +} + +impl<'a> ToAddresses<'a> for AddressList<'a> { + type Iter = AddressListIter<'a>; + + /// Get an iterator over the D-Bus addresses. + fn to_addresses(&'a self) -> Self::Iter { + AddressListIter::new(self) + } +} + +impl<'a> Iterator for AddressListIter<'a> { + type Item = Result>; + + fn next(&mut self) -> Option { + if self.next_index >= self.data.len() { + return None; + } + + let mut addr = &self.data[self.next_index..]; + if let Some(end) = addr.find(';') { + addr = &addr[..end]; + self.next_index += end + 1; + } else { + self.next_index = self.data.len(); + } + + Some(Address::try_from(addr)) + } +} + +/// An iterator of D-Bus addresses. +pub struct AddressListIter<'a> { + data: &'a str, + next_index: usize, +} + +impl<'a> AddressListIter<'a> { + fn new(list: &'a AddressList<'_>) -> Self { + Self { + data: list.addr.as_ref(), + next_index: 0, + } + } +} + +impl<'a> TryFrom for AddressList<'a> { + type Error = Error; + + fn try_from(value: String) -> Result { + Ok(Self { + addr: Cow::Owned(value), + }) + } +} + +impl<'a> TryFrom<&'a str> for AddressList<'a> { + type Error = Error; + + fn try_from(value: &'a str) -> Result { + Ok(Self { + addr: Cow::Borrowed(value), + }) + } +} + +impl fmt::Display for AddressList<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.addr) + } +} diff --git a/zbus/src/address/mod.rs b/zbus/src/address/mod.rs index 40707f48e..c50eb056b 100644 --- a/zbus/src/address/mod.rs +++ b/zbus/src/address/mod.rs @@ -3,475 +3,347 @@ //! Server addresses consist of a transport name followed by a colon, and then an optional, //! comma-separated list of keys and values in the form key=value. //! +//! # Miscellaneous and caveats on D-Bus addresses +//! +//! * Assumes values are UTF-8 encoded. +//! +//! * Duplicated keys are accepted, last pair wins. +//! +//! * Assumes that empty `key=val` is accepted, so `transport:,,guid=...` is valid. +//! +//! * Allows key only, so `transport:foo,bar` is ok. +//! +//! * Accept unknown keys and transports. +//! //! See also: //! //! * [Server addresses] in the D-Bus specification. //! //! [Server addresses]: https://dbus.freedesktop.org/doc/dbus-specification.html#addresses -pub mod transport; +use std::{borrow::Cow, env, fmt}; -use crate::{Error, Guid, OwnedGuid, Result}; #[cfg(all(unix, not(target_os = "macos")))] use nix::unistd::Uid; -use std::{collections::HashMap, env, str::FromStr}; -use std::fmt::{Display, Formatter}; +pub mod transport; -use self::transport::Stream; -pub use self::transport::Transport; +mod address_list; +pub use address_list::{AddressList, AddressListIter}; -/// A bus address. -#[derive(Clone, Debug, PartialEq, Eq)] -#[non_exhaustive] -pub struct Address { - guid: Option, - transport: Transport, +mod percent; +pub use percent::*; + +#[cfg(test)] +mod tests; + +/// Error returned when an address is invalid. +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum Error { + UnknownTransport, + MissingTransport, + Encoding(String), + DuplicateKey(String), + MissingKey(String), + MissingValue(String), + InvalidValue(String), + UnknownTcpFamily(String), + Other(String), } -impl Address { - /// Create a new `Address` from a `Transport`. - pub fn new(transport: Transport) -> Self { - Self { - transport, - guid: None, +impl fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Error::UnknownTransport => write!(f, "Unsupported transport in address"), + Error::MissingTransport => write!(f, "Missing transport in address"), + Error::Encoding(e) => write!(f, "Encoding error: {e}"), + Error::DuplicateKey(e) => write!(f, "Duplicate key: `{e}`"), + Error::MissingKey(e) => write!(f, "Missing key: `{e}`"), + Error::MissingValue(e) => write!(f, "Missing value for key: `{e}`"), + Error::InvalidValue(e) => write!(f, "Invalid value for key: `{e}`"), + Error::UnknownTcpFamily(e) => write!(f, "Unknown TCP address family: `{e}`"), + Error::Other(e) => write!(f, "Other error: {e}"), } } +} - /// Set the GUID for this address. - pub fn set_guid(mut self, guid: G) -> Result - where - G: TryInto, - G::Error: Into, - { - self.guid = Some(guid.try_into().map_err(Into::into)?); +impl std::error::Error for Error {} - Ok(self) - } +pub type Result = std::result::Result; - /// The transport details for this address. - pub fn transport(&self) -> &Transport { - &self.transport - } +/// Get the address for session socket respecting the DBUS_SESSION_BUS_ADDRESS environment +/// variable. If we don't recognize the value (or it's not set) we fall back to +/// $XDG_RUNTIME_DIR/bus +pub fn session() -> Result> { + match env::var("DBUS_SESSION_BUS_ADDRESS") { + Ok(val) => AddressList::try_from(val), + _ => { + #[cfg(windows)] + { + AddressList::try_from("autolaunch:scope=*user;autolaunch:") + } - #[cfg_attr(any(target_os = "macos", windows), async_recursion::async_recursion)] - pub(crate) async fn connect(self) -> Result { - self.transport.connect().await - } + #[cfg(all(unix, not(target_os = "macos")))] + { + let runtime_dir = env::var("XDG_RUNTIME_DIR") + .unwrap_or_else(|_| format!("/run/user/{}", Uid::effective())); + let path = format!("unix:path={runtime_dir}/bus"); - /// Get the address for the session socket respecting the `DBUS_SESSION_BUS_ADDRESS` environment - /// variable. If we don't recognize the value (or it's not set) we fall back to - /// `$XDG_RUNTIME_DIR/bus`. - pub fn session() -> Result { - match env::var("DBUS_SESSION_BUS_ADDRESS") { - Ok(val) => Self::from_str(&val), - _ => { - #[cfg(windows)] - return Self::from_str("autolaunch:"); - - #[cfg(all(unix, not(target_os = "macos")))] - { - let runtime_dir = env::var("XDG_RUNTIME_DIR") - .unwrap_or_else(|_| format!("/run/user/{}", Uid::effective())); - let path = format!("unix:path={runtime_dir}/bus"); - - Self::from_str(&path) - } - - #[cfg(target_os = "macos")] - return Self::from_str("launchd:env=DBUS_LAUNCHD_SESSION_BUS_SOCKET"); + AddressList::try_from(path) } - } - } - /// Get the address for the system bus respecting the `DBUS_SYSTEM_BUS_ADDRESS` environment - /// variable. If we don't recognize the value (or it's not set) we fall back to - /// `/var/run/dbus/system_bus_socket`. - pub fn system() -> Result { - match env::var("DBUS_SYSTEM_BUS_ADDRESS") { - Ok(val) => Self::from_str(&val), - _ => { - #[cfg(all(unix, not(target_os = "macos")))] - return Self::from_str("unix:path=/var/run/dbus/system_bus_socket"); - - #[cfg(windows)] - return Self::from_str("autolaunch:"); - - #[cfg(target_os = "macos")] - return Self::from_str("launchd:env=DBUS_LAUNCHD_SESSION_BUS_SOCKET"); + #[cfg(target_os = "macos")] + { + AddressList::try_from("launchd:env=DBUS_LAUNCHD_SESSION_BUS_SOCKET") } } } +} - /// The GUID for this address, if known. - pub fn guid(&self) -> Option<&Guid<'_>> { - self.guid.as_ref().map(|guid| guid.inner()) +/// Get the address for system bus respecting the DBUS_SYSTEM_BUS_ADDRESS environment +/// variable. If we don't recognize the value (or it's not set) we fall back to +/// /var/run/dbus/system_bus_socket +pub fn system() -> Result> { + match env::var("DBUS_SYSTEM_BUS_ADDRESS") { + Ok(val) => AddressList::try_from(val), + _ => { + #[cfg(all(unix, not(target_os = "macos")))] + return AddressList::try_from("unix:path=/var/run/dbus/system_bus_socket"); + + #[cfg(windows)] + return AddressList::try_from("autolaunch:"); + + #[cfg(target_os = "macos")] + return AddressList::try_from("launchd:env=DBUS_LAUNCHD_SESSION_BUS_SOCKET"); + } } } -impl Display for Address { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - self.transport.fmt(f)?; +/// A bus address. +/// +/// Example: +/// ``` +/// use zbus::Address; +/// +/// let _: Address = "unix:path=/tmp/dbus.sock".try_into().unwrap(); +/// ``` +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Address<'a> { + pub(super) addr: Cow<'a, str>, +} + +impl<'a> Address<'a> { + /// The connection GUID if any. + pub fn guid(&self) -> Option> { + self.get_string("guid").and_then(|res| res.ok()) + } + + /// Transport connection details + pub fn transport(&self) -> Result> { + self.try_into() + } + + pub(super) fn key_val_iter(&'a self) -> KeyValIter<'a> { + let mut split = self.addr.splitn(2, ':'); + // skip transport:.. + split.next(); + let kv = split.next().unwrap_or(""); + KeyValIter::new(kv) + } + + fn new>>(addr: A) -> Result { + let addr = addr.into(); + let addr = Self { addr }; + + addr.validate()?; + + Ok(addr) + } - if let Some(guid) = &self.guid { - write!(f, ",guid={}", guid)?; + fn validate(&self) -> Result<()> { + self.transport()?; + for (k, v) in self.key_val_iter() { + let v = match v { + Some(v) => decode_percents(v)?, + _ => Cow::from(b"" as &[_]), + }; + if k == "guid" { + validate_guid(v.as_ref())?; + } } Ok(()) } -} -impl FromStr for Address { - type Err = Error; - - /// Parse the transport part of a D-Bus address into a `Transport`. - fn from_str(address: &str) -> Result { - let col = address - .find(':') - .ok_or_else(|| Error::Address("address has no colon".to_owned()))?; - let transport = &address[..col]; - let mut options = HashMap::new(); - - if address.len() > col + 1 { - for kv in address[col + 1..].split(',') { - let (k, v) = match kv.find('=') { - Some(eq) => (&kv[..eq], &kv[eq + 1..]), - None => { - return Err(Error::Address( - "missing = when parsing key/value".to_owned(), - )) - } - }; - if options.insert(k, v).is_some() { - return Err(Error::Address(format!( - "Key `{k}` specified multiple times" - ))); - } + // the last key=val wins + fn get_string(&'a self, key: &str) -> Option>> { + let mut val = None; + for (k, v) in self.key_val_iter() { + if key == k { + val = v; } } + val.map(decode_percents_str) + } +} - Ok(Self { - guid: options - .remove("guid") - .map(|s| Guid::from_str(s).map(|guid| OwnedGuid::from(guid).to_owned())) - .transpose()?, - transport: Transport::from_options(transport, options)?, - }) +fn validate_guid(value: &[u8]) -> Result<()> { + if value.len() != 32 || value.iter().any(|&c| !c.is_ascii_hexdigit()) { + return Err(Error::InvalidValue("guid".into())); } + + Ok(()) } -impl TryFrom<&str> for Address { +impl Address<'_> { + pub fn to_owned(&self) -> Address<'static> { + let addr = self.addr.to_string(); + Address { addr: addr.into() } + } +} + +impl<'a> TryFrom for Address<'a> { type Error = Error; - fn try_from(value: &str) -> Result { - Self::from_str(value) + fn try_from(addr: String) -> Result { + Self::new(addr) } } -impl From for Address { - fn from(transport: Transport) -> Self { - Self::new(transport) +impl<'a> TryFrom<&'a str> for Address<'a> { + type Error = Error; + + fn try_from(addr: &'a str) -> Result { + Self::new(addr) } } -#[cfg(test)] -mod tests { - use super::{ - transport::{Tcp, TcpTransportFamily, Transport}, - Address, - }; - #[cfg(target_os = "macos")] - use crate::address::transport::Launchd; - #[cfg(windows)] - use crate::address::transport::{Autolaunch, AutolaunchScope}; - use crate::{ - address::transport::{Unix, UnixSocket}, - Error, - }; - use std::str::FromStr; - use test_log::test; - - #[test] - fn parse_dbus_addresses() { - match Address::from_str("").unwrap_err() { - Error::Address(e) => assert_eq!(e, "address has no colon"), - _ => panic!(), - } - match Address::from_str("foo").unwrap_err() { - Error::Address(e) => assert_eq!(e, "address has no colon"), - _ => panic!(), - } - match Address::from_str("foo:opt").unwrap_err() { - Error::Address(e) => assert_eq!(e, "missing = when parsing key/value"), - _ => panic!(), - } - match Address::from_str("foo:opt=1,opt=2").unwrap_err() { - Error::Address(e) => assert_eq!(e, "Key `opt` specified multiple times"), - _ => panic!(), - } - match Address::from_str("tcp:host=localhost").unwrap_err() { - Error::Address(e) => assert_eq!(e, "tcp address is missing `port`"), - _ => panic!(), - } - match Address::from_str("tcp:host=localhost,port=32f").unwrap_err() { - Error::Address(e) => assert_eq!(e, "invalid tcp `port`"), - _ => panic!(), - } - match Address::from_str("tcp:host=localhost,port=123,family=ipv7").unwrap_err() { - Error::Address(e) => assert_eq!(e, "invalid tcp address `family`: ipv7"), - _ => panic!(), - } - match Address::from_str("unix:foo=blah").unwrap_err() { - Error::Address(e) => assert_eq!(e, "unix: address is invalid"), - _ => panic!(), - } - #[cfg(target_os = "linux")] - match Address::from_str("unix:path=/tmp,abstract=foo").unwrap_err() { - Error::Address(e) => { - assert_eq!(e, "unix: address is invalid") - } - _ => panic!(), - } - assert_eq!( - Address::from_str("unix:path=/tmp/dbus-foo").unwrap(), - Transport::Unix(Unix::new(UnixSocket::File("/tmp/dbus-foo".into()))).into(), - ); - #[cfg(target_os = "linux")] - assert_eq!( - Address::from_str("unix:abstract=/tmp/dbus-foo").unwrap(), - Transport::Unix(Unix::new(UnixSocket::Abstract("/tmp/dbus-foo".into()))).into(), - ); - let guid = crate::Guid::generate(); - assert_eq!( - Address::from_str(&format!("unix:path=/tmp/dbus-foo,guid={guid}")).unwrap(), - Address::from(Transport::Unix(Unix::new(UnixSocket::File( - "/tmp/dbus-foo".into() - )))) - .set_guid(guid.clone()) - .unwrap(), - ); - assert_eq!( - Address::from_str("tcp:host=localhost,port=4142").unwrap(), - Transport::Tcp(Tcp::new("localhost", 4142)).into(), - ); - assert_eq!( - Address::from_str("tcp:host=localhost,port=4142,family=ipv4").unwrap(), - Transport::Tcp(Tcp::new("localhost", 4142).set_family(Some(TcpTransportFamily::Ipv4))) - .into(), - ); - assert_eq!( - Address::from_str("tcp:host=localhost,port=4142,family=ipv6").unwrap(), - Transport::Tcp(Tcp::new("localhost", 4142).set_family(Some(TcpTransportFamily::Ipv6))) - .into(), - ); - assert_eq!( - Address::from_str("tcp:host=localhost,port=4142,family=ipv6,noncefile=/a/file/path") - .unwrap(), - Transport::Tcp( - Tcp::new("localhost", 4142) - .set_family(Some(TcpTransportFamily::Ipv6)) - .set_nonce_file(Some(b"/a/file/path".to_vec())) - ) - .into(), - ); - assert_eq!( - Address::from_str( - "nonce-tcp:host=localhost,port=4142,family=ipv6,noncefile=/a/file/path%20to%20file%201234" - ) - .unwrap(), - Transport::Tcp( - Tcp::new("localhost", 4142) - .set_family(Some(TcpTransportFamily::Ipv6)) - .set_nonce_file(Some(b"/a/file/path to file 1234".to_vec())) - ).into() - ); - #[cfg(windows)] - assert_eq!( - Address::from_str("autolaunch:").unwrap(), - Transport::Autolaunch(Autolaunch::new()).into(), - ); - #[cfg(windows)] - assert_eq!( - Address::from_str("autolaunch:scope=*my_cool_scope*").unwrap(), - Transport::Autolaunch( - Autolaunch::new() - .set_scope(Some(AutolaunchScope::Other("*my_cool_scope*".to_string()))) - ) - .into(), - ); - #[cfg(target_os = "macos")] - assert_eq!( - Address::from_str("launchd:env=my_cool_env_key").unwrap(), - Transport::Launchd(Launchd::new("my_cool_env_key")).into(), - ); - - #[cfg(all(feature = "vsock", not(feature = "tokio")))] - assert_eq!( - Address::from_str(&format!("vsock:cid=98,port=2934,guid={guid}")).unwrap(), - Address::from(Transport::Vsock(super::transport::Vsock::new(98, 2934))) - .set_guid(guid) - .unwrap(), - ); - assert_eq!( - Address::from_str("unix:dir=/some/dir").unwrap(), - Transport::Unix(Unix::new(UnixSocket::Dir("/some/dir".into()))).into(), - ); - assert_eq!( - Address::from_str("unix:tmpdir=/some/dir").unwrap(), - Transport::Unix(Unix::new(UnixSocket::TmpDir("/some/dir".into()))).into(), - ); +impl fmt::Display for Address<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let kv = KeyValFmt::new().add("guid", self.guid()); + let t = self.transport().map_err(|_| fmt::Error)?; + let kv = t.key_val_fmt_add(kv); + write!(f, "{t}:{kv}")?; + Ok(()) } +} + +pub(super) struct KeyValIter<'a> { + data: &'a str, + next_index: usize, +} - #[test] - fn stringify_dbus_addresses() { - assert_eq!( - Address::from(Transport::Unix(Unix::new(UnixSocket::File( - "/tmp/dbus-foo".into() - )))) - .to_string(), - "unix:path=/tmp/dbus-foo", - ); - assert_eq!( - Address::from(Transport::Unix(Unix::new(UnixSocket::Dir( - "/tmp/dbus-foo".into() - )))) - .to_string(), - "unix:dir=/tmp/dbus-foo", - ); - assert_eq!( - Address::from(Transport::Unix(Unix::new(UnixSocket::TmpDir( - "/tmp/dbus-foo".into() - )))) - .to_string(), - "unix:tmpdir=/tmp/dbus-foo" - ); - // FIXME: figure out how to handle abstract on Windows - #[cfg(target_os = "linux")] - assert_eq!( - Address::from(Transport::Unix(Unix::new(UnixSocket::Abstract( - "/tmp/dbus-foo".into() - )))) - .to_string(), - "unix:abstract=/tmp/dbus-foo" - ); - assert_eq!( - Address::from(Transport::Tcp(Tcp::new("localhost", 4142))).to_string(), - "tcp:host=localhost,port=4142" - ); - assert_eq!( - Address::from(Transport::Tcp( - Tcp::new("localhost", 4142).set_family(Some(TcpTransportFamily::Ipv4)) - )) - .to_string(), - "tcp:host=localhost,port=4142,family=ipv4" - ); - assert_eq!( - Address::from(Transport::Tcp( - Tcp::new("localhost", 4142).set_family(Some(TcpTransportFamily::Ipv6)) - )) - .to_string(), - "tcp:host=localhost,port=4142,family=ipv6" - ); - assert_eq!( - Address::from(Transport::Tcp(Tcp::new("localhost", 4142) - .set_family(Some(TcpTransportFamily::Ipv6)) - .set_nonce_file(Some(b"/a/file/path to file 1234".to_vec()) - ))) - .to_string(), - "nonce-tcp:noncefile=/a/file/path%20to%20file%201234,host=localhost,port=4142,family=ipv6" - ); - #[cfg(windows)] - assert_eq!( - Address::from(Transport::Autolaunch(Autolaunch::new())).to_string(), - "autolaunch:" - ); - #[cfg(windows)] - assert_eq!( - Address::from(Transport::Autolaunch(Autolaunch::new().set_scope(Some( - AutolaunchScope::Other("*my_cool_scope*".to_string()) - )))) - .to_string(), - "autolaunch:scope=*my_cool_scope*" - ); - #[cfg(target_os = "macos")] - assert_eq!( - Address::from(Transport::Launchd(Launchd::new("my_cool_key"))).to_string(), - "launchd:env=my_cool_key" - ); - - #[cfg(all(feature = "vsock", not(feature = "tokio")))] - { - let guid = crate::Guid::generate(); - assert_eq!( - Address::from(Transport::Vsock(super::transport::Vsock::new(98, 2934))) - .set_guid(guid.clone()) - .unwrap() - .to_string(), - format!("vsock:cid=98,port=2934,guid={guid}"), - ); +impl<'a> KeyValIter<'a> { + fn new(data: &'a str) -> Self { + KeyValIter { + data, + next_index: 0, } } +} - #[test] - fn connect_tcp() { - let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); - let port = listener.local_addr().unwrap().port(); - let addr = Address::from_str(&format!("tcp:host=localhost,port={port}")).unwrap(); - crate::utils::block_on(async { addr.connect().await }).unwrap(); - } +impl<'a> Iterator for KeyValIter<'a> { + type Item = (&'a str, Option<&'a str>); - #[test] - fn connect_nonce_tcp() { - struct PercentEncoded<'a>(&'a [u8]); + fn next(&mut self) -> Option { + if self.next_index >= self.data.len() { + return None; + } - impl std::fmt::Display for PercentEncoded<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - super::transport::encode_percents(f, self.0) - } + let mut pair = &self.data[self.next_index..]; + if let Some(end) = pair.find(',') { + pair = &pair[..end]; + self.next_index += end + 1; + } else { + self.next_index = self.data.len(); } + let mut split = pair.split('='); + // SAFETY: first split always returns something + let key = split.next().unwrap(); + + Some((key, split.next())) + } +} + +pub(crate) trait KeyValFmtAdd { + fn key_val_fmt_add<'a: 'b, 'b>(&'a self, kv: KeyValFmt<'b>) -> KeyValFmt<'b>; +} - use std::io::Write; +pub(crate) struct KeyValFmt<'a> { + fields: Vec<(Box, Box)>, +} - const TEST_COOKIE: &[u8] = b"VERILY SECRETIVE"; +impl<'a> KeyValFmt<'a> { + fn new() -> Self { + Self { fields: vec![] } + } - let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); - let port = listener.local_addr().unwrap().port(); + pub(crate) fn add(mut self, key: K, val: Option) -> Self + where + K: fmt::Display + 'a, + V: Encodable + 'a, + { + if let Some(val) = val { + self.fields.push((Box::new(key), Box::new(val))); + } + + self + } +} + +impl fmt::Display for KeyValFmt<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut first = true; + for (k, v) in self.fields.iter() { + if !first { + write!(f, ",")?; + } + write!(f, "{k}=")?; + v.encode(f)?; + first = false; + } - let mut cookie = tempfile::NamedTempFile::new().unwrap(); - cookie.as_file_mut().write_all(TEST_COOKIE).unwrap(); + Ok(()) + } +} - let encoded_path = format!( - "{}", - PercentEncoded(cookie.path().to_str().unwrap().as_ref()) - ); +/// A trait for objects which can be converted or resolved to one or more [`Address`] values. +pub trait ToAddresses<'a> { + type Iter: Iterator>>; - let addr = Address::from_str(&format!( - "nonce-tcp:host=localhost,port={port},noncefile={encoded_path}" - )) - .unwrap(); + fn to_addresses(&'a self) -> Self::Iter; +} - let (sender, receiver) = std::sync::mpsc::sync_channel(1); +impl<'a> ToAddresses<'a> for Address<'a> { + type Iter = std::iter::Once>>; - std::thread::spawn(move || { - use std::io::Read; + /// Get an iterator over the D-Bus addresses. + fn to_addresses(&'a self) -> Self::Iter { + std::iter::once(Ok(self.clone())) + } +} - let mut client = listener.incoming().next().unwrap().unwrap(); +impl<'a> ToAddresses<'a> for str { + type Iter = std::iter::Once>>; - let mut buf = [0u8; 16]; - client.read_exact(&mut buf).unwrap(); + fn to_addresses(&'a self) -> Self::Iter { + std::iter::once(self.try_into()) + } +} - sender.send(buf == TEST_COOKIE).unwrap(); - }); +impl<'a> ToAddresses<'a> for String { + type Iter = std::iter::Once>>; - crate::utils::block_on(addr.connect()).unwrap(); + fn to_addresses(&'a self) -> Self::Iter { + std::iter::once(self.as_str().try_into()) + } +} - let saw_cookie = receiver - .recv_timeout(std::time::Duration::from_millis(100)) - .expect("nonce file content hasn't been received by server thread in time"); +impl<'a> ToAddresses<'a> for Vec>> { + type Iter = std::iter::Cloned>>>; - assert!( - saw_cookie, - "nonce file content has been received, but was invalid" - ); + /// Get an iterator over the D-Bus addresses. + fn to_addresses(&'a self) -> Self::Iter { + self.iter().cloned() } } diff --git a/zbus/src/address/percent.rs b/zbus/src/address/percent.rs new file mode 100644 index 000000000..ccd2c6310 --- /dev/null +++ b/zbus/src/address/percent.rs @@ -0,0 +1,209 @@ +use std::{ + borrow::Cow, + ffi::{OsStr, OsString}, + fmt, +}; + +use super::{Error, Result}; + +// A trait for types that can be percent-encoded and written to a [`fmt::Formatter`]. +pub(crate) trait Encodable { + fn encode(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result; +} + +impl Encodable for T { + fn encode(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + encode_percents(f, self.to_string().as_bytes()) + } +} + +pub(crate) struct EncData(pub T); + +impl> Encodable for EncData { + fn encode(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + encode_percents(f, self.0.as_ref()) + } +} + +pub(crate) struct EncOsStr(pub T); + +impl Encodable for EncOsStr<&Cow<'_, OsStr>> { + fn encode(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + encode_percents(f, self.0.to_string_lossy().as_bytes()) + } +} + +impl Encodable for EncOsStr<&OsStr> { + fn encode(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + encode_percents(f, self.0.to_string_lossy().as_bytes()) + } +} + +/// Percent-encode the value. +pub fn encode_percents(f: &mut dyn fmt::Write, value: &[u8]) -> std::fmt::Result { + for &byte in value { + if matches!(byte, b'-' | b'0'..=b'9' | b'A'..=b'Z' | b'a'..=b'z' | b'_' | b'/' | b'.' | b'\\' | b'*') + { + // Write the byte directly if it's in the allowed set + f.write_char(byte as char)?; + } else { + // Otherwise, write its percent-encoded form + write!(f, "%{:02X}", byte)?; + } + } + + Ok(()) +} + +/// Percent-decode the string. +pub fn decode_percents(value: &str) -> Result> { + // Check if decoding is necessary + let needs_decoding = value.chars().any(|c| c == '%' || !is_allowed_char(c)); + + if !needs_decoding { + return Ok(Cow::Borrowed(value.as_bytes())); + } + + let mut decoded = Vec::with_capacity(value.len()); + let mut chars = value.chars(); + + while let Some(c) = chars.next() { + match c { + '%' => { + let high = chars + .next() + .ok_or_else(|| Error::Encoding("Incomplete percent-encoded sequence".into()))?; + let low = chars + .next() + .ok_or_else(|| Error::Encoding("Incomplete percent-encoded sequence".into()))?; + decoded.push(decode_hex_pair(high, low)?); + } + _ if is_allowed_char(c) => decoded.push(c as u8), + _ => return Err(Error::Encoding("Invalid character in address".into())), + } + } + + Ok(Cow::Owned(decoded)) +} + +fn is_allowed_char(c: char) -> bool { + matches!(c, '-' | '0'..='9' | 'A'..='Z' | 'a'..='z' | '_' | '/' | '.' | '\\' | '*') +} + +fn decode_hex_pair(high: char, low: char) -> Result { + let high_digit = decode_hex(high)?; + let low_digit = decode_hex(low)?; + + Ok(high_digit << 4 | low_digit) +} + +fn decode_hex(c: char) -> Result { + match c { + '0'..='9' => Ok(c as u8 - b'0'), + 'a'..='f' => Ok(c as u8 - b'a' + 10), + 'A'..='F' => Ok(c as u8 - b'A' + 10), + + _ => Err(Error::Encoding( + "Invalid hexadecimal character in percent-encoded sequence".into(), + )), + } +} + +pub(super) fn decode_percents_str(value: &str) -> Result> { + cow_bytes_to_str(decode_percents(value)?) +} + +fn cow_bytes_to_str(cow: Cow<'_, [u8]>) -> Result> { + match cow { + Cow::Borrowed(bytes) => Ok(Cow::Borrowed( + std::str::from_utf8(bytes).map_err(|e| Error::Encoding(format!("{e}")))?, + )), + Cow::Owned(bytes) => Ok(Cow::Owned( + String::from_utf8(bytes).map_err(|e| Error::Encoding(format!("{e}")))?, + )), + } +} + +pub(super) fn decode_percents_os_str(value: &str) -> Result> { + cow_bytes_to_os_str(decode_percents(value)?) +} + +fn cow_bytes_to_os_str(cow: Cow<'_, [u8]>) -> Result> { + match cow { + Cow::Borrowed(bytes) => Ok(Cow::Borrowed(OsStr::new( + std::str::from_utf8(bytes).map_err(|e| Error::Encoding(format!("{e}")))?, + ))), + Cow::Owned(bytes) => Ok(Cow::Owned(OsString::from( + String::from_utf8(bytes).map_err(|e| Error::Encoding(format!("{e}")))?, + ))), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn simple_ascii() { + const INPUT: &[u8] = "hello".as_bytes(); + + let mut output = String::new(); + encode_percents(&mut output, INPUT).unwrap(); + assert_eq!(output, "hello"); + + let result = decode_percents(&output).unwrap(); + assert!(matches!(result, Cow::Borrowed(_))); + assert_eq!(result, Cow::Borrowed(INPUT)); + } + + #[test] + fn special_characters() { + const INPUT: &[u8] = "hello world!".as_bytes(); + + let mut output = String::new(); + encode_percents(&mut output, INPUT).unwrap(); + assert_eq!(output, "hello%20world%21"); + + let result = decode_percents(&output).unwrap(); + assert!(matches!(result, Cow::Owned(_))); + assert_eq!(result, Cow::Borrowed(INPUT)); + } + + #[test] + fn empty_input() { + const INPUT: &[u8] = "".as_bytes(); + + let mut output = String::new(); + encode_percents(&mut output, INPUT).unwrap(); + assert_eq!(output, ""); + + let result = decode_percents(&output).unwrap(); + assert!(matches!(result, Cow::Borrowed(_))); + assert_eq!(result, Cow::Borrowed(INPUT)); + } + + #[test] + fn non_ascii_characters() { + const INPUT: &[u8] = "😊".as_bytes(); + + let mut output = String::new(); + encode_percents(&mut output, INPUT).unwrap(); + assert_eq!(output, "%F0%9F%98%8A"); + + let result = decode_percents(&output).unwrap(); + assert!(matches!(result, Cow::Owned(_))); + assert_eq!(result, Cow::Borrowed(INPUT)); + } + + #[test] + fn incomplete_encoding() { + let result = decode_percents("incomplete%"); + assert!(result.is_err()); + } + + #[test] + fn invalid_characters() { + let result = decode_percents("invalid%2Gchar"); + assert!(result.is_err()); + } +} diff --git a/zbus/src/address/tests.rs b/zbus/src/address/tests.rs new file mode 100644 index 000000000..f76123a7e --- /dev/null +++ b/zbus/src/address/tests.rs @@ -0,0 +1,185 @@ +use std::{borrow::Cow, ffi::OsStr}; + +#[cfg(target_os = "windows")] +use super::transport::AutolaunchScope; +use super::{ + transport::{TcpFamily, Transport, UnixAddrKind}, + Address, +}; + +#[test] +fn parse_err() { + assert_eq!( + Address::try_from("").unwrap_err().to_string(), + "Missing transport in address" + ); + assert_eq!( + Address::try_from("foo").unwrap_err().to_string(), + "Missing transport in address" + ); + assert_eq!( + Address::try_from("foo:").unwrap_err().to_string(), + "Unsupported transport in address" + ); + assert_eq!( + Address::try_from("tcp:opt=%1").unwrap_err().to_string(), + "Encoding error: Incomplete percent-encoded sequence" + ); + assert_eq!( + Address::try_from("tcp:opt=%1z").unwrap_err().to_string(), + "Encoding error: Invalid hexadecimal character in percent-encoded sequence" + ); + assert_eq!( + Address::try_from("tcp:opt=1\rz").unwrap_err().to_string(), + "Encoding error: Invalid character in address" + ); + assert_eq!( + Address::try_from("tcp:guid=9406e28972c595c590766c9564ce623") + .unwrap_err() + .to_string(), + "Invalid value for key: `guid`" + ); + assert_eq!( + Address::try_from("tcp:guid=9406e28972c595c590766c9564ce623g") + .unwrap_err() + .to_string(), + "Invalid value for key: `guid`" + ); + + let addr = Address::try_from("tcp:guid=9406e28972c595c590766c9564ce623f").unwrap(); + addr.guid().unwrap(); +} + +#[test] +fn parse_unix() { + let addr = + Address::try_from("unix:path=/tmp/dbus-foo,guid=9406e28972c595c590766c9564ce623f").unwrap(); + let Transport::Unix(u) = addr.transport().unwrap() else { + panic!(); + }; + assert_eq!( + u.kind(), + &UnixAddrKind::Path(Cow::Borrowed(OsStr::new("/tmp/dbus-foo"))) + ); + + assert_eq!( + Address::try_from("unix:foo=blah").unwrap_err().to_string(), + "Other error: invalid `unix:` address, missing required key" + ); + assert_eq!( + Address::try_from("unix:path=/blah,abstract=foo") + .unwrap_err() + .to_string(), + "Other error: invalid address, only one of `path` `dir` `tmpdir` `abstract` or `runtime` expected" + ); + assert_eq!( + Address::try_from("unix:runtime=no") + .unwrap_err() + .to_string(), + "Invalid value for key: `runtime`" + ); + Address::try_from(String::from("unix:path=/tmp/foo")).unwrap(); +} + +#[cfg(target_os = "macos")] +#[test] +fn parse_launchd() { + let addr = Address::try_from("launchd:env=FOOBAR").unwrap(); + let Transport::Launchd(t) = addr.transport().unwrap() else { + panic!(); + }; + assert_eq!(t.env(), "FOOBAR"); + + assert_eq!( + Address::try_from("launchd:weof").unwrap_err().to_string(), + "Missing key: `env`" + ); +} + +#[cfg(target_os = "linux")] +#[test] +fn parse_systemd() { + let addr = Address::try_from("systemd:").unwrap(); + let Transport::Systemd(_) = addr.transport().unwrap() else { + panic!(); + }; +} + +#[test] +fn parse_tcp() { + let addr = Address::try_from("tcp:host=localhost,bind=*,port=0,family=ipv4").unwrap(); + let Transport::Tcp(t) = addr.transport().unwrap() else { + panic!(); + }; + assert_eq!(t.host().unwrap(), "localhost"); + assert_eq!(t.bind().unwrap(), "*"); + assert_eq!(t.port().unwrap(), 0); + assert_eq!(t.family().unwrap(), TcpFamily::IPv4); + + let addr = Address::try_from("tcp:").unwrap(); + let Transport::Tcp(t) = addr.transport().unwrap() else { + panic!(); + }; + assert!(t.host().is_none()); + assert!(t.bind().is_none()); + assert!(t.port().is_none()); + assert!(t.family().is_none()); +} + +#[test] +fn parse_nonce_tcp() { + let addr = + Address::try_from("nonce-tcp:host=localhost,bind=*,port=0,family=ipv6,noncefile=foo") + .unwrap(); + let Transport::NonceTcp(t) = addr.transport().unwrap() else { + panic!(); + }; + assert_eq!(t.host().unwrap(), "localhost"); + assert_eq!(t.bind().unwrap(), "*"); + assert_eq!(t.port().unwrap(), 0); + assert_eq!(t.family().unwrap(), TcpFamily::IPv6); + assert_eq!(t.noncefile().unwrap(), "foo"); +} + +#[test] +fn parse_unixexec() { + let addr = Address::try_from("unixexec:path=/bin/test,argv2=foo").unwrap(); + let Transport::Unixexec(t) = addr.transport().unwrap() else { + panic!(); + }; + + assert_eq!(t.path(), "/bin/test"); + assert_eq!(t.argv(), &[(2, Cow::from("foo"))]); + + assert_eq!( + Address::try_from("unixexec:weof").unwrap_err().to_string(), + "Missing key: `path`" + ); +} + +#[test] +fn parse_autolaunch() { + let addr = Address::try_from("autolaunch:scope=*user").unwrap(); + #[allow(unused)] + let Transport::Autolaunch(t) = addr.transport().unwrap() else { + panic!(); + }; + #[cfg(target_os = "windows")] + assert_eq!(t.scope().unwrap(), &AutolaunchScope::User); +} + +#[test] +#[cfg(feature = "vsock")] +fn parse_vsock() { + let addr = Address::try_from("vsock:cid=12,port=32").unwrap(); + let Transport::Vsock(t) = addr.transport().unwrap() else { + panic!(); + }; + assert_eq!(t.port(), Some(32)); + assert_eq!(t.cid(), Some(12)); + + assert_eq!( + Address::try_from("vsock:port=abc").unwrap_err().to_string(), + "Invalid value for key: `port`" + ); +} diff --git a/zbus/src/address/transport/autolaunch.rs b/zbus/src/address/transport/autolaunch.rs index 5233672b8..3d56a907f 100644 --- a/zbus/src/address/transport/autolaunch.rs +++ b/zbus/src/address/transport/autolaunch.rs @@ -1,83 +1,91 @@ -use crate::{Error, Result}; -use std::collections::HashMap; +use std::marker::PhantomData; +#[cfg(target_os = "windows")] +use std::{borrow::Cow, fmt}; -/// Transport properties of an autolaunch D-Bus address. -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct Autolaunch { - pub(super) scope: Option, +#[cfg(target_os = "windows")] +use super::percent::decode_percents_str; +use super::{Address, Error, KeyValFmt, KeyValFmtAdd, Result}; + +/// Scope of autolaunch (Windows only) +#[cfg(target_os = "windows")] +#[derive(Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum AutolaunchScope<'a> { + /// Limit session bus to dbus installation path. + InstallPath, + /// Limit session bus to the recent user. + User, + /// other values - specify dedicated session bus like "release", "debug" or other. + Other(Cow<'a, str>), } -impl std::fmt::Display for Autolaunch { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "autolaunch:")?; - if let Some(scope) = &self.scope { - write!(f, "scope={}", scope)?; +#[cfg(target_os = "windows")] +impl fmt::Display for AutolaunchScope<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::InstallPath => write!(f, "*install-path"), + Self::User => write!(f, "*user"), + Self::Other(o) => write!(f, "{o}"), } - - Ok(()) } } -impl Default for Autolaunch { - fn default() -> Self { - Self::new() - } -} +#[cfg(target_os = "windows")] +impl<'a> TryFrom> for AutolaunchScope<'a> { + type Error = Error; -impl Autolaunch { - /// Create a new autolaunch transport. - pub fn new() -> Self { - Self { scope: None } + fn try_from(s: Cow<'a, str>) -> Result { + match s.as_ref() { + "*install-path" => Ok(Self::InstallPath), + "*user" => Ok(Self::User), + _ => Ok(Self::Other(s)), + } } +} - /// Set the `autolaunch:` address `scope` value. - pub fn set_scope(mut self, scope: Option) -> Self { - self.scope = scope; - - self - } +/// `autolaunch:` D-Bus transport. +/// +/// +#[derive(Debug, PartialEq, Eq, Default)] +pub struct Autolaunch<'a> { + #[cfg(target_os = "windows")] + scope: Option>, + phantom: PhantomData<&'a ()>, +} - /// The optional scope. - pub fn scope(&self) -> Option<&AutolaunchScope> { +impl<'a> Autolaunch<'a> { + #[cfg(target_os = "windows")] + /// Scope of autolaunch (Windows only) + pub fn scope(&self) -> Option<&AutolaunchScope<'a>> { self.scope.as_ref() } +} + +impl<'a> TryFrom<&'a Address<'a>> for Autolaunch<'a> { + type Error = Error; + + fn try_from(s: &'a Address<'a>) -> Result { + #[allow(unused_mut)] + let mut res = Autolaunch::default(); - pub(super) fn from_options(opts: HashMap<&str, &str>) -> Result { - opts.get("scope") - .map(|scope| -> Result<_> { - let decoded = super::decode_percents(scope)?; - match decoded.as_slice() { - b"install-path" => Ok(AutolaunchScope::InstallPath), - b"user" => Ok(AutolaunchScope::User), - _ => String::from_utf8(decoded) - .map(AutolaunchScope::Other) - .map_err(|_| { - Error::Address("autolaunch scope is not valid UTF-8".to_owned()) - }), + for (k, v) in s.key_val_iter() { + match (k, v) { + #[cfg(target_os = "windows")] + ("scope", Some(v)) => { + res.scope = Some(decode_percents_str(v)?.try_into()?); } - }) - .transpose() - .map(|scope| Self { scope }) - } -} + _ => continue, + } + } -#[derive(Clone, Debug, PartialEq, Eq)] -#[non_exhaustive] -pub enum AutolaunchScope { - /// Limit session bus to dbus installation path. - InstallPath, - /// Limit session bus to the recent user. - User, - /// Other values - specify dedicated session bus like "release", "debug" or other. - Other(String), + Ok(res) + } } -impl std::fmt::Display for AutolaunchScope { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::InstallPath => write!(f, "*install-path"), - Self::User => write!(f, "*user"), - Self::Other(o) => write!(f, "{o}"), - } +impl KeyValFmtAdd for Autolaunch<'_> { + fn key_val_fmt_add<'a: 'b, 'b>(&'a self, kv: KeyValFmt<'b>) -> KeyValFmt<'b> { + #[cfg(target_os = "windows")] + let kv = kv.add("scope", self.scope()); + kv } } diff --git a/zbus/src/address/transport/launchd.rs b/zbus/src/address/transport/launchd.rs index 482fcc32f..4f44486eb 100644 --- a/zbus/src/address/transport/launchd.rs +++ b/zbus/src/address/transport/launchd.rs @@ -1,60 +1,46 @@ -use super::{Transport, Unix, UnixSocket}; -use crate::{process::run, Result}; -use std::collections::HashMap; +use std::borrow::Cow; -#[derive(Clone, Debug, PartialEq, Eq)] -#[non_exhaustive] -/// The transport properties of a launchd D-Bus address. -pub struct Launchd { - pub(super) env: String, -} +use super::{percent::decode_percents_str, Address, Error, KeyValFmt, KeyValFmtAdd, Result}; -impl Launchd { - /// Create a new launchd D-Bus address. - pub fn new(env: &str) -> Self { - Self { - env: env.to_string(), - } - } +/// `launchd:` D-Bus transport. +/// +/// +#[derive(Debug, PartialEq, Eq)] +pub struct Launchd<'a> { + env: Cow<'a, str>, +} - /// The path of the unix domain socket for the launchd created dbus-daemon. +impl<'a> Launchd<'a> { + /// Environment variable. + /// + /// Environment variable used to get the path of the unix domain socket for the launchd created + /// dbus-daemon. pub fn env(&self) -> &str { - &self.env + self.env.as_ref() } +} - /// Determine the actual transport details behind a launchd address. - pub(super) async fn bus_address(&self) -> Result { - let output = run("launchctl", ["getenv", self.env()]) - .await - .expect("failed to wait on launchctl output"); - - if !output.status.success() { - return Err(crate::Error::Address(format!( - "launchctl terminated with code: {}", - output.status - ))); +impl<'a> TryFrom<&'a Address<'a>> for Launchd<'a> { + type Error = Error; + + fn try_from(s: &'a Address<'a>) -> Result { + for (k, v) in s.key_val_iter() { + match (k, v) { + ("env", Some(v)) => { + return Ok(Launchd { + env: decode_percents_str(v)?, + }); + } + _ => continue, + } } - let addr = String::from_utf8(output.stdout).map_err(|e| { - crate::Error::Address(format!("Unable to parse launchctl output as UTF-8: {}", e)) - })?; - - Ok(Transport::Unix(Unix::new(UnixSocket::File( - addr.trim().into(), - )))) - } - - pub(super) fn from_options(opts: HashMap<&str, &str>) -> Result { - opts.get("env") - .ok_or_else(|| crate::Error::Address("missing env key".into())) - .map(|env| Self { - env: env.to_string(), - }) + Err(Error::MissingKey("env".into())) } } -impl std::fmt::Display for Launchd { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "launchd:env={}", self.env) +impl KeyValFmtAdd for Launchd<'_> { + fn key_val_fmt_add<'a: 'b, 'b>(&'a self, kv: KeyValFmt<'b>) -> KeyValFmt<'b> { + kv.add("env", Some(self.env())) } } diff --git a/zbus/src/address/transport/mod.rs b/zbus/src/address/transport/mod.rs index 47fd200d0..3d8a2585d 100644 --- a/zbus/src/address/transport/mod.rs +++ b/zbus/src/address/transport/mod.rs @@ -1,350 +1,114 @@ -//! D-Bus transport Information module. -//! -//! This module provides the trasport information for D-Bus addresses. +//! D-Bus supported transports. -#[cfg(windows)] -use crate::win32::autolaunch_bus_address; -use crate::{Error, Result}; -#[cfg(not(feature = "tokio"))] -use async_io::Async; -use std::collections::HashMap; -#[cfg(not(feature = "tokio"))] -use std::net::TcpStream; -#[cfg(unix)] -use std::os::unix::net::{SocketAddr, UnixStream}; -#[cfg(feature = "tokio")] -use tokio::net::TcpStream; -#[cfg(feature = "tokio-vsock")] -use tokio_vsock::VsockStream; -#[cfg(windows)] -use uds_windows::UnixStream; -#[cfg(all(feature = "vsock", not(feature = "tokio")))] -use vsock::VsockStream; +use std::fmt; -use std::{ - fmt::{Display, Formatter}, - str::from_utf8_unchecked, -}; +use super::{percent, Address, Error, KeyValFmt, KeyValFmtAdd, Result}; -mod unix; -pub use unix::{Unix, UnixSocket}; -mod tcp; -pub use tcp::{Tcp, TcpTransportFamily}; -#[cfg(windows)] mod autolaunch; -#[cfg(windows)] -pub use autolaunch::{Autolaunch, AutolaunchScope}; +pub use autolaunch::Autolaunch; +#[cfg(target_os = "windows")] +pub use autolaunch::AutolaunchScope; + #[cfg(target_os = "macos")] mod launchd; #[cfg(target_os = "macos")] pub use launchd::Launchd; -#[cfg(any( - all(feature = "vsock", not(feature = "tokio")), - feature = "tokio-vsock" -))] -#[path = "vsock.rs"] -// Gotta rename to avoid name conflict with the `vsock` crate. -mod vsock_transport; -#[cfg(target_os = "linux")] -use std::os::linux::net::SocketAddrExt; -#[cfg(any( - all(feature = "vsock", not(feature = "tokio")), - feature = "tokio-vsock" -))] -pub use vsock_transport::Vsock; - -/// The transport properties of a D-Bus address. -#[derive(Clone, Debug, PartialEq, Eq)] -#[non_exhaustive] -pub enum Transport { - /// A Unix Domain Socket address. - Unix(Unix), - /// A TCP address. - Tcp(Tcp), - /// An autolaunch D-Bus address. - #[cfg(windows)] - Autolaunch(Autolaunch), - /// A launchd D-Bus address. - #[cfg(target_os = "macos")] - Launchd(Launchd), - #[cfg(any( - all(feature = "vsock", not(feature = "tokio")), - feature = "tokio-vsock" - ))] - /// A VSOCK address. - /// - /// This variant is only available when either the `vsock` or `tokio-vsock` feature is enabled. - /// The type of `stream` is `vsock::VsockStream` with the `vsock` feature and - /// `tokio_vsock::VsockStream` with the `tokio-vsock` feature. - Vsock(Vsock), -} - -impl Transport { - #[cfg_attr(any(target_os = "macos", windows), async_recursion::async_recursion)] - pub(super) async fn connect(self) -> Result { - match self { - Transport::Unix(unix) => { - // This is a `path` in case of Windows until uds_windows provides the needed API: - // https://github.com/haraldh/rust_uds_windows/issues/14 - let addr = match unix.take_path() { - #[cfg(unix)] - UnixSocket::File(path) => SocketAddr::from_pathname(path)?, - #[cfg(windows)] - UnixSocket::File(path) => path, - #[cfg(target_os = "linux")] - UnixSocket::Abstract(name) => { - SocketAddr::from_abstract_name(name.as_encoded_bytes())? - } - UnixSocket::Dir(_) | UnixSocket::TmpDir(_) => { - // you can't connect to a unix:dir - return Err(Error::Unsupported); - } - }; - let stream = crate::Task::spawn_blocking( - move || -> Result<_> { - #[cfg(unix)] - let stream = UnixStream::connect_addr(&addr)?; - #[cfg(windows)] - let stream = UnixStream::connect(addr)?; - stream.set_nonblocking(true)?; - - Ok(stream) - }, - "unix stream connection", - ) - .await?; - #[cfg(not(feature = "tokio"))] - { - Async::new(stream) - .map(Stream::Unix) - .map_err(|e| Error::InputOutput(e.into())) - } - - #[cfg(feature = "tokio")] - { - #[cfg(unix)] - { - tokio::net::UnixStream::from_std(stream) - .map(Stream::Unix) - .map_err(|e| Error::InputOutput(e.into())) - } - - #[cfg(not(unix))] - { - let _ = stream; - Err(Error::Unsupported) - } - } - } - #[cfg(all(feature = "vsock", not(feature = "tokio")))] - Transport::Vsock(addr) => { - let stream = VsockStream::connect_with_cid_port(addr.cid(), addr.port())?; - Async::new(stream).map(Stream::Vsock).map_err(Into::into) - } - #[cfg(feature = "tokio-vsock")] - Transport::Vsock(addr) => VsockStream::connect(addr.cid(), addr.port()) - .await - .map(Stream::Vsock) - .map_err(Into::into), +mod nonce_tcp; +pub use nonce_tcp::NonceTcp; - Transport::Tcp(mut addr) => match addr.take_nonce_file() { - Some(nonce_file) => { - #[allow(unused_mut)] - let mut stream = addr.connect().await?; - - #[cfg(unix)] - let nonce_file = { - use std::os::unix::ffi::OsStrExt; - std::ffi::OsStr::from_bytes(&nonce_file) - }; - - #[cfg(windows)] - let nonce_file = std::str::from_utf8(&nonce_file).map_err(|_| { - Error::Address("nonce file path is invalid UTF-8".to_owned()) - })?; - - #[cfg(not(feature = "tokio"))] - { - let nonce = std::fs::read(nonce_file)?; - let mut nonce = &nonce[..]; +#[cfg(target_os = "linux")] +mod systemd; +#[cfg(target_os = "linux")] +pub use systemd::Systemd; - while !nonce.is_empty() { - let len = stream - .write_with(|mut s| std::io::Write::write(&mut s, nonce)) - .await?; - nonce = &nonce[len..]; - } - } +mod tcp; +pub use tcp::{Tcp, TcpFamily}; - #[cfg(feature = "tokio")] - { - let nonce = tokio::fs::read(nonce_file).await?; - tokio::io::AsyncWriteExt::write_all(&mut stream, &nonce).await?; - } +mod unix; +pub use unix::{Unix, UnixAddrKind}; - Ok(Stream::Tcp(stream)) - } - None => addr.connect().await.map(Stream::Tcp), - }, +mod unixexec; +pub use unixexec::Unixexec; - #[cfg(windows)] - Transport::Autolaunch(Autolaunch { scope }) => match scope { - Some(_) => Err(Error::Address( - "Autolaunch scopes are currently unsupported".to_owned(), - )), - None => { - let addr = autolaunch_bus_address()?; - addr.connect().await - } - }, +mod vsock; +pub use vsock::Vsock; - #[cfg(target_os = "macos")] - Transport::Launchd(launchd) => { - let addr = launchd.bus_address().await?; - addr.connect().await - } - } - } +/// A D-Bus transport. +#[derive(Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum Transport<'a> { + /// Unix Domain Sockets transport. + Unix(unix::Unix<'a>), + #[cfg(target_os = "macos")] + /// launchd transport. + Launchd(launchd::Launchd<'a>), + #[cfg(target_os = "linux")] + /// systemd transport. + Systemd(systemd::Systemd<'a>), + /// TCP Sockets transport. + Tcp(tcp::Tcp<'a>), + /// Nonce-authenticated TCP Sockets transport. + NonceTcp(nonce_tcp::NonceTcp<'a>), + /// Executed Subprocesses on Unix transport. + Unixexec(unixexec::Unixexec<'a>), + /// Autolaunch transport. + Autolaunch(autolaunch::Autolaunch<'a>), + /// VSOCK Sockets transport. + Vsock(vsock::Vsock<'a>), +} - // Helper for `FromStr` impl of `Address`. - pub(super) fn from_options(transport: &str, options: HashMap<&str, &str>) -> Result { - match transport { - "unix" => Unix::from_options(options).map(Self::Unix), - "tcp" => Tcp::from_options(options, false).map(Self::Tcp), - "nonce-tcp" => Tcp::from_options(options, true).map(Self::Tcp), - #[cfg(any( - all(feature = "vsock", not(feature = "tokio")), - feature = "tokio-vsock" - ))] - "vsock" => Vsock::from_options(options).map(Self::Vsock), - #[cfg(windows)] - "autolaunch" => Autolaunch::from_options(options).map(Self::Autolaunch), +impl fmt::Display for Transport<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Unix(_) => write!(f, "unix"), #[cfg(target_os = "macos")] - "launchd" => Launchd::from_options(options).map(Self::Launchd), - - _ => Err(Error::Address(format!( - "unsupported transport '{transport}'" - ))), + Self::Launchd(_) => write!(f, "launchd"), + #[cfg(target_os = "linux")] + Self::Systemd(_) => write!(f, "systemd"), + Self::Tcp(_) => write!(f, "tcp"), + Self::NonceTcp(_) => write!(f, "nonce-tcp"), + Self::Unixexec(_) => write!(f, "unixexec"), + Self::Autolaunch(_) => write!(f, "autolaunch"), + Self::Vsock(_) => write!(f, "vsock"), } } } -#[cfg(not(feature = "tokio"))] -#[derive(Debug)] -pub(crate) enum Stream { - Unix(Async), - Tcp(Async), - #[cfg(feature = "vsock")] - Vsock(Async), -} - -#[cfg(feature = "tokio")] -#[derive(Debug)] -pub(crate) enum Stream { - #[cfg(unix)] - Unix(tokio::net::UnixStream), - Tcp(TcpStream), - #[cfg(feature = "tokio-vsock")] - Vsock(VsockStream), -} - -fn decode_hex(c: char) -> Result { - match c { - '0'..='9' => Ok(c as u8 - b'0'), - 'a'..='f' => Ok(c as u8 - b'a' + 10), - 'A'..='F' => Ok(c as u8 - b'A' + 10), - - _ => Err(Error::Address( - "invalid hexadecimal character in percent-encoded sequence".to_owned(), - )), - } -} - -pub(crate) fn decode_percents(value: &str) -> Result> { - let mut iter = value.chars(); - let mut decoded = Vec::new(); - - while let Some(c) = iter.next() { - if matches!(c, '-' | '0'..='9' | 'A'..='Z' | 'a'..='z' | '_' | '/' | '.' | '\\' | '*') { - decoded.push(c as u8) - } else if c == '%' { - decoded.push( - decode_hex(iter.next().ok_or_else(|| { - Error::Address("incomplete percent-encoded sequence".to_owned()) - })?)? - << 4 - | decode_hex(iter.next().ok_or_else(|| { - Error::Address("incomplete percent-encoded sequence".to_owned()) - })?)?, - ); - } else { - return Err(Error::Address("Invalid character in address".to_owned())); +impl KeyValFmtAdd for Transport<'_> { + fn key_val_fmt_add<'a: 'b, 'b>(&'a self, kv: KeyValFmt<'b>) -> KeyValFmt<'b> { + match self { + Self::Unix(t) => t.key_val_fmt_add(kv), + #[cfg(target_os = "macos")] + Self::Launchd(t) => t.key_val_fmt_add(kv), + #[cfg(target_os = "linux")] + Self::Systemd(t) => t.key_val_fmt_add(kv), + Self::Tcp(t) => t.key_val_fmt_add(kv), + Self::NonceTcp(t) => t.key_val_fmt_add(kv), + Self::Unixexec(t) => t.key_val_fmt_add(kv), + Self::Autolaunch(t) => t.key_val_fmt_add(kv), + Self::Vsock(t) => t.key_val_fmt_add(kv), } } - - Ok(decoded) } -pub(super) fn encode_percents(f: &mut Formatter<'_>, mut value: &[u8]) -> std::fmt::Result { - const LOOKUP: &str = "\ -%00%01%02%03%04%05%06%07%08%09%0a%0b%0c%0d%0e%0f\ -%10%11%12%13%14%15%16%17%18%19%1a%1b%1c%1d%1e%1f\ -%20%21%22%23%24%25%26%27%28%29%2a%2b%2c%2d%2e%2f\ -%30%31%32%33%34%35%36%37%38%39%3a%3b%3c%3d%3e%3f\ -%40%41%42%43%44%45%46%47%48%49%4a%4b%4c%4d%4e%4f\ -%50%51%52%53%54%55%56%57%58%59%5a%5b%5c%5d%5e%5f\ -%60%61%62%63%64%65%66%67%68%69%6a%6b%6c%6d%6e%6f\ -%70%71%72%73%74%75%76%77%78%79%7a%7b%7c%7d%7e%7f\ -%80%81%82%83%84%85%86%87%88%89%8a%8b%8c%8d%8e%8f\ -%90%91%92%93%94%95%96%97%98%99%9a%9b%9c%9d%9e%9f\ -%a0%a1%a2%a3%a4%a5%a6%a7%a8%a9%aa%ab%ac%ad%ae%af\ -%b0%b1%b2%b3%b4%b5%b6%b7%b8%b9%ba%bb%bc%bd%be%bf\ -%c0%c1%c2%c3%c4%c5%c6%c7%c8%c9%ca%cb%cc%cd%ce%cf\ -%d0%d1%d2%d3%d4%d5%d6%d7%d8%d9%da%db%dc%dd%de%df\ -%e0%e1%e2%e3%e4%e5%e6%e7%e8%e9%ea%eb%ec%ed%ee%ef\ -%f0%f1%f2%f3%f4%f5%f6%f7%f8%f9%fa%fb%fc%fd%fe%ff"; - - loop { - let pos = value.iter().position( - |c| !matches!(c, b'-' | b'0'..=b'9' | b'A'..=b'Z' | b'a'..=b'z' | b'_' | b'/' | b'.' | b'\\' | b'*'), - ); - - if let Some(pos) = pos { - // SAFETY: The above `position()` call made sure that only ASCII chars are in the string - // up to `pos` - f.write_str(unsafe { from_utf8_unchecked(&value[..pos]) })?; - - let c = value[pos]; - value = &value[pos + 1..]; +impl<'a> TryFrom<&'a Address<'a>> for Transport<'a> { + type Error = Error; - let pos = c as usize * 3; - f.write_str(&LOOKUP[pos..pos + 3])?; - } else { - // SAFETY: The above `position()` call made sure that only ASCII chars are in the rest - // of the string - f.write_str(unsafe { from_utf8_unchecked(value) })?; - return Ok(()); - } - } -} - -impl Display for Transport { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Self::Tcp(tcp) => write!(f, "{}", tcp)?, - Self::Unix(unix) => write!(f, "{}", unix)?, - #[cfg(any( - all(feature = "vsock", not(feature = "tokio")), - feature = "tokio-vsock" - ))] - Self::Vsock(vsock) => write!(f, "{}", vsock)?, - #[cfg(windows)] - Self::Autolaunch(autolaunch) => write!(f, "{}", autolaunch)?, + fn try_from(s: &'a Address<'a>) -> Result { + let col = s.addr.find(':').ok_or(Error::MissingTransport)?; + match &s.addr[..col] { + "unix" => Ok(Self::Unix(s.try_into()?)), #[cfg(target_os = "macos")] - Self::Launchd(launchd) => write!(f, "{}", launchd)?, + "launchd" => Ok(Self::Launchd(s.try_into()?)), + #[cfg(target_os = "linux")] + "systemd" => Ok(Self::Systemd(s.try_into()?)), + "tcp" => Ok(Self::Tcp(s.try_into()?)), + "nonce-tcp" => Ok(Self::NonceTcp(s.try_into()?)), + "unixexec" => Ok(Self::Unixexec(s.try_into()?)), + "autolaunch" => Ok(Self::Autolaunch(s.try_into()?)), + "vsock" => Ok(Self::Vsock(s.try_into()?)), + _ => Err(Error::UnknownTransport), } - - Ok(()) } } diff --git a/zbus/src/address/transport/nonce_tcp.rs b/zbus/src/address/transport/nonce_tcp.rs new file mode 100644 index 000000000..f0240f50d --- /dev/null +++ b/zbus/src/address/transport/nonce_tcp.rs @@ -0,0 +1,100 @@ +use std::{borrow::Cow, ffi::OsStr}; + +use super::{ + percent::{decode_percents_os_str, decode_percents_str, EncOsStr}, + tcp::TcpFamily, + Address, Error, KeyValFmt, KeyValFmtAdd, Result, +}; + +/// `nonce-tcp:` D-Bus transport. +/// +/// +#[derive(Debug, Default, PartialEq, Eq)] +pub struct NonceTcp<'a> { + host: Option>, + bind: Option>, + port: Option, + family: Option, + noncefile: Option>, +} + +impl<'a> NonceTcp<'a> { + /// If set, the DNS name or IP address. + pub fn host(&self) -> Option<&str> { + self.host.as_ref().map(|v| v.as_ref()) + } + + /// If set, the listenable address. + /// + /// Used in a listenable address to configure the interface on which the server will listen: + /// either the IP address of one of the local machine's interfaces (most commonly `127.0.0.1`), + /// or a DNS name that resolves to one of those IP addresses, or `*` to listen on all interfaces + /// simultaneously. + pub fn bind(&self) -> Option<&str> { + self.bind.as_ref().map(|v| v.as_ref()) + } + + /// If set, the TCP port. + /// + /// The TCP port the server will open. A zero value let the server choose a free port provided + /// from the underlying operating system. + pub fn port(&self) -> Option { + self.port + } + + /// If set, the type of socket family. + pub fn family(&self) -> Option { + self.family + } + + /// If set, the nonce file location. + /// + /// File location containing the secret. This is only meaningful in connectable addresses. + pub fn noncefile(&self) -> Option<&OsStr> { + self.noncefile.as_ref().map(|v| v.as_ref()) + } +} + +impl KeyValFmtAdd for NonceTcp<'_> { + fn key_val_fmt_add<'a: 'b, 'b>(&'a self, kv: KeyValFmt<'b>) -> KeyValFmt<'b> { + kv.add("host", self.host()) + .add("bind", self.bind()) + .add("port", self.port()) + .add("family", self.family()) + .add("noncefile", self.noncefile().map(EncOsStr)) + } +} + +impl<'a> TryFrom<&'a Address<'a>> for NonceTcp<'a> { + type Error = Error; + + fn try_from(s: &'a Address<'a>) -> Result { + let mut res = NonceTcp::default(); + for (k, v) in s.key_val_iter() { + match (k, v) { + ("host", Some(v)) => { + res.host = Some(decode_percents_str(v)?); + } + ("bind", Some(v)) => { + res.bind = Some(decode_percents_str(v)?); + } + ("port", Some(v)) => { + res.port = Some( + decode_percents_str(v)? + .parse() + .map_err(|_| Error::InvalidValue("port".into()))?, + ); + } + ("family", Some(v)) => { + res.family = Some(decode_percents_str(v)?.as_ref().try_into()?); + } + ("noncefile", Some(v)) => { + res.noncefile = Some(decode_percents_os_str(v)?); + } + _ => continue, + } + } + + Ok(res) + } +} diff --git a/zbus/src/address/transport/systemd.rs b/zbus/src/address/transport/systemd.rs new file mode 100644 index 000000000..5f6350f86 --- /dev/null +++ b/zbus/src/address/transport/systemd.rs @@ -0,0 +1,28 @@ +use std::marker::PhantomData; + +use super::{Address, Error, KeyValFmt, KeyValFmtAdd, Result}; + +/// `systemd:` D-Bus transport. +/// +/// +#[derive(Debug, PartialEq, Eq)] +pub struct Systemd<'a> { + // use a phantom lifetime for eventually future fields and consistency + phantom: PhantomData<&'a ()>, +} + +impl<'a> TryFrom<&'a Address<'a>> for Systemd<'a> { + type Error = Error; + + fn try_from(_s: &'a Address<'a>) -> Result { + Ok(Systemd { + phantom: PhantomData, + }) + } +} + +impl KeyValFmtAdd for Systemd<'_> { + fn key_val_fmt_add<'a: 'b, 'b>(&'a self, kv: KeyValFmt<'b>) -> KeyValFmt<'b> { + kv + } +} diff --git a/zbus/src/address/transport/tcp.rs b/zbus/src/address/transport/tcp.rs index c289fa666..2f1969e2d 100644 --- a/zbus/src/address/transport/tcp.rs +++ b/zbus/src/address/transport/tcp.rs @@ -1,229 +1,115 @@ -use super::encode_percents; -use crate::{Error, Result}; -#[cfg(not(feature = "tokio"))] -use async_io::Async; -#[cfg(not(feature = "tokio"))] -use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; -use std::{ - collections::HashMap, - fmt::{Display, Formatter}, - str::FromStr, -}; -#[cfg(feature = "tokio")] -use tokio::net::TcpStream; - -/// A TCP transport in a D-Bus address. -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct Tcp { - pub(super) host: String, - pub(super) bind: Option, - pub(super) port: u16, - pub(super) family: Option, - pub(super) nonce_file: Option>, +use std::{borrow::Cow, fmt}; + +use super::{percent::decode_percents_str, Address, Error, KeyValFmt, KeyValFmtAdd, Result}; + +/// TCP IP address family +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[non_exhaustive] +pub enum TcpFamily { + /// IPv4 + IPv4, + /// IPv6 + IPv6, } -impl Tcp { - /// Create a new TCP transport with the given host and port. - pub fn new(host: &str, port: u16) -> Self { - Self { - host: host.to_owned(), - port, - bind: None, - family: None, - nonce_file: None, +impl fmt::Display for TcpFamily { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::IPv4 => write!(f, "ipv4"), + Self::IPv6 => write!(f, "ipv6"), } } +} - /// Set the `tcp:` address `bind` value. - pub fn set_bind(mut self, bind: Option) -> Self { - self.bind = bind; - - self - } - - /// Set the `tcp:` address `family` value. - pub fn set_family(mut self, family: Option) -> Self { - self.family = family; +impl TryFrom<&str> for TcpFamily { + type Error = Error; - self + fn try_from(s: &str) -> Result { + match s { + "ipv4" => Ok(Self::IPv4), + "ipv6" => Ok(Self::IPv6), + _ => Err(Error::UnknownTcpFamily(s.into())), + } } +} - /// Set the `tcp:` address `noncefile` value. - pub fn set_nonce_file(mut self, nonce_file: Option>) -> Self { - self.nonce_file = nonce_file; - - self - } +/// `tcp:` D-Bus transport. +/// +/// +#[derive(Debug, Default, PartialEq, Eq)] +pub struct Tcp<'a> { + host: Option>, + bind: Option>, + port: Option, + family: Option, +} - /// The `tcp:` address `host` value. - pub fn host(&self) -> &str { - &self.host +impl<'a> Tcp<'a> { + /// If set, DNS name or IP address. + pub fn host(&self) -> Option<&str> { + self.host.as_ref().map(|v| v.as_ref()) } - /// The `tcp:` address `bind` value. + /// If set, the listenable address. + /// + /// Used in a listenable address to configure the interface on which the server will listen: + /// either the IP address of one of the local machine's interfaces (most commonly `127.0.0.1`), + /// or a DNS name that resolves to one of those IP addresses, or `*` to listen on all interfaces + /// simultaneously. pub fn bind(&self) -> Option<&str> { - self.bind.as_deref() + self.bind.as_ref().map(|v| v.as_ref()) } - /// The `tcp:` address `port` value. - pub fn port(&self) -> u16 { + /// If set, the TCP port. + /// + /// The TCP port the server will open. A zero value let the server choose a free port provided + /// from the underlying operating system. + pub fn port(&self) -> Option { self.port } - /// The `tcp:` address `family` value. - pub fn family(&self) -> Option { + /// If set, the type of socket family. + pub fn family(&self) -> Option { self.family } - - /// The nonce file path, if any. - pub fn nonce_file(&self) -> Option<&[u8]> { - self.nonce_file.as_deref() - } - - /// Take ownership of the nonce file path, if any. - pub fn take_nonce_file(&mut self) -> Option> { - self.nonce_file.take() - } - - pub(super) fn from_options( - opts: HashMap<&str, &str>, - nonce_tcp_required: bool, - ) -> Result { - let bind = None; - if opts.contains_key("bind") { - return Err(Error::Address("`bind` isn't yet supported".into())); - } - - let host = opts - .get("host") - .ok_or_else(|| Error::Address("tcp address is missing `host`".into()))? - .to_string(); - let port = opts - .get("port") - .ok_or_else(|| Error::Address("tcp address is missing `port`".into()))?; - let port = port - .parse::() - .map_err(|_| Error::Address("invalid tcp `port`".into()))?; - let family = opts - .get("family") - .map(|f| TcpTransportFamily::from_str(f)) - .transpose()?; - let nonce_file = opts - .get("noncefile") - .map(|f| super::decode_percents(f)) - .transpose()?; - if nonce_tcp_required && nonce_file.is_none() { - return Err(Error::Address( - "nonce-tcp address is missing `noncefile`".into(), - )); - } - - Ok(Self { - host, - bind, - port, - family, - nonce_file, - }) - } - - #[cfg(not(feature = "tokio"))] - pub(super) async fn connect(self) -> Result> { - let addrs = crate::Task::spawn_blocking( - move || -> Result> { - let addrs = (self.host(), self.port()).to_socket_addrs()?.filter(|a| { - if let Some(family) = self.family() { - if family == TcpTransportFamily::Ipv4 { - a.is_ipv4() - } else { - a.is_ipv6() - } - } else { - true - } - }); - Ok(addrs.collect()) - }, - "connect tcp", - ) - .await - .map_err(|e| Error::Address(format!("Failed to receive TCP addresses: {e}")))?; - - // we could attempt connections in parallel? - let mut last_err = Error::Address("Failed to connect".into()); - for addr in addrs { - match Async::::connect(addr).await { - Ok(stream) => return Ok(stream), - Err(e) => last_err = e.into(), - } - } - - Err(last_err) - } - - #[cfg(feature = "tokio")] - pub(super) async fn connect(self) -> Result { - TcpStream::connect((self.host(), self.port())) - .await - .map_err(|e| Error::InputOutput(e.into())) - } } -impl Display for Tcp { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self.nonce_file() { - Some(nonce_file) => { - f.write_str("nonce-tcp:noncefile=")?; - encode_percents(f, nonce_file)?; - f.write_str(",")?; - } - None => f.write_str("tcp:")?, - } - f.write_str("host=")?; - - encode_percents(f, self.host().as_bytes())?; - - write!(f, ",port={}", self.port())?; - - if let Some(bind) = self.bind() { - f.write_str(",bind=")?; - encode_percents(f, bind.as_bytes())?; - } - - if let Some(family) = self.family() { - write!(f, ",family={family}")?; - } - - Ok(()) +impl KeyValFmtAdd for Tcp<'_> { + fn key_val_fmt_add<'a: 'b, 'b>(&'a self, kv: KeyValFmt<'b>) -> KeyValFmt<'b> { + kv.add("host", self.host()) + .add("bind", self.bind()) + .add("port", self.port()) + .add("family", self.family()) } } -/// A `tcp:` address family. -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum TcpTransportFamily { - Ipv4, - Ipv6, -} - -impl FromStr for TcpTransportFamily { - type Err = Error; - - fn from_str(family: &str) -> Result { - match family { - "ipv4" => Ok(Self::Ipv4), - "ipv6" => Ok(Self::Ipv6), - _ => Err(Error::Address(format!( - "invalid tcp address `family`: {family}" - ))), +impl<'a> TryFrom<&'a Address<'a>> for Tcp<'a> { + type Error = Error; + + fn try_from(s: &'a Address<'a>) -> Result { + let mut res = Tcp::default(); + for (k, v) in s.key_val_iter() { + match (k, v) { + ("host", Some(v)) => { + res.host = Some(decode_percents_str(v)?); + } + ("bind", Some(v)) => { + res.bind = Some(decode_percents_str(v)?); + } + ("port", Some(v)) => { + res.port = Some( + decode_percents_str(v)? + .parse() + .map_err(|_| Error::InvalidValue("port".into()))?, + ); + } + ("family", Some(v)) => { + res.family = Some(decode_percents_str(v)?.as_ref().try_into()?); + } + _ => continue, + } } - } -} -impl Display for TcpTransportFamily { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Self::Ipv4 => write!(f, "ipv4"), - Self::Ipv6 => write!(f, "ipv6"), - } + Ok(res) } } diff --git a/zbus/src/address/transport/unix.rs b/zbus/src/address/transport/unix.rs index 9d6f2acf7..d105ce18b 100644 --- a/zbus/src/address/transport/unix.rs +++ b/zbus/src/address/transport/unix.rs @@ -1,129 +1,118 @@ -#[cfg(target_os = "linux")] -use std::ffi::OsString; -use std::{ - ffi::OsStr, - fmt::{Display, Formatter}, - path::PathBuf, -}; +use std::{borrow::Cow, ffi::OsStr}; -#[cfg(unix)] -use super::encode_percents; +use super::{ + percent::{decode_percents, decode_percents_os_str, decode_percents_str, EncData, EncOsStr}, + Address, Error, KeyValFmt, KeyValFmtAdd, Result, +}; -/// A Unix domain socket transport in a D-Bus address. -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct Unix { - path: UnixSocket, +/// A sub-type of `unix:` transport. +#[derive(Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum UnixAddrKind<'a> { + /// Path of the unix domain socket. + Path(Cow<'a, OsStr>), + /// Directory in which a socket file with a random file name starting with 'dbus-' should be + /// created by a server. + Dir(Cow<'a, OsStr>), + /// The same as "dir", except that on platforms with abstract sockets, a server may attempt to + /// create an abstract socket whose name starts with this directory instead of a path-based + /// socket. + Tmpdir(Cow<'a, OsStr>), + /// Unique string in the abstract namespace, often syntactically resembling a path but + /// unconnected to the filesystem namespace + Abstract(Cow<'a, [u8]>), + /// Listen on $XDG_RUNTIME_DIR/bus. + Runtime, } -impl Unix { - /// Create a new Unix transport with the given path. - pub fn new(path: UnixSocket) -> Self { - Self { path } - } - - /// The path. - pub fn path(&self) -> &UnixSocket { - &self.path - } - - /// Take the path, consuming `self`. - pub fn take_path(self) -> UnixSocket { - self.path - } - - pub(super) fn from_options(opts: std::collections::HashMap<&str, &str>) -> crate::Result { - let path = opts.get("path"); - let abs = opts.get("abstract"); - let dir = opts.get("dir"); - let tmpdir = opts.get("tmpdir"); - let path = match (path, abs, dir, tmpdir) { - (Some(p), None, None, None) => UnixSocket::File(PathBuf::from(p)), - #[cfg(target_os = "linux")] - (None, Some(p), None, None) => UnixSocket::Abstract(OsString::from(p)), - #[cfg(not(target_os = "linux"))] - (None, Some(_), None, None) => { - return Err(crate::Error::Address( - "abstract sockets currently Linux-only".to_owned(), - )); - } - (None, None, Some(p), None) => UnixSocket::Dir(PathBuf::from(p)), - (None, None, None, Some(p)) => UnixSocket::TmpDir(PathBuf::from(p)), - _ => { - return Err(crate::Error::Address("unix: address is invalid".to_owned())); - } - }; - - Ok(Self::new(path)) +impl KeyValFmtAdd for UnixAddrKind<'_> { + fn key_val_fmt_add<'a: 'b, 'b>(&'a self, kv: KeyValFmt<'b>) -> KeyValFmt<'b> { + match self { + UnixAddrKind::Path(p) => kv.add("path", Some(EncOsStr(p))), + UnixAddrKind::Dir(p) => kv.add("dir", Some(EncOsStr(p))), + UnixAddrKind::Tmpdir(p) => kv.add("tmpdir", Some(EncOsStr(p))), + UnixAddrKind::Abstract(p) => kv.add("abstract", Some(EncData(p))), + UnixAddrKind::Runtime => kv.add("runtime", Some("yes")), + } } } -impl Display for Unix { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "unix:{}", self.path) - } +/// `unix:` D-Bus transport. +/// +/// +#[derive(Debug, PartialEq, Eq)] +pub struct Unix<'a> { + kind: UnixAddrKind<'a>, } -/// A Unix domain socket path in a D-Bus address. -#[derive(Clone, Debug, PartialEq, Eq)] -#[non_exhaustive] -pub enum UnixSocket { - /// A path to a unix domain socket on the filesystem. - File(PathBuf), - /// An abstract unix domain socket name. - #[cfg(target_os = "linux")] - Abstract(OsString), - /// A listenable address using the specified path, in which a socket file with a random file - /// name starting with 'dbus-' will be created by the server. See [UNIX domain socket address] - /// reference documentation. - /// - /// This address is mostly relevant to server (typically bus broker) implementations. - /// - /// [UNIX domain socket address]: https://dbus.freedesktop.org/doc/dbus-specification.html#transports-unix-domain-sockets-addresses - Dir(PathBuf), - /// The same as UnixDir, except that on platforms with abstract sockets, the server may attempt - /// to create an abstract socket whose name starts with this directory instead of a path-based - /// socket. - /// - /// This address is mostly relevant to server (typically bus broker) implementations. - TmpDir(PathBuf), +impl<'a> Unix<'a> { + /// One of the various `unix:` addresses. + pub fn kind(&self) -> &UnixAddrKind<'a> { + &self.kind + } } -impl Display for UnixSocket { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - fn fmt_unix_path(f: &mut Formatter<'_>, path: &OsStr) -> std::fmt::Result { - #[cfg(unix)] - { - use std::os::unix::ffi::OsStrExt; +impl<'a> TryFrom<&'a Address<'a>> for Unix<'a> { + type Error = Error; - encode_percents(f, path.as_bytes())?; - } + fn try_from(s: &'a Address<'a>) -> Result { + let mut kind = None; + let mut iter = s.key_val_iter(); + for (k, v) in &mut iter { + match k { + "path" | "dir" | "tmpdir" => { + let v = v.ok_or_else(|| Error::MissingValue(k.into()))?; + let v = decode_percents_os_str(v)?; + kind = Some(match k { + "path" => UnixAddrKind::Path(v), + "dir" => UnixAddrKind::Dir(v), + "tmpdir" => UnixAddrKind::Tmpdir(v), + // can't happen, we matched those earlier + _ => panic!(), + }); - #[cfg(windows)] - write!(f, "{}", path.to_str().ok_or(std::fmt::Error)?)?; + break; + } + "abstract" => { + let v = v.ok_or_else(|| Error::MissingValue(k.into()))?; + let v = decode_percents(v)?; + kind = Some(UnixAddrKind::Abstract(v)); - Ok(()) - } + break; + } + "runtime" => { + let v = v.ok_or_else(|| Error::MissingValue(k.into()))?; + let v = decode_percents_str(v)?; + if v != "yes" { + return Err(Error::InvalidValue(k.into())); + } + kind = Some(UnixAddrKind::Runtime); - match self { - UnixSocket::File(path) => { - f.write_str("path=")?; - fmt_unix_path(f, path.as_os_str())?; + break; + } + _ => continue, } - #[cfg(target_os = "linux")] - UnixSocket::Abstract(name) => { - f.write_str("abstract=")?; - fmt_unix_path(f, name)?; - } - UnixSocket::Dir(path) => { - f.write_str("dir=")?; - fmt_unix_path(f, path.as_os_str())?; - } - UnixSocket::TmpDir(path) => { - f.write_str("tmpdir=")?; - fmt_unix_path(f, path.as_os_str())?; + } + let Some(kind) = kind else { + return Err(Error::Other( + "invalid `unix:` address, missing required key".into(), + )); + }; + for (k, _) in iter { + match k { + "path" | "dir" | "tmpdir" | "abstract" | "runtime" => { + return Err(Error::Other("invalid address, only one of `path` `dir` `tmpdir` `abstract` or `runtime` expected".into())); + } + _ => (), } } - Ok(()) + Ok(Unix { kind }) + } +} + +impl KeyValFmtAdd for Unix<'_> { + fn key_val_fmt_add<'a: 'b, 'b>(&'a self, kv: KeyValFmt<'b>) -> KeyValFmt<'b> { + self.kind().key_val_fmt_add(kv) } } diff --git a/zbus/src/address/transport/unixexec.rs b/zbus/src/address/transport/unixexec.rs new file mode 100644 index 000000000..a74137049 --- /dev/null +++ b/zbus/src/address/transport/unixexec.rs @@ -0,0 +1,86 @@ +use std::{borrow::Cow, ffi::OsStr, fmt}; + +use super::{ + percent::{decode_percents_os_str, decode_percents_str, EncOsStr}, + Address, Error, KeyValFmt, KeyValFmtAdd, Result, +}; + +#[derive(Debug, PartialEq, Eq)] +struct Argv(usize); + +impl fmt::Display for Argv { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let n = self.0; + + write!(f, "argv{n}") + } +} + +/// `unixexec:` D-Bus transport. +/// +/// +#[derive(Debug, PartialEq, Eq)] +pub struct Unixexec<'a> { + path: Cow<'a, OsStr>, + argv: Vec<(usize, Cow<'a, str>)>, +} + +impl<'a> Unixexec<'a> { + /// Binary to execute. + /// + /// Path of the binary to execute, either an absolute path or a binary name that is searched for + /// in the default search path of the OS. This corresponds to the first argument of execlp(). + /// This key is mandatory. + pub fn path(&self) -> &OsStr { + self.path.as_ref() + } + + /// Arguments. + /// + /// Arguments to pass to the binary as `[(nth, arg),...]`. + pub fn argv(&self) -> &[(usize, Cow<'a, str>)] { + self.argv.as_ref() + } +} + +impl<'a> TryFrom<&'a Address<'a>> for Unixexec<'a> { + type Error = Error; + + fn try_from(s: &'a Address<'a>) -> Result { + let mut path = None; + let mut argv = Vec::new(); + + for (k, v) in s.key_val_iter() { + match (k, v) { + ("path", Some(v)) => { + path = Some(decode_percents_os_str(v)?); + } + (k, Some(v)) if k.starts_with("argv") => { + let n: usize = k[4..].parse().map_err(|_| Error::InvalidValue(k.into()))?; + let arg = decode_percents_str(v)?; + argv.push((n, arg)); + } + _ => continue, + } + } + + let Some(path) = path else { + return Err(Error::MissingKey("path".into())); + }; + + argv.sort_by_key(|(num, _)| *num); + + Ok(Self { path, argv }) + } +} + +impl KeyValFmtAdd for Unixexec<'_> { + fn key_val_fmt_add<'a: 'b, 'b>(&'a self, mut kv: KeyValFmt<'b>) -> KeyValFmt<'b> { + kv = kv.add("path", Some(EncOsStr(self.path()))); + for (n, arg) in self.argv() { + kv = kv.add(Argv(*n), Some(arg)); + } + + kv + } +} diff --git a/zbus/src/address/transport/vsock.rs b/zbus/src/address/transport/vsock.rs index f18c60df6..a160f00ba 100644 --- a/zbus/src/address/transport/vsock.rs +++ b/zbus/src/address/transport/vsock.rs @@ -1,49 +1,67 @@ -use crate::{Error, Result}; -use std::collections::HashMap; - -/// A VSOCK D-Bus address. -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct Vsock { - pub(super) cid: u32, - pub(super) port: u32, +use std::marker::PhantomData; + +use super::{percent::decode_percents_str, Address, Error, KeyValFmt, KeyValFmtAdd, Result}; + +/// `vsock:` D-Bus transport. +#[derive(Debug, PartialEq, Eq)] +pub struct Vsock<'a> { + // no cid means ANY + cid: Option, + // no port means ANY + port: Option, + // use a phantom lifetime for eventually future fields and consistency + phantom: PhantomData<&'a ()>, } -impl Vsock { - /// Create a new VSOCK address. - pub fn new(cid: u32, port: u32) -> Self { - Self { cid, port } +impl<'a> Vsock<'a> { + /// The VSOCK port. + pub fn port(&self) -> Option { + self.port } - /// The Client ID. - pub fn cid(&self) -> u32 { + /// The VSOCK CID. + pub fn cid(&self) -> Option { self.cid } +} - /// The port. - pub fn port(&self) -> u32 { - self.port - } +impl<'a> TryFrom<&'a Address<'a>> for Vsock<'a> { + type Error = Error; + + fn try_from(s: &'a Address<'a>) -> Result { + let mut port = None; + let mut cid = None; + + for (k, v) in s.key_val_iter() { + match (k, v) { + ("port", Some(v)) => { + port = Some( + decode_percents_str(v)? + .parse() + .map_err(|_| Error::InvalidValue(k.into()))?, + ); + } + ("cid", Some(v)) => { + cid = Some( + decode_percents_str(v)? + .parse() + .map_err(|_| Error::InvalidValue(k.into()))?, + ) + } + _ => continue, + } + } - pub(super) fn from_options(opts: HashMap<&str, &str>) -> Result { - let cid = opts - .get("cid") - .ok_or_else(|| Error::Address("VSOCK address is missing cid=".into()))?; - let cid = cid - .parse::() - .map_err(|e| Error::Address(format!("Failed to parse VSOCK cid `{}`: {}", cid, e)))?; - let port = opts - .get("port") - .ok_or_else(|| Error::Address("VSOCK address is missing port=".into()))?; - let port = port - .parse::() - .map_err(|e| Error::Address(format!("Failed to parse VSOCK port `{}`: {}", port, e)))?; - - Ok(Self { cid, port }) + Ok(Vsock { + port, + cid, + phantom: PhantomData, + }) } } -impl std::fmt::Display for Vsock { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "vsock:cid={},port={}", self.cid, self.port) +impl KeyValFmtAdd for Vsock<'_> { + fn key_val_fmt_add<'a: 'b, 'b>(&'a self, kv: KeyValFmt<'b>) -> KeyValFmt<'b> { + kv.add("cid", self.cid()).add("port", self.port()) } } diff --git a/zbus/src/blocking/connection/builder.rs b/zbus/src/blocking/connection/builder.rs index 17beb196b..b6d86d08f 100644 --- a/zbus/src/blocking/connection/builder.rs +++ b/zbus/src/blocking/connection/builder.rs @@ -15,8 +15,8 @@ use zvariant::{ObjectPath, Str}; #[cfg(feature = "p2p")] use crate::Guid; use crate::{ - address::Address, blocking::Connection, connection::socket::BoxedSplit, names::WellKnownName, - object_server::Interface, utils::block_on, AuthMechanism, Error, Result, + address::ToAddresses, blocking::Connection, connection::socket::BoxedSplit, + names::WellKnownName, object_server::Interface, utils::block_on, AuthMechanism, Error, Result, }; /// A builder for [`zbus::blocking::Connection`]. @@ -40,10 +40,9 @@ impl<'a> Builder<'a> { /// Create a builder for a connection that will use the given [D-Bus bus address]. /// /// [D-Bus bus address]: https://dbus.freedesktop.org/doc/dbus-specification.html#addresses - pub fn address(address: A) -> Result + pub fn address<'t, A>(address: &'t A) -> Result where - A: TryInto
, - A::Error: Into, + A: ToAddresses<'t> + ?Sized, { crate::connection::Builder::address(address).map(Self) } diff --git a/zbus/src/connection/builder.rs b/zbus/src/connection/builder.rs index fcf56257b..9e7083f44 100644 --- a/zbus/src/connection/builder.rs +++ b/zbus/src/connection/builder.rs @@ -24,13 +24,14 @@ use vsock::VsockStream; use zvariant::{ObjectPath, Str}; use crate::{ - address::{self, Address}, + address::{Address, ToAddresses}, names::{InterfaceName, WellKnownName}, object_server::{ArcInterface, Interface}, Connection, Error, Executor, Guid, OwnedGuid, Result, }; use super::{ + connect::connect_address, handshake::{AuthMechanism, Authenticated}, socket::{BoxedSplit, ReadHalf, Split, WriteHalf}, }; @@ -47,7 +48,7 @@ enum Target { feature = "tokio-vsock" ))] VsockStream(VsockStream), - Address(Address), + Address(Vec>), Socket(Split, Box>), AuthenticatedSocket(Split, Box>), } @@ -79,12 +80,12 @@ assert_impl_all!(Builder<'_>: Send, Sync, Unpin); impl<'a> Builder<'a> { /// Create a builder for the session/user message bus connection. pub fn session() -> Result { - Ok(Self::new(Target::Address(Address::session()?))) + Self::address(&crate::address::session()?) } /// Create a builder for the system-wide message bus connection. pub fn system() -> Result { - Ok(Self::new(Target::Address(Address::system()?))) + Self::address(&crate::address::system()?) } /// Create a builder for a connection that will use the given [D-Bus bus address]. @@ -118,14 +119,17 @@ impl<'a> Builder<'a> { /// current session using `ibus address` command. /// /// [D-Bus bus address]: https://dbus.freedesktop.org/doc/dbus-specification.html#addresses - pub fn address(address: A) -> Result + pub fn address<'t, A>(address: &'t A) -> Result where - A: TryInto
, - A::Error: Into, + A: ToAddresses<'t> + ?Sized, { - Ok(Self::new(Target::Address( - address.try_into().map_err(Into::into)?, - ))) + let addr = address + .to_addresses() + .filter_map(std::result::Result::ok) + .map(|a| a.to_owned()) + .collect(); + + Ok(Builder::new(Target::Address(addr))) } /// Create a builder for a connection that will use the given unix stream. @@ -533,17 +537,9 @@ impl<'a> Builder<'a> { #[cfg(feature = "tokio-vsock")] Target::VsockStream(stream) => stream.into(), Target::Address(address) => { - guid = address.guid().map(|g| g.to_owned().into()); - match address.connect().await? { - #[cfg(any(unix, not(feature = "tokio")))] - address::transport::Stream::Unix(stream) => stream.into(), - address::transport::Stream::Tcp(stream) => stream.into(), - #[cfg(any( - all(feature = "vsock", not(feature = "tokio")), - feature = "tokio-vsock" - ))] - address::transport::Stream::Vsock(stream) => stream.into(), - } + return connect_address(&address) + .await + .map(|(split, guid)| (split, guid, false)); } Target::Socket(stream) => stream, Target::AuthenticatedSocket(stream) => { diff --git a/zbus/src/connection/connect/macos.rs b/zbus/src/connection/connect/macos.rs new file mode 100644 index 000000000..21db9b20c --- /dev/null +++ b/zbus/src/connection/connect/macos.rs @@ -0,0 +1,54 @@ +#![cfg(target_os = "macos")] + +use super::socket; +use crate::{ + address::{transport::Transport, Address}, + process::run, + Error, Result, +}; + +async fn launchd_bus_address(env_key: &str) -> Result> { + let output = run("launchctl", ["getenv", env_key]) + .await + .expect("failed to wait on launchctl output"); + + if !output.status.success() { + return Err(Error::Address(format!( + "launchctl terminated with code: {}", + output.status + ))); + } + + let addr = String::from_utf8(output.stdout) + .map_err(|e| Error::Address(format!("Unable to parse launchctl output as UTF-8: {}", e)))?; + + Ok(format!("unix:path={}", addr.trim()).try_into()?) +} + +pub(crate) async fn connect( + l: &crate::address::transport::Launchd<'_>, +) -> Result { + let addr = launchd_bus_address(l.env()).await?; + + match addr.transport()? { + Transport::Unix(t) => socket::unix::connect(&t).await, + _ => Err(Error::Address(format!("Address is unsupported: {}", addr))), + } +} + +#[cfg(test)] +mod tests { + use crate::address::{transport::Transport, Address}; + + #[test] + fn connect_launchd_session_bus() { + let addr: Address<'_> = "launchd:env=DBUS_LAUNCHD_SESSION_BUS_SOCKET" + .try_into() + .unwrap(); + let launchd = match addr.transport().unwrap() { + Transport::Launchd(l) => l, + _ => unreachable!(), + }; + crate::utils::block_on(super::connect(&launchd)).unwrap(); + } +} diff --git a/zbus/src/connection/connect/mod.rs b/zbus/src/connection/connect/mod.rs new file mode 100644 index 000000000..e137ba751 --- /dev/null +++ b/zbus/src/connection/connect/mod.rs @@ -0,0 +1,62 @@ +use std::{future::Future, pin::Pin}; +use tracing::debug; + +use crate::{ + address::{transport::Transport, Address}, + Error, Guid, OwnedGuid, Result, +}; + +use super::socket::{self, BoxedSplit}; + +mod macos; +mod win32; + +type ConnectResult = Result<(BoxedSplit, Option)>; + +fn connect(addr: &Address<'_>) -> Pin>> { + let addr = addr.to_owned(); + Box::pin(async move { + let guid = match addr.guid() { + Some(g) => Some(Guid::try_from(g.as_ref())?.into()), + _ => None, + }; + let split = match addr.transport()? { + Transport::Tcp(t) => socket::tcp::connect(&t).await?.into(), + Transport::NonceTcp(t) => socket::tcp::connect_nonce(&t).await?.into(), + #[cfg(any(unix, not(feature = "tokio")))] + Transport::Unix(u) => socket::unix::connect(&u).await?.into(), + #[cfg(any( + all(feature = "vsock", not(feature = "tokio")), + feature = "tokio-vsock" + ))] + Transport::Vsock(v) => socket::vsock::connect(&v).await?.into(), + #[cfg(target_os = "macos")] + Transport::Launchd(l) => macos::connect(&l).await?.into(), + #[cfg(target_os = "windows")] + Transport::Autolaunch(l) => { + return win32::connect(&l).await; + } + _ => { + return Err(Error::Address(format!("Unhandled address: {}", addr))); + } + }; + Ok((split, guid)) + }) +} + +pub(crate) async fn connect_address( + address: &[Address<'_>], +) -> Result<(BoxedSplit, Option)> { + for addr in address { + match connect(addr).await { + Ok(res) => { + return Ok(res); + } + Err(e) => { + debug!("Failed to connect to: {}", e); + continue; + } + } + } + Err(Error::Address("No connectable address".into())) +} diff --git a/zbus/src/connection/connect/win32.rs b/zbus/src/connection/connect/win32.rs new file mode 100644 index 000000000..481e4098a --- /dev/null +++ b/zbus/src/connection/connect/win32.rs @@ -0,0 +1,41 @@ +#![cfg(target_os = "windows")] + +use super::BoxedSplit; +use crate::{ + address::{transport::Transport, Address}, + win32::autolaunch_bus_address, + Error, OwnedGuid, Result, +}; + +pub(crate) async fn connect( + l: &crate::address::transport::Autolaunch<'_>, +) -> Result<(BoxedSplit, Option)> { + if l.scope().is_some() { + return Err(Error::Address( + "autolaunch with scope isn't supported yet".into(), + )); + } + + let addr: Address<'_> = autolaunch_bus_address()?.try_into()?; + + if let Transport::Autolaunch(_) = addr.transport()? { + return Err(Error::Address("Recursive autolaunch: address".into())); + } + + super::connect(&addr).await +} + +#[cfg(test)] +mod tests { + #[test] + fn connect_autolaunch_session_bus() { + use crate::address::{transport::Transport, Address}; + + let addr: Address<'_> = "autolaunch:".try_into().unwrap(); + let autolaunch = match addr.transport().unwrap() { + Transport::Autolaunch(l) => l, + _ => unreachable!(), + }; + crate::utils::block_on(super::connect(&autolaunch)).unwrap(); + } +} diff --git a/zbus/src/connection/mod.rs b/zbus/src/connection/mod.rs index b1e45ca58..69fff600c 100644 --- a/zbus/src/connection/mod.rs +++ b/zbus/src/connection/mod.rs @@ -43,6 +43,8 @@ use socket_reader::SocketReader; pub(crate) mod handshake; use handshake::Authenticated; +mod connect; + const DEFAULT_MAX_QUEUED: usize = 64; const DEFAULT_MAX_METHOD_RETURN_QUEUED: usize = 8; @@ -1374,28 +1376,6 @@ mod tests { use std::{pin::pin, time::Duration}; use test_log::test; - #[cfg(windows)] - #[test] - fn connect_autolaunch_session_bus() { - let addr = - crate::win32::autolaunch_bus_address().expect("Unable to get session bus address"); - - crate::block_on(async { addr.connect().await }).expect("Unable to connect to session bus"); - } - - #[cfg(target_os = "macos")] - #[test] - fn connect_launchd_session_bus() { - use crate::address::{transport::Launchd, Address, Transport}; - crate::block_on(async { - let addr = Address::from(Transport::Launchd(Launchd::new( - "DBUS_LAUNCHD_SESSION_BUS_SOCKET", - ))); - addr.connect().await - }) - .expect("Unable to connect to session bus"); - } - #[test] #[timeout(15000)] fn disconnect_on_drop() { diff --git a/zbus/src/connection/socket/mod.rs b/zbus/src/connection/socket/mod.rs index d0a17eac6..9d9062c0a 100644 --- a/zbus/src/connection/socket/mod.rs +++ b/zbus/src/connection/socket/mod.rs @@ -6,9 +6,9 @@ pub use channel::Channel; mod split; pub use split::{BoxedSplit, Split}; -mod tcp; -mod unix; -mod vsock; +pub(crate) mod tcp; +pub(crate) mod unix; +pub(crate) mod vsock; #[cfg(not(feature = "tokio"))] use async_io::Async; diff --git a/zbus/src/connection/socket/tcp.rs b/zbus/src/connection/socket/tcp.rs index 144bd2eae..d66db8108 100644 --- a/zbus/src/connection/socket/tcp.rs +++ b/zbus/src/connection/socket/tcp.rs @@ -6,6 +6,8 @@ use std::os::fd::BorrowedFd; #[cfg(not(feature = "tokio"))] use std::{net::TcpStream, sync::Arc}; +use crate::{address::transport::TcpFamily, Error, Result}; + use super::{ReadHalf, RecvmsgResult, WriteHalf}; #[cfg(feature = "tokio")] use super::{Socket, Split}; @@ -170,3 +172,183 @@ fn win32_credentials_from_addr( .set_process_id(pid) .set_windows_sid(sid)) } + +#[cfg(not(feature = "tokio"))] +type Stream = Async; +#[cfg(feature = "tokio")] +type Stream = tokio::net::TcpStream; + +async fn connect_with(host: &str, port: u16, family: Option) -> Result { + #[cfg(not(feature = "tokio"))] + { + use std::net::ToSocketAddrs; + + let host = host.to_string(); + let addrs = crate::Task::spawn_blocking( + move || -> Result> { + let addrs = (host, port).to_socket_addrs()?.filter(|a| { + if let Some(family) = family { + if family == TcpFamily::IPv4 { + a.is_ipv4() + } else { + a.is_ipv6() + } + } else { + true + } + }); + Ok(addrs.collect()) + }, + "connect tcp", + ) + .await + .map_err(|e| Error::Address(format!("Failed to receive TCP addresses: {e}")))?; + + // we could attempt connections in parallel? + let mut last_err = Error::Address("Failed to connect".into()); + for addr in addrs { + match Stream::connect(addr).await { + Ok(stream) => return Ok(stream), + Err(e) => last_err = e.into(), + } + } + + Err(last_err) + } + + #[cfg(feature = "tokio")] + { + // FIXME: doesn't handle family + let _ = family; + Stream::connect((host, port)) + .await + .map_err(|e| Error::InputOutput(e.into())) + } +} + +pub(crate) async fn connect(addr: &crate::address::transport::Tcp<'_>) -> Result { + let Some(host) = addr.host() else { + return Err(Error::Address("No host in address".into())); + }; + let Some(port) = addr.port() else { + return Err(Error::Address("No port in address".into())); + }; + + connect_with(host, port, addr.family()).await +} + +pub(crate) async fn connect_nonce( + addr: &crate::address::transport::NonceTcp<'_>, +) -> Result { + let Some(host) = addr.host() else { + return Err(Error::Address("No host in address".into())); + }; + let Some(port) = addr.port() else { + return Err(Error::Address("No port in address".into())); + }; + let Some(noncefile) = addr.noncefile() else { + return Err(Error::Address("No noncefile in address".into())); + }; + + #[allow(unused_mut)] + let mut stream = connect_with(host, port, addr.family()).await?; + + #[cfg(not(feature = "tokio"))] + { + use std::io::prelude::*; + + let nonce = std::fs::read(noncefile)?; + let mut nonce = &nonce[..]; + + while !nonce.is_empty() { + let len = stream.write_with(|mut s| s.write(nonce)).await?; + nonce = &nonce[len..]; + } + } + + #[cfg(feature = "tokio")] + { + let nonce = tokio::fs::read(noncefile).await?; + tokio::io::AsyncWriteExt::write_all(&mut stream, &nonce).await?; + } + + Ok(stream) +} + +#[cfg(test)] +mod tests { + use crate::address::{transport::Transport, Address}; + + #[test] + fn connect() { + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + let addr: Address<'_> = format!("tcp:host=localhost,port={port}") + .try_into() + .unwrap(); + let tcp = match addr.transport().unwrap() { + Transport::Tcp(tcp) => tcp, + _ => unreachable!(), + }; + crate::utils::block_on(super::connect(&tcp)).unwrap(); + } + + #[test] + fn connect_nonce_tcp() { + struct PercentEncoded<'a>(&'a [u8]); + + impl std::fmt::Display for PercentEncoded<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + crate::address::encode_percents(f, self.0) + } + } + + use std::io::Write; + + const TEST_COOKIE: &[u8] = b"VERILY SECRETIVE"; + + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + + let mut cookie = tempfile::NamedTempFile::new().unwrap(); + cookie.as_file_mut().write_all(TEST_COOKIE).unwrap(); + + let encoded_path = format!( + "{}", + PercentEncoded(cookie.path().to_str().unwrap().as_ref()) + ); + + let addr: Address<'_> = + format!("nonce-tcp:host=localhost,port={port},noncefile={encoded_path}") + .try_into() + .unwrap(); + let tcp = match addr.transport().unwrap() { + Transport::NonceTcp(tcp) => tcp, + _ => unreachable!(), + }; + + let (sender, receiver) = std::sync::mpsc::sync_channel(1); + + std::thread::spawn(move || { + use std::io::Read; + + let mut client = listener.incoming().next().unwrap().unwrap(); + + let mut buf = [0u8; 16]; + client.read_exact(&mut buf).unwrap(); + + sender.send(buf == TEST_COOKIE).unwrap(); + }); + + crate::utils::block_on(super::connect_nonce(&tcp)).unwrap(); + + let saw_cookie = receiver + .recv_timeout(std::time::Duration::from_millis(100)) + .expect("nonce file content hasn't been received by server thread in time"); + + assert!( + saw_cookie, + "nonce file content has been received, but was invalid" + ); + } +} diff --git a/zbus/src/connection/socket/unix.rs b/zbus/src/connection/socket/unix.rs index f93329767..72edff430 100644 --- a/zbus/src/connection/socket/unix.rs +++ b/zbus/src/connection/socket/unix.rs @@ -2,7 +2,7 @@ use async_io::Async; #[cfg(unix)] use std::os::unix::io::{AsRawFd, BorrowedFd, FromRawFd, RawFd}; -#[cfg(all(unix, not(feature = "tokio")))] +#[cfg(unix)] use std::os::unix::net::UnixStream; #[cfg(not(feature = "tokio"))] use std::sync::Arc; @@ -24,6 +24,8 @@ use nix::{ #[cfg(unix)] use crate::utils::FDS_MAX; +#[cfg(any(unix, not(feature = "tokio")))] +use crate::{Error, Result}; #[cfg(all(unix, not(feature = "tokio")))] #[async_trait::async_trait] @@ -378,3 +380,64 @@ fn send_zero_byte_blocking(fd: RawFd) -> io::Result { ) .map_err(|e| e.into()) } + +#[cfg(not(feature = "tokio"))] +pub(crate) type Stream = Async; +#[cfg(all(unix, feature = "tokio"))] +pub(crate) type Stream = tokio::net::UnixStream; + +#[cfg(any(unix, not(feature = "tokio")))] +pub(crate) async fn connect(addr: &crate::address::transport::Unix<'_>) -> Result { + use crate::address::transport::UnixAddrKind; + #[cfg(target_os = "linux")] + use std::os::linux::net::SocketAddrExt; + #[cfg(unix)] + use std::os::unix::net::SocketAddr; + + let kind = addr.kind(); + + // This is a `path` in case of Windows until uds_windows provides the needed API: + // https://github.com/haraldh/rust_uds_windows/issues/14 + let addr = match kind { + #[cfg(unix)] + UnixAddrKind::Path(p) => SocketAddr::from_pathname(std::path::Path::new(p))?, + #[cfg(windows)] + UnixAddrKind::Path(p) => p.clone().into_owned(), + #[cfg(target_os = "linux")] + UnixAddrKind::Abstract(name) => SocketAddr::from_abstract_name(name)?, + _ => return Err(Error::Address("Address is not connectable".into())), + }; + + let stream = crate::Task::spawn_blocking( + move || -> Result<_> { + #[cfg(unix)] + let stream = UnixStream::connect_addr(&addr)?; + #[cfg(windows)] + let stream = UnixStream::connect(addr)?; + stream.set_nonblocking(true)?; + + Ok(stream) + }, + "unix stream connection", + ) + .await?; + + #[cfg(not(feature = "tokio"))] + { + Async::new(stream).map_err(|e| Error::InputOutput(e.into())) + } + + #[cfg(feature = "tokio")] + { + #[cfg(unix)] + { + tokio::net::UnixStream::from_std(stream).map_err(|e| Error::InputOutput(e.into())) + } + + #[cfg(not(unix))] + { + let _ = stream; + Err(Error::Unsupported) + } + } +} diff --git a/zbus/src/connection/socket/vsock.rs b/zbus/src/connection/socket/vsock.rs index ec26e487b..4889167c2 100644 --- a/zbus/src/connection/socket/vsock.rs +++ b/zbus/src/connection/socket/vsock.rs @@ -1,3 +1,13 @@ +#[cfg(all(feature = "vsock", not(feature = "tokio")))] +#[cfg(not(feature = "tokio"))] +use async_io::Async; + +#[cfg(any( + all(feature = "vsock", not(feature = "tokio")), + feature = "tokio-vsock" +))] +use crate::{Error, Result}; + #[cfg(feature = "tokio-vsock")] use super::{Socket, Split}; @@ -104,3 +114,37 @@ impl super::WriteHalf for tokio_vsock::WriteHalf { tokio::io::AsyncWriteExt::shutdown(self).await } } + +#[cfg(all(feature = "vsock", not(feature = "tokio")))] +type Stream = Async; +#[cfg(feature = "tokio-vsock")] +type Stream = tokio_vsock::VsockStream; + +#[cfg(any( + all(feature = "vsock", not(feature = "tokio")), + feature = "tokio-vsock" +))] +pub(crate) async fn connect(addr: &crate::address::transport::Vsock<'_>) -> Result { + let Some(cid) = addr.cid() else { + return Err(Error::Address("No cid in address".into())); + }; + let Some(port) = addr.port() else { + return Err(Error::Address("No port in address".into())); + }; + + #[cfg(all(feature = "vsock", not(feature = "tokio")))] + { + let stream = crate::Task::spawn_blocking( + move || vsock::VsockStream::connect_with_cid_port(cid, port), + "connect vsock", + ) + .await + .map_err(|e| Error::Address(format!("Failed to connect: {e}")))?; + Ok(Async::new(stream).map_err(|e| Error::InputOutput(e.into()))?) + } + + #[cfg(feature = "tokio-vsock")] + Stream::connect(cid, port) + .await + .map_err(|e| Error::InputOutput(e.into())) +} diff --git a/zbus/src/error.rs b/zbus/src/error.rs index 2980a492b..4adb4ed30 100644 --- a/zbus/src/error.rs +++ b/zbus/src/error.rs @@ -199,6 +199,12 @@ impl From for Error { } } +impl From for Error { + fn from(val: crate::address::Error) -> Self { + Error::Address(val.to_string()) + } +} + impl From for Error { fn from(val: VariantError) -> Self { Error::Variant(val) diff --git a/zbus/src/win32.rs b/zbus/src/win32.rs index 664fe1071..2d4d47653 100644 --- a/zbus/src/win32.rs +++ b/zbus/src/win32.rs @@ -30,7 +30,6 @@ use windows_sys::Win32::{ }, }; -use crate::Address; #[cfg(not(feature = "tokio"))] use uds_windows::UnixStream; @@ -305,7 +304,7 @@ fn read_shm(name: &str) -> Result, crate::Error> { Ok(data.to_bytes().to_owned()) } -pub fn autolaunch_bus_address() -> Result { +pub fn autolaunch_bus_address() -> Result { let mutex = Mutex::new("DBusAutolaunchMutex")?; let _guard = mutex.lock(); @@ -313,7 +312,7 @@ pub fn autolaunch_bus_address() -> Result { let addr = String::from_utf8(addr) .map_err(|e| crate::Error::Address(format!("Unable to parse address as UTF-8: {}", e)))?; - addr.parse() + Ok(addr) } #[cfg(test)]