diff --git a/spectrum-network/src/network_controller.rs b/spectrum-network/src/network_controller.rs index c66c5175..7057cd34 100644 --- a/spectrum-network/src/network_controller.rs +++ b/spectrum-network/src/network_controller.rs @@ -114,7 +114,9 @@ pub struct NetworkController { peer_manager: TPeerManager, enabled_peers: HashMap>, requests_recv: UnboundedReceiver, - pending_actions: VecDeque>, + pending_actions: VecDeque< + NetworkBehaviourAction, PartialPeerConnHandler>, + >, } impl NetworkController @@ -158,7 +160,7 @@ where THandler: ProtocolEvents + Clone + 'static, { type ConnectionHandler = PartialPeerConnHandler; - type OutEvent = NetworkControllerOut; + type OutEvent = Result; fn new_handler(&mut self) -> Self::ConnectionHandler { trace!("New handler is created"); @@ -443,20 +445,25 @@ where ConnectedPeer::Connected { enabled_protocols, .. } => { - let (_, prot_handler) = self.supported_protocols.get(&protocol).unwrap(); - match enabled_protocols.entry(protocol) { - Entry::Occupied(_) => warn!( - "PM requested already enabled protocol {:?} with peer {:?}", - protocol, pid - ), - Entry::Vacant(protocol_entry) => { - protocol_entry.insert(( - EnabledProtocol::PendingEnable, - prot_handler.clone(), - )); - prot_handler.protocol_requested_local(pid); + if let Some((_, prot_handler)) = self.supported_protocols.get(&protocol) { + match enabled_protocols.entry(protocol) { + Entry::Occupied(_) => warn!( + "PM requested already enabled protocol {:?} with peer {:?}", + protocol, pid + ), + Entry::Vacant(protocol_entry) => { + protocol_entry.insert(( + EnabledProtocol::PendingEnable, + prot_handler.clone(), + )); + prot_handler.protocol_requested_local(pid); + } } - }; + } else { + return Poll::Ready(NetworkBehaviourAction::GenerateEvent(Err( + NetworkControllerError::UnsupportedProtocol(protocol), + ))); + } } ConnectedPeer::PendingConnect | ConnectedPeer::PendingApprove(_) @@ -548,3 +555,9 @@ where } } } + +#[derive(Debug, thiserror::Error)] +pub enum NetworkControllerError { + #[error("Unsupported protocol: {0:?}")] + UnsupportedProtocol(ProtocolId), +}