diff --git a/src/create/tokio.rs b/src/create/tokio.rs index ff2a5494..f5402679 100644 --- a/src/create/tokio.rs +++ b/src/create/tokio.rs @@ -27,7 +27,7 @@ use tokio_util::compat::{ use crate::{ create::{unbuffered_stdout, Spawner}, - error::LoopError, + error::{HandshakeError, LoopError}, neovim::Neovim, Handler, }; @@ -213,3 +213,46 @@ where Ok((neovim, io_handle)) } + +/// Connect to a neovim instance by spawning a new one and send a handshake +/// message. Unlike `new_child_cmd`, this function is tolerant to extra +/// data in the reader before the handshake response is received. +/// +/// `message` should be a unique string that is normally not found in the +/// stdout. Due to the way Neovim packs strings, the length has to be either +/// less than 20 characters or more than 31 characters long. +/// See https://github.com/neovim/neovim/issues/32784 for more information. +pub async fn new_child_handshake_cmd( + cmd: &mut Command, + handler: H, + message: &str, +) -> Result< + ( + Neovim>, + JoinHandle>>, + Child, + ), + Box, +> +where + H: Handler> + Send + 'static, +{ + let mut child = cmd.stdin(Stdio::piped()).stdout(Stdio::piped()).spawn()?; + let stdout = child + .stdout + .take() + .ok_or_else(|| Error::new(ErrorKind::Other, "Can't open stdout"))? + .compat(); + let stdin = child + .stdin + .take() + .ok_or_else(|| Error::new(ErrorKind::Other, "Can't open stdin"))? + .compat_write(); + + let (neovim, io) = + Neovim::>::handshake(stdout, stdin, handler, message) + .await?; + let io_handle = spawn(io); + + Ok((neovim, io_handle, child)) +} diff --git a/src/error.rs b/src/error.rs index d7e5e9e0..8ac4f555 100644 --- a/src/error.rs +++ b/src/error.rs @@ -426,3 +426,81 @@ impl From for Box { Box::new(LoopError::MsgidNotFound(i)) } } + +#[derive(Debug)] +pub enum HandshakeError { + /// Sending the request to neovim has failed. + /// + /// Fields: + /// + /// 0. The underlying error + SendError(EncodeError), + /// Sending the request to neovim has failed. + /// + /// Fields: + /// + /// 0. The underlying error + /// 1. The data read so far + RecvError(io::Error, String), + /// Unexpected response received + /// + /// Fields: + /// + /// 0. The data read so far + UnexpectedResponse(String), + /// The launch of Neovim failed + /// + /// Fields: + /// + /// 0. The underlying error + LaunchError(io::Error), +} + +impl From> for Box { + fn from(v: Box) -> Box { + Box::new(HandshakeError::SendError(*v)) + } +} + +impl From<(io::Error, String)> for Box { + fn from(v: (io::Error, String)) -> Box { + Box::new(HandshakeError::RecvError(v.0, v.1)) + } +} + +impl From for Box { + fn from(v: io::Error) -> Box { + Box::new(HandshakeError::LaunchError(v)) + } +} + +impl Error for HandshakeError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match *self { + Self::SendError(ref s) => Some(s), + Self::RecvError(ref s, _) => Some(s), + Self::LaunchError(ref s) => Some(s), + Self::UnexpectedResponse(_) => None, + } + } +} + +impl Display for HandshakeError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> { + match *self { + Self::SendError(ref s) => write!(fmt, "Error sending handshake '{s}'"), + Self::RecvError(ref s, ref output) => { + write!( + fmt, + "Error receiving handshake response '{s}'\n\ + Unexpected output:\n{output}" + ) + } + Self::LaunchError(ref s) => write!(fmt, "Error launching nvim '{s}'"), + Self::UnexpectedResponse(ref output) => write!( + fmt, + "Error receiving handshake response, unexpected output:\n{output}" + ), + } + } +} diff --git a/src/neovim.rs b/src/neovim.rs index 6825a23c..d421b696 100644 --- a/src/neovim.rs +++ b/src/neovim.rs @@ -12,16 +12,17 @@ use futures::{ mpsc::{unbounded, UnboundedReceiver, UnboundedSender}, oneshot, }, - io::{AsyncRead, AsyncWrite,}, + future, + io::{AsyncRead, AsyncReadExt, AsyncWrite}, lock::Mutex, sink::SinkExt, stream::StreamExt, - future, TryFutureExt, + TryFutureExt, }; use crate::{ create::Spawner, - error::{CallError, DecodeError, EncodeError, LoopError}, + error::{CallError, DecodeError, EncodeError, HandshakeError, LoopError}, rpc::{ handler::Handler, model, @@ -108,13 +109,110 @@ where let (sender, receiver) = unbounded(); let fut = future::try_join( req.clone().io_loop(reader, sender), - req.clone().handler_loop(handler, receiver) + req.clone().handler_loop(handler, receiver), ) .map_ok(|_| ()); (req, fut) } + /// Create a new instance, immediately send a handshake message and + /// wait for the response. Unlike `new`, this function is tolerant to extra + /// data in the reader before the handshake response is received. + /// + /// `message` should be a unique string that is normally not found in the + /// stdout. Due to the way Neovim packs strings, the length has to be either + /// less than 20 characters or more than 31 characters long. + /// See https://github.com/neovim/neovim/issues/32784 for more information. + pub async fn handshake( + mut reader: R, + writer: W, + handler: H, + message: &str, + ) -> Result< + ( + Neovim<::Writer>, + impl Future>>, + ), + Box, + > + where + R: AsyncRead + Send + Unpin + 'static, + H: Handler + Spawner, + { + let instance = Neovim { + writer: Arc::new(Mutex::new(writer)), + msgid_counter: Arc::new(AtomicU64::new(0)), + queue: Arc::new(Mutex::new(Vec::new())), + }; + + let msgid = instance.msgid_counter.fetch_add(1, Ordering::SeqCst); + // Nvim encodes fixed size strings with a length of 20-31 bytes wrong, so + // avoid that + let msg_len = message.len(); + assert!( + !(20..=31).contains(&msg_len), + "The message should be less than 20 characters or more than 31 characters + long, but the length is {msg_len}." + ); + + let req = RpcMessage::RpcRequest { + msgid, + method: "nvim_exec_lua".to_owned(), + params: call_args![format!("return '{message}'"), Vec::::new()], + }; + model::encode(instance.writer.clone(), req).await?; + + let expected_resp = RpcMessage::RpcResponse { + msgid, + error: rmpv::Value::Nil, + result: rmpv::Value::String(message.into()), + }; + let mut expected_data = Vec::new(); + model::encode_sync(&mut expected_data, expected_resp) + .expect("Encoding static data can't fail"); + let mut actual_data = Vec::new(); + let mut start = 0; + let mut end = 0; + while end - start != expected_data.len() { + actual_data.resize(start + expected_data.len(), 0); + + let bytes_read = + reader + .read(&mut actual_data[start..]) + .await + .map_err(|err| { + ( + err, + String::from_utf8_lossy(&actual_data[..end]).to_string(), + ) + })?; + if bytes_read == 0 { + // The end of the stream has been reached when the reader returns Ok(0). + // Since we haven't detected a suitable response yet, return an error. + return Err(Box::new(HandshakeError::UnexpectedResponse( + String::from_utf8_lossy(&actual_data[..end]).to_string(), + ))); + } + end += bytes_read; + while end - start > 0 { + if actual_data[start..end] == expected_data[..end - start] { + break; + } + start += 1; + } + } + + let (sender, receiver) = unbounded(); + let fut = future::try_join( + instance.clone().io_loop(reader, sender), + instance.clone().handler_loop(handler, receiver), + ) + .map_ok(|_| ()); + + Ok((instance, fut)) + } + async fn send_msg( &self, method: &str, diff --git a/src/rpc/model.rs b/src/rpc/model.rs index bf5f0ff8..a97c0eb6 100644 --- a/src/rpc/model.rs +++ b/src/rpc/model.rs @@ -2,12 +2,12 @@ use std::{ self, convert::TryInto, - io::{self, Cursor, ErrorKind, Read}, + io::{self, Cursor, ErrorKind, Read, Write}, sync::Arc, }; use futures::{ - io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt,}, + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, lock::Mutex, }; use rmpv::{decode::read_value, encode::write_value, Value}; @@ -162,13 +162,11 @@ fn decode_buffer( } } -/// Encode the given message into the `BufWriter`. Flushes the writer when -/// finished. -pub async fn encode( - writer: Arc>, +/// Encode the given message into the `writer`. +pub fn encode_sync( + writer: &mut W, msg: RpcMessage, ) -> std::result::Result<(), Box> { - let mut v: Vec = vec![]; match msg { RpcMessage::RpcRequest { msgid, @@ -176,7 +174,7 @@ pub async fn encode( params, } => { let val = rpc_args!(0, msgid, method, params); - write_value(&mut v, &val)?; + write_value(writer, &val)?; } RpcMessage::RpcResponse { msgid, @@ -184,14 +182,26 @@ pub async fn encode( result, } => { let val = rpc_args!(1, msgid, error, result); - write_value(&mut v, &val)?; + write_value(writer, &val)?; } RpcMessage::RpcNotification { method, params } => { let val = rpc_args!(2, method, params); - write_value(&mut v, &val)?; + write_value(writer, &val)?; } }; + Ok(()) +} + +/// Encode the given message into the `BufWriter`. Flushes the writer when +/// finished. +pub async fn encode( + writer: Arc>, + msg: RpcMessage, +) -> std::result::Result<(), Box> { + let mut v: Vec = vec![]; + encode_sync(&mut v, msg)?; + let mut writer = writer.lock().await; writer.write_all(&v).await?; writer.flush().await?; diff --git a/tests/connecting/handshake.rs b/tests/connecting/handshake.rs new file mode 100644 index 00000000..2d29263e --- /dev/null +++ b/tests/connecting/handshake.rs @@ -0,0 +1,88 @@ +use nvim_rs::rpc::handler::Dummy as DummyHandler; + +#[cfg(feature = "use_tokio")] +use nvim_rs::create::tokio as create; +#[cfg(feature = "use_tokio")] +use tokio::process::Command; +#[cfg(feature = "use_tokio")] +use tokio::test as atest; + +#[cfg(feature = "use_async-std")] +use async_std::test as atest; +#[cfg(feature = "use_async-std")] +use nvim_rs::create::async_std as create; +#[cfg(feature = "use_async-std")] +use std::process::Command; + +#[path = "../common/mod.rs"] +mod common; +use common::*; + +use nvim_rs::error::HandshakeError; + +#[atest] +async fn successful_handshake() { + let handler = DummyHandler::new(); + + create::new_child_handshake_cmd( + Command::new(nvim_path()).args(&["-u", "NONE", "--embed"]), + handler, + "handshake_message", + ) + .await + .expect("Should launch correctly"); +} + +#[cfg(unix)] +#[atest] +async fn successful_handshake_with_extra_output() { + let handler = DummyHandler::new(); + let nvim = nvim_path(); + + create::new_child_handshake_cmd( + Command::new("/bin/sh").args(&[ + "-c", + &format!( + "echo 'extra output';{} -u NONE --embed", + nvim.to_string_lossy() + ), + ]), + handler, + "handshake_message", + ) + .await + .expect("Should launch correctly"); +} + +#[cfg(unix)] +#[atest] +async fn unsuccessful_handshake_with_wrong_output() { + let handler = DummyHandler::new(); + + // NOTE: This has to match the exact length of the message sent + let expected_request_len = 46; + + // Make sure that the command is alive for long enough by reading the request + // message from stdin with dd + let res = create::new_child_handshake_cmd( + Command::new("/bin/sh").args(&[ + "-c", + &format!("echo 'wrong output'; + timeout 5 dd bs=1 count={expected_request_len} > /dev/null 2>&1")]), + handler, + "handshake_message", + ) + .await; + + match res { + Err(err) => match *err { + HandshakeError::UnexpectedResponse(output) => { + assert_eq!(output, "wrong output\n"); + } + _ => { + panic!("Unexpected error returned {}", err); + } + }, + _ => panic!("No error returned"), + } +} diff --git a/tests/connecting/mod.rs b/tests/connecting/mod.rs index d40592d1..7fd970c6 100644 --- a/tests/connecting/mod.rs +++ b/tests/connecting/mod.rs @@ -2,3 +2,6 @@ pub mod conns; #[cfg(feature = "use_async-std")] pub mod conns; + +#[cfg(feature = "use_tokio")] +pub mod handshake;