From 3a2cff36ccd3b051963eb5e9b7bfde0d1e635ad1 Mon Sep 17 00:00:00 2001 From: jbesraa Date: Thu, 4 Jul 2024 13:45:49 +0300 Subject: [PATCH 1/8] Constrain `T` and `B` generics `Frame` is defined with `T` and `B` generics but the constraints are only introduced in the impl level which makes it harder to read the enum and understand whats those generics are about. As those generics are a key part of the `Frame` enum, it makes more sense to introduce the constraints in the enum level. --- protocols/v2/framing-sv2/src/framing.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/protocols/v2/framing-sv2/src/framing.rs b/protocols/v2/framing-sv2/src/framing.rs index 616d53354..b96535a6b 100644 --- a/protocols/v2/framing-sv2/src/framing.rs +++ b/protocols/v2/framing-sv2/src/framing.rs @@ -12,7 +12,11 @@ type Slice = buffer_sv2::Slice; /// A wrapper to be used in a context we need a generic reference to a frame /// but it doesn't matter which kind of frame it is (`Sv2Frame` or `HandShakeFrame`) #[derive(Debug)] -pub enum Frame { +pub enum Frame +where + T: Serialize + GetSize, + B: AsMut<[u8]> + AsRef<[u8]>, +{ HandShake(HandShakeFrame), Sv2(Sv2Frame), } @@ -26,13 +30,13 @@ impl + AsRef<[u8]>> Frame { } } -impl From for Frame { +impl + AsRef<[u8]>> From for Frame { fn from(v: HandShakeFrame) -> Self { Self::HandShake(v) } } -impl From> for Frame { +impl + AsRef<[u8]>> From> for Frame { fn from(v: Sv2Frame) -> Self { Self::Sv2(v) } @@ -175,7 +179,7 @@ impl + AsRef<[u8]>> Sv2Frame { } } -impl Sv2Frame { +impl + AsRef<[u8]>> Sv2Frame { /// Maps a `Sv2Frame` to `Sv2Frame` by applying `fun`, /// which is assumed to be a closure that converts `A` to `C` pub fn map(self, fun: fn(A) -> C) -> Sv2Frame { @@ -190,7 +194,7 @@ impl Sv2Frame { } } -impl TryFrom> for Sv2Frame { +impl + AsRef<[u8]>> TryFrom> for Sv2Frame { type Error = Error; fn try_from(v: Frame) -> Result { @@ -232,7 +236,7 @@ impl HandShakeFrame { } } -impl TryFrom> for HandShakeFrame { +impl + AsRef<[u8]>> TryFrom> for HandShakeFrame { type Error = Error; fn try_from(v: Frame) -> Result { From 86f085b782edb20f936e5773a93ada1ba960c38a Mon Sep 17 00:00:00 2001 From: jbesraa Date: Thu, 4 Jul 2024 14:28:45 +0300 Subject: [PATCH 2/8] Change `payload` fn signature to return Option ..instead of panicking --- .../src/sv2/criterion_sv2_benchmark.rs | 27 ++++++++++--------- benches/benches/src/sv2/iai_sv2_benchmark.rs | 6 ++--- examples/interop-cpp/src/main.rs | 2 +- examples/ping-pong-with-noise/src/node.rs | 4 +-- examples/ping-pong-without-noise/src/node.rs | 4 +-- protocols/v2/framing-sv2/src/framing.rs | 7 +++-- protocols/v2/sv2-ffi/src/lib.rs | 27 ++++++++++--------- roles/jd-client/src/lib/downstream.rs | 4 +-- roles/jd-client/src/lib/job_declarator/mod.rs | 2 +- .../lib/job_declarator/setup_connection.rs | 2 +- .../src/lib/template_receiver/mod.rs | 2 +- .../lib/template_receiver/setup_connection.rs | 2 +- .../src/lib/upstream_sv2/upstream.rs | 9 +++++-- roles/jd-server/src/lib/job_declarator/mod.rs | 10 ++++++- .../mining-proxy/src/lib/downstream_mining.rs | 4 +-- roles/mining-proxy/src/lib/upstream_mining.rs | 6 ++--- roles/pool/src/lib/mining_pool/mod.rs | 5 +++- .../src/lib/mining_pool/setup_connection.rs | 5 +++- roles/pool/src/lib/template_receiver/mod.rs | 8 +++++- .../lib/template_receiver/setup_connection.rs | 5 +++- roles/test-utils/mining-device/src/main.rs | 4 +-- .../src/lib/upstream_sv2/upstream.rs | 7 +++-- utils/message-generator/src/executor.rs | 1 - utils/message-generator/src/main.rs | 4 +-- 24 files changed, 96 insertions(+), 61 deletions(-) diff --git a/benches/benches/src/sv2/criterion_sv2_benchmark.rs b/benches/benches/src/sv2/criterion_sv2_benchmark.rs index 18fab853d..c39126d3d 100644 --- a/benches/benches/src/sv2/criterion_sv2_benchmark.rs +++ b/benches/benches/src/sv2/criterion_sv2_benchmark.rs @@ -53,10 +53,11 @@ fn client_sv2_setup_connection_serialize_deserialize(c: &mut Criterion) { let mut dst = vec![0; size]; let _serialized = frame.serialize(&mut dst); b.iter(|| { - let mut frame = StdFrame::from_bytes(black_box(dst.clone().into())).unwrap(); - let type_ = frame.get_header().unwrap().msg_type().clone(); - let payload = frame.payload(); - let _ = AnyMessage::try_from((type_, payload)).unwrap(); + if let Ok(mut frame) = StdFrame::from_bytes(black_box(dst.clone().into())) { + let msg_type = frame.header().msg_type().clone(); + let payload = frame.payload().unwrap(); + let _ = AnyMessage::try_from((msg_type, payload)).unwrap(); + } }); }); } @@ -94,10 +95,11 @@ fn client_sv2_open_channel_serialize_deserialize(c: &mut Criterion) { let mut dst = vec![0; size]; frame.serialize(&mut dst); b.iter(|| { - let mut frame = StdFrame::from_bytes(black_box(dst.clone().into())).unwrap(); - let type_ = frame.get_header().unwrap().msg_type().clone(); - let payload = frame.payload(); - black_box(AnyMessage::try_from((type_, payload)).unwrap()); + if let Ok(mut frame) = StdFrame::from_bytes(black_box(dst.clone().into())) { + let msg_type = frame.header().msg_type().clone(); + let payload = frame.payload().unwrap(); + black_box(AnyMessage::try_from((msg_type, payload)).unwrap()); + } }); }); } @@ -150,10 +152,11 @@ fn client_sv2_mining_message_submit_standard_serialize_deserialize(c: &mut Crite "client_sv2_mining_message_submit_standard_serialize_deserialize", |b| { b.iter(|| { - let mut frame = StdFrame::from_bytes(black_box(dst.clone().into())).unwrap(); - let type_ = frame.get_header().unwrap().msg_type().clone(); - let payload = frame.payload(); - black_box(AnyMessage::try_from((type_, payload)).unwrap()); + if let Ok(mut frame) = StdFrame::from_bytes(black_box(dst.clone().into())) { + let msg_type = frame.header().msg_type().clone(); + let payload = frame.payload().unwrap(); + black_box(AnyMessage::try_from((msg_type, payload)).unwrap()); + } }); }, ); diff --git a/benches/benches/src/sv2/iai_sv2_benchmark.rs b/benches/benches/src/sv2/iai_sv2_benchmark.rs index b049b9dc4..9965f9f01 100644 --- a/benches/benches/src/sv2/iai_sv2_benchmark.rs +++ b/benches/benches/src/sv2/iai_sv2_benchmark.rs @@ -47,7 +47,7 @@ fn client_sv2_setup_connection_serialize_deserialize() { frame.serialize(&mut dst); let mut frame = StdFrame::from_bytes(black_box(dst.clone().into())).unwrap(); let type_ = frame.get_header().unwrap().msg_type().clone(); - let payload = frame.payload(); + let payload = frame.payload().unwrap(); black_box(AnyMessage::try_from((type_, payload))); } @@ -78,7 +78,7 @@ fn client_sv2_open_channel_serialize_deserialize() { frame.serialize(&mut dst); let mut frame = StdFrame::from_bytes(black_box(dst.clone().into())).unwrap(); let type_ = frame.get_header().unwrap().msg_type().clone(); - let payload = frame.payload(); + let payload = frame.payload().unwrap(); black_box(AnyMessage::try_from((type_, payload))); } @@ -128,7 +128,7 @@ fn client_sv2_mining_message_submit_standard_serialize_deserialize() { frame.serialize(&mut dst); let mut frame = StdFrame::from_bytes(black_box(dst.clone().into())).unwrap(); let type_ = frame.get_header().unwrap().msg_type().clone(); - let payload = frame.payload(); + let payload = frame.payload().unwrap(); black_box(AnyMessage::try_from((type_, payload))); } diff --git a/examples/interop-cpp/src/main.rs b/examples/interop-cpp/src/main.rs index 09950e94b..90fd0e788 100644 --- a/examples/interop-cpp/src/main.rs +++ b/examples/interop-cpp/src/main.rs @@ -127,7 +127,7 @@ mod main_ { stream.read_exact(buffer).unwrap(); if let Ok(mut f) = decoder.next_frame() { let msg_type = f.get_header().unwrap().msg_type(); - let payload = f.payload(); + let payload = f.payload().unwrap(); let message: Sv2Message = (msg_type, payload).try_into().unwrap(); match message { Sv2Message::SetupConnection(_) => panic!(), diff --git a/examples/ping-pong-with-noise/src/node.rs b/examples/ping-pong-with-noise/src/node.rs index 1ae042aa8..f8d1792fd 100644 --- a/examples/ping-pong-with-noise/src/node.rs +++ b/examples/ping-pong-with-noise/src/node.rs @@ -99,7 +99,7 @@ impl Node { ) -> Message<'static> { match self.expected { Expected::Ping => { - let ping: Result = from_bytes(frame.payload()); + let ping: Result = from_bytes(frame.payload().unwrap()); match ping { Ok(ping) => { println!("Node {} received:", self.name); @@ -118,7 +118,7 @@ impl Node { } } Expected::Pong => { - let pong: Result = from_bytes(frame.payload()); + let pong: Result = from_bytes(frame.payload().unwrap()); match pong { Ok(pong) => { println!("Node {} received:", self.name); diff --git a/examples/ping-pong-without-noise/src/node.rs b/examples/ping-pong-without-noise/src/node.rs index 21edf617e..64c9c7415 100644 --- a/examples/ping-pong-without-noise/src/node.rs +++ b/examples/ping-pong-without-noise/src/node.rs @@ -87,7 +87,7 @@ impl Node { ) -> Message<'static> { match self.expected { Expected::Ping => { - let ping: Result = from_bytes(frame.payload()); + let ping: Result = from_bytes(frame.payload().unwrap()); match ping { Ok(ping) => { println!("Node {} received:", self.name); @@ -107,7 +107,7 @@ impl Node { } } Expected::Pong => { - let pong: Result = from_bytes(frame.payload()); + let pong: Result = from_bytes(frame.payload().unwrap()); match pong { Ok(pong) => { println!("Node {} received:", self.name); diff --git a/protocols/v2/framing-sv2/src/framing.rs b/protocols/v2/framing-sv2/src/framing.rs index b96535a6b..af8a22db5 100644 --- a/protocols/v2/framing-sv2/src/framing.rs +++ b/protocols/v2/framing-sv2/src/framing.rs @@ -82,12 +82,11 @@ impl + AsRef<[u8]>> Sv2Frame { /// This function is only intended as a fast way to get a reference to an /// already serialized payload. If the frame has not yet been /// serialized, this function should never be used (it will panic). - pub fn payload(&mut self) -> &mut [u8] { + pub fn payload(&mut self) -> Option<&mut [u8]> { if let Some(serialized) = self.serialized.as_mut() { - &mut serialized.as_mut()[Header::SIZE..] + Some(&mut serialized.as_mut()[Header::SIZE..]) } else { - // panic here is the expected behaviour - panic!("Sv2Frame is not yet serialized.") + None } } diff --git a/protocols/v2/sv2-ffi/src/lib.rs b/protocols/v2/sv2-ffi/src/lib.rs index 9befa0ca7..02a3a2236 100644 --- a/protocols/v2/sv2-ffi/src/lib.rs +++ b/protocols/v2/sv2-ffi/src/lib.rs @@ -469,7 +469,10 @@ pub extern "C" fn next_frame(decoder: *mut DecoderWrapper) -> CResult header.msg_type(), None => return CResult::Err(Sv2Error::InvalidSv2Frame), }; - let payload = f.payload(); + let payload = match f.payload() { + Some(payload) => payload, + None => return CResult::Err(Sv2Error::InvalidSv2Frame), + }; let len = payload.len(); let ptr = payload.as_mut_ptr(); let payload = unsafe { std::slice::from_raw_parts_mut(ptr, len) }; @@ -761,7 +764,7 @@ mod tests { let mut decoded = decoder.next_frame().unwrap(); let msg_type = decoded.get_header().unwrap().msg_type(); - let payload = decoded.payload(); + let payload = decoded.payload().unwrap(); let decoded_message: Sv2Message = (msg_type, payload).try_into().unwrap(); let decoded_message = match decoded_message { Sv2Message::CoinbaseOutputDataSize(m) => m, @@ -813,7 +816,7 @@ mod tests { // Extract payload of the frame which is the NewTemplate message let msg_type = decoded.get_header().unwrap().msg_type(); - let payload = decoded.payload(); + let payload = decoded.payload().unwrap(); let decoded_message: Sv2Message = (msg_type, payload).try_into().unwrap(); let decoded_message = match decoded_message { Sv2Message::NewTemplate(m) => m, @@ -861,7 +864,7 @@ mod tests { let mut decoded = decoder.next_frame().unwrap(); let msg_type = decoded.get_header().unwrap().msg_type(); - let payload = decoded.payload(); + let payload = decoded.payload().unwrap(); let decoded_message: Sv2Message = (msg_type, payload).try_into().unwrap(); let decoded_message = match decoded_message { Sv2Message::RequestTransactionData(m) => m, @@ -911,7 +914,7 @@ mod tests { let mut decoded = decoder.next_frame().unwrap(); let msg_type = decoded.get_header().unwrap().msg_type(); - let payload = decoded.payload(); + let payload = decoded.payload().unwrap(); let decoded_message: Sv2Message = (msg_type, payload).try_into().unwrap(); let decoded_message = match decoded_message { Sv2Message::RequestTransactionDataError(m) => m, @@ -961,7 +964,7 @@ mod tests { let mut decoded = decoder.next_frame().unwrap(); let msg_type = decoded.get_header().unwrap().msg_type(); - let payload = decoded.payload(); + let payload = decoded.payload().unwrap(); let decoded_message: Sv2Message = (msg_type, payload).try_into().unwrap(); let decoded_message = match decoded_message { Sv2Message::RequestTransactionDataSuccess(m) => m, @@ -1006,7 +1009,7 @@ mod tests { let mut decoded = decoder.next_frame().unwrap(); let msg_type = decoded.get_header().unwrap().msg_type(); - let payload = decoded.payload(); + let payload = decoded.payload().unwrap(); let decoded_message: Sv2Message = (msg_type, payload).try_into().unwrap(); let decoded_message = match decoded_message { Sv2Message::SetNewPrevHash(m) => m, @@ -1051,7 +1054,7 @@ mod tests { let mut decoded = decoder.next_frame().unwrap(); let msg_type = decoded.get_header().unwrap().msg_type(); - let payload = decoded.payload(); + let payload = decoded.payload().unwrap(); let decoded_message: Sv2Message = (msg_type, payload).try_into().unwrap(); let decoded_message = match decoded_message { Sv2Message::SubmitSolution(m) => m, @@ -1109,7 +1112,7 @@ mod tests { let mut decoded = decoder.next_frame().unwrap(); let msg_type = decoded.get_header().unwrap().msg_type(); - let payload = decoded.payload(); + let payload = decoded.payload().unwrap(); let decoded_message: Sv2Message = (msg_type, payload).try_into().unwrap(); let decoded_message = match decoded_message { Sv2Message::ChannelEndpointChanged(m) => m, @@ -1145,7 +1148,7 @@ mod tests { let mut decoded = decoder.next_frame().unwrap(); let msg_type = decoded.get_header().unwrap().msg_type(); - let payload = decoded.payload(); + let payload = decoded.payload().unwrap(); let decoded_message: Sv2Message = (msg_type, payload).try_into().unwrap(); let decoded_message = match decoded_message { Sv2Message::SetupConnection(m) => m, @@ -1194,7 +1197,7 @@ mod tests { let mut decoded = decoder.next_frame().unwrap(); let msg_type = decoded.get_header().unwrap().msg_type(); - let payload = decoded.payload(); + let payload = decoded.payload().unwrap(); let decoded_message: Sv2Message = (msg_type, payload).try_into().unwrap(); let decoded_message = match decoded_message { Sv2Message::SetupConnectionError(m) => m, @@ -1243,7 +1246,7 @@ mod tests { let mut decoded = decoder.next_frame().unwrap(); let msg_type = decoded.get_header().unwrap().msg_type(); - let payload = decoded.payload(); + let payload = decoded.payload().unwrap(); let decoded_message: Sv2Message = (msg_type, payload).try_into().unwrap(); let decoded_message = match decoded_message { Sv2Message::SetupConnectionSuccess(m) => m, diff --git a/roles/jd-client/src/lib/downstream.rs b/roles/jd-client/src/lib/downstream.rs index 5b26cef2f..57b1b3bca 100644 --- a/roles/jd-client/src/lib/downstream.rs +++ b/roles/jd-client/src/lib/downstream.rs @@ -253,7 +253,7 @@ impl DownstreamMiningNode { /// Parse the received message and relay it to the right upstream pub async fn next(self_mutex: &Arc>, mut incoming: StdFrame) { let message_type = incoming.get_header().unwrap().msg_type(); - let payload = incoming.payload(); + let payload = incoming.payload().unwrap(); let routing_logic = roles_logic_sv2::routing_logic::MiningRoutingLogic::None; @@ -698,7 +698,7 @@ pub async fn listen_for_downstream_mining( let mut incoming: StdFrame = node.receiver.recv().await.unwrap().try_into().unwrap(); let message_type = incoming.get_header().unwrap().msg_type(); - let payload = incoming.payload(); + let payload = incoming.payload().unwrap(); let routing_logic = roles_logic_sv2::routing_logic::CommonRoutingLogic::None; let node = Arc::new(Mutex::new(node)); if let Some(upstream) = upstream { diff --git a/roles/jd-client/src/lib/job_declarator/mod.rs b/roles/jd-client/src/lib/job_declarator/mod.rs index 29fb2e4f2..b76e471ee 100644 --- a/roles/jd-client/src/lib/job_declarator/mod.rs +++ b/roles/jd-client/src/lib/job_declarator/mod.rs @@ -293,7 +293,7 @@ impl JobDeclarator { loop { let mut incoming: StdFrame = receiver.recv().await.unwrap().try_into().unwrap(); let message_type = incoming.get_header().unwrap().msg_type(); - let payload = incoming.payload(); + let payload = incoming.payload().unwrap(); let next_message_to_send = ParseServerJobDeclarationMessages::handle_message_job_declaration( self_mutex.clone(), diff --git a/roles/jd-client/src/lib/job_declarator/setup_connection.rs b/roles/jd-client/src/lib/job_declarator/setup_connection.rs index 0e7b6fd8a..a1ddc613c 100644 --- a/roles/jd-client/src/lib/job_declarator/setup_connection.rs +++ b/roles/jd-client/src/lib/job_declarator/setup_connection.rs @@ -58,7 +58,7 @@ impl SetupConnectionHandler { let mut incoming: StdFrame = receiver.recv().await.unwrap().try_into().unwrap(); let message_type = incoming.get_header().unwrap().msg_type(); - let payload = incoming.payload(); + let payload = incoming.payload().unwrap(); ParseUpstreamCommonMessages::handle_message_common( Arc::new(Mutex::new(SetupConnectionHandler {})), message_type, diff --git a/roles/jd-client/src/lib/template_receiver/mod.rs b/roles/jd-client/src/lib/template_receiver/mod.rs index f418318a8..51dd8fbd6 100644 --- a/roles/jd-client/src/lib/template_receiver/mod.rs +++ b/roles/jd-client/src/lib/template_receiver/mod.rs @@ -187,7 +187,7 @@ impl TemplateRx { let mut frame: StdFrame = handle_result!(tx_status.clone(), received.try_into()); let message_type = frame.get_header().unwrap().msg_type(); - let payload = frame.payload(); + let payload = frame.payload().expect("No payload set"); let next_message_to_send = ParseServerTemplateDistributionMessages::handle_message_template_distribution( diff --git a/roles/jd-client/src/lib/template_receiver/setup_connection.rs b/roles/jd-client/src/lib/template_receiver/setup_connection.rs index 505b945c3..81fb0166c 100644 --- a/roles/jd-client/src/lib/template_receiver/setup_connection.rs +++ b/roles/jd-client/src/lib/template_receiver/setup_connection.rs @@ -54,7 +54,7 @@ impl SetupConnectionHandler { .try_into() .expect("Failed to parse incoming SetupConnectionResponse"); let message_type = incoming.get_header().unwrap().msg_type(); - let payload = incoming.payload(); + let payload = incoming.payload().unwrap(); ParseUpstreamCommonMessages::handle_message_common( Arc::new(Mutex::new(SetupConnectionHandler {})), message_type, diff --git a/roles/jd-client/src/lib/upstream_sv2/upstream.rs b/roles/jd-client/src/lib/upstream_sv2/upstream.rs index b04efa335..2228ae9b0 100644 --- a/roles/jd-client/src/lib/upstream_sv2/upstream.rs +++ b/roles/jd-client/src/lib/upstream_sv2/upstream.rs @@ -237,7 +237,12 @@ impl Upstream { return Err(framing_sv2::Error::ExpectedHandshakeFrame.into()); }; // Gets the message payload - let payload = incoming.payload(); + let payload = match incoming.payload() { + Some(payload) => payload, + None => { + return Err(framing_sv2::Error::ExpectedHandshakeFrame.into()); + } + }; // Handle the incoming message (should be either `SetupConnectionSuccess` or // `SetupConnectionError`) @@ -333,7 +338,7 @@ impl Upstream { let message_type = handle_result!(tx_status, message_type).msg_type(); - let payload = incoming.payload(); + let payload = incoming.payload().expect("Payload not found"); // Since this is not communicating with an SV2 proxy, but instead a custom SV1 // proxy where the routing logic is handled via the `Upstream`'s communication diff --git a/roles/jd-server/src/lib/job_declarator/mod.rs b/roles/jd-server/src/lib/job_declarator/mod.rs index 34d9e66de..051e1e795 100644 --- a/roles/jd-server/src/lib/job_declarator/mod.rs +++ b/roles/jd-server/src/lib/job_declarator/mod.rs @@ -206,7 +206,15 @@ impl JobDeclaratorDownstream { .ok_or_else(|| JdsError::Custom(String::from("No header set"))); let header = handle_result!(tx_status, header); let message_type = header.msg_type(); - let payload = frame.payload(); + let payload = match frame.payload() { + Some(p) => p, + None => { + handle_result!( + tx_status, + Err(JdsError::Custom("No payload set".to_string())) + ) + } + }; let next_message_to_send = ParseClientJobDeclarationMessages::handle_message_job_declaration( self_mutex.clone(), diff --git a/roles/mining-proxy/src/lib/downstream_mining.rs b/roles/mining-proxy/src/lib/downstream_mining.rs index 188055119..c810de0e9 100644 --- a/roles/mining-proxy/src/lib/downstream_mining.rs +++ b/roles/mining-proxy/src/lib/downstream_mining.rs @@ -229,7 +229,7 @@ impl DownstreamMiningNode { /// Parse the received message and relay it to the right upstream pub async fn next(self_mutex: Arc>, mut incoming: StdFrame) { let message_type = incoming.get_header().unwrap().msg_type(); - let payload = incoming.payload(); + let payload = incoming.payload().unwrap(); let routing_logic = super::get_routing_logic(); @@ -500,7 +500,7 @@ pub async fn listen_for_downstream_mining(address: SocketAddr) { task::spawn(async move { let mut incoming: StdFrame = node.receiver.recv().await.unwrap().try_into().unwrap(); let message_type = incoming.get_header().unwrap().msg_type(); - let payload = incoming.payload(); + let payload = incoming.payload().unwrap(); let routing_logic = super::get_common_routing_logic(); let node = Arc::new(Mutex::new(node)); diff --git a/roles/mining-proxy/src/lib/upstream_mining.rs b/roles/mining-proxy/src/lib/upstream_mining.rs index e3f6eef99..864ea5813 100644 --- a/roles/mining-proxy/src/lib/upstream_mining.rs +++ b/roles/mining-proxy/src/lib/upstream_mining.rs @@ -412,7 +412,7 @@ impl UpstreamMiningNode { .unwrap(); let message_type = response.get_header().unwrap().msg_type(); - let payload = response.payload(); + let payload = response.payload().unwrap(); match (message_type, payload).try_into() { Ok(CommonMessages::SetupConnectionSuccess(_)) => { let receiver = self_mutex @@ -578,7 +578,7 @@ impl UpstreamMiningNode { pub async fn next(self_mutex: Arc>, mut incoming: StdFrame) { let message_type = incoming.get_header().unwrap().msg_type(); - let payload = incoming.payload(); + let payload = incoming.payload().unwrap(); let routing_logic = super::get_routing_logic(); @@ -616,7 +616,7 @@ impl UpstreamMiningNode { .unwrap(); let message_type = response.get_header().unwrap().msg_type(); - let payload = response.payload(); + let payload = response.payload().unwrap(); match (message_type, payload).try_into() { Ok(CommonMessages::SetupConnectionSuccess(m)) => { let receiver = self_mutex diff --git a/roles/pool/src/lib/mining_pool/mod.rs b/roles/pool/src/lib/mining_pool/mod.rs index 4b9d10b18..adc2a83d0 100644 --- a/roles/pool/src/lib/mining_pool/mod.rs +++ b/roles/pool/src/lib/mining_pool/mod.rs @@ -204,7 +204,10 @@ impl Downstream { .get_header() .ok_or_else(|| PoolError::Custom(String::from("No header set")))? .msg_type(); - let payload = incoming.payload(); + let payload = match incoming.payload() { + Some(p) => p, + None => return Err(PoolError::Custom(String::from("No payload set"))), + }; debug!( "Received downstream message type: {:?}, payload: {:?}", message_type, payload diff --git a/roles/pool/src/lib/mining_pool/setup_connection.rs b/roles/pool/src/lib/mining_pool/setup_connection.rs index f0c47e9a8..80babbd57 100644 --- a/roles/pool/src/lib/mining_pool/setup_connection.rs +++ b/roles/pool/src/lib/mining_pool/setup_connection.rs @@ -62,7 +62,10 @@ impl SetupConnectionHandler { .get_header() .ok_or_else(|| PoolError::Custom(String::from("No header set")))? .msg_type(); - let payload = incoming.payload(); + let payload = match incoming.payload() { + Some(p) => p, + None => return Err(PoolError::Custom(String::from("No payload set"))), + }; let response = ParseDownstreamCommonMessages::handle_message_common( self_.clone(), message_type, diff --git a/roles/pool/src/lib/template_receiver/mod.rs b/roles/pool/src/lib/template_receiver/mod.rs index 2eeaa554f..fb3222a3e 100644 --- a/roles/pool/src/lib/template_receiver/mod.rs +++ b/roles/pool/src/lib/template_receiver/mod.rs @@ -112,7 +112,13 @@ impl TemplateRx { .get_header() .ok_or_else(|| PoolError::Custom(String::from("No header set"))); let message_type = handle_result!(status_tx, message_type_res).msg_type(); - let payload = message_from_tp.payload(); + let payload = match message_from_tp.payload() { + Some(p) => p, + None => { + let err = PoolError::Custom(String::from("No payload set")); + handle_result!(status_tx, Err(err)) + } + }; let msg = handle_result!( status_tx, ParseServerTemplateDistributionMessages::handle_message_template_distribution( diff --git a/roles/pool/src/lib/template_receiver/setup_connection.rs b/roles/pool/src/lib/template_receiver/setup_connection.rs index 6687eadc6..684937d78 100644 --- a/roles/pool/src/lib/template_receiver/setup_connection.rs +++ b/roles/pool/src/lib/template_receiver/setup_connection.rs @@ -57,7 +57,10 @@ impl SetupConnectionHandler { .get_header() .ok_or_else(|| PoolError::Custom(String::from("No header set")))? .msg_type(); - let payload = incoming.payload(); + let payload = match incoming.payload() { + Some(p) => p, + None => return Err(PoolError::Custom(String::from("No payload set"))), + }; ParseUpstreamCommonMessages::handle_message_common( Arc::new(Mutex::new(SetupConnectionHandler {})), diff --git a/roles/test-utils/mining-device/src/main.rs b/roles/test-utils/mining-device/src/main.rs index 763f83af5..4251d293d 100644 --- a/roles/test-utils/mining-device/src/main.rs +++ b/roles/test-utils/mining-device/src/main.rs @@ -181,7 +181,7 @@ impl SetupConnectionHandler { let mut incoming: StdFrame = receiver.recv().await.unwrap().try_into().unwrap(); let message_type = incoming.get_header().unwrap().msg_type(); - let payload = incoming.payload(); + let payload = incoming.payload().unwrap(); ParseUpstreamCommonMessages::handle_message_common( self_, message_type, @@ -314,7 +314,7 @@ impl Device { loop { let mut incoming: StdFrame = receiver.recv().await.unwrap().try_into().unwrap(); let message_type = incoming.get_header().unwrap().msg_type(); - let payload = incoming.payload(); + let payload = incoming.payload().unwrap(); let next = Device::handle_message_mining( self_mutex.clone(), message_type, diff --git a/roles/translator/src/lib/upstream_sv2/upstream.rs b/roles/translator/src/lib/upstream_sv2/upstream.rs index 6aab5978e..4854682a1 100644 --- a/roles/translator/src/lib/upstream_sv2/upstream.rs +++ b/roles/translator/src/lib/upstream_sv2/upstream.rs @@ -211,7 +211,10 @@ impl Upstream { return Err(framing_sv2::Error::ExpectedHandshakeFrame.into()); }; // Gets the message payload - let payload = incoming.payload(); + let payload = match incoming.payload() { + Some(payload) => payload, + None => return Err(framing_sv2::Error::ExpectedHandshakeFrame.into()), + }; // Handle the incoming message (should be either `SetupConnectionSuccess` or // `SetupConnectionError`) @@ -306,7 +309,7 @@ impl Upstream { let message_type = handle_result!(tx_status, message_type).msg_type(); - let payload = incoming.payload(); + let payload = incoming.payload().expect("Payload is None"); // Since this is not communicating with an SV2 proxy, but instead a custom SV1 // proxy where the routing logic is handled via the `Upstream`'s communication diff --git a/utils/message-generator/src/executor.rs b/utils/message-generator/src/executor.rs index 22843e03a..2e3e306d1 100644 --- a/utils/message-generator/src/executor.rs +++ b/utils/message-generator/src/executor.rs @@ -198,7 +198,6 @@ impl Executor { action.result.len(), result ); - match result { ActionResult::MatchMessageType(message_type) => { let message = match recv.recv().await { diff --git a/utils/message-generator/src/main.rs b/utils/message-generator/src/main.rs index 327d50cbc..dc0ff2a52 100644 --- a/utils/message-generator/src/main.rs +++ b/utils/message-generator/src/main.rs @@ -661,8 +661,8 @@ mod test { (EitherFrame::Sv2(mut frame1), EitherFrame::Sv2(mut frame2)) => { let mt1 = frame1.get_header().unwrap().msg_type(); let mt2 = frame2.get_header().unwrap().msg_type(); - let p1 = frame1.payload(); - let p2 = frame2.payload(); + let p1 = frame1.payload().unwrap(); + let p2 = frame2.payload().unwrap(); let message1: Mining = (mt1, p1).try_into().unwrap(); let message2: Mining = (mt2, p2).try_into().unwrap(); match (message1, message2) { From 33dc060c95992bf9b4f53ec18d76408aa99a5834 Mon Sep 17 00:00:00 2001 From: jbesraa Date: Thu, 4 Jul 2024 15:02:35 +0300 Subject: [PATCH 3/8] Change `Sv2Frame::get_header` to `header` and .. Remove `Option` from its return type. --- benches/benches/src/sv2/iai_sv2_benchmark.rs | 6 ++--- examples/interop-cpp/src/main.rs | 2 +- protocols/v2/framing-sv2/src/framing.rs | 4 +-- protocols/v2/sv2-ffi/src/lib.rs | 27 +++++++++---------- roles/jd-client/src/lib/downstream.rs | 4 +-- roles/jd-client/src/lib/job_declarator/mod.rs | 2 +- .../lib/job_declarator/setup_connection.rs | 2 +- .../src/lib/template_receiver/mod.rs | 8 +++--- .../lib/template_receiver/setup_connection.rs | 2 +- .../src/lib/upstream_sv2/upstream.rs | 15 ++--------- roles/jd-server/src/lib/job_declarator/mod.rs | 5 +--- .../mining-proxy/src/lib/downstream_mining.rs | 4 +-- roles/mining-proxy/src/lib/upstream_mining.rs | 8 +++--- roles/pool/src/lib/mining_pool/mod.rs | 5 +--- .../src/lib/mining_pool/setup_connection.rs | 5 +--- roles/pool/src/lib/template_receiver/mod.rs | 5 +--- .../lib/template_receiver/setup_connection.rs | 5 +--- roles/test-utils/mining-device/src/main.rs | 4 +-- .../src/lib/upstream_sv2/upstream.rs | 15 ++--------- utils/message-generator/src/executor.rs | 14 +++++----- utils/message-generator/src/main.rs | 4 +-- 21 files changed, 53 insertions(+), 93 deletions(-) diff --git a/benches/benches/src/sv2/iai_sv2_benchmark.rs b/benches/benches/src/sv2/iai_sv2_benchmark.rs index 9965f9f01..d052981f2 100644 --- a/benches/benches/src/sv2/iai_sv2_benchmark.rs +++ b/benches/benches/src/sv2/iai_sv2_benchmark.rs @@ -46,7 +46,7 @@ fn client_sv2_setup_connection_serialize_deserialize() { let mut dst = vec![0; size]; frame.serialize(&mut dst); let mut frame = StdFrame::from_bytes(black_box(dst.clone().into())).unwrap(); - let type_ = frame.get_header().unwrap().msg_type().clone(); + let type_ = frame.header().msg_type().clone(); let payload = frame.payload().unwrap(); black_box(AnyMessage::try_from((type_, payload))); } @@ -77,7 +77,7 @@ fn client_sv2_open_channel_serialize_deserialize() { let mut dst = vec![0; size]; frame.serialize(&mut dst); let mut frame = StdFrame::from_bytes(black_box(dst.clone().into())).unwrap(); - let type_ = frame.get_header().unwrap().msg_type().clone(); + let type_ = frame.header().msg_type().clone(); let payload = frame.payload().unwrap(); black_box(AnyMessage::try_from((type_, payload))); } @@ -127,7 +127,7 @@ fn client_sv2_mining_message_submit_standard_serialize_deserialize() { let mut dst = vec![0; size]; frame.serialize(&mut dst); let mut frame = StdFrame::from_bytes(black_box(dst.clone().into())).unwrap(); - let type_ = frame.get_header().unwrap().msg_type().clone(); + let type_ = frame.header().msg_type().clone(); let payload = frame.payload().unwrap(); black_box(AnyMessage::try_from((type_, payload))); } diff --git a/examples/interop-cpp/src/main.rs b/examples/interop-cpp/src/main.rs index 90fd0e788..c03f01d38 100644 --- a/examples/interop-cpp/src/main.rs +++ b/examples/interop-cpp/src/main.rs @@ -126,7 +126,7 @@ mod main_ { let buffer = decoder.writable(); stream.read_exact(buffer).unwrap(); if let Ok(mut f) = decoder.next_frame() { - let msg_type = f.get_header().unwrap().msg_type(); + let msg_type = f.header().msg_type(); let payload = f.payload().unwrap(); let message: Sv2Message = (msg_type, payload).try_into().unwrap(); match message { diff --git a/protocols/v2/framing-sv2/src/framing.rs b/protocols/v2/framing-sv2/src/framing.rs index af8a22db5..1070478f4 100644 --- a/protocols/v2/framing-sv2/src/framing.rs +++ b/protocols/v2/framing-sv2/src/framing.rs @@ -91,8 +91,8 @@ impl + AsRef<[u8]>> Sv2Frame { } /// `Sv2Frame` always returns `Some(self.header)`. - pub fn get_header(&self) -> Option { - Some(self.header) + pub fn header(&self) -> crate::header::Header { + self.header } /// Tries to build a `Sv2Frame` from raw bytes, assuming they represent a serialized `Sv2Frame` frame (`Self.serialized`). diff --git a/protocols/v2/sv2-ffi/src/lib.rs b/protocols/v2/sv2-ffi/src/lib.rs index 02a3a2236..fb1b55777 100644 --- a/protocols/v2/sv2-ffi/src/lib.rs +++ b/protocols/v2/sv2-ffi/src/lib.rs @@ -465,10 +465,7 @@ pub extern "C" fn next_frame(decoder: *mut DecoderWrapper) -> CResult { - let msg_type = match f.get_header() { - Some(header) => header.msg_type(), - None => return CResult::Err(Sv2Error::InvalidSv2Frame), - }; + let msg_type = f.header().msg_type(); let payload = match f.payload() { Some(payload) => payload, None => return CResult::Err(Sv2Error::InvalidSv2Frame), @@ -763,7 +760,7 @@ mod tests { let mut decoded = decoder.next_frame().unwrap(); - let msg_type = decoded.get_header().unwrap().msg_type(); + let msg_type = decoded.header().msg_type(); let payload = decoded.payload().unwrap(); let decoded_message: Sv2Message = (msg_type, payload).try_into().unwrap(); let decoded_message = match decoded_message { @@ -815,7 +812,7 @@ mod tests { let mut decoded = decoder.next_frame().unwrap(); // Extract payload of the frame which is the NewTemplate message - let msg_type = decoded.get_header().unwrap().msg_type(); + let msg_type = decoded.header().msg_type(); let payload = decoded.payload().unwrap(); let decoded_message: Sv2Message = (msg_type, payload).try_into().unwrap(); let decoded_message = match decoded_message { @@ -863,7 +860,7 @@ mod tests { let mut decoded = decoder.next_frame().unwrap(); - let msg_type = decoded.get_header().unwrap().msg_type(); + let msg_type = decoded.header().msg_type(); let payload = decoded.payload().unwrap(); let decoded_message: Sv2Message = (msg_type, payload).try_into().unwrap(); let decoded_message = match decoded_message { @@ -913,7 +910,7 @@ mod tests { let mut decoded = decoder.next_frame().unwrap(); - let msg_type = decoded.get_header().unwrap().msg_type(); + let msg_type = decoded.header().msg_type(); let payload = decoded.payload().unwrap(); let decoded_message: Sv2Message = (msg_type, payload).try_into().unwrap(); let decoded_message = match decoded_message { @@ -963,7 +960,7 @@ mod tests { let mut decoded = decoder.next_frame().unwrap(); - let msg_type = decoded.get_header().unwrap().msg_type(); + let msg_type = decoded.header().msg_type(); let payload = decoded.payload().unwrap(); let decoded_message: Sv2Message = (msg_type, payload).try_into().unwrap(); let decoded_message = match decoded_message { @@ -1008,7 +1005,7 @@ mod tests { let mut decoded = decoder.next_frame().unwrap(); - let msg_type = decoded.get_header().unwrap().msg_type(); + let msg_type = decoded.header().msg_type(); let payload = decoded.payload().unwrap(); let decoded_message: Sv2Message = (msg_type, payload).try_into().unwrap(); let decoded_message = match decoded_message { @@ -1053,7 +1050,7 @@ mod tests { let mut decoded = decoder.next_frame().unwrap(); - let msg_type = decoded.get_header().unwrap().msg_type(); + let msg_type = decoded.header().msg_type(); let payload = decoded.payload().unwrap(); let decoded_message: Sv2Message = (msg_type, payload).try_into().unwrap(); let decoded_message = match decoded_message { @@ -1111,7 +1108,7 @@ mod tests { let mut decoded = decoder.next_frame().unwrap(); - let msg_type = decoded.get_header().unwrap().msg_type(); + let msg_type = decoded.header().msg_type(); let payload = decoded.payload().unwrap(); let decoded_message: Sv2Message = (msg_type, payload).try_into().unwrap(); let decoded_message = match decoded_message { @@ -1147,7 +1144,7 @@ mod tests { let mut decoded = decoder.next_frame().unwrap(); - let msg_type = decoded.get_header().unwrap().msg_type(); + let msg_type = decoded.header().msg_type(); let payload = decoded.payload().unwrap(); let decoded_message: Sv2Message = (msg_type, payload).try_into().unwrap(); let decoded_message = match decoded_message { @@ -1196,7 +1193,7 @@ mod tests { let mut decoded = decoder.next_frame().unwrap(); - let msg_type = decoded.get_header().unwrap().msg_type(); + let msg_type = decoded.header().msg_type(); let payload = decoded.payload().unwrap(); let decoded_message: Sv2Message = (msg_type, payload).try_into().unwrap(); let decoded_message = match decoded_message { @@ -1245,7 +1242,7 @@ mod tests { let mut decoded = decoder.next_frame().unwrap(); - let msg_type = decoded.get_header().unwrap().msg_type(); + let msg_type = decoded.header().msg_type(); let payload = decoded.payload().unwrap(); let decoded_message: Sv2Message = (msg_type, payload).try_into().unwrap(); let decoded_message = match decoded_message { diff --git a/roles/jd-client/src/lib/downstream.rs b/roles/jd-client/src/lib/downstream.rs index 57b1b3bca..4d8d60fa0 100644 --- a/roles/jd-client/src/lib/downstream.rs +++ b/roles/jd-client/src/lib/downstream.rs @@ -252,7 +252,7 @@ impl DownstreamMiningNode { /// Parse the received message and relay it to the right upstream pub async fn next(self_mutex: &Arc>, mut incoming: StdFrame) { - let message_type = incoming.get_header().unwrap().msg_type(); + let message_type = incoming.header().msg_type(); let payload = incoming.payload().unwrap(); let routing_logic = roles_logic_sv2::routing_logic::MiningRoutingLogic::None; @@ -697,7 +697,7 @@ pub async fn listen_for_downstream_mining( ); let mut incoming: StdFrame = node.receiver.recv().await.unwrap().try_into().unwrap(); - let message_type = incoming.get_header().unwrap().msg_type(); + let message_type = incoming.header().msg_type(); let payload = incoming.payload().unwrap(); let routing_logic = roles_logic_sv2::routing_logic::CommonRoutingLogic::None; let node = Arc::new(Mutex::new(node)); diff --git a/roles/jd-client/src/lib/job_declarator/mod.rs b/roles/jd-client/src/lib/job_declarator/mod.rs index b76e471ee..605975012 100644 --- a/roles/jd-client/src/lib/job_declarator/mod.rs +++ b/roles/jd-client/src/lib/job_declarator/mod.rs @@ -292,7 +292,7 @@ impl JobDeclarator { let receiver = self_mutex.safe_lock(|d| d.receiver.clone()).unwrap(); loop { let mut incoming: StdFrame = receiver.recv().await.unwrap().try_into().unwrap(); - let message_type = incoming.get_header().unwrap().msg_type(); + let message_type = incoming.header().msg_type(); let payload = incoming.payload().unwrap(); let next_message_to_send = ParseServerJobDeclarationMessages::handle_message_job_declaration( diff --git a/roles/jd-client/src/lib/job_declarator/setup_connection.rs b/roles/jd-client/src/lib/job_declarator/setup_connection.rs index a1ddc613c..c5c86e102 100644 --- a/roles/jd-client/src/lib/job_declarator/setup_connection.rs +++ b/roles/jd-client/src/lib/job_declarator/setup_connection.rs @@ -57,7 +57,7 @@ impl SetupConnectionHandler { let mut incoming: StdFrame = receiver.recv().await.unwrap().try_into().unwrap(); - let message_type = incoming.get_header().unwrap().msg_type(); + let message_type = incoming.header().msg_type(); let payload = incoming.payload().unwrap(); ParseUpstreamCommonMessages::handle_message_common( Arc::new(Mutex::new(SetupConnectionHandler {})), diff --git a/roles/jd-client/src/lib/template_receiver/mod.rs b/roles/jd-client/src/lib/template_receiver/mod.rs index 51dd8fbd6..6a879aa07 100644 --- a/roles/jd-client/src/lib/template_receiver/mod.rs +++ b/roles/jd-client/src/lib/template_receiver/mod.rs @@ -186,7 +186,7 @@ impl TemplateRx { let received = handle_result!(tx_status.clone(), receiver.recv().await); let mut frame: StdFrame = handle_result!(tx_status.clone(), received.try_into()); - let message_type = frame.get_header().unwrap().msg_type(); + let message_type = frame.header().msg_type(); let payload = frame.payload().expect("No payload set"); let next_message_to_send = @@ -273,7 +273,7 @@ impl TemplateRx { _ => { error!("{:?}", frame); error!("{:?}", frame.payload()); - error!("{:?}", frame.get_header()); + error!("{:?}", frame.header()); std::process::exit(1); } } @@ -282,14 +282,14 @@ impl TemplateRx { error!("{:?}", m); error!("{:?}", frame); error!("{:?}", frame.payload()); - error!("{:?}", frame.get_header()); + error!("{:?}", frame.header()); std::process::exit(1); } Err(e) => { error!("{:?}", e); error!("{:?}", frame); error!("{:?}", frame.payload()); - error!("{:?}", frame.get_header()); + error!("{:?}", frame.header()); std::process::exit(1); } } diff --git a/roles/jd-client/src/lib/template_receiver/setup_connection.rs b/roles/jd-client/src/lib/template_receiver/setup_connection.rs index 81fb0166c..010199781 100644 --- a/roles/jd-client/src/lib/template_receiver/setup_connection.rs +++ b/roles/jd-client/src/lib/template_receiver/setup_connection.rs @@ -53,7 +53,7 @@ impl SetupConnectionHandler { .expect("Connection to TP closed!") .try_into() .expect("Failed to parse incoming SetupConnectionResponse"); - let message_type = incoming.get_header().unwrap().msg_type(); + let message_type = incoming.header().msg_type(); let payload = incoming.payload().unwrap(); ParseUpstreamCommonMessages::handle_message_common( Arc::new(Mutex::new(SetupConnectionHandler {})), diff --git a/roles/jd-client/src/lib/upstream_sv2/upstream.rs b/roles/jd-client/src/lib/upstream_sv2/upstream.rs index 2228ae9b0..dd9571a68 100644 --- a/roles/jd-client/src/lib/upstream_sv2/upstream.rs +++ b/roles/jd-client/src/lib/upstream_sv2/upstream.rs @@ -231,11 +231,7 @@ impl Upstream { }; // Gets the binary frame message type from the message header - let message_type = if let Some(header) = incoming.get_header() { - header.msg_type() - } else { - return Err(framing_sv2::Error::ExpectedHandshakeFrame.into()); - }; + let message_type = incoming.header().msg_type(); // Gets the message payload let payload = match incoming.payload() { Some(payload) => payload, @@ -329,14 +325,7 @@ impl Upstream { let mut incoming: StdFrame = handle_result!(tx_status, incoming.try_into()); // On message receive, get the message type from the message header and get the // message payload - let message_type = - incoming - .get_header() - .ok_or(super::super::error::Error::FramingSv2( - framing_sv2::Error::ExpectedSv2Frame, - )); - - let message_type = handle_result!(tx_status, message_type).msg_type(); + let message_type = incoming.header().msg_type(); let payload = incoming.payload().expect("Payload not found"); diff --git a/roles/jd-server/src/lib/job_declarator/mod.rs b/roles/jd-server/src/lib/job_declarator/mod.rs index 051e1e795..ef1629380 100644 --- a/roles/jd-server/src/lib/job_declarator/mod.rs +++ b/roles/jd-server/src/lib/job_declarator/mod.rs @@ -201,10 +201,7 @@ impl JobDeclaratorDownstream { match recv.recv().await { Ok(message) => { let mut frame: StdFrame = handle_result!(tx_status, message.try_into()); - let header = frame - .get_header() - .ok_or_else(|| JdsError::Custom(String::from("No header set"))); - let header = handle_result!(tx_status, header); + let header = frame.header(); let message_type = header.msg_type(); let payload = match frame.payload() { Some(p) => p, diff --git a/roles/mining-proxy/src/lib/downstream_mining.rs b/roles/mining-proxy/src/lib/downstream_mining.rs index c810de0e9..243250f81 100644 --- a/roles/mining-proxy/src/lib/downstream_mining.rs +++ b/roles/mining-proxy/src/lib/downstream_mining.rs @@ -228,7 +228,7 @@ impl DownstreamMiningNode { /// Parse the received message and relay it to the right upstream pub async fn next(self_mutex: Arc>, mut incoming: StdFrame) { - let message_type = incoming.get_header().unwrap().msg_type(); + let message_type = incoming.header().msg_type(); let payload = incoming.payload().unwrap(); let routing_logic = super::get_routing_logic(); @@ -499,7 +499,7 @@ pub async fn listen_for_downstream_mining(address: SocketAddr) { task::spawn(async move { let mut incoming: StdFrame = node.receiver.recv().await.unwrap().try_into().unwrap(); - let message_type = incoming.get_header().unwrap().msg_type(); + let message_type = incoming.header().msg_type(); let payload = incoming.payload().unwrap(); let routing_logic = super::get_common_routing_logic(); let node = Arc::new(Mutex::new(node)); diff --git a/roles/mining-proxy/src/lib/upstream_mining.rs b/roles/mining-proxy/src/lib/upstream_mining.rs index 864ea5813..f5b370567 100644 --- a/roles/mining-proxy/src/lib/upstream_mining.rs +++ b/roles/mining-proxy/src/lib/upstream_mining.rs @@ -411,7 +411,7 @@ impl UpstreamMiningNode { .unwrap() .unwrap(); - let message_type = response.get_header().unwrap().msg_type(); + let message_type = response.header().msg_type(); let payload = response.payload().unwrap(); match (message_type, payload).try_into() { Ok(CommonMessages::SetupConnectionSuccess(_)) => { @@ -577,7 +577,7 @@ impl UpstreamMiningNode { } pub async fn next(self_mutex: Arc>, mut incoming: StdFrame) { - let message_type = incoming.get_header().unwrap().msg_type(); + let message_type = incoming.header().msg_type(); let payload = incoming.payload().unwrap(); let routing_logic = super::get_routing_logic(); @@ -615,7 +615,7 @@ impl UpstreamMiningNode { .unwrap() .unwrap(); - let message_type = response.get_header().unwrap().msg_type(); + let message_type = response.header().msg_type(); let payload = response.payload().unwrap(); match (message_type, payload).try_into() { Ok(CommonMessages::SetupConnectionSuccess(m)) => { @@ -861,7 +861,7 @@ impl UpstreamMiningNode { // #[cfg(test)] // #[allow(unused)] // pub async fn next_faster(&mut self, mut incoming: StdFrame) { - // let message_type = incoming.get_header().unwrap().msg_type(); + // let message_type = incoming.header().msg_type(); // // When a channel is opened we need to setup the channel id in order to relay next messages // // to the right Downstream diff --git a/roles/pool/src/lib/mining_pool/mod.rs b/roles/pool/src/lib/mining_pool/mod.rs index adc2a83d0..e2b526e8f 100644 --- a/roles/pool/src/lib/mining_pool/mod.rs +++ b/roles/pool/src/lib/mining_pool/mod.rs @@ -200,10 +200,7 @@ impl Downstream { } pub async fn next(self_mutex: Arc>, mut incoming: StdFrame) -> PoolResult<()> { - let message_type = incoming - .get_header() - .ok_or_else(|| PoolError::Custom(String::from("No header set")))? - .msg_type(); + let message_type = incoming.header().msg_type(); let payload = match incoming.payload() { Some(p) => p, None => return Err(PoolError::Custom(String::from("No payload set"))), diff --git a/roles/pool/src/lib/mining_pool/setup_connection.rs b/roles/pool/src/lib/mining_pool/setup_connection.rs index 80babbd57..8710ac541 100644 --- a/roles/pool/src/lib/mining_pool/setup_connection.rs +++ b/roles/pool/src/lib/mining_pool/setup_connection.rs @@ -58,10 +58,7 @@ impl SetupConnectionHandler { } }; - let message_type = incoming - .get_header() - .ok_or_else(|| PoolError::Custom(String::from("No header set")))? - .msg_type(); + let message_type = incoming.header().msg_type(); let payload = match incoming.payload() { Some(p) => p, None => return Err(PoolError::Custom(String::from("No payload set"))), diff --git a/roles/pool/src/lib/template_receiver/mod.rs b/roles/pool/src/lib/template_receiver/mod.rs index fb3222a3e..0e086383a 100644 --- a/roles/pool/src/lib/template_receiver/mod.rs +++ b/roles/pool/src/lib/template_receiver/mod.rs @@ -108,10 +108,7 @@ impl TemplateRx { .try_into() .map_err(|e| PoolError::Codec(codec_sv2::Error::FramingSv2Error(e))) ); - let message_type_res = message_from_tp - .get_header() - .ok_or_else(|| PoolError::Custom(String::from("No header set"))); - let message_type = handle_result!(status_tx, message_type_res).msg_type(); + let message_type = message_from_tp.header().msg_type(); let payload = match message_from_tp.payload() { Some(p) => p, None => { diff --git a/roles/pool/src/lib/template_receiver/setup_connection.rs b/roles/pool/src/lib/template_receiver/setup_connection.rs index 684937d78..5fc6bf97a 100644 --- a/roles/pool/src/lib/template_receiver/setup_connection.rs +++ b/roles/pool/src/lib/template_receiver/setup_connection.rs @@ -53,10 +53,7 @@ impl SetupConnectionHandler { .await? .try_into() .map_err(|e| PoolError::Codec(codec_sv2::Error::FramingSv2Error(e)))?; - let message_type = incoming - .get_header() - .ok_or_else(|| PoolError::Custom(String::from("No header set")))? - .msg_type(); + let message_type = incoming.header().msg_type(); let payload = match incoming.payload() { Some(p) => p, None => return Err(PoolError::Custom(String::from("No payload set"))), diff --git a/roles/test-utils/mining-device/src/main.rs b/roles/test-utils/mining-device/src/main.rs index 4251d293d..daf7a4bb8 100644 --- a/roles/test-utils/mining-device/src/main.rs +++ b/roles/test-utils/mining-device/src/main.rs @@ -180,7 +180,7 @@ impl SetupConnectionHandler { info!("Setup connection sent to {}", address); let mut incoming: StdFrame = receiver.recv().await.unwrap().try_into().unwrap(); - let message_type = incoming.get_header().unwrap().msg_type(); + let message_type = incoming.header().msg_type(); let payload = incoming.payload().unwrap(); ParseUpstreamCommonMessages::handle_message_common( self_, @@ -313,7 +313,7 @@ impl Device { loop { let mut incoming: StdFrame = receiver.recv().await.unwrap().try_into().unwrap(); - let message_type = incoming.get_header().unwrap().msg_type(); + let message_type = incoming.header().msg_type(); let payload = incoming.payload().unwrap(); let next = Device::handle_message_mining( self_mutex.clone(), diff --git a/roles/translator/src/lib/upstream_sv2/upstream.rs b/roles/translator/src/lib/upstream_sv2/upstream.rs index 4854682a1..979d8beda 100644 --- a/roles/translator/src/lib/upstream_sv2/upstream.rs +++ b/roles/translator/src/lib/upstream_sv2/upstream.rs @@ -205,11 +205,7 @@ impl Upstream { }; // Gets the binary frame message type from the message header - let message_type = if let Some(header) = incoming.get_header() { - header.msg_type() - } else { - return Err(framing_sv2::Error::ExpectedHandshakeFrame.into()); - }; + let message_type = incoming.header().msg_type(); // Gets the message payload let payload = match incoming.payload() { Some(payload) => payload, @@ -300,14 +296,7 @@ impl Upstream { let mut incoming: StdFrame = handle_result!(tx_status, incoming.try_into()); // On message receive, get the message type from the message header and get the // message payload - let message_type = - incoming - .get_header() - .ok_or(super::super::error::Error::FramingSv2( - framing_sv2::Error::ExpectedSv2Frame, - )); - - let message_type = handle_result!(tx_status, message_type).msg_type(); + let message_type = incoming.header().msg_type(); let payload = incoming.payload().expect("Payload is None"); diff --git a/utils/message-generator/src/executor.rs b/utils/message-generator/src/executor.rs index 2e3e306d1..9a6633683 100644 --- a/utils/message-generator/src/executor.rs +++ b/utils/message-generator/src/executor.rs @@ -211,7 +211,7 @@ impl Executor { let message: Sv2Frame, _> = message.try_into().unwrap(); debug!("RECV {:#?}", message); - let header = message.get_header().unwrap(); + let header = message.header(); if header.msg_type() != *message_type { error!( @@ -242,8 +242,8 @@ impl Executor { let mut message: Sv2Frame, _> = message.try_into().unwrap(); debug!("RECV {:#?}", message); - let header = message.get_header().unwrap(); - let payload = message.payload(); + let header = message.header(); + let payload = message.payload().unwrap(); if subprotocol.as_str() == "CommonMessages" { match (header.msg_type(), payload).try_into() { Ok(roles_logic_sv2::parsers::CommonMessages::SetupConnection(m)) => { @@ -543,8 +543,8 @@ impl Executor { let mut message: Sv2Frame, _> = message.try_into().unwrap(); debug!("RECV {:#?}", message); - let header = message.get_header().unwrap(); - let payload = message.payload(); + let header = message.header(); + let payload = message.payload().unwrap(); if subprotocol.as_str() == "CommonMessages" { match (header.msg_type(), payload).try_into() { Ok(parsers::CommonMessages::SetupConnection(m)) => { @@ -755,7 +755,7 @@ impl Executor { let mut message: Sv2Frame, _> = message.try_into().unwrap(); debug!("RECV {:#?}", message); - let payload = message.payload(); + let payload = message.payload().unwrap(); if payload.len() != *message_len { error!( "WRONG MESSAGE len expected: {} received: {}", @@ -778,7 +778,7 @@ impl Executor { let message: Sv2Frame, _> = message.try_into().unwrap(); debug!("RECV {:#?}", message); - let header = message.get_header().unwrap(); + let header = message.header(); if header.ext_type() != *ext_type { error!( "WRONG EXTENSION TYPE expected: {} received: {}", diff --git a/utils/message-generator/src/main.rs b/utils/message-generator/src/main.rs index dc0ff2a52..f802a7bf9 100644 --- a/utils/message-generator/src/main.rs +++ b/utils/message-generator/src/main.rs @@ -659,8 +659,8 @@ mod test { let client_received = client_recv.recv().await.unwrap(); match (server_received, client_received) { (EitherFrame::Sv2(mut frame1), EitherFrame::Sv2(mut frame2)) => { - let mt1 = frame1.get_header().unwrap().msg_type(); - let mt2 = frame2.get_header().unwrap().msg_type(); + let mt1 = frame1.header().msg_type(); + let mt2 = frame2.header().msg_type(); let p1 = frame1.payload().unwrap(); let p2 = frame2.payload().unwrap(); let message1: Mining = (mt1, p1).try_into().unwrap(); From 3c322e4167b0ac991dadaea36ba858c3718b6cd3 Mon Sep 17 00:00:00 2001 From: jbesraa Date: Tue, 16 Jul 2024 16:52:05 +0300 Subject: [PATCH 4/8] Change `Sv2Frame` from struct to enum `Sv2Frame` can handle serialized and non-serliazed data, both scenarios were previously in the same struct wrapped by Option, where if one is Some, the other is None but never both None or Some. --- protocols/v2/framing-sv2/src/framing.rs | 92 ++++++++++++------------- 1 file changed, 44 insertions(+), 48 deletions(-) diff --git a/protocols/v2/framing-sv2/src/framing.rs b/protocols/v2/framing-sv2/src/framing.rs index 1070478f4..7076118cd 100644 --- a/protocols/v2/framing-sv2/src/framing.rs +++ b/protocols/v2/framing-sv2/src/framing.rs @@ -44,11 +44,9 @@ impl + AsRef<[u8]>> From> /// Abstraction for a SV2 Frame. #[derive(Debug, Clone)] -pub struct Sv2Frame { - header: Header, - payload: Option, - /// Serialized header + payload - serialized: Option, +pub enum Sv2Frame { + Payload { header: Header, payload: T }, + Raw { header: Header, serialized: B }, } impl + AsRef<[u8]>> Sv2Frame { @@ -57,23 +55,23 @@ impl + AsRef<[u8]>> Sv2Frame { /// When called on a non serialized frame, it is not so cheap (because it serializes it). #[inline] pub fn serialize(self, dst: &mut [u8]) -> Result<(), Error> { - if let Some(mut serialized) = self.serialized { - dst.swap_with_slice(serialized.as_mut()); - Ok(()) - } else if let Some(payload) = self.payload { - #[cfg(not(feature = "with_serde"))] - to_writer(self.header, dst).map_err(Error::BinarySv2Error)?; - #[cfg(not(feature = "with_serde"))] - to_writer(payload, &mut dst[Header::SIZE..]).map_err(Error::BinarySv2Error)?; - #[cfg(feature = "with_serde")] - to_writer(&self.header, dst.as_mut()).map_err(Error::BinarySv2Error)?; - #[cfg(feature = "with_serde")] - to_writer(&payload, &mut dst.as_mut()[Header::SIZE..]) - .map_err(Error::BinarySv2Error)?; - Ok(()) - } else { - // Sv2Frame always has a payload or a serialized payload - panic!("Impossible state") + match self { + Sv2Frame::Raw { mut serialized, .. } => { + dst.swap_with_slice(serialized.as_mut()); + Ok(()) + } + Sv2Frame::Payload { header, payload } => { + #[cfg(not(feature = "with_serde"))] + to_writer(header, dst).map_err(Error::BinarySv2Error)?; + #[cfg(not(feature = "with_serde"))] + to_writer(payload, &mut dst[Header::SIZE..]).map_err(Error::BinarySv2Error)?; + #[cfg(feature = "with_serde")] + to_writer(&header, dst.as_mut()).map_err(Error::BinarySv2Error)?; + #[cfg(feature = "with_serde")] + to_writer(&payload, &mut dst.as_mut()[Header::SIZE..]) + .map_err(Error::BinarySv2Error)?; + Ok(()) + } } } @@ -83,16 +81,18 @@ impl + AsRef<[u8]>> Sv2Frame { /// already serialized payload. If the frame has not yet been /// serialized, this function should never be used (it will panic). pub fn payload(&mut self) -> Option<&mut [u8]> { - if let Some(serialized) = self.serialized.as_mut() { - Some(&mut serialized.as_mut()[Header::SIZE..]) - } else { - None + match self { + Sv2Frame::Raw { serialized, .. } => Some(&mut serialized.as_mut()[Header::SIZE..]), + Sv2Frame::Payload { .. } => None, } } /// `Sv2Frame` always returns `Some(self.header)`. pub fn header(&self) -> crate::header::Header { - self.header + match self { + Self::Payload { header, .. } => *header, + Self::Raw { header, .. } => *header, + } } /// Tries to build a `Sv2Frame` from raw bytes, assuming they represent a serialized `Sv2Frame` frame (`Self.serialized`). @@ -113,10 +113,9 @@ impl + AsRef<[u8]>> Sv2Frame { pub fn from_bytes_unchecked(mut bytes: B) -> Self { // Unchecked function caller is supposed to already know that the passed bytes are valid let header = Header::from_bytes(bytes.as_mut()).expect("Invalid header"); - Self { + Self::Raw { header, - payload: None, - serialized: Some(bytes), + serialized: bytes, } } @@ -150,13 +149,9 @@ impl + AsRef<[u8]>> Sv2Frame { /// otherwise, returns the length of `self.payload`. #[inline] pub fn encoded_length(&self) -> usize { - if let Some(serialized) = self.serialized.as_ref() { - serialized.as_ref().len() - } else if let Some(payload) = self.payload.as_ref() { - payload.get_size() + Header::SIZE - } else { - // Sv2Frame always has a payload or a serialized payload - panic!("Impossible state") + match self { + Sv2Frame::Raw { serialized, .. } => serialized.as_ref().len(), + Sv2Frame::Payload { payload, .. } => payload.get_size() + Header::SIZE, } } @@ -170,10 +165,9 @@ impl + AsRef<[u8]>> Sv2Frame { ) -> Option { let extension_type = update_extension_type(extension_type, channel_msg); let len = message.get_size() as u32; - Header::from_len(len, message_type, extension_type).map(|header| Self { + Header::from_len(len, message_type, extension_type).map(|header| Self::Payload { header, - payload: Some(message), - serialized: None, + payload: message, }) } } @@ -181,14 +175,16 @@ impl + AsRef<[u8]>> Sv2Frame { impl + AsRef<[u8]>> Sv2Frame { /// Maps a `Sv2Frame` to `Sv2Frame` by applying `fun`, /// which is assumed to be a closure that converts `A` to `C` - pub fn map(self, fun: fn(A) -> C) -> Sv2Frame { - let serialized = self.serialized; - let header = self.header; - let payload = self.payload.map(fun); - Sv2Frame { - header, - payload, - serialized, + pub fn map(self, fun: fn(A) -> C) -> Sv2Frame + where + C: Serialize + GetSize, + { + match self { + Sv2Frame::Raw { header, serialized } => Sv2Frame::Raw { header, serialized }, + Sv2Frame::Payload { header, payload } => Sv2Frame::Payload { + header, + payload: fun(payload), + }, } } } From 3b11866343fea201d3ab83033af26b0c1947cfbc Mon Sep 17 00:00:00 2001 From: jbesraa Date: Thu, 11 Jul 2024 13:04:33 +0300 Subject: [PATCH 5/8] Remove redundant tests --- protocols/v2/framing-sv2/src/framing.rs | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/protocols/v2/framing-sv2/src/framing.rs b/protocols/v2/framing-sv2/src/framing.rs index 7076118cd..1cc233ef2 100644 --- a/protocols/v2/framing-sv2/src/framing.rs +++ b/protocols/v2/framing-sv2/src/framing.rs @@ -265,16 +265,3 @@ fn update_extension_type(extension_type: u16, channel_msg: bool) -> u16 { extension_type & mask } } - -#[cfg(test)] -use binary_sv2::binary_codec_sv2; - -#[cfg(test)] -#[derive(Serialize)] -struct T {} - -#[test] -fn test_size_hint() { - let h = Sv2Frame::>::size_hint(&[0, 128, 30, 46, 0, 0][..]); - assert!(h == 46); -} From 302670dd8d3e23588cd1c705aa196c93c20252e5 Mon Sep 17 00:00:00 2001 From: jbesraa Date: Thu, 11 Jul 2024 13:04:57 +0300 Subject: [PATCH 6/8] Reorder functions in `impl` In `Sv2Frame` and `HandShakeFrame` put initialisers at the top --- protocols/v2/framing-sv2/src/framing.rs | 109 ++++++++++++------------ 1 file changed, 53 insertions(+), 56 deletions(-) diff --git a/protocols/v2/framing-sv2/src/framing.rs b/protocols/v2/framing-sv2/src/framing.rs index 1cc233ef2..67da26fe3 100644 --- a/protocols/v2/framing-sv2/src/framing.rs +++ b/protocols/v2/framing-sv2/src/framing.rs @@ -50,6 +50,46 @@ pub enum Sv2Frame { } impl + AsRef<[u8]>> Sv2Frame { + /// Tries to build a `Sv2Frame` from raw bytes, assuming they represent a serialized `Sv2Frame` frame (`Self.serialized`). + /// Returns a `Sv2Frame` on success, or the number of the bytes needed to complete the frame + /// as an error. `Self.serialized` is `Some`, but nothing is assumed or checked about the correctness of the payload. + #[inline] + pub fn from_bytes(mut bytes: B) -> Result { + let hint = Self::size_hint(bytes.as_mut()); + + if hint == 0 { + Ok(Self::from_bytes_unchecked(bytes)) + } else { + Err(hint) + } + } + + #[inline] + pub fn from_bytes_unchecked(mut bytes: B) -> Self { + // Unchecked function caller is supposed to already know that the passed bytes are valid + let header = Header::from_bytes(bytes.as_mut()).expect("Invalid header"); + Sv2Frame::Raw { + header, + serialized: bytes, + } + } + + /// Tries to build a `Sv2Frame` from a non-serialized payload. + /// Returns a `Sv2Frame` if the size of the payload fits in the frame, `None` otherwise. + pub fn from_message( + message: T, + message_type: u8, + extension_type: u16, + channel_msg: bool, + ) -> Option { + let extension_type = update_extension_type(extension_type, channel_msg); + let len = message.get_size() as u32; + Header::from_len(len, message_type, extension_type).map(|header| Self::Payload { + header, + payload: message, + }) + } + /// Write the serialized `Sv2Frame` into `dst`. /// This operation when called on an already serialized frame is very cheap. /// When called on a non serialized frame, it is not so cheap (because it serializes it). @@ -75,11 +115,10 @@ impl + AsRef<[u8]>> Sv2Frame { } } - /// `self` can be either serialized (`self.serialized` is `Some()`) or - /// deserialized (`self.serialized` is `None`, `self.payload` is `Some()`). - /// This function is only intended as a fast way to get a reference to an - /// already serialized payload. If the frame has not yet been - /// serialized, this function should never be used (it will panic). + /// `self` can be either serialized (`self.serialized` is `Some()`) or deserialized + /// (`self.serialized` is `None`, `self.payload` is `Some()`). This function is only intended + /// as a fast way to get a reference to an already serialized payload. If the frame has not yet + /// been serialized, this function should never be used (it will panic). pub fn payload(&mut self) -> Option<&mut [u8]> { match self { Sv2Frame::Raw { serialized, .. } => Some(&mut serialized.as_mut()[Header::SIZE..]), @@ -87,38 +126,13 @@ impl + AsRef<[u8]>> Sv2Frame { } } - /// `Sv2Frame` always returns `Some(self.header)`. - pub fn header(&self) -> crate::header::Header { + pub fn header(&self) -> Header { match self { Self::Payload { header, .. } => *header, Self::Raw { header, .. } => *header, } } - /// Tries to build a `Sv2Frame` from raw bytes, assuming they represent a serialized `Sv2Frame` frame (`Self.serialized`). - /// Returns a `Sv2Frame` on success, or the number of the bytes needed to complete the frame - /// as an error. `Self.serialized` is `Some`, but nothing is assumed or checked about the correctness of the payload. - #[inline] - pub fn from_bytes(mut bytes: B) -> Result { - let hint = Self::size_hint(bytes.as_mut()); - - if hint == 0 { - Ok(Self::from_bytes_unchecked(bytes)) - } else { - Err(hint) - } - } - - #[inline] - pub fn from_bytes_unchecked(mut bytes: B) -> Self { - // Unchecked function caller is supposed to already know that the passed bytes are valid - let header = Header::from_bytes(bytes.as_mut()).expect("Invalid header"); - Self::Raw { - header, - serialized: bytes, - } - } - /// After parsing `bytes` into a `Header`, this function helps to determine if the `msg_length` /// field is correctly representing the size of the frame. /// - Returns `0` if the byte slice is of the expected size according to the header. @@ -145,8 +159,8 @@ impl + AsRef<[u8]>> Sv2Frame { } } - /// If `Sv2Frame` is serialized, returns the length of `self.serialized`, - /// otherwise, returns the length of `self.payload`. + /// If `Sv2Frame` is serialized, returns the length of `self.serialized`, otherwise, returns the + /// length of `self.payload`. #[inline] pub fn encoded_length(&self) -> usize { match self { @@ -154,22 +168,6 @@ impl + AsRef<[u8]>> Sv2Frame { Sv2Frame::Payload { payload, .. } => payload.get_size() + Header::SIZE, } } - - /// Tries to build a `Sv2Frame` from a non-serialized payload. - /// Returns a `Sv2Frame` if the size of the payload fits in the frame, `None` otherwise. - pub fn from_message( - message: T, - message_type: u8, - extension_type: u16, - channel_msg: bool, - ) -> Option { - let extension_type = update_extension_type(extension_type, channel_msg); - let len = message.get_size() as u32; - Header::from_len(len, message_type, extension_type).map(|header| Self::Payload { - header, - payload: message, - }) - } } impl + AsRef<[u8]>> Sv2Frame { @@ -200,8 +198,7 @@ impl + AsRef<[u8]>> TryFrom> } } -/// Abstraction for a Noise Handshake Frame -/// Contains only a `Slice` payload with a fixed length +/// Abstraction for a Noise Handshake Frame Contains only a `Slice` payload with a fixed length /// Only used during Noise Handshake process #[derive(Debug)] pub struct HandShakeFrame { @@ -209,11 +206,6 @@ pub struct HandShakeFrame { } impl HandShakeFrame { - /// Returns payload of `HandShakeFrame` as a `Vec` - pub fn get_payload_when_handshaking(&self) -> Vec { - self.payload[0..].to_vec() - } - /// Builds a `HandShakeFrame` from raw bytes. Nothing is assumed or checked about the correctness of the payload. pub fn from_bytes(bytes: Slice) -> Result { Ok(Self::from_bytes_unchecked(bytes)) @@ -224,6 +216,11 @@ impl HandShakeFrame { Self { payload: bytes } } + /// Returns payload of `HandShakeFrame` as a `Vec` + pub fn get_payload_when_handshaking(&self) -> Vec { + self.payload[0..].to_vec() + } + /// Returns the size of the `HandShakeFrame` payload. #[inline] fn encoded_length(&self) -> usize { From 07f18c7923e55981078842521abdfc40aab81186 Mon Sep 17 00:00:00 2001 From: jbesraa Date: Thu, 11 Jul 2024 13:18:04 +0300 Subject: [PATCH 7/8] Move `handshake_message_to_frame` to.. `HandShakeFrame` as `HandShake::from_message` --- protocols/v2/codec-sv2/src/lib.rs | 12 +++++++++--- protocols/v2/framing-sv2/src/framing.rs | 20 ++++++++++---------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/protocols/v2/codec-sv2/src/lib.rs b/protocols/v2/codec-sv2/src/lib.rs index 0a2492890..7da4ff885 100644 --- a/protocols/v2/codec-sv2/src/lib.rs +++ b/protocols/v2/codec-sv2/src/lib.rs @@ -30,7 +30,7 @@ pub use noise_sv2::{self, Initiator, NoiseCodec, Responder}; pub use buffer_sv2; -pub use framing_sv2::{self, framing::handshake_message_to_frame as h2f}; +pub use framing_sv2::{self}; #[cfg(feature = "noise_sv2")] #[derive(Debug)] @@ -49,7 +49,10 @@ impl State { pub fn step_0(&mut self) -> core::result::Result { match self { Self::HandShake(h) => match h { - HandshakeRole::Initiator(i) => i.step_0().map_err(|e| e.into()).map(h2f), + HandshakeRole::Initiator(i) => i + .step_0() + .map_err(|e| e.into()) + .map(HandShakeFrame::from_message), HandshakeRole::Responder(_) => Err(Error::InvalidStepForResponder), }, _ => Err(Error::NotInHandShakeState), @@ -64,7 +67,10 @@ impl State { Self::HandShake(h) => match h { HandshakeRole::Responder(r) => { let (message, codec) = r.step_1(re_pub)?; - Ok((h2f(message), Self::Transport(codec))) + Ok(( + HandShakeFrame::from_message(message), + Self::Transport(codec), + )) } HandshakeRole::Initiator(_) => Err(Error::InvalidStepForInitiator), }, diff --git a/protocols/v2/framing-sv2/src/framing.rs b/protocols/v2/framing-sv2/src/framing.rs index 67da26fe3..ce7d2001d 100644 --- a/protocols/v2/framing-sv2/src/framing.rs +++ b/protocols/v2/framing-sv2/src/framing.rs @@ -211,6 +211,16 @@ impl HandShakeFrame { Ok(Self::from_bytes_unchecked(bytes)) } + /// Returns a `HandShakeFrame` from a generic byte array + #[allow(clippy::useless_conversion)] + pub fn from_message>(message: T) -> HandShakeFrame { + let mut payload = Vec::new(); + payload.extend_from_slice(message.as_ref()); + HandShakeFrame { + payload: payload.into(), + } + } + #[inline] pub fn from_bytes_unchecked(bytes: Slice) -> Self { Self { payload: bytes } @@ -239,16 +249,6 @@ impl + AsRef<[u8]>> TryFrom> } } -/// Returns a `HandShakeFrame` from a generic byte array -#[allow(clippy::useless_conversion)] -pub fn handshake_message_to_frame>(message: T) -> HandShakeFrame { - let mut payload = Vec::new(); - payload.extend_from_slice(message.as_ref()); - HandShakeFrame { - payload: payload.into(), - } -} - /// Basically a boolean bit filter for `extension_type`. /// Takes an `extension_type` represented as a `u16` and a boolean flag (`channel_msg`). /// If `channel_msg` is true, it sets the most significant bit of `extension_type` to 1, From 05f9b4ca399d8b609e75eba67ad7863ba5ea7de0 Mon Sep 17 00:00:00 2001 From: jbesraa Date: Thu, 11 Jul 2024 14:31:02 +0300 Subject: [PATCH 8/8] Test `Sv2Frame` and `HandShakeFrame` --- protocols/v2/framing-sv2/src/framing.rs | 97 ++++++++++++++++++++++--- 1 file changed, 88 insertions(+), 9 deletions(-) diff --git a/protocols/v2/framing-sv2/src/framing.rs b/protocols/v2/framing-sv2/src/framing.rs index ce7d2001d..d4018d3ca 100644 --- a/protocols/v2/framing-sv2/src/framing.rs +++ b/protocols/v2/framing-sv2/src/framing.rs @@ -50,9 +50,10 @@ pub enum Sv2Frame { } impl + AsRef<[u8]>> Sv2Frame { - /// Tries to build a `Sv2Frame` from raw bytes, assuming they represent a serialized `Sv2Frame` frame (`Self.serialized`). - /// Returns a `Sv2Frame` on success, or the number of the bytes needed to complete the frame - /// as an error. `Self.serialized` is `Some`, but nothing is assumed or checked about the correctness of the payload. + /// Tries to build a `Sv2Frame` from raw bytes, assuming they represent a serialized `Sv2Frame` + /// frame (`Self.serialized`). Returns a `Sv2Frame` on success, or the number of the bytes + /// needed to complete the frame as an error. `Self.serialized` is `Some`, but nothing is + /// assumed or checked about the correctness of the payload. #[inline] pub fn from_bytes(mut bytes: B) -> Result { let hint = Self::size_hint(bytes.as_mut()); @@ -74,8 +75,8 @@ impl + AsRef<[u8]>> Sv2Frame { } } - /// Tries to build a `Sv2Frame` from a non-serialized payload. - /// Returns a `Sv2Frame` if the size of the payload fits in the frame, `None` otherwise. + /// Tries to build a `Sv2Frame` from a non-serialized payload. Returns a `Sv2Frame` if the size + /// of the payload fits in the frame, `None` otherwise. pub fn from_message( message: T, message_type: u8, @@ -249,10 +250,9 @@ impl + AsRef<[u8]>> TryFrom> } } -/// Basically a boolean bit filter for `extension_type`. -/// Takes an `extension_type` represented as a `u16` and a boolean flag (`channel_msg`). -/// If `channel_msg` is true, it sets the most significant bit of `extension_type` to 1, -/// otherwise, it clears the most significant bit to 0. +/// Basically a boolean bit filter for `extension_type`. Takes an `extension_type` represented as +/// a `u16` and a boolean flag (`channel_msg`). If `channel_msg` is true, it sets the most +/// significant bit of `extension_type` to 1, otherwise, it clears the most significant bit to 0. fn update_extension_type(extension_type: u16, channel_msg: bool) -> u16 { if channel_msg { let mask = 0b1000_0000_0000_0000; @@ -262,3 +262,82 @@ fn update_extension_type(extension_type: u16, channel_msg: bool) -> u16 { extension_type & mask } } + +#[cfg(test)] +mod tests { + use super::*; + use alloc::vec; + + #[test] + fn test_sv2_frame_from_bytes() { + let slice: Slice = vec![].into(); + assert_eq!( + Sv2Frame::::from_bytes(slice.clone()).unwrap_err(), + 6 + ); + let slice: Slice = vec![0; 6].into(); + assert!(Sv2Frame::::from_bytes(slice.clone()).is_ok()); + let slice: Slice = vec![0; 10].into(); + assert_eq!( + Sv2Frame::::from_bytes(slice.clone()).unwrap_err(), + 4 + ); + let slice: Slice = vec![0; 8].into(); + assert_eq!( + Sv2Frame::::from_bytes(slice.clone()).unwrap_err(), + 2 + ); + let slice: Slice = vec![0; 4].into(); + assert_eq!( + Sv2Frame::::from_bytes(slice.clone()).unwrap_err(), + 2 + ); + let slice: Slice = vec![0; 2].into(); + assert_eq!( + Sv2Frame::::from_bytes(slice.clone()).unwrap_err(), + 4 + ); + } + + #[test] + fn test_sv2_frame_from_message() { + let message = 0u32; + let message_type = 0u8; + let extension_type = 0u16; + let ret = + Sv2Frame::::from_message(message, message_type, extension_type, false) + .unwrap(); + assert_eq!(ret.encoded_length(), 10); + } + + #[test] + fn test_sv2_frame_payload() { + let message = 2u32; + let message_type = 0u8; + let extension_type = 0u16; + let mut frame = + Sv2Frame::::from_message(message, message_type, extension_type, true) + .unwrap(); + assert!(frame.payload().is_none()); + let slice: Slice = vec![0; 6].into(); + let mut frame = Sv2Frame::::from_bytes(slice.clone()).unwrap(); + assert_eq!(frame.payload().unwrap().len(), 0); + } + + #[test] + fn test_handsahke_from_bytes() { + let slice: Slice = vec![].into(); + let frame = HandShakeFrame::from_bytes(slice.clone()).unwrap(); + assert_eq!(frame.encoded_length(), 0); + let slice: Slice = vec![0; 6].into(); + let frame = HandShakeFrame::from_bytes(slice.clone()).unwrap(); + assert_eq!(frame.encoded_length(), 6); + } + + #[test] + fn test_handshake_from_message() { + let message = vec![0u8; 6]; + let frame = HandShakeFrame::from_message(message); + assert_eq!(frame.encoded_length(), 6); + } +}