Skip to content

Commit cb9b648

Browse files
authored
OnClosed callback for producer (#293)
* Handle on_closed callback * Implement timeout. Unconfirmed messages on OnClosed
1 parent 7c13c2c commit cb9b648

File tree

9 files changed

+522
-84
lines changed

9 files changed

+522
-84
lines changed

examples/ha_producer.rs

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
use core::panic;
2+
use std::sync::atomic::AtomicBool;
3+
use std::sync::Arc;
4+
use std::time::Duration;
5+
6+
use rabbitmq_stream_client::error::{ProducerPublishError, StreamCreateError};
7+
use rabbitmq_stream_client::types::{ByteCapacity, Message, ResponseCode};
8+
use rabbitmq_stream_client::Environment;
9+
use rabbitmq_stream_client::{
10+
ConfirmationStatus, NoDedup, OnClosed, Producer, RabbitMQStreamResult,
11+
};
12+
use tokio::sync::Notify;
13+
use tokio::sync::RwLock;
14+
use tokio::time::sleep;
15+
use tracing::info;
16+
17+
struct MyHAProducerInner {
18+
environment: Environment,
19+
stream: String,
20+
producer: RwLock<Producer<NoDedup>>,
21+
notify: Notify,
22+
is_opened: AtomicBool,
23+
}
24+
25+
#[derive(Clone)]
26+
struct MyHAProducer(Arc<MyHAProducerInner>);
27+
28+
#[async_trait::async_trait]
29+
impl OnClosed for MyHAProducer {
30+
async fn on_closed(&self, unconfirmed: Vec<Message>) {
31+
info!("Producer is closed. Creating new one");
32+
33+
self.0
34+
.is_opened
35+
.store(false, std::sync::atomic::Ordering::SeqCst);
36+
37+
let mut producer = self.0.producer.write().await;
38+
39+
let new_producer = self
40+
.0
41+
.environment
42+
.producer()
43+
.build(&self.0.stream)
44+
.await
45+
.unwrap();
46+
47+
new_producer.set_on_closed(Box::new(self.clone())).await;
48+
49+
if !unconfirmed.is_empty() {
50+
info!("Resending {} unconfirmed messages.", unconfirmed.len());
51+
if let Err(e) = producer.batch_send_with_confirm(unconfirmed).await {
52+
eprintln!("Error resending unconfirmed messages: {:?}", e);
53+
}
54+
}
55+
56+
*producer = new_producer;
57+
58+
self.0
59+
.is_opened
60+
.store(true, std::sync::atomic::Ordering::SeqCst);
61+
self.0.notify.notify_waiters();
62+
}
63+
}
64+
65+
impl MyHAProducer {
66+
async fn new(environment: Environment, stream: &str) -> RabbitMQStreamResult<Self> {
67+
ensure_stream_exists(&environment, stream).await?;
68+
69+
let producer = environment.producer().build(stream).await.unwrap();
70+
71+
let inner = MyHAProducerInner {
72+
environment,
73+
stream: stream.to_string(),
74+
producer: RwLock::new(producer),
75+
notify: Notify::new(),
76+
is_opened: AtomicBool::new(true),
77+
};
78+
let s = Self(Arc::new(inner));
79+
80+
let p = s.0.producer.write().await;
81+
p.set_on_closed(Box::new(s.clone())).await;
82+
drop(p);
83+
84+
Ok(s)
85+
}
86+
87+
async fn send_with_confirm(
88+
&self,
89+
message: Message,
90+
) -> Result<ConfirmationStatus, ProducerPublishError> {
91+
if !self.0.is_opened.load(std::sync::atomic::Ordering::SeqCst) {
92+
self.0.notify.notified().await;
93+
}
94+
95+
let producer = self.0.producer.read().await;
96+
let err = producer.send_with_confirm(message.clone()).await;
97+
98+
match err {
99+
Ok(s) => Ok(s),
100+
Err(e) => match e {
101+
ProducerPublishError::Timeout | ProducerPublishError::Closed => {
102+
Box::pin(self.send_with_confirm(message)).await
103+
}
104+
_ => return Err(e),
105+
},
106+
}
107+
}
108+
}
109+
110+
async fn ensure_stream_exists(environment: &Environment, stream: &str) -> RabbitMQStreamResult<()> {
111+
let create_response = environment
112+
.stream_creator()
113+
.max_length(ByteCapacity::GB(5))
114+
.create(stream)
115+
.await;
116+
117+
if let Err(e) = create_response {
118+
if let StreamCreateError::Create { stream, status } = e {
119+
match status {
120+
// we can ignore this error because the stream already exists
121+
ResponseCode::StreamAlreadyExists => {}
122+
err => {
123+
panic!("Error creating stream: {:?} {:?}", stream, err);
124+
}
125+
}
126+
}
127+
}
128+
129+
Ok(())
130+
}
131+
132+
#[tokio::main]
133+
async fn main() -> Result<(), Box<dyn std::error::Error>> {
134+
let _ = tracing_subscriber::fmt::try_init();
135+
136+
let environment = Environment::builder().build().await?;
137+
let stream = "hello-rust-stream";
138+
139+
let producer = MyHAProducer::new(environment, stream).await?;
140+
141+
let number_of_messages = 1000000;
142+
for i in 0..number_of_messages {
143+
let msg = Message::builder()
144+
.body(format!("stream message_{}", i))
145+
.build();
146+
producer.send_with_confirm(msg).await?;
147+
sleep(Duration::from_millis(100)).await;
148+
}
149+
150+
Ok(())
151+
}

src/client/dispatcher.rs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,6 @@ use super::{channel::ChannelReceiver, handler::MessageHandler};
1919
#[derive(Clone)]
2020
pub(crate) struct Dispatcher<T>(DispatcherState<T>);
2121

22-
pub(crate) struct DispatcherState<T> {
23-
requests: Arc<RequestsMap>,
24-
correlation_id: Arc<AtomicU32>,
25-
handler: Arc<RwLock<Option<T>>>,
26-
}
27-
28-
impl<T> Clone for DispatcherState<T> {
29-
fn clone(&self) -> Self {
30-
DispatcherState {
31-
requests: self.requests.clone(),
32-
correlation_id: self.correlation_id.clone(),
33-
handler: self.handler.clone(),
34-
}
35-
}
36-
}
37-
3822
struct RequestsMap {
3923
requests: DashMap<u32, Sender<Response>>,
4024
closed: AtomicBool,
@@ -126,6 +110,22 @@ where
126110
}
127111
}
128112

