Skip to content

Commit

Permalink
Merge pull request #27 from digital-society-coop/state
Browse files Browse the repository at this point in the history
Simplify API and use "state" to improve type-safety
  • Loading branch information
connec committed Nov 2, 2023
2 parents fcba66c + e896b5a commit 1940658
Show file tree
Hide file tree
Showing 6 changed files with 332 additions and 130 deletions.
8 changes: 6 additions & 2 deletions examples/example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::error::Error;
use axum::{response::IntoResponse, routing::get, Json};
use http::StatusCode;

// OPTIONAL: use a type alias to avoid repeating your database type
// Recommended: use a type alias to avoid repeating your database type
type Tx = axum_sqlx_tx::Tx<sqlx::Sqlite>;

#[tokio::main]
Expand All @@ -19,11 +19,15 @@ async fn main() -> Result<(), Box<dyn Error>> {
.execute(&pool)
.await?;

let (state, layer) = Tx::setup(pool);

// Standard axum app setup
let app = axum::Router::new()
.route("/numbers", get(list_numbers).post(generate_number))
// Apply the Tx middleware
.layer(axum_sqlx_tx::Layer::new(pool.clone()));
.layer(layer)
// Add the Tx state
.with_state(state);

let server = axum::Server::bind(&([0, 0, 0, 0], 0).into()).serve(app.into_make_service());

Expand Down
68 changes: 31 additions & 37 deletions src/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use bytes::Bytes;
use futures_core::future::BoxFuture;
use http_body::{combinators::UnsyncBoxBody, Body};

use crate::{tx::TxSlot, Error};
use crate::{tx::TxSlot, State};

/// A [`tower_layer::Layer`] that enables the [`Tx`] extractor.
///
Expand All @@ -20,34 +20,19 @@ use crate::{tx::TxSlot, Error};
///
/// [`Tx`]: crate::Tx
/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
pub struct Layer<DB: sqlx::Database, E = Error> {
pool: sqlx::Pool<DB>,
pub struct Layer<DB: sqlx::Database, E> {
state: State<DB>,
_error: PhantomData<E>,
}

impl<DB: sqlx::Database> Layer<DB> {
/// Construct a new layer with the given `pool`.
///
/// A connection will be obtained from the pool the first time a [`Tx`](crate::Tx) is extracted
/// from a request.
///
/// If you want to access the pool outside of a transaction, you should add it also with
/// [`axum::Extension`].
///
/// To use a different type than [`Error`] to convert commit errors into responses, see
/// [`new_with_error`](Self::new_with_error).
///
/// [`axum::Extension`]: https://docs.rs/axum/latest/axum/extract/struct.Extension.html
pub fn new(pool: sqlx::Pool<DB>) -> Self {
Self::new_with_error(pool)
}

/// Construct a new layer with a specific error type.
///
/// See [`Layer::new`] for more information.
pub fn new_with_error<E>(pool: sqlx::Pool<DB>) -> Layer<DB, E> {
Layer {
pool,
impl<DB: sqlx::Database, E> Layer<DB, E>
where
E: IntoResponse,
sqlx::Error: Into<E>,
{
pub(crate) fn new(state: State<DB>) -> Self {
Self {
state,
_error: PhantomData,
}
}
Expand All @@ -56,18 +41,22 @@ impl<DB: sqlx::Database> Layer<DB> {
impl<DB: sqlx::Database, E> Clone for Layer<DB, E> {
fn clone(&self) -> Self {
Self {
pool: self.pool.clone(),
state: self.state.clone(),
_error: self._error,
}
}
}

impl<DB: sqlx::Database, S, E> tower_layer::Layer<S> for Layer<DB, E> {
impl<DB: sqlx::Database, S, E> tower_layer::Layer<S> for Layer<DB, E>
where
E: IntoResponse,
sqlx::Error: Into<E>,
{
type Service = Service<DB, S, E>;

fn layer(&self, inner: S) -> Self::Service {
Service {
pool: self.pool.clone(),
state: self.state.clone(),
inner,
_error: self._error,
}
Expand All @@ -77,8 +66,8 @@ impl<DB: sqlx::Database, S, E> tower_layer::Layer<S> for Layer<DB, E> {
/// A [`tower_service::Service`] that enables the [`Tx`](crate::Tx) extractor.
///
/// See [`Layer`] for more information.
pub struct Service<DB: sqlx::Database, S, E = Error> {
pool: sqlx::Pool<DB>,
pub struct Service<DB: sqlx::Database, S, E> {
state: State<DB>,
inner: S,
_error: PhantomData<E>,
}
Expand All @@ -87,7 +76,7 @@ pub struct Service<DB: sqlx::Database, S, E = Error> {
impl<DB: sqlx::Database, S: Clone, E> Clone for Service<DB, S, E> {
fn clone(&self) -> Self {
Self {
pool: self.pool.clone(),
state: self.state.clone(),
inner: self.inner.clone(),
_error: self._error,
}
Expand All @@ -103,7 +92,8 @@ where
Error = std::convert::Infallible,
>,
S::Future: Send + 'static,
E: From<Error> + IntoResponse,
E: IntoResponse,
sqlx::Error: Into<E>,
ResBody: Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
Expand All @@ -119,7 +109,7 @@ where
}

fn call(&mut self, mut req: http::Request<ReqBody>) -> Self::Future {
let transaction = TxSlot::bind(req.extensions_mut(), self.pool.clone());
let transaction = TxSlot::bind(req.extensions_mut(), self.state.clone());

let res = self.inner.call(req);

Expand All @@ -128,7 +118,7 @@ where

if !res.status().is_server_error() && !res.status().is_client_error() {
if let Err(error) = transaction.commit().await {
return Ok(E::from(Error::Database { error }).into_response());
return Ok(error.into().into_response());
}
}

Expand All @@ -139,17 +129,21 @@ where

#[cfg(test)]
mod tests {
use crate::{Error, State};

use super::Layer;

// The trait shenanigans required by axum for layers are significant, so this "test" ensures
// we've got it right.
#[allow(unused, unreachable_code, clippy::diverging_sub_expression)]
fn layer_compiles() {
let pool: sqlx::Pool<sqlx::Sqlite> = todo!();
let state: State<sqlx::Sqlite> = todo!();

let layer = Layer::<_, Error>::new(state);

let app = axum::Router::new()
.route("/", axum::routing::get(|| async { "hello" }))
.layer(Layer::new(pool));
.layer(layer);

axum::Server::bind(todo!()).serve(app.into_make_service());
}
Expand Down
Loading

0 comments on commit 1940658

Please sign in to comment.