From 700b7841ab80898dd150a2372ac9b91b962e80eb Mon Sep 17 00:00:00 2001 From: Florian Amsallem Date: Tue, 9 Jul 2024 18:59:35 +0200 Subject: [PATCH 1/2] editoast: improve performances of batch pathfinding and simulation - Massive reduction of Redis call - Stop handling expiration time of redis entries --- editoast/src/core/v2/simulation.rs | 18 +- editoast/src/redis_utils.rs | 67 +---- editoast/src/views/v2/path.rs | 6 +- editoast/src/views/v2/path/pathfinding.rs | 230 +++++++++----- editoast/src/views/v2/path/properties.rs | 10 +- editoast/src/views/v2/timetable.rs | 34 +-- editoast/src/views/v2/timetable/stdcm.rs | 58 ++-- editoast/src/views/v2/train_schedule.rs | 281 +++++++++--------- .../src/views/v2/train_schedule/projection.rs | 29 +- editoast/src/views/v2/train_schedule/proxy.rs | 90 ------ 10 files changed, 350 insertions(+), 473 deletions(-) delete mode 100644 editoast/src/views/v2/train_schedule/proxy.rs diff --git a/editoast/src/core/v2/simulation.rs b/editoast/src/core/v2/simulation.rs index 61a7dfc34e0..4fd775590d8 100644 --- a/editoast/src/core/v2/simulation.rs +++ b/editoast/src/core/v2/simulation.rs @@ -137,7 +137,7 @@ pub struct SimulationPath { pub track_section_ranges: Vec, } -#[derive(Deserialize, Serialize, Clone, Debug, ToSchema)] +#[derive(Deserialize, Default, Serialize, Clone, Debug, ToSchema)] #[schema(as = ReportTrainV2)] pub struct ReportTrain { /// List of positions of a train @@ -152,7 +152,7 @@ pub struct ReportTrain { pub scheduled_points_honored: bool, } -#[derive(Deserialize, Serialize, Clone, Debug, ToSchema)] +#[derive(Deserialize, Default, Serialize, Clone, Debug, ToSchema)] pub struct CompleteReportTrain { #[serde(flatten)] #[schema(value_type = ReportTrainV2)] @@ -200,7 +200,7 @@ pub struct RoutingZoneRequirement { pub end_time: u64, } -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)] +#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize, ToSchema)] pub struct ElectricalProfiles { /// List of `n` boundaries of the ranges. /// A boundary is a distance from the beginning of the path in mm. @@ -279,6 +279,18 @@ pub enum SimulationResponse { }, } +impl Default for SimulationResponse { + fn default() -> Self { + Self::Success { + base: Default::default(), + provisional: Default::default(), + final_output: Default::default(), + mrsp: Default::default(), + electrical_profiles: Default::default(), + } + } +} + impl AsCoreRequest> for SimulationRequest { const METHOD: reqwest::Method = reqwest::Method::POST; const URL_PATH: &'static str = "/v2/standalone_simulation"; diff --git a/editoast/src/redis_utils.rs b/editoast/src/redis_utils.rs index 496c62b4970..753057bf1c7 100644 --- a/editoast/src/redis_utils.rs +++ b/editoast/src/redis_utils.rs @@ -8,7 +8,6 @@ use redis::cluster_async::ClusterConnection; use redis::AsyncCommands; use redis::Client; use redis::ErrorKind; -use redis::Expiry; use redis::RedisError; use redis::RedisFuture; use redis::RedisResult; @@ -60,7 +59,7 @@ impl ConnectionLike for RedisConnection { impl RedisConnection { /// Get a deserializable value from redis - #[tracing::instrument(name = "cache:get", skip(self), err)] + #[tracing::instrument(name = "cache:json_get", skip(self), err)] pub async fn json_get( &mut self, key: K, @@ -102,32 +101,12 @@ impl RedisConnection { .collect() } - /// Get a deserializable value from redis with expiry time - #[tracing::instrument(name = "cache:get_with_expiration", skip(self), err)] - pub async fn json_get_ex( - &mut self, - key: K, - seconds: u64, - ) -> Result> { - let value: Option = self.get_ex(key, Expiry::EX(seconds as usize)).await?; - match value { - Some(v) => match serde_json::from_str(&v) { - Ok(value) => Ok(value), - Err(_) => { - Err(RedisError::from((ErrorKind::TypeError, "Expected valid json")).into()) - } - }, - None => Ok(None), - } - } - /// Set a serializable value to redis with expiry time - #[tracing::instrument(name = "cache:set_with_expiration", skip(self, value), err)] - pub async fn json_set_ex( + #[tracing::instrument(name = "cache:json_set", skip(self, value), err)] + pub async fn json_set( &mut self, key: K, value: &T, - seconds: u64, ) -> Result<()> { let str_value = match serde_json::to_string(value) { Ok(value) => value, @@ -139,7 +118,7 @@ impl RedisConnection { .into()) } }; - self.set_ex(key, str_value, seconds).await?; + self.set(key, str_value).await?; Ok(()) } @@ -209,41 +188,3 @@ impl RedisClient { } } } - -#[cfg(test)] -mod tests { - use super::RedisClient; - use rstest::rstest; - use serde::{Deserialize, Serialize}; - use tokio::time::{sleep, Duration}; - - /// Test get and set json values to redis - #[rstest] - async fn json_get_set() { - #[derive(Serialize, Deserialize, Debug, PartialEq)] - struct TestStruct { - name: String, - age: u8, - } - - let redis_client = RedisClient::new(Default::default()).unwrap(); - let mut redis = redis_client.get_connection().await.unwrap(); - - let key = "__test__.json_get_set"; - let test_struct = TestStruct { - name: "John".to_string(), - age: 25, - }; - - redis.json_set_ex(key, &test_struct, 60).await.unwrap(); - let value: TestStruct = redis.json_get_ex(key, 2).await.unwrap().unwrap(); - assert_eq!(value, test_struct); - - // Wait for 5 seconds - sleep(Duration::from_secs(3)).await; - - // Check if the value has expired - let value = redis.json_get::(key).await.unwrap(); - assert_eq!(value, None); - } -} diff --git a/editoast/src/views/v2/path.rs b/editoast/src/views/v2/path.rs index ec3cc360167..bca88e738ae 100644 --- a/editoast/src/views/v2/path.rs +++ b/editoast/src/views/v2/path.rs @@ -2,7 +2,7 @@ pub mod pathfinding; pub mod projection; mod properties; -pub use pathfinding::pathfinding_from_train; +pub use pathfinding::pathfinding_from_train_batch; use editoast_derive::EditoastError; use thiserror::Error; @@ -13,10 +13,6 @@ use crate::modelsv2::prelude::*; use crate::modelsv2::Infra; use editoast_models::DbConnection; -/// Expiration time for the cache of the pathfinding and path properties. -/// Note: 604800 seconds = 1 week -const CACHE_PATH_EXPIRATION: u64 = 604800; - crate::routes! { properties::routes(), pathfinding::routes(), diff --git a/editoast/src/views/v2/path/pathfinding.rs b/editoast/src/views/v2/path/pathfinding.rs index 30094da6eb3..fabc99bb0f1 100644 --- a/editoast/src/views/v2/path/pathfinding.rs +++ b/editoast/src/views/v2/path/pathfinding.rs @@ -15,7 +15,6 @@ use serde::Deserialize; use tracing::info; use utoipa::ToSchema; -use super::CACHE_PATH_EXPIRATION; use crate::core::v2::pathfinding::PathfindingRequest; use crate::core::v2::pathfinding::PathfindingResult; use crate::core::AsCoreRequest; @@ -27,12 +26,12 @@ use crate::modelsv2::OperationalPointModel; use crate::modelsv2::Retrieve; use crate::modelsv2::RetrieveBatch; use crate::modelsv2::RetrieveBatchUnchecked; +use crate::modelsv2::RollingStockModel; use crate::modelsv2::TrackSectionModel; use crate::redis_utils::RedisClient; use crate::redis_utils::RedisConnection; use crate::views::get_app_version; use crate::views::v2::path::PathfindingError; -use crate::views::v2::train_schedule::TrainScheduleProxy; use editoast_models::DbConnection; use editoast_models::DbConnectionPoolV2; use editoast_schemas::infra::OperationalPoint; @@ -95,117 +94,202 @@ pub async fn post( infra_id: *infra_id, }) .await?; + Ok(Json( - pathfinding_blocks(conn, &mut redis_conn, core, &infra, &path_input).await?, + pathfinding_blocks(conn, &mut redis_conn, core, &infra, path_input).await?, )) } +/// Pathfinding computation given a path input async fn pathfinding_blocks( conn: &mut DbConnection, redis_conn: &mut RedisConnection, core: Arc, infra: &Infra, - path_input: &PathfindingInput, + path_input: PathfindingInput, ) -> Result { - // Compute unique hash of PathInput - let hash = path_input_hash(infra.id, &infra.version, path_input); + let mut path = pathfinding_blocks_batch(conn, redis_conn, core, infra, &[path_input]).await?; + Ok(path.pop().unwrap()) +} + +/// Pathfinding batch computation given a list of path inputs +async fn pathfinding_blocks_batch( + conn: &mut DbConnection, + redis_conn: &mut RedisConnection, + core: Arc, + infra: &Infra, + pathfinding_inputs: &[PathfindingInput], +) -> Result> { + // Compute hashes of all path_inputs + let hashes: Vec<_> = pathfinding_inputs + .iter() + .map(|input| path_input_hash(infra.id, &infra.version, input)) + .collect(); + // Try to retrieve the result from Redis - let result: Option = - redis_conn.json_get_ex(&hash, CACHE_PATH_EXPIRATION).await?; - if let Some(pathfinding) = result { - info!("Hit cache"); - return Ok(pathfinding); - } - // If miss cache: + let mut pathfinding_results: Vec> = + redis_conn.json_get_bulk(&hashes).await?; + + // Report number of hit cache + let nb_hit = pathfinding_results.iter().flatten().count(); + info!( + nb_hit, + nb_miss = pathfinding_inputs.len() - nb_hit, + "Hit cache" + ); + + // Handle miss cache: // 1) extract locations from path items - let path_items = path_input.clone().path_items; + let mut to_cache = vec![]; + let mut pathfinding_requests = vec![]; + let mut pathfinding_requests_index = vec![]; + for (index, (pathfinding_result, pathfinding_input)) in pathfinding_results + .iter_mut() + .zip(pathfinding_inputs) + .enumerate() + { + if pathfinding_result.is_some() { + continue; + } + + match build_pathfinding_request(pathfinding_input, conn, infra).await? { + Ok(pathfinding_request) => { + pathfinding_requests.push(pathfinding_request); + pathfinding_requests_index.push(index); + } + Err(result) => { + *pathfinding_result = Some(result.clone()); + to_cache.push((&hashes[index], result)); + } + } + } + + // 2) Send pathfinding requests to core + let mut futures = vec![]; + for request in &pathfinding_requests { + futures.push(Box::pin(request.fetch(core.as_ref()))); + } + let computed_paths: Vec<_> = futures::future::join_all(futures) + .await + .into_iter() + .collect::>()?; + + for (index, computed_path) in computed_paths.into_iter().enumerate() { + let path_index = pathfinding_requests_index[index]; + to_cache.push((&hashes[path_index], computed_path.clone())); + pathfinding_results[path_index] = Some(computed_path); + } + + // 3) Put in cache + redis_conn.json_set_bulk(&to_cache).await?; + + Ok(pathfinding_results.into_iter().flatten().collect()) +} + +async fn build_pathfinding_request( + pathfinding_input: &PathfindingInput, + conn: &mut DbConnection, + infra: &Infra, +) -> Result> { + let path_items = &pathfinding_input.path_items; if path_items.len() <= 1 { - return Ok(PathfindingResult::NotEnoughPathItems); + return Ok(Err(PathfindingResult::NotEnoughPathItems)); } - let result = extract_location_from_path_items(conn, infra.id, &path_items).await?; - let track_offsets = match result { + let track_offsets = match extract_location_from_path_items(conn, infra.id, path_items).await? { Ok(track_offsets) => track_offsets, - Err(e) => return Ok(e.into()), + Err(e) => return Ok(Err(e.into())), }; // Check if tracks exist - if let Err(pathfinding_result) = - check_tracks_from_path_items(conn, infra.id, &track_offsets).await? - { - return Ok(pathfinding_result); + if let Err(err) = check_tracks_from_path_items(conn, infra.id, &track_offsets).await? { + return Ok(Err(err)); } - // 2) Compute path from core - let pathfinding_request = PathfindingRequest { + // Create the pathfinding request + Ok(Ok(PathfindingRequest { infra: infra.id, expected_version: infra.version.clone(), path_items: track_offsets, - rolling_stock_loading_gauge: path_input.rolling_stock_loading_gauge, - rolling_stock_is_thermal: path_input.rolling_stock_is_thermal, - rolling_stock_supported_electrifications: path_input + rolling_stock_loading_gauge: pathfinding_input.rolling_stock_loading_gauge, + rolling_stock_is_thermal: pathfinding_input.rolling_stock_is_thermal, + rolling_stock_supported_electrifications: pathfinding_input .rolling_stock_supported_electrifications .clone(), - rolling_stock_supported_signaling_systems: path_input + rolling_stock_supported_signaling_systems: pathfinding_input .rolling_stock_supported_signaling_systems .clone(), - }; - let pathfinding_result = pathfinding_request.fetch(core.as_ref()).await?; - - // 3) Put in cache - redis_conn - .json_set_ex(&hash, &pathfinding_result, CACHE_PATH_EXPIRATION) - .await?; - - Ok(pathfinding_result) + })) } -/// Compute a path given a batch of trainschedule and an infrastructure. -/// -/// ## Important -/// -/// If this function was called with the same train schedule, the result will be cached. -/// If you call this function multiple times with the same train schedule but with another infra, then you must provide a fresh `cache`. +/// Compute a path given a train schedule and an infrastructure. pub async fn pathfinding_from_train( conn: &mut DbConnection, redis: &mut RedisConnection, core: Arc, infra: &Infra, train_schedule: TrainSchedule, - proxy: Arc, ) -> Result { - if let Some(res) = proxy.get_pathfinding_result(train_schedule.id) { - return Ok(res); - } + let rolling_stocks = + RollingStockModel::retrieve(conn, train_schedule.rolling_stock_name.clone()) + .await? + .into_iter() + .map(|rs| (rs.name.clone(), rs)) + .collect(); - // Retrieve rolling stock - let rolling_stock_name = train_schedule.rolling_stock_name.clone(); - let Some(rolling_stock) = proxy - .get_rolling_stock(rolling_stock_name.clone(), conn) - .await? - else { - return Ok(PathfindingResult::RollingStockNotFound { rolling_stock_name }); - }; + Ok( + pathfinding_from_train_batch(conn, redis, core, infra, &[train_schedule], &rolling_stocks) + .await? + .pop() + .unwrap(), + ) +} - // Create the path input - let path_input = PathfindingInput { - rolling_stock_loading_gauge: rolling_stock.loading_gauge, - rolling_stock_is_thermal: rolling_stock.has_thermal_curves(), - rolling_stock_supported_electrifications: rolling_stock.supported_electrification(), - rolling_stock_supported_signaling_systems: rolling_stock.supported_signaling_systems.0, - path_items: train_schedule - .path - .into_iter() - .map(|item| item.location) - .collect(), - }; +/// Compute a path given a batch of trainschedule and an infrastructure. +pub async fn pathfinding_from_train_batch( + conn: &mut DbConnection, + redis: &mut RedisConnection, + core: Arc, + infra: &Infra, + train_schedules: &[TrainSchedule], + rolling_stocks: &HashMap, +) -> Result> { + let mut results = vec![PathfindingResult::NotEnoughPathItems; train_schedules.len()]; + let mut to_compute = vec![]; + let mut to_compute_index = vec![]; + for (index, train_schedule) in train_schedules.iter().enumerate() { + // Retrieve rolling stock + let rolling_stock_name = &train_schedule.rolling_stock_name; + let Some(rolling_stock) = rolling_stocks.get(rolling_stock_name).cloned() else { + let rolling_stock_name = rolling_stock_name.clone(); + results[index] = PathfindingResult::RollingStockNotFound { rolling_stock_name }; + continue; + }; - match pathfinding_blocks(conn, redis, core, infra, &path_input).await { - Ok(res) => { - proxy.set_pathfinding_result(train_schedule.id, res.clone()); - Ok(res) - } - err => err, + // Create the path input + let path_input = PathfindingInput { + rolling_stock_loading_gauge: rolling_stock.loading_gauge, + rolling_stock_is_thermal: rolling_stock.has_thermal_curves(), + rolling_stock_supported_electrifications: rolling_stock.supported_electrification(), + rolling_stock_supported_signaling_systems: rolling_stock.supported_signaling_systems.0, + path_items: train_schedule + .path + .clone() + .into_iter() + .map(|item| item.location) + .collect(), + }; + to_compute.push(path_input); + to_compute_index.push(index); + } + + for (index, res) in pathfinding_blocks_batch(conn, redis, core, infra, &to_compute) + .await? + .into_iter() + .enumerate() + { + results[to_compute_index[index]] = res; } + Ok(results) } /// Generates a unique hash based on the pathfinding entries. diff --git a/editoast/src/views/v2/path/properties.rs b/editoast/src/views/v2/path/properties.rs index a025c39bf5d..851b8a22fa3 100644 --- a/editoast/src/views/v2/path/properties.rs +++ b/editoast/src/views/v2/path/properties.rs @@ -21,7 +21,6 @@ use std::hash::Hasher; use tracing::info; use utoipa::ToSchema; -use super::CACHE_PATH_EXPIRATION; use crate::client::get_app_version; use crate::core::v2::path_properties::OperationalPointOnPath; use crate::core::v2::path_properties::PathPropertiesRequest; @@ -230,10 +229,7 @@ async fn retrieve_path_properties( let track_ranges = &path_properties_input.track_section_ranges; let hash = path_properties_input_hash(infra, infra_version, track_ranges); - let path_properties: PathProperties = redis_conn - .json_get_ex(&hash, CACHE_PATH_EXPIRATION) - .await? - .unwrap_or_default(); + let path_properties: PathProperties = redis_conn.json_get(&hash).await?.unwrap_or_default(); Ok(path_properties) } @@ -251,9 +247,7 @@ async fn cache_path_properties( let hash = path_properties_input_hash(infra, infra_version, track_ranges); // Cache all properties except electrifications - redis_conn - .json_set_ex(&hash, &path_properties, CACHE_PATH_EXPIRATION) - .await?; + redis_conn.json_set(&hash, &path_properties).await?; Ok(()) } diff --git a/editoast/src/views/v2/timetable.rs b/editoast/src/views/v2/timetable.rs index 83f05cc85e2..571fc85dbf4 100644 --- a/editoast/src/views/v2/timetable.rs +++ b/editoast/src/views/v2/timetable.rs @@ -2,7 +2,6 @@ pub mod stdcm; use std::collections::HashMap; use std::ops::DerefMut as _; -use std::sync::Arc; use actix_web::delete; use actix_web::get; @@ -35,13 +34,11 @@ use crate::modelsv2::timetable::TimetableWithTrains; use crate::modelsv2::train_schedule::TrainSchedule; use crate::modelsv2::train_schedule::TrainScheduleChangeset; use crate::modelsv2::Infra; -use crate::modelsv2::RollingStockModel; use crate::views::pagination::PaginatedList; use crate::views::pagination::PaginationQueryParam; use crate::views::pagination::PaginationStats; use crate::views::v2::train_schedule::train_simulation_batch; use crate::views::v2::train_schedule::TrainScheduleForm; -use crate::views::v2::train_schedule::TrainScheduleProxy; use crate::views::v2::train_schedule::TrainScheduleResult; use crate::CoreClient; use crate::RedisClient; @@ -329,50 +326,41 @@ pub async fn conflicts( query: Query, ) -> Result>> { let db_pool = db_pool.into_inner(); - let conn = &mut db_pool.clone().get().await?; let redis_client = redis_client.into_inner(); let core_client = core_client.into_inner(); let timetable_id = timetable_id.into_inner().id; let infra_id = query.into_inner().infra_id; // 1. Retrieve Timetable / Infra / Trains / Simultion - let timetable_trains = TimetableWithTrains::retrieve_or_fail(conn, timetable_id, || { - TimetableError::NotFound { timetable_id } - }) + let timetable_trains = TimetableWithTrains::retrieve_or_fail( + db_pool.get().await?.deref_mut(), + timetable_id, + || TimetableError::NotFound { timetable_id }, + ) .await?; - let timetable: Timetable = timetable_trains.clone().into(); - let infra = Infra::retrieve_or_fail(conn, infra_id, || TimetableError::InfraNotFound { - infra_id, + let infra = Infra::retrieve_or_fail(db_pool.get().await?.deref_mut(), infra_id, || { + TimetableError::InfraNotFound { infra_id } }) .await?; let (trains, _): (Vec<_>, _) = - TrainSchedule::retrieve_batch(conn, timetable_trains.train_ids).await?; - - let (rolling_stocks, _): (Vec<_>, _) = RollingStockModel::retrieve_batch( - db_pool.get().await?.deref_mut(), - trains - .iter() - .map::(|t| t.rolling_stock_name.clone()), - ) - .await?; - - let proxy = Arc::new(TrainScheduleProxy::new(&rolling_stocks, &[timetable])); + TrainSchedule::retrieve_batch(db_pool.get().await?.deref_mut(), timetable_trains.train_ids) + .await?; let simulations = train_simulation_batch( - db_pool.clone(), + db_pool.get().await?.deref_mut(), redis_client.clone(), core_client.clone(), &trains, &infra, - proxy, ) .await?; // 2. Build core request let mut trains_requirements = HashMap::with_capacity(trains.len()); for (train, sim) in trains.into_iter().zip(simulations) { + let (sim, _) = sim; let final_output = match sim { SimulationResponse::Success { final_output, .. } => final_output, _ => continue, diff --git a/editoast/src/views/v2/timetable/stdcm.rs b/editoast/src/views/v2/timetable/stdcm.rs index d9439c17bc0..e3081530d1e 100644 --- a/editoast/src/views/v2/timetable/stdcm.rs +++ b/editoast/src/views/v2/timetable/stdcm.rs @@ -27,7 +27,6 @@ use crate::core::v2::stdcm::{STDCMRequest, STDCMStepTimingData}; use crate::core::AsCoreRequest; use crate::core::CoreClient; use crate::error::Result; -use crate::modelsv2::timetable::Timetable; use crate::modelsv2::timetable::TimetableWithTrains; use crate::modelsv2::train_schedule::TrainSchedule; use crate::modelsv2::work_schedules::WorkSchedule; @@ -35,8 +34,8 @@ use crate::modelsv2::RollingStockModel; use crate::modelsv2::{Infra, List}; use crate::views::v2::path::pathfinding::extract_location_from_path_items; use crate::views::v2::path::pathfinding::TrackOffsetExtractionError; -use crate::views::v2::train_schedule::TrainScheduleProxy; -use crate::views::v2::train_schedule::{train_simulation, train_simulation_batch}; +use crate::views::v2::train_schedule::train_simulation; +use crate::views::v2::train_schedule::train_simulation_batch; use crate::RedisClient; use crate::Retrieve; use crate::RetrieveBatch; @@ -157,7 +156,6 @@ async fn stdcm( query: Query, data: Json, ) -> Result> { - let conn = &mut db_pool.clone().get().await?; let db_pool = db_pool.into_inner(); let core_client = core_client.into_inner(); let timetable_id = id.into_inner(); @@ -166,48 +164,44 @@ async fn stdcm( let redis_client_inner = redis_client.into_inner(); // 1. Retrieve Timetable / Infra / Trains / Simulation / Rolling Stock - let timetable_trains = TimetableWithTrains::retrieve_or_fail(conn, timetable_id, || { - STDCMError::TimetableNotFound { timetable_id } - }) + let timetable_trains = TimetableWithTrains::retrieve_or_fail( + db_pool.get().await?.deref_mut(), + timetable_id, + || STDCMError::TimetableNotFound { timetable_id }, + ) .await?; - let timetable: Timetable = timetable_trains.clone().into(); - let infra = - Infra::retrieve_or_fail(conn, infra_id, || STDCMError::InfraNotFound { infra_id }).await?; + let infra = Infra::retrieve_or_fail(db_pool.get().await?.deref_mut(), infra_id, || { + STDCMError::InfraNotFound { infra_id } + }) + .await?; let (trains, _): (Vec<_>, _) = - TrainSchedule::retrieve_batch(conn, timetable_trains.train_ids).await?; + TrainSchedule::retrieve_batch(db_pool.get().await?.deref_mut(), timetable_trains.train_ids) + .await?; - let (rolling_stocks, _): (Vec<_>, _) = RollingStockModel::retrieve_batch( + let rolling_stock = RollingStockModel::retrieve_or_fail( db_pool.get().await?.deref_mut(), - trains - .iter() - .map::(|t| t.rolling_stock_name.clone()), + data.rolling_stock_id, + || STDCMError::RollingStockNotFound { + rolling_stock_id: data.rolling_stock_id, + }, ) .await?; - let proxy = Arc::new(TrainScheduleProxy::new(&rolling_stocks, &[timetable])); - let simulations = train_simulation_batch( - db_pool.clone(), + db_pool.get().await?.deref_mut(), redis_client_inner.clone(), core_client.clone(), &trains, &infra, - proxy, ) .await?; - let rolling_stock = RollingStockModel::retrieve_or_fail(conn, data.rolling_stock_id, || { - STDCMError::RollingStockNotFound { - rolling_stock_id: data.rolling_stock_id, - } - }) - .await?; - // 2. Build core request let mut trains_requirements = HashMap::new(); for (train, sim) in trains.iter().zip(simulations) { + let (sim, _) = sim; let final_output = match sim { SimulationResponse::Success { final_output, .. } => final_output, _ => continue, @@ -244,7 +238,7 @@ async fn stdcm( let departure_time = get_earliest_departure_time(&data, maximum_run_time); // 3. Parse stdcm path items - let path_items = parse_stdcm_steps(conn, &data, &infra).await?; + let path_items = parse_stdcm_steps(db_pool.get().await?.deref_mut(), &data, &infra).await?; // 4. Build STDCM request let stdcm_response = STDCMRequest { @@ -267,7 +261,7 @@ async fn stdcm( margin: data.margin, time_step: Some(2000), work_schedules: build_work_schedules( - conn, + db_pool.get().await?.deref_mut(), departure_time, data.maximum_departure_delay, maximum_run_time, @@ -346,14 +340,12 @@ async fn get_maximum_run_time( options: Default::default(), }; - let conn = &mut db_pool.clone().get().await?; - let sim_result = train_simulation( - conn, + let (sim_result, _) = train_simulation( + db_pool.get().await?.deref_mut(), redis_client, core_client, - &train_schedule, + train_schedule, infra, - Arc::default(), ) .await?; diff --git a/editoast/src/views/v2/train_schedule.rs b/editoast/src/views/v2/train_schedule.rs index 310fe2eb4ed..ddc9b47d34b 100644 --- a/editoast/src/views/v2/train_schedule.rs +++ b/editoast/src/views/v2/train_schedule.rs @@ -1,5 +1,4 @@ mod projection; -mod proxy; use std::collections::hash_map::DefaultHasher; use std::collections::HashMap; @@ -42,7 +41,8 @@ use crate::modelsv2::prelude::*; use crate::modelsv2::timetable::Timetable; use crate::modelsv2::train_schedule::TrainSchedule; use crate::modelsv2::train_schedule::TrainScheduleChangeset; -use crate::views::v2::path::pathfinding_from_train; +use crate::views::v2::path::pathfinding::pathfinding_from_train; +use crate::views::v2::path::pathfinding_from_train_batch; use crate::views::v2::path::PathfindingError; use crate::RedisClient; use crate::RollingStockModel; @@ -50,10 +50,6 @@ use editoast_models::DbConnection; use editoast_models::DbConnectionPool; use editoast_models::DbConnectionPoolV2; -pub use proxy::TrainScheduleProxy; - -const CACHE_SIMULATION_EXPIRATION: u64 = 604800; // 1 week - crate::routes! { "/v2/train_schedule" => { delete, @@ -309,116 +305,169 @@ pub async fn simulation( .await?; Ok(Json( - train_simulation( - conn, - redis_client, - core_client, - &train_schedule, - &infra, - Arc::default(), - ) - .await?, + train_simulation(conn, redis_client, core_client, train_schedule, &infra) + .await? + .0, )) } -/// Compute the simulation of a given train schedule +/// Compute simulation of a train schedule pub async fn train_simulation( conn: &mut DbConnection, redis_client: Arc, core: Arc, - train_schedule: &TrainSchedule, + train_schedule: TrainSchedule, infra: &Infra, - proxy: Arc, -) -> Result { +) -> Result<(SimulationResponse, PathfindingResult)> { + Ok( + train_simulation_batch(conn, redis_client, core, &[train_schedule], infra) + .await? + .pop() + .unwrap(), + ) +} + +/// Compute in batch the simulation of a list of train schedule +/// +/// Note: The order of the returned simulations is the same as the order of the train schedules. +pub async fn train_simulation_batch( + conn: &mut DbConnection, + redis_client: Arc, + core: Arc, + train_schedules: &[TrainSchedule], + infra: &Infra, +) -> Result> { let mut redis_conn = redis_client.get_connection().await?; // Compute path - let pathfinding_result = pathfinding_from_train( + let (rolling_stocks, _): (Vec<_>, _) = RollingStockModel::retrieve_batch( + conn, + train_schedules + .iter() + .map::(|t| t.rolling_stock_name.clone()), + ) + .await?; + let rolling_stocks: HashMap<_, _> = rolling_stocks + .into_iter() + .map(|rs| (rs.name.clone(), rs)) + .collect(); + let (timetables, _): (Vec<_>, _) = Timetable::retrieve_batch( + conn, + train_schedules + .iter() + .map(|t| t.timetable_id) + .collect::>(), + ) + .await?; + let timetables: HashMap<_, _> = timetables + .into_iter() + .map(|timetable| (timetable.id, timetable)) + .collect(); + let pathfinding_results = pathfinding_from_train_batch( conn, &mut redis_conn, core.clone(), infra, - train_schedule.clone(), - proxy.clone(), + train_schedules, + &rolling_stocks, ) .await?; - let (path, path_items_positions) = match pathfinding_result { - PathfindingResult::Success(PathfindingResultSuccess { - blocks, - routes, - track_section_ranges, - path_items_positions, - .. - }) => ( - SimulationPath { + let mut simulation_results = vec![SimulationResponse::default(); train_schedules.len()]; + let mut to_sim = Vec::with_capacity(train_schedules.len()); + for (index, (pathfinding, train_schedule)) in + pathfinding_results.iter().zip(train_schedules).enumerate() + { + let (path, path_items_positions) = match pathfinding { + PathfindingResult::Success(PathfindingResultSuccess { blocks, routes, track_section_ranges, - }, - path_items_positions, - ), - _ => { - return Ok(SimulationResponse::PathfindingFailed { pathfinding_result }); - } - }; + path_items_positions, + .. + }) => ( + SimulationPath { + blocks: blocks.clone(), + routes: routes.clone(), + track_section_ranges: track_section_ranges.clone(), + }, + path_items_positions, + ), + _ => { + simulation_results[index] = SimulationResponse::PathfindingFailed { + pathfinding_result: pathfinding.clone(), + }; + continue; + } + }; - // Build simulation request - let simulation_request = build_simulation_request( - conn, - infra, - train_schedule, - &path_items_positions, - path, - proxy, - ) - .await?; + // Build simulation request + let rolling_stock = rolling_stocks[&train_schedule.rolling_stock_name].clone(); + let timetable = timetables[&train_schedule.timetable_id].clone(); + let simulation_request = build_simulation_request( + infra, + train_schedule, + path_items_positions, + path, + rolling_stock, + timetable, + ); - // Compute unique hash of the simulation input - let hash = train_simulation_input_hash(infra.id, &infra.version, &simulation_request); + // Compute unique hash of the simulation input + let simulation_hash = + train_simulation_input_hash(infra.id, &infra.version, &simulation_request); + to_sim.push((index, simulation_hash, simulation_request)); + } - let result: Option = redis_conn - .json_get_ex(&hash, CACHE_SIMULATION_EXPIRATION) + let cached_results: Vec> = redis_conn + .json_get_bulk(&to_sim.iter().map(|(_, hash, _)| hash).collect::>()) .await?; - if let Some(simulation_result) = result { - info!("Simulation hit cache"); - return Ok(simulation_result); - } + + let nb_hit = cached_results.iter().flatten().count(); + let nb_miss = to_sim.len() - nb_hit; + info!(nb_hit, nb_miss, "Hit cache"); // Compute simulation from core - let result = simulation_request.fetch(core.as_ref()).await?; + let mut futures = Vec::with_capacity(nb_miss); + let mut futures_index_hash = Vec::with_capacity(nb_miss); + for ((train_index, train_hash, sim_request), sim_cached) in to_sim.iter().zip(cached_results) { + if let Some(sim_cached) = sim_cached { + simulation_results[*train_index] = sim_cached; + continue; + } + futures.push(Box::pin(sim_request.fetch(core.as_ref()))); + futures_index_hash.push((*train_index, train_hash)); + } + + let simulated: Vec<_> = futures::future::join_all(futures) + .await + .into_iter() + .collect::>()?; + + let mut to_cache = Vec::with_capacity(simulated.len()); + for ((train_index, train_hash), sim_res) in futures_index_hash.into_iter().zip(simulated) { + simulation_results[train_index] = sim_res.clone(); + to_cache.push((train_hash, sim_res)); + } // Cache the simulation response - redis_conn - .json_set_ex(&hash, &result, CACHE_SIMULATION_EXPIRATION) - .await?; + redis_conn.json_set_bulk(&to_cache).await?; // Return the response - Ok(result) + Ok(simulation_results + .into_iter() + .zip(pathfinding_results) + .collect()) } -async fn build_simulation_request( - conn: &mut DbConnection, +fn build_simulation_request( infa: &Infra, train_schedule: &TrainSchedule, path_items_position: &[u64], path: SimulationPath, - proxy: Arc, -) -> Result { - // Get rolling stock - let rolling_stock_name = train_schedule.rolling_stock_name.clone(); - let rolling_stock = proxy - .get_rolling_stock(rolling_stock_name, conn) - .await? - .expect("Rolling stock should exist since the pathfinding succeeded"); - // Get electrical_profile_set_id - let timetable_id = train_schedule.timetable_id; - let timetable = proxy - .get_timetable(timetable_id, conn) - .await? - .expect("Timetable should exist since it's a foreign key"); - + rolling_stock: RollingStockModel, + timetable: Timetable, +) -> SimulationRequest { assert_eq!(path_items_position.len(), train_schedule.path.len()); - // Project path items to path offset let path_items_to_position: HashMap<_, _> = train_schedule .path @@ -464,7 +513,7 @@ async fn build_simulation_request( }) .collect(); - Ok(SimulationRequest { + SimulationRequest { infra: infa.id, expected_version: infa.version.clone(), path, @@ -478,7 +527,7 @@ async fn build_simulation_request( options: train_schedule.options.clone(), rolling_stock: rolling_stock.into(), electrical_profile_set_id: timetable.electrical_profile_set_id, - }) + } } // Compute hash input of a simulation @@ -560,37 +609,20 @@ pub async fn simulation_summary( }, ) .await?; - let (rolling_stocks, _): (Vec<_>, _) = RollingStockModel::retrieve_batch( - db_pool.get().await?.deref_mut(), - trains - .iter() - .map::(|t| t.rolling_stock_name.clone()), - ) - .await?; - let (timetables, _): (Vec<_>, _) = Timetable::retrieve_batch( - db_pool.get().await?.deref_mut(), - trains - .iter() - .map(|t| t.timetable_id) - .collect::>(), - ) - .await?; - - let proxy = Arc::new(TrainScheduleProxy::new(&rolling_stocks, &timetables)); let simulations = train_simulation_batch( - db_pool.clone(), + db_pool.get().await?.deref_mut(), redis_client, core, &trains, &infra, - proxy.clone(), ) .await?; // Transform simulations to simulation summary let mut simulation_summaries = HashMap::new(); for (train, sim) in trains.iter().zip(simulations) { + let (sim, _) = sim; let simulation_summary_result = match sim { SimulationResponse::Success { final_output, .. } => { let report = final_output.report_train; @@ -626,45 +658,6 @@ pub async fn simulation_summary( Ok(Json(simulation_summaries)) } -/// Compute train simulation in batch given a list of train schedules. -/// -/// Note: The order of the returned simulations is the same as the order of the train schedules. -/// -/// Returns an error if any of the train ids are not found. -pub async fn train_simulation_batch( - db_pool: Arc, - redis_client: Arc, - core_client: Arc, - train_schedules: &[TrainSchedule], - infra: &Infra, - proxy: Arc, -) -> Result> { - let pending_simulations = - train_schedules - .iter() - .zip(db_pool.iter_conn()) - .map(|(train_schedule, conn)| { - let redis_client = redis_client.clone(); - let core_client = core_client.clone(); - let cache = proxy.clone(); - async move { - train_simulation( - conn.await - .expect("Failed to get database connection") - .deref_mut(), - redis_client, - core_client, - train_schedule, - infra, - cache, - ) - .await - } - }); - - futures::future::try_join_all(pending_simulations).await -} - #[derive(Debug, Default, Clone, Serialize, Deserialize, IntoParams, ToSchema)] pub struct InfraIdQueryParam { infra_id: i64, @@ -704,15 +697,7 @@ async fn get_path( }) .await?; Ok(Json( - pathfinding_from_train( - conn, - &mut redis_conn, - core, - &infra, - train_schedule, - Arc::default(), - ) - .await?, + pathfinding_from_train(conn, &mut redis_conn, core, &infra, train_schedule).await?, )) } diff --git a/editoast/src/views/v2/train_schedule/projection.rs b/editoast/src/views/v2/train_schedule/projection.rs index 4a4f99d240d..81f49b45871 100644 --- a/editoast/src/views/v2/train_schedule/projection.rs +++ b/editoast/src/views/v2/train_schedule/projection.rs @@ -30,18 +30,15 @@ use crate::core::AsCoreRequest; use crate::core::CoreClient; use crate::error::Result; use crate::modelsv2::infra::Infra; -use crate::modelsv2::timetable::Timetable; use crate::modelsv2::train_schedule::TrainSchedule; use crate::modelsv2::Retrieve; use crate::modelsv2::RetrieveBatch; -use crate::views::v2::path::pathfinding_from_train; use crate::views::v2::path::projection::PathProjection; use crate::views::v2::path::projection::TrackLocationFromPath; use crate::views::v2::train_schedule::train_simulation_batch; use crate::views::v2::train_schedule::CompleteReportTrain; use crate::views::v2::train_schedule::ReportTrain; use crate::views::v2::train_schedule::SignalSighting; -use crate::views::v2::train_schedule::TrainScheduleProxy; use crate::views::v2::train_schedule::ZoneUpdate; use editoast_models::DbConnectionPoolV2; @@ -170,24 +167,12 @@ async fn project_path( ) .await?; - let (timetables, _): (Vec<_>, _) = Timetable::retrieve_batch( - db_pool.get().await?.deref_mut(), - trains - .iter() - .map(|t| t.timetable_id) - .collect::>(), - ) - .await?; - - let proxy = Arc::new(TrainScheduleProxy::new(&rolling_stocks, &timetables)); - let simulations = train_simulation_batch( - db_pool.clone(), + db_pool.get().await?.deref_mut(), redis_client.clone(), core.clone(), &trains, &infra, - proxy.clone(), ) .await?; @@ -195,17 +180,7 @@ async fn project_path( let mut trains_hash_values = HashMap::new(); let mut trains_details = HashMap::new(); - for (train, sim) in trains.iter().zip(simulations) { - let pathfinding_result = pathfinding_from_train( - db_pool.get().await?.deref_mut(), - &mut redis_conn, - core.clone(), - &infra, - train.clone(), - proxy.clone(), - ) - .await?; - + for (train, (sim, pathfinding_result)) in trains.iter().zip(simulations) { let track_ranges = match pathfinding_result { PathfindingResult::Success(PathfindingResultSuccess { track_section_ranges, diff --git a/editoast/src/views/v2/train_schedule/proxy.rs b/editoast/src/views/v2/train_schedule/proxy.rs deleted file mode 100644 index 6c28535919c..00000000000 --- a/editoast/src/views/v2/train_schedule/proxy.rs +++ /dev/null @@ -1,90 +0,0 @@ -use chashmap::CHashMap; -use editoast_models::DbConnection; - -use crate::core::v2::pathfinding::PathfindingResult; -use crate::error::Result; -use crate::modelsv2::timetable::Timetable; -use crate::{Retrieve, RollingStockModel}; - -/// Used to cache postgres and redis queries while simulating train schedules -#[derive(Debug, Default)] -pub struct TrainScheduleProxy { - /// Map train schedule id with their respective timetable - timetables: CHashMap, - /// Map rolling stock name with their respective rolling stock - rolling_stocks: CHashMap, - /// Map train schedule id with their computed path - pathfinding_results: CHashMap, -} - -impl TrainScheduleProxy { - /// Initialize the cache with a list of rolling stocks - pub fn new(rolling_stocks: &[RollingStockModel], timetables: &[Timetable]) -> Self { - Self { - rolling_stocks: rolling_stocks - .iter() - .map(|rs| (rs.name.clone(), rs.clone())) - .collect(), - timetables: timetables - .iter() - .map(|timetable| (timetable.id, timetable.clone())) - .collect(), - ..Default::default() - } - } - - /// Returns the cached value given a timetable ID. - /// If the value is not cached, it will retrieve it and cache it. - pub async fn get_timetable( - &self, - id: i64, - conn: &mut DbConnection, - ) -> Result> { - if let Some(timetable) = self.timetables.get(&id) { - return Ok(Some(timetable.clone())); - } - - let Some(timetable) = Timetable::retrieve(conn, id).await? else { - return Ok(None); - }; - - if self.timetables.get_mut(&id).is_none() { - self.timetables.insert_new(id, timetable.clone()); - } - Ok(Some(timetable)) - } - - /// Returns the cached value given a rolling stock name. - /// If the value is not cached, it will retrieve it and cache it. - pub async fn get_rolling_stock( - &self, - name: String, - conn: &mut DbConnection, - ) -> Result> { - if let Some(rs) = self.rolling_stocks.get(&name) { - return Ok(Some(rs.clone())); - } - - let Some(rs) = RollingStockModel::retrieve(conn, name).await? else { - return Ok(None); - }; - - if self.rolling_stocks.get(&rs.name).is_none() { - self.rolling_stocks.insert(rs.name.clone(), rs.clone()); - } - Ok(Some(rs)) - } - - /// Returns the cached value given a train schedule ID. - pub fn get_pathfinding_result(&self, id: i64) -> Option { - self.pathfinding_results.get(&id).map(|r| r.clone()) - } - - /// Caches a pathfinding result given a train schedule ID. - /// If the value is already cached, it won't be updated. - pub fn set_pathfinding_result(&self, id: i64, result: PathfindingResult) { - if self.pathfinding_results.get(&id).is_none() { - self.pathfinding_results.insert(id, result); - } - } -} From eb59e0764212da362255c6aca8fb71aa92b889d0 Mon Sep 17 00:00:00 2001 From: Jean SIMARD Date: Thu, 11 Jul 2024 12:03:28 +0200 Subject: [PATCH 2/2] editoast: rename typo in a variable --- editoast/src/views/v2/train_schedule.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/editoast/src/views/v2/train_schedule.rs b/editoast/src/views/v2/train_schedule.rs index ddc9b47d34b..d79ecac02f1 100644 --- a/editoast/src/views/v2/train_schedule.rs +++ b/editoast/src/views/v2/train_schedule.rs @@ -460,7 +460,7 @@ pub async fn train_simulation_batch( } fn build_simulation_request( - infa: &Infra, + infra: &Infra, train_schedule: &TrainSchedule, path_items_position: &[u64], path: SimulationPath, @@ -514,8 +514,8 @@ fn build_simulation_request( .collect(); SimulationRequest { - infra: infa.id, - expected_version: infa.version.clone(), + infra: infra.id, + expected_version: infra.version.clone(), path, schedule, margins,