Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ use crate::{
notification, request_response, UserProtocol,
},
transport::{
manager::limits::ConnectionLimitsConfig, tcp::config::Config as TcpConfig,
KEEP_ALIVE_TIMEOUT, MAX_PARALLEL_DIALS,
manager::{limits::ConnectionLimitsConfig, TransportHandle},
tcp::config::Config as TcpConfig,
Transport, TransportEvent, KEEP_ALIVE_TIMEOUT, MAX_PARALLEL_DIALS,
},
types::protocol::ProtocolName,
PeerId,
Expand All @@ -43,6 +44,7 @@ use crate::transport::webrtc::config::Config as WebRtcConfig;
#[cfg(feature = "websocket")]
use crate::transport::websocket::config::Config as WebSocketConfig;

use hickory_resolver::TokioResolver;
use multiaddr::Multiaddr;

use std::{collections::HashMap, sync::Arc, time::Duration};
Expand Down Expand Up @@ -83,6 +85,15 @@ pub struct ConfigBuilder {
#[cfg(feature = "websocket")]
websocket: Option<WebSocketConfig>,

/// List of custom transports.
custom_transports: Vec<(
&'static str,
fn(
TransportHandle,
Arc<TokioResolver>,
) -> crate::Result<(Box<dyn Transport<Item = TransportEvent>>, Vec<Multiaddr>)>,
)>,

/// Keypair.
keypair: Option<Keypair>,

Expand Down Expand Up @@ -146,6 +157,7 @@ impl ConfigBuilder {
webrtc: None,
#[cfg(feature = "websocket")]
websocket: None,
custom_transports: Vec::new(),
keypair: None,
ping: None,
identify: None,
Expand Down Expand Up @@ -191,6 +203,20 @@ impl ConfigBuilder {
self
}

/// Add a custom transport configuration, enabling the transport.
pub fn with_custom_transport(
mut self,
name: &'static str,
transport: fn(
TransportHandle,
Arc<TokioResolver>,
)
-> crate::Result<(Box<dyn Transport<Item = TransportEvent>>, Vec<Multiaddr>)>,
) -> Self {
self.custom_transports.push((name, transport));
self
}

/// Add keypair.
///
/// If no keypair is specified, litep2p creates a new keypair.
Expand Down Expand Up @@ -305,6 +331,7 @@ impl ConfigBuilder {
webrtc: self.webrtc.take(),
#[cfg(feature = "websocket")]
websocket: self.websocket.take(),
custom_transports: self.custom_transports,
ping: self.ping.take(),
identify: self.identify.take(),
kademlia: self.kademlia.take(),
Expand Down Expand Up @@ -339,6 +366,15 @@ pub struct Litep2pConfig {
#[cfg(feature = "websocket")]
pub(crate) websocket: Option<WebSocketConfig>,

/// Custom transports.
pub(crate) custom_transports: Vec<(
&'static str,
fn(
TransportHandle,
Arc<TokioResolver>,
) -> crate::Result<(Box<dyn Transport<Item = TransportEvent>>, Vec<Multiaddr>)>,
)>,

/// Keypair.
pub(crate) keypair: Keypair,

Expand Down
17 changes: 16 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ impl Litep2p {
if let Some(config) = litep2p_config.websocket.take() {
let handle = transport_manager.transport_handle(Arc::clone(&litep2p_config.executor));
let (transport, transport_listen_addresses) =
<WebSocketTransport as TransportBuilder>::new(handle, config, resolver)?;
<WebSocketTransport as TransportBuilder>::new(handle, config, resolver.clone())?;

for address in transport_listen_addresses {
transport_manager.register_listen_address(address.clone());
Expand All @@ -383,6 +383,21 @@ impl Litep2p {
.register_transport(SupportedTransport::WebSocket, Box::new(transport));
}

// enable custom transports
for (name, transport_factory) in litep2p_config.custom_transports {
let handle = transport_manager.transport_handle(Arc::clone(&litep2p_config.executor));

let (transport, transport_listen_addresses) =
transport_factory(handle, resolver.clone())?;

for address in transport_listen_addresses {
transport_manager.register_listen_address(address.clone());
listen_addresses.push(address.with(Protocol::P2p(*local_peer_id.as_ref())));
}

transport_manager.register_transport(SupportedTransport::Custom(name), transport);
}

// enable mdns if the config exists
if let Some(config) = litep2p_config.mdns.take() {
let mdns = Mdns::new(transport_handle, config, listen_addresses.clone());
Expand Down
102 changes: 34 additions & 68 deletions src/transport/manager/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,18 @@ impl TransportContext {
) {
assert!(self.transports.insert(name, transport).is_none());
}

/// Iterate through all transports
pub fn iter_mut(
&mut self,
) -> impl Iterator<
Item = (
&SupportedTransport,
&mut (dyn Transport<Item = TransportEvent> + 'static),
),
> {
self.transports.iter_mut().map(|(a, b)| (a, &mut **b))
}
}

impl Stream for TransportContext {
Expand Down Expand Up @@ -615,64 +627,6 @@ impl TransportManager {

tracing::debug!(target: LOG_TARGET, address = ?address_record.address(), "dial address");

let mut protocol_stack = address_record.as_ref().iter();
match protocol_stack
.next()
.ok_or_else(|| Error::TransportNotSupported(address_record.address().clone()))?
{
Protocol::Ip4(_) | Protocol::Ip6(_) => {}
Protocol::Dns(_) | Protocol::Dns4(_) | Protocol::Dns6(_) => {}
transport => {
tracing::error!(
target: LOG_TARGET,
?transport,
"invalid transport, expected `ip4`/`ip6`"
);
return Err(Error::TransportNotSupported(
address_record.address().clone(),
));
}
};

let supported_transport = match protocol_stack
.next()
.ok_or_else(|| Error::TransportNotSupported(address_record.address().clone()))?
{
Protocol::Tcp(_) => match protocol_stack.next() {
#[cfg(feature = "websocket")]
Some(Protocol::Ws(_)) | Some(Protocol::Wss(_)) => SupportedTransport::WebSocket,
Some(Protocol::P2p(_)) => SupportedTransport::Tcp,
_ =>
return Err(Error::TransportNotSupported(
address_record.address().clone(),
)),
},
#[cfg(feature = "quic")]
Protocol::Udp(_) => match protocol_stack
.next()
.ok_or_else(|| Error::TransportNotSupported(address_record.address().clone()))?
{
Protocol::QuicV1 => SupportedTransport::Quic,
_ => {
tracing::debug!(target: LOG_TARGET, address = ?address_record.address(), "expected `quic-v1`");
return Err(Error::TransportNotSupported(
address_record.address().clone(),
));
}
},
protocol => {
tracing::error!(
target: LOG_TARGET,
?protocol,
"invalid protocol"
);

return Err(Error::TransportNotSupported(
address_record.address().clone(),
));
}
};

// when constructing `AddressRecord`, `PeerId` was verified to be part of the address
let remote_peer_id =
PeerId::try_from_multiaddr(address_record.address()).expect("`PeerId` to exist");
Expand All @@ -699,12 +653,27 @@ impl TransportManager {
};
}

self.transports
.get_mut(&supported_transport)
.ok_or(Error::TransportNotSupported(
let mut dailed = false;

for (_, transport) in self.transports.iter_mut() {
if let Err(err) = transport.dial(connection_id, address_record.address().clone()) {
if let Error::AddressError(AddressError::InvalidProtocol) = err {
continue;
}

return Err(err);
}

dailed = true;
break;
}

if !dailed {
return Err(Error::TransportNotSupported(
address_record.address().clone(),
))?
.dial(connection_id, address_record.address().clone())?;
));
}

self.pending_connections.insert(connection_id, remote_peer_id);

Ok(())
Expand Down Expand Up @@ -1687,7 +1656,7 @@ mod tests {
}

#[tokio::test]
async fn try_to_dial_over_disabled_transport() {
async fn try_to_dial_over_custom_transport() {
let mut manager = TransportManagerBuilder::new().build();
let _handle = manager.transport_handle(Arc::new(DefaultExecutor {}));
manager.register_transport(SupportedTransport::Tcp, Box::new(DummyTransport::new()));
Expand All @@ -1700,10 +1669,7 @@ mod tests {
Multihash::from_bytes(&PeerId::random().to_bytes()).unwrap(),
));

assert!(std::matches!(
manager.dial_address(address).await,
Err(Error::TransportNotSupported(_))
));
assert!(std::matches!(manager.dial_address(address).await, Ok(())));
}

#[tokio::test]
Expand Down
3 changes: 3 additions & 0 deletions src/transport/manager/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ pub enum SupportedTransport {
/// WebSocket
#[cfg(feature = "websocket")]
WebSocket,

/// Custom transport
Custom(&'static str),
}

/// Peer context.
Expand Down
13 changes: 8 additions & 5 deletions src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

//! Transport protocol implementations provided by [`Litep2p`](`crate::Litep2p`).

use crate::{error::DialError, transport::manager::TransportHandle, types::ConnectionId, PeerId};
use crate::{error::DialError, types::ConnectionId, PeerId};

use futures::Stream;
use hickory_resolver::TokioResolver;
Expand All @@ -42,7 +42,10 @@ pub(crate) mod dummy;

pub(crate) mod manager;

pub use manager::limits::{ConnectionLimitsConfig, ConnectionLimitsError};
pub use manager::{
limits::{ConnectionLimitsConfig, ConnectionLimitsError},
TransportHandle,
};

/// Timeout for opening a connection.
pub(crate) const CONNECTION_OPEN_TIMEOUT: Duration = Duration::from_secs(10);
Expand Down Expand Up @@ -119,7 +122,7 @@ impl Endpoint {

/// Transport event.
#[derive(Debug)]
pub(crate) enum TransportEvent {
pub enum TransportEvent {
/// Fully negotiated connection established to remote peer.
ConnectionEstablished {
/// Peer ID.
Expand Down Expand Up @@ -175,7 +178,7 @@ pub(crate) enum TransportEvent {
},
}

pub(crate) trait TransportBuilder {
pub trait TransportBuilder {
type Config: Debug;
type Transport: Transport;

Expand All @@ -189,7 +192,7 @@ pub(crate) trait TransportBuilder {
Self: Sized;
}

pub(crate) trait Transport: Stream + Unpin + Send {
pub trait Transport: Stream + Unpin + Send {
/// Dial `address` and negotiate connection.
fn dial(&mut self, connection_id: ConnectionId, address: Multiaddr) -> crate::Result<()>;

Expand Down
2 changes: 1 addition & 1 deletion src/transport/tcp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ enum RawConnectionResult {
}

/// TCP transport.
pub(crate) struct TcpTransport {
pub struct TcpTransport {
/// Transport context.
context: TransportHandle,

Expand Down
22 changes: 21 additions & 1 deletion tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

use litep2p::{config::ConfigBuilder, transport::tcp::config::Config as TcpConfig};
use std::sync::Arc;

use hickory_resolver::TokioResolver;
use libp2p::Multiaddr;
use litep2p::{
config::ConfigBuilder,
transport::{tcp::config::Config as TcpConfig, TransportEvent, TransportHandle},
};

#[cfg(feature = "quic")]
use litep2p::transport::quic::config::Config as QuicConfig;
Expand All @@ -31,6 +38,18 @@ pub(crate) enum Transport {
Quic(QuicConfig),
#[cfg(feature = "websocket")]
WebSocket(WebSocketConfig),
Custom(
(
&'static str,
fn(
TransportHandle,
Arc<TokioResolver>,
) -> litep2p::Result<(
Box<dyn litep2p::transport::Transport<Item = TransportEvent>>,
Vec<Multiaddr>,
)>,
),
),
}

pub(crate) fn add_transport(config: ConfigBuilder, transport: Transport) -> ConfigBuilder {
Expand All @@ -40,5 +59,6 @@ pub(crate) fn add_transport(config: ConfigBuilder, transport: Transport) -> Conf
Transport::Quic(transport) => config.with_quic(transport),
#[cfg(feature = "websocket")]
Transport::WebSocket(transport) => config.with_websocket(transport),
Transport::Custom((name, transport)) => config.with_custom_transport(name, transport),
}
}
Loading