From 7c54454407c15f218f6ea116ee9c3115d16e551b Mon Sep 17 00:00:00 2001 From: sodiboo Date: Wed, 3 Apr 2024 00:48:58 +0200 Subject: [PATCH 01/10] implement version checking; streamed IPC streamed IPC will allow multiple requests per connection --- Cargo.lock | 1 + Cargo.toml | 3 +- niri-ipc/Cargo.toml | 1 + niri-ipc/src/lib.rs | 9 +++- niri-ipc/src/socket.rs | 73 ++++++++++++++++++++++++++ src/cli.rs | 2 + src/ipc/client.rs | 114 ++++++++++++++++++++++++++++++----------- src/ipc/server.rs | 32 +++++++----- 8 files changed, 188 insertions(+), 47 deletions(-) create mode 100644 niri-ipc/src/socket.rs diff --git a/Cargo.lock b/Cargo.lock index 8292783edb..76a82a6e87 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2202,6 +2202,7 @@ version = "0.1.4" dependencies = [ "clap", "serde", + "serde_json", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 4997c42084..7f6dfa323f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ anyhow = "1.0.81" bitflags = "2.5.0" clap = { version = "~4.4.18", features = ["derive"] } serde = { version = "1.0.197", features = ["derive"] } +serde_json = "1.0.115" tracing = { version = "0.1.40", features = ["max_level_trace", "release_max_level_debug"] } tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } tracy-client = { version = "0.17.0", default-features = false } @@ -67,7 +68,7 @@ portable-atomic = { version = "1.6.0", default-features = false, features = ["fl profiling = "1.0.15" sd-notify = "0.4.1" serde.workspace = true -serde_json = "1.0.115" +serde_json.workspace = true smithay-drm-extras.workspace = true tracing-subscriber.workspace = true tracing.workspace = true diff --git a/niri-ipc/Cargo.toml b/niri-ipc/Cargo.toml index 9bd2e83387..e0963116ee 100644 --- a/niri-ipc/Cargo.toml +++ b/niri-ipc/Cargo.toml @@ -10,6 +10,7 @@ repository.workspace = true [dependencies] clap = { workspace = true, optional = true } serde.workspace = true +serde_json.workspace = true [features] clap = ["dep:clap"] diff --git a/niri-ipc/src/lib.rs b/niri-ipc/src/lib.rs index fb515e5de2..7b3a2ff0e0 100644 --- a/niri-ipc/src/lib.rs +++ b/niri-ipc/src/lib.rs @@ -6,12 +6,15 @@ use std::str::FromStr; use serde::{Deserialize, Serialize}; -/// Name of the environment variable containing the niri IPC socket path. -pub const SOCKET_PATH_ENV: &str = "NIRI_SOCKET"; +mod socket; + +pub use socket::{NiriSocket, SOCKET_PATH_ENV}; /// Request from client to niri. #[derive(Debug, Serialize, Deserialize, Clone)] pub enum Request { + /// Request the version string for the running niri instance. + Version, /// Request information about connected outputs. Outputs, /// Request information about the focused window. @@ -35,6 +38,8 @@ pub type Reply = Result; pub enum Response { /// A request that does not need a response was handled successfully. Handled, + /// The version string for the running niri instance. + Version(String), /// Information about connected outputs. /// /// Map from connector name to output info. diff --git a/niri-ipc/src/socket.rs b/niri-ipc/src/socket.rs new file mode 100644 index 0000000000..9bbdbf0609 --- /dev/null +++ b/niri-ipc/src/socket.rs @@ -0,0 +1,73 @@ +use std::io::{self, Write}; +use std::os::unix::net::UnixStream; +use std::path::Path; + +use serde_json::de::IoRead; +use serde_json::StreamDeserializer; + +use crate::{Reply, Request}; + +/// Name of the environment variable containing the niri IPC socket path. +pub const SOCKET_PATH_ENV: &str = "NIRI_SOCKET"; + +/// A client for the niri IPC server. +/// +/// This struct is used to communicate with the niri IPC server. It handles the socket connection +/// and serialization/deserialization of messages. +pub struct NiriSocket { + stream: UnixStream, + responses: StreamDeserializer<'static, IoRead, Reply>, +} + +impl TryFrom for NiriSocket { + type Error = io::Error; + fn try_from(stream: UnixStream) -> io::Result { + let responses = serde_json::Deserializer::from_reader(stream.try_clone()?).into_iter(); + Ok(Self { stream, responses }) + } +} + +impl NiriSocket { + /// Connects to the default niri IPC socket + /// + /// This is equivalent to calling [Self::connect] with the value of the [SOCKET_PATH_ENV] + /// environment variable. + pub fn new() -> io::Result { + let socket_path = std::env::var_os(SOCKET_PATH_ENV).ok_or_else(|| { + io::Error::new( + io::ErrorKind::NotFound, + format!("{SOCKET_PATH_ENV} is not set, are you running this within niri?"), + ) + })?; + Self::connect(socket_path) + } + + /// Connect to the socket at the given path + /// + /// See also: [UnixStream::connect] + pub fn connect(path: impl AsRef) -> io::Result { + Self::try_from(UnixStream::connect(path.as_ref())?) + } + + /// Handle a request to the niri IPC server + /// + /// # Returns + /// Ok(Ok([Response](crate::Response))) corresponds to a successful response from the running + /// niri instance. Ok(Err([String])) corresponds to an error received from the running niri + /// instance. Err([std::io::Error]) corresponds to an error in the IPC communication. + pub fn send(&mut self, request: Request) -> io::Result { + let mut buf = serde_json::to_vec(&request).unwrap(); + writeln!(buf).unwrap(); + self.stream.write_all(&buf)?; // .context("error writing IPC request")?; + self.stream.flush()?; + + if let Some(next) = self.responses.next() { + next.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) + } else { + Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "no response from server", + )) + } + } +} diff --git a/src/cli.rs b/src/cli.rs index 1a1b3eb99a..2678a536c0 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -52,6 +52,8 @@ pub enum Sub { #[derive(Subcommand)] pub enum Msg { + /// Print the version string of the running niri instance. + Version, /// List connected outputs. Outputs, /// Print information about the focused window. diff --git a/src/ipc/client.rs b/src/ipc/client.rs index 97413d0b19..2cda4a9f35 100644 --- a/src/ipc/client.rs +++ b/src/ipc/client.rs @@ -1,54 +1,106 @@ -use std::env; -use std::io::{Read, Write}; -use std::net::Shutdown; -use std::os::unix::net::UnixStream; - use anyhow::{anyhow, bail, Context}; -use niri_ipc::{LogicalOutput, Mode, Output, Reply, Request, Response}; +use niri_ipc::{LogicalOutput, Mode, NiriSocket, Output, Request, Response}; use crate::cli::Msg; +use crate::utils::version; pub fn handle_msg(msg: Msg, json: bool) -> anyhow::Result<()> { - let socket_path = env::var_os(niri_ipc::SOCKET_PATH_ENV).with_context(|| { - format!( - "{} is not set, are you running this within niri?", - niri_ipc::SOCKET_PATH_ENV - ) - })?; + let mut client = NiriSocket::new().context("error initializing the niri ipc client")?; - let mut stream = - UnixStream::connect(socket_path).context("error connecting to {socket_path}")?; + // Default SIGPIPE so that our prints don't panic on stdout closing. + unsafe { + libc::signal(libc::SIGPIPE, libc::SIG_DFL); + } let request = match &msg { + Msg::Version => Request::Version, Msg::Outputs => Request::Outputs, Msg::FocusedWindow => Request::FocusedWindow, Msg::Action { action } => Request::Action(action.clone()), }; - let mut buf = serde_json::to_vec(&request).unwrap(); - stream - .write_all(&buf) - .context("error writing IPC request")?; - stream - .shutdown(Shutdown::Write) - .context("error closing IPC stream for writing")?; - buf.clear(); - stream - .read_to_end(&mut buf) - .context("error reading IPC response")?; + let version_reply = client + .send(Request::Version) + .context("error sending version request to niri")?; - let reply: Reply = serde_json::from_slice(&buf).context("error parsing IPC reply")?; + match version_reply.clone() { + Ok(response) => 'a: { + if matches!(msg, Msg::Version) && !json { + // Print a nicer warning for human consumers. + break 'a; + } + let Response::Version(server_version) = response else { + bail!("unexpected response: expected Version, got {response:?}"); + }; + + let my_version = version(); + + if my_version != server_version { + eprintln!("Warning: niri msg was invoked with a different version of niri than the running compositor."); + eprintln!("niri msg: {my_version}"); + eprintln!("compositor: {server_version}"); + eprintln!("Did you forget to restart niri after an update?"); + eprintln!(); + } + } + Err(_) => { + eprintln!("Warning: unable to get server version."); + eprintln!("Did you forget to restart niri after an update?"); + eprintln!(); + } + } + + let reply = match msg { + Msg::Version => version_reply, + _ => { + if version_reply.is_err() { + eprintln!("Assuming niri does not support streaming IPC. Reconnecting..."); + eprintln!(); + client = NiriSocket::new().context("error initializing the niri ipc client")?; + } + + client + .send(request) + .context("error sending request to niri")? + } + }; let response = reply .map_err(|msg| anyhow!(msg)) .context("niri could not handle the request")?; - // Default SIGPIPE so that our prints don't panic on stdout closing. - unsafe { - libc::signal(libc::SIGPIPE, libc::SIG_DFL); - } - match msg { + Msg::Version => { + let Response::Version(server_version) = response else { + bail!("unexpected response: expected Version, got {response:?}"); + }; + + if json { + #[derive(serde::Serialize)] + struct Versions { + cli: String, + compositor: String, + } + let version = serde_json::to_string(&Versions { + cli: version(), + compositor: server_version, + }) + .context("error formatting response")?; + println!("{version}"); + return Ok(()); + } + + let client_version = version(); + + println!("niri msg is {client_version}"); + println!("the compositor is {server_version}"); + if client_version != server_version { + eprintln!(); + eprintln!("These are different"); + eprintln!("Did you forget to restart niri after an update?"); + } + println!(); + } Msg::Outputs => { let Response::Outputs(outputs) = response else { bail!("unexpected response: expected Outputs, got {response:?}"); diff --git a/src/ipc/server.rs b/src/ipc/server.rs index 38a9b3dd39..9fb1613f7f 100644 --- a/src/ipc/server.rs +++ b/src/ipc/server.rs @@ -1,3 +1,4 @@ +use std::io::Write; use std::os::unix::net::{UnixListener, UnixStream}; use std::path::PathBuf; use std::sync::{Arc, Mutex}; @@ -7,7 +8,7 @@ use anyhow::Context; use calloop::io::Async; use directories::BaseDirs; use futures_util::io::{AsyncReadExt, BufReader}; -use futures_util::{AsyncBufReadExt, AsyncWriteExt}; +use futures_util::{AsyncBufReadExt, AsyncWriteExt, StreamExt}; use niri_ipc::{Request, Response}; use smithay::desktop::Window; use smithay::reexports::calloop::generic::Generic; @@ -18,6 +19,7 @@ use smithay::wayland::shell::xdg::XdgToplevelSurfaceData; use crate::backend::IpcOutputMap; use crate::niri::State; +use crate::utils::version; pub struct IpcServer { pub socket_path: PathBuf, @@ -106,21 +108,24 @@ fn on_new_ipc_client(state: &mut State, stream: UnixStream) { async fn handle_client(ctx: ClientCtx, stream: Async<'_, UnixStream>) -> anyhow::Result<()> { let (read, mut write) = stream.split(); - let mut buf = String::new(); - // Read a single line to allow extensibility in the future to keep reading. - BufReader::new(read) - .read_line(&mut buf) - .await - .context("error reading request")?; + // note that we can't use the stream json deserializer here + // because the stream is asynchronous and the deserializer doesn't support that + // https://github.com/serde-rs/json/issues/575 - let reply = process(&ctx, &buf).map_err(|err| { - warn!("error processing IPC request: {err:?}"); - err.to_string() - }); + let mut lines = BufReader::new(read).lines(); - let buf = serde_json::to_vec(&reply).context("error formatting reply")?; - write.write_all(&buf).await.context("error writing reply")?; + while let Some(line) = lines.next().await { + let reply = process(&ctx, &line?).map_err(|err| { + warn!("error processing IPC request: {err:?}"); + err.to_string() + }); + + let mut buf = serde_json::to_vec(&reply).context("error formatting reply")?; + writeln!(buf).unwrap(); + write.write_all(&buf).await.context("error writing reply")?; + write.flush().await.context("error flushing reply")?; + } Ok(()) } @@ -129,6 +134,7 @@ fn process(ctx: &ClientCtx, buf: &str) -> anyhow::Result { let request: Request = serde_json::from_str(buf).context("error parsing request")?; let response = match request { + Request::Version => Response::Version(version()), Request::Outputs => { let ipc_outputs = ctx.ipc_outputs.lock().unwrap().clone(); Response::Outputs(ipc_outputs) From d4e32e226611a6d46631b727935f27bea1cad52a Mon Sep 17 00:00:00 2001 From: sodiboo Date: Sun, 7 Apr 2024 06:19:47 +0200 Subject: [PATCH 02/10] add nonsense request --- niri-ipc/src/lib.rs | 2 ++ src/cli.rs | 2 ++ src/ipc/client.rs | 4 ++++ src/ipc/server.rs | 18 ++++++++++++------ 4 files changed, 20 insertions(+), 6 deletions(-) diff --git a/niri-ipc/src/lib.rs b/niri-ipc/src/lib.rs index 7b3a2ff0e0..28f09b1244 100644 --- a/niri-ipc/src/lib.rs +++ b/niri-ipc/src/lib.rs @@ -13,6 +13,8 @@ pub use socket::{NiriSocket, SOCKET_PATH_ENV}; /// Request from client to niri. #[derive(Debug, Serialize, Deserialize, Clone)] pub enum Request { + /// Always responds with an error. + Nonsense, /// Request the version string for the running niri instance. Version, /// Request information about connected outputs. diff --git a/src/cli.rs b/src/cli.rs index 2678a536c0..922eabb2b0 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -52,6 +52,8 @@ pub enum Sub { #[derive(Subcommand)] pub enum Msg { + /// Print an error message. + Nonsense, /// Print the version string of the running niri instance. Version, /// List connected outputs. diff --git a/src/ipc/client.rs b/src/ipc/client.rs index 2cda4a9f35..7f37a05a71 100644 --- a/src/ipc/client.rs +++ b/src/ipc/client.rs @@ -13,6 +13,7 @@ pub fn handle_msg(msg: Msg, json: bool) -> anyhow::Result<()> { } let request = match &msg { + Msg::Nonsense => Request::Nonsense, Msg::Version => Request::Version, Msg::Outputs => Request::Outputs, Msg::FocusedWindow => Request::FocusedWindow, @@ -70,6 +71,9 @@ pub fn handle_msg(msg: Msg, json: bool) -> anyhow::Result<()> { .context("niri could not handle the request")?; match msg { + Msg::Nonsense => { + bail!("unexpected response: expected an error, got {response:?}"); + } Msg::Version => { let Response::Version(server_version) = response else { bail!("unexpected response: expected Version, got {response:?}"); diff --git a/src/ipc/server.rs b/src/ipc/server.rs index 9fb1613f7f..39af175965 100644 --- a/src/ipc/server.rs +++ b/src/ipc/server.rs @@ -9,7 +9,7 @@ use calloop::io::Async; use directories::BaseDirs; use futures_util::io::{AsyncReadExt, BufReader}; use futures_util::{AsyncBufReadExt, AsyncWriteExt, StreamExt}; -use niri_ipc::{Request, Response}; +use niri_ipc::{Reply, Request, Response}; use smithay::desktop::Window; use smithay::reexports::calloop::generic::Generic; use smithay::reexports::calloop::{Interest, LoopHandle, Mode, PostAction}; @@ -116,9 +116,16 @@ async fn handle_client(ctx: ClientCtx, stream: Async<'_, UnixStream>) -> anyhow: let mut lines = BufReader::new(read).lines(); while let Some(line) = lines.next().await { - let reply = process(&ctx, &line?).map_err(|err| { + let reply: Reply = serde_json::from_str(&match line { + Ok(line) => line, + // ConnectionReset is expected when the client disconnects. + Err(err) if err.kind() == io::ErrorKind::ConnectionReset => break, + Err(err) => return Err(err).context("error reading line"), + }) + .map_err(|err| format!("error parsing request: {err}")) + .and_then(|req| process(&ctx, req)) + .inspect_err(|err| { warn!("error processing IPC request: {err:?}"); - err.to_string() }); let mut buf = serde_json::to_vec(&reply).context("error formatting reply")?; @@ -130,10 +137,9 @@ async fn handle_client(ctx: ClientCtx, stream: Async<'_, UnixStream>) -> anyhow: Ok(()) } -fn process(ctx: &ClientCtx, buf: &str) -> anyhow::Result { - let request: Request = serde_json::from_str(buf).context("error parsing request")?; - +fn process(ctx: &ClientCtx, request: Request) -> Reply { let response = match request { + Request::Nonsense => return Err("nonsense request".into()), Request::Version => Response::Version(version()), Request::Outputs => { let ipc_outputs = ctx.ipc_outputs.lock().unwrap().clone(); From 44d13219b650e35a5a6460484c5db18ef3c5a450 Mon Sep 17 00:00:00 2001 From: sodiboo Date: Sun, 7 Apr 2024 06:23:56 +0200 Subject: [PATCH 03/10] change inline struct to json macro --- src/ipc/client.rs | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/ipc/client.rs b/src/ipc/client.rs index 7f37a05a71..b3d5381aa5 100644 --- a/src/ipc/client.rs +++ b/src/ipc/client.rs @@ -1,5 +1,6 @@ use anyhow::{anyhow, bail, Context}; use niri_ipc::{LogicalOutput, Mode, NiriSocket, Output, Request, Response}; +use serde_json::json; use crate::cli::Msg; use crate::utils::version; @@ -80,17 +81,13 @@ pub fn handle_msg(msg: Msg, json: bool) -> anyhow::Result<()> { }; if json { - #[derive(serde::Serialize)] - struct Versions { - cli: String, - compositor: String, - } - let version = serde_json::to_string(&Versions { - cli: version(), - compositor: server_version, - }) - .context("error formatting response")?; - println!("{version}"); + println!( + "{}", + json!({ + "cli": version(), + "compositor": server_version, + }) + ); return Ok(()); } From 095f2f1f8d38f9a2ce5c31b110a54d44716dd47c Mon Sep 17 00:00:00 2001 From: sodiboo Date: Sun, 7 Apr 2024 06:25:52 +0200 Subject: [PATCH 04/10] only check version if request actually fails --- src/ipc/client.rs | 85 +++++++++++++++++++++++------------------------ 1 file changed, 42 insertions(+), 43 deletions(-) diff --git a/src/ipc/client.rs b/src/ipc/client.rs index b3d5381aa5..6041285603 100644 --- a/src/ipc/client.rs +++ b/src/ipc/client.rs @@ -1,4 +1,4 @@ -use anyhow::{anyhow, bail, Context}; +use anyhow::{bail, Context}; use niri_ipc::{LogicalOutput, Mode, NiriSocket, Output, Request, Response}; use serde_json::json; @@ -6,7 +6,8 @@ use crate::cli::Msg; use crate::utils::version; pub fn handle_msg(msg: Msg, json: bool) -> anyhow::Result<()> { - let mut client = NiriSocket::new().context("error initializing the niri ipc client")?; + let mut client = NiriSocket::new() + .context("a communication error occured while trying to initialize the socket")?; // Default SIGPIPE so that our prints don't panic on stdout closing. unsafe { @@ -21,56 +22,54 @@ pub fn handle_msg(msg: Msg, json: bool) -> anyhow::Result<()> { Msg::Action { action } => Request::Action(action.clone()), }; - let version_reply = client - .send(Request::Version) - .context("error sending version request to niri")?; + let reply = client + .send(request) + .context("a communication error occurred while sending request to niri")?; - match version_reply.clone() { - Ok(response) => 'a: { - if matches!(msg, Msg::Version) && !json { - // Print a nicer warning for human consumers. - break 'a; - } - let Response::Version(server_version) = response else { - bail!("unexpected response: expected Version, got {response:?}"); - }; - - let my_version = version(); - - if my_version != server_version { - eprintln!("Warning: niri msg was invoked with a different version of niri than the running compositor."); - eprintln!("niri msg: {my_version}"); - eprintln!("compositor: {server_version}"); - eprintln!("Did you forget to restart niri after an update?"); - eprintln!(); - } - } - Err(_) => { - eprintln!("Warning: unable to get server version."); - eprintln!("Did you forget to restart niri after an update?"); + let response = match reply { + Ok(r) => r, + Err(err_msg) => { + eprintln!("The compositor returned an error:"); eprintln!(); - } - } + eprintln!("{err_msg}"); - let reply = match msg { - Msg::Version => version_reply, - _ => { - if version_reply.is_err() { - eprintln!("Assuming niri does not support streaming IPC. Reconnecting..."); + if matches!(msg, Msg::Version) { eprintln!(); - client = NiriSocket::new().context("error initializing the niri ipc client")?; + eprintln!("Note: unable to get the compositor's version."); + eprintln!("Did you forget to restart niri after an update?"); + } else { + // We're making a new client here just for some vague notion of + // backwards compatibility. + // It is in general not necessary to do so. + match NiriSocket::new().and_then(|mut client| client.send(Request::Version)) { + Ok(Ok(Response::Version(server_version))) => { + let my_version = version(); + if my_version != server_version { + eprintln!(); + eprintln!("Note: niri msg was invoked with a different version of niri than the running compositor."); + eprintln!("niri msg: {my_version}"); + eprintln!("compositor: {server_version}"); + eprintln!("Did you forget to restart niri after an update?"); + } + } + Ok(Ok(_)) => { + // nonsensical response, do not add confusing context + } + Ok(Err(_)) => { + eprintln!(); + eprintln!("Note: unable to get the compositor's version."); + eprintln!("Did you forget to restart niri after an update?"); + } + Err(_) => { + // communication error, do not add irrelevant context + } + } } - client - .send(request) - .context("error sending request to niri")? + return Ok(()); } }; - let response = reply - .map_err(|msg| anyhow!(msg)) - .context("niri could not handle the request")?; - match msg { Msg::Nonsense => { bail!("unexpected response: expected an error, got {response:?}"); From 29439eb8efeea23ea9d1c77626089b7880ca291f Mon Sep 17 00:00:00 2001 From: sodiboo Date: Sun, 7 Apr 2024 06:43:54 +0200 Subject: [PATCH 05/10] fix usage of inspect_err (MSRV 1.72.0; stabilized 1.76.0) --- src/ipc/server.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/ipc/server.rs b/src/ipc/server.rs index 39af175965..7aa2d8dc35 100644 --- a/src/ipc/server.rs +++ b/src/ipc/server.rs @@ -118,15 +118,16 @@ async fn handle_client(ctx: ClientCtx, stream: Async<'_, UnixStream>) -> anyhow: while let Some(line) = lines.next().await { let reply: Reply = serde_json::from_str(&match line { Ok(line) => line, - // ConnectionReset is expected when the client disconnects. + // ConnectionReset is expected when the client disconnects Err(err) if err.kind() == io::ErrorKind::ConnectionReset => break, Err(err) => return Err(err).context("error reading line"), }) .map_err(|err| format!("error parsing request: {err}")) - .and_then(|req| process(&ctx, req)) - .inspect_err(|err| { + .and_then(|req| process(&ctx, req)); + + if let Err(err) = &reply { warn!("error processing IPC request: {err:?}"); - }); + } let mut buf = serde_json::to_vec(&reply).context("error formatting reply")?; writeln!(buf).unwrap(); From d491258ae4a138717f86b685edc5dd76307e8762 Mon Sep 17 00:00:00 2001 From: sodiboo Date: Mon, 15 Apr 2024 22:36:01 +0200 Subject: [PATCH 06/10] "nonsense request" -> "return error" --- niri-ipc/src/lib.rs | 4 ++-- src/cli.rs | 2 +- src/ipc/client.rs | 4 ++-- src/ipc/server.rs | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/niri-ipc/src/lib.rs b/niri-ipc/src/lib.rs index 28f09b1244..8911e76c01 100644 --- a/niri-ipc/src/lib.rs +++ b/niri-ipc/src/lib.rs @@ -13,8 +13,8 @@ pub use socket::{NiriSocket, SOCKET_PATH_ENV}; /// Request from client to niri. #[derive(Debug, Serialize, Deserialize, Clone)] pub enum Request { - /// Always responds with an error. - Nonsense, + /// Always responds with an error. (For testing error handling) + ReturnError, /// Request the version string for the running niri instance. Version, /// Request information about connected outputs. diff --git a/src/cli.rs b/src/cli.rs index 922eabb2b0..6296b1b3d5 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -53,7 +53,7 @@ pub enum Sub { #[derive(Subcommand)] pub enum Msg { /// Print an error message. - Nonsense, + Error, /// Print the version string of the running niri instance. Version, /// List connected outputs. diff --git a/src/ipc/client.rs b/src/ipc/client.rs index 6041285603..a1297cc029 100644 --- a/src/ipc/client.rs +++ b/src/ipc/client.rs @@ -15,7 +15,7 @@ pub fn handle_msg(msg: Msg, json: bool) -> anyhow::Result<()> { } let request = match &msg { - Msg::Nonsense => Request::Nonsense, + Msg::Error => Request::ReturnError, Msg::Version => Request::Version, Msg::Outputs => Request::Outputs, Msg::FocusedWindow => Request::FocusedWindow, @@ -71,7 +71,7 @@ pub fn handle_msg(msg: Msg, json: bool) -> anyhow::Result<()> { }; match msg { - Msg::Nonsense => { + Msg::Error => { bail!("unexpected response: expected an error, got {response:?}"); } Msg::Version => { diff --git a/src/ipc/server.rs b/src/ipc/server.rs index 7aa2d8dc35..7372daf434 100644 --- a/src/ipc/server.rs +++ b/src/ipc/server.rs @@ -140,7 +140,7 @@ async fn handle_client(ctx: ClientCtx, stream: Async<'_, UnixStream>) -> anyhow: fn process(ctx: &ClientCtx, request: Request) -> Reply { let response = match request { - Request::Nonsense => return Err("nonsense request".into()), + Request::ReturnError => return Err("client wanted an error".into()), Request::Version => Response::Version(version()), Request::Outputs => { let ipc_outputs = ctx.ipc_outputs.lock().unwrap().clone(); From 50b38d4accdd46a10da3f53a272a8449901fe804 Mon Sep 17 00:00:00 2001 From: sodiboo Date: Mon, 15 Apr 2024 22:36:50 +0200 Subject: [PATCH 07/10] oneshot connections --- niri-ipc/src/socket.rs | 2 +- src/ipc/client.rs | 7 ++----- src/ipc/server.rs | 33 ++++++++++++++++++--------------- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/niri-ipc/src/socket.rs b/niri-ipc/src/socket.rs index 9bbdbf0609..60abd3f9d2 100644 --- a/niri-ipc/src/socket.rs +++ b/niri-ipc/src/socket.rs @@ -55,7 +55,7 @@ impl NiriSocket { /// Ok(Ok([Response](crate::Response))) corresponds to a successful response from the running /// niri instance. Ok(Err([String])) corresponds to an error received from the running niri /// instance. Err([std::io::Error]) corresponds to an error in the IPC communication. - pub fn send(&mut self, request: Request) -> io::Result { + pub fn send(mut self, request: Request) -> io::Result { let mut buf = serde_json::to_vec(&request).unwrap(); writeln!(buf).unwrap(); self.stream.write_all(&buf)?; // .context("error writing IPC request")?; diff --git a/src/ipc/client.rs b/src/ipc/client.rs index a1297cc029..bf0f0dad4c 100644 --- a/src/ipc/client.rs +++ b/src/ipc/client.rs @@ -6,7 +6,7 @@ use crate::cli::Msg; use crate::utils::version; pub fn handle_msg(msg: Msg, json: bool) -> anyhow::Result<()> { - let mut client = NiriSocket::new() + let client = NiriSocket::new() .context("a communication error occured while trying to initialize the socket")?; // Default SIGPIPE so that our prints don't panic on stdout closing. @@ -38,10 +38,7 @@ pub fn handle_msg(msg: Msg, json: bool) -> anyhow::Result<()> { eprintln!("Note: unable to get the compositor's version."); eprintln!("Did you forget to restart niri after an update?"); } else { - // We're making a new client here just for some vague notion of - // backwards compatibility. - // It is in general not necessary to do so. - match NiriSocket::new().and_then(|mut client| client.send(Request::Version)) { + match NiriSocket::new().and_then(|client| client.send(Request::Version)) { Ok(Ok(Response::Version(server_version))) => { let my_version = version(); if my_version != server_version { diff --git a/src/ipc/server.rs b/src/ipc/server.rs index 7372daf434..d3c2de6006 100644 --- a/src/ipc/server.rs +++ b/src/ipc/server.rs @@ -115,26 +115,29 @@ async fn handle_client(ctx: ClientCtx, stream: Async<'_, UnixStream>) -> anyhow: let mut lines = BufReader::new(read).lines(); - while let Some(line) = lines.next().await { - let reply: Reply = serde_json::from_str(&match line { - Ok(line) => line, - // ConnectionReset is expected when the client disconnects - Err(err) if err.kind() == io::ErrorKind::ConnectionReset => break, - Err(err) => return Err(err).context("error reading line"), - }) + let line = match lines.next().await.unwrap_or(Err(io::Error::new(io::ErrorKind::UnexpectedEof, "Unreachable; BufReader returned None but when the stream ends, the connection should be reset"))) { + Ok(line) => line, + Err(err) if err.kind() == io::ErrorKind::ConnectionReset => return Ok(()), + Err(err) => return Err(err).context("error reading line"), + }; + + let reply: Reply = serde_json::from_str(&line) .map_err(|err| format!("error parsing request: {err}")) .and_then(|req| process(&ctx, req)); - if let Err(err) = &reply { - warn!("error processing IPC request: {err:?}"); - } - - let mut buf = serde_json::to_vec(&reply).context("error formatting reply")?; - writeln!(buf).unwrap(); - write.write_all(&buf).await.context("error writing reply")?; - write.flush().await.context("error flushing reply")?; + if let Err(err) = &reply { + warn!("error processing IPC request: {err:?}"); } + let mut buf = serde_json::to_vec(&reply).context("error formatting reply")?; + writeln!(buf).unwrap(); + write.write_all(&buf).await.context("error writing reply")?; + write.flush().await.context("error flushing reply")?; + + // We do not check for more lines at this moment. + // Dropping the stream will reset the connection before we read them. + // For now, a client should not be sending more than one request per connection. + Ok(()) } From f1bfef9a36337696b57efa3e4e41ca04dece1f5b Mon Sep 17 00:00:00 2001 From: sodiboo Date: Wed, 17 Apr 2024 22:43:17 +0200 Subject: [PATCH 08/10] ipc: implement Request trait with statically-guaranteed Response type --- niri-ipc/src/lib.rs | 221 +++++++++++++-- niri-ipc/src/socket.rs | 12 +- src/ipc/client.rs | 618 ++++++++++++++++++++++++++++------------- src/ipc/server.rs | 115 ++++++-- 4 files changed, 710 insertions(+), 256 deletions(-) diff --git a/niri-ipc/src/lib.rs b/niri-ipc/src/lib.rs index 8911e76c01..bbf3a6621b 100644 --- a/niri-ipc/src/lib.rs +++ b/niri-ipc/src/lib.rs @@ -1,7 +1,7 @@ //! Types for communicating with niri via IPC. #![warn(missing_docs)] -use std::collections::HashMap; +use std::collections::BTreeMap; use std::str::FromStr; use serde::{Deserialize, Serialize}; @@ -10,21 +10,199 @@ mod socket; pub use socket::{NiriSocket, SOCKET_PATH_ENV}; -/// Request from client to niri. +mod private { + pub trait Sealed {} +} + +// TODO: remove ResponseDecoder and AnyRequest? + +#[allow(missing_docs)] +pub trait ResponseDecoder { + type Output: for<'de> Deserialize<'de>; + fn decode(&self, value: serde_json::Value) -> serde_json::Result; +} + +#[derive(Debug, Clone, Copy)] +#[allow(missing_docs)] +pub struct TrivialDecoder Deserialize<'de>>(std::marker::PhantomData); + +impl Deserialize<'de>> Default for TrivialDecoder { + fn default() -> Self { + Self(std::marker::PhantomData) + } +} + +impl Deserialize<'de>> ResponseDecoder for TrivialDecoder { + type Output = T; + + fn decode(&self, value: serde_json::Value) -> serde_json::Result { + serde_json::from_value(value) + } +} + +/// A request that can be sent to niri. +pub trait Request: + Serialize + for<'de> Deserialize<'de> + private::Sealed + Into +{ + /// The type of the response that niri sends for this request. + type Response: Serialize + for<'de> Deserialize<'de>; + + #[allow(missing_docs)] + fn decoder(&self) -> impl ResponseDecoder> + 'static; + + /// Convert the request into a RequestMessage (for serialization). + fn into_message(self) -> RequestMessage; +} + +macro_rules! requests { + (@$item:item$(;)?) => { $item }; + ($($(#[$m:meta])*$variant:ident($v:vis struct $request:ident$($p:tt)?) -> $response:ty;)*) => { + #[derive(Debug, Serialize, Deserialize, Clone)] + /// A plain tag for each request type. + pub enum RequestType { + $( + $(#[$m])* + $variant, + )* + } + + #[derive(Debug, Serialize, Deserialize, Clone)] + enum AnyRequest { + $( + $(#[$m])* + $variant($request), + )* + } + + #[derive(Debug, Serialize, Deserialize, Clone)] + enum AnyResponse { + $( + $(#[$m])* + $variant($response), + )* + } + + impl private::Sealed for AnyRequest {} + + struct AnyResponseDecoder(RequestType); + + impl ResponseDecoder for AnyResponseDecoder { + type Output = Reply; + + fn decode(&self, value: serde_json::Value) -> serde_json::Result { + match self.0 { + $( + RequestType::$variant => TrivialDecoder::>::default().decode(value).map(|r| r.map(AnyResponse::$variant)), + )* + } + } + } + + impl TryFrom for AnyRequest { + type Error = serde_json::Error; + + fn try_from(message: RequestMessage) -> serde_json::Result { + match message.request_type { + $( + RequestType::$variant => serde_json::from_value(message.request_body).map(AnyRequest::$variant), + )* + } + } + } + + impl Request for AnyRequest { + type Response = AnyResponse; + + fn decoder(&self) -> impl ResponseDecoder> + 'static { + match self { + $( + AnyRequest::$variant(_) => AnyResponseDecoder(RequestType::$variant), + )* + } + } + + fn into_message(self) -> RequestMessage { + match self { + $( + AnyRequest::$variant(request) => request.into_message(), + )* + } + } + } + + + $( + requests!(@ + $(#[$m])* + #[derive(Debug, Serialize, Deserialize, Clone)] + $v struct $request $($p)?; + ); + + impl From<$request> for AnyRequest { + fn from(request: $request) -> Self { + AnyRequest::$variant(request) + } + } + + impl crate::private::Sealed for $request {} + + impl crate::Request for $request { + type Response = $response; + + fn decoder(&self) -> impl crate::ResponseDecoder> + 'static { + TrivialDecoder::>::default() + } + + fn into_message(self) -> RequestMessage { + RequestMessage { + request_type: RequestType::$variant, + request_body: serde_json::to_value(self).unwrap(), + } + } + } + )* + } +} + +/// The message format for IPC communication. +/// +/// This is mainly to avoid using sum types in IPC communication, which are more annoying to use +/// with non-Rust tooling. #[derive(Debug, Serialize, Deserialize, Clone)] -pub enum Request { - /// Always responds with an error. (For testing error handling) - ReturnError, - /// Request the version string for the running niri instance. - Version, - /// Request information about connected outputs. - Outputs, - /// Request information about the focused window. - FocusedWindow, - /// Perform an action. - Action(Action), +pub struct RequestMessage { + /// The type of the request. + pub request_type: RequestType, + /// The raw JSON body of the request. + pub request_body: serde_json::Value, } +impl From for RequestMessage { + fn from(value: R) -> Self { + value.into_message() + } +} + +/// Uninstantiable +#[derive(Debug, Serialize, Deserialize, Clone)] +pub enum Never {} + +requests!( + /// Always responds with an error (for testing error handling). + ReturnError(pub struct ErrorRequest) -> Never; + + /// Requests the version string for the running niri instance. + Version(pub struct VersionRequest) -> String; + + /// Requests information about connected outputs. + Outputs(pub struct OutputRequest) -> BTreeMap; + + /// Requests information about the focused window. + FocusedWindow(pub struct FocusedWindowRequest) -> Option; + + /// Requests that the compositor perform an action. + Action(pub struct ActionRequest(pub Action)) -> (); +); + /// Reply from niri to client. /// /// Every request gets one reply. @@ -33,22 +211,7 @@ pub enum Request { /// * If the request does not need any particular response, it will be /// `Reply::Ok(Response::Handled)`. Kind of like an `Ok(())`. /// * Otherwise, it will be `Reply::Ok(response)` with one of the other [`Response`] variants. -pub type Reply = Result; - -/// Successful response from niri to client. -#[derive(Debug, Serialize, Deserialize, Clone)] -pub enum Response { - /// A request that does not need a response was handled successfully. - Handled, - /// The version string for the running niri instance. - Version(String), - /// Information about connected outputs. - /// - /// Map from connector name to output info. - Outputs(HashMap), - /// Information about the focused window. - FocusedWindow(Option), -} +pub type Reply = Result; /// Actions that niri can perform. // Variants in this enum should match the spelling of the ones in niri-config. Most, but not all, diff --git a/niri-ipc/src/socket.rs b/niri-ipc/src/socket.rs index 60abd3f9d2..650d8e3d67 100644 --- a/niri-ipc/src/socket.rs +++ b/niri-ipc/src/socket.rs @@ -5,7 +5,7 @@ use std::path::Path; use serde_json::de::IoRead; use serde_json::StreamDeserializer; -use crate::{Reply, Request}; +use crate::{Reply, Request, ResponseDecoder}; /// Name of the environment variable containing the niri IPC socket path. pub const SOCKET_PATH_ENV: &str = "NIRI_SOCKET"; @@ -16,7 +16,7 @@ pub const SOCKET_PATH_ENV: &str = "NIRI_SOCKET"; /// and serialization/deserialization of messages. pub struct NiriSocket { stream: UnixStream, - responses: StreamDeserializer<'static, IoRead, Reply>, + responses: StreamDeserializer<'static, IoRead, serde_json::Value>, } impl TryFrom for NiriSocket { @@ -55,14 +55,16 @@ impl NiriSocket { /// Ok(Ok([Response](crate::Response))) corresponds to a successful response from the running /// niri instance. Ok(Err([String])) corresponds to an error received from the running niri /// instance. Err([std::io::Error]) corresponds to an error in the IPC communication. - pub fn send(mut self, request: Request) -> io::Result { - let mut buf = serde_json::to_vec(&request).unwrap(); + pub fn send_request(mut self, request: R) -> io::Result> { + let decoder = request.decoder(); + let mut buf = serde_json::to_vec(&request.into_message()).unwrap(); writeln!(buf).unwrap(); self.stream.write_all(&buf)?; // .context("error writing IPC request")?; self.stream.flush()?; if let Some(next) = self.responses.next() { - next.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) + next.and_then(|v| decoder.decode(v)) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) } else { Err(io::Error::new( io::ErrorKind::UnexpectedEof, diff --git a/src/ipc/client.rs b/src/ipc/client.rs index bf0f0dad4c..60779047d5 100644 --- a/src/ipc/client.rs +++ b/src/ipc/client.rs @@ -1,242 +1,470 @@ -use anyhow::{bail, Context}; -use niri_ipc::{LogicalOutput, Mode, NiriSocket, Output, Request, Response}; +use anyhow::Context; +use niri_ipc::{ + ActionRequest, ErrorRequest, FocusedWindowRequest, LogicalOutput, Mode, NiriSocket, Output, + OutputRequest, Request, VersionRequest, +}; use serde_json::json; use crate::cli::Msg; use crate::utils::version; +struct CompositorError { + message: String, + version: Option, +} + +type MsgResult = Result; + +trait MsgRequest: Request { + fn json(response: Self::Response) -> serde_json::Value { + json!(response) + } + + fn show_to_human(response: Self::Response); + + fn check_version() -> Option { + if let Ok(Ok(version)) = VersionRequest.send() { + Some(version) + } else { + None + } + } + + fn send(self) -> anyhow::Result> { + let socket = NiriSocket::new().context("trying to initialize the socket")?; + let reply = socket + .send_request(self) + .context("while sending request to niri")?; + Ok(reply.map_err(|message| CompositorError { + message, + version: Self::check_version(), + })) + } +} + pub fn handle_msg(msg: Msg, json: bool) -> anyhow::Result<()> { - let client = NiriSocket::new() - .context("a communication error occured while trying to initialize the socket")?; + match &msg { + Msg::Error => run(json, ErrorRequest), + Msg::Version => run(json, VersionRequest), + Msg::Outputs => run(json, OutputRequest), + Msg::FocusedWindow => run(json, FocusedWindowRequest), + Msg::Action { action } => run(json, ActionRequest(action.clone())), + } +} + +fn run(json: bool, request: R) -> anyhow::Result<()> { + let reply = request.send().context("a communication error occurred")?; // Default SIGPIPE so that our prints don't panic on stdout closing. unsafe { libc::signal(libc::SIGPIPE, libc::SIG_DFL); } - let request = match &msg { - Msg::Error => Request::ReturnError, - Msg::Version => Request::Version, - Msg::Outputs => Request::Outputs, - Msg::FocusedWindow => Request::FocusedWindow, - Msg::Action { action } => Request::Action(action.clone()), - }; - - let reply = client - .send(request) - .context("a communication error occurred while sending request to niri")?; - - let response = match reply { - Ok(r) => r, - Err(err_msg) => { + match reply { + Ok(response) => { + if json { + println!("{}", R::json(response)); + } else { + R::show_to_human(response); + } + } + Err(CompositorError { + message, + version: server_version, + }) => { eprintln!("The compositor returned an error:"); eprintln!(); - eprintln!("{err_msg}"); + eprintln!("{message}"); - if matches!(msg, Msg::Version) { + if let Some(server_version) = server_version { + let my_version = version(); + if my_version != server_version { + eprintln!(); + eprintln!("Note: niri msg was invoked with a different version of niri than the running compositor."); + eprintln!("niri msg: {my_version}"); + eprintln!("compositor: {server_version}"); + eprintln!("Did you forget to restart niri after an update?"); + } + } else { eprintln!(); eprintln!("Note: unable to get the compositor's version."); eprintln!("Did you forget to restart niri after an update?"); - } else { - match NiriSocket::new().and_then(|client| client.send(Request::Version)) { - Ok(Ok(Response::Version(server_version))) => { - let my_version = version(); - if my_version != server_version { - eprintln!(); - eprintln!("Note: niri msg was invoked with a different version of niri than the running compositor."); - eprintln!("niri msg: {my_version}"); - eprintln!("compositor: {server_version}"); - eprintln!("Did you forget to restart niri after an update?"); - } - } - Ok(Ok(_)) => { - // nonsensical response, do not add confusing context - } - Ok(Err(_)) => { - eprintln!(); - eprintln!("Note: unable to get the compositor's version."); - eprintln!("Did you forget to restart niri after an update?"); - } - Err(_) => { - // communication error, do not add irrelevant context - } - } } - - return Ok(()); } - }; + } - match msg { - Msg::Error => { - bail!("unexpected response: expected an error, got {response:?}"); - } - Msg::Version => { - let Response::Version(server_version) = response else { - bail!("unexpected response: expected Version, got {response:?}"); - }; + Ok(()) +} - if json { - println!( - "{}", - json!({ - "cli": version(), - "compositor": server_version, - }) - ); - return Ok(()); - } +// fn _a() { - let client_version = version(); +// let reply = client +// .send(request) +// .context("a communication error occurred while sending request to niri")?; - println!("niri msg is {client_version}"); - println!("the compositor is {server_version}"); - if client_version != server_version { - eprintln!(); - eprintln!("These are different"); - eprintln!("Did you forget to restart niri after an update?"); - } - println!(); +// let response = match reply { +// Ok(r) => r, +// Err(err_msg) => { +// eprintln!("The compositor returned an error:"); +// eprintln!(); +// eprintln!("{err_msg}"); + +// if matches!(msg, Msg::Version) { +// eprintln!(); +// eprintln!("Note: unable to get the compositor's version."); +// eprintln!("Did you forget to restart niri after an update?"); +// } else { +// match NiriSocket::new().and_then(|client| client.send(RequestEnum::Version)) { +// Ok(Ok(Response::Version(server_version))) => { +// let my_version = version(); +// if my_version != server_version { +// eprintln!(); +// eprintln!("Note: niri msg was invoked with a different version of +// niri than the running compositor."); eprintln!("niri msg: +// {my_version}"); eprintln!("compositor: {server_version}"); +// eprintln!("Did you forget to restart niri after an update?"); +// } +// } +// Ok(Ok(_)) => { +// // nonsensical response, do not add confusing context +// } +// Ok(Err(_)) => { +// eprintln!(); +// eprintln!("Note: unable to get the compositor's version."); +// eprintln!("Did you forget to restart niri after an update?"); +// } +// Err(_) => { +// // communication error, do not add irrelevant context +// } +// } +// } + +// return Ok(()); +// } +// }; + +// match msg { +// Msg::Error => { +// bail!("unexpected response: expected an error, got {response:?}"); +// } +// Msg::Version => { +// let Response::Version(server_version) = response else { +// bail!("unexpected response: expected Version, got {response:?}"); +// }; + +// if json { +// println!( +// "{}", +// json!({ +// "cli": version(), +// "compositor": server_version, +// }) +// ); +// return Ok(()); +// } + +// let client_version = version(); + +// println!("niri msg is {client_version}"); +// println!("the compositor is {server_version}"); +// if client_version != server_version { +// eprintln!(); +// eprintln!("These are different"); +// eprintln!("Did you forget to restart niri after an update?"); +// } +// println!(); +// } +// Msg::Outputs => { +// let Response::Outputs(outputs) = response else { +// bail!("unexpected response: expected Outputs, got {response:?}"); +// }; + +// if json { +// let output = +// serde_json::to_string(&outputs).context("error formatting response")?; +// println!("{output}"); +// return Ok(()); +// } + +// let mut outputs = outputs.into_iter().collect::>(); +// outputs.sort_unstable_by(|a, b| a.0.cmp(&b.0)); + +// for (connector, output) in outputs.into_iter() { +// let Output { +// name, +// make, +// model, +// physical_size, +// modes, +// current_mode, +// logical, +// } = output; + +// println!(r#"Output "{connector}" ({make} - {model} - {name})"#); + +// if let Some(current) = current_mode { +// let mode = *modes +// .get(current) +// .context("invalid response: current mode does not exist")?; +// let Mode { +// width, +// height, +// refresh_rate, +// is_preferred, +// } = mode; +// let refresh = refresh_rate as f64 / 1000.; +// let preferred = if is_preferred { " (preferred)" } else { "" }; +// println!(" Current mode: {width}x{height} @ {refresh:.3} Hz{preferred}"); +// } else { +// println!(" Disabled"); +// } + +// if let Some((width, height)) = physical_size { +// println!(" Physical size: {width}x{height} mm"); +// } else { +// println!(" Physical size: unknown"); +// } + +// if let Some(logical) = logical { +// let LogicalOutput { +// x, +// y, +// width, +// height, +// scale, +// transform, +// } = logical; +// println!(" Logical position: {x}, {y}"); +// println!(" Logical size: {width}x{height}"); +// println!(" Scale: {scale}"); + +// let transform = match transform { +// niri_ipc::Transform::Normal => "normal", +// niri_ipc::Transform::_90 => "90° counter-clockwise", +// niri_ipc::Transform::_180 => "180°", +// niri_ipc::Transform::_270 => "270° counter-clockwise", +// niri_ipc::Transform::Flipped => "flipped horizontally", +// niri_ipc::Transform::Flipped90 => { +// "90° counter-clockwise, flipped horizontally" +// } +// niri_ipc::Transform::Flipped180 => "flipped vertically", +// niri_ipc::Transform::Flipped270 => { +// "270° counter-clockwise, flipped horizontally" +// } +// }; +// println!(" Transform: {transform}"); +// } + +// println!(" Available modes:"); +// for (idx, mode) in modes.into_iter().enumerate() { +// let Mode { +// width, +// height, +// refresh_rate, +// is_preferred, +// } = mode; +// let refresh = refresh_rate as f64 / 1000.; + +// let is_current = Some(idx) == current_mode; +// let qualifier = match (is_current, is_preferred) { +// (true, true) => " (current, preferred)", +// (true, false) => " (current)", +// (false, true) => " (preferred)", +// (false, false) => "", +// }; + +// println!(" {width}x{height}@{refresh:.3}{qualifier}"); +// } +// println!(); +// } +// } +// Msg::FocusedWindow => { +// let Response::FocusedWindow(window) = response else { +// bail!("unexpected response: expected FocusedWindow, got {response:?}"); +// }; + +// if json { +// let window = serde_json::to_string(&window).context("error formatting +// response")?; println!("{window}"); +// return Ok(()); +// } + +// if let Some(window) = window { +// println!("Focused window:"); + +// if let Some(title) = window.title { +// println!(" Title: \"{title}\""); +// } else { +// println!(" Title: (unset)"); +// } + +// if let Some(app_id) = window.app_id { +// println!(" App ID: \"{app_id}\""); +// } else { +// println!(" App ID: (unset)"); +// } +// } else { +// println!("No window is focused."); +// } +// } +// Msg::Action { .. } => { +// let Response::Handled = response else { +// bail!("unexpected response: expected Handled, got {response:?}"); +// }; +// } +// } + +// Ok(()) +// } + +impl MsgRequest for ErrorRequest { + fn json(response: Self::Response) -> serde_json::Value { + match response {} + } + + fn show_to_human(response: Self::Response) { + match response {} + } +} + +impl MsgRequest for VersionRequest { + fn check_version() -> Option { + // If the version request fails, we can't exactly try again. + None + } + fn json(response: Self::Response) -> serde_json::Value { + json!({ + "cli": version(), + "compositor": response, + }) + } + fn show_to_human(response: Self::Response) { + let client_version = version(); + let server_version = response; + println!("niri msg is {client_version}"); + println!("the compositor is {server_version}"); + if client_version != server_version { + eprintln!(); + eprintln!("These are different"); + eprintln!("Did you forget to restart niri after an update?"); } - Msg::Outputs => { - let Response::Outputs(outputs) = response else { - bail!("unexpected response: expected Outputs, got {response:?}"); - }; + println!(); + } +} - if json { - let output = - serde_json::to_string(&outputs).context("error formatting response")?; - println!("{output}"); - return Ok(()); - } +impl MsgRequest for OutputRequest { + fn show_to_human(response: Self::Response) { + for (connector, output) in response { + let Output { + name, + make, + model, + physical_size, + modes, + current_mode, + logical, + } = output; + + println!(r#"Output "{connector}" ({make} - {model} - {name})"#); - let mut outputs = outputs.into_iter().collect::>(); - outputs.sort_unstable_by(|a, b| a.0.cmp(&b.0)); - - for (connector, output) in outputs.into_iter() { - let Output { - name, - make, - model, - physical_size, - modes, - current_mode, - logical, - } = output; - - println!(r#"Output "{connector}" ({make} - {model} - {name})"#); - - if let Some(current) = current_mode { - let mode = *modes - .get(current) - .context("invalid response: current mode does not exist")?; - let Mode { - width, - height, - refresh_rate, - is_preferred, - } = mode; + match current_mode.map(|idx| modes.get(idx)) { + None => println!(" Disabled"), + Some(None) => println!(" Current mode: (invalid index)"), + Some(Some(&Mode { + width, + height, + refresh_rate, + is_preferred, + })) => { let refresh = refresh_rate as f64 / 1000.; let preferred = if is_preferred { " (preferred)" } else { "" }; println!(" Current mode: {width}x{height} @ {refresh:.3} Hz{preferred}"); - } else { - println!(" Disabled"); } + } - if let Some((width, height)) = physical_size { - println!(" Physical size: {width}x{height} mm"); - } else { - println!(" Physical size: unknown"); - } + if let Some((width, height)) = physical_size { + println!(" Physical size: {width}x{height} mm"); + } else { + println!(" Physical size: unknown"); + } - if let Some(logical) = logical { - let LogicalOutput { - x, - y, - width, - height, - scale, - transform, - } = logical; - println!(" Logical position: {x}, {y}"); - println!(" Logical size: {width}x{height}"); - println!(" Scale: {scale}"); - - let transform = match transform { - niri_ipc::Transform::Normal => "normal", - niri_ipc::Transform::_90 => "90° counter-clockwise", - niri_ipc::Transform::_180 => "180°", - niri_ipc::Transform::_270 => "270° counter-clockwise", - niri_ipc::Transform::Flipped => "flipped horizontally", - niri_ipc::Transform::Flipped90 => { - "90° counter-clockwise, flipped horizontally" - } - niri_ipc::Transform::Flipped180 => "flipped vertically", - niri_ipc::Transform::Flipped270 => { - "270° counter-clockwise, flipped horizontally" - } - }; - println!(" Transform: {transform}"); - } + if let Some(logical) = logical { + let LogicalOutput { + x, + y, + width, + height, + scale, + transform, + } = logical; + println!(" Logical position: {x}, {y}"); + println!(" Logical size: {width}x{height}"); + println!(" Scale: {scale}"); - println!(" Available modes:"); - for (idx, mode) in modes.into_iter().enumerate() { - let Mode { - width, - height, - refresh_rate, - is_preferred, - } = mode; - let refresh = refresh_rate as f64 / 1000.; + let transform = match transform { + niri_ipc::Transform::Normal => "normal", + niri_ipc::Transform::_90 => "90° counter-clockwise", + niri_ipc::Transform::_180 => "180°", + niri_ipc::Transform::_270 => "270° counter-clockwise", + niri_ipc::Transform::Flipped => "flipped horizontally", + niri_ipc::Transform::Flipped90 => "90° counter-clockwise, flipped horizontally", + niri_ipc::Transform::Flipped180 => "flipped vertically", + niri_ipc::Transform::Flipped270 => { + "270° counter-clockwise, flipped horizontally" + } + }; + println!(" Transform: {transform}"); + } - let is_current = Some(idx) == current_mode; - let qualifier = match (is_current, is_preferred) { - (true, true) => " (current, preferred)", - (true, false) => " (current)", - (false, true) => " (preferred)", - (false, false) => "", - }; + println!(" Available modes:"); + for (idx, mode) in modes.into_iter().enumerate() { + let Mode { + width, + height, + refresh_rate, + is_preferred, + } = mode; + let refresh = refresh_rate as f64 / 1000.; - println!(" {width}x{height}@{refresh:.3}{qualifier}"); - } - println!(); - } - } - Msg::FocusedWindow => { - let Response::FocusedWindow(window) = response else { - bail!("unexpected response: expected FocusedWindow, got {response:?}"); - }; + let is_current = Some(idx) == current_mode; + let qualifier = match (is_current, is_preferred) { + (true, true) => " (current, preferred)", + (true, false) => " (current)", + (false, true) => " (preferred)", + (false, false) => "", + }; - if json { - let window = serde_json::to_string(&window).context("error formatting response")?; - println!("{window}"); - return Ok(()); + println!(" {width}x{height}@{refresh:.3}{qualifier}"); } + println!(); + } + } +} - if let Some(window) = window { - println!("Focused window:"); +impl MsgRequest for FocusedWindowRequest { + fn show_to_human(response: Self::Response) { + if let Some(window) = response { + println!("Focused window:"); - if let Some(title) = window.title { - println!(" Title: \"{title}\""); - } else { - println!(" Title: (unset)"); - } + if let Some(title) = window.title { + println!(" Title: \"{title}\""); + } else { + println!(" Title: (unset)"); + } - if let Some(app_id) = window.app_id { - println!(" App ID: \"{app_id}\""); - } else { - println!(" App ID: (unset)"); - } + if let Some(app_id) = window.app_id { + println!(" App ID: \"{app_id}\""); } else { - println!("No window is focused."); + println!(" App ID: (unset)"); } - } - Msg::Action { .. } => { - let Response::Handled = response else { - bail!("unexpected response: expected Handled, got {response:?}"); - }; + } else { + println!("No window is focused."); } } +} - Ok(()) +impl MsgRequest for ActionRequest { + fn show_to_human(response: Self::Response) { + response + } } diff --git a/src/ipc/server.rs b/src/ipc/server.rs index d3c2de6006..e10a001ea3 100644 --- a/src/ipc/server.rs +++ b/src/ipc/server.rs @@ -9,7 +9,10 @@ use calloop::io::Async; use directories::BaseDirs; use futures_util::io::{AsyncReadExt, BufReader}; use futures_util::{AsyncBufReadExt, AsyncWriteExt, StreamExt}; -use niri_ipc::{Reply, Request, Response}; +use niri_ipc::{ + ActionRequest, ErrorRequest, FocusedWindowRequest, OutputRequest, Reply, Request, + RequestMessage, RequestType, VersionRequest, +}; use smithay::desktop::Window; use smithay::reexports::calloop::generic::Generic; use smithay::reexports::calloop::{Interest, LoopHandle, Mode, PostAction}; @@ -121,9 +124,7 @@ async fn handle_client(ctx: ClientCtx, stream: Async<'_, UnixStream>) -> anyhow: Err(err) => return Err(err).context("error reading line"), }; - let reply: Reply = serde_json::from_str(&line) - .map_err(|err| format!("error parsing request: {err}")) - .and_then(|req| process(&ctx, req)); + let reply = process(&ctx, line); if let Err(err) = &reply { warn!("error processing IPC request: {err:?}"); @@ -141,17 +142,79 @@ async fn handle_client(ctx: ClientCtx, stream: Async<'_, UnixStream>) -> anyhow: Ok(()) } -fn process(ctx: &ClientCtx, request: Request) -> Reply { - let response = match request { - Request::ReturnError => return Err("client wanted an error".into()), - Request::Version => Response::Version(version()), - Request::Outputs => { - let ipc_outputs = ctx.ipc_outputs.lock().unwrap().clone(); - Response::Outputs(ipc_outputs) - } - Request::FocusedWindow => { - let window = ctx.ipc_focused_window.lock().unwrap().clone(); - let window = window.map(|window| { +trait HandleRequest: Request { + fn handle(self, ctx: &ClientCtx) -> Reply; +} + +fn process(ctx: &ClientCtx, line: String) -> Reply { + let request: RequestMessage = serde_json::from_str(&line).map_err(|err| { + warn!("error parsing IPC request: {err:?}"); + "error parsing request" + })?; + + macro_rules! handle { + ($($variant:ident => $type:ty,)*) => { + match request.request_type { + $( + RequestType::$variant => { + let request = serde_json::from_value::<$type>(request.request_body).map_err(|err| + { + warn!("error parsing IPC request: {err:?}"); + "error parsing request" + })?; + HandleRequest::handle(request, ctx).and_then(|v| { + serde_json::to_value(v).map_err(|err| { + warn!("error serializing response to IPC request: {err:?}"); + "error serializing response".into() + }) + }) + } + )* + } + }; + } + + handle!( + ReturnError => ErrorRequest, + Version => VersionRequest, + Outputs => OutputRequest, + FocusedWindow => FocusedWindowRequest, + Action => ActionRequest, + ) +} + +impl HandleRequest for ErrorRequest { + fn handle(self, _ctx: &ClientCtx) -> Reply { + Err("client wanted an error".into()) + } +} + +impl HandleRequest for VersionRequest { + fn handle(self, _ctx: &ClientCtx) -> Reply { + Ok(version()) + } +} + +impl HandleRequest for OutputRequest { + fn handle(self, ctx: &ClientCtx) -> Reply { + Ok(ctx + .ipc_outputs + .lock() + .unwrap() + .clone() + .into_iter() + .collect()) + } +} + +impl HandleRequest for FocusedWindowRequest { + fn handle(self, ctx: &ClientCtx) -> Reply { + Ok(ctx + .ipc_focused_window + .lock() + .unwrap() + .clone() + .map(|window| { let wl_surface = window.toplevel().expect("no X11 support").wl_surface(); with_states(wl_surface, |states| { let role = states @@ -166,17 +229,15 @@ fn process(ctx: &ClientCtx, request: Request) -> Reply { app_id: role.app_id.clone(), } }) - }); - Response::FocusedWindow(window) - } - Request::Action(action) => { - let action = niri_config::Action::from(action); - ctx.event_loop.insert_idle(move |state| { - state.do_action(action); - }); - Response::Handled - } - }; + })) + } +} - Ok(response) +impl HandleRequest for ActionRequest { + fn handle(self, ctx: &ClientCtx) -> Reply { + ctx.event_loop.insert_idle(move |state| { + state.do_action(self.0.into()); + }); + Ok(()) + } } From edb1e85358672ff8dd08d4d892eb661dc2538ebe Mon Sep 17 00:00:00 2001 From: sodiboo Date: Thu, 18 Apr 2024 09:56:27 +0200 Subject: [PATCH 09/10] remove commented old impl and better describe why we override SIGPIPE --- src/ipc/client.rs | 236 +++------------------------------------------- 1 file changed, 14 insertions(+), 222 deletions(-) diff --git a/src/ipc/client.rs b/src/ipc/client.rs index 60779047d5..6ec2b7779a 100644 --- a/src/ipc/client.rs +++ b/src/ipc/client.rs @@ -55,7 +55,20 @@ pub fn handle_msg(msg: Msg, json: bool) -> anyhow::Result<()> { fn run(json: bool, request: R) -> anyhow::Result<()> { let reply = request.send().context("a communication error occurred")?; - // Default SIGPIPE so that our prints don't panic on stdout closing. + // Piping `niri msg` into a command like `jq invalid` will cause jq to exit early + // from the invalid expression. That also causes the pipe to close, and the piped process + // receives a SIGPIPE. Normally, this would cause println! to panic, but because the error + // ultimately doesn't originate in niri, and it's not a bug in niri, the resulting backtrace is + // quite unhelpful to the user considering that the actual error (invalid jq expression) is + // already shown on the terminal. + // + // To avoid this, we ignore any SIGPIPE we receive from here on out. This can potentially + // interfere with IPC code, so we ensure that it is already finished by the time we reach this + // point. Actual errors with the IPC code are not handled by us; they're bubbled up to + // main() as Err(_). Those are separate from the pipe closing; and should be printed anyways. + // But after this point, we only really print things, so it's safe to ignore SIGPIPE. + // And since stdio panics are the *only* error path, we can be confident that there is actually + // no error path from this point on. unsafe { libc::signal(libc::SIGPIPE, libc::SIG_DFL); } @@ -96,227 +109,6 @@ fn run(json: bool, request: R) -> anyhow::Result<()> { Ok(()) } -// fn _a() { - -// let reply = client -// .send(request) -// .context("a communication error occurred while sending request to niri")?; - -// let response = match reply { -// Ok(r) => r, -// Err(err_msg) => { -// eprintln!("The compositor returned an error:"); -// eprintln!(); -// eprintln!("{err_msg}"); - -// if matches!(msg, Msg::Version) { -// eprintln!(); -// eprintln!("Note: unable to get the compositor's version."); -// eprintln!("Did you forget to restart niri after an update?"); -// } else { -// match NiriSocket::new().and_then(|client| client.send(RequestEnum::Version)) { -// Ok(Ok(Response::Version(server_version))) => { -// let my_version = version(); -// if my_version != server_version { -// eprintln!(); -// eprintln!("Note: niri msg was invoked with a different version of -// niri than the running compositor."); eprintln!("niri msg: -// {my_version}"); eprintln!("compositor: {server_version}"); -// eprintln!("Did you forget to restart niri after an update?"); -// } -// } -// Ok(Ok(_)) => { -// // nonsensical response, do not add confusing context -// } -// Ok(Err(_)) => { -// eprintln!(); -// eprintln!("Note: unable to get the compositor's version."); -// eprintln!("Did you forget to restart niri after an update?"); -// } -// Err(_) => { -// // communication error, do not add irrelevant context -// } -// } -// } - -// return Ok(()); -// } -// }; - -// match msg { -// Msg::Error => { -// bail!("unexpected response: expected an error, got {response:?}"); -// } -// Msg::Version => { -// let Response::Version(server_version) = response else { -// bail!("unexpected response: expected Version, got {response:?}"); -// }; - -// if json { -// println!( -// "{}", -// json!({ -// "cli": version(), -// "compositor": server_version, -// }) -// ); -// return Ok(()); -// } - -// let client_version = version(); - -// println!("niri msg is {client_version}"); -// println!("the compositor is {server_version}"); -// if client_version != server_version { -// eprintln!(); -// eprintln!("These are different"); -// eprintln!("Did you forget to restart niri after an update?"); -// } -// println!(); -// } -// Msg::Outputs => { -// let Response::Outputs(outputs) = response else { -// bail!("unexpected response: expected Outputs, got {response:?}"); -// }; - -// if json { -// let output = -// serde_json::to_string(&outputs).context("error formatting response")?; -// println!("{output}"); -// return Ok(()); -// } - -// let mut outputs = outputs.into_iter().collect::>(); -// outputs.sort_unstable_by(|a, b| a.0.cmp(&b.0)); - -// for (connector, output) in outputs.into_iter() { -// let Output { -// name, -// make, -// model, -// physical_size, -// modes, -// current_mode, -// logical, -// } = output; - -// println!(r#"Output "{connector}" ({make} - {model} - {name})"#); - -// if let Some(current) = current_mode { -// let mode = *modes -// .get(current) -// .context("invalid response: current mode does not exist")?; -// let Mode { -// width, -// height, -// refresh_rate, -// is_preferred, -// } = mode; -// let refresh = refresh_rate as f64 / 1000.; -// let preferred = if is_preferred { " (preferred)" } else { "" }; -// println!(" Current mode: {width}x{height} @ {refresh:.3} Hz{preferred}"); -// } else { -// println!(" Disabled"); -// } - -// if let Some((width, height)) = physical_size { -// println!(" Physical size: {width}x{height} mm"); -// } else { -// println!(" Physical size: unknown"); -// } - -// if let Some(logical) = logical { -// let LogicalOutput { -// x, -// y, -// width, -// height, -// scale, -// transform, -// } = logical; -// println!(" Logical position: {x}, {y}"); -// println!(" Logical size: {width}x{height}"); -// println!(" Scale: {scale}"); - -// let transform = match transform { -// niri_ipc::Transform::Normal => "normal", -// niri_ipc::Transform::_90 => "90° counter-clockwise", -// niri_ipc::Transform::_180 => "180°", -// niri_ipc::Transform::_270 => "270° counter-clockwise", -// niri_ipc::Transform::Flipped => "flipped horizontally", -// niri_ipc::Transform::Flipped90 => { -// "90° counter-clockwise, flipped horizontally" -// } -// niri_ipc::Transform::Flipped180 => "flipped vertically", -// niri_ipc::Transform::Flipped270 => { -// "270° counter-clockwise, flipped horizontally" -// } -// }; -// println!(" Transform: {transform}"); -// } - -// println!(" Available modes:"); -// for (idx, mode) in modes.into_iter().enumerate() { -// let Mode { -// width, -// height, -// refresh_rate, -// is_preferred, -// } = mode; -// let refresh = refresh_rate as f64 / 1000.; - -// let is_current = Some(idx) == current_mode; -// let qualifier = match (is_current, is_preferred) { -// (true, true) => " (current, preferred)", -// (true, false) => " (current)", -// (false, true) => " (preferred)", -// (false, false) => "", -// }; - -// println!(" {width}x{height}@{refresh:.3}{qualifier}"); -// } -// println!(); -// } -// } -// Msg::FocusedWindow => { -// let Response::FocusedWindow(window) = response else { -// bail!("unexpected response: expected FocusedWindow, got {response:?}"); -// }; - -// if json { -// let window = serde_json::to_string(&window).context("error formatting -// response")?; println!("{window}"); -// return Ok(()); -// } - -// if let Some(window) = window { -// println!("Focused window:"); - -// if let Some(title) = window.title { -// println!(" Title: \"{title}\""); -// } else { -// println!(" Title: (unset)"); -// } - -// if let Some(app_id) = window.app_id { -// println!(" App ID: \"{app_id}\""); -// } else { -// println!(" App ID: (unset)"); -// } -// } else { -// println!("No window is focused."); -// } -// } -// Msg::Action { .. } => { -// let Response::Handled = response else { -// bail!("unexpected response: expected Handled, got {response:?}"); -// }; -// } -// } - -// Ok(()) -// } - impl MsgRequest for ErrorRequest { fn json(response: Self::Response) -> serde_json::Value { match response {} From 24e344f0990f802179f4064e888c26f7a331e2fd Mon Sep 17 00:00:00 2001 From: sodiboo Date: Sun, 21 Apr 2024 21:07:58 +0200 Subject: [PATCH 10/10] various changes to data model. mostly resilience of errors --- niri-ipc/src/lib.rs | 270 ++++++++++++++++++++++------------------- niri-ipc/src/socket.rs | 17 ++- src/cli.rs | 2 +- src/ipc/client.rs | 51 +++++--- src/ipc/server.rs | 51 +++----- 5 files changed, 205 insertions(+), 186 deletions(-) diff --git a/niri-ipc/src/lib.rs b/niri-ipc/src/lib.rs index bbf3a6621b..784b5eb1a7 100644 --- a/niri-ipc/src/lib.rs +++ b/niri-ipc/src/lib.rs @@ -10,34 +10,19 @@ mod socket; pub use socket::{NiriSocket, SOCKET_PATH_ENV}; -mod private { - pub trait Sealed {} -} - -// TODO: remove ResponseDecoder and AnyRequest? - -#[allow(missing_docs)] -pub trait ResponseDecoder { - type Output: for<'de> Deserialize<'de>; - fn decode(&self, value: serde_json::Value) -> serde_json::Result; -} - -#[derive(Debug, Clone, Copy)] -#[allow(missing_docs)] -pub struct TrivialDecoder Deserialize<'de>>(std::marker::PhantomData); - -impl Deserialize<'de>> Default for TrivialDecoder { - fn default() -> Self { - Self(std::marker::PhantomData) - } +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(untagged)] +#[doc(hidden)] +pub enum MaybeUnknown { + Known(T), + Unknown(U), } -impl Deserialize<'de>> ResponseDecoder for TrivialDecoder { - type Output = T; +#[doc(hidden)] +pub type MaybeJson = MaybeUnknown; - fn decode(&self, value: serde_json::Value) -> serde_json::Result { - serde_json::from_value(value) - } +mod private { + pub trait Sealed {} } /// A request that can be sent to niri. @@ -47,90 +32,56 @@ pub trait Request: /// The type of the response that niri sends for this request. type Response: Serialize + for<'de> Deserialize<'de>; - #[allow(missing_docs)] - fn decoder(&self) -> impl ResponseDecoder> + 'static; - /// Convert the request into a RequestMessage (for serialization). fn into_message(self) -> RequestMessage; } +impl From for RequestMessage { + fn from(value: R) -> Self { + value.into_message() + } +} + macro_rules! requests { (@$item:item$(;)?) => { $item }; - ($($(#[$m:meta])*$variant:ident($v:vis struct $request:ident$($p:tt)?) -> $response:ty;)*) => { - #[derive(Debug, Serialize, Deserialize, Clone)] - /// A plain tag for each request type. - pub enum RequestType { - $( - $(#[$m])* - $variant, - )* - } - + ($dollar:tt; $($(#[$m:meta])*$variant:ident($v:vis struct $request:ident$($p:tt)?) -> $response:ty;)*) => { #[derive(Debug, Serialize, Deserialize, Clone)] - enum AnyRequest { + #[doc(hidden)] + pub enum RequestMessage { $( $(#[$m])* $variant($request), )* } - #[derive(Debug, Serialize, Deserialize, Clone)] - enum AnyResponse { - $( - $(#[$m])* - $variant($response), - )* - } - - impl private::Sealed for AnyRequest {} - - struct AnyResponseDecoder(RequestType); - - impl ResponseDecoder for AnyResponseDecoder { - type Output = Reply; - - fn decode(&self, value: serde_json::Value) -> serde_json::Result { - match self.0 { - $( - RequestType::$variant => TrivialDecoder::>::default().decode(value).map(|r| r.map(AnyResponse::$variant)), - )* - } - } - } - - impl TryFrom for AnyRequest { - type Error = serde_json::Error; - - fn try_from(message: RequestMessage) -> serde_json::Result { - match message.request_type { - $( - RequestType::$variant => serde_json::from_value(message.request_body).map(AnyRequest::$variant), - )* - } - } - } - - impl Request for AnyRequest { - type Response = AnyResponse; - - fn decoder(&self) -> impl ResponseDecoder> + 'static { - match self { - $( - AnyRequest::$variant(_) => AnyResponseDecoder(RequestType::$variant), - )* - } - } - - fn into_message(self) -> RequestMessage { - match self { + // This is for use in server code. + // Essentially just equivalent to the following: + // fn dispatch(message: RequestMessage, f: impl FnOnce(R) -> T) -> T; + // except: + // (a) rust doesn't quite support this kind of higher-order generic functions + // (b) even if it did, it would have to be sound by the Request bound, which isn't possible + // because the inherent usage is a per-type implementation which can only be sound by-example + // + // essentially this just cuts down on the server needing to enumerate all request types + #[macro_export] + #[doc(hidden)] + macro_rules! dispatch { + ($dollar message:expr, $dollar f:expr) => {{ + let message: RequestMessage = $dollar message; + match message { $( - AnyRequest::$variant(request) => request.into_message(), + RequestMessage::$variant(request) => { + const fn ascribe(f: F) -> F where F: FnOnce($crate::$request) -> R { + f + } + let f = ascribe($dollar f); + f(request) + } )* } - } + }}; } - $( requests!(@ $(#[$m])* @@ -138,48 +89,17 @@ macro_rules! requests { $v struct $request $($p)?; ); - impl From<$request> for AnyRequest { - fn from(request: $request) -> Self { - AnyRequest::$variant(request) - } - } - impl crate::private::Sealed for $request {} impl crate::Request for $request { type Response = $response; - fn decoder(&self) -> impl crate::ResponseDecoder> + 'static { - TrivialDecoder::>::default() - } - - fn into_message(self) -> RequestMessage { - RequestMessage { - request_type: RequestType::$variant, - request_body: serde_json::to_value(self).unwrap(), - } + fn into_message(self) -> crate::RequestMessage { + RequestMessage::$variant(self) } } )* - } -} - -/// The message format for IPC communication. -/// -/// This is mainly to avoid using sum types in IPC communication, which are more annoying to use -/// with non-Rust tooling. -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct RequestMessage { - /// The type of the request. - pub request_type: RequestType, - /// The raw JSON body of the request. - pub request_body: serde_json::Value, -} - -impl From for RequestMessage { - fn from(value: R) -> Self { - value.into_message() - } + }; } /// Uninstantiable @@ -187,8 +107,10 @@ impl From for RequestMessage { pub enum Never {} requests!( + $; + /// Always responds with an error (for testing error handling). - ReturnError(pub struct ErrorRequest) -> Never; + ReturnError(pub struct ErrorRequest(pub String)) -> Never; /// Requests the version string for the running niri instance. Version(pub struct VersionRequest) -> String; @@ -203,6 +125,95 @@ requests!( Action(pub struct ActionRequest(pub Action)) -> (); ); +#[derive(Debug, Serialize, Deserialize, Clone)] +struct ErrorRepr { + #[serde(rename = "error_type")] + tag: String, + #[serde(rename = "error_message")] + message: String, +} + +macro_rules! error { + ( + $(#[$meta_enum:meta])* + $v:vis enum $name:ident { + $(#[$meta_end:meta])* + $other:ident (String)$(,)? + $( + $(#[$meta_variant:meta])* + $variant:ident = $msg:literal, + )* + } + ) => { + $(#[$meta_enum])* + $v enum $name { + $( + $(#[$meta_variant])* + $variant, + )* + $(#[$meta_end])* $other(String), + } + + impl Serialize for $name { + fn serialize(&self, serializer: S) -> Result { + match self { + $( + $name::$variant => ErrorRepr { + tag: String::from(stringify!($variant)), + message: String::from($msg), + }, + )* + $name::$other(msg) => ErrorRepr { + tag: String::from(stringify!($other)), + message: msg.clone(), + }, + }.serialize(serializer) + } + } + + impl<'de> Deserialize<'de> for $name { + fn deserialize>(deserializer: D) -> Result { + let repr = ErrorRepr::deserialize(deserializer)?; + match repr.tag.as_str() { + $( + stringify!($variant) => Ok(Error::$variant), + )* + _ => Ok(Error::$other(repr.message)), + } + } + } + + impl std::fmt::Display for $name { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + $( + $name::$variant => write!(f, $msg), + )* + $name::$other(msg) => write!(f, "{}", msg), + } + } + } + }; +} + +error! { + /// Errors that can occur when sending a request to niri. + #[derive(Debug, Clone, PartialEq, Eq)] + pub enum Error { + /// An error occurred that doesn't have a specific variant. + /// This occurs when the compositor sends an error that this client doesn't know about. + Other(String), + /// The client didn't send valid JSON. + ClientBadJson = "the client didn't send valid JSON", + /// The compositor didn't understand our request. + ClientBadProtocol = "the client didn't follow the protocol; this may be caused by mismatched versions", + /// The compositor sent a request we didn't understand. + CompositorBadProtocol = "the compositor didn't follow the protocol; this may be caused by mismatched versions", + /// There is + InternalError = "an internal error occurred in the compositor", + } +} + /// Reply from niri to client. /// /// Every request gets one reply. @@ -211,7 +222,7 @@ requests!( /// * If the request does not need any particular response, it will be /// `Reply::Ok(Response::Handled)`. Kind of like an `Ok(())`. /// * Otherwise, it will be `Reply::Ok(response)` with one of the other [`Response`] variants. -pub type Reply = Result; +pub type Reply = Result; /// Actions that niri can perform. // Variants in this enum should match the spelling of the ones in niri-config. Most, but not all, @@ -462,6 +473,7 @@ pub struct LogicalOutput { #[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)] pub enum Transform { /// Untransformed. + #[serde(rename = "normal")] Normal, /// Rotated by 90°. #[serde(rename = "90")] @@ -473,12 +485,16 @@ pub enum Transform { #[serde(rename = "270")] _270, /// Flipped horizontally. + #[serde(rename = "flipped")] Flipped, /// Rotated by 90° and flipped horizontally. + #[serde(rename = "flipped-90")] Flipped90, /// Flipped vertically. + #[serde(rename = "flipped-180")] Flipped180, /// Rotated by 270° and flipped horizontally. + #[serde(rename = "flipped-270")] Flipped270, } diff --git a/niri-ipc/src/socket.rs b/niri-ipc/src/socket.rs index 650d8e3d67..a6b348527c 100644 --- a/niri-ipc/src/socket.rs +++ b/niri-ipc/src/socket.rs @@ -5,7 +5,7 @@ use std::path::Path; use serde_json::de::IoRead; use serde_json::StreamDeserializer; -use crate::{Reply, Request, ResponseDecoder}; +use crate::{MaybeJson, Reply, Request}; /// Name of the environment variable containing the niri IPC socket path. pub const SOCKET_PATH_ENV: &str = "NIRI_SOCKET"; @@ -52,18 +52,23 @@ impl NiriSocket { /// Handle a request to the niri IPC server /// /// # Returns - /// Ok(Ok([Response](crate::Response))) corresponds to a successful response from the running - /// niri instance. Ok(Err([String])) corresponds to an error received from the running niri - /// instance. Err([std::io::Error]) corresponds to an error in the IPC communication. + /// - Ok(Ok([Response](crate::Response))) corresponds to a successful response from the running + /// niri instance. + /// - Ok(Err([String])) corresponds to an error received from the running niri + /// instance. + /// - Err([std::io::Error]) corresponds to an error in the IPC communication. pub fn send_request(mut self, request: R) -> io::Result> { - let decoder = request.decoder(); let mut buf = serde_json::to_vec(&request.into_message()).unwrap(); writeln!(buf).unwrap(); self.stream.write_all(&buf)?; // .context("error writing IPC request")?; self.stream.flush()?; if let Some(next) = self.responses.next() { - next.and_then(|v| decoder.decode(v)) + next.and_then(serde_json::from_value) + .map(|v| match v { + MaybeJson::Known(reply) => reply, + MaybeJson::Unknown(_) => Err(crate::Error::CompositorBadProtocol), + }) .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) } else { Err(io::Error::new( diff --git a/src/cli.rs b/src/cli.rs index 6296b1b3d5..53d27bf067 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -53,7 +53,7 @@ pub enum Sub { #[derive(Subcommand)] pub enum Msg { /// Print an error message. - Error, + Error { message: String }, /// Print the version string of the running niri instance. Version, /// List connected outputs. diff --git a/src/ipc/client.rs b/src/ipc/client.rs index 6ec2b7779a..b284a224af 100644 --- a/src/ipc/client.rs +++ b/src/ipc/client.rs @@ -1,7 +1,7 @@ use anyhow::Context; use niri_ipc::{ - ActionRequest, ErrorRequest, FocusedWindowRequest, LogicalOutput, Mode, NiriSocket, Output, - OutputRequest, Request, VersionRequest, + ActionRequest, Error, ErrorRequest, FocusedWindowRequest, LogicalOutput, Mode, NiriSocket, + Output, OutputRequest, Request, VersionRequest, }; use serde_json::json; @@ -9,12 +9,10 @@ use crate::cli::Msg; use crate::utils::version; struct CompositorError { - message: String, + error: niri_ipc::Error, version: Option, } -type MsgResult = Result; - trait MsgRequest: Request { fn json(response: Self::Response) -> serde_json::Value { json!(response) @@ -30,25 +28,23 @@ trait MsgRequest: Request { } } - fn send(self) -> anyhow::Result> { - let socket = NiriSocket::new().context("trying to initialize the socket")?; - let reply = socket - .send_request(self) - .context("while sending request to niri")?; - Ok(reply.map_err(|message| CompositorError { - message, + fn send(self) -> anyhow::Result> { + let socket = NiriSocket::new().context("problem initializing the socket")?; + let reply = socket.send_request(self).context("problem ")?; + Ok(reply.map_err(|error| CompositorError { + error, version: Self::check_version(), })) } } pub fn handle_msg(msg: Msg, json: bool) -> anyhow::Result<()> { - match &msg { - Msg::Error => run(json, ErrorRequest), + match msg { + Msg::Error { message } => run(json, ErrorRequest(message)), Msg::Version => run(json, VersionRequest), Msg::Outputs => run(json, OutputRequest), Msg::FocusedWindow => run(json, FocusedWindowRequest), - Msg::Action { action } => run(json, ActionRequest(action.clone())), + Msg::Action { action } => run(json, ActionRequest(action)), } } @@ -82,12 +78,28 @@ fn run(json: bool, request: R) -> anyhow::Result<()> { } } Err(CompositorError { - message, + error, version: server_version, }) => { - eprintln!("The compositor returned an error:"); - eprintln!(); - eprintln!("{message}"); + match error { + Error::ClientBadJson => { + eprintln!("Something went wrong in the CLI; the compositor says the JSON it sent was invalid") + } + Error::ClientBadProtocol => { + eprintln!("The compositor didn't understand the request sent by the CLI.") + } + Error::CompositorBadProtocol => { + eprintln!("The compositor returned a response that the CLI didn't understand.") + } + Error::InternalError => { + eprintln!("Something went wrong in the compositor. I don't know what.") + } + Error::Other(msg) => { + eprintln!("The compositor returned an error:"); + eprintln!(); + eprintln!("{msg}"); + } + } if let Some(server_version) = server_version { let my_version = version(); @@ -121,6 +133,7 @@ impl MsgRequest for ErrorRequest { impl MsgRequest for VersionRequest { fn check_version() -> Option { + eprintln!("version"); // If the version request fails, we can't exactly try again. None } diff --git a/src/ipc/server.rs b/src/ipc/server.rs index e10a001ea3..4a31f100bf 100644 --- a/src/ipc/server.rs +++ b/src/ipc/server.rs @@ -10,8 +10,8 @@ use directories::BaseDirs; use futures_util::io::{AsyncReadExt, BufReader}; use futures_util::{AsyncBufReadExt, AsyncWriteExt, StreamExt}; use niri_ipc::{ - ActionRequest, ErrorRequest, FocusedWindowRequest, OutputRequest, Reply, Request, - RequestMessage, RequestType, VersionRequest, + ActionRequest, Error, ErrorRequest, FocusedWindowRequest, MaybeJson, MaybeUnknown, + OutputRequest, Reply, Request, RequestMessage, VersionRequest, }; use smithay::desktop::Window; use smithay::reexports::calloop::generic::Generic; @@ -147,45 +147,30 @@ trait HandleRequest: Request { } fn process(ctx: &ClientCtx, line: String) -> Reply { - let request: RequestMessage = serde_json::from_str(&line).map_err(|err| { + let deserialized: MaybeJson = serde_json::from_str(&line).map_err(|err| { warn!("error parsing IPC request: {err:?}"); - "error parsing request" + Error::ClientBadJson })?; - macro_rules! handle { - ($($variant:ident => $type:ty,)*) => { - match request.request_type { - $( - RequestType::$variant => { - let request = serde_json::from_value::<$type>(request.request_body).map_err(|err| - { - warn!("error parsing IPC request: {err:?}"); - "error parsing request" - })?; - HandleRequest::handle(request, ctx).and_then(|v| { - serde_json::to_value(v).map_err(|err| { - warn!("error serializing response to IPC request: {err:?}"); - "error serializing response".into() - }) - }) - } - )* - } - }; + match deserialized { + MaybeUnknown::Known(request) => niri_ipc::dispatch!(request, |req| { + req.handle(ctx).and_then(|v| { + serde_json::to_value(v).map_err(|err| { + warn!("error serializing response to IPC request: {err:?}"); + Error::InternalError + }) + }) + }), + MaybeUnknown::Unknown(payload) => { + warn!("client sent an invalid payload: {payload}"); + Err(Error::ClientBadProtocol) + } } - - handle!( - ReturnError => ErrorRequest, - Version => VersionRequest, - Outputs => OutputRequest, - FocusedWindow => FocusedWindowRequest, - Action => ActionRequest, - ) } impl HandleRequest for ErrorRequest { fn handle(self, _ctx: &ClientCtx) -> Reply { - Err("client wanted an error".into()) + Err(Error::Other(self.0)) } }