Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: getting more than one connection for Redis channel subscriber #178

Merged
merged 3 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/benchmark/src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub struct Args {
#[arg(short, long, default_value_t = 60, help = "Request timeout")]
pub timeout: u8,

#[arg(short, long, default_value_t = true, help = "Authenticate WebSocket")]
#[arg(long, default_value_t = true, help = "Authenticate WebSocket")]
pub authenticate: bool,

#[arg(short, long, default_value_t = 10, help = "Amount of quests to create")]
Expand Down
4 changes: 4 additions & 0 deletions crates/benchmark/src/quests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ fn random_action() -> Action {
}

#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CreateQuestRequest {
name: String,
description: String,
definition: QuestDefinition,
image_url: String,
}

pub fn random_quest() -> CreateQuestRequest {
Expand Down Expand Up @@ -64,6 +66,7 @@ pub fn random_quest() -> CreateQuestRequest {
name: create_random_string(10),
description: create_random_string(100),
definition: QuestDefinition { connections, steps },
image_url: "http://google.com".to_string(),
}
}

Expand Down Expand Up @@ -171,6 +174,7 @@ pub fn grab_some_pies() -> CreateQuestRequest {
},
],
},
image_url: "http://google.com".to_string(),
}
}

Expand Down
53 changes: 27 additions & 26 deletions crates/benchmark/src/quests_simulation.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
use std::{fmt::Display, time::Duration};

use crate::{
args::Args,
client::{create_test_identity, get_signed_headers, TestWebSocketTransport},
quests::{create_random_string, random_quest},
simulation::{Client, Context},
};
use async_trait::async_trait;
use dcl_rpc::{client::RpcClient, stream_protocol::Generator};
use log::{debug, error, info};
Expand All @@ -8,13 +14,6 @@ use rand::{seq::SliceRandom, thread_rng};
use serde::Deserialize;
use tokio::time::timeout;

use crate::{
args::Args,
client::{create_test_identity, get_signed_headers, TestWebSocketTransport},
quests::{create_random_string, random_quest},
simulation::{Client, Context},
};

