Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion src/create/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use tokio_util::compat::{

use crate::{
create::{unbuffered_stdout, Spawner},
error::LoopError,
error::{HandshakeError, LoopError},
neovim::Neovim,
Handler,
};
Expand Down Expand Up @@ -213,3 +213,46 @@ where

Ok((neovim, io_handle))
}

/// Connect to a neovim instance by spawning a new one and send a handshake
/// message. Unlike `new_child_cmd`, this function is tolerant to extra
/// data in the reader before the handshake response is received.
///
/// `message` should be a unique string that is normally not found in the
/// stdout. Due to the way Neovim packs strings, the length has to be either
/// less than 20 characters or more than 31 characters long.
/// See https://github.com/neovim/neovim/issues/32784 for more information.
pub async fn new_child_handshake_cmd<H>(
cmd: &mut Command,
handler: H,
message: &str,
) -> Result<
(
Neovim<Compat<ChildStdin>>,
JoinHandle<Result<(), Box<LoopError>>>,
Child,
),
Box<HandshakeError>,
>
where
H: Handler<Writer = Compat<ChildStdin>> + Send + 'static,
{
let mut child = cmd.stdin(Stdio::piped()).stdout(Stdio::piped()).spawn()?;
let stdout = child
.stdout
.take()
.ok_or_else(|| Error::new(ErrorKind::Other, "Can't open stdout"))?
.compat();
let stdin = child
.stdin
.take()
.ok_or_else(|| Error::new(ErrorKind::Other, "Can't open stdin"))?
.compat_write();

let (neovim, io) =
Neovim::<Compat<ChildStdin>>::handshake(stdout, stdin, handler, message)
.await?;
let io_handle = spawn(io);

Ok((neovim, io_handle, child))
}
78 changes: 78 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,81 @@ impl From<u64> for Box<LoopError> {
Box::new(LoopError::MsgidNotFound(i))
}
}

#[derive(Debug)]
pub enum HandshakeError {
/// Sending the request to neovim has failed.
///
/// Fields:
///
/// 0. The underlying error
SendError(EncodeError),
/// Sending the request to neovim has failed.
///
/// Fields:
///
/// 0. The underlying error
/// 1. The data read so far
RecvError(io::Error, String),
/// Unexpected response received
///
/// Fields:
///
/// 0. The data read so far
UnexpectedResponse(String),
/// The launch of Neovim failed
///
/// Fields:
///
/// 0. The underlying error
LaunchError(io::Error),
}

impl From<Box<EncodeError>> for Box<HandshakeError> {
fn from(v: Box<EncodeError>) -> Box<HandshakeError> {
Box::new(HandshakeError::SendError(*v))
}
}

impl From<(io::Error, String)> for Box<HandshakeError> {
fn from(v: (io::Error, String)) -> Box<HandshakeError> {
Box::new(HandshakeError::RecvError(v.0, v.1))
}
}

impl From<io::Error> for Box<HandshakeError> {
fn from(v: io::Error) -> Box<HandshakeError> {
Box::new(HandshakeError::LaunchError(v))
}
}

impl Error for HandshakeError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match *self {
Self::SendError(ref s) => Some(s),
Self::RecvError(ref s, _) => Some(s),
Self::LaunchError(ref s) => Some(s),
Self::UnexpectedResponse(_) => None,
}
}
}

impl Display for HandshakeError {
fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
match *self {
Self::SendError(ref s) => write!(fmt, "Error sending handshake '{s}'"),
Self::RecvError(ref s, ref output) => {
write!(
fmt,
"Error receiving handshake response '{s}'\n\
Unexpected output:\n{output}"
)
}
Self::LaunchError(ref s) => write!(fmt, "Error launching nvim '{s}'"),
Self::UnexpectedResponse(ref output) => write!(
fmt,
"Error receiving handshake response, unexpected output:\n{output}"
),
}
}
}
106 changes: 102 additions & 4 deletions src/neovim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@ use futures::{
mpsc::{unbounded, UnboundedReceiver, UnboundedSender},
oneshot,
},
io::{AsyncRead, AsyncWrite,},
future,
io::{AsyncRead, AsyncReadExt, AsyncWrite},
lock::Mutex,
sink::SinkExt,
stream::StreamExt,
future, TryFutureExt,
TryFutureExt,
};

