diff --git a/lib/protoflow-zeromq/Cargo.toml b/lib/protoflow-zeromq/Cargo.toml index 00fde5ab..7c31ba0c 100644 --- a/lib/protoflow-zeromq/Cargo.toml +++ b/lib/protoflow-zeromq/Cargo.toml @@ -17,16 +17,25 @@ publish.workspace = true [features] default = ["all", "std"] all = ["tracing"] -std = ["protoflow-core/std", "tracing?/std"] #, "zeromq/default"] +std = ["protoflow-core/std", "tracing?/std"] tracing = ["protoflow-core/tracing", "dep:tracing"] unstable = ["protoflow-core/unstable"] [build-dependencies] cfg_aliases.workspace = true +prost-build = "0.13.2" [dependencies] protoflow-core.workspace = true tracing = { version = "0.1", default-features = false, optional = true } -#zeromq = { version = "0.4", default-features = false } +zeromq = { version = "0.4.1", default-features = false, features = [ + "tokio-runtime", + "all-transport", +] } +tokio = { version = "1.40.0", default-features = false } +prost = "0.13.2" +prost-types = "0.13.2" [dev-dependencies] +futures-util = "0.3.31" +tracing-subscriber = "0.3.19" diff --git a/lib/protoflow-zeromq/build.rs b/lib/protoflow-zeromq/build.rs new file mode 100644 index 00000000..71f34fd0 --- /dev/null +++ b/lib/protoflow-zeromq/build.rs @@ -0,0 +1,6 @@ +use std::io::Result; +fn main() -> Result<()> { + prost_build::Config::default() + .out_dir("src/") + .compile_protos(&["proto/transport_event.proto"], &["proto/"]) +} diff --git a/lib/protoflow-zeromq/proto/transport_event.proto b/lib/protoflow-zeromq/proto/transport_event.proto new file mode 100644 index 00000000..8a31eda4 --- /dev/null +++ b/lib/protoflow-zeromq/proto/transport_event.proto @@ -0,0 +1,46 @@ +syntax = "proto3"; + +package protoflow.zmq; + +message Connect { + int64 output = 1; + int64 input = 2; +} + +message AckConnection { + int64 output = 1; + int64 input = 2; +} + +message Message { + int64 output = 1; + int64 input = 2; + uint64 sequence = 3; + bytes message = 4; +} + +message AckMessage { + int64 output = 1; + int64 input = 2; + uint64 sequence = 3; +} + +message CloseOutput { + int64 output = 1; + int64 input = 2; +} + +message CloseInput { + int64 input = 1; +} + +message Event { + oneof payload { + Connect connect = 1; + AckConnection ack_connection = 2; + Message message = 3; + AckMessage ack_message = 4; + CloseOutput close_output = 5; + CloseInput close_input = 6; + } +} diff --git a/lib/protoflow-zeromq/src/event.rs b/lib/protoflow-zeromq/src/event.rs new file mode 100644 index 00000000..f5f60c19 --- /dev/null +++ b/lib/protoflow-zeromq/src/event.rs @@ -0,0 +1,160 @@ +// This is free and unencumbered software released into the public domain. + +use protoflow_core::{ + prelude::{Bytes, Vec}, + InputPortID, OutputPortID, +}; +use zeromq::ZmqMessage; + +pub type SequenceID = u64; + +/// ZmqTransportEvent represents the data that goes over the wire from one port to another. +#[derive(Clone, Debug, PartialEq)] +pub enum ZmqTransportEvent { + Connect(OutputPortID, InputPortID), + AckConnection(OutputPortID, InputPortID), + Message(OutputPortID, InputPortID, SequenceID, Bytes), + AckMessage(OutputPortID, InputPortID, SequenceID), + CloseOutput(OutputPortID, InputPortID), + CloseInput(InputPortID), +} + +impl ZmqTransportEvent { + fn write_topic(&self, f: &mut W) -> Result<(), std::io::Error> { + use ZmqTransportEvent::*; + match self { + Connect(o, i) => write!(f, "{}:conn:{}", i, o), + AckConnection(o, i) => write!(f, "{}:ackConn:{}", i, o), + Message(o, i, seq, _) => write!(f, "{}:msg:{}:{}", i, o, seq), + AckMessage(o, i, seq) => write!(f, "{}:ackMsg:{}:{}", i, o, seq), + CloseOutput(o, i) => write!(f, "{}:closeOut:{}", i, o), + CloseInput(i) => write!(f, "{}:closeIn", i), + } + } +} + +impl From for ZmqMessage { + fn from(value: ZmqTransportEvent) -> Self { + let mut topic = Vec::new(); + value.write_topic(&mut topic).unwrap(); + + // first frame of the message is the topic + let mut msg = ZmqMessage::from(topic); + + fn map_id(id: T) -> i64 + where + isize: From, + { + isize::from(id) as i64 + } + + // second frame of the message is the payload + use crate::protoflow_zmq::{self, event::Payload, Event}; + use prost::Message; + use ZmqTransportEvent::*; + let payload = match value { + Connect(output, input) => Payload::Connect(protoflow_zmq::Connect { + output: map_id(output), + input: map_id(input), + }), + AckConnection(output, input) => Payload::AckConnection(protoflow_zmq::AckConnection { + output: map_id(output), + input: map_id(input), + }), + Message(output, input, sequence, message) => Payload::Message(protoflow_zmq::Message { + output: map_id(output), + input: map_id(input), + sequence, + message: message.to_vec(), + }), + AckMessage(output, input, sequence) => Payload::AckMessage(protoflow_zmq::AckMessage { + output: map_id(output), + input: map_id(input), + sequence, + }), + CloseOutput(output, input) => Payload::CloseOutput(protoflow_zmq::CloseOutput { + output: map_id(output), + input: map_id(input), + }), + CloseInput(input) => Payload::CloseInput(protoflow_zmq::CloseInput { + input: map_id(input), + }), + }; + + let bytes = Event { + payload: Some(payload), + } + .encode_to_vec(); + msg.push_back(bytes.into()); + + msg + } +} + +impl TryFrom for ZmqTransportEvent { + type Error = protoflow_core::DecodeError; + + fn try_from(value: ZmqMessage) -> Result { + use crate::protoflow_zmq::{self, event::Payload, Event}; + use prost::Message; + use protoflow_core::DecodeError; + + fn map_id(id: i64) -> Result + where + T: TryFrom, + std::borrow::Cow<'static, str>: From<>::Error>, + { + (id as isize).try_into().map_err(DecodeError::new) + } + + value + .get(1) + .ok_or_else(|| { + protoflow_core::DecodeError::new("message contains less than two frames") + }) + .and_then(|bytes| { + let event = Event::decode(bytes.as_ref())?; + + use ZmqTransportEvent::*; + Ok(match event.payload { + None => { + return Err(protoflow_core::DecodeError::new("message payload is empty")) + } + Some(Payload::Connect(protoflow_zmq::Connect { output, input })) => { + Connect(map_id(output)?, map_id(input)?) + } + + Some(Payload::AckConnection(protoflow_zmq::AckConnection { + output, + input, + })) => AckConnection(map_id(output)?, map_id(input)?), + + Some(Payload::Message(protoflow_zmq::Message { + output, + input, + sequence, + message, + })) => Message( + map_id(output)?, + map_id(input)?, + sequence, + Bytes::from(message), + ), + + Some(Payload::AckMessage(protoflow_zmq::AckMessage { + output, + input, + sequence, + })) => AckMessage(map_id(output)?, map_id(input)?, sequence), + + Some(Payload::CloseOutput(protoflow_zmq::CloseOutput { output, input })) => { + CloseOutput(map_id(output)?, map_id(input)?) + } + + Some(Payload::CloseInput(protoflow_zmq::CloseInput { input })) => { + CloseInput(map_id(input)?) + } + }) + }) + } +} diff --git a/lib/protoflow-zeromq/src/input_port.rs b/lib/protoflow-zeromq/src/input_port.rs new file mode 100644 index 00000000..15074e7b --- /dev/null +++ b/lib/protoflow-zeromq/src/input_port.rs @@ -0,0 +1,619 @@ +// This is free and unencumbered software released into the public domain. + +use crate::{ + subscribe_topics, unsubscribe_topics, SequenceID, ZmqSubscriptionRequest, ZmqTransport, + ZmqTransportEvent, +}; +use protoflow_core::{ + prelude::{fmt, format, vec, Arc, BTreeMap, Bytes, String, ToString, Vec}, + InputPortID, OutputPortID, PortError, PortState, +}; +use tokio::sync::{ + mpsc::{channel, Receiver, Sender}, + Mutex, RwLock, +}; + +#[cfg(feature = "tracing")] +use tracing::{error, info, trace, trace_span, warn}; + +#[derive(Clone, Debug)] +pub enum ZmqInputPortRequest { + Close, +} + +/// ZmqInputPortEvent represents events that we receive from the background worker of the port. +#[derive(Clone, Debug, PartialEq)] +pub enum ZmqInputPortEvent { + Message(Bytes), + Closed, +} + +#[derive(Clone, Debug)] +pub enum ZmqInputPortState { + Open( + // channel for close requests from the public `close` method + Sender<(ZmqInputPortRequest, Sender>)>, + // channel used internally for events from socket + Sender, + ), + Connected( + // channel for requests from public close + Sender<(ZmqInputPortRequest, Sender>)>, + // channels to send-to and receive-from the public `recv` method + Sender, + Arc>>, + // channel used internally for events from socket + Sender, + // vec of the connected port ids + BTreeMap, + ), + Closed, +} + +impl fmt::Display for ZmqInputPortState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use ZmqInputPortState::*; + match *self { + Open(..) => write!(f, "Open"), + Connected(.., ref ids) => { + write!( + f, + "Connected({:?})", + ids.keys().map(|id| isize::from(*id)).collect::>() + ) + } + Closed => write!(f, "Closed"), + } + } +} + +impl ZmqInputPortState { + pub fn state(&self) -> PortState { + use ZmqInputPortState::*; + match self { + Open(..) => PortState::Open, + Connected(..) => PortState::Connected, + Closed => PortState::Closed, + } + } + + pub async fn event_sender(&self) -> Option> { + use ZmqInputPortState::*; + match self { + Open(_, sender) | Connected(.., sender, _) => Some(sender.clone()), + Closed => None, + } + } +} + +fn input_topics(id: InputPortID) -> Vec { + vec![ + format!("{}:conn", id), + format!("{}:msg", id), + format!("{}:closeOut", id), + ] +} + +pub async fn input_port_event_sender( + inputs: &RwLock>>, + id: InputPortID, +) -> Option> { + inputs + .read() + .await + .get(&id)? + .read() + .await + .event_sender() + .await +} + +pub fn start_input_worker( + transport: &ZmqTransport, + input_port_id: InputPortID, +) -> Result<(), PortError> { + #[cfg(feature = "tracing")] + let span = trace_span!("ZmqTransport::input_port_worker", ?input_port_id); + + let (to_worker_send, mut to_worker_recv) = channel(1); + let (req_send, mut req_recv) = channel(1); + + { + let mut inputs = transport.tokio.block_on(transport.inputs.write()); + if inputs.contains_key(&input_port_id) { + return Err(PortError::Invalid(input_port_id.into())); + } + let state = ZmqInputPortState::Open(req_send.clone(), to_worker_send.clone()); + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("saving new state: {}", state)); + inputs.insert(input_port_id, RwLock::new(state)); + } + + let sub_queue = transport.sub_queue.clone(); + let pub_queue = transport.pub_queue.clone(); + let inputs = transport.inputs.clone(); + + let topics = input_topics(input_port_id); + if transport + .tokio + .block_on(subscribe_topics(&topics, &sub_queue)) + .is_err() + { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("topic subscription failed")); + return Err(PortError::Other("topic subscription failed".to_string())); + } + + async fn handle_socket_event( + event: ZmqTransportEvent, + inputs: &RwLock>>, + pub_queue: &Sender, + input_port_id: InputPortID, + ) { + #[cfg(feature = "tracing")] + let span = trace_span!( + "ZmqTransport::input_port_worker::handle_socket_event", + ?input_port_id + ); + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?event, "got socket event")); + + use ZmqTransportEvent::*; + match event { + Connect(output_port_id, target_id) => { + #[cfg(feature = "tracing")] + let span = trace_span!(parent: &span, "Connect", ?output_port_id); + + debug_assert_eq!(input_port_id, target_id); + + let inputs = inputs.read().await; + let Some(input_state) = inputs.get(&input_port_id) else { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("port state not found")); + return; + }; + let mut input_state = input_state.write().await; + + use ZmqInputPortState::*; + match &*input_state { + Open(..) => (), + Connected(.., connected_ids) => { + if connected_ids.contains_key(&output_port_id) { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("output port is already connected")); + return; + } + } + Closed => return, + }; + + let add_connection = |input_state: &mut ZmqInputPortState| match input_state { + Open(req_send, to_worker_send) => { + let (msgs_send, msgs_recv) = channel(1); + let msgs_recv = Arc::new(Mutex::new(msgs_recv)); + let mut connected_ids = BTreeMap::new(); + connected_ids.insert(output_port_id, 0); + *input_state = Connected( + req_send.clone(), + msgs_send, + msgs_recv, + to_worker_send.clone(), + connected_ids, + ); + } + Connected(.., ids) => { + ids.insert(output_port_id, 0); + } + Closed => unreachable!(), + }; + + if pub_queue + .send(ZmqTransportEvent::AckConnection( + output_port_id, + input_port_id, + )) + .await + .is_err() + { + #[cfg(feature = "tracing")] + span.in_scope(|| warn!("publish channel is closed")); + return; + } + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("sent conn-ack")); + + add_connection(&mut input_state); + + #[cfg(feature = "tracing")] + span.in_scope(|| info!("Connected new port: {}", input_state)); + } + Message(output_port_id, target_id, msg_seq_id, bytes) => { + #[cfg(feature = "tracing")] + let span = trace_span!(parent: &span, "Message", ?output_port_id, ?msg_seq_id); + + debug_assert_eq!(input_port_id, target_id); + + let inputs = inputs.read().await; + let Some(input_state) = inputs.get(&input_port_id) else { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("port state not found")); + return; + }; + let mut input_state = input_state.write().await; + + use ZmqInputPortState::*; + match *input_state { + Connected(_, ref sender, _, _, ref mut connected_ids) => { + let Some(&last_seen_seq_id) = connected_ids.get(&output_port_id) else { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("got message from non-connected output port")); + return; + }; + + let send_ack = { + #[cfg(feature = "tracing")] + let span = span.clone(); + + |ack_id| async move { + if pub_queue + .send(ZmqTransportEvent::AckMessage( + output_port_id, + input_port_id, + ack_id, + )) + .await + .is_err() + { + #[cfg(feature = "tracing")] + span.in_scope(|| warn!("publish channel is closed")); + } + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?ack_id, "sent msg-ack")); + } + }; + + use std::cmp::Ordering::*; + match msg_seq_id.cmp(&last_seen_seq_id) { + // seq_id for msg is greater than last seen seq_id by one + Greater if (msg_seq_id - last_seen_seq_id == 1) => { + if sender + .send(ZmqInputPortEvent::Message(bytes)) + .await + .is_err() + { + #[cfg(feature = "tracing")] + span.in_scope(|| warn!("receiver for input events has closed")); + return; + } + send_ack(msg_seq_id).await; + let _ = connected_ids.insert(output_port_id, msg_seq_id); + } + Equal => { + send_ack(last_seen_seq_id).await; + } + // either the seq_id is greater than the last seen seq_id by more than + // one, or somehow less than the last seen seq_id: + _ => { + send_ack(last_seen_seq_id).await; + } + } + } + + Open(..) | Closed => { + #[cfg(feature = "tracing")] + span.in_scope(|| warn!("port is not connected: {}", input_state)); + } + } + } + CloseOutput(output_port_id, target_id) => { + #[cfg(feature = "tracing")] + let span = trace_span!(parent: &span, "CloseOutput", ?output_port_id); + + debug_assert_eq!(input_port_id, target_id); + + let inputs = inputs.read().await; + let Some(input_state) = inputs.get(&input_port_id) else { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("port state not found")); + return; + }; + + let mut input_state = input_state.write().await; + + use ZmqInputPortState::*; + let Connected(ref req_send, ref sender, _, ref event_sender, ref mut connected_ids) = + *input_state + else { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("input port wasn't connected")); + return; + }; + + if connected_ids.remove(&output_port_id).is_none() { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("output port doesn't match any connected port")); + return; + } + + if !connected_ids.is_empty() { + return; + } + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("last connected port disconnected")); + + if let Err(err) = sender.try_send(ZmqInputPortEvent::Closed) { + #[cfg(feature = "tracing")] + span.in_scope(|| warn!("did not send InputPortEvent::Closed: {}", err)); + } + + // TODO: Should last connection closing close the input port too? + // It does in the MPSC transport. + //*input_state = ZmqInputPortState::Closed; + + *input_state = Open(req_send.clone(), event_sender.clone()) + } + + // ignore, ideally we never receive these here: + AckConnection(..) | AckMessage(..) | CloseInput(_) => (), + } + } + + async fn handle_input_request( + request: ZmqInputPortRequest, + response_chan: Sender>, + inputs: &RwLock>>, + pub_queue: &Sender, + sub_queue: &Sender, + input_port_id: InputPortID, + ) { + #[cfg(feature = "tracing")] + let span = trace_span!( + "ZmqTransport::input_port_worker::handle_input_event", + ?input_port_id + ); + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?request, "got input request")); + + use ZmqInputPortRequest::*; + match request { + Close => { + let inputs = inputs.read().await; + let Some(input_state) = inputs.get(&input_port_id) else { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("port state not found")); + return; + }; + let mut input_state = input_state.write().await; + + use ZmqInputPortState::*; + + if let Closed = *input_state { + return; + } + + if let Connected(_, ref port_events, ..) = *input_state { + if let Err(err) = port_events.try_send(ZmqInputPortEvent::Closed) { + #[cfg(feature = "tracing")] + span.in_scope(|| warn!("did not send InputPortEvent::Closed: {}", err)); + } + }; + + if pub_queue + .send(ZmqTransportEvent::CloseInput(input_port_id)) + .await + .is_err() + { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("can't publish CloseInput event")); + // don't exit, continue to close the port + } + + *input_state = ZmqInputPortState::Closed; + + let topics = input_topics(input_port_id); + if unsubscribe_topics(&topics, sub_queue).await.is_err() { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("topic unsubscription failed")); + } + + if response_chan.send(Ok(())).await.is_err() { + #[cfg(feature = "tracing")] + span.in_scope(|| warn!("response channel is closed")); + } + } + } + } + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("spawning")); + + tokio::task::spawn(async move { + // Input worker loop: + // 1. Receive connection attempts and respond + // 2. Receive messages and forward to channel + // 3. Receive and handle disconnects + loop { + tokio::select! { + Some(event) = to_worker_recv.recv() => { + handle_socket_event(event, &inputs, &pub_queue, input_port_id).await; + } + Some((request, response_chan)) = req_recv.recv() => { + handle_input_request(request, response_chan, &inputs, &pub_queue, &sub_queue, input_port_id).await; + } + else => break, + }; + } + + #[cfg(feature = "tracing")] + { + let state = match inputs.read().await.get(&input_port_id) { + Some(input) => Some(input.read().await.clone()), + None => None, + }; + span.in_scope(|| { + trace!( + events_closed = to_worker_recv.is_closed(), + requests_closed = req_recv.is_closed(), + ?state, + "exited input worker loop" + ) + }); + } + }); + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::time::Duration; + + #[test] + fn redelivery_is_idempotent() { + let _ = tracing_subscriber::fmt::try_init(); + + let rt = tokio::runtime::Runtime::new().unwrap(); + let _guard = rt.enter(); + + let (pub_queue, mut pub_queue_recv) = channel(1); + let (sub_queue, mut sub_queue_recv) = channel(1); + + let inputs = Arc::new(RwLock::new(BTreeMap::new())); + let outputs = Arc::new(RwLock::new(BTreeMap::new())); + + let output_id = OutputPortID::try_from(1).unwrap(); + let input_id = InputPortID::try_from(-1).unwrap(); + + let transport = ZmqTransport { + tokio: rt.handle().clone(), + pub_queue, + sub_queue, + inputs: inputs.clone(), + outputs: outputs.clone(), + }; + + // start a fake socket worker that just drops all messages + let sub_queue = tokio::task::spawn(async move { + while sub_queue_recv.recv().await.is_some() {} + Some(()) + }); + + start_input_worker(&transport, input_id).unwrap(); + + let (recv_send, recv_recv) = channel(1); + let recv_recv = Arc::new(Mutex::new(recv_recv)); + + // manually connect the port + let (req_sender, event_sender) = rt.block_on(async { + let inputs = inputs.read().await; + let mut input_state = inputs.get(&input_id).unwrap().write().await; + + let ZmqInputPortState::Open(ref req_sender, ref event_sender) = *input_state else { + panic!(""); + }; + let req_sender = req_sender.clone(); + let event_sender = event_sender.clone(); + + let mut connected_ids = BTreeMap::new(); + connected_ids.insert(output_id, 0); + + *input_state = ZmqInputPortState::Connected( + req_sender.clone(), + recv_send, + recv_recv.clone(), + event_sender.clone(), + connected_ids, + ); + + (req_sender.clone(), event_sender.clone()) + }); + + let timeout = Duration::from_secs(1); + + // send a message from the `output_id` to the worker + rt.block_on(tokio::time::timeout( + timeout, + event_sender.send(ZmqTransportEvent::Message( + output_id, + input_id, + 1, + Bytes::new(), + )), + )) + .unwrap() + .unwrap(); + + // verify that the worker tries to publish a msg-ack + assert_eq!( + Some(ZmqTransportEvent::AckMessage(output_id, input_id, 1)), + rt.block_on(pub_queue_recv.recv()) + ); + + // verify that the worker forwards a new message + assert_eq!( + Ok(Some(ZmqInputPortEvent::Message(Bytes::new()))), + rt.block_on(tokio::time::timeout(timeout, async { + recv_recv.lock().await.recv().await + })) + ); + + // send a new message with the same sequence id to the worker + rt.block_on(tokio::time::timeout( + timeout, + event_sender.send(ZmqTransportEvent::Message( + output_id, + input_id, + 1, + Bytes::new(), + )), + )) + .unwrap() + .unwrap(); + + // verify that the worker tries to publish a msg-ack + assert_eq!( + Ok(Some(ZmqTransportEvent::AckMessage(output_id, input_id, 1))), + rt.block_on(tokio::time::timeout(timeout, pub_queue_recv.recv())) + ); + + // verify that the worker *DOESN'T* forward the message + assert!(rt + .block_on(tokio::time::timeout(timeout, async { + recv_recv.lock().await.recv().await + })) + .is_err()); + + let (close_send, mut close_recv) = channel(1); + + // send a close request the worker + rt.block_on(tokio::time::timeout(timeout, async { + req_sender + .send((ZmqInputPortRequest::Close, close_send)) + .await + .unwrap(); + close_recv.recv().await.unwrap() + })) + .unwrap() + .unwrap(); + + // drop remaining references to the channels that the worker is waiting on + drop(event_sender); + drop(req_sender); + drop(transport); + + // verify that the fake socket worker also exits, implies that the worker has exited as the + // channel sender references must be dropped for the fake worker to exit. + assert_eq!( + Some(()), + rt.block_on(tokio::time::timeout(timeout, sub_queue)) + .unwrap() + .unwrap() + ); + } +} diff --git a/lib/protoflow-zeromq/src/lib.rs b/lib/protoflow-zeromq/src/lib.rs index 0f186b2c..2b579c4d 100644 --- a/lib/protoflow-zeromq/src/lib.rs +++ b/lib/protoflow-zeromq/src/lib.rs @@ -5,3 +5,512 @@ #[doc(hidden)] pub use protoflow_core::prelude; + +#[path = "protoflow.zmq.rs"] +mod protoflow_zmq; + +mod input_port; +use input_port::*; + +mod output_port; +use output_port::*; + +mod socket; +use socket::*; + +mod event; +use event::*; + +extern crate std; + +use protoflow_core::{ + prelude::{Arc, BTreeMap, Bytes, ToString}, + InputPortID, OutputPortID, PortError, PortResult, PortState, Transport, +}; + +use tokio::sync::{ + mpsc::{channel, error::TryRecvError, Sender}, + RwLock, +}; +use zeromq::{util::PeerIdentity, Socket, SocketOptions}; + +#[cfg(feature = "tracing")] +use tracing::trace; + +const DEFAULT_PUB_SOCKET: &str = "tcp://127.0.0.1:10000"; +const DEFAULT_SUB_SOCKET: &str = "tcp://127.0.0.1:10001"; + +pub struct ZmqTransport { + tokio: tokio::runtime::Handle, + + pub_queue: Sender, + sub_queue: Sender, + + outputs: Arc>>>, + inputs: Arc>>>, +} + +impl Default for ZmqTransport { + fn default() -> Self { + Self::new(DEFAULT_PUB_SOCKET, DEFAULT_SUB_SOCKET) + } +} + +impl ZmqTransport { + pub fn new(pub_url: &str, sub_url: &str) -> Self { + let tokio = tokio::runtime::Handle::current(); + + let peer_id = PeerIdentity::new(); + + let psock = { + let peer_id = peer_id.clone(); + let mut sock_opts = SocketOptions::default(); + sock_opts.peer_identity(peer_id); + + let mut psock = zeromq::PubSocket::with_options(sock_opts); + tokio + .block_on(psock.connect(pub_url)) + .expect("failed to connect PUB"); + psock + }; + + let ssock = { + let mut sock_opts = SocketOptions::default(); + sock_opts.peer_identity(peer_id); + + let mut ssock = zeromq::SubSocket::with_options(sock_opts); + tokio + .block_on(ssock.connect(sub_url)) + .expect("failed to connect SUB"); + ssock + }; + + let outputs = Arc::new(RwLock::new(BTreeMap::default())); + let inputs = Arc::new(RwLock::new(BTreeMap::default())); + + let (pub_queue, pub_queue_recv) = channel(1); + + let (sub_queue, sub_queue_recv) = channel(1); + + let transport = Self { + pub_queue, + sub_queue, + tokio, + outputs, + inputs, + }; + + start_pub_socket_worker(&transport, psock, pub_queue_recv); + start_sub_socket_worker(&transport, ssock, sub_queue_recv); + + transport + } +} + +impl Transport for ZmqTransport { + fn input_state(&self, input: InputPortID) -> PortResult { + self.tokio.block_on(async { + Ok(self + .inputs + .read() + .await + .get(&input) + .ok_or_else(|| PortError::Invalid(input.into()))? + .read() + .await + .state()) + }) + } + + fn output_state(&self, output: OutputPortID) -> PortResult { + self.tokio.block_on(async { + Ok(self + .outputs + .read() + .await + .get(&output) + .ok_or_else(|| PortError::Invalid(output.into()))? + .read() + .await + .state()) + }) + } + + fn open_input(&self) -> PortResult { + #[cfg(feature = "tracing")] + trace!(target: "ZmqTransport::open_input", "creating new input port"); + + let new_id = { + let inputs = self.tokio.block_on(self.inputs.read()); + InputPortID::try_from(-(inputs.len() as isize + 1)) + .map_err(|e| PortError::Other(e.to_string()))? + }; + + #[cfg(feature = "tracing")] + trace!(target: "ZmqTransport::open_input", ?new_id, "created new input port"); + + start_input_worker(self, new_id).map(|_| new_id) + } + + fn open_output(&self) -> PortResult { + #[cfg(feature = "tracing")] + trace!(target: "ZmqTransport::open_output", "creating new output port"); + + let new_id = { + let outputs = self.tokio.block_on(self.outputs.read()); + OutputPortID::try_from(outputs.len() as isize + 1) + .map_err(|e| PortError::Other(e.to_string()))? + }; + + #[cfg(feature = "tracing")] + trace!(target: "ZmqTransport::open_output", ?new_id, "created new output port"); + + start_output_worker(self, new_id).map(|_| new_id) + } + + fn close_input(&self, input: InputPortID) -> PortResult { + self.tokio.block_on(async { + let sender = { + let inputs = self.inputs.read().await; + let Some(input_state) = inputs.get(&input) else { + return Err(PortError::Invalid(input.into())); + }; + let input_state = input_state.read().await; + + use ZmqInputPortState::*; + match *input_state { + Open(ref sender, _) | Connected(ref sender, ..) => sender.clone(), + Closed => return Ok(false), // already closed + } + }; + + let (close_send, mut close_recv) = channel(1); + + sender + .send((ZmqInputPortRequest::Close, close_send)) + .await + .map_err(|e| PortError::Other(e.to_string()))?; + + close_recv + .recv() + .await + .ok_or(PortError::Disconnected)? + .map(|_| true) + }) + } + + fn close_output(&self, output: OutputPortID) -> PortResult { + self.tokio.block_on(async { + let mut close_recv = { + let outputs = self.outputs.read().await; + let Some(output_state) = outputs.get(&output) else { + return Err(PortError::Invalid(output.into())); + }; + + let output_state = output_state.read().await; + let (close_send, close_recv) = channel(1); + + use ZmqOutputPortState::*; + match *output_state { + Open(_, ref sender, _) => sender + .send(close_send) + .await + .map_err(|e| PortError::Other(e.to_string()))?, + Connected(ref sender, ..) => sender + .send((ZmqOutputPortRequest::Close, close_send)) + .await + .map_err(|e| PortError::Other(e.to_string()))?, + Closed => return Ok(false), // already closed + }; + + close_recv + }; + + close_recv + .recv() + .await + .ok_or(PortError::Disconnected)? + .map(|_| true) + }) + } + + fn connect(&self, source: OutputPortID, target: InputPortID) -> PortResult { + #[cfg(feature = "tracing")] + trace!(target: "ZmqTransport::connect", ?source, ?target, "connecting ports"); + + self.tokio.block_on(async { + let sender = { + let outputs = self.outputs.read().await; + let Some(output_state) = outputs.get(&source) else { + return Err(PortError::Invalid(source.into())); + }; + + let output_state = output_state.read().await; + let ZmqOutputPortState::Open(ref sender, _, _) = *output_state else { + return Err(PortError::Invalid(source.into())); + }; + + sender.clone() + }; + + let (confirm_send, mut confirm_recv) = channel(1); + + sender + .send((target, confirm_send)) + .await + .map_err(|e| PortError::Other(e.to_string()))?; + + confirm_recv + .recv() + .await + .ok_or(PortError::Disconnected)? + .map(|_| true) + }) + } + + fn send(&self, output: OutputPortID, message: Bytes) -> PortResult<()> { + #[cfg(feature = "tracing")] + trace!(target: "ZmqTransport::send", ?output, "sending from output port"); + + self.tokio.block_on(async { + let sender = { + let outputs = self.outputs.read().await; + let Some(output) = outputs.get(&output) else { + return Err(PortError::Invalid(output.into())); + }; + let output = output.read().await; + + let ZmqOutputPortState::Connected(sender, _, _) = &*output else { + return Err(PortError::Disconnected); + }; + + sender.clone() + }; + + let (ack_send, mut ack_recv) = channel(1); + + sender + .send((ZmqOutputPortRequest::Send(message), ack_send)) + .await + .map_err(|e| PortError::Other(e.to_string()))?; + + ack_recv.recv().await.ok_or(PortError::Disconnected)? + }) + } + + fn recv(&self, input: InputPortID) -> PortResult> { + #[cfg(feature = "tracing")] + trace!(target: "ZmqTransport::recv", ?input, "receiving from input port"); + + self.tokio.block_on(async { + let receiver = { + let inputs = self.inputs.read().await; + let Some(input_state) = inputs.get(&input) else { + return Err(PortError::Invalid(input.into())); + }; + + let input_state = input_state.read().await; + let ZmqInputPortState::Connected(_, _, receiver, _, _) = &*input_state else { + return Err(PortError::Disconnected); + }; + + receiver.clone() + }; + + let mut receiver = receiver.lock().await; + + use ZmqInputPortEvent::*; + match receiver.recv().await { + Some(Closed) => Ok(None), // EOS + Some(Message(bytes)) => Ok(Some(bytes)), + None => Err(PortError::Disconnected), + } + }) + } + + fn try_recv(&self, input: InputPortID) -> PortResult> { + #[cfg(feature = "tracing")] + trace!(target: "ZmqTransport::try_recv", ?input, "receiving from input port"); + + self.tokio.block_on(async { + let receiver = { + let inputs = self.inputs.read().await; + let Some(input_state) = inputs.get(&input) else { + return Err(PortError::Invalid(input.into())); + }; + + let input_state = input_state.read().await; + let ZmqInputPortState::Connected(_, _, receiver, _, _) = &*input_state else { + return Err(PortError::Disconnected); + }; + + receiver.clone() + }; + + let mut receiver = receiver.lock().await; + + use ZmqInputPortEvent::*; + match receiver.try_recv() { + Ok(Closed) => Ok(None), // EOS + Ok(Message(bytes)) => Ok(Some(bytes)), + Err(TryRecvError::Disconnected) => Err(PortError::Disconnected), + // TODO: what should we answer with here?: + Err(TryRecvError::Empty) => Err(PortError::RecvFailed), + } + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use protoflow_core::{runtimes::StdRuntime, System}; + use std::time::Duration; + + use futures_util::future::TryFutureExt; + use zeromq::{PubSocket, SocketRecv, SocketSend, SubSocket}; + + async fn start_zmqtransport_server() { + // retry for a second + for _ in 0..20 { + // bind a *SUB* socket to the *PUB* address so that the transport can *PUB* to it + let mut pub_srv = SubSocket::new(); + if pub_srv.bind(DEFAULT_PUB_SOCKET).await.is_err() { + tokio::time::sleep(Duration::from_millis(50)).await; + continue; + } + + // bind a *PUB* socket to the *SUB* address so that the transport can *SUB* to it + let mut sub_srv = PubSocket::new(); + if sub_srv.bind(DEFAULT_SUB_SOCKET).await.is_err() { + tokio::time::sleep(Duration::from_millis(50)).await; + continue; + } + + // subscribe to all messages + pub_srv.subscribe("").await.unwrap(); + + // resend anything received on the *SUB* socket to the *PUB* socket + tokio::task::spawn(async move { + let mut pub_srv = pub_srv; + loop { + pub_srv + .recv() + .and_then(|msg| sub_srv.send(msg)) + .await + .unwrap(); + } + }); + + return; + } + + panic!("unable to start server for tests, are the ports 10000 and 10001 available?"); + } + + #[test] + fn implementation_matches() { + let rt = tokio::runtime::Runtime::new().unwrap(); + let _guard = rt.enter(); + + rt.block_on(start_zmqtransport_server()); + + let _ = System::::build(|_s| { /* do nothing */ }); + } + + #[test] + fn run_transport() { + let _ = tracing_subscriber::fmt::try_init(); + + let rt = tokio::runtime::Runtime::new().unwrap(); + let _guard = rt.enter(); + + rt.block_on(start_zmqtransport_server()); + + let transport = ZmqTransport::default(); + let runtime = StdRuntime::new(transport).unwrap(); + let system = System::new(&runtime); + + let output = system.output(); + let input = system.input(); + + system.connect(&output, &input); + + let output = std::thread::spawn(move || { + let mut output = output; + output.send(&"Hello world!".to_string())?; + output.close() + }); + + let input = std::thread::spawn(move || { + let mut input = input; + + let msg = input.recv()?; + assert_eq!(Some("Hello world!".to_string()), msg); + + let msg = input.recv()?; + assert_eq!(None, msg); + + input.close() + }); + + output.join().expect("thread failed").unwrap(); + input.join().expect("thread failed").unwrap(); + } + + #[test] + fn multiple_outputs_to_one_input() { + let _ = tracing_subscriber::fmt::try_init(); + + let rt = tokio::runtime::Runtime::new().unwrap(); + let _guard = rt.enter(); + + rt.block_on(start_zmqtransport_server()); + + let transport = ZmqTransport::default(); + let runtime = StdRuntime::new(transport).unwrap(); + let system = System::new(&runtime); + + let mut output1 = system.output(); + let mut output2 = system.output(); + + let mut input = system.input(); + + assert!(system.connect(&output1, &input)); + assert!(system.connect(&output2, &input)); + + output1.send(&"Hello from one!".to_string()).unwrap(); + assert_eq!(Some("Hello from one!".to_string()), input.recv().unwrap()); + + output2.send(&"Hello from two!".to_string()).unwrap(); + assert_eq!(Some("Hello from two!".to_string()), input.recv().unwrap()); + + output1.send(&"Hello from one again!".to_string()).unwrap(); + assert_eq!( + Some("Hello from one again!".to_string()), + input.recv().unwrap() + ); + + assert!(input.close().unwrap()); + assert_eq!( + Err(PortError::Disconnected), + output1.send(&"Hello from one!".to_string()) + ); + assert_eq!( + Err(PortError::Disconnected), + output2.send(&"Hello from two!".to_string()) + ); + + assert_eq!(Err(PortError::Disconnected), input.try_recv()); + + assert!( + !output1.close().unwrap(), + "closing output should return Ok(false) because input was already closed" + ); + assert!( + !output2.close().unwrap(), + "closing output should return Ok(false) because input was already closed" + ); + } +} diff --git a/lib/protoflow-zeromq/src/output_port.rs b/lib/protoflow-zeromq/src/output_port.rs new file mode 100644 index 00000000..8d18f7d0 --- /dev/null +++ b/lib/protoflow-zeromq/src/output_port.rs @@ -0,0 +1,393 @@ +// This is free and unencumbered software released into the public domain. + +use crate::{subscribe_topics, unsubscribe_topics, ZmqTransport, ZmqTransportEvent}; +use protoflow_core::{ + prelude::{fmt, format, vec, BTreeMap, Bytes, String, ToString, Vec}, + InputPortID, OutputPortID, PortError, PortState, +}; +use tokio::sync::{ + mpsc::{channel, Sender}, + RwLock, +}; + +#[cfg(feature = "tracing")] +use tracing::{debug, error, info, trace, trace_span, warn}; + +#[derive(Clone, Debug)] +pub enum ZmqOutputPortRequest { + Close, + Send(Bytes), +} + +const DEFAULT_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(200); +const DEFAULT_MAX_RETRIES: u64 = 10; + +#[derive(Clone, Debug)] +pub enum ZmqOutputPortState { + Open( + // channel for connection requests from public `connect` method + Sender<(InputPortID, Sender>)>, + // channel for close requests from the public `close` method + Sender>>, + // channel used internally for events from socket + Sender, + ), + Connected( + // channel for public `send` and `close` methods, contained channel is for the ack back + Sender<(ZmqOutputPortRequest, Sender>)>, + // channel used internally for events from socket + Sender, + // id of the connected input port + InputPortID, + ), + Closed, +} + +impl fmt::Display for ZmqOutputPortState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use ZmqOutputPortState::*; + match *self { + Open(..) => write!(f, "Open"), + Connected(.., ref id) => { + write!(f, "Connected({:?})", isize::from(*id),) + } + Closed => write!(f, "Closed"), + } + } +} + +impl ZmqOutputPortState { + pub fn state(&self) -> PortState { + use ZmqOutputPortState::*; + match self { + Open(..) => PortState::Open, + Connected(..) => PortState::Connected, + Closed => PortState::Closed, + } + } + + pub async fn event_sender(&self) -> Option> { + use ZmqOutputPortState::*; + match self { + Open(.., sender) | Connected(.., sender, _) => Some(sender.clone()), + Closed => None, + } + } +} + +fn output_topics(source: OutputPortID, target: InputPortID) -> Vec { + vec![ + format!("{}:ackConn:{}", target, source), + format!("{}:ackMsg:{}:", target, source), + format!("{}:closeIn", target), + ] +} + +pub async fn output_port_event_sender( + outputs: &RwLock>>, + id: OutputPortID, +) -> Option> { + outputs + .read() + .await + .get(&id)? + .read() + .await + .event_sender() + .await +} + +pub fn start_output_worker( + transport: &ZmqTransport, + output_port_id: OutputPortID, +) -> Result<(), PortError> { + #[cfg(feature = "tracing")] + let span = trace_span!("ZmqTransport::output_port_worker", ?output_port_id); + + let (conn_send, mut conn_recv) = channel(1); + let (close_send, mut close_recv) = channel(1); + let (to_worker_send, mut to_worker_recv) = channel(1); + + { + let mut outputs = transport.tokio.block_on(transport.outputs.write()); + if outputs.contains_key(&output_port_id) { + return Err(PortError::Invalid(output_port_id.into())); + } + let state = ZmqOutputPortState::Open(conn_send, close_send, to_worker_send.clone()); + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("saving new state: {}", state)); + outputs.insert(output_port_id, RwLock::new(state)); + } + + let sub_queue = transport.sub_queue.clone(); + let pub_queue = transport.pub_queue.clone(); + let outputs = transport.outputs.clone(); + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("spawning")); + + tokio::task::spawn(async move { + let (input_port_id, conn_confirm) = tokio::select! { + Some((input_port_id, conn_confirm)) = conn_recv.recv() => (input_port_id, conn_confirm), + Some(close_confirm) = close_recv.recv() => { + let response = { + if let Some(output_state) = outputs.read().await.get(&output_port_id) { + let mut output_state = output_state.write().await; + debug_assert!(matches!(*output_state, ZmqOutputPortState::Open(..))); + *output_state = ZmqOutputPortState::Closed; + Ok(()) + } else { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("port state not found")); + Err(PortError::Invalid(output_port_id.into())) + } + }; + + let _ = close_confirm.try_send(response); + return; + } + else => { + // all senders have dropped, i.e. there's no connection request coming + + if let Some(output_state) = outputs.read().await.get(&output_port_id) { + let mut output_state = output_state.write().await; + debug_assert!(matches!(*output_state, ZmqOutputPortState::Open(..))); + *output_state = ZmqOutputPortState::Closed; + } + + #[cfg(feature = "tracing")] + debug!(parent: &span, "no connection or close request"); + return; + } + }; + + #[cfg(feature = "tracing")] + let span = trace_span!(parent: &span, "task", ?input_port_id); + + let topics = output_topics(output_port_id, input_port_id); + if subscribe_topics(&topics, &sub_queue).await.is_err() { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("topic subscription failed")); + return; + } + + let (msg_req_send, mut msg_req_recv) = channel(1); + + // Output worker loop: + // 1. Send connection attempt + // 2. Send messages + // 2.1 Wait for ACK + // 2.2. Resend on timeout + // 3. Send disconnect events + + loop { + #[cfg(feature = "tracing")] + let span = trace_span!(parent: &span, "connect_loop"); + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("sending connection attempt...")); + + if pub_queue + .send(ZmqTransportEvent::Connect(output_port_id, input_port_id)) + .await + .is_err() + { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("publish channel is closed")); + return; + } + + let Some(response) = to_worker_recv.recv().await else { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("all senders to worker have dropped?")); + return; + }; + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?response, "got response")); + + use ZmqTransportEvent::*; + match response { + AckConnection(_, input_port_id) => { + let response = match outputs.read().await.get(&output_port_id) { + None => { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("port state not found")); + Err(PortError::Invalid(output_port_id.into())) + } + Some(output_state) => { + let mut output_state = output_state.write().await; + debug_assert!(matches!(*output_state, ZmqOutputPortState::Open(..))); + *output_state = ZmqOutputPortState::Connected( + msg_req_send, + to_worker_send, + input_port_id, + ); + + #[cfg(feature = "tracing")] + span.in_scope(|| info!("Connected!")); + + Ok(()) + } + }; + + if conn_confirm.send(response).await.is_err() { + #[cfg(feature = "tracing")] + span.in_scope(|| warn!("connection confirmation channel is closed")); + // don't exit, proceed to send loop + } + drop(conn_confirm); + + break; + } + _ => continue, + } + } + + let mut seq_id = 1; + 'send: while let Some((request, response_chan)) = msg_req_recv.recv().await { + #[cfg(feature = "tracing")] + let span = trace_span!(parent: &span, "send_loop", ?seq_id); + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?request, "sending request")); + + let respond = |response| async { + if response_chan.send(response).await.is_err() { + #[cfg(feature = "tracing")] + span.in_scope(|| warn!("response channel is closed")); + } + }; + + match request { + ZmqOutputPortRequest::Close => { + let response = pub_queue + .send(ZmqTransportEvent::CloseOutput( + output_port_id, + input_port_id, + )) + .await + .map_err(|e| PortError::Other(e.to_string())); + respond(response).await; + break 'send; + } + ZmqOutputPortRequest::Send(bytes) => { + let msg = ZmqTransportEvent::Message( + output_port_id, + input_port_id, + seq_id, + bytes.clone(), + ); + + let mut attempts = 0; + 'retry: loop { + attempts += 1; + + #[cfg(feature = "tracing")] + let span = trace_span!(parent: &span, "retry_loop", ?attempts); + + if attempts >= DEFAULT_MAX_RETRIES { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("reached max send attempts")); + respond(Err(PortError::Disconnected)).await; + break 'send; + } + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!("attempting to send message")); + + if pub_queue.send(msg.clone()).await.is_err() { + // the socket for publishing has closed, we won't be able to send any + // messages + respond(Err(PortError::Disconnected)).await; + break 'send; + } + + 'recv: loop { + #[cfg(feature = "tracing")] + let span = trace_span!(parent: &span, "recv_loop"); + + let timeout = tokio::time::sleep(DEFAULT_TIMEOUT); + + let event = tokio::select! { + // after DEFAULT_TIMEOUT duration has passed since the last + // received event from the socket, retry + _ = timeout => continue 'retry, + event_opt = to_worker_recv.recv() => match event_opt { + Some(event) => event, + None => { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("all senders to worker have dropped")); + respond(Err(PortError::Invalid(output_port_id.into()))).await; + break 'send; + } + } + }; + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?event, "received event")); + + use ZmqTransportEvent::*; + match event { + AckMessage(_, _, ack_id) => { + if ack_id == seq_id { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?ack_id, "msg-ack matches")); + respond(Ok(())).await; + break 'retry; + } else { + #[cfg(feature = "tracing")] + span.in_scope(|| { + trace!(?ack_id, "got msg-ack for different sequence") + }); + continue 'recv; + } + } + CloseInput(_) => { + // report that the input port was closed + respond(Err(PortError::Disconnected)).await; + break 'send; + } + + // ignore others, we shouldn't receive any new conn-acks + // nor should we be receiving input port events + AckConnection(..) | Connect(..) | Message(..) | CloseOutput(..) => { + continue 'recv + } + } + } + } + } + } + + seq_id += 1; + } + + let outputs = outputs.read().await; + let Some(output_state) = outputs.get(&output_port_id) else { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("port state not found")); + return; + }; + let mut output_state = output_state.write().await; + debug_assert!(matches!(*output_state, ZmqOutputPortState::Connected(..))); + *output_state = ZmqOutputPortState::Closed; + + #[cfg(feature = "tracing")] + span.in_scope(|| { + trace!( + events_closed = to_worker_recv.is_closed(), + requests_closed = msg_req_recv.is_closed(), + state = ?*output_state, + "exited output worker loop" + ) + }); + + if unsubscribe_topics(&topics, &sub_queue).await.is_err() { + #[cfg(feature = "tracing")] + span.in_scope(|| error!("topic unsubscription failed")); + } + }); + + Ok(()) +} diff --git a/lib/protoflow-zeromq/src/protoflow.zmq.rs b/lib/protoflow-zeromq/src/protoflow.zmq.rs new file mode 100644 index 00000000..1afabab7 --- /dev/null +++ b/lib/protoflow-zeromq/src/protoflow.zmq.rs @@ -0,0 +1,70 @@ +// This file is @generated by prost-build. +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct Connect { + #[prost(int64, tag = "1")] + pub output: i64, + #[prost(int64, tag = "2")] + pub input: i64, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct AckConnection { + #[prost(int64, tag = "1")] + pub output: i64, + #[prost(int64, tag = "2")] + pub input: i64, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Message { + #[prost(int64, tag = "1")] + pub output: i64, + #[prost(int64, tag = "2")] + pub input: i64, + #[prost(uint64, tag = "3")] + pub sequence: u64, + #[prost(bytes = "vec", tag = "4")] + pub message: ::prost::alloc::vec::Vec, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct AckMessage { + #[prost(int64, tag = "1")] + pub output: i64, + #[prost(int64, tag = "2")] + pub input: i64, + #[prost(uint64, tag = "3")] + pub sequence: u64, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct CloseOutput { + #[prost(int64, tag = "1")] + pub output: i64, + #[prost(int64, tag = "2")] + pub input: i64, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct CloseInput { + #[prost(int64, tag = "1")] + pub input: i64, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Event { + #[prost(oneof = "event::Payload", tags = "1, 2, 3, 4, 5, 6")] + pub payload: ::core::option::Option, +} +/// Nested message and enum types in `Event`. +pub mod event { + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Payload { + #[prost(message, tag = "1")] + Connect(super::Connect), + #[prost(message, tag = "2")] + AckConnection(super::AckConnection), + #[prost(message, tag = "3")] + Message(super::Message), + #[prost(message, tag = "4")] + AckMessage(super::AckMessage), + #[prost(message, tag = "5")] + CloseOutput(super::CloseOutput), + #[prost(message, tag = "6")] + CloseInput(super::CloseInput), + } +} diff --git a/lib/protoflow-zeromq/src/socket.rs b/lib/protoflow-zeromq/src/socket.rs new file mode 100644 index 00000000..cef3d2f1 --- /dev/null +++ b/lib/protoflow-zeromq/src/socket.rs @@ -0,0 +1,198 @@ +// This is free and unencumbered software released into the public domain. + +use crate::{ + input_port_event_sender, output_port_event_sender, ZmqInputPortState, ZmqOutputPortState, + ZmqTransport, ZmqTransportEvent, +}; +use protoflow_core::{ + prelude::{BTreeMap, String, Vec}, + InputPortID, OutputPortID, PortError, +}; +use tokio::sync::{ + mpsc::{error::SendError, Receiver, Sender}, + RwLock, +}; +use zeromq::{SocketRecv, SocketSend, ZmqMessage}; + +#[derive(Clone, Debug)] +pub enum ZmqSubscriptionRequest { + Subscribe(String), + Unsubscribe(String), +} + +#[cfg(feature = "tracing")] +use tracing::{debug, error, trace, trace_span, warn}; + +pub fn start_pub_socket_worker( + transport: &ZmqTransport, + psock: zeromq::PubSocket, + pub_queue: Receiver, +) { + #[cfg(feature = "tracing")] + let span = trace_span!("ZmqTransport::pub_socket"); + let outputs = transport.outputs.clone(); + let inputs = transport.inputs.clone(); + let mut psock = psock; + let mut pub_queue = pub_queue; + tokio::task::spawn(async move { + while let Some(event) = pub_queue.recv().await { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?event, "sending event to socket")); + + use ZmqTransportEvent::*; + let shortcut_sender = match event { + Connect(_, id) | Message(_, id, _, _) | CloseOutput(_, id) => { + input_port_event_sender(&inputs, id).await + } + AckConnection(id, _) | AckMessage(id, ..) => { + output_port_event_sender(&outputs, id).await + } + CloseInput(..) => None, + }; + + if let Some(sender) = shortcut_sender { + #[cfg(feature = "tracing")] + span.in_scope(|| debug!("attempting to shortcut send directly to target port")); + if sender.send(event.clone()).await.is_ok() { + continue; + } + #[cfg(feature = "tracing")] + span.in_scope(|| warn!("failed to send message with shortcut, sending to socket")); + } + + if let Err(err) = psock.send(event.into()).await { + #[cfg(feature = "tracing")] + span.in_scope(|| error!(?err, "failed to send message")); + } + } + }); +} + +pub async fn subscribe_topics( + topics: &[String], + sub_queue: &Sender, +) -> Result<(), SendError> { + let mut handles = Vec::with_capacity(topics.len()); + for topic in topics { + handles.push(sub_queue.send(ZmqSubscriptionRequest::Subscribe(topic.clone()))); + } + for handle in handles { + handle.await?; + } + Ok(()) +} + +pub async fn unsubscribe_topics( + topics: &[String], + sub_queue: &Sender, +) -> Result<(), SendError> { + let mut handles = Vec::with_capacity(topics.len()); + for topic in topics { + handles.push(sub_queue.send(ZmqSubscriptionRequest::Unsubscribe(topic.clone()))); + } + for handle in handles { + handle.await?; + } + Ok(()) +} + +pub fn start_sub_socket_worker( + transport: &ZmqTransport, + ssock: zeromq::SubSocket, + sub_queue: Receiver, +) { + #[cfg(feature = "tracing")] + let span = trace_span!("ZmqTransport::sub_socket"); + let outputs = transport.outputs.clone(); + let inputs = transport.inputs.clone(); + let mut ssock = ssock; + let mut sub_queue = sub_queue; + tokio::task::spawn(async move { + loop { + tokio::select! { + Ok(msg) = ssock.recv() => { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?msg, "got message from socket")); + + if let Err(err) = handle_zmq_msg(msg, &outputs, &inputs).await { + #[cfg(feature = "tracing")] + span.in_scope(|| error!(?err, "failed to process message")); + } + }, + Some(req) = sub_queue.recv() => { + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?req, "got sub update request")); + + use ZmqSubscriptionRequest::*; + match req { + Subscribe(topic) => if let Err(err) = ssock.subscribe(&topic).await { + #[cfg(feature = "tracing")] + span.in_scope(|| error!(?err, ?topic, "subscribe failed")); + }, + Unsubscribe(topic) => if let Err(err) = ssock.unsubscribe(&topic).await { + #[cfg(feature = "tracing")] + span.in_scope(|| error!(?err, ?topic, "unsubscribe failed")); + } + }; + } + }; + } + }); +} + +async fn handle_zmq_msg( + msg: ZmqMessage, + outputs: &RwLock>>, + inputs: &RwLock>>, +) -> Result<(), PortError> { + #[cfg(feature = "tracing")] + let span = trace_span!("ZmqTransport::handle_zmq_msg"); + + let event = ZmqTransportEvent::try_from(msg)?; + + #[cfg(feature = "tracing")] + span.in_scope(|| trace!(?event, "got event")); + + use ZmqTransportEvent::*; + match event { + // input ports + Connect(_, input_port_id) + | Message(_, input_port_id, _, _) + | CloseOutput(_, input_port_id) => { + let sender = input_port_event_sender(inputs, input_port_id) + .await + .ok_or_else(|| PortError::Invalid(input_port_id.into()))?; + + sender.send(event).await.map_err(|_| PortError::Closed) + } + + // output ports + AckConnection(output_port_id, _) | AckMessage(output_port_id, _, _) => { + let sender = output_port_event_sender(outputs, output_port_id) + .await + .ok_or_else(|| PortError::Invalid(output_port_id.into()))?; + + sender.send(event).await.map_err(|_| PortError::Closed) + } + CloseInput(input_port_id) => { + for (_, state) in outputs.read().await.iter() { + let sender = { + let state = state.read().await; + let ZmqOutputPortState::Connected(_, ref sender, ref id) = *state else { + continue; + }; + if *id != input_port_id { + continue; + } + + sender.clone() + }; + + if let Err(_e) = sender.send(event.clone()).await { + continue; // TODO + } + } + Ok(()) + } + } +}