#[derive(Deserialize)]
struct CreateQuestResponse {
id: String,
Expand All @@ -31,14 +30,14 @@ impl TestContext {
let headers = get_signed_headers(
create_test_identity(),
"post",
"/quests",
"/api/quests",
serde_json::to_string(&quest).unwrap().as_str(),
);

let client = reqwest::Client::new();

let res = client
.post(format!("{api_host}/quests"))
.post(format!("{api_host}/api/quests"))
.header(headers[0].0.clone(), headers[0].1.clone())
.header(headers[1].0.clone(), headers[1].1.clone())
.header(headers[2].0.clone(), headers[2].1.clone())
Expand Down Expand Up @@ -155,9 +154,7 @@ impl ClientState {
.await;
debug!(
"User {} > StartQuestRequest: id {} > Response: {:?}",
quest_id,
&user_address[..4],
response
quest_id, &user_address, response
);
match response {
Ok(StartQuestResponse {
Expand All @@ -174,7 +171,7 @@ impl ClientState {
let Ok(quest_updates) = act else {
error!(
"User {} > Timeout while fetching started quest state",
&user_address[..4]
&user_address
);
return ClientState::StartQuestRequested { updates, quest_id };
};
Expand All @@ -193,13 +190,17 @@ impl ClientState {
quest_instance_id: id,
quest_state: state,
},
_ => {
Some(other) => {
error!(
"User {} > Start Quest > Received update is not the quest state",
&user_address[..4]
"User {} > Start Quest > Received update is not the quest state > {other:?}",
&user_address
);
ClientState::StartQuestRequested { updates, quest_id }
}
None => {
error!("No update. Decoding error");
ClientState::StartQuestRequested { updates, quest_id }
}
}
}
ClientState::MakeQuestProgress {
Expand Down Expand Up @@ -232,7 +233,7 @@ impl ClientState {
_ => {
error!(
"> User {} > Make Quest Progress > event not accepted, retrying",
&user_address[..4]
&user_address
);
ClientState::MakeQuestProgress {
updates,
Expand All @@ -247,24 +248,24 @@ impl ClientState {
quest_instance_id,
quest_state,
} => {
debug!("User {} > Fetch next event > Next.", &user_address[..4]);
debug!("User {} > Fetch next event > Next.", &user_address);
let act = timeout(context.timeout, updates.next()).await;
let Ok(quest_update) = act else {
error!(
"User {} > Timeout while fetching next event!",
&user_address[..4]
&user_address
);
return ClientState::FetchQuestUpdate {
updates,
quest_instance_id,
quest_state,
};
};
debug!("User {} > Fetch next event > Done.", &user_address[..4]);
debug!("User {} > Fetch next event > Done.", &user_address);

debug!(
"User {} > quest_update received > {quest_update:?}",
&user_address[..4]
&user_address
);

match quest_update {
Expand All @@ -286,7 +287,7 @@ impl ClientState {
}
}
Some(user_update::Message::EventIgnored(_)) => {
error!("User {} > Event ignored", &user_address[..4]);
error!("User {} > Event ignored", &user_address);
ClientState::MakeQuestProgress {
updates,
quest_instance_id,
Expand All @@ -299,7 +300,7 @@ impl ClientState {
})) => {
debug!(
"User {} > QuestStateUpdate received for wrong quest instance {}, expected instance was {}",
&user_address[..4],
&user_address,
instance_id,
quest_instance_id
);
Expand All @@ -321,9 +322,9 @@ impl ClientState {
};

if std::mem::discriminant(&state) == current_state_discriminant {
info!("User {} > State didn't change", &user_address[..4]);
info!("User {} > State didn't change", &user_address);
} else {
info!("User {} > {state}", &user_address[..4]);
info!("User {} > {state}", &user_address);
}
state
}
Expand All @@ -337,7 +338,7 @@ impl Context for TestContext {
for _ in 0..args.quests {
match Self::create_random_quest(&args.api_host).await {
Ok(quest_id) => quest_ids.push(quest_id),
Err(reason) => debug!("Quest Creation > Couldn't POST quest: {reason}"),
Err(reason) => error!("Quest Creation > Couldn't POST quest: {reason}"),
}
}
Self {
Expand Down
117 changes: 47 additions & 70 deletions crates/message_broker/src/channel.rs
Original file line number Diff line number Diff line change
@@ -1,89 +1,71 @@
use crate::redis::Redis;
use async_trait::async_trait;
use deadpool_redis::redis::AsyncCommands;
use deadpool_redis::redis::{aio::PubSub, AsyncCommands};
use futures_util::{Future, StreamExt as _};
use log::{debug, error};
use quests_protocol::definitions::*;
use std::sync::Arc;
use tokio::task::JoinHandle;

pub trait ChannelSubscriber<OnUpdateOutput>: Send + Sync {
fn subscribe<
NewPublishment: ProtocolMessage + Default,
U: Future<Output = OnUpdateOutput> + Send + Sync,
>(
&self,
channel_name: &str,
on_update_fn: impl Fn(NewPublishment) -> U + Send + Sync + 'static,
) -> JoinHandle<()>;
}

#[async_trait]
pub trait ChannelPublisher<Publishment>: Send + Sync {
async fn publish(&self, update: Publishment);
#[derive(Debug)]
pub enum RedisChannelSubscriberError {
RedisError,
NoConnectionAvailable,
}

pub struct RedisChannelSubscriber {
redis: Arc<Redis>,
subscriptor: PubSub,
}

impl RedisChannelSubscriber {
pub fn new(redis: Arc<Redis>) -> Self {
Self { redis }
pub async fn new(
redis: Arc<Redis>,
channel_name: &str,
) -> Result<Self, RedisChannelSubscriberError> {
let connection = redis
.get_async_connection()
.await
.ok_or(RedisChannelSubscriberError::NoConnectionAvailable)?;

let connection = deadpool_redis::Connection::take(connection);
let mut pubsub = connection.into_pubsub();

pubsub
.subscribe(channel_name)
.await
.map_err(|_| RedisChannelSubscriberError::RedisError)?;

Ok(Self {
subscriptor: pubsub,
})
}
}

impl ChannelSubscriber<bool> for RedisChannelSubscriber {
/// Listens to a specific channel for new messages
fn subscribe<
pub async fn on_new_message<
NewPublishment: ProtocolMessage + Default,
U: Future<Output = bool> + Send + Sync,
U: Future<Output = ()> + Send + Sync,
>(
&self,
channel_name: &str,
on_update_fn: impl Fn(NewPublishment) -> U + Send + Sync + 'static,
) -> JoinHandle<()> {
let redis = self.redis.clone();
let channel_name = channel_name.to_string();
tokio::spawn(async move {
debug!("Subscribing to channel {channel_name}");
let connection = redis
.get_async_connection()
.await
.expect("to get a connection"); // TODO: Error handling
&mut self,
on_new_message_fn: impl Fn(NewPublishment) -> U + Send + Sync + 'static,
) {
let mut on_message_stream = self.subscriptor.on_message();

let connection = deadpool_redis::Connection::take(connection);
let mut pubsub = connection.into_pubsub();
pubsub
.subscribe(channel_name.clone())
.await
.expect("to be able to listen to this channel");

debug!("Subscribed to channel {channel_name}!");
let mut on_message_stream = pubsub.on_message();

loop {
if let Some(message) = on_message_stream.next().await {
let payload = message.get_payload::<Vec<u8>>();
match payload {
Ok(payload) => {
debug!("New message received from channel");
let update = NewPublishment::decode(&*payload);
match update {
Ok(update) => {
debug!("New publishment parsed {update:?}");
if !on_update_fn(update).await {
break;
}
}
Err(_) => error!("Couldn't deserialize update"),
loop {
if let Some(message) = on_message_stream.next().await {
let payload = message.get_payload::<Vec<u8>>();
match payload {
Ok(payload) => {
debug!("New message received from channel");
let update = NewPublishment::decode(&*payload);
match update {
Ok(update) => {
debug!("New publishment parsed {update:?}");
on_new_message_fn(update).await;
}
Err(_) => error!("Couldn't deserialize update"),
}
Err(_) => error!("Couldn't retrieve payload"),
}
Err(_) => error!("Couldn't retrieve payload"),
}
}
})
}
}
}

Expand All @@ -99,13 +81,8 @@ impl RedisChannelPublisher {
channel_name: channel_name.to_string(),
}
}
}

#[async_trait]
impl<Publishment: ProtocolMessage + 'static> ChannelPublisher<Publishment>
for RedisChannelPublisher
{
async fn publish(&self, publishment: Publishment) {
pub async fn publish<P: ProtocolMessage + 'static>(&self, publishment: P) {
debug!("Publish > Getting connection...");
let mut publish = self
.redis
Expand Down
8 changes: 5 additions & 3 deletions crates/server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,25 @@ pub async fn run_app() {
redis.clone(),
QUESTS_CHANNEL_NAME,
));
let quests_channel_subscriber = RedisChannelSubscriber::new(redis.clone());
let quests_channel_subscriber = RedisChannelSubscriber::new(redis.clone(), QUESTS_CHANNEL_NAME)
.await
.expect("> run_app > Couldn't initialize subscriber");

let http_metrics_collector = Arc::new(HttpMetricsCollectorBuilder::default().build());

let (warp_websocket_server, rpc_server) = rpc::run_rpc_server((
config.clone(),
database.clone(),
events_queue.clone(),
quests_channel_subscriber,
quests_channel_publisher.clone(),
quests_channel_subscriber,
))
.await;

let event_processing = event_processing::run_event_processor(
database.clone(),
events_queue.clone(),
quests_channel_publisher.clone(),
quests_channel_publisher,
);

let actix_rest_api_server = api::run_server(
Expand Down
Loading
Loading