From 0927aab8580870e68bfd58e3d15686af4f39edbd Mon Sep 17 00:00:00 2001 From: Chris Connelly Date: Wed, 1 Nov 2023 15:33:21 +0000 Subject: [PATCH 1/6] chore: doc typo --- src/slot.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/slot.rs b/src/slot.rs index 5a078ae..2102ef9 100644 --- a/src/slot.rs +++ b/src/slot.rs @@ -4,7 +4,7 @@ //! Conceptually, the `Slot` is the "primary" owner of the value, and access can be leased to one //! other owner through an associated `Lease`. //! -//! It's implemented as a wrapper around an `Arc>>`, where the `Slot` `take`s the +//! It's implemented as a wrapper around an `Arc>>`, where the `Slot` `take`s the //! value from the `Option` on lease, and the `Lease` puts it back in on drop. //! //! Note that while this is **safe** to use across threads (it is `Send` + `Sync`), concurrently From c9eccafb9b68dbac0e3ead7ce81d3dc38f30eeb7 Mon Sep 17 00:00:00 2001 From: Chris Connelly Date: Wed, 1 Nov 2023 14:29:44 +0000 Subject: [PATCH 2/6] feat!: introduce `State` to make `Tx` more type-safe `Tx` now implements `FromRequestParts`, meaning the `Router` must be provided an instance of `State` in order for it to compile. `State` can be constructed with `Layer::new`, which now returns a tuple of `Layer` and `State`. The idea behind constructing both in the same function is to make it harder to forget to add the `Layer`, however `Layer::new` is perhaps not the best place for it long-term. Ultimately this has the desired effect of making `Tx` more type-safe - applications that attempt to use `Tx` without providing `State` won't compile, and the API for obtaining `State` makes it harder to forget to add the `Layer`. BREAKING CHANGE: `Layer::new` now returns a `(Layer, State)` tuple. This can be consumed easily by destructuring assignment. `Tx` now requires `State` to be provided on `Router`s in order for the `Router` to be usable. --- examples/example.rs | 6 +++- src/layer.rs | 38 ++++++++++++++---------- src/lib.rs | 50 +++++++++++++++++++++++++++---- src/tx.rs | 16 ++++++---- tests/lib.rs | 72 +++++++++++++++++++++++++++++++++++++++++---- 5 files changed, 148 insertions(+), 34 deletions(-) diff --git a/examples/example.rs b/examples/example.rs index 01039db..bb9d348 100644 --- a/examples/example.rs +++ b/examples/example.rs @@ -19,11 +19,15 @@ async fn main() -> Result<(), Box> { .execute(&pool) .await?; + let (layer, state) = axum_sqlx_tx::Layer::new(pool.clone()); + // Standard axum app setup let app = axum::Router::new() .route("/numbers", get(list_numbers).post(generate_number)) // Apply the Tx middleware - .layer(axum_sqlx_tx::Layer::new(pool.clone())); + .layer(layer) + // Add the Tx state + .with_state(state); let server = axum::Server::bind(&([0, 0, 0, 0], 0).into()).serve(app.into_make_service()); diff --git a/src/layer.rs b/src/layer.rs index 2637690..3d89535 100644 --- a/src/layer.rs +++ b/src/layer.rs @@ -7,7 +7,7 @@ use bytes::Bytes; use futures_core::future::BoxFuture; use http_body::{combinators::UnsyncBoxBody, Body}; -use crate::{tx::TxSlot, Error}; +use crate::{tx::TxSlot, Error, State}; /// A [`tower_layer::Layer`] that enables the [`Tx`] extractor. /// @@ -21,42 +21,46 @@ use crate::{tx::TxSlot, Error}; /// [`Tx`]: crate::Tx /// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html pub struct Layer { - pool: sqlx::Pool, + state: State, _error: PhantomData, } impl Layer { - /// Construct a new layer with the given `pool`. + /// Construct a new layer and [`State`] with the given `pool`. /// /// A connection will be obtained from the pool the first time a [`Tx`](crate::Tx) is extracted /// from a request. /// /// If you want to access the pool outside of a transaction, you should add it also with - /// [`axum::Extension`]. + /// [`axum::Extension`] or as router state. /// /// To use a different type than [`Error`] to convert commit errors into responses, see /// [`new_with_error`](Self::new_with_error). /// /// [`axum::Extension`]: https://docs.rs/axum/latest/axum/extract/struct.Extension.html - pub fn new(pool: sqlx::Pool) -> Self { + pub fn new(pool: sqlx::Pool) -> (Self, State) { Self::new_with_error(pool) } /// Construct a new layer with a specific error type. /// /// See [`Layer::new`] for more information. - pub fn new_with_error(pool: sqlx::Pool) -> Layer { - Layer { - pool, - _error: PhantomData, - } + pub fn new_with_error(pool: sqlx::Pool) -> (Layer, State) { + let state = State::new(pool); + ( + Layer { + state: state.clone(), + _error: PhantomData, + }, + state, + ) } } impl Clone for Layer { fn clone(&self) -> Self { Self { - pool: self.pool.clone(), + state: self.state.clone(), _error: self._error, } } @@ -67,7 +71,7 @@ impl tower_layer::Layer for Layer { fn layer(&self, inner: S) -> Self::Service { Service { - pool: self.pool.clone(), + state: self.state.clone(), inner, _error: self._error, } @@ -78,7 +82,7 @@ impl tower_layer::Layer for Layer { /// /// See [`Layer`] for more information. pub struct Service { - pool: sqlx::Pool, + state: State, inner: S, _error: PhantomData, } @@ -87,7 +91,7 @@ pub struct Service { impl Clone for Service { fn clone(&self) -> Self { Self { - pool: self.pool.clone(), + state: self.state.clone(), inner: self.inner.clone(), _error: self._error, } @@ -119,7 +123,7 @@ where } fn call(&mut self, mut req: http::Request) -> Self::Future { - let transaction = TxSlot::bind(req.extensions_mut(), self.pool.clone()); + let transaction = TxSlot::bind(req.extensions_mut(), self.state.clone()); let res = self.inner.call(req); @@ -147,9 +151,11 @@ mod tests { fn layer_compiles() { let pool: sqlx::Pool = todo!(); + let (layer, _state) = Layer::new(pool); + let app = axum::Router::new() .route("/", axum::routing::get(|| async { "hello" })) - .layer(Layer::new(pool)); + .layer(layer); axum::Server::bind(todo!()).serve(app.into_make_service()); } diff --git a/src/lib.rs b/src/lib.rs index aa25699..f96dacf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,15 +17,20 @@ //! //! # Usage //! -//! To use the [`Tx`] extractor, you must first add [`Layer`] to your app: +//! To use the [`Tx`] extractor, you must first add [`State`] and [`Layer`] to your app: //! //! ``` //! # async fn foo() { //! let pool = /* any sqlx::Pool */ //! # sqlx::SqlitePool::connect(todo!()).await.unwrap(); +//! +//! let (layer, state) = axum_sqlx_tx::Layer::new(pool); +//! //! let app = axum::Router::new() //! // .route(...)s -//! .layer(axum_sqlx_tx::Layer::new(pool)); +//! # .route("/", axum::routing::get(|tx: axum_sqlx_tx::Tx| async move {})) +//! .layer(layer) +//! .with_state(state); //! # axum::Server::bind(todo!()).serve(app.into_make_service()); //! # } //! ``` @@ -67,7 +72,7 @@ //! the response. //! //! ``` -//! use axum::response::IntoResponse; +//! use axum::{response::IntoResponse, routing::post}; //! use axum_sqlx_tx::Tx; //! use sqlx::Sqlite; //! @@ -91,9 +96,13 @@ //! // Change the layer error type //! # async fn foo() { //! # let pool: sqlx::SqlitePool = todo!(); +//! +//! let (layer, state) = axum_sqlx_tx::Layer::new_with_error::(pool); +//! //! let app = axum::Router::new() -//! // .route(...)s -//! .layer(axum_sqlx_tx::Layer::new_with_error::(pool)); +//! .route("/", post(create_user)) +//! .layer(layer) +//! .with_state(state); //! # axum::Server::bind(todo!()).serve(app.into_make_service()); //! # } //! @@ -120,6 +129,37 @@ pub use crate::{ tx::Tx, }; +/// Application state that enables the [`Tx`] extractor. +/// +/// `State` must be provided to `Router`s in order to use the [`Tx`] extractor, or else attempting +/// to use the `Router` will not compile. +/// +/// `State` is constructed via [`Layer::new`](crate::Layer::new), which also returns a +/// [middleware](crate::Layer). The state and the middleware together enable the [`Tx`] extractor to +/// work. +#[derive(Debug)] +pub struct State { + pool: sqlx::Pool, +} + +impl State { + pub(crate) fn new(pool: sqlx::Pool) -> Self { + Self { pool } + } + + pub(crate) async fn transaction(&self) -> Result, sqlx::Error> { + self.pool.begin().await + } +} + +impl Clone for State { + fn clone(&self) -> Self { + Self { + pool: self.pool.clone(), + } + } +} + /// Possible errors when extracting [`Tx`] from a request. /// /// `axum` requires that the `FromRequest` `Rejection` implements `IntoResponse`, which this does diff --git a/src/tx.rs b/src/tx.rs index 2a0d29d..7d32ac5 100644 --- a/src/tx.rs +++ b/src/tx.rs @@ -2,13 +2,16 @@ use std::marker::PhantomData; -use axum_core::{extract::FromRequestParts, response::IntoResponse}; +use axum_core::{ + extract::{FromRef, FromRequestParts}, + response::IntoResponse, +}; use http::request::Parts; use sqlx::Transaction; use crate::{ slot::{Lease, Slot}, - Error, + Error, State, }; /// An `axum` extractor for a database transaction. @@ -115,6 +118,7 @@ impl std::ops::DerefMut for Tx { impl FromRequestParts for Tx where E: From + IntoResponse, + State: FromRef, { type Rejection = E; @@ -145,9 +149,9 @@ impl TxSlot { /// /// When the request extensions are dropped, `commit` can be called to commit the transaction /// (if any). - pub(crate) fn bind(extensions: &mut http::Extensions, pool: sqlx::Pool) -> Self { + pub(crate) fn bind(extensions: &mut http::Extensions, state: State) -> Self { let (slot, tx) = Slot::new_leased(None); - extensions.insert(Lazy { pool, tx }); + extensions.insert(Lazy { state, tx }); Self(slot) } @@ -164,7 +168,7 @@ impl TxSlot { /// When the transaction is started, it's inserted into the `Option` leased from the `TxSlot`, so /// that when `Lazy` is dropped the transaction is moved to the `TxSlot`. struct Lazy { - pool: sqlx::Pool, + state: State, tx: Lease>>>, } @@ -173,7 +177,7 @@ impl Lazy { let tx = if let Some(tx) = self.tx.as_mut() { tx } else { - let tx = self.pool.begin().await?; + let tx = self.state.transaction().await?; self.tx.insert(Slot::new(tx)) }; diff --git a/tests/lib.rs b/tests/lib.rs index 498d690..8d4737d 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -1,4 +1,5 @@ use axum::{middleware, response::IntoResponse}; +use axum_sqlx_tx::State; use sqlx::{sqlite::SqliteArguments, Arguments as _}; use tempfile::NamedTempFile; use tower::ServiceExt; @@ -96,6 +97,8 @@ async fn extract_from_middleware_and_handler() { next.run(req).await } + let (layer, state) = axum_sqlx_tx::Layer::new(pool.clone()); + let app = axum::Router::new() .route( "/", @@ -107,8 +110,12 @@ async fn extract_from_middleware_and_handler() { axum::Json(users) }), ) - .layer(middleware::from_fn(test_middleware)) - .layer(axum_sqlx_tx::Layer::new(pool.clone())); + .layer(middleware::from_fn_with_state( + state.clone(), + test_middleware, + )) + .layer(layer) + .with_state(state); let response = app .oneshot( @@ -126,9 +133,56 @@ async fn extract_from_middleware_and_handler() { assert_eq!(body.as_ref(), b"[[1,\"bobby tables\"]]"); } +#[tokio::test] +async fn substates() { + #[derive(Clone)] + struct MyState { + state: State, + } + + impl axum_core::extract::FromRef for State { + fn from_ref(state: &MyState) -> Self { + state.state.clone() + } + } + + let db = NamedTempFile::new().unwrap(); + let pool = sqlx::SqlitePool::connect(&format!("sqlite://{}", db.path().display())) + .await + .unwrap(); + + let (layer, state) = axum_sqlx_tx::Layer::new(pool); + + let app = axum::Router::new() + .route("/", axum::routing::get(|_: Tx| async move {})) + .layer(layer) + .with_state(MyState { state }); + let response = app + .oneshot( + http::Request::builder() + .uri("/") + .body(axum::body::Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert!(response.status().is_success()); +} + #[tokio::test] async fn missing_layer() { - let app = axum::Router::new().route("/", axum::routing::get(|_: Tx| async move {})); + let db = NamedTempFile::new().unwrap(); + let pool = sqlx::SqlitePool::connect(&format!("sqlite://{}", db.path().display())) + .await + .unwrap(); + + // Note that we have to explicitly ignore the `_layer`, making it hard to do this accidentally. + let (_layer, state) = axum_sqlx_tx::Layer::new(pool); + + let app = axum::Router::new() + .route("/", axum::routing::get(|_: Tx| async move {})) + .with_state(state); let response = app .oneshot( http::Request::builder() @@ -187,6 +241,8 @@ async fn layer_error_override() { .await .unwrap(); + let (layer, state) = axum_sqlx_tx::Layer::new_with_error::(pool.clone()); + let app = axum::Router::new() .route( "/", @@ -197,7 +253,8 @@ async fn layer_error_override() { .unwrap(); }), ) - .layer(axum_sqlx_tx::Layer::new_with_error::(pool.clone())); + .layer(layer) + .with_state(state); let response = app .oneshot( @@ -242,7 +299,7 @@ struct Response { async fn build_app(handler: H) -> (NamedTempFile, sqlx::SqlitePool, Response) where - H: axum::handler::Handler, + H: axum::handler::Handler, axum::body::Body>, T: 'static, { let db = NamedTempFile::new().unwrap(); @@ -255,9 +312,12 @@ where .await .unwrap(); + let (layer, state) = axum_sqlx_tx::Layer::new(pool.clone()); + let app = axum::Router::new() .route("/", axum::routing::get(handler)) - .layer(axum_sqlx_tx::Layer::new(pool.clone())); + .layer(layer) + .with_state(state); let response = app .oneshot( From 32a0a0fd3fda415c6b82180d174ecc22f5af06c5 Mon Sep 17 00:00:00 2001 From: Chris Connelly Date: Thu, 2 Nov 2023 11:31:26 +0000 Subject: [PATCH 3/6] feat: introduce `Tx::setup`, `Tx::config`, and `Config` APIs This will centralise the configuration API and reduce the number of types that need to interacted with directly (i.e. `use`d) down to just `Tx`. --- src/lib.rs | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/tx.rs | 36 +++++++++++++++++++++++++++++++++++- 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index f96dacf..8b92d08 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -124,11 +124,61 @@ mod layer; mod slot; mod tx; +use std::marker::PhantomData; + pub use crate::{ layer::{Layer, Service}, tx::Tx, }; +/// Configuration for [`Tx`] extractors. +/// +/// Use `Config` to configure and build the [`State`] and [`Layer`] supporting [`Tx`] extractors. +/// +/// A new `Config` can be constructed using [`Tx::config`]. +/// +/// ``` +/// # async fn foo() { +/// # let pool: sqlx::SqlitePool = todo!(); +/// type Tx = axum_sqlx_tx::Tx; +/// +/// let config = Tx::config(pool); +/// # } +/// ``` +pub struct Config { + pool: sqlx::Pool, + _layer_error: PhantomData, +} + +impl Config { + fn new(pool: sqlx::Pool) -> Self { + Self { + pool, + _layer_error: PhantomData, + } + } + + /// Change the layer error type. + /// + /// The [`Layer`] middleware can return an error if the transaction fails to commit after a + /// successful response. + pub fn layer_error(self) -> Config + where + Error: Into, + { + Config { + pool: self.pool, + _layer_error: PhantomData, + } + } + + /// Create a [`State`] and [`Layer`] to enable the [`Tx`] extractor. + pub fn setup(self) -> (State, Layer) { + let (layer, state) = Layer::new(self.pool); + (state, layer.with_error()) + } +} + /// Application state that enables the [`Tx`] extractor. /// /// `State` must be provided to `Router`s in order to use the [`Tx`] extractor, or else attempting diff --git a/src/tx.rs b/src/tx.rs index 7d32ac5..f08ab05 100644 --- a/src/tx.rs +++ b/src/tx.rs @@ -11,7 +11,7 @@ use sqlx::Transaction; use crate::{ slot::{Lease, Slot}, - Error, State, + Config, Error, State, }; /// An `axum` extractor for a database transaction. @@ -76,6 +76,40 @@ use crate::{ pub struct Tx(Lease>, PhantomData); impl Tx { + /// Crate a [`State`] and [`Layer`](crate::Layer) to enable the extractor. + /// + /// This is convenient to use from a type alias, e.g. + /// + /// ``` + /// # async fn foo() { + /// type Tx = axum_sqlx_tx::Tx; + /// + /// let pool: sqlx::SqlitePool = todo!(); + /// let (state, layer) = Tx::setup(pool); + /// # } + /// ``` + pub fn setup(pool: sqlx::Pool) -> (State, crate::Layer) { + Config::new(pool).setup() + } + + /// Configure extractor behaviour. + /// + /// See the [`Config`] API for available options. + /// + /// This is convenient to use from a type alias, e.g. + /// + /// ``` + /// # async fn foo() { + /// type Tx = axum_sqlx_tx::Tx; + /// + /// # let pool: sqlx::SqlitePool = todo!(); + /// let config = Tx::config(pool); + /// # } + /// ``` + pub fn config(pool: sqlx::Pool) -> Config { + Config::new(pool) + } + /// Explicitly commit the transaction. /// /// By default, the transaction will be committed when a successful response is returned From e8f694e301b01ebbd18e368c7f33601cab786912 Mon Sep 17 00:00:00 2001 From: Chris Connelly Date: Thu, 2 Nov 2023 13:09:28 +0000 Subject: [PATCH 4/6] refactor!: remove `Layer::new` and `Layer::new_with_error` This leaves `Tx::{setup,config}` as the only entrypoints to the API. Error handling documentation was also rewritten. BREAKING CHANGE: `Layer::{new,new_with_error}` have been removed. Use `Tx::{setup,config}` instead. --- examples/example.rs | 4 +- src/layer.rs | 39 +++-------- src/lib.rs | 154 ++++++++++++++++++++------------------------ tests/lib.rs | 15 +++-- 4 files changed, 89 insertions(+), 123 deletions(-) diff --git a/examples/example.rs b/examples/example.rs index bb9d348..6b02e04 100644 --- a/examples/example.rs +++ b/examples/example.rs @@ -5,7 +5,7 @@ use std::error::Error; use axum::{response::IntoResponse, routing::get, Json}; use http::StatusCode; -// OPTIONAL: use a type alias to avoid repeating your database type +// Recommended: use a type alias to avoid repeating your database type type Tx = axum_sqlx_tx::Tx; #[tokio::main] @@ -19,7 +19,7 @@ async fn main() -> Result<(), Box> { .execute(&pool) .await?; - let (layer, state) = axum_sqlx_tx::Layer::new(pool.clone()); + let (state, layer) = Tx::setup(pool); // Standard axum app setup let app = axum::Router::new() diff --git a/src/layer.rs b/src/layer.rs index 3d89535..2a4d3c1 100644 --- a/src/layer.rs +++ b/src/layer.rs @@ -25,35 +25,12 @@ pub struct Layer { _error: PhantomData, } -impl Layer { - /// Construct a new layer and [`State`] with the given `pool`. - /// - /// A connection will be obtained from the pool the first time a [`Tx`](crate::Tx) is extracted - /// from a request. - /// - /// If you want to access the pool outside of a transaction, you should add it also with - /// [`axum::Extension`] or as router state. - /// - /// To use a different type than [`Error`] to convert commit errors into responses, see - /// [`new_with_error`](Self::new_with_error). - /// - /// [`axum::Extension`]: https://docs.rs/axum/latest/axum/extract/struct.Extension.html - pub fn new(pool: sqlx::Pool) -> (Self, State) { - Self::new_with_error(pool) - } - - /// Construct a new layer with a specific error type. - /// - /// See [`Layer::new`] for more information. - pub fn new_with_error(pool: sqlx::Pool) -> (Layer, State) { - let state = State::new(pool); - ( - Layer { - state: state.clone(), - _error: PhantomData, - }, +impl Layer { + pub(crate) fn new(state: State) -> Self { + Self { state, - ) + _error: PhantomData, + } } } @@ -143,15 +120,17 @@ where #[cfg(test)] mod tests { + use crate::{Error, State}; + use super::Layer; // The trait shenanigans required by axum for layers are significant, so this "test" ensures // we've got it right. #[allow(unused, unreachable_code, clippy::diverging_sub_expression)] fn layer_compiles() { - let pool: sqlx::Pool = todo!(); + let state: State = todo!(); - let (layer, _state) = Layer::new(pool); + let layer = Layer::<_, Error>::new(state); let app = axum::Router::new() .route("/", axum::routing::get(|| async { "hello" })) diff --git a/src/lib.rs b/src/lib.rs index 8b92d08..a851463 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,18 +17,22 @@ //! //! # Usage //! -//! To use the [`Tx`] extractor, you must first add [`State`] and [`Layer`] to your app: +//! To use the [`Tx`] extractor, you must first add [`State`] and [`Layer`] to your app. [`State`] +//! holds the configuration for the extractor, and the [`Layer`] middleware manages the +//! request-bound transaction. //! //! ``` //! # async fn foo() { -//! let pool = /* any sqlx::Pool */ -//! # sqlx::SqlitePool::connect(todo!()).await.unwrap(); +//! // It's recommended to create aliases specialised for your extractor(s) +//! type Tx = axum_sqlx_tx::Tx; //! -//! let (layer, state) = axum_sqlx_tx::Layer::new(pool); +//! let pool = sqlx::SqlitePool::connect("...").await.unwrap(); +//! +//! let (state, layer) = Tx::setup(pool); //! //! let app = axum::Router::new() //! // .route(...)s -//! # .route("/", axum::routing::get(|tx: axum_sqlx_tx::Tx| async move {})) +//! # .route("/", axum::routing::get(|tx: Tx| async move {})) //! .layer(layer) //! .with_state(state); //! # axum::Server::bind(todo!()).serve(app.into_make_service()); @@ -38,10 +42,9 @@ //! You can then simply add [`Tx`] as an argument to your handlers: //! //! ``` -//! use axum_sqlx_tx::Tx; -//! use sqlx::Sqlite; +//! type Tx = axum_sqlx_tx::Tx; //! -//! async fn create_user(mut tx: Tx, /* ... */) { +//! async fn create_user(mut tx: Tx, /* ... */) { //! // `&mut Tx` implements `sqlx::Executor` //! let user = sqlx::query("INSERT INTO users (...) VALUES (...)") //! .fetch_one(&mut tx) @@ -55,62 +58,14 @@ //! } //! ``` //! -//! If you forget to add the middleware you'll get [`Error::MissingExtension`] (internal server -//! error) when using the extractor. You'll also get an error ([`Error::OverlappingExtractors`]) if -//! you have multiple `Tx` arguments in a single handler, or call `Tx::from_request` multiple times -//! in a single middleware. -//! //! ## Error handling //! -//! `axum` requires that middleware do not return errors, and that the errors returned by extractors -//! implement `IntoResponse`. By default, [`Error`] is used by [`Layer`] and [`Tx`] to -//! convert errors into HTTP 500 responses, with the error's `Display` value as the response body, -//! however it's generally not a good practice to return internal error details to clients! -//! -//! To make it easier to customise error handling, both [`Layer`] and [`Tx`] have a second generic -//! type parameter, `E`, that can be used to override the error type that will be used to convert -//! the response. -//! -//! ``` -//! use axum::{response::IntoResponse, routing::post}; -//! use axum_sqlx_tx::Tx; -//! use sqlx::Sqlite; -//! -//! struct MyError(axum_sqlx_tx::Error); -//! -//! // Errors must implement From -//! impl From for MyError { -//! fn from(error: axum_sqlx_tx::Error) -> Self { -//! Self(error) -//! } -//! } -//! -//! // Errors must implement IntoResponse -//! impl IntoResponse for MyError { -//! fn into_response(self) -> axum::response::Response { -//! // note that you would probably want to log the error or something -//! (http::StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response() -//! } -//! } -//! -//! // Change the layer error type -//! # async fn foo() { -//! # let pool: sqlx::SqlitePool = todo!(); +//! `axum` requires that errors can be turned into responses. The [`Error`] type converts into a +//! HTTP 500 response with the error message as the response body. This may be suitable for +//! development or internal services but it's generally not advisable to return internal error +//! details to clients. //! -//! let (layer, state) = axum_sqlx_tx::Layer::new_with_error::(pool); -//! -//! let app = axum::Router::new() -//! .route("/", post(create_user)) -//! .layer(layer) -//! .with_state(state); -//! # axum::Server::bind(todo!()).serve(app.into_make_service()); -//! # } -//! -//! // Change the extractor error type -//! async fn create_user(mut tx: Tx, /* ... */) { -//! /* ... */ -//! } -//! ``` +//! See [`Error`] for how to customise error handling. //! //! # Examples //! @@ -133,9 +88,9 @@ pub use crate::{ /// Configuration for [`Tx`] extractors. /// -/// Use `Config` to configure and build the [`State`] and [`Layer`] supporting [`Tx`] extractors. +/// Use `Config` to configure and create a [`State`] and [`Layer`]. /// -/// A new `Config` can be constructed using [`Tx::config`]. +/// Access the `Config` API from [`Tx::config`]. /// /// ``` /// # async fn foo() { @@ -159,9 +114,6 @@ impl Config { } /// Change the layer error type. - /// - /// The [`Layer`] middleware can return an error if the transaction fails to commit after a - /// successful response. pub fn layer_error(self) -> Config where Error: Into, @@ -174,8 +126,9 @@ impl Config { /// Create a [`State`] and [`Layer`] to enable the [`Tx`] extractor. pub fn setup(self) -> (State, Layer) { - let (layer, state) = Layer::new(self.pool); - (state, layer.with_error()) + let state = State::new(self.pool); + let layer = Layer::new(state.clone()); + (state, layer) } } @@ -184,9 +137,8 @@ impl Config { /// `State` must be provided to `Router`s in order to use the [`Tx`] extractor, or else attempting /// to use the `Router` will not compile. /// -/// `State` is constructed via [`Layer::new`](crate::Layer::new), which also returns a -/// [middleware](crate::Layer). The state and the middleware together enable the [`Tx`] extractor to -/// work. +/// `State` is constructed via [`Tx::setup`] or [`Config::setup`], which also return a middleware +/// [`Layer`]. The state and the middleware together enable the [`Tx`] extractor to work. #[derive(Debug)] pub struct State { pool: sqlx::Pool, @@ -212,35 +164,69 @@ impl Clone for State { /// Possible errors when extracting [`Tx`] from a request. /// -/// `axum` requires that the `FromRequest` `Rejection` implements `IntoResponse`, which this does -/// by returning the `Display` representation of the variant. Note that this means returning -/// configuration and database errors to clients, but you can override the type of error that -/// `Tx::from_request` returns using the `E` generic parameter: +/// Errors can occur at two points during the request lifecycle: +/// +/// 1. The [`Tx`] extractor might fail to obtain a connection from the pool and `BEGIN` a +/// transaction. This could be due to: +/// +/// - Forgetting to add the middleware: [`Error::MissingExtension`]. +/// - Calling the extractor multiple times in the same request: [`Error::OverlappingExtractors`]. +/// - A problem communicating with the database: [`Error::Database`]. +/// +/// 2. The middleware [`Layer`] might fail to commit the transaction. This could be due to a problem +/// communicating with the database, or else a logic error (e.g. unsatisfied deferred +/// constraint): [`Error::Database`]. +/// +/// `axum` requires that errors can be turned into responses. The [`Error`] type converts into a +/// HTTP 500 response with the error message as the response body. This may be suitable for +/// development or internal services but it's generally not advisable to return internal error +/// details to clients. +/// +/// You can override the error types for both the [`Tx`] extractor and [`Layer`]: +/// +/// - Override the [`Tx`]`` error type using the `E` generic type parameter. +/// - Override the [`Layer`] error type using [`Config::layer_error`]. +/// +/// In both cases, the error type must implement `From<`[`Error`]`>` and +/// `axum::response::IntoResponse`. /// /// ``` -/// use axum::response::IntoResponse; -/// use axum_sqlx_tx::Tx; -/// use sqlx::Sqlite; +/// use axum::{response::IntoResponse, routing::post}; /// /// struct MyError(axum_sqlx_tx::Error); /// -/// // The error type must implement From /// impl From for MyError { /// fn from(error: axum_sqlx_tx::Error) -> Self { /// Self(error) /// } /// } /// -/// // The error type must implement IntoResponse /// impl IntoResponse for MyError { /// fn into_response(self) -> axum::response::Response { +/// // note that you would probably want to log the error as well /// (http::StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response() /// } /// } /// -/// async fn handler(tx: Tx) { -/// /* ... */ -/// } +/// // Override the `Tx` error type using the second generic type parameter +/// type Tx = axum_sqlx_tx::Tx; +/// +/// # async fn foo() { +/// let pool = sqlx::SqlitePool::connect("...").await.unwrap(); +/// +/// let (state, layer) = Tx::config(pool) +/// // Override the `Layer` error type using the `Config` API +/// .layer_error::() +/// .setup(); +/// # let app = axum::Router::new() +/// # .route("/", post(create_user)) +/// # .layer(layer) +/// # .with_state(state); +/// # axum::Server::bind(todo!()).serve(app.into_make_service()); +/// # } +/// # async fn create_user(mut tx: Tx, /* ... */) { +/// # /* ... */ +/// # } /// ``` #[derive(Debug, thiserror::Error)] pub enum Error { @@ -252,7 +238,7 @@ pub enum Error { #[error("axum_sqlx_tx::Tx extractor used multiple times in the same handler/middleware")] OverlappingExtractors, - /// A database error occurred when starting the transaction. + /// A database error occurred when starting or committing the transaction. #[error(transparent)] Database { #[from] diff --git a/tests/lib.rs b/tests/lib.rs index 8d4737d..865e8af 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -4,7 +4,7 @@ use sqlx::{sqlite::SqliteArguments, Arguments as _}; use tempfile::NamedTempFile; use tower::ServiceExt; -type Tx = axum_sqlx_tx::Tx; +type Tx = axum_sqlx_tx::Tx; #[tokio::test] async fn commit_on_success() { @@ -97,7 +97,7 @@ async fn extract_from_middleware_and_handler() { next.run(req).await } - let (layer, state) = axum_sqlx_tx::Layer::new(pool.clone()); + let (state, layer) = Tx::setup(pool); let app = axum::Router::new() .route( @@ -151,7 +151,7 @@ async fn substates() { .await .unwrap(); - let (layer, state) = axum_sqlx_tx::Layer::new(pool); + let (state, layer) = Tx::setup(pool); let app = axum::Router::new() .route("/", axum::routing::get(|_: Tx| async move {})) @@ -178,7 +178,7 @@ async fn missing_layer() { .unwrap(); // Note that we have to explicitly ignore the `_layer`, making it hard to do this accidentally. - let (_layer, state) = axum_sqlx_tx::Layer::new(pool); + let (state, _layer) = Tx::setup(pool); let app = axum::Router::new() .route("/", axum::routing::get(|_: Tx| async move {})) @@ -212,7 +212,8 @@ async fn overlapping_extractors() { #[tokio::test] async fn extractor_error_override() { - let (_, _, response) = build_app(|_: Tx, _: Tx| async move {}).await; + let (_, _, response) = + build_app(|_: Tx, _: axum_sqlx_tx::Tx| async move {}).await; assert!(response.status.is_client_error()); assert_eq!(response.body, "internal server error"); @@ -241,7 +242,7 @@ async fn layer_error_override() { .await .unwrap(); - let (layer, state) = axum_sqlx_tx::Layer::new_with_error::(pool.clone()); + let (state, layer) = Tx::config(pool).layer_error::().setup(); let app = axum::Router::new() .route( @@ -312,7 +313,7 @@ where .await .unwrap(); - let (layer, state) = axum_sqlx_tx::Layer::new(pool.clone()); + let (state, layer) = Tx::setup(pool.clone()); let app = axum::Router::new() .route("/", axum::routing::get(handler)) From e2701a7431a461ea62d9a7c8fd77dd2b8792ed63 Mon Sep 17 00:00:00 2001 From: Chris Connelly Date: Thu, 2 Nov 2023 13:26:36 +0000 Subject: [PATCH 5/6] refactor!: rationalise generic error type bounds Required bounds are now present on more `impl` blocks to provide better diagnostics if error types are missing required traits. The `Layer` error type is now only required to implement `From` (technically `sqlx::Error` is required to implement `Into` since this is more flexible). BREAKING CHANGE: The tighter bounds on `impl` blocks shouldn't break already working code, however the change to the `Layer` error type bounds will break error overrides that implement `From` - this must be changed to `From`. --- src/layer.rs | 23 ++++++++++++++++------- src/lib.rs | 31 +++++++++++++++++++++++-------- src/tx.rs | 2 +- tests/lib.rs | 24 +++++++++++++++++++----- 4 files changed, 59 insertions(+), 21 deletions(-) diff --git a/src/layer.rs b/src/layer.rs index 2a4d3c1..b0ebf6d 100644 --- a/src/layer.rs +++ b/src/layer.rs @@ -7,7 +7,7 @@ use bytes::Bytes; use futures_core::future::BoxFuture; use http_body::{combinators::UnsyncBoxBody, Body}; -use crate::{tx::TxSlot, Error, State}; +use crate::{tx::TxSlot, State}; /// A [`tower_layer::Layer`] that enables the [`Tx`] extractor. /// @@ -20,12 +20,16 @@ use crate::{tx::TxSlot, Error, State}; /// /// [`Tx`]: crate::Tx /// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html -pub struct Layer { +pub struct Layer { state: State, _error: PhantomData, } -impl Layer { +impl Layer +where + E: IntoResponse, + sqlx::Error: Into, +{ pub(crate) fn new(state: State) -> Self { Self { state, @@ -43,7 +47,11 @@ impl Clone for Layer { } } -impl tower_layer::Layer for Layer { +impl tower_layer::Layer for Layer +where + E: IntoResponse, + sqlx::Error: Into, +{ type Service = Service; fn layer(&self, inner: S) -> Self::Service { @@ -58,7 +66,7 @@ impl tower_layer::Layer for Layer { /// A [`tower_service::Service`] that enables the [`Tx`](crate::Tx) extractor. /// /// See [`Layer`] for more information. -pub struct Service { +pub struct Service { state: State, inner: S, _error: PhantomData, @@ -84,7 +92,8 @@ where Error = std::convert::Infallible, >, S::Future: Send + 'static, - E: From + IntoResponse, + E: IntoResponse, + sqlx::Error: Into, ResBody: Body + Send + 'static, ResBody::Error: Into>, { @@ -109,7 +118,7 @@ where if !res.status().is_server_error() && !res.status().is_client_error() { if let Err(error) = transaction.commit().await { - return Ok(E::from(Error::Database { error }).into_response()); + return Ok(error.into().into_response()); } } diff --git a/src/lib.rs b/src/lib.rs index a851463..e7f3545 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -105,7 +105,11 @@ pub struct Config { _layer_error: PhantomData, } -impl Config { +impl Config +where + LayerError: axum_core::response::IntoResponse, + sqlx::Error: Into, +{ fn new(pool: sqlx::Pool) -> Self { Self { pool, @@ -116,7 +120,7 @@ impl Config { /// Change the layer error type. pub fn layer_error(self) -> Config where - Error: Into, + sqlx::Error: Into, { Config { pool: self.pool, @@ -184,20 +188,31 @@ impl Clone for State { /// /// You can override the error types for both the [`Tx`] extractor and [`Layer`]: /// -/// - Override the [`Tx`]`` error type using the `E` generic type parameter. -/// - Override the [`Layer`] error type using [`Config::layer_error`]. +/// - Override the [`Tx`]`` error type using the `E` generic type parameter. `E` must be +/// convertible from [`Error`] (e.g. [`Error`]`: Into`). /// -/// In both cases, the error type must implement `From<`[`Error`]`>` and -/// `axum::response::IntoResponse`. +/// - Override the [`Layer`] error type using [`Config::layer_error`]. The layer error type must be +/// convertible from `sqlx::Error` (e.g. `sqlx::Error: Into`). +/// +/// In both cases, the error type must implement `axum::response::IntoResponse`. /// /// ``` /// use axum::{response::IntoResponse, routing::post}; /// -/// struct MyError(axum_sqlx_tx::Error); +/// enum MyError{ +/// Extractor(axum_sqlx_tx::Error), +/// Layer(sqlx::Error), +/// } /// /// impl From for MyError { /// fn from(error: axum_sqlx_tx::Error) -> Self { -/// Self(error) +/// Self::Extractor(error) +/// } +/// } +/// +/// impl From for MyError { +/// fn from(error: sqlx::Error) -> Self { +/// Self::Layer(error) /// } /// } /// diff --git a/src/tx.rs b/src/tx.rs index f08ab05..13e24fd 100644 --- a/src/tx.rs +++ b/src/tx.rs @@ -88,7 +88,7 @@ impl Tx { /// let (state, layer) = Tx::setup(pool); /// # } /// ``` - pub fn setup(pool: sqlx::Pool) -> (State, crate::Layer) { + pub fn setup(pool: sqlx::Pool) -> (State, crate::Layer) { Config::new(pool).setup() } diff --git a/tests/lib.rs b/tests/lib.rs index 865e8af..773a013 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -213,7 +213,7 @@ async fn overlapping_extractors() { #[tokio::test] async fn extractor_error_override() { let (_, _, response) = - build_app(|_: Tx, _: axum_sqlx_tx::Tx| async move {}).await; + build_app(|_: Tx, _: axum_sqlx_tx::Tx| async move {}).await; assert!(response.status.is_client_error()); assert_eq!(response.body, "internal server error"); @@ -242,7 +242,7 @@ async fn layer_error_override() { .await .unwrap(); - let (state, layer) = Tx::config(pool).layer_error::().setup(); + let (state, layer) = Tx::config(pool).layer_error::().setup(); let app = axum::Router::new() .route( @@ -335,15 +335,29 @@ where (db, pool, Response { status, body }) } -struct MyError(axum_sqlx_tx::Error); +struct MyExtractorError(axum_sqlx_tx::Error); -impl From for MyError { +impl From for MyExtractorError { fn from(error: axum_sqlx_tx::Error) -> Self { Self(error) } } -impl IntoResponse for MyError { +impl IntoResponse for MyExtractorError { + fn into_response(self) -> axum::response::Response { + (http::StatusCode::IM_A_TEAPOT, "internal server error").into_response() + } +} + +struct MyLayerError(sqlx::Error); + +impl From for MyLayerError { + fn from(error: sqlx::Error) -> Self { + Self(error) + } +} + +impl IntoResponse for MyLayerError { fn into_response(self) -> axum::response::Response { (http::StatusCode::IM_A_TEAPOT, "internal server error").into_response() } From e896b5aa93d0f18a5dc3b3a0fe34842869220124 Mon Sep 17 00:00:00 2001 From: Chris Connelly Date: Thu, 2 Nov 2023 13:38:48 +0000 Subject: [PATCH 6/6] chore: doc fix --- src/tx.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tx.rs b/src/tx.rs index 13e24fd..9eabce8 100644 --- a/src/tx.rs +++ b/src/tx.rs @@ -113,8 +113,8 @@ impl Tx { /// Explicitly commit the transaction. /// /// By default, the transaction will be committed when a successful response is returned - /// (specifically, when the [`Service`](crate::Service) middleware intercepts an HTTP `2XX` - /// response). This method allows the transaction to be committed explicitly. + /// (specifically, when the [`Service`](crate::Service) middleware intercepts an HTTP `2XX` or + /// `3XX` response). This method allows the transaction to be committed explicitly. /// /// **Note:** trying to use the `Tx` extractor again after calling `commit` will currently /// generate [`Error::OverlappingExtractors`] errors. This may change in future.