use crate::{
create::Spawner,
error::{CallError, DecodeError, EncodeError, LoopError},
error::{CallError, DecodeError, EncodeError, HandshakeError, LoopError},
rpc::{
handler::Handler,
model,
Expand Down Expand Up @@ -108,13 +109,110 @@ where
let (sender, receiver) = unbounded();
let fut = future::try_join(
req.clone().io_loop(reader, sender),
req.clone().handler_loop(handler, receiver)
req.clone().handler_loop(handler, receiver),
)
.map_ok(|_| ());

(req, fut)
}

/// Create a new instance, immediately send a handshake message and
/// wait for the response. Unlike `new`, this function is tolerant to extra
/// data in the reader before the handshake response is received.
///
/// `message` should be a unique string that is normally not found in the
/// stdout. Due to the way Neovim packs strings, the length has to be either
/// less than 20 characters or more than 31 characters long.
/// See https://github.com/neovim/neovim/issues/32784 for more information.
pub async fn handshake<H, R>(
mut reader: R,
writer: W,
handler: H,
message: &str,
) -> Result<
(
Neovim<<H as Handler>::Writer>,
impl Future<Output = Result<(), Box<LoopError>>>,
),
Box<HandshakeError>,
>
where
R: AsyncRead + Send + Unpin + 'static,
H: Handler<Writer = W> + Spawner,
{
let instance = Neovim {
writer: Arc::new(Mutex::new(writer)),
msgid_counter: Arc::new(AtomicU64::new(0)),
queue: Arc::new(Mutex::new(Vec::new())),
};

let msgid = instance.msgid_counter.fetch_add(1, Ordering::SeqCst);
// Nvim encodes fixed size strings with a length of 20-31 bytes wrong, so
// avoid that
let msg_len = message.len();
assert!(
!(20..=31).contains(&msg_len),
"The message should be less than 20 characters or more than 31 characters
long, but the length is {msg_len}."
);

let req = RpcMessage::RpcRequest {
msgid,
method: "nvim_exec_lua".to_owned(),
params: call_args![format!("return '{message}'"), Vec::<Value>::new()],
};
model::encode(instance.writer.clone(), req).await?;

let expected_resp = RpcMessage::RpcResponse {
msgid,
error: rmpv::Value::Nil,
result: rmpv::Value::String(message.into()),
};
let mut expected_data = Vec::new();
model::encode_sync(&mut expected_data, expected_resp)
.expect("Encoding static data can't fail");
let mut actual_data = Vec::new();
let mut start = 0;
let mut end = 0;
while end - start != expected_data.len() {
actual_data.resize(start + expected_data.len(), 0);

let bytes_read =
reader
.read(&mut actual_data[start..])
.await
.map_err(|err| {
(
err,
String::from_utf8_lossy(&actual_data[..end]).to_string(),
)
})?;
if bytes_read == 0 {
// The end of the stream has been reached when the reader returns Ok(0).
// Since we haven't detected a suitable response yet, return an error.
return Err(Box::new(HandshakeError::UnexpectedResponse(
String::from_utf8_lossy(&actual_data[..end]).to_string(),
)));
}
end += bytes_read;
while end - start > 0 {
if actual_data[start..end] == expected_data[..end - start] {
break;
}
start += 1;
}
}

let (sender, receiver) = unbounded();
let fut = future::try_join(
instance.clone().io_loop(reader, sender),
instance.clone().handler_loop(handler, receiver),
)
.map_ok(|_| ());

Ok((instance, fut))
}

async fn send_msg(
&self,
method: &str,
Expand Down
30 changes: 20 additions & 10 deletions src/rpc/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
use std::{
self,
convert::TryInto,
io::{self, Cursor, ErrorKind, Read},
io::{self, Cursor, ErrorKind, Read, Write},
sync::Arc,
};

use futures::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt,},
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
lock::Mutex,
};
use rmpv::{decode::read_value, encode::write_value, Value};
Expand Down Expand Up @@ -162,36 +162,46 @@ fn decode_buffer<R: Read>(
}
}

/// Encode the given message into the `BufWriter`. Flushes the writer when
/// finished.
pub async fn encode<W: AsyncWrite + Send + Unpin + 'static>(
writer: Arc<Mutex<W>>,
/// Encode the given message into the `writer`.
pub fn encode_sync<W: Write>(
writer: &mut W,
msg: RpcMessage,
) -> std::result::Result<(), Box<EncodeError>> {
let mut v: Vec<u8> = vec![];
match msg {
RpcMessage::RpcRequest {
msgid,
method,
params,
} => {
let val = rpc_args!(0, msgid, method, params);
write_value(&mut v, &val)?;
write_value(writer, &val)?;
}
RpcMessage::RpcResponse {
msgid,
error,
result,
} => {
let val = rpc_args!(1, msgid, error, result);
write_value(&mut v, &val)?;
write_value(writer, &val)?;
}
RpcMessage::RpcNotification { method, params } => {
let val = rpc_args!(2, method, params);
write_value(&mut v, &val)?;
write_value(writer, &val)?;
}
};

Ok(())
}

/// Encode the given message into the `BufWriter`. Flushes the writer when
/// finished.
pub async fn encode<W: AsyncWrite + Send + Unpin + 'static>(
writer: Arc<Mutex<W>>,
msg: RpcMessage,
) -> std::result::Result<(), Box<EncodeError>> {
let mut v: Vec<u8> = vec![];
encode_sync(&mut v, msg)?;

let mut writer = writer.lock().await;
writer.write_all(&v).await?;
writer.flush().await?;
Expand Down
Loading
Loading