Skip to content

Commit

Permalink
test: update tests for message serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
lbennett-stacki committed Jun 17, 2024
1 parent 16d0e47 commit 2967efd
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 29 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions packages/providers/core/lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ edition = "2021"
async-trait = "0.1.80"
bytes = "1.6.0"
ctrlc = "3.4.4"
env_logger = "0.11.3"
insta = "1.39.0"
log = "0.4.21"
tokio = { version = "1.38.0", features = [
"net",
Expand Down
73 changes: 59 additions & 14 deletions packages/providers/core/lib/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use tokio::{io, net::UnixStream};
use crate::{messages::MessageSlice, Message};

#[repr(u8)]
#[derive(Debug)]
#[derive(Debug, PartialEq, Clone)]
pub enum Opcode {
Initialize = 0x1,
Destroy = 0x2,
Expand All @@ -17,10 +17,22 @@ pub enum Opcode {

impl Opcode {
fn as_byte(&self) -> u8 {
self as *const _ as u8
log::debug!("Converting opcode {:?} to opcode byte", self);
let byte = self.clone() as u8;

log::debug!(
"TEMP self:{:?} ---- getvalue:{:?} ---- static0x3:{:?} ---- byte:{:?}",
self,
Opcode::GetValue,
0x3,
byte
);

byte
}

fn from_byte(byte: u8) -> Result<Opcode, ProtocolError> {
log::debug!("Converting opcode byte {} to opcode", byte);
// TODO: has to be a better way, also see similar conversions to improve in LSP, lexer and parser
match byte {
0x1 => Ok(Opcode::Initialize),
Expand All @@ -40,10 +52,10 @@ type Payload = Vec<u8>;

#[derive(Debug)]
pub struct Header {
version: u8,
opcode: Opcode,
checksum: u16,
payload_length: u32,
pub version: u8,
pub opcode: Opcode,
pub checksum: u16,
pub payload_length: u32,
}

impl Header {
Expand All @@ -54,6 +66,12 @@ impl Header {
let checksum_bytes = self.checksum.to_be_bytes();
let payload_length_bytes = self.payload_length.to_be_bytes();

log::debug!(
"HeaderBytes to Header, found payload length {:?} from {:?}",
payload_length_bytes,
self.payload_length
);

bytes = bytes
.iter()
.chain(checksum_bytes.iter())
Expand All @@ -66,17 +84,25 @@ impl Header {

fn from_message(message: MessageSlice) -> Result<Header, ProtocolError> {
let version = *message.first().ok_or(ProtocolError::InvalidHeader)?;
log::debug!("Header from message, version:{:?}", version);
let opcode = *message.get(1).ok_or(ProtocolError::InvalidHeader)?;
log::debug!("Header from message, opcode:{:?}", opcode);
let checksum: [u8; 2] = message
.get(1..3)
.get(2..=3)
.ok_or(ProtocolError::InvalidHeader)?
.try_into()
.or(Err(ProtocolError::InvalidHeader))?;
log::debug!("Header from message, checksum:{:?}", checksum);
let payload_length: [u8; 4] = message
.get(3..7)
.get(4..=7)
.ok_or(ProtocolError::InvalidHeader)?
.try_into()
.or(Err(ProtocolError::InvalidHeader))?;
log::debug!("Header from message, payload_length:{:?}", payload_length);
log::debug!(
"Header from message, payload_length AS BE BYTES:{:?}",
u32::from_be_bytes(payload_length)
);

Ok(Header {
version,
Expand All @@ -89,6 +115,8 @@ impl Header {

impl MessageSerializer {
pub fn serialize(&self, payload: Payload) -> Message {
log::debug!("Serializing payload {:?}", payload);

let header = self.generate_header(&payload);

header.iter().chain(payload.iter()).copied().collect()
Expand Down Expand Up @@ -117,20 +145,37 @@ pub enum ProtocolError {

pub struct MessageDeserializer {}

pub struct DeserializedMessage {
pub header: Header,
pub payload: Payload,
}

impl MessageDeserializer {
pub fn deserialize(message: MessageSlice) -> Result<(Header, Payload), ProtocolError> {
// TODO: use a struct to return instead of tuple
pub fn deserialize(message: MessageSlice) -> Result<DeserializedMessage, ProtocolError> {
log::debug!("Deserialize on message called: {:?}", message);

let header = Header::from_message(message)?;
let max_payload_index: usize = (7 + header.payload_length)
let max_payload_index: usize = (8 + header.payload_length)
.try_into()
.or(Err(ProtocolError::UnreadableStream))?;
let payload: &[u8] = message
.get(7..max_payload_index)
.ok_or(ProtocolError::InvalidPayload)?;

let payload = message.get(8..max_payload_index);
log::debug!("TEMP Payload check {:?}", payload);

let payload: Vec<u8> = payload.ok_or(ProtocolError::InvalidPayload)?.to_vec();

log::debug!(
"Deserialize found a payload of len {}: {:?} between indexes 7..{}",
payload.len(),
payload,
max_payload_index
);

// TODO: message may have left over bytes from another/the next message.
// We need to handle this.

Ok((header, payload.to_vec()))
Ok(DeserializedMessage { header, payload })
}
}

Expand Down
55 changes: 45 additions & 10 deletions packages/providers/core/lib/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,16 @@ impl Server {
#[cfg(test)]
mod tests {
use super::Controller;
use crate::ProviderError;
use crate::protocol::Opcode;
use crate::{messages::Message, Server};
use crate::{MessageDeserializer, MessageStreamReader, ProviderError};
use async_trait::async_trait;
use insta::assert_snapshot;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::io::AsyncWriteExt;
use tokio::net::UnixStream;
use tokio::sync::Mutex;
use tokio::time::sleep;

#[derive(Debug)]
Expand All @@ -90,12 +93,16 @@ mod tests {
#[async_trait]
impl Controller for TestController {
async fn action(&self, message: &Message) -> Result<Message, ProviderError> {
Ok(message.clone())
log::debug!("TestController received message {:?}", message);
let response = "nout".as_bytes().to_vec();
Ok(response)
}
}

#[tokio::test]
async fn test_server() {
env_logger::init();

let controller = Arc::new(TestController {});
let path = "/tmp/nv-provider.sock";

Expand All @@ -104,18 +111,46 @@ mod tests {
tokio::spawn(async move {
let _ = server.start().await;
});
// TODO: no sleepy
// TODO: no sleepy, receive signal that server is ready instead
sleep(Duration::from_millis(100)).await;

let mut client = UnixStream::connect(path).await.unwrap();
let client = Arc::new(Mutex::new(UnixStream::connect(path).await.unwrap()));
let msg = b"who";
client.write_all(msg).await.unwrap();

let mut buf = [0; 1024];
let n = client.read(&mut buf).await.unwrap();
let writable_client = client.clone();
let mut writable_client = writable_client.lock().await;
writable_client.write_all(msg).await.unwrap();
std::mem::drop(writable_client);

log::debug!("Reading response...");
let message = MessageStreamReader::read_message(&client).await.unwrap();

log::debug!("Message read: {:?}", message);
log::debug!("Length read: {:?}", message.len());

let deserialized = MessageDeserializer::deserialize(&message);

let response = String::from_utf8((buf[..n]).to_vec());
match deserialized {
Ok(res) => {
log::debug!("Deserialized header {:?}", res.header);
log::debug!("Deserialized payload {:?}", res.payload);

assert_eq!(response, Ok("who".to_owned()));
let header = res.header;
let payload = res.payload;
assert_eq!(header.version, 0x0);
assert_eq!(header.opcode, Opcode::GetValue);
assert_snapshot!(header.checksum);
assert_eq!(header.payload_length, 4);
assert_eq!(String::from_utf8(payload), Ok("nout".to_owned()));
}

Err(err) => {
log::error!(
"Error deserializing message response in server test {:?}",
err
);
panic!("Deserialize should not error");
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
source: packages/providers/core/lib/src/server.rs
expression: header.checksum
---
0
7 changes: 3 additions & 4 deletions packages/providers/core/load-testing/src/load.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,12 @@ pub async fn generate() -> Result<DataCollection, ServerError> {

match deserialized {
Ok(res) => {
log::debug!("Deserialized header {:?}", res.0);
log::debug!("Deserialized payload {:?}", res.1);
log::debug!("Deserialized header {:?}", res.header);
log::debug!("Deserialized payload {:?}", res.payload);
log::debug!(
"Deserialized payload encoded UTF8 {:?}",
String::from_utf8(res.1)
String::from_utf8(res.payload)
);
// TODO: assert utf8 success?
}

Err(err) => {
Expand Down
2 changes: 1 addition & 1 deletion packages/providers/providers/env/lib/src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ impl Provider for EnvProvider {
async fn get_value(&self, key: &str) -> Result<Vec<u8>, ProviderError> {
log::debug!("Getting EnvProvider value for key {}", key);
let result = env::var(key).map_err(|error| {
log::error!("ENV PROV ERR {:?}", error);
log::error!("Env prov err {:?}", error);

match error {
env::VarError::NotPresent => ProviderError::NoValueForKey,
Expand Down

0 comments on commit 2967efd

Please sign in to comment.