Skip to content

Commit

Permalink
editoast: bulk get cached projection
Browse files Browse the repository at this point in the history
  • Loading branch information
flomonster committed Jul 9, 2024
1 parent c9332df commit e681c02
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 27 deletions.
48 changes: 48 additions & 0 deletions editoast/src/redis_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,27 @@ impl RedisConnection {
}
}

/// Get a list of deserializable value from redis
#[tracing::instrument(name = "cache:get_bulk", skip(self), err)]
pub async fn json_get_bulk<T: DeserializeOwned, K: Debug + ToRedisArgs + Send + Sync>(
&mut self,
keys: &[K],
) -> Result<Vec<Option<T>>> {
let values: Vec<Option<String>> = self.mget(keys).await?;
values
.into_iter()
.map(|value| match value {
Some(v) => match serde_json::from_str::<T>(&v) {
Ok(value) => Ok(Some(value)),
Err(_) => {
Err(RedisError::from((ErrorKind::TypeError, "Expected valid json")).into())
}
},
None => Ok(None),
})
.collect()
}

/// Get a deserializable value from redis with expiry time
#[tracing::instrument(name = "cache:get_with_expiration", skip(self), err)]
pub async fn json_get_ex<T: DeserializeOwned, K: Debug + ToRedisArgs + Send + Sync>(
Expand Down Expand Up @@ -117,6 +138,33 @@ impl RedisConnection {
self.set_ex(key, str_value, seconds).await?;
Ok(())
}

/// Set a list of serializable values to redis
#[tracing::instrument(name = "cache:set_bulk", skip(self, items), err)]
pub async fn json_set_bulk<K: Debug + ToRedisArgs + Send + Sync, T: Serialize>(
&mut self,
items: &[(K, T)],
) -> Result<()> {
if items.is_empty() {
return Ok(());
}
let mut ser_items = vec![];
for (key, value) in items.iter() {
let str_value = match serde_json::to_string(value) {
Ok(value) => value,
Err(_) => {
return Err(RedisError::from((
ErrorKind::IoError,
"An error occured serializing to json",
))
.into())
}
};
ser_items.push((key, str_value));
}
self.mset(&ser_items).await?;
Ok(())
}
}

#[derive(Clone)]
Expand Down
55 changes: 28 additions & 27 deletions editoast/src/views/v2/train_schedule/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ use editoast_models::DbConnectionPoolV2;
use crate::RedisClient;
use crate::RollingStockModel;

const CACHE_PROJECTION_EXPIRATION: u64 = 604800; // 1 week

editoast_common::schemas! {
ProjectPathTrainResult,
ProjectPathForm,
Expand Down Expand Up @@ -104,7 +102,7 @@ struct ProjectPathTrainResult {
}

/// Project path output is described by time-space points and blocks
#[derive(Debug, Deserialize, Serialize, ToSchema)]
#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
struct CachedProjectPathTrainResult {
/// List of space-time curves sections along the path
#[schema(inline)]
Expand Down Expand Up @@ -194,8 +192,8 @@ async fn project_path(
.await?;

// 1. Retrieve cached projection
let mut hit_cache: HashMap<i64, CachedProjectPathTrainResult> = HashMap::new();
let mut miss_cache = HashMap::new();
let mut trains_hash_values = HashMap::new();
let mut trains_details = HashMap::new();

for (train, sim) in trains.iter().zip(simulations) {
let pathfinding_result = pathfinding_from_train(
Expand Down Expand Up @@ -245,15 +243,25 @@ async fn project_path(
&path_routes,
&path_blocks,
);
let projection: Option<CachedProjectPathTrainResult> = redis_conn
.json_get_ex(&hash, CACHE_PROJECTION_EXPIRATION)
.await?;
trains_hash_values.insert(train.id, hash);
trains_details.insert(train.id, train_details);
}
let cached_projections: Vec<Option<CachedProjectPathTrainResult>> = redis_conn
.json_get_bulk(&trains_hash_values.values().collect::<Vec<_>>())
.await?;

let mut hit_cache: HashMap<i64, CachedProjectPathTrainResult> = HashMap::new();
let mut miss_cache = HashMap::new();
for ((train_id, train_details), projection) in
trains_details.into_iter().zip(cached_projections)
{
if let Some(cached) = projection {
hit_cache.insert(train.id, cached);
hit_cache.insert(train_id, cached);
} else {
miss_cache.insert(train.id, train_details);
miss_cache.insert(train_id, train_details.clone());
}
}

info!(
nb_hit = hit_cache.len(),
nb_miss = miss_cache.len(),
Expand All @@ -274,31 +282,24 @@ async fn project_path(
);
let signal_updates = signal_updates?;

// 3. Store the projection in the cache
for (id, train_details) in miss_cache {
let hash = train_projection_input_hash(
infra.id,
&infra.version,
&train_details,
&path_track_ranges,
&path_routes,
&path_blocks,
);
let cached = CachedProjectPathTrainResult {
// 3. Store the projection in the cache (using pipeline)
let mut new_items = vec![];
for id in miss_cache.keys() {
let hash = &trains_hash_values[id];
let cached_value = CachedProjectPathTrainResult {
space_time_curves: space_time_curves
.get(&id)
.get(id)
.expect("Space time curves not availabe for train")
.clone(),
signal_updates: signal_updates
.get(&id)
.get(id)
.expect("Signal update not availabe for train")
.clone(),
};
redis_conn
.json_set_ex(&hash, &cached, CACHE_PROJECTION_EXPIRATION)
.await?;
hit_cache.insert(id, cached);
hit_cache.insert(*id, cached_value.clone());
new_items.push((hash, cached_value));
}
redis_conn.json_set_bulk(&new_items).await?;

let train_map: HashMap<i64, TrainSchedule> = trains.into_iter().map(|ts| (ts.id, ts)).collect();

Expand Down

0 comments on commit e681c02

Please sign in to comment.