Skip to content

Commit

Permalink
editoast: refactor connection pool for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Wadjetz committed May 14, 2024
1 parent af41523 commit e387a6c
Show file tree
Hide file tree
Showing 8 changed files with 238 additions and 62 deletions.
6 changes: 2 additions & 4 deletions editoast/src/fixtures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ pub mod tests {
use std::io::Cursor;
use std::ops::{Deref, DerefMut};

use crate::modelsv2::connection_pool::DbConnectionConfig;
use crate::modelsv2::connection_pool::create_connection_pool_for_tests;
use crate::modelsv2::connection_pool::DbConnectionPool;
use crate::modelsv2::DbConnection;
use crate::{
Expand Down Expand Up @@ -143,9 +143,7 @@ pub mod tests {
let pg_config_url = PostgresConfig::default()
.url()
.expect("cannot get postgres config url");
let config = DbConnectionConfig::new(pg_config_url);
let pool = DbConnectionPool::builder(config).build().unwrap();
Data::new(pool)
Data::new(create_connection_pool_for_tests(pg_config_url))
}

pub fn get_fast_rolling_stock_form(name: &str) -> RollingStockForm {
Expand Down
12 changes: 9 additions & 3 deletions editoast/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mod views;
use crate::core::CoreClient;
use crate::error::InternalError;
use crate::modelsv2::DbConnectionPool;
use crate::modelsv2::DbConnectionPoolV2;
use crate::modelsv2::Infra;
use crate::views::OpenApiRoot;
use actix_cors::Cors;
Expand All @@ -28,6 +29,7 @@ use actix_web::web::{scope, Data, JsonConfig, PayloadConfig};
use actix_web::{App, HttpServer};
use chashmap::CHashMap;
use clap::Parser;
use client::PostgresConfig;
use client::{
ClearArgs, Client, Color, Commands, DeleteProfileSetArgs, ElectricalProfilesCommands,
ExportTimetableArgs, GenerateArgs, ImportProfileSetArgs, ImportRailjsonArgs,
Expand Down Expand Up @@ -192,7 +194,7 @@ async fn run() -> Result<(), Box<dyn Error + Send + Sync>> {
}

match client.command {
Commands::Runserver(args) => runserver(args, create_db_pool()?, redis_config).await,
Commands::Runserver(args) => runserver(args, pg_config, redis_config).await,
Commands::ImportRollingStock(args) => import_rolling_stock(args, create_db_pool()?).await,
Commands::OsmToRailjson(args) => {
osm_to_railjson::osm_to_railjson(args.osm_pbf_in, args.railjson_out)
Expand Down Expand Up @@ -365,13 +367,16 @@ fn log_received_request(req: &ServiceRequest) {
/// Create and run the server
async fn runserver(
args: RunserverArgs,
db_pool: Data<DbConnectionPool>,
postgres_config: PostgresConfig,
redis_config: RedisConfig,
) -> Result<(), Box<dyn Error + Send + Sync>> {
info!("Building server...");
// Config database
let redis = RedisClient::new(redis_config)?;

// Create database pool
let db_pool = create_connection_pool(postgres_config.url()?, postgres_config.pool_size);

// Custom Json extractor configuration
let json_cfg = JsonConfig::default()
.limit(250 * 1024 * 1024) // 250MiB
Expand Down Expand Up @@ -426,7 +431,8 @@ async fn runserver(
.wrap(Logger::new(actix_logger_format).log_target("actix_logger"))
.app_data(json_cfg.clone())
.app_data(payload_config.clone())
.app_data(db_pool.clone())
.app_data(Data::new(db_pool.clone()))
.app_data(Data::new(DbConnectionPoolV2::create(db_pool.clone())))
.app_data(Data::new(redis.clone()))
.app_data(infra_caches.clone())
.app_data(Data::new(MapLayers::parse()))
Expand Down
17 changes: 16 additions & 1 deletion editoast/src/modelsv2/connection_pool.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use diesel::{ConnectionError, ConnectionResult};
use diesel::ConnectionError;
use diesel::ConnectionResult;
use diesel_async::pooled_connection::deadpool::Pool;
use diesel_async::pooled_connection::AsyncDieselConnectionManager;
use diesel_async::pooled_connection::ManagerConfig;
Expand All @@ -9,6 +10,11 @@ use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
use tracing::error;
use url::Url;

mod db_connection_error;
pub use db_connection_error::DbConnectionError;
mod db_connection_pool;
pub use db_connection_pool::DbConnectionPoolV2;

pub type DbConnection = AsyncPgConnection;
pub type DbConnectionPool = Pool<DbConnection>;
pub type DbConnectionConfig = AsyncDieselConnectionManager<AsyncPgConnection>;
Expand All @@ -23,6 +29,15 @@ pub fn create_connection_pool(url: Url, max_size: usize) -> DbConnectionPool {
.expect("Failed to create pool.")
}

#[cfg(test)]
pub fn create_connection_pool_for_tests(url: Url) -> DbConnectionPool {
let config = DbConnectionConfig::new(url);
let pool = DbConnectionPool::builder(config)
.build()
.expect("Failed to create pool for tests.");
pool
}

fn establish_connection(config: &str) -> BoxFuture<ConnectionResult<DbConnection>> {
let fut = async {
let mut connector_builder = SslConnector::builder(SslMethod::tls()).unwrap();
Expand Down
26 changes: 26 additions & 0 deletions editoast/src/modelsv2/connection_pool/db_connection_error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use reqwest::StatusCode;
use thiserror::Error;

use crate::error::EditoastError;

#[derive(Debug, Error)]
pub enum DbConnectionError {
#[error(transparent)]
DeadpoolPool(diesel_async::pooled_connection::deadpool::PoolError),
#[error(transparent)]
#[allow(dead_code)]
DieselError(diesel::result::Error),
#[allow(dead_code)]
#[error("Test connection not initialized")]
TestConnection,
}

impl EditoastError for DbConnectionError {
fn get_status(&self) -> StatusCode {
StatusCode::INTERNAL_SERVER_ERROR
}

fn get_type(&self) -> &str {
"editoast:ConnectionPoolError"
}
}
77 changes: 77 additions & 0 deletions editoast/src/modelsv2/connection_pool/db_connection_pool.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
use diesel_async::pooled_connection::deadpool::Object;
use diesel_async::pooled_connection::deadpool::Pool;

use diesel_async::AsyncPgConnection;
use std::sync::Arc;
#[cfg(test)]
use tokio::sync::OwnedRwLockWriteGuard;
#[cfg(test)]
use tokio::sync::RwLock;
use url::Url;

use super::DbConnectionError;

#[derive(Clone)]
pub struct DbConnectionPoolV2 {
pub pool: Arc<Pool<AsyncPgConnection>>,
#[cfg(test)]
pub test_connection: Option<Arc<RwLock<Object<AsyncPgConnection>>>>,
}

impl DbConnectionPoolV2 {
#[cfg(test)]
pub async fn create(pool: Pool<AsyncPgConnection>) -> Result<Self, DbConnectionError> {
use diesel_async::AsyncConnection;
use tokio::sync::RwLock;
let mut conn = pool.get().await.map_err(DbConnectionError::DeadpoolPool)?;
conn.begin_test_transaction()
.await
.map_err(DbConnectionError::DieselError)?;
let test_connection = Arc::new(RwLock::new(conn));

Ok(Self {
pool: Arc::new(pool),
test_connection: Some(test_connection),
})
}

#[cfg(not(test))]
pub async fn create(pool: Pool<AsyncPgConnection>) -> Result<Self, DbConnectionError> {
Ok(Self {
pool: Arc::new(pool),
})
}

#[cfg(test)]
pub async fn create_from_url(url: Url, pool_size: usize) -> Result<Self, DbConnectionError> {
use diesel_async::pooled_connection::AsyncDieselConnectionManager;
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(url.as_str());
let pool = Pool::builder(manager)
.max_size(pool_size)
.build()
.expect("Failed to create pool.");
Ok(Self::create(pool).await?)
}

#[cfg(test)]
pub async fn get(
&self,
) -> Result<OwnedRwLockWriteGuard<Object<AsyncPgConnection>>, DbConnectionError> {
if let Some(test_connection) = &self.test_connection {
let connection = test_connection.clone().write_owned().await;
Ok(connection)
} else {
Err(DbConnectionError::TestConnection)
}
}

#[cfg(not(test))]
pub async fn get(&self) -> Result<Object<AsyncPgConnection>, DbConnectionError> {
let co = self
.pool
.get()
.await
.map_err(DbConnectionError::DeadpoolPool)?;
Ok(co)
}
}
1 change: 1 addition & 0 deletions editoast/src/modelsv2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub use prelude::*;

pub use connection_pool::DbConnection;
pub use connection_pool::DbConnectionPool;
pub use connection_pool::DbConnectionPoolV2;
pub use documents::Document;
pub use electrical_profiles::ElectricalProfileSet;
pub use infra::Infra;
Expand Down
127 changes: 73 additions & 54 deletions editoast/src/views/documents.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use thiserror::Error;
use utoipa::ToSchema;

use crate::error::Result;
use crate::modelsv2::DbConnectionPool;
use crate::modelsv2::DbConnectionPoolV2;
use crate::modelsv2::*;

crate::routes! {
Expand Down Expand Up @@ -49,7 +49,7 @@ pub enum DocumentErrors {
)
)]
#[get("/{document_key}")]
async fn get(db_pool: Data<DbConnectionPool>, document_key: Path<i64>) -> Result<HttpResponse> {
async fn get(db_pool: Data<DbConnectionPoolV2>, document_key: Path<i64>) -> Result<HttpResponse> {
let document_key = document_key.into_inner();
let conn = &mut db_pool.get().await?;
let doc = Document::retrieve_or_fail(conn, document_key, || DocumentErrors::NotFound {
Expand Down Expand Up @@ -79,7 +79,7 @@ struct NewDocumentResponse {
)]
#[post("")]
async fn post(
db_pool: Data<DbConnectionPool>,
db_pool: Data<DbConnectionPoolV2>,
content_type: Header<ContentType>,
bytes: Bytes,
) -> Result<HttpResponse> {
Expand Down Expand Up @@ -112,7 +112,10 @@ async fn post(
)
)]
#[delete("/{document_key}")]
async fn delete(db_pool: Data<DbConnectionPool>, document_key: Path<i64>) -> Result<HttpResponse> {
async fn delete(
db_pool: Data<DbConnectionPoolV2>,
document_key: Path<i64>,
) -> Result<HttpResponse> {
let document_key = document_key.into_inner();
let conn = &mut db_pool.get().await?;
Document::delete_static_or_fail(conn, document_key, || DocumentErrors::NotFound {
Expand All @@ -131,72 +134,88 @@ mod tests {
use serde::Deserialize;

use super::*;
use crate::fixtures::tests::db_pool;
use crate::fixtures::tests::document_example;
use crate::fixtures::tests::TestFixture;
use crate::views::tests::create_test_service;

#[rstest]
async fn get_document(
#[future] document_example: TestFixture<Document>,
db_pool: Data<DbConnectionPool>,
) {
let service = create_test_service().await;
let doc = document_example.await;

let doc_key = doc.id();
let url = format!("/documents/{}", doc_key);

// Should succeed
let request = TestRequest::get().uri(&url).to_request();
let response = call_service(&service, request).await;
assert!(response.status().is_success());

// Delete the document
assert!(doc
.model
.delete(&mut db_pool.get().await.unwrap())
.await
.unwrap());

// Should fail
let request = TestRequest::get().uri(&url).to_request();
let response = call_service(&service, request).await;
assert!(response.status().is_client_error());
}
use crate::views::tests::create_test_connection_pool;
use crate::views::tests::create_test_service_with_connection_pool;

#[derive(Deserialize, Clone, Debug)]
struct PostDocumentResponse {
document_key: i64,
}

async fn create_document(connection_pool: &DbConnectionPoolV2) -> Document {
let mut conn = connection_pool
.get()
.await
.expect("Failed to get connection pool");
Document::changeset()
.data("Document post test data".as_bytes().to_vec())
.content_type(String::from("text/plain"))
.create(&mut conn)
.await
.expect("Failed to create document")
}

#[rstest]
async fn document_post(db_pool: Data<DbConnectionPool>) {
let service = create_test_service().await;
async fn document_post() {
let pool = create_test_connection_pool().await;
let service = create_test_service_with_connection_pool(pool.clone()).await;

// Insert document
let request = TestRequest::post()
.uri("/documents")
.insert_header(ContentType::plaintext())
.set_payload("Test data".as_bytes().to_vec())
.set_payload("Document post test data".as_bytes().to_vec())
.to_request();

// Insert document
let create_response: PostDocumentResponse =
call_and_read_body_json(&service, request).await;

println!("Document id: {}", &create_response.document_key);

assert!(create_response.document_key > 0);
}

#[rstest]
async fn get_document() {
let pool = create_test_connection_pool().await;
let service = create_test_service_with_connection_pool(pool.clone()).await;

// Insert document test
let document = create_document(&pool).await;

println!("Document id: {}", document.id);

// Get document test
let request = TestRequest::get()
.uri(&format!("/documents/{}", document.id))
.to_request();
let response: PostDocumentResponse = call_and_read_body_json(&service, request).await;

// Delete the document
assert!(
Document::delete_static(&mut db_pool.get().await.unwrap(), response.document_key)
.await
.unwrap()
);
let response = call_service(&service, request).await;

assert!(response.status().is_success());
}

#[rstest]
async fn document_delete(#[future] document_example: TestFixture<Document>) {
let document_example = document_example.await;
let service = create_test_service().await;
async fn document_delete() {
let pool = create_test_connection_pool().await;
let service = create_test_service_with_connection_pool(pool.clone()).await;

// Insert document test
let document = create_document(&pool).await;

// Delete document request
let request = TestRequest::delete()
.uri(format!("/documents/{}", document_example.id()).as_str())
.uri(format!("/documents/{}", document.id).as_str())
.to_request();
assert!(call_service(&service, request).await.status().is_success());
let response = call_service(&service, request).await;

assert!(response.status().is_success());

// Get document request
let request = TestRequest::get()
.uri(&format!("/documents/{}", document.id))
.to_request();
let response = call_service(&service, request).await;

assert_eq!(response.status(), StatusCode::NOT_FOUND);
}
}
Loading

0 comments on commit e387a6c

Please sign in to comment.