113+
pub(crate) struct DispatcherState<T> {
114+
requests: Arc<RequestsMap>,
115+
correlation_id: Arc<AtomicU32>,
116+
handler: Arc<RwLock<Option<T>>>,
117+
}
118+
119+
impl<T> Clone for DispatcherState<T> {
120+
fn clone(&self) -> Self {
121+
DispatcherState {
122+
requests: self.requests.clone(),
123+
correlation_id: self.correlation_id.clone(),
124+
handler: self.handler.clone(),
125+
}
126+
}
127+
}
128+
129129
impl<T> DispatcherState<T>
130130
where
131131
T: MessageHandler,

src/client/message.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ impl BaseMessage for Message {
2020
}
2121
}
2222

23-
#[derive(Debug)]
23+
#[derive(Debug, Clone)]
2424
pub struct ClientMessage {
2525
publishing_id: u64,
2626
message: Message,
@@ -39,6 +39,10 @@ impl ClientMessage {
3939
pub fn filter_value_extract(&mut self, filter_value_extractor: impl Fn(&Message) -> String) {
4040
self.filter_value = Some(filter_value_extractor(&self.message));
4141
}
42+
43+
pub fn into_message(self) -> Message {
44+
self.message
45+
}
4246
}
4347

4448
impl BaseMessage for ClientMessage {

0 commit comments

Comments
 (0)