Skip to content

Commit

Permalink
editoast: add roles check to rolling stocks
Browse files Browse the repository at this point in the history
  • Loading branch information
Wadjetz committed Aug 29, 2024
1 parent 80973c5 commit 811a2ec
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 2 deletions.
3 changes: 3 additions & 0 deletions editoast/editoast_authz/src/builtin_role.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ pub enum BuiltinRole {

#[strum(serialize = "rolling_stock_collection:read")]
RollingStockCollectionRead,
#[strum(serialize = "rolling_stock_collection:write")]
RollingStockCollectionWrite,

#[strum(serialize = "work_schedule:write")]
WorkScheduleWrite,
Expand All @@ -42,6 +44,7 @@ impl BuiltinRoleSet for BuiltinRole {
InfraRead => vec![],
InfraWrite => vec![InfraRead],
RollingStockCollectionRead => vec![],
RollingStockCollectionWrite => vec![RollingStockCollectionRead],
WorkScheduleWrite => vec![],
MapRead => vec![],
Stdcm => vec![MapRead],
Expand Down
29 changes: 29 additions & 0 deletions editoast/src/views/light_rolling_stocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use axum::extract::Json;
use axum::extract::Path;
use axum::extract::Query;
use axum::extract::State;
use axum::Extension;
use editoast_authz::BuiltinRole;
use editoast_models::DbConnection;
use editoast_models::DbConnectionPoolV2;
use editoast_schemas::rolling_stock::RollingStockLivery;
Expand All @@ -28,6 +30,9 @@ use crate::SelectionSettings;
#[cfg(test)]
use serde::Deserialize;

use super::AuthorizationError;
use super::AuthorizerExt;

crate::routes! {
"/light_rolling_stock" => {
list,
Expand Down Expand Up @@ -91,8 +96,16 @@ struct LightRollingStockWithLiveriesCountList {
)]
async fn list(
State(db_pool): State<DbConnectionPoolV2>,
Extension(authorizer): AuthorizerExt,
Query(page_settings): Query<PaginationQueryParam>,
) -> Result<Json<LightRollingStockWithLiveriesCountList>> {
let authorized = authorizer
.check_roles([BuiltinRole::RollingStockCollectionRead].into())
.await
.map_err(AuthorizationError::AuthError)?;
if !authorized {
return Err(AuthorizationError::Unauthorized.into());
}
let settings = page_settings
.validate(1000)?
.warn_page_size(100)
Expand Down Expand Up @@ -128,8 +141,16 @@ async fn list(
)]
async fn get(
State(db_pool): State<DbConnectionPoolV2>,
Extension(authorizer): AuthorizerExt,
Path(light_rolling_stock_id): Path<i64>,
) -> Result<Json<LightRollingStockWithLiveries>> {
let authorized = authorizer
.check_roles([BuiltinRole::RollingStockCollectionRead].into())
.await
.map_err(AuthorizationError::AuthError)?;
if !authorized {
return Err(AuthorizationError::Unauthorized.into());
}
let light_rolling_stock = LightRollingStockModel::retrieve_or_fail(
db_pool.get().await?.deref_mut(),
light_rolling_stock_id,
Expand Down Expand Up @@ -157,8 +178,16 @@ async fn get(
)]
async fn get_by_name(
State(db_pool): State<DbConnectionPoolV2>,
Extension(authorizer): AuthorizerExt,
Path(light_rolling_stock_name): Path<String>,
) -> Result<Json<LightRollingStockWithLiveries>> {
let authorized = authorizer
.check_roles([BuiltinRole::RollingStockCollectionRead].into())
.await
.map_err(AuthorizationError::AuthError)?;
if !authorized {
return Err(AuthorizationError::Unauthorized.into());
}
let light_rolling_stock = LightRollingStockModel::retrieve_or_fail(
db_pool.get().await?.deref_mut(),
light_rolling_stock_name.clone(),
Expand Down
74 changes: 72 additions & 2 deletions editoast/src/views/rolling_stocks/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ use axum::extract::Query;
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::Extension;
use editoast_authz::BuiltinRole;
use editoast_derive::EditoastError;
use editoast_models::DbConnection;
use editoast_models::DbConnectionPoolV2;
use editoast_schemas::rolling_stock::RollingStockLivery;
use image::DynamicImage;
use image::GenericImage;
Expand All @@ -35,8 +39,8 @@ use crate::modelsv2::rolling_stock_model::TrainScheduleScenarioStudyProject;
use crate::modelsv2::Document;
use crate::modelsv2::RollingStockModel;
use crate::modelsv2::RollingStockSeparatedImageModel;
use editoast_models::DbConnection;
use editoast_models::DbConnectionPoolV2;
use crate::views::AuthorizationError;
use crate::views::AuthorizerExt;

crate::routes! {
"/rolling_stock" => {
Expand Down Expand Up @@ -201,8 +205,16 @@ pub struct RollingStockNameParam {
)]
async fn get(
State(db_pool): State<DbConnectionPoolV2>,
Extension(authorizer): AuthorizerExt,
Path(rolling_stock_id): Path<i64>,
) -> Result<Json<RollingStockWithLiveries>> {
let authorized = authorizer
.check_roles([BuiltinRole::RollingStockCollectionRead].into())
.await
.map_err(AuthorizationError::AuthError)?;
if !authorized {
return Err(AuthorizationError::Unauthorized.into());
}
let rolling_stock = retrieve_existing_rolling_stock(
db_pool.get().await?.deref_mut(),
RollingStockKey::Id(rolling_stock_id),
Expand All @@ -225,8 +237,17 @@ async fn get(
)]
async fn get_by_name(
State(db_pool): State<DbConnectionPoolV2>,
Extension(authorizer): AuthorizerExt,
Path(rolling_stock_name): Path<String>,
) -> Result<Json<RollingStockWithLiveries>> {
let authorized = authorizer
.check_roles([BuiltinRole::RollingStockCollectionRead].into())
.await
.map_err(AuthorizationError::AuthError)?;
if !authorized {
return Err(AuthorizationError::Unauthorized.into());
}

let rolling_stock = retrieve_existing_rolling_stock(
db_pool.get().await?.deref_mut(),
RollingStockKey::Name(rolling_stock_name),
Expand All @@ -248,7 +269,15 @@ async fn get_by_name(
)]
async fn get_power_restrictions(
State(db_pool): State<DbConnectionPoolV2>,
Extension(authorizer): AuthorizerExt,
) -> Result<Json<Vec<String>>> {
let authorized = authorizer
.check_roles([BuiltinRole::RollingStockCollectionRead].into())
.await
.map_err(AuthorizationError::AuthError)?;
if !authorized {
return Err(AuthorizationError::Unauthorized.into());
}
let conn = &mut db_pool.get().await?;
let power_restrictions = RollingStockModel::get_power_restrictions(conn).await?;
Ok(Json(
Expand Down Expand Up @@ -278,9 +307,17 @@ struct PostRollingStockQueryParams {
)]
async fn create(
State(db_pool): State<DbConnectionPoolV2>,
Extension(authorizer): AuthorizerExt,
Query(query_params): Query<PostRollingStockQueryParams>,
Json(rolling_stock_form): Json<RollingStockForm>,
) -> Result<Json<RollingStockModel>> {
let authorized = authorizer
.check_roles([BuiltinRole::RollingStockCollectionWrite].into())
.await
.map_err(AuthorizationError::AuthError)?;
if !authorized {
return Err(AuthorizationError::Unauthorized.into());
}
rolling_stock_form.validate()?;
let conn = &mut db_pool.get().await?;
let rolling_stock_name = rolling_stock_form.name.clone();
Expand Down Expand Up @@ -308,9 +345,17 @@ async fn create(
)]
async fn update(
State(db_pool): State<DbConnectionPoolV2>,
Extension(authorizer): AuthorizerExt,
Path(rolling_stock_id): Path<i64>,
Json(rolling_stock_form): Json<RollingStockForm>,
) -> Result<Json<RollingStockWithLiveries>> {
let authorized = authorizer
.check_roles([BuiltinRole::RollingStockCollectionWrite].into())
.await
.map_err(AuthorizationError::AuthError)?;
if !authorized {
return Err(AuthorizationError::Unauthorized.into());
}
rolling_stock_form.validate()?;
let name = rolling_stock_form.name.clone();

Expand Down Expand Up @@ -369,9 +414,17 @@ struct DeleteRollingStockQueryParams {
)]
async fn delete(
State(db_pool): State<DbConnectionPoolV2>,
Extension(authorizer): AuthorizerExt,
Path(rolling_stock_id): Path<i64>,
Query(DeleteRollingStockQueryParams { force }): Query<DeleteRollingStockQueryParams>,
) -> Result<impl IntoResponse> {
let authorized = authorizer
.check_roles([BuiltinRole::RollingStockCollectionWrite].into())
.await
.map_err(AuthorizationError::AuthError)?;
if !authorized {
return Err(AuthorizationError::Unauthorized.into());
}
let conn = &mut db_pool.get().await?;
assert_rolling_stock_unlocked(
&retrieve_existing_rolling_stock(conn, RollingStockKey::Id(rolling_stock_id)).await?,
Expand Down Expand Up @@ -423,9 +476,18 @@ struct RollingStockLockedUpdateForm {
)]
async fn update_locked(
State(db_pool): State<DbConnectionPoolV2>,
Extension(authorizer): AuthorizerExt,
Path(rolling_stock_id): Path<i64>,
Json(RollingStockLockedUpdateForm { locked }): Json<RollingStockLockedUpdateForm>,
) -> Result<impl IntoResponse> {
let authorized = authorizer
.check_roles([BuiltinRole::RollingStockCollectionWrite].into())
.await
.map_err(AuthorizationError::AuthError)?;
if !authorized {
return Err(AuthorizationError::Unauthorized.into());
}

let conn = &mut db_pool.get().await?;

// FIXME: check that the rolling stock exists (the Option<RollingSrtockModel> is ignored here)
Expand Down Expand Up @@ -509,9 +571,17 @@ async fn parse_multipart_content(
)]
async fn create_livery(
State(db_pool): State<DbConnectionPoolV2>,
Extension(authorizer): AuthorizerExt,
Path(rolling_stock_id): Path<i64>,
form: Multipart,
) -> Result<Json<RollingStockLivery>> {
let authorized = authorizer
.check_roles([BuiltinRole::RollingStockCollectionWrite].into())
.await
.map_err(AuthorizationError::AuthError)?;
if !authorized {
return Err(AuthorizationError::Unauthorized.into());
}
let conn = &mut db_pool.get().await?;

let (name, images) = parse_multipart_content(form)
Expand Down

0 comments on commit 811a2ec

Please sign in to comment.