diff --git a/editoast/src/redis_utils.rs b/editoast/src/redis_utils.rs index 665625acccf..34cf0767b67 100644 --- a/editoast/src/redis_utils.rs +++ b/editoast/src/redis_utils.rs @@ -77,6 +77,27 @@ impl RedisConnection { } } + /// Get a list of deserializable value from redis + #[tracing::instrument(name = "cache:get_bulk", skip(self), err)] + pub async fn json_get_bulk( + &mut self, + keys: &[K], + ) -> Result>> { + let values: Vec> = self.mget(keys).await?; + values + .into_iter() + .map(|value| match value { + Some(v) => match serde_json::from_str::(&v) { + Ok(value) => Ok(Some(value)), + Err(_) => { + Err(RedisError::from((ErrorKind::TypeError, "Expected valid json")).into()) + } + }, + None => Ok(None), + }) + .collect() + } + /// Get a deserializable value from redis with expiry time #[tracing::instrument(name = "cache:get_with_expiration", skip(self), err)] pub async fn json_get_ex( @@ -117,6 +138,33 @@ impl RedisConnection { self.set_ex(key, str_value, seconds).await?; Ok(()) } + + /// Set a list of serializable values to redis + #[tracing::instrument(name = "cache:set_bulk", skip(self, items), err)] + pub async fn json_set_bulk( + &mut self, + items: &[(K, T)], + ) -> Result<()> { + if items.is_empty() { + return Ok(()); + } + let mut ser_items = vec![]; + for (key, value) in items.iter() { + let str_value = match serde_json::to_string(value) { + Ok(value) => value, + Err(_) => { + return Err(RedisError::from(( + ErrorKind::IoError, + "An error occured serializing to json", + )) + .into()) + } + }; + ser_items.push((key, str_value)); + } + self.mset(&ser_items).await?; + Ok(()) + } } #[derive(Clone)] diff --git a/editoast/src/views/v2/train_schedule/projection.rs b/editoast/src/views/v2/train_schedule/projection.rs index 355faa52fda..4a4f99d240d 100644 --- a/editoast/src/views/v2/train_schedule/projection.rs +++ b/editoast/src/views/v2/train_schedule/projection.rs @@ -48,8 +48,6 @@ use editoast_models::DbConnectionPoolV2; use crate::RedisClient; use crate::RollingStockModel; -const CACHE_PROJECTION_EXPIRATION: u64 = 604800; // 1 week - editoast_common::schemas! { ProjectPathTrainResult, ProjectPathForm, @@ -104,7 +102,7 @@ struct ProjectPathTrainResult { } /// Project path output is described by time-space points and blocks -#[derive(Debug, Deserialize, Serialize, ToSchema)] +#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)] struct CachedProjectPathTrainResult { /// List of space-time curves sections along the path #[schema(inline)] @@ -194,8 +192,8 @@ async fn project_path( .await?; // 1. Retrieve cached projection - let mut hit_cache: HashMap = HashMap::new(); - let mut miss_cache = HashMap::new(); + let mut trains_hash_values = HashMap::new(); + let mut trains_details = HashMap::new(); for (train, sim) in trains.iter().zip(simulations) { let pathfinding_result = pathfinding_from_train( @@ -245,15 +243,25 @@ async fn project_path( &path_routes, &path_blocks, ); - let projection: Option = redis_conn - .json_get_ex(&hash, CACHE_PROJECTION_EXPIRATION) - .await?; + trains_hash_values.insert(train.id, hash); + trains_details.insert(train.id, train_details); + } + let cached_projections: Vec> = redis_conn + .json_get_bulk(&trains_hash_values.values().collect::>()) + .await?; + + let mut hit_cache: HashMap = HashMap::new(); + let mut miss_cache = HashMap::new(); + for ((train_id, train_details), projection) in + trains_details.into_iter().zip(cached_projections) + { if let Some(cached) = projection { - hit_cache.insert(train.id, cached); + hit_cache.insert(train_id, cached); } else { - miss_cache.insert(train.id, train_details); + miss_cache.insert(train_id, train_details.clone()); } } + info!( nb_hit = hit_cache.len(), nb_miss = miss_cache.len(), @@ -274,31 +282,24 @@ async fn project_path( ); let signal_updates = signal_updates?; - // 3. Store the projection in the cache - for (id, train_details) in miss_cache { - let hash = train_projection_input_hash( - infra.id, - &infra.version, - &train_details, - &path_track_ranges, - &path_routes, - &path_blocks, - ); - let cached = CachedProjectPathTrainResult { + // 3. Store the projection in the cache (using pipeline) + let mut new_items = vec![]; + for id in miss_cache.keys() { + let hash = &trains_hash_values[id]; + let cached_value = CachedProjectPathTrainResult { space_time_curves: space_time_curves - .get(&id) + .get(id) .expect("Space time curves not availabe for train") .clone(), signal_updates: signal_updates - .get(&id) + .get(id) .expect("Signal update not availabe for train") .clone(), }; - redis_conn - .json_set_ex(&hash, &cached, CACHE_PROJECTION_EXPIRATION) - .await?; - hit_cache.insert(id, cached); + hit_cache.insert(*id, cached_value.clone()); + new_items.push((hash, cached_value)); } + redis_conn.json_set_bulk(&new_items).await?; let train_map: HashMap = trains.into_iter().map(|ts| (ts.id, ts)).collect();