Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

codec: check serialization result length #496

Open
wants to merge 7 commits into
base: v0.5.x
Choose a base branch
from
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions benchmark/src/bench.rs
Original file line number Diff line number Diff line change
@@ -100,8 +100,9 @@ impl Generic {

#[inline]
#[allow(clippy::ptr_arg)]
pub fn bin_ser(t: &Vec<u8>, buf: &mut Vec<u8>) {
buf.extend_from_slice(t)
pub fn bin_ser(t: &Vec<u8>, buf: &mut Vec<u8>) -> grpc::Result<()> {
buf.extend_from_slice(t);
Ok(())
}

#[inline]
4 changes: 2 additions & 2 deletions src/call/client.rs
Original file line number Diff line number Diff line change
@@ -103,7 +103,7 @@ impl Call {
) -> Result<ClientUnaryReceiver<Resp>> {
let call = channel.create_call(method, &opt)?;
let mut payload = vec![];
(method.req_ser())(req, &mut payload);
(method.req_ser())(req, &mut payload)?;
let cq_f = check_run(BatchType::CheckRead, |ctx, tag| unsafe {
grpc_sys::grpcwrap_call_start_unary(
call.call,
@@ -157,7 +157,7 @@ impl Call {
) -> Result<ClientSStreamReceiver<Resp>> {
let call = channel.create_call(method, &opt)?;
let mut payload = vec![];
(method.req_ser())(req, &mut payload);
(method.req_ser())(req, &mut payload)?;
let cq_f = check_run(BatchType::Finish, |ctx, tag| unsafe {
grpc_sys::grpcwrap_call_start_server_streaming(
call.call,
2 changes: 1 addition & 1 deletion src/call/mod.rs
Original file line number Diff line number Diff line change
@@ -657,7 +657,7 @@ impl SinkBase {
}

self.buf.clear();
ser(t, &mut self.buf);
ser(t, &mut self.buf)?;
if flags.get_buffer_hint() && self.send_metadata {
// temporary fix: buffer hint with send meta will not send out any metadata.
flags = flags.buffer_hint(false);
19 changes: 14 additions & 5 deletions src/call/server.rs
Original file line number Diff line number Diff line change
@@ -330,11 +330,20 @@ macro_rules! impl_unary_sink {
}

fn complete(mut self, status: RpcStatus, t: Option<T>) -> $rt {
let data = t.as_ref().map(|t| {
let mut buf = vec![];
(self.ser)(t, &mut buf);
buf
});
let data = match t {
Some(t) => {
let mut buf = vec![];
match (self.ser)(&t, &mut buf) {
Ok(()) => Some(buf),
Err(e) => return $rt {
call: self.call.take().unwrap(),
cq_f: None,
err: Some(e),
}
}
},
None => None,
};

let write_flags = self.write_flags;
let res = self.call.as_mut().unwrap().call(|c| {
37 changes: 27 additions & 10 deletions src/codec.rs
Original file line number Diff line number Diff line change
@@ -4,7 +4,11 @@ use crate::call::MessageReader;
use crate::error::Result;

pub type DeserializeFn<T> = fn(MessageReader) -> Result<T>;
pub type SerializeFn<T> = fn(&T, &mut Vec<u8>);
pub type SerializeFn<T> = fn(&T, &mut Vec<u8>) -> Result<()>;

/// According to https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md, grpc uses
/// a four bytes to describe the length of a message, so it should not exceed u32::MAX.
const MAX_MESSAGE_SIZE: usize = u32::MAX as usize;

/// Defines how to serialize and deserialize between the specialized type and byte slice.
pub struct Marshaller<T> {
@@ -28,12 +32,19 @@ pub struct Marshaller<T> {
pub mod pb_codec {
use protobuf::{CodedInputStream, Message};

use super::MessageReader;
use crate::error::Result;
use super::{MessageReader, MAX_MESSAGE_SIZE};
use crate::error::{Error, Result};

#[inline]
pub fn ser<T: Message>(t: &T, buf: &mut Vec<u8>) {
t.write_to_vec(buf).unwrap()
pub fn ser<T: Message>(t: &T, buf: &mut Vec<u8>) -> Result<()> {
t.write_to_vec(buf)?;
if buf.len() <= MAX_MESSAGE_SIZE as usize {
Ok(())
} else {
Err(Error::Codec(
format!("message is too large: {} > {}", buf.len(), MAX_MESSAGE_SIZE).into(),
))
}
}

#[inline]
@@ -47,15 +58,21 @@ pub mod pb_codec {

#[cfg(feature = "prost-codec")]
pub mod pr_codec {
use bytes::buf::BufMut;
use prost::Message;

use super::MessageReader;
use crate::error::Result;
use super::{MessageReader, MAX_MESSAGE_SIZE};
use crate::error::{Error, Result};

#[inline]
pub fn ser<M: Message, B: BufMut>(msg: &M, buf: &mut B) {
msg.encode(buf).expect("Writing message to buffer failed");
pub fn ser<M: Message>(msg: &M, buf: &mut Vec<u8>) -> Result<()> {
msg.encode(buf)?;
if buf.len() <= MAX_MESSAGE_SIZE as usize {
Ok(())
} else {
Err(Error::Codec(
format!("message is too large: {} > {}", buf.len(), MAX_MESSAGE_SIZE).into(),
))
}
}

#[inline]
13 changes: 9 additions & 4 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -5,8 +5,6 @@ use std::{error, fmt, result};
use crate::call::RpcStatus;
use crate::grpc_sys::grpc_call_error;

#[cfg(feature = "prost-codec")]
use prost::DecodeError;
#[cfg(feature = "protobuf-codec")]
use protobuf::ProtobufError;

@@ -64,8 +62,15 @@ impl From<ProtobufError> for Error {
}

#[cfg(feature = "prost-codec")]
impl From<DecodeError> for Error {
fn from(e: DecodeError) -> Error {
impl From<prost::DecodeError> for Error {
fn from(e: prost::DecodeError) -> Error {
Error::Codec(Box::new(e))
}
}

#[cfg(feature = "prost-codec")]
impl From<prost::EncodeError> for Error {
fn from(e: prost::EncodeError) -> Error {
Error::Codec(Box::new(e))
}
}