From 4643ffbf21f62195210e7626e10b0a4c26d45560 Mon Sep 17 00:00:00 2001 From: Tim Ling <791016+kettlebell@users.noreply.github.com> Date: Fri, 25 Nov 2022 14:01:45 +1100 Subject: [PATCH] Introduce supported protocol ids and vers --- spectrum-network/src/lib.rs | 2 +- spectrum-network/src/network_controller.rs | 87 ++--- spectrum-network/src/peer_conn_handler.rs | 316 +++++++++--------- spectrum-network/src/protocol.rs | 4 +- spectrum-network/src/protocol_upgrade.rs | 73 ++-- .../supported_protocol_vers.rs | 152 +++++++++ spectrum-network/src/types.rs | 8 +- spectrum-network/tests/tests.rs | 3 +- spectrum-node/src/main.rs | 3 +- 9 files changed, 407 insertions(+), 241 deletions(-) create mode 100644 spectrum-network/src/protocol_upgrade/supported_protocol_vers.rs diff --git a/spectrum-network/src/lib.rs b/spectrum-network/src/lib.rs index 93b2fe1c..af9300c5 100644 --- a/spectrum-network/src/lib.rs +++ b/spectrum-network/src/lib.rs @@ -2,7 +2,7 @@ pub mod network_controller; pub mod peer_conn_handler; pub mod peer_manager; pub mod protocol; +pub mod protocol_api; pub mod protocol_handler; pub mod protocol_upgrade; pub mod types; -pub mod protocol_api; diff --git a/spectrum-network/src/network_controller.rs b/spectrum-network/src/network_controller.rs index 6c062fa4..0504f37f 100644 --- a/spectrum-network/src/network_controller.rs +++ b/spectrum-network/src/network_controller.rs @@ -5,6 +5,7 @@ use crate::peer_conn_handler::{ use crate::peer_manager::{PeerEvents, PeerManagerOut, Peers}; use crate::protocol::ProtocolConfig; use crate::protocol_api::ProtocolEvents; +use crate::protocol_upgrade::supported_protocol_vers::SupportedProtocolIdMap; use crate::types::{ProtocolId, ProtocolVer}; use libp2p::core::connection::ConnectionId; @@ -215,7 +216,7 @@ impl NetworkEvents for NetworkController { conn_handler_conf: PeerConnHandlerConf, /// All supported protocols and their handlers - supported_protocols: HashMap, + supported_protocols: SupportedProtocolIdMap<(ProtocolConfig, THandler)>, /// PeerManager API peers: TPeers, /// PeerManager stream itself @@ -240,7 +241,7 @@ where ) -> Self { Self { conn_handler_conf, - supported_protocols, + supported_protocols: supported_protocols.into(), peers, peer_manager, enabled_peers: HashMap::new(), @@ -254,8 +255,7 @@ where self.conn_handler_conf.clone(), self.supported_protocols .iter() - .clone() - .map(|(prot_id, (conf, _))| (*prot_id, conf.clone())) + .map(|(prot_id, (conf, _))| (prot_id, conf.clone())) .collect::>(), ) } @@ -361,7 +361,7 @@ where { let protocol_id = protocol_tag.protocol_id(); let protocol_ver = protocol_tag.protocol_ver(); - match enabled_protocols.entry(protocol_id) { + match enabled_protocols.entry(protocol_id.get_inner()) { Entry::Occupied(mut entry) => { trace!( "Current state of protocol {:?} is {:?}", @@ -371,12 +371,12 @@ where if let (EnabledProtocol::PendingEnable, handler) = entry.get() { handler.protocol_enabled( peer_id, - protocol_ver, + protocol_ver.get_inner(), out_channel.clone(), handshake, ); let enabled_protocol = EnabledProtocol::Enabled { - ver: protocol_ver, + ver: protocol_ver.get_inner(), sink: out_channel, }; entry.insert((enabled_protocol, handler.clone())); @@ -398,42 +398,51 @@ where }) = self.enabled_peers.get_mut(&peer_id) { let protocol_id = protocol_tag.protocol_id(); - if let Some((_, prot_handler)) = self.supported_protocols.get(&protocol_id) { - match enabled_protocols.entry(protocol_id) { - Entry::Vacant(entry) => { - entry.insert((EnabledProtocol::PendingApprove, prot_handler.clone())); - prot_handler.protocol_requested( + let (_, prot_handler) = self.supported_protocols.get_supported(protocol_id); + match enabled_protocols.entry(protocol_id.get_inner()) { + Entry::Vacant(entry) => { + entry.insert((EnabledProtocol::PendingApprove, prot_handler.clone())); + prot_handler.protocol_requested( + peer_id, + protocol_tag.protocol_ver().get_inner(), + handshake, + ); + } + Entry::Occupied(_) => { + warn!( + "Peer {:?} opened already enabled protocol {:?}", + peer_id, protocol_id + ); + self.pending_actions + .push_back(NetworkBehaviourAction::NotifyHandler { peer_id, - protocol_tag.protocol_ver(), - handshake, - ); - } - Entry::Occupied(_) => { - warn!( - "Peer {:?} opened already enabled protocol {:?}", - peer_id, protocol_id - ); - self.pending_actions - .push_back(NetworkBehaviourAction::NotifyHandler { - peer_id, - handler: NotifyHandler::One(connection), - event: ConnHandlerIn::Close(protocol_id), - }) - } + handler: NotifyHandler::One(connection), + event: ConnHandlerIn::Close(protocol_id), + }) } - } else { - self.pending_actions - .push_back(NetworkBehaviourAction::NotifyHandler { + } + } + } + ConnHandlerOut::ClosedByPeer(protocol_id) | ConnHandlerOut::RefusedToOpen(protocol_id) => { + if let Some(ConnectedPeer::Connected { + enabled_protocols, .. + }) = self.enabled_peers.get_mut(&peer_id) + { + match enabled_protocols.entry(protocol_id.get_inner()) { + Entry::Occupied(entry) => { + trace!( + "Peer {:?} closed the substream for protocol {:?}", peer_id, - handler: NotifyHandler::One(connection), - event: ConnHandlerIn::Close(protocol_id), - }) + protocol_id + ); + entry.remove(); + } + Entry::Vacant(_) => {} } } } - ConnHandlerOut::ClosedByPeer(protocol_id) - | ConnHandlerOut::RefusedToOpen(protocol_id) - | ConnHandlerOut::Closed(protocol_id) => { + + ConnHandlerOut::Closed(protocol_id) => { if let Some(ConnectedPeer::Connected { enabled_protocols, .. }) = self.enabled_peers.get_mut(&peer_id) @@ -564,7 +573,7 @@ where ConnectedPeer::Connected { enabled_protocols, .. } => { - if let Some((_, prot_handler)) = self.supported_protocols.get(&protocol) { + if let Some((_, prot_handler)) = self.supported_protocols.get(protocol) { match enabled_protocols.entry(protocol) { Entry::Occupied(_) => warn!( "PM requested already enabled protocol {:?} with peer {:?}", @@ -617,7 +626,7 @@ where enabled_protocols, }) = self.enabled_peers.get_mut(&peer_id) { - let (_, prot_handler) = self.supported_protocols.get(&protocol_id).unwrap(); + let (_, prot_handler) = self.supported_protocols.get(protocol_id).unwrap(); match enabled_protocols.entry(protocol_id) { Entry::Occupied(protocol_entry) => match protocol_entry.remove_entry().1 { // Protocol handler approves either outbound or inbound protocol request. diff --git a/spectrum-network/src/peer_conn_handler.rs b/spectrum-network/src/peer_conn_handler.rs index f114e028..22f4bdcb 100644 --- a/spectrum-network/src/peer_conn_handler.rs +++ b/spectrum-network/src/peer_conn_handler.rs @@ -4,7 +4,10 @@ use crate::peer_conn_handler::message_sink::{MessageSink, StreamNotification}; use crate::protocol_upgrade::combinators::AnyUpgradeOf; use crate::protocol_upgrade::handshake::PolyVerHandshakeSpec; use crate::protocol_upgrade::substream::{ProtocolSubstreamIn, ProtocolSubstreamOut}; -use crate::types::{ProtocolId, ProtocolTag, ProtocolVer, RawMessage}; +use crate::protocol_upgrade::supported_protocol_vers::{ + SupportedProtocolId, SupportedProtocolIdMap, SupportedProtocolTag, SupportedProtocolVer, +}; +use crate::types::{ProtocolId, ProtocolTag, RawMessage}; use futures::channel::mpsc; pub use futures::prelude::*; use libp2p::core::ConnectedPoint; @@ -17,7 +20,7 @@ use libp2p::{InboundUpgrade, OutboundUpgrade, PeerId}; use crate::protocol::{ProtocolConfig, ProtocolSpec}; use crate::protocol_upgrade::{ProtocolUpgradeErr, ProtocolUpgradeIn, ProtocolUpgradeOut}; use log::trace; -use std::collections::{HashMap, VecDeque}; +use std::collections::VecDeque; use std::fmt::{Debug, Formatter}; use std::mem; use std::pin::Pin; @@ -27,7 +30,7 @@ use std::time::{Duration, Instant}; #[derive(Debug)] pub struct Protocol { /// Negotiated protocol version - pub ver: ProtocolVer, + pub ver: SupportedProtocolVer, /// Spec for negotiated protocol version pub spec: ProtocolSpec, /// Protocol state @@ -35,7 +38,7 @@ pub struct Protocol { pub state: Option, /// Specs for all supported versions of this protocol /// Note, versions must be listed in descending order. - pub all_versions_specs: Vec<(ProtocolVer, ProtocolSpec)>, + pub all_versions_specs: Vec<(SupportedProtocolVer, ProtocolSpec)>, } pub enum ProtocolState { @@ -127,7 +130,7 @@ pub enum ConnHandlerIn { /// substream request for the given [`ProtocolId`]. /// /// Must always be answered by a [`ConnHandlerOut::Closed`] event. - Close(ProtocolId), + Close(SupportedProtocolId), /// Instruct the handler to close the notification substreams, or reject any pending incoming /// substream request for all protocols. /// @@ -140,12 +143,12 @@ pub enum ConnHandlerOut { // Input commands outcomes: /// Ack [`ConnHandlerIn::Open`]. Substream was negotiated. Opened { - protocol_tag: ProtocolTag, + protocol_tag: SupportedProtocolTag, out_channel: MessageSink, handshake: Option, }, /// Ack [`ConnHandlerIn::Open`]. Peer refused to open a substream. - RefusedToOpen(ProtocolId), + RefusedToOpen(SupportedProtocolId), /// Ack [`ConnHandlerIn::Close`] Closed(ProtocolId), /// Ack [`ConnHandlerIn::CloseAllProtocols`] @@ -158,14 +161,14 @@ pub enum ConnHandlerOut { /// yet been acknowledged by a matching [`ConnHandlerOut`], then you don't need to a send /// another [`ConnHandlerIn`]. OpenedByPeer { - protocol_tag: ProtocolTag, + protocol_tag: SupportedProtocolTag, handshake: Option, }, /// The remote would like the substreams to be closed. Send a [`ConnHandlerIn::Close`] in /// order to close them. If a [`ConnHandlerIn::Close`] has been sent before and has not yet /// been acknowledged by a [`ConnHandlerOut::CloseResult`], then you don't need to a send /// another one. - ClosedByPeer(ProtocolId), + ClosedByPeer(SupportedProtocolId), /// Received a message on a custom protocol substream. /// Can only happen when the handler is in the open state. Message { @@ -190,11 +193,14 @@ pub trait PeerConnHandlerActions { pub struct PartialPeerConnHandler { conf: PeerConnHandlerConf, - supported_protocols: Vec<(ProtocolId, ProtocolConfig)>, + supported_protocols: Vec<(SupportedProtocolId, ProtocolConfig)>, } impl PartialPeerConnHandler { - pub fn new(conf: PeerConnHandlerConf, supported_protocols: Vec<(ProtocolId, ProtocolConfig)>) -> Self { + pub fn new( + conf: PeerConnHandlerConf, + supported_protocols: Vec<(SupportedProtocolId, ProtocolConfig)>, + ) -> Self { Self { conf, supported_protocols, @@ -206,25 +212,24 @@ impl IntoConnectionHandler for PartialPeerConnHandler { type Handler = PeerConnHandler; fn into_handler(self, remote_peer_id: &PeerId, connected_point: &ConnectedPoint) -> Self::Handler { - let protocols = HashMap::from_iter(self.supported_protocols.iter().flat_map(|(protocol_id, p)| { - p.supported_versions.iter().map(|(ver, spec)| { - ( - *protocol_id, - Protocol { - ver: *ver, - spec: *spec, - state: Some(ProtocolState::Closed), - all_versions_specs: p.supported_versions.clone(), - }, - ) + let protocols = self + .supported_protocols + .iter() + .flat_map(|(protocol_id, p)| { + p.supported_versions.iter().map(|(ver, spec)| { + ( + *protocol_id, + Protocol { + ver: *ver, + spec: *spec, + state: Some(ProtocolState::Closed), + all_versions_specs: p.supported_versions.clone(), + }, + ) + }) }) - })); - - #[cfg(not(feature = "test_peer_punish_too_slow"))] - let throttle_recv = ThrottleStage::Disable; - #[cfg(feature = "test_peer_punish_too_slow")] - let throttle_recv = ThrottleStage::Start; - + .collect::>() + .into(); PeerConnHandler { conf: self.conf, protocols, @@ -248,7 +253,7 @@ impl IntoConnectionHandler for PartialPeerConnHandler { pub struct PeerConnHandler { conf: PeerConnHandlerConf, - protocols: HashMap, + protocols: SupportedProtocolIdMap, /// When the connection with the remote has been successfully established. created_at: Instant, /// Whether we are the connection dialer or listener. @@ -256,8 +261,9 @@ pub struct PeerConnHandler { /// Remote we are connected to. peer_id: PeerId, /// Events to return in priority from `poll`. - pending_events: - VecDeque>, + pending_events: VecDeque< + ConnectionHandlerEvent, + >, /// Is the handler going to terminate due to this err. fault: Option, /// Current throttle stage. Throttling is performed for each inbound `StreamNotification`. @@ -280,13 +286,13 @@ impl ConnectionHandler for PeerConnHandler { type InboundProtocol = AnyUpgradeOf; type OutboundProtocol = ProtocolUpgradeOut; type InboundOpenInfo = (); - type OutboundOpenInfo = ProtocolTag; + type OutboundOpenInfo = SupportedProtocolTag; fn listen_protocol(&self) -> SubstreamProtocol { let protocols = self .protocols .iter() - .map(|(pid, prot)| ProtocolUpgradeIn::new(*pid, prot.all_versions_specs.clone())) + .map(|(pid, prot)| ProtocolUpgradeIn::new(pid, prot.all_versions_specs.clone())) .collect::>(); SubstreamProtocol::new(protocols, ()) } @@ -299,70 +305,65 @@ impl ConnectionHandler for PeerConnHandler { trace!("inject_fully_negotiated_inbound()"); let negotiated_tag = upgrade.negotiated_tag; let protocol_id = negotiated_tag.protocol_id(); - if let Some(protocol) = self.protocols.get_mut(&protocol_id) { - let state = protocol.state.take(); - if let Some(state) = state { - trace!("Current protocol state is {:?}", state); - let state_next = match state { - ProtocolState::Closed => { - let event = ConnectionHandlerEvent::Custom(ConnHandlerOut::OpenedByPeer { + let protocol = self.protocols.get_supported_mut(protocol_id); + let state = protocol.state.take(); + if let Some(state) = state { + trace!("Current protocol state is {:?}", state); + let state_next = match state { + ProtocolState::Closed => { + let event = ConnectionHandlerEvent::Custom(ConnHandlerOut::OpenedByPeer { + protocol_tag: negotiated_tag, + handshake: upgrade.handshake, + }); + self.pending_events.push_back(event); + ProtocolState::PartiallyOpenedByPeer { + substream_in: upgrade.substream, + } + } + // Should not happen in normal network conditions. + ProtocolState::Opening => ProtocolState::PartiallyOpenedByPeer { + substream_in: upgrade.substream, + }, + ProtocolState::PartiallyOpened { substream_out } + | ProtocolState::InboundClosedByPeer { substream_out, .. } => { + if protocol.spec.approve_required { + // Approve immediately if required. + trace!("Sending approve for outbound protocol {:?}", protocol_id); + upgrade.substream.send_approve(); + } + let (async_msg_snd, async_msg_recv) = + mpsc::channel::(self.conf.async_msg_buffer_size); + let (sync_msg_snd, sync_msg_recv) = + mpsc::channel::(self.conf.sync_msg_buffer_size); + let sink = MessageSink::new(self.peer_id, async_msg_snd, sync_msg_snd); + self.pending_events + .push_back(ConnectionHandlerEvent::Custom(ConnHandlerOut::Opened { protocol_tag: negotiated_tag, + out_channel: sink, handshake: upgrade.handshake, - }); - self.pending_events.push_back(event); - ProtocolState::PartiallyOpenedByPeer { - substream_in: upgrade.substream, - } - } - // Should not happen in normal network conditions. - ProtocolState::Opening => ProtocolState::PartiallyOpenedByPeer { + })); + ProtocolState::Opened { + substream_out, substream_in: upgrade.substream, - }, - ProtocolState::PartiallyOpened { substream_out } - | ProtocolState::InboundClosedByPeer { substream_out, .. } => { - if protocol.spec.approve_required { - // Approve immediately if required. - trace!("Sending approve for outbound protocol {:?}", protocol_id); - upgrade.substream.send_approve(); - } - let (async_msg_snd, async_msg_recv) = - mpsc::channel::(self.conf.async_msg_buffer_size); - let (sync_msg_snd, sync_msg_recv) = - mpsc::channel::(self.conf.sync_msg_buffer_size); - let sink = MessageSink::new(self.peer_id, async_msg_snd, sync_msg_snd); - self.pending_events.push_back(ConnectionHandlerEvent::Custom( - ConnHandlerOut::Opened { - protocol_tag: negotiated_tag, - out_channel: sink, - handshake: upgrade.handshake, - }, - )); - ProtocolState::Opened { - substream_out, - substream_in: upgrade.substream, - pending_messages_recv: stream::select( - async_msg_recv.fuse(), - sync_msg_recv.fuse(), - ) + pending_messages_recv: stream::select(async_msg_recv.fuse(), sync_msg_recv.fuse()) .peekable(), - } } - // If a substream already exists, silently drop the new one. - // Note that we drop the substream, which will send an equivalent to a - // TCP "RST" to the remote and force-close the substream. It might - // seem like an unclean way to get rid of a substream. However, keep - // in mind that it is invalid for the remote to open multiple such - // substreams, and therefore sending a "RST" is the most correct thing - // to do. - ProtocolState::PartiallyOpenedByPeer { .. } - | ProtocolState::Opened { .. } - | ProtocolState::Accepting { .. } - | ProtocolState::OutboundClosedByPeer { .. } => state, - }; - trace!("Next protocol state is {:?}", state_next); - protocol.state = Some(state_next); + } + // If a substream already exists, silently drop the new one. + // Note that we drop the substream, which will send an equivalent to a + // TCP "RST" to the remote and force-close the substream. It might + // seem like an unclean way to get rid of a substream. However, keep + // in mind that it is invalid for the remote to open multiple such + // substreams, and therefore sending a "RST" is the most correct thing + // to do. + ProtocolState::PartiallyOpenedByPeer { .. } + | ProtocolState::Opened { .. } + | ProtocolState::Accepting { .. } + | ProtocolState::OutboundClosedByPeer { .. } => state, }; - } + trace!("Next protocol state is {:?}", state_next); + protocol.state = Some(state_next); + }; } fn inject_fully_negotiated_outbound( @@ -372,48 +373,43 @@ impl ConnectionHandler for PeerConnHandler { ) { trace!("inject_fully_negotiated_outbound()"); let protocol_id = negotiated_tag.protocol_id(); - if let Some(protocol) = self.protocols.get_mut(&protocol_id) { - let state = protocol.state.take(); - trace!("Current protocol state is {:?}", state); - if let Some(state) = state { - let state_next = match state { - ProtocolState::Opening => ProtocolState::PartiallyOpened { + let protocol = self.protocols.get_supported_mut(protocol_id); + let state = protocol.state.take(); + trace!("Current protocol state is {:?}", state); + if let Some(state) = state { + let state_next = match state { + ProtocolState::Opening => ProtocolState::PartiallyOpened { + substream_out: upgrade.substream, + }, + ProtocolState::Accepting { + substream_in: Some(substream_in), + } => { + let (async_msg_snd, async_msg_recv) = + mpsc::channel::(self.conf.async_msg_buffer_size); + let (sync_msg_snd, sync_msg_recv) = + mpsc::channel::(self.conf.sync_msg_buffer_size); + let sink = MessageSink::new(self.peer_id, async_msg_snd, sync_msg_snd); + self.pending_events + .push_back(ConnectionHandlerEvent::Custom(ConnHandlerOut::Opened { + protocol_tag: negotiated_tag, + out_channel: sink, + handshake: None, + })); + ProtocolState::Opened { + substream_in, substream_out: upgrade.substream, - }, - ProtocolState::Accepting { - substream_in: Some(substream_in), - } => { - let (async_msg_snd, async_msg_recv) = - mpsc::channel::(self.conf.async_msg_buffer_size); - let (sync_msg_snd, sync_msg_recv) = - mpsc::channel::(self.conf.sync_msg_buffer_size); - let sink = MessageSink::new(self.peer_id, async_msg_snd, sync_msg_snd); - self.pending_events.push_back(ConnectionHandlerEvent::Custom( - ConnHandlerOut::Opened { - protocol_tag: negotiated_tag, - out_channel: sink, - handshake: None, - }, - )); - ProtocolState::Opened { - substream_in, - substream_out: upgrade.substream, - pending_messages_recv: stream::select( - async_msg_recv.fuse(), - sync_msg_recv.fuse(), - ) + pending_messages_recv: stream::select(async_msg_recv.fuse(), sync_msg_recv.fuse()) .peekable(), - } } - // todo: handle this in the case we decide to re-open out substream. - ProtocolState::OutboundClosedByPeer { .. } => state, - // todo: warn, inconsistent state; discard other options explicitly. - _ => state, - }; - trace!("Next protocol state is {:?}", state_next); - protocol.state = Some(state_next); + } + // todo: handle this in the case we decide to re-open out substream. + ProtocolState::OutboundClosedByPeer { .. } => state, + // todo: warn, inconsistent state; discard other options explicitly. + _ => state, }; - } + trace!("Next protocol state is {:?}", state_next); + protocol.state = Some(state_next); + }; } fn inject_event(&mut self, cmd: ConnHandlerIn) { @@ -423,13 +419,13 @@ impl ConnectionHandler for PeerConnHandler { handshake, } => { trace!("ConnHandlerIn::Open[{:?}]", protocol_id); - if let Some(protocol) = self.protocols.get_mut(&protocol_id) { + if let Some((supported_protocol_id, protocol)) = self.protocols.get_mut(protocol_id) { let state = protocol.state.take(); if let Some(state) = state { let state_next = match state { ProtocolState::Closed => { let upgrade = ProtocolUpgradeOut::new( - protocol_id, + supported_protocol_id, protocol .all_versions_specs .clone() @@ -442,7 +438,7 @@ impl ConnectionHandler for PeerConnHandler { ConnectionHandlerEvent::OutboundSubstreamRequest { protocol: SubstreamProtocol::new( upgrade, - ProtocolTag::new(protocol_id, protocol.ver), + SupportedProtocolTag::new(supported_protocol_id, protocol.ver), ) .with_timeout(self.conf.open_timeout), }, @@ -450,14 +446,14 @@ impl ConnectionHandler for PeerConnHandler { ProtocolState::Opening } ProtocolState::PartiallyOpenedByPeer { mut substream_in } => { - let ver_handshake = handshake.handshake_for(protocol.ver); + let ver_handshake = handshake.handshake_for(protocol.ver.get_inner()); if ver_handshake.is_some() { // If handshake is defined dialer is waiting for approve, so we send it. trace!("Sending approve for inbound protocol {:?}", protocol_id); substream_in.send_approve() } let upgrade = ProtocolUpgradeOut::new( - protocol_id, + supported_protocol_id, // Version is negotiated during inbound upgr, so we pass it exclusively to outbound upgr. vec![(protocol.ver, protocol.spec, ver_handshake)], ); @@ -465,7 +461,7 @@ impl ConnectionHandler for PeerConnHandler { ConnectionHandlerEvent::OutboundSubstreamRequest { protocol: SubstreamProtocol::new( upgrade, - ProtocolTag::new(protocol_id, protocol.ver), + SupportedProtocolTag::new(supported_protocol_id, protocol.ver), ) .with_timeout(self.conf.open_timeout), }, @@ -481,13 +477,12 @@ impl ConnectionHandler for PeerConnHandler { } } ConnHandlerIn::Close(protocol_id) => { - if let Some(protocol) = self.protocols.get_mut(&protocol_id) { - protocol.state = Some(ProtocolState::Closed); - self.pending_events - .push_back(ConnectionHandlerEvent::Custom(ConnHandlerOut::Closed( - protocol_id, - ))) - } + let protocol = self.protocols.get_supported_mut(protocol_id); + protocol.state = Some(ProtocolState::Closed); + self.pending_events + .push_back(ConnectionHandlerEvent::Custom(ConnHandlerOut::Closed( + protocol_id.get_inner(), + ))) } ConnHandlerIn::CloseAllProtocols => { for protocol in self.protocols.values_mut() { @@ -505,20 +500,20 @@ impl ConnectionHandler for PeerConnHandler { err: ConnectionHandlerUpgrErr, ) { let protocol_id = protocol_tag.protocol_id(); - if let Some(protocol) = self.protocols.get_mut(&protocol_id) { - if let Some(state) = &protocol.state { - match state { - ProtocolState::Opening | ProtocolState::Accepting { .. } => { - trace!("Failed to open protocol {:?}, {:?}", protocol_id, err); - self.pending_events.push_back(ConnectionHandlerEvent::Custom( - ConnHandlerOut::RefusedToOpen(protocol_id), - )) - } - _ => {} + let protocol = self.protocols.get_supported_mut(protocol_id); + + if let Some(state) = &protocol.state { + match state { + ProtocolState::Opening | ProtocolState::Accepting { .. } => { + trace!("Failed to open protocol {:?}, {:?}", protocol_id, err); + self.pending_events.push_back(ConnectionHandlerEvent::Custom( + ConnHandlerOut::RefusedToOpen(protocol_id), + )) } + _ => {} } - protocol.state = Some(ProtocolState::Closed) } + protocol.state = Some(ProtocolState::Closed) } fn connection_keep_alive(&self) -> KeepAlive { @@ -618,7 +613,7 @@ impl ConnectionHandler for PeerConnHandler { // performed before the code paths that can produce `Ready` (with some rare exceptions). // Importantly, the flush is performed *after* notifications are queued with // `Sink::start_send`. - for (protocol_id, protocol) in &mut self.protocols { + for (protocol_id, protocol) in self.protocols.iter_mut() { if let Some(state) = &mut protocol.state { if let ProtocolState::Opened { substream_out, .. } | ProtocolState::InboundClosedByPeer { substream_out, .. } = state @@ -634,7 +629,7 @@ impl ConnectionHandler for PeerConnHandler { } else { protocol.state = Some(ProtocolState::Closed) } - let event = ConnHandlerOut::ClosedByPeer(*protocol_id); + let event = ConnHandlerOut::ClosedByPeer(protocol_id); return Poll::Ready(ConnectionHandlerEvent::Custom(event)); } } @@ -643,7 +638,7 @@ impl ConnectionHandler for PeerConnHandler { } // Poll inbound substreams. - for (protocol_id, protocol) in &mut self.protocols { + for (protocol_id, protocol) in self.protocols.iter_mut() { if let Some(state) = &mut protocol.state { match state { ProtocolState::Opened { substream_in, .. } @@ -652,7 +647,10 @@ impl ConnectionHandler for PeerConnHandler { Poll::Pending => {} Poll::Ready(Some(Ok(msg))) => { let event = ConnHandlerOut::Message { - protocol_tag: ProtocolTag::new(*protocol_id, protocol.ver), + protocol_tag: ProtocolTag::new( + protocol_id.get_inner(), + protocol.ver.get_inner(), + ), content: msg, }; return Poll::Ready(ConnectionHandlerEvent::Custom(event)); @@ -685,7 +683,7 @@ impl ConnectionHandler for PeerConnHandler { Poll::Ready(Err(_)) => { protocol.state = Some(ProtocolState::Closed); return Poll::Ready(ConnectionHandlerEvent::Custom( - ConnHandlerOut::ClosedByPeer(*protocol_id), + ConnHandlerOut::ClosedByPeer(protocol_id), )); } } diff --git a/spectrum-network/src/protocol.rs b/spectrum-network/src/protocol.rs index d483ba2e..3c2cb557 100644 --- a/spectrum-network/src/protocol.rs +++ b/spectrum-network/src/protocol.rs @@ -1,4 +1,4 @@ -use crate::types::{ProtocolId, ProtocolVer}; +use crate::{protocol_upgrade::supported_protocol_vers::SupportedProtocolVer, types::ProtocolId}; pub const SYNC_PROTOCOL_ID: ProtocolId = ProtocolId::from_u8(0); @@ -12,5 +12,5 @@ pub struct ProtocolSpec { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ProtocolConfig { - pub supported_versions: Vec<(ProtocolVer, ProtocolSpec)>, + pub supported_versions: Vec<(SupportedProtocolVer, ProtocolSpec)>, } diff --git a/spectrum-network/src/protocol_upgrade.rs b/spectrum-network/src/protocol_upgrade.rs index 29eb546f..bd5ac137 100644 --- a/spectrum-network/src/protocol_upgrade.rs +++ b/spectrum-network/src/protocol_upgrade.rs @@ -2,23 +2,28 @@ pub mod combinators; pub mod handshake; mod message; pub(crate) mod substream; +pub(crate) mod supported_protocol_vers; +pub use supported_protocol_vers::GetSupportedProtocolVer; use crate::protocol::ProtocolSpec; use crate::protocol_upgrade::message::{Approve, APPROVE_SIZE}; use crate::protocol_upgrade::substream::{ProtocolApproveState, ProtocolSubstreamIn, ProtocolSubstreamOut}; -use crate::types::{ProtocolId, ProtocolTag, ProtocolVer, RawMessage}; +use crate::types::RawMessage; use asynchronous_codec::Framed; use futures::{AsyncRead, AsyncReadExt, AsyncWrite}; use libp2p::core::{upgrade, UpgradeInfo}; use libp2p::{InboundUpgrade, OutboundUpgrade}; use log::trace; -use std::collections::BTreeMap; use std::fmt::Debug; use std::future::Future; use std::pin::Pin; use std::{io, vec}; use unsigned_varint::codec::UviBytes; +use self::supported_protocol_vers::{ + SupportedProtocolId, SupportedProtocolTag, SupportedProtocolVer, SupportedProtocolVerBTreeMap, +}; + #[derive(Debug, thiserror::Error)] pub enum ProtocolHandshakeErr { #[error(transparent)] @@ -33,11 +38,9 @@ pub enum ProtocolHandshakeErr { pub enum ProtocolUpgradeErr { #[error(transparent)] HandshakeErr(#[from] ProtocolHandshakeErr), - #[error("Unsupported {0:?}")] - UnsupportedProtocolVer(ProtocolVer), } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] pub struct InboundProtocolSpec { /// Maximum allowed size for a single message. max_message_size: usize, @@ -59,19 +62,22 @@ impl From for InboundProtocolSpec { #[derive(Debug, Clone)] pub struct ProtocolUpgradeIn { /// Protocol to negotiate. - protocol_id: ProtocolId, + protocol_id: SupportedProtocolId, /// Protocol versions to negotiate. /// The first one is the main name, while the other ones are fall backs. - supported_versions: BTreeMap, + supported_versions: SupportedProtocolVerBTreeMap, } impl ProtocolUpgradeIn { - pub fn new(protocol_id: ProtocolId, supported_versions: Vec<(ProtocolVer, ProtocolSpec)>) -> Self { - let supported_versions = BTreeMap::from_iter( - supported_versions - .into_iter() - .map(|(ver, spec)| (ver, InboundProtocolSpec::from(spec))), - ); + pub fn new( + protocol_id: SupportedProtocolId, + supported_versions: Vec<(SupportedProtocolVer, ProtocolSpec)>, + ) -> Self { + let supported_versions = supported_versions + .into_iter() + .map(|(ver, spec)| (ver, InboundProtocolSpec::from(spec))) + .collect::>() + .into(); Self { protocol_id, supported_versions, @@ -80,14 +86,13 @@ impl ProtocolUpgradeIn { } impl UpgradeInfo for ProtocolUpgradeIn { - type Info = ProtocolTag; + type Info = SupportedProtocolTag; type InfoIter = vec::IntoIter; fn protocol_info(&self) -> Self::InfoIter { self.supported_versions .keys() - .cloned() - .map(|v| ProtocolTag::new(self.protocol_id, v)) + .map(|v| SupportedProtocolTag::new(self.protocol_id, v)) .collect::>() .into_iter() } @@ -106,10 +111,7 @@ where let target = format!("Inbound({})", negotiated_tag); trace!(target: &target, "upgrade_inbound()"); let protocol_ver = negotiated_tag.protocol_ver(); - let pspec = self - .supported_versions - .get(&protocol_ver) - .ok_or(ProtocolUpgradeErr::UnsupportedProtocolVer(protocol_ver))?; + let pspec = self.supported_versions.get(protocol_ver); let mut codec = UviBytes::default(); codec.set_max_len(pspec.max_message_size); let handshake = if pspec.handshake_required { @@ -161,21 +163,22 @@ impl OutboundProtocolSpec { #[derive(Debug, Clone)] pub struct ProtocolUpgradeOut { /// Protocol to negotiate. - protocol_id: ProtocolId, + protocol_id: SupportedProtocolId, /// Protocol versions to negotiate. /// The first one is the main name, while the other ones are fall backs. - supported_versions: BTreeMap, + supported_versions: SupportedProtocolVerBTreeMap, } impl ProtocolUpgradeOut { pub fn new( - protocol_id: ProtocolId, - supported_versions: Vec<(ProtocolVer, ProtocolSpec, Option)>, + protocol_id: SupportedProtocolId, + supported_versions: Vec<(SupportedProtocolVer, ProtocolSpec, Option)>, ) -> Self { - let supported_versions = - BTreeMap::from_iter(supported_versions.into_iter().map(|(ver, spec, handshake)| { - (ver, OutboundProtocolSpec::new(spec.max_message_size, handshake)) - })); + let supported_versions = supported_versions + .into_iter() + .map(|(ver, spec, handshake)| (ver, OutboundProtocolSpec::new(spec.max_message_size, handshake))) + .collect::>() + .into(); Self { protocol_id, supported_versions, @@ -184,14 +187,13 @@ impl ProtocolUpgradeOut { } impl UpgradeInfo for ProtocolUpgradeOut { - type Info = ProtocolTag; + type Info = SupportedProtocolTag; type InfoIter = vec::IntoIter; fn protocol_info(&self) -> Self::InfoIter { self.supported_versions .keys() - .cloned() - .map(|v| ProtocolTag::new(self.protocol_id, v)) + .map(|v| SupportedProtocolTag::new(self.protocol_id, v)) .collect::>() .into_iter() } @@ -210,10 +212,7 @@ where let target = format!("Outbound({})", negotiated_tag); trace!(target: &target, "upgrade_outbound()"); let protocol_ver = negotiated_tag.protocol_ver(); - let pspec = self - .supported_versions - .get(&protocol_ver) - .ok_or(ProtocolUpgradeErr::UnsupportedProtocolVer(protocol_ver))?; + let pspec = self.supported_versions.get(protocol_ver); let mut codec = UviBytes::default(); codec.set_max_len(pspec.max_message_size); if let Some(handshake) = &pspec.handshake { @@ -240,7 +239,7 @@ where pub struct InboundProtocolUpgraded { /// ProtocolTag negotiated with the peer. - pub negotiated_tag: ProtocolTag, + pub negotiated_tag: SupportedProtocolTag, /// Handshake sent by the peer. pub handshake: Option, pub substream: Substream, @@ -248,7 +247,7 @@ pub struct InboundProtocolUpgraded { pub struct OutboundProtocolUpgraded { /// ProtocolTag negotiated with the peer. - pub negotiated_tag: ProtocolTag, + pub negotiated_tag: SupportedProtocolTag, pub substream: Substream, } diff --git a/spectrum-network/src/protocol_upgrade/supported_protocol_vers.rs b/spectrum-network/src/protocol_upgrade/supported_protocol_vers.rs new file mode 100644 index 00000000..020180f8 --- /dev/null +++ b/spectrum-network/src/protocol_upgrade/supported_protocol_vers.rs @@ -0,0 +1,152 @@ +use std::{ + collections::{BTreeMap, HashMap}, + fmt::Display, +}; + +use libp2p::core::upgrade; + +use crate::{ + protocol_handler::sync::message::SyncSpec, + types::{ProtocolId, ProtocolTag, ProtocolVer}, +}; + +pub trait GetSupportedProtocolVer { + fn get_supported_ver() -> SupportedProtocolVer; +} + +impl GetSupportedProtocolVer for SyncSpec { + fn get_supported_ver() -> SupportedProtocolVer { + SupportedProtocolVer(Self::v1()) + } +} + +#[derive(Debug, Clone)] +pub struct SupportedProtocolVerBTreeMap(BTreeMap); + +impl SupportedProtocolVerBTreeMap { + pub fn get(&self, ver: SupportedProtocolVer) -> &T { + #[allow(clippy::unwrap_used)] + self.0.get(&ver.0).unwrap() + } + + pub fn keys(&self) -> impl Iterator + '_ { + self.0.keys().cloned().map(SupportedProtocolVer) + } +} + +impl From> for SupportedProtocolVerBTreeMap { + fn from(v: Vec<(SupportedProtocolVer, T)>) -> Self { + Self(v.into_iter().map(|(ver, t)| (ver.get_inner(), t)).collect()) + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub struct SupportedProtocolId(ProtocolId); + +impl SupportedProtocolId { + /// It's safe to expose the underlying [`ProtocolId`]. + pub fn get_inner(&self) -> ProtocolId { + self.0 + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub struct SupportedProtocolVer(ProtocolVer); + +impl SupportedProtocolVer { + /// It's safe to expose the underlying [`ProtocolVer`]. + pub fn get_inner(&self) -> ProtocolVer { + self.0 + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct SupportedProtocolTag(ProtocolTag); + +impl SupportedProtocolTag { + pub fn protocol_ver(&self) -> SupportedProtocolVer { + SupportedProtocolVer::from(*self) + } + + pub fn protocol_id(&self) -> SupportedProtocolId { + SupportedProtocolId::from(*self) + } +} + +impl SupportedProtocolTag { + pub fn new(protocol_id: SupportedProtocolId, protocol_ver: SupportedProtocolVer) -> Self { + Self(ProtocolTag::new(protocol_id.0, protocol_ver.0)) + } +} + +impl From for SupportedProtocolVer { + fn from(p: SupportedProtocolTag) -> Self { + Self(ProtocolVer::from(p.0)) + } +} + +impl From for SupportedProtocolId { + fn from(p: SupportedProtocolTag) -> Self { + Self(ProtocolId::from(p.0)) + } +} + +impl upgrade::ProtocolName for SupportedProtocolTag { + fn protocol_name(&self) -> &[u8] { + self.0.protocol_name() + } +} + +impl Display for SupportedProtocolTag { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +pub struct SupportedProtocolIdMap(HashMap); + +impl From> for SupportedProtocolIdMap { + fn from(h: HashMap) -> Self { + Self(h) + } +} + +impl From> for SupportedProtocolIdMap { + fn from(v: Vec<(SupportedProtocolId, T)>) -> Self { + Self(v.into_iter().map(|(id, t)| (id.get_inner(), t)).collect()) + } +} + +impl SupportedProtocolIdMap { + pub fn get(&self, id: ProtocolId) -> Option<&T> { + self.0.get(&id) + } + + pub fn get_mut(&mut self, id: ProtocolId) -> Option<(SupportedProtocolId, &mut T)> { + self.0.get_mut(&id).map(|t| (SupportedProtocolId(id), t)) + } + + pub fn get_supported(&self, id: SupportedProtocolId) -> &T { + self.0.get(&id.0).unwrap() + } + + pub fn get_supported_mut(&mut self, id: SupportedProtocolId) -> &mut T { + self.0.get_mut(&id.0).unwrap() + } + + pub fn iter(&self) -> impl Iterator { + self.0.iter().map(|(id, v)| (SupportedProtocolId(*id), v)) + } + + pub fn iter_mut(&mut self) -> impl Iterator { + self.0.iter_mut().map(|(id, v)| (SupportedProtocolId(*id), v)) + } + + pub fn values(&self) -> impl Iterator { + self.0.values() + } + + pub fn values_mut(&mut self) -> impl Iterator { + self.0.values_mut() + } +} diff --git a/spectrum-network/src/types.rs b/spectrum-network/src/types.rs index b43a9f84..01989925 100644 --- a/spectrum-network/src/types.rs +++ b/spectrum-network/src/types.rs @@ -117,7 +117,13 @@ impl ProtocolTag { impl From for ProtocolVer { fn from(p: ProtocolTag) -> Self { - ProtocolVer::from(p.0[1]) + ProtocolVer::from(p.0[2]) + } +} + +impl From for ProtocolId { + fn from(p: ProtocolTag) -> Self { + ProtocolId::from(p.0[1]) } } diff --git a/spectrum-network/tests/tests.rs b/spectrum-network/tests/tests.rs index cae94b87..cd049a53 100644 --- a/spectrum-network/tests/tests.rs +++ b/spectrum-network/tests/tests.rs @@ -27,6 +27,7 @@ use spectrum_network::protocol_api::ProtocolMailbox; use spectrum_network::protocol_handler::sync::message::SyncSpec; use spectrum_network::protocol_handler::sync::{NodeStatus, SyncBehaviour}; use spectrum_network::protocol_handler::ProtocolHandler; +use spectrum_network::protocol_upgrade::GetSupportedProtocolVer; use spectrum_network::types::Reputation; use std::collections::HashMap; use std::{ @@ -201,7 +202,7 @@ pub fn build_node( let (peer_manager, peers) = PeerManager::new(peer_state, peer_manager_conf); let sync_conf = ProtocolConfig { supported_versions: vec![( - SyncSpec::v1(), + SyncSpec::get_supported_ver(), ProtocolSpec { max_message_size: 100, approve_required: true, diff --git a/spectrum-node/src/main.rs b/spectrum-node/src/main.rs index d2c579d0..20b7984e 100644 --- a/spectrum-node/src/main.rs +++ b/spectrum-node/src/main.rs @@ -18,6 +18,7 @@ use spectrum_network::protocol::{ProtocolConfig, ProtocolSpec, SYNC_PROTOCOL_ID} use spectrum_network::protocol_handler::sync::message::{SyncMessage, SyncMessageV1, SyncSpec}; use spectrum_network::protocol_handler::sync::{NodeStatus, SyncBehaviour}; use spectrum_network::protocol_handler::ProtocolHandler; +use spectrum_network::protocol_upgrade::GetSupportedProtocolVer; use spectrum_network::types::Reputation; use std::time::Duration; @@ -76,7 +77,7 @@ async fn main() -> Result<(), Box> { let (peer_manager, peers) = PeerManager::new(peer_state, peer_manager_conf); let sync_conf = ProtocolConfig { supported_versions: vec![( - SyncSpec::v1(), + SyncSpec::get_supported_ver(), ProtocolSpec { max_message_size: 100, approve_required: true,