diff --git a/Cargo.lock b/Cargo.lock index da52a0f9..3e0295f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -146,12 +146,12 @@ dependencies = [ "log", "once_cell", "paste", - "postgres", "proc-macro2 1.0.47", "quote 1.0.21", "r2d2", "rusqlite", "serde_json", + "tokio-postgres", "uuid", ] @@ -183,18 +183,18 @@ dependencies = [ name = "butane_core" version = "0.5.0" dependencies = [ + "async-trait", "bytes", "cfg-if", "chrono", "fallible-iterator", "fallible-streaming-iterator", "fs2", + "futures-util", "hex", "log", "native-tls", - "once_cell", "pin-project", - "postgres", "postgres-native-tls", "proc-macro2 1.0.47", "quote 1.0.21", @@ -205,6 +205,8 @@ dependencies = [ "serde_json", "syn 1.0.103", "thiserror", + "tokio", + "tokio-postgres", "uuid", ] @@ -1027,20 +1029,6 @@ version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160" -[[package]] -name = "postgres" -version = "0.19.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "960c214283ef8f0027974c03e9014517ced5db12f021a9abb66185a5751fab0a" -dependencies = [ - "bytes", - "fallible-iterator", - "futures-util", - "log", - "tokio", - "tokio-postgres", -] - [[package]] name = "postgres-native-tls" version = "0.5.0" @@ -1504,9 +1492,9 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" [[package]] name = "tokio" -version = "1.21.2" +version = "1.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9e03c497dc955702ba729190dc4aac6f2a0ce97f913e5b1b5912fc5039d9099" +checksum = "597a12a59981d9e3c38d216785b0c37399f6e415e8d0712047620f189371b0bb" dependencies = [ "autocfg", "bytes", @@ -1515,7 +1503,7 @@ dependencies = [ "mio", "pin-project-lite", "socket2", - "winapi", + "windows-sys 0.42.0", ] [[package]] diff --git a/butane/Cargo.toml b/butane/Cargo.toml index c31b764d..f3783cbc 100644 --- a/butane/Cargo.toml +++ b/butane/Cargo.toml @@ -40,11 +40,12 @@ log_for_test = {package="log", version = "0.4"} quote = "1.0" proc-macro2="1.0" once_cell="1.5.2" -postgres = { version = "0.19", features=["with-geo-types-0_7"] } +tokio-postgres = { version = "0.7", features=["with-geo-types-0_7"] } +# tokio-postgres = { version = "0.7" } r2d2_for_test = {package="r2d2", version = "0.8"} rusqlite = {workspace=true} serde_json = "1.0" uuid_for_test = {package="uuid", version = "1.2", features=["v4"] } [package.metadata.docs.rs] -all-features = true \ No newline at end of file +all-features = true diff --git a/butane/tests/basic.rs b/butane/tests/basic.rs index d1dcec7a..15a5c67c 100644 --- a/butane/tests/basic.rs +++ b/butane/tests/basic.rs @@ -4,7 +4,7 @@ use butane::{butane_type, find, model, query}; use butane::{ForeignKey, ObjectState}; use paste; #[cfg(feature = "pg")] -use postgres; +use tokio_postgres as postgres; #[cfg(feature = "sqlite")] use rusqlite; diff --git a/butane/tests/custom_pg.rs b/butane/tests/custom_pg.rs index 4e44ec50..b00f70f0 100644 --- a/butane/tests/custom_pg.rs +++ b/butane/tests/custom_pg.rs @@ -9,7 +9,7 @@ mod custom_pg { use butane::{butane_type, db::Connection, model, ObjectState}; use butane::{FieldType, FromSql, SqlType, SqlVal, SqlValRef, ToSql}; use geo_types; - use postgres; + use tokio_postgres as postgres; use std::result::Result; // newtype so we can implement traits for it. diff --git a/butane_core/Cargo.toml b/butane_core/Cargo.toml index 7b68c807..50aef87a 100644 --- a/butane_core/Cargo.toml +++ b/butane_core/Cargo.toml @@ -15,7 +15,7 @@ debug = ["log"] sqlite = ["rusqlite"] sqlite-bundled = ["rusqlite/bundled"] tls = ["postgres-native-tls", "native-tls"] -pg = ["postgres", "bytes"] +pg = ["tokio-postgres", "bytes"] [dependencies] @@ -25,10 +25,9 @@ fallible-iterator = "0.2" fallible-streaming-iterator = "0.1" fs2 = "0.4" # for file locks hex = "0.4" -once_cell="1.5" log = { version="0.4", optional=true } native-tls={ version = "0.2", optional = true } -postgres={ version = "0.19", features=["with-chrono-0_4"], optional = true} +tokio-postgres={ version = "0.7", features=["with-chrono-0_4"], optional = true} postgres-native-tls={ version = "0.5", optional = true } proc-macro2 = "1.0" pin-project = "1" @@ -42,3 +41,6 @@ syn = { version = "1.0", features = ["full", "extra-traits"] } thiserror = "1.0" chrono = { version = "0.4", features=["serde"], optional = true } uuid = {workspace=true, optional=true} +async-trait = "0.1" +tokio = { version = "1.24", features=["rt", "sync"] } +futures-util = "0.3" diff --git a/butane_core/src/custom.rs b/butane_core/src/custom.rs index 1c77f846..6495cd77 100644 --- a/butane_core/src/custom.rs +++ b/butane_core/src/custom.rs @@ -7,12 +7,13 @@ use serde::{Deserialize, Serialize}; use std::fmt; +use tokio_postgres as postgres; /// For use with [SqlType::Custom](crate::SqlType) #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] pub enum SqlTypeCustom { #[cfg(feature = "pg")] - Pg(#[serde(with = "pgtypeser")] postgres::types::Type), + Pg(#[serde(with = "pgtypeser")] tokio_postgres::types::Type), } /// For use with [SqlVal::Custom](crate::SqlVal) @@ -137,6 +138,7 @@ impl From> for SqlValCustom { #[cfg(feature = "pg")] mod pgtypeser { use serde::{Deserialize, Deserializer, Serialize, Serializer}; + use tokio_postgres as postgres; pub fn serialize(ty: &postgres::types::Type, serializer: S) -> Result where diff --git a/butane_core/src/db/connmethods.rs b/butane_core/src/db/connmethods.rs index 85078ec4..be37df71 100644 --- a/butane_core/src/db/connmethods.rs +++ b/butane_core/src/db/connmethods.rs @@ -5,14 +5,16 @@ use crate::query::{BoolExpr, Expr, Order}; use crate::{Result, SqlType, SqlVal, SqlValRef}; use std::ops::{Deref, DerefMut}; use std::vec::Vec; +use async_trait::async_trait; /// Methods available on a database connection. Most users do not need /// to call these methods directly and will instead use methods on /// [DataObject][crate::DataObject] or the `query!` macro. This trait is /// implemented by both database connections and transactions. -pub trait ConnectionMethods { - fn execute(&self, sql: &str) -> Result<()>; - fn query<'a, 'b, 'c: 'a>( +#[async_trait] +pub trait ConnectionMethods: Sync { + async fn execute(&self, sql: &str) -> Result<()>; + async fn query<'a, 'b, 'c: 'a>( &'c self, table: &str, columns: &'b [Column], @@ -21,7 +23,7 @@ pub trait ConnectionMethods { offset: Option, sort: Option<&[Order]>, ) -> Result>; - fn insert_returning_pk( + async fn insert_returning_pk( &self, table: &str, columns: &[Column], @@ -29,30 +31,30 @@ pub trait ConnectionMethods { values: &[SqlValRef<'_>], ) -> Result; /// Like `insert_returning_pk` but with no return value - fn insert_only(&self, table: &str, columns: &[Column], values: &[SqlValRef<'_>]) -> Result<()>; + async fn insert_only(&self, table: &str, columns: &[Column], values: &[SqlValRef<'_>]) -> Result<()>; /// Insert unless there's a conflict on the primary key column, in which case update - fn insert_or_replace( + async fn insert_or_replace( &self, table: &str, columns: &[Column], pkcol: &Column, values: &[SqlValRef<'_>], ) -> Result<()>; - fn update( + async fn update<'a>( &self, table: &str, pkcol: Column, - pk: SqlValRef, + pk: SqlValRef<'a>, columns: &[Column], values: &[SqlValRef<'_>], ) -> Result<()>; - fn delete(&self, table: &str, pkcol: &'static str, pk: SqlVal) -> Result<()> { - self.delete_where(table, BoolExpr::Eq(pkcol, Expr::Val(pk)))?; + async fn delete(&self, table: &str, pkcol: &'static str, pk: SqlVal) -> Result<()> { + self.delete_where(table, BoolExpr::Eq(pkcol, Expr::Val(pk))).await?; Ok(()) } - fn delete_where(&self, table: &str, expr: BoolExpr) -> Result; + async fn delete_where(&self, table: &str, expr: BoolExpr) -> Result; /// Tests if a table exists in the database. - fn has_table(&self, table: &str) -> Result; + async fn has_table(&self, table: &str) -> Result; } /// Represents a database column. Most users do not need to use this diff --git a/butane_core/src/db/macros.rs b/butane_core/src/db/macros.rs index 6a1f1897..1af8edbe 100644 --- a/butane_core/src/db/macros.rs +++ b/butane_core/src/db/macros.rs @@ -1,11 +1,12 @@ #[macro_export] macro_rules! connection_method_wrapper { ($ty:path) => { + #[async_trait::async_trait] impl ConnectionMethods for $ty { - fn execute(&self, sql: &str) -> Result<()> { - ConnectionMethods::execute(self.wrapped_connection_methods()?, sql) + async fn execute(&self, sql: &str) -> Result<()> { + ConnectionMethods::execute(self.wrapped_connection_methods()?, sql).await } - fn query<'a, 'b, 'c: 'a>( + async fn query<'a, 'b, 'c: 'a>( &'c self, table: &str, columns: &'b [Column], @@ -15,9 +16,9 @@ macro_rules! connection_method_wrapper { sort: Option<&[$crate::query::Order]>, ) -> Result> { self.wrapped_connection_methods()? - .query(table, columns, expr, limit, offset, sort) + .query(table, columns, expr, limit, offset, sort).await } - fn insert_returning_pk( + async fn insert_returning_pk( &self, table: &str, columns: &[Column], @@ -26,8 +27,9 @@ macro_rules! connection_method_wrapper { ) -> Result { self.wrapped_connection_methods()? .insert_returning_pk(table, columns, pkcol, values) + .await } - fn insert_only( + async fn insert_only( &self, table: &str, columns: &[Column], @@ -35,8 +37,9 @@ macro_rules! connection_method_wrapper { ) -> Result<()> { self.wrapped_connection_methods()? .insert_only(table, columns, values) + .await } - fn insert_or_replace( + async fn insert_or_replace( &self, table: &str, columns: &[Column], @@ -45,23 +48,25 @@ macro_rules! connection_method_wrapper { ) -> Result<()> { self.wrapped_connection_methods()? .insert_or_replace(table, columns, pkcol, values) + .await } - fn update( + async fn update<'a>( &self, table: &str, pkcol: Column, - pk: SqlValRef, + pk: SqlValRef<'a>, columns: &[Column], values: &[SqlValRef<'_>], ) -> Result<()> { self.wrapped_connection_methods()? .update(table, pkcol, pk, columns, values) + .await } - fn delete_where(&self, table: &str, expr: BoolExpr) -> Result { - self.wrapped_connection_methods()?.delete_where(table, expr) + async fn delete_where(&self, table: &str, expr: BoolExpr) -> Result { + self.wrapped_connection_methods()?.delete_where(table, expr).await } - fn has_table(&self, table: &str) -> Result { - self.wrapped_connection_methods()?.has_table(table) + async fn has_table(&self, table: &str) -> Result { + self.wrapped_connection_methods()?.has_table(table).await } } }; diff --git a/butane_core/src/db/mod.rs b/butane_core/src/db/mod.rs index 85978135..3c5fbba6 100644 --- a/butane_core/src/db/mod.rs +++ b/butane_core/src/db/mod.rs @@ -20,14 +20,13 @@ use std::fs; use std::io::Write; use std::ops::{Deref, DerefMut}; use std::path::Path; +use async_trait::async_trait; mod connmethods; mod helper; mod macros; #[cfg(feature = "pg")] pub mod pg; -#[cfg(feature = "sqlite")] -pub mod sqlite; #[cfg(feature = "r2d2")] mod r2; @@ -42,10 +41,11 @@ pub use connmethods::{ }; /// Database connection. +#[async_trait] pub trait BackendConnection: ConnectionMethods + Send + 'static { /// Begin a database transaction. The transaction object must be /// used in place of this connection until it is committed and aborted. - fn transaction(&mut self) -> Result; + async fn transaction(&mut self) -> Result; /// Retrieve the backend backend this connection fn backend(&self) -> Box; fn backend_name(&self) -> &'static str; @@ -60,8 +60,8 @@ pub struct Connection { conn: Box, } impl Connection { - pub fn execute(&mut self, sql: impl AsRef) -> Result<()> { - self.conn.execute(sql.as_ref()) + pub async fn execute(&mut self, sql: impl AsRef) -> Result<()> { + self.conn.execute(sql.as_ref()).await } // For use with connection_method_wrapper macro #[allow(clippy::unnecessary_wraps)] @@ -69,9 +69,10 @@ impl Connection { Ok(self.conn.as_ref()) } } +#[async_trait] impl BackendConnection for Connection { - fn transaction(&mut self) -> Result { - self.conn.transaction() + async fn transaction(&mut self) -> Result { + self.conn.transaction().await } fn backend(&self) -> Box { self.conn.backend() @@ -129,12 +130,14 @@ fn conn_complete_if_dir(path: &Path) -> Cow { } /// Database backend. A boxed implementation can be returned by name via [get_backend][crate::db::get_backend]. -pub trait Backend { +#[async_trait] +pub trait Backend: Sync { fn name(&self) -> &'static str; fn create_migration_sql(&self, current: &adb::ADB, ops: Vec) -> Result; - fn connect(&self, conn_str: &str) -> Result; + async fn connect(&self, conn_str: &str) -> Result; } +#[async_trait] impl Backend for Box { fn name(&self) -> &'static str { self.deref().name() @@ -142,16 +145,14 @@ impl Backend for Box { fn create_migration_sql(&self, current: &adb::ADB, ops: Vec) -> Result { self.deref().create_migration_sql(current, ops) } - fn connect(&self, conn_str: &str) -> Result { - self.deref().connect(conn_str) + async fn connect(&self, conn_str: &str) -> Result { + self.deref().connect(conn_str).await } } /// Find a backend by name. pub fn get_backend(name: &str) -> Option> { match name { - #[cfg(feature = "sqlite")] - sqlite::BACKEND_NAME => Some(Box::new(sqlite::SQLiteBackend::new())), #[cfg(feature = "pg")] pg::BACKEND_NAME => Some(Box::new(pg::PgBackend::new())), _ => None, @@ -160,20 +161,22 @@ pub fn get_backend(name: &str) -> Option> { /// Connect to a database. For non-boxed connections, see individual /// [Backend][crate::db::Backend] implementations. -pub fn connect(spec: &ConnectionSpec) -> Result { +pub async fn connect(spec: &ConnectionSpec) -> Result { get_backend(&spec.backend_name) .ok_or_else(|| Error::UnknownBackend(spec.backend_name.clone()))? .connect(&spec.conn_str) + .await } -trait BackendTransaction<'c>: ConnectionMethods { +#[async_trait] +trait BackendTransaction<'c>: ConnectionMethods + Send { /// Commit the transaction Unfortunately because we use this as a /// trait object, we can't consume self. It should be understood /// that no methods should be called after commit. This trait is /// not public, and that behavior is enforced by Transaction - fn commit(&mut self) -> Result<()>; + async fn commit(&mut self) -> Result<()>; /// Roll back the transaction. Same comment about consuming self as above. - fn rollback(&mut self) -> Result<()>; + async fn rollback(&mut self) -> Result<()>; // Workaround for https://github.com/rust-lang/rfcs/issues/2765 fn connection_methods(&self) -> &dyn ConnectionMethods; @@ -194,12 +197,12 @@ impl<'c> Transaction<'c> { Transaction { trans } } /// Commit the transaction - pub fn commit(mut self) -> Result<()> { - self.trans.deref_mut().commit() + pub async fn commit(mut self) -> Result<()> { + self.trans.commit().await } /// Roll back the transaction. Equivalent to dropping it. - pub fn rollback(mut self) -> Result<()> { - self.trans.deref_mut().rollback() + pub async fn rollback(mut self) -> Result<()> { + self.trans.deref_mut().rollback().await } // For use with connection_method_wrapper macro #[allow(clippy::unnecessary_wraps)] diff --git a/butane_core/src/db/pg.rs b/butane_core/src/db/pg.rs index 96613722..d1b8fe8c 100644 --- a/butane_core/src/db/pg.rs +++ b/butane_core/src/db/pg.rs @@ -9,10 +9,12 @@ use crate::{Result, SqlType, SqlVal, SqlValRef}; use bytes::BufMut; #[cfg(feature = "datetime")] use chrono::NaiveDateTime; -use postgres::fallible_iterator::FallibleIterator; -use postgres::GenericClient; -use std::cell::RefCell; +use tokio_postgres::GenericClient; +use tokio_postgres as postgres; use std::fmt::Write; +use async_trait::async_trait; +use tokio; +use futures_util::stream::StreamExt; /// The name of the postgres backend. pub const BACKEND_NAME: &str = "pg"; @@ -26,10 +28,12 @@ impl PgBackend { } } impl PgBackend { - fn connect(&self, params: &str) -> Result { - PgConnection::open(params) + async fn connect(&self, params: &str) -> Result { + PgConnection::open(params).await } } + +#[async_trait] impl Backend for PgBackend { fn name(&self) -> &'static str { BACKEND_NAME @@ -44,24 +48,28 @@ impl Backend for PgBackend { .join("\n")) } - fn connect(&self, path: &str) -> Result { + async fn connect(&self, path: &str) -> Result { Ok(Connection { - conn: Box::new(self.connect(path)?), + conn: Box::new(self.connect(path).await?), }) } } +// type PgConnHandle = tokio::task::JoinHandle>; + /// Pg database connection. pub struct PgConnection { - conn: RefCell, + client: postgres::Client, + // Save the handle to the task running the connection to keep it alive + _conn_handle: tokio::task::JoinHandle<()>, } + impl PgConnection { - fn open(params: &str) -> Result { - Ok(PgConnection { - conn: RefCell::new(Self::connect(params)?), - }) + async fn open(params: &str) -> Result { + let (client, _conn_handle) = Self::connect(params).await?; + Ok( Self { client, _conn_handle }) } - fn connect(params: &str) -> Result { + async fn connect(params: &str) -> Result<(postgres::Client, tokio::task::JoinHandle<()>)> { cfg_if::cfg_if! { if #[cfg(feature = "tls")] { let connector = native_tls::TlsConnector::new()?; @@ -70,18 +78,27 @@ impl PgConnection { let connector = postgres::NoTls; } } - Ok(postgres::Client::connect(params, connector)?) + let (client, conn) = postgres::connect(params, connector).await?; + let conn_handle = tokio::spawn(async move { + if let Err(_e) = conn.await { + // TODO don't panic + panic!() + } + }); + Ok((client, conn_handle)) } } impl PgConnectionLike for PgConnection { type Client = postgres::Client; - fn cell(&self) -> Result<&RefCell> { - Ok(&self.conn) + fn client(&self) -> Result<&Self::Client> { + Ok(&self.client) } } + +#[async_trait] impl BackendConnection for PgConnection { - fn transaction(&mut self) -> Result> { - let trans: postgres::Transaction<'_> = self.conn.get_mut().transaction()?; + async fn transaction(&mut self) -> Result> { + let trans: postgres::Transaction<'_> = self.client.transaction().await?; let trans = Box::new(PgTransaction::new(trans)); Ok(Transaction::new(trans)) } @@ -92,7 +109,7 @@ impl BackendConnection for PgConnection { BACKEND_NAME } fn is_closed(&self) -> bool { - self.conn.borrow().is_closed() + self.client.is_closed() } } @@ -109,23 +126,24 @@ fn sqlvalref_for_pg_query<'a>(v: &'a SqlValRef<'a>) -> &'a dyn postgres::types:: /// Shared functionality between connection and /// transaction. Implementation detail. Semver exempt. pub trait PgConnectionLike { - type Client: postgres::GenericClient; - fn cell(&self) -> Result<&RefCell>; + type Client: postgres::GenericClient + Send + Sync; + fn client(&self) -> Result<&Self::Client>; } +#[async_trait] impl ConnectionMethods for T where - T: PgConnectionLike, + T: PgConnectionLike + std::marker::Sync, { - fn execute(&self, sql: &str) -> Result<()> { + async fn execute(&self, sql: &str) -> Result<()> { if cfg!(feature = "log") { debug!("execute sql {}", sql); } - self.cell()?.try_borrow_mut()?.batch_execute(sql.as_ref())?; + self.client()?.batch_execute(sql.as_ref()).await?; Ok(()) } - fn query<'a, 'b, 'c: 'a>( + async fn query<'a, 'b, 'c: 'a>( &'c self, table: &str, columns: &'b [Column], @@ -166,23 +184,22 @@ where let types: Vec = values.iter().map(pgtype_for_val).collect(); let stmt = self - .cell()? - .try_borrow_mut()? - .prepare_typed(&sqlquery, types.as_ref())?; - // todo avoid intermediate vec? - let rowvec: Vec = self - .cell()? - .try_borrow_mut()? - .query_raw(&stmt, values.iter().map(sqlval_for_pg_query))? - .map_err(Error::Postgres) - .map(|r| { - check_columns(&r, columns)?; - Ok(r) - }) - .collect()?; + .client()? + .prepare_typed(&sqlquery, types.as_ref()).await?; + let mut rowvec = Vec::::new(); + let rowstream = self + .client()? + .query_raw(&stmt, values.iter().map(sqlval_for_pg_query)).await + .map_err(Error::Postgres)?; + let mut rowstream = Box::pin(rowstream); + while let Some(r) = rowstream.next().await { + let r = r?; + check_columns(&r, columns)?; + rowvec.push(r); + } Ok(Box::new(VecRows::new(rowvec))) } - fn insert_returning_pk( + async fn insert_returning_pk( &self, table: &str, columns: &[Column], @@ -202,16 +219,13 @@ where } // use query instead of execute so we can get our result back - let pk: Option = self - .cell()? - .try_borrow_mut()? - .query_raw(sql.as_str(), values.iter().map(sqlvalref_for_pg_query))? - .map_err(Error::Postgres) - .map(|r| sql_val_from_postgres(&r, 0, pkcol)) - .nth(0)?; - pk.ok_or_else(|| Error::Internal("could not get pk".to_string())) - } - fn insert_only(&self, table: &str, columns: &[Column], values: &[SqlValRef<'_>]) -> Result<()> { + let pk_stream = self.client()? + .query_raw(sql.as_str(), values.iter().map(sqlvalref_for_pg_query)).await + .map_err(Error::Postgres)? + .map(|r| r.map(|x| sql_val_from_postgres(&x, 0, pkcol))); + Box::pin(pk_stream).next().await.ok_or(Error::Internal(("could not get pk").to_string()))?? + } + async fn insert_only(&self, table: &str, columns: &[Column], values: &[SqlValRef<'_>]) -> Result<()> { let mut sql = String::new(); helper::sql_insert_with_placeholders( table, @@ -220,31 +234,29 @@ where &mut sql, ); let params: Vec<&DynToSqlPg> = values.iter().map(|v| v as &DynToSqlPg).collect(); - self.cell()? - .try_borrow_mut()? - .execute(sql.as_str(), params.as_slice())?; + self.client()? + .execute(sql.as_str(), params.as_slice()).await?; Ok(()) } - fn insert_or_replace<'a>( + async fn insert_or_replace( &self, table: &str, columns: &[Column], pkcol: &Column, - values: &[SqlValRef<'a>], + values: &[SqlValRef<'_>], ) -> Result<()> { let mut sql = String::new(); sql_insert_or_replace_with_placeholders(table, columns, pkcol, &mut sql); let params: Vec<&DynToSqlPg> = values.iter().map(|v| v as &DynToSqlPg).collect(); - self.cell()? - .try_borrow_mut()? - .execute(sql.as_str(), params.as_slice())?; + self.client()? + .execute(sql.as_str(), params.as_slice()).await?; Ok(()) } - fn update( + async fn update<'a>( &self, table: &str, pkcol: Column, - pk: SqlValRef, + pk: SqlValRef<'a>, columns: &[Column], values: &[SqlValRef<'_>], ) -> Result<()> { @@ -264,12 +276,10 @@ where if cfg!(feature = "log") { debug!("update sql {}", sql); } - self.cell()? - .try_borrow_mut()? - .execute(sql.as_str(), params.as_slice())?; + self.client()?.execute(sql.as_str(), params.as_slice()).await?; Ok(()) } - fn delete_where(&self, table: &str, expr: BoolExpr) -> Result { + async fn delete_where(&self, table: &str, expr: BoolExpr) -> Result { let mut sql = String::new(); let mut values: Vec = Vec::new(); write!(&mut sql, "DELETE FROM {} WHERE ", table).unwrap(); @@ -281,35 +291,33 @@ where ); let params: Vec<&DynToSqlPg> = values.iter().map(|v| v as &DynToSqlPg).collect(); let cnt = self - .cell()? - .try_borrow_mut()? - .execute(sql.as_str(), params.as_slice())?; + .client()? + .execute(sql.as_str(), params.as_slice()).await?; Ok(cnt as usize) } - fn has_table(&self, table: &str) -> Result { + async fn has_table(&self, table: &str) -> Result { // future improvement, should be schema-aware let stmt = self - .cell()? - .try_borrow_mut()? - .prepare("SELECT table_name FROM information_schema.tables WHERE table_name=$1;")?; - let rows = self.cell()?.try_borrow_mut()?.query(&stmt, &[&table])?; + .client()? + .prepare("SELECT table_name FROM information_schema.tables WHERE table_name=$1;").await?; + let rows = self.client()?.query(&stmt, &[&table]).await?; Ok(!rows.is_empty()) } } struct PgTransaction<'c> { - trans: Option>>, + trans: Option>, } impl<'c> PgTransaction<'c> { fn new(trans: postgres::Transaction<'c>) -> Self { PgTransaction { - trans: Some(RefCell::new(trans)), + trans: Some(trans), } } - fn get(&self) -> Result<&RefCell>> { + fn get(&self) -> Result<&postgres::Transaction<'c>> { match &self.trans { + Some(x) => Ok(x), None => Err(Self::already_consumed()), - Some(trans) => Ok(trans), } } fn already_consumed() -> Error { @@ -318,22 +326,24 @@ impl<'c> PgTransaction<'c> { } impl<'c> PgConnectionLike for PgTransaction<'c> { type Client = postgres::Transaction<'c>; - fn cell(&self) -> Result<&RefCell> { + fn client(&self) -> Result<&Self::Client> { self.get() } } +#[async_trait] impl<'c> BackendTransaction<'c> for PgTransaction<'c> { - fn commit(&mut self) -> Result<()> { + async fn commit(&mut self) -> Result<()> { match self.trans.take() { None => Err(Self::already_consumed()), - Some(trans) => Ok(trans.into_inner().commit()?), + Some(trans) => Ok(trans.commit().await?), } } - fn rollback(&mut self) -> Result<()> { + + async fn rollback(&mut self) -> Result<()> { match self.trans.take() { None => Err(Self::already_consumed()), - Some(trans) => Ok(trans.into_inner().rollback()?), + Some(trans) => Ok(trans.rollback().await?), } } // Workaround for https://github.com/rust-lang/rfcs/issues/2765 diff --git a/butane_core/src/db/sqlite.rs b/butane_core/src/db/sqlite.rs deleted file mode 100644 index b1f0bc2e..00000000 --- a/butane_core/src/db/sqlite.rs +++ /dev/null @@ -1,614 +0,0 @@ -//! SQLite database backend -use super::helper; -use super::*; -use crate::db::connmethods::BackendRows; -use crate::debug; -use crate::migrations::adb::{AColumn, ATable, Operation, TypeIdentifier, ADB}; -use crate::query; -use crate::query::Order; -use crate::{Result, SqlType, SqlVal, SqlValRef}; -#[cfg(feature = "datetime")] -use chrono::naive::NaiveDateTime; -use fallible_streaming_iterator::FallibleStreamingIterator; -use pin_project::pin_project; -use std::borrow::Cow; -use std::fmt::Write; -use std::pin::Pin; - -#[cfg(feature = "datetime")] -const SQLITE_DT_FORMAT: &str = "%Y-%m-%d %H:%M:%S"; - -/// The name of the sqlite backend. -pub const BACKEND_NAME: &str = "sqlite"; - -/// SQLite [Backend][crate::db::Backend] implementation. -#[derive(Default)] -pub struct SQLiteBackend {} -impl SQLiteBackend { - pub fn new() -> SQLiteBackend { - SQLiteBackend {} - } -} -impl SQLiteBackend { - fn connect(&self, path: &str) -> Result { - SQLiteConnection::open(Path::new(path)) - } -} -impl Backend for SQLiteBackend { - fn name(&self) -> &'static str { - BACKEND_NAME - } - - fn create_migration_sql(&self, current: &ADB, ops: Vec) -> Result { - let mut current: ADB = (*current).clone(); - Ok(ops - .into_iter() - .map(|o| { - let sql = sql_for_op(&mut current, &o); - current.transform_with(o); - sql - }) - .collect::>>()? - .join("\n")) - } - - fn connect(&self, path: &str) -> Result { - Ok(Connection { - conn: Box::new(self.connect(path)?), - }) - } -} - -/// SQLite database connection. -pub struct SQLiteConnection { - conn: rusqlite::Connection, -} -impl SQLiteConnection { - fn open(path: impl AsRef) -> Result { - rusqlite::Connection::open(path) - .map(|conn| SQLiteConnection { conn }) - .map_err(|e| e.into()) - } - - // For use with connection_method_wrapper macro - #[allow(clippy::unnecessary_wraps)] - fn wrapped_connection_methods(&self) -> Result<&rusqlite::Connection> { - Ok(&self.conn) - } -} -connection_method_wrapper!(SQLiteConnection); - -impl BackendConnection for SQLiteConnection { - fn transaction(&mut self) -> Result> { - let trans: rusqlite::Transaction<'_> = self.conn.transaction()?; - let trans = Box::new(SqliteTransaction::new(trans)); - Ok(Transaction::new(trans)) - } - fn backend(&self) -> Box { - Box::new(SQLiteBackend {}) - } - fn backend_name(&self) -> &'static str { - "sqlite" - } - fn is_closed(&self) -> bool { - false - } -} - -impl ConnectionMethods for rusqlite::Connection { - fn execute(&self, sql: &str) -> Result<()> { - if cfg!(feature = "log") { - debug!("execute sql {}", sql); - } - self.execute_batch(sql.as_ref())?; - Ok(()) - } - - fn query<'a, 'b, 'c: 'a>( - &'c self, - table: &str, - columns: &'b [Column], - expr: Option, - limit: Option, - offset: Option, - order: Option<&[Order]>, - ) -> Result> { - let mut sqlquery = String::new(); - helper::sql_select(columns, table, &mut sqlquery); - let mut values: Vec = Vec::new(); - if let Some(expr) = expr { - sqlquery.write_str(" WHERE ").unwrap(); - sql_for_expr( - query::Expr::Condition(Box::new(expr)), - &mut values, - &mut SQLitePlaceholderSource::new(), - &mut sqlquery, - ); - } - - if let Some(order) = order { - helper::sql_order(order, &mut sqlquery) - } - - if let Some(limit) = limit { - helper::sql_limit(limit, &mut sqlquery) - } - - if let Some(offset) = offset { - if limit.is_none() { - // Sqlite only supports offset in conjunction with - // limit, so add a max limit if we don't have one - // already. - helper::sql_limit(i32::MAX, &mut sqlquery) - } - helper::sql_offset(offset, &mut sqlquery) - } - - debug!("query sql {}", sqlquery); - - let stmt = self.prepare(&sqlquery)?; - let adapter = QueryAdapter::new(stmt, rusqlite::params_from_iter(values))?; - Ok(Box::new(adapter)) - } - fn insert_returning_pk( - &self, - table: &str, - columns: &[Column], - pkcol: &Column, - values: &[SqlValRef<'_>], - ) -> Result { - let mut sql = String::new(); - helper::sql_insert_with_placeholders( - table, - columns, - &mut SQLitePlaceholderSource::new(), - &mut sql, - ); - if cfg!(feature = "log") { - debug!("insert sql {}", sql); - } - self.execute(&sql, rusqlite::params_from_iter(values))?; - let pk: SqlVal = self.query_row_and_then( - &format!( - "SELECT {} FROM {} WHERE ROWID = last_insert_rowid()", - pkcol.name(), - table - ), - [], - |row| sql_val_from_rusqlite(row.get_ref_unwrap(0), pkcol), - )?; - Ok(pk) - } - fn insert_only(&self, table: &str, columns: &[Column], values: &[SqlValRef<'_>]) -> Result<()> { - let mut sql = String::new(); - helper::sql_insert_with_placeholders( - table, - columns, - &mut SQLitePlaceholderSource::new(), - &mut sql, - ); - if cfg!(feature = "log") { - debug!("insert sql {}", sql); - } - self.execute(&sql, rusqlite::params_from_iter(values))?; - Ok(()) - } - fn insert_or_replace( - &self, - table: &str, - columns: &[Column], - _pkcol: &Column, - values: &[SqlValRef], - ) -> Result<()> { - let mut sql = String::new(); - sql_insert_or_update(table, columns, &mut sql); - self.execute(&sql, rusqlite::params_from_iter(values))?; - Ok(()) - } - fn update( - &self, - table: &str, - pkcol: Column, - pk: SqlValRef, - columns: &[Column], - values: &[SqlValRef<'_>], - ) -> Result<()> { - let mut sql = String::new(); - helper::sql_update_with_placeholders( - table, - pkcol, - columns, - &mut SQLitePlaceholderSource::new(), - &mut sql, - ); - let placeholder_values = [values, &[pk]].concat(); - if cfg!(feature = "log") { - debug!("update sql {}", sql); - } - self.execute(&sql, rusqlite::params_from_iter(placeholder_values))?; - Ok(()) - } - fn delete_where(&self, table: &str, expr: BoolExpr) -> Result { - let mut sql = String::new(); - let mut values: Vec = Vec::new(); - write!(&mut sql, "DELETE FROM {} WHERE ", table).unwrap(); - sql_for_expr( - query::Expr::Condition(Box::new(expr)), - &mut values, - &mut SQLitePlaceholderSource::new(), - &mut sql, - ); - let cnt = self.execute(&sql, rusqlite::params_from_iter(values))?; - Ok(cnt) - } - fn has_table(&self, table: &str) -> Result { - let mut stmt = - self.prepare("SELECT name FROM sqlite_master WHERE type='table' AND name=?;")?; - let mut rows = stmt.query([table])?; - Ok(rows.next()?.is_some()) - } -} - -struct SqliteTransaction<'c> { - trans: Option>, -} -impl<'c> SqliteTransaction<'c> { - fn new(trans: rusqlite::Transaction<'c>) -> Self { - SqliteTransaction { trans: Some(trans) } - } - fn get(&self) -> Result<&rusqlite::Transaction<'c>> { - match &self.trans { - None => Err(Self::already_consumed()), - Some(trans) => Ok(trans), - } - } - fn wrapped_connection_methods(&self) -> Result<&rusqlite::Connection> { - Ok(self.get()?.deref()) - } - fn already_consumed() -> Error { - Error::Internal("transaction has already been consumed".to_string()) - } -} -connection_method_wrapper!(SqliteTransaction<'_>); -impl<'c> BackendTransaction<'c> for SqliteTransaction<'c> { - fn commit(&mut self) -> Result<()> { - match self.trans.take() { - None => Err(Self::already_consumed()), - Some(trans) => Ok(trans.commit()?), - } - } - fn rollback(&mut self) -> Result<()> { - match self.trans.take() { - None => Err(Self::already_consumed()), - Some(trans) => Ok(trans.rollback()?), - } - } - // Workaround for https://github.com/rust-lang/rfcs/issues/2765 - fn connection_methods(&self) -> &dyn ConnectionMethods { - self - } - fn connection_methods_mut(&mut self) -> &mut dyn ConnectionMethods { - self - } -} - -impl rusqlite::ToSql for SqlVal { - fn to_sql(&self) -> rusqlite::Result> { - Ok(sqlvalref_to_sqlite(&self.as_ref())) - } -} - -impl<'a> rusqlite::ToSql for SqlValRef<'a> { - fn to_sql<'b>(&'b self) -> rusqlite::Result> { - Ok(sqlvalref_to_sqlite(self)) - } -} - -fn sqlvalref_to_sqlite<'a, 'b>(valref: &'b SqlValRef<'a>) -> rusqlite::types::ToSqlOutput<'a> { - use rusqlite::types::{ToSqlOutput::Borrowed, ToSqlOutput::Owned, Value, ValueRef}; - use SqlValRef::*; - match valref { - Bool(b) => Owned(Value::Integer(*b as i64)), - Int(i) => Owned(Value::Integer(*i as i64)), - BigInt(i) => Owned(Value::Integer(*i)), - Real(r) => Owned(Value::Real(*r)), - Text(t) => Borrowed(ValueRef::Text(t.as_bytes())), - Blob(b) => Borrowed(ValueRef::Blob(b)), - #[cfg(feature = "datetime")] - Timestamp(dt) => { - let f = dt.format(SQLITE_DT_FORMAT); - Owned(Value::Text(f.to_string())) - } - Null => Owned(Value::Null), - Custom(_) => panic!("Custom types not supported in sqlite"), - } -} - -#[pin_project] -struct QueryAdapterInner<'a> { - stmt: rusqlite::Statement<'a>, - // will always be Some when the constructor has finished. We use an option only to get the - // stmt in place before we can reference it. - rows: Option>, -} - -impl<'a> QueryAdapterInner<'a> { - fn new(stmt: rusqlite::Statement<'a>, params: impl rusqlite::Params) -> Result>> { - let mut q = Box::pin(QueryAdapterInner { stmt, rows: None }); - unsafe { - //Soundness: we pin a QueryAdapterInner value containing - // both the stmt and the rows referencing the statement - // together. It is not possible to drop/move the stmt without - // bringing the referencing rows along with it. - let q_ref = Pin::get_unchecked_mut(Pin::as_mut(&mut q)); - let stmt_ref: *mut rusqlite::Statement<'a> = &mut q_ref.stmt; - q_ref.rows = Some((*stmt_ref).query(params)?) - } - Ok(q) - } - - fn next<'b>(self: Pin<&'b mut Self>) -> Result>> { - let this = self.project(); - let rows: &mut rusqlite::Rows<'a> = this.rows.as_mut().unwrap(); - Ok(rows.next()?) - } - - fn current(self: Pin<&Self>) -> Option<&rusqlite::Row> { - let this = self.project_ref(); - this.rows.as_ref().unwrap().get() - } -} - -struct QueryAdapter<'a> { - inner: Pin>>, -} -impl<'a> QueryAdapter<'a> { - fn new(stmt: rusqlite::Statement<'a>, params: impl rusqlite::Params) -> Result { - Ok(QueryAdapter { - inner: QueryAdapterInner::new(stmt, params)?, - }) - } -} - -impl<'a> BackendRows for QueryAdapter<'a> { - fn next<'b>(&'b mut self) -> Result> { - Ok(self - .inner - .as_mut() - .next()? - .map(|row| row as &dyn BackendRow)) - } - fn current<'b>(&'b self) -> Option<&'b (dyn BackendRow + 'b)> { - self.inner - .as_ref() - .current() - .map(|row| row as &dyn BackendRow) - } -} - -impl BackendRow for rusqlite::Row<'_> { - fn get(&self, idx: usize, ty: SqlType) -> Result { - sql_valref_from_rusqlite(self.get_ref(idx)?, &ty) - } - fn len(&self) -> usize { - self.as_ref().column_count() - } -} - -fn sql_for_expr( - expr: query::Expr, - values: &mut Vec, - pls: &mut SQLitePlaceholderSource, - w: &mut W, -) where - W: Write, -{ - helper::sql_for_expr(expr, sql_for_expr, values, pls, w) -} - -fn sql_val_from_rusqlite(val: rusqlite::types::ValueRef, col: &Column) -> Result { - sql_valref_from_rusqlite(val, col.ty()).map(|v| v.into()) -} - -fn sql_valref_from_rusqlite<'a>( - val: rusqlite::types::ValueRef<'a>, - ty: &SqlType, -) -> Result> { - if let rusqlite::types::ValueRef::Null = val { - return Ok(SqlValRef::Null); - } - Ok(match ty { - SqlType::Bool => SqlValRef::Bool(val.as_i64()? != 0), - SqlType::Int => SqlValRef::Int(val.as_i64()? as i32), - SqlType::BigInt => SqlValRef::BigInt(val.as_i64()?), - SqlType::Real => SqlValRef::Real(val.as_f64()?), - SqlType::Text => SqlValRef::Text(val.as_str()?), - #[cfg(feature = "datetime")] - SqlType::Timestamp => SqlValRef::Timestamp(NaiveDateTime::parse_from_str( - val.as_str()?, - SQLITE_DT_FORMAT, - )?), - SqlType::Blob => SqlValRef::Blob(val.as_blob()?), - SqlType::Custom(v) => { - return Err(Error::IncompatibleCustomT(v.deref().clone(), BACKEND_NAME)) - } - }) -} - -fn sql_for_op(current: &mut ADB, op: &Operation) -> Result { - match op { - Operation::AddTable(table) => Ok(create_table(table, false)), - Operation::AddTableIfNotExists(table) => Ok(create_table(table, true)), - Operation::RemoveTable(name) => Ok(drop_table(name)), - Operation::AddColumn(tbl, col) => add_column(tbl, col), - Operation::RemoveColumn(tbl, name) => Ok(remove_column(current, tbl, name)), - Operation::ChangeColumn(tbl, old, new) => Ok(change_column(current, tbl, old, Some(new))), - } -} - -fn create_table(table: &ATable, allow_exists: bool) -> String { - let coldefs = table - .columns - .iter() - .map(define_column) - .collect::>() - .join(",\n"); - let modifier = if allow_exists { "IF NOT EXISTS " } else { "" }; - format!("CREATE TABLE {}{} (\n{}\n);", modifier, table.name, coldefs) -} - -fn define_column(col: &AColumn) -> String { - let mut constraints: Vec = Vec::new(); - if !col.nullable() { - constraints.push("NOT NULL".to_string()); - } - if col.is_pk() { - constraints.push("PRIMARY KEY".to_string()); - } - if col.is_auto() && !col.is_pk() { - // integer primary key is automatically an alias for ROWID, - // and we only allow auto on integer types - constraints.push("AUTOINCREMENT".to_string()); - } - if col.unique() { - constraints.push("UNIQUE".to_string()); - } - format!( - "{} {} {}", - &col.name(), - col_sqltype(col), - constraints.join(" ") - ) -} - -fn col_sqltype(col: &AColumn) -> Cow { - match col.typeid() { - Ok(TypeIdentifier::Ty(ty)) => Cow::Borrowed(sqltype(&ty)), - Ok(TypeIdentifier::Name(name)) => Cow::Owned(name), - // sqlite doesn't actually require that the column type be - // specified - Err(_) => Cow::Borrowed(""), - } -} - -fn sqltype(ty: &SqlType) -> &'static str { - match ty { - SqlType::Bool => "INTEGER", - SqlType::Int => "INTEGER", - SqlType::BigInt => "INTEGER", - SqlType::Real => "REAL", - SqlType::Text => "TEXT", - #[cfg(feature = "datetime")] - SqlType::Timestamp => "TEXT", - SqlType::Blob => "BLOB", - SqlType::Custom(_) => panic!("Custom types not supported by sqlite backend"), - } -} - -fn drop_table(name: &str) -> String { - format!("DROP TABLE {};", name) -} - -fn add_column(tbl_name: &str, col: &AColumn) -> Result { - let default: SqlVal = helper::column_default(col)?; - Ok(format!( - "ALTER TABLE {} ADD COLUMN {} DEFAULT {};", - tbl_name, - define_column(col), - helper::sql_literal_value(default)? - )) -} - -fn remove_column(current: &mut ADB, tbl_name: &str, name: &str) -> String { - let old = current - .get_table(tbl_name) - .and_then(|table| table.column(name)) - .cloned(); - match old { - Some(col) => change_column(current, tbl_name, &col, None), - None => { - crate::warn!( - "Cannot remove column {} that does not exist from table {}", - name, - tbl_name - ); - "".to_string() - } - } -} - -fn copy_table(old: &ATable, new: &ATable) -> String { - let column_names = new - .columns - .iter() - .map(|col| col.name()) - .collect::>() - .join(", "); - format!( - "INSERT INTO {} SELECT {} FROM {};", - &new.name, column_names, &old.name - ) -} - -fn tmp_table_name(name: &str) -> String { - format!("{}__butane_tmp", name) -} - -fn change_column( - current: &mut ADB, - tbl_name: &str, - old: &AColumn, - new: Option<&AColumn>, -) -> String { - let table = current.get_table(tbl_name); - if table.is_none() { - crate::warn!( - "Cannot alter column {} from table {} that does not exist", - &old.name(), - tbl_name - ); - return "".to_string(); - } - let old_table = table.unwrap(); - let mut new_table = old_table.clone(); - new_table.name = tmp_table_name(&new_table.name); - match new { - Some(col) => new_table.replace_column(col.clone()), - None => new_table.remove_column(old.name()), - } - let stmts: [&str; 4] = [ - &create_table(&new_table, false), - ©_table(old_table, &new_table), - &drop_table(&old_table.name), - &format!("ALTER TABLE {} RENAME TO {};", &new_table.name, tbl_name), - ]; - let result = stmts.join("\n"); - new_table.name = old_table.name.clone(); - current.replace_table(new_table); - result -} - -pub fn sql_insert_or_update(table: &str, columns: &[Column], w: &mut impl Write) { - write!(w, "INSERT OR REPLACE ").unwrap(); - write!(w, "INTO {} (", table).unwrap(); - helper::list_columns(columns, w); - write!(w, ") VALUES (").unwrap(); - columns.iter().fold("", |sep, _| { - write!(w, "{}?", sep).unwrap(); - ", " - }); - write!(w, ")").unwrap(); -} - -struct SQLitePlaceholderSource {} -impl SQLitePlaceholderSource { - fn new() -> Self { - SQLitePlaceholderSource {} - } -} -impl helper::PlaceholderSource for SQLitePlaceholderSource { - fn next_placeholder(&mut self) -> Cow { - // sqlite placeholder is always a question mark. - Cow::Borrowed("?") - } -} diff --git a/butane_core/src/fkey.rs b/butane_core/src/fkey.rs index 834e12a8..c48e23d4 100644 --- a/butane_core/src/fkey.rs +++ b/butane_core/src/fkey.rs @@ -1,6 +1,6 @@ use crate::db::ConnectionMethods; use crate::*; -use once_cell::unsync::OnceCell; +use tokio::sync::OnceCell; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::borrow::Cow; use std::fmt::{Debug, Formatter}; @@ -55,17 +55,6 @@ impl ForeignKey { } } - /// Loads the value referred to by this foreign key from the - /// database if necessary and returns a reference to it. - pub fn load(&self, conn: &impl ConnectionMethods) -> Result<&T> { - self.val - .get_or_try_init(|| { - let pk = self.valpk.get().unwrap(); - T::get(conn, &T::PKType::from_sql_ref(pk.as_ref())?).map(Box::new) - }) - .map(|v| v.as_ref()) - } - fn new_raw() -> Self { ForeignKey { val: OnceCell::new(), @@ -85,6 +74,20 @@ impl ForeignKey { } } +impl ForeignKey { + /// Loads the value referred to by this foreign key from the + /// database if necessary and returns a reference to it. + pub async fn load(&self, conn: &impl ConnectionMethods) -> Result<&T> { + self.val + .get_or_try_init(|| async { + let pk = self.valpk.get().unwrap(); + T::get(conn, &T::PKType::from_sql_ref(pk.as_ref())?).await.map(Box::new) + }) + .await + .map(|v| v.as_ref()) + } +} + impl From for ForeignKey { fn from(obj: T) -> Self { let ret = Self::new_raw(); diff --git a/butane_core/src/lib.rs b/butane_core/src/lib.rs index ccdc2412..748b750a 100644 --- a/butane_core/src/lib.rs +++ b/butane_core/src/lib.rs @@ -5,6 +5,7 @@ use std::borrow::Borrow; use std::cmp::{Eq, PartialEq}; use std::default::Default; use thiserror::Error as ThisError; +use async_trait::async_trait; pub mod codegen; pub mod custom; @@ -53,6 +54,7 @@ impl Eq for ObjectState {} /// object type. The purpose of a result type which is not also an /// object type is to allow a query to retrieve a subset of an /// object's columns. +#[async_trait] pub trait DataResult: Sized { /// Corresponding object type. type DBO: DataObject; @@ -61,14 +63,15 @@ pub trait DataResult: Sized { where Self: Sized; /// Create a blank query (matching all rows) for this type. - fn query() -> Query; + async fn query() -> Query; } /// An object in the database. /// /// Rather than implementing this type manually, use the /// `#[model]` attribute. -pub trait DataObject: DataResult { +#[async_trait] +pub trait DataObject: DataResult + Sync { /// The type of the primary key field. type PKType: PrimaryKeyType; type Fields: Default; @@ -82,25 +85,27 @@ pub trait DataObject: DataResult { /// Get the primary key fn pk(&self) -> &Self::PKType; /// Find this object in the database based on primary key. - fn get(conn: &impl ConnectionMethods, id: impl Borrow) -> Result + async fn get(conn: &impl ConnectionMethods, id: impl Borrow + Send + Sync) -> Result where Self: Sized, + Self::PKType: Sync, { - ::query() - .filter(query::BoolExpr::Eq( + let query = ::query().await; + query.filter(query::BoolExpr::Eq( Self::PKCOL, query::Expr::Val(id.borrow().to_sql()), )) .limit(1) - .load(conn)? + .load(conn).await? .into_iter() .nth(0) .ok_or(Error::NoSuchObject) } + /// Save the object to the database. - fn save(&mut self, conn: &impl ConnectionMethods) -> Result<()>; + async fn save(&mut self, conn: &impl ConnectionMethods) -> Result<()>; /// Delete the object from the database. - fn delete(&self, conn: &impl ConnectionMethods) -> Result<()>; + async fn delete(&self, conn: &impl ConnectionMethods) -> Result<()>; } pub trait ModelTyped { @@ -162,7 +167,7 @@ pub enum Error { SQLiteFromSQL(rusqlite::types::FromSqlError), #[cfg(feature = "pg")] #[error("Postgres error {0}")] - Postgres(#[from] postgres::Error), + Postgres(#[from] tokio_postgres::Error), #[cfg(feature = "datetime")] #[error("Chrono error {0}")] Chrono(#[from] chrono::ParseError), diff --git a/butane_core/src/many.rs b/butane_core/src/many.rs index c6969f69..58e404a8 100644 --- a/butane_core/src/many.rs +++ b/butane_core/src/many.rs @@ -1,12 +1,13 @@ use crate::db::{Column, ConnectionMethods}; use crate::query::{BoolExpr, Expr}; use crate::{DataObject, Error, FieldType, Result, SqlType, SqlVal, ToSql}; -use once_cell::unsync::OnceCell; +use tokio::sync::OnceCell; use serde::{Deserialize, Serialize}; use std::borrow::Cow; fn default_oc() -> OnceCell> { - OnceCell::default() + // Same as impl Default for once_cell::unsync::OnceCell + OnceCell::new() } /// Used to implement a many-to-many relationship between models. @@ -102,7 +103,7 @@ where } /// Used by macro-generated code. You do not need to call this directly. - pub fn save(&mut self, conn: &impl ConnectionMethods) -> Result<()> { + pub async fn save(&mut self, conn: &impl ConnectionMethods) -> Result<()> { let owner = self.owner.as_ref().ok_or(Error::NotInitialized)?; while !self.new_values.is_empty() { conn.insert_only( @@ -112,13 +113,13 @@ where owner.as_ref(), self.new_values.pop().unwrap().as_ref().clone(), ], - )?; + ).await?; } if !self.removed_values.is_empty() { conn.delete_where( &self.item_table, BoolExpr::In("has", std::mem::take(&mut self.removed_values)), - )?; + ).await?; } self.new_values.clear(); Ok(()) @@ -126,31 +127,31 @@ where /// Loads the values referred to by this foreign key from the /// database if necessary and returns a reference to them. - pub fn load(&self, conn: &impl ConnectionMethods) -> Result> { - let vals: Result<&Vec> = self.all_values.get_or_try_init(|| { + pub async fn load(&self, conn: &impl ConnectionMethods) -> Result> { + let vals: Result<&Vec> = self.all_values.get_or_try_init(|| async { //if we don't have an owner then there are no values let owner: &SqlVal = match &self.owner { Some(o) => o, None => return Ok(Vec::new()), }; - let mut vals = T::query() + let mut vals = T::query().await .filter(BoolExpr::Subquery { col: T::PKCOL, tbl2: self.item_table.clone(), tbl2_col: "has", expr: Box::new(BoolExpr::Eq("owner", Expr::Val(owner.clone()))), }) - .load(conn)?; + .load(conn).await?; // Now add in the values for things not saved to the db yet if !self.new_values.is_empty() { vals.append( - &mut T::query() + &mut T::query().await .filter(BoolExpr::In(T::PKCOL, self.new_values.clone())) - .load(conn)?, + .load(conn).await?, ); } Ok(vals) - }); + }).await; vals.map(|v| v.iter()) } pub fn columns(&self) -> [Column; 2] { diff --git a/butane_core/src/migrations/fsmigrations.rs b/butane_core/src/migrations/fsmigrations.rs index 328ac66d..a52d8dac 100644 --- a/butane_core/src/migrations/fsmigrations.rs +++ b/butane_core/src/migrations/fsmigrations.rs @@ -10,7 +10,6 @@ use std::fs::{File, OpenOptions}; use std::io::{Read, Write}; use std::path::{Path, PathBuf}; -use std::rc::Rc; type SqlTypeMap = BTreeMap; const TYPES_FILENAME: &str = "types.json"; @@ -43,7 +42,7 @@ impl MigrationsState { /// A migration stored in the filesystem pub struct FsMigration { - fs: Rc, + fs: std::sync::Arc, root: PathBuf, } @@ -236,13 +235,13 @@ impl Eq for FsMigration {} /// A collection of migrations stored in the filesystem. pub struct FsMigrations { - fs: Rc, + fs: std::sync::Arc, root: PathBuf, current: FsMigration, } impl FsMigrations { pub fn new(root: PathBuf) -> Self { - let fs = Rc::new(OsFilesystem {}); + let fs = std::sync::Arc::new(OsFilesystem {}); let current = FsMigration { fs: fs.clone(), root: root.join("current"), @@ -297,6 +296,7 @@ impl Migrations for FsMigrations { } } +#[async_trait::async_trait] impl MigrationsMut for FsMigrations { fn current(&mut self) -> &mut Self::M { &mut self.current @@ -320,7 +320,7 @@ impl MigrationsMut for FsMigrations { Ok(()) } - fn clear_migrations(&mut self, conn: &impl ConnectionMethods) -> Result<()> { + async fn clear_migrations(&mut self, conn: &impl ConnectionMethods) -> Result<()> { for entry in std::fs::read_dir(&self.root)? { let entry = entry?; if matches!(entry.path().file_name(), Some(name) if name == "current") { @@ -332,7 +332,7 @@ impl MigrationsMut for FsMigrations { std::fs::remove_file(entry.path())?; } } - conn.delete_where(super::ButaneMigration::TABLE, crate::query::BoolExpr::True)?; + conn.delete_where(super::ButaneMigration::TABLE, crate::query::BoolExpr::True).await?; Ok(()) } } diff --git a/butane_core/src/migrations/memmigrations.rs b/butane_core/src/migrations/memmigrations.rs index 030d4c11..6f5fd2d1 100644 --- a/butane_core/src/migrations/memmigrations.rs +++ b/butane_core/src/migrations/memmigrations.rs @@ -128,6 +128,7 @@ impl Migrations for MemMigrations { } } +#[async_trait::async_trait] impl MigrationsMut for MemMigrations { fn current(&mut self) -> &mut Self::M { &mut self.current @@ -152,10 +153,10 @@ impl MigrationsMut for MemMigrations { Ok(()) } - fn clear_migrations(&mut self, conn: &impl ConnectionMethods) -> Result<()> { + async fn clear_migrations(&mut self, conn: &impl ConnectionMethods) -> Result<()> { self.migrations.clear(); self.latest = None; - conn.delete_where(ButaneMigration::TABLE, BoolExpr::True)?; + conn.delete_where(ButaneMigration::TABLE, BoolExpr::True).await?; Ok(()) } } diff --git a/butane_core/src/migrations/migration.rs b/butane_core/src/migrations/migration.rs index b73b5d39..d15dbbc5 100644 --- a/butane_core/src/migrations/migration.rs +++ b/butane_core/src/migrations/migration.rs @@ -13,6 +13,7 @@ use std::cmp::PartialEq; /// /// A Migration cannot be constructed directly, only retrieved from /// [Migrations][crate::migrations::Migrations]. +#[async_trait::async_trait] pub trait Migration: PartialEq { /// Retrieves the full abstract database state describing all tables fn db(&self) -> Result; @@ -37,46 +38,46 @@ pub trait Migration: PartialEq { /// Apply the migration to a database connection. The connection /// must be for the same type of database as this and the database /// must be in the state of the migration prior to this one - fn apply(&self, conn: &mut impl db::BackendConnection) -> Result<()> { + async fn apply(&self, conn: &mut impl db::BackendConnection) -> Result<()> { let backend_name = conn.backend_name(); - let tx = conn.transaction()?; + let tx = conn.transaction().await?; let sql = self .up_sql(backend_name)? .ok_or_else(|| Error::UnknownBackend(backend_name.to_string()))?; - tx.execute(&sql)?; - self.mark_applied(&tx)?; - tx.commit() + tx.execute(&sql).await?; + self.mark_applied(&tx).await?; + tx.commit().await } /// Mark the migration as being applied without doing any /// work. Use carefully -- the caller must ensure that the /// database schema already matches that expected by this /// migration. - fn mark_applied(&self, conn: &impl db::ConnectionMethods) -> Result<()> { + async fn mark_applied(&self, conn: &impl db::ConnectionMethods) -> Result<()> { conn.insert_only( ButaneMigration::TABLE, ButaneMigration::COLUMNS, &[self.name().as_ref().to_sql_ref()], - ) + ).await } /// Un-apply (downgrade) the migration to a database /// connection. The connection must be for the same type of /// database as this and this must be the latest migration applied /// to the database. - fn downgrade(&self, conn: &mut impl db::BackendConnection) -> Result<()> { + async fn downgrade(&self, conn: &mut impl db::BackendConnection) -> Result<()> { let backend_name = conn.backend_name(); - let tx = conn.transaction()?; + let tx = conn.transaction().await?; let sql = self .down_sql(backend_name)? .ok_or_else(|| Error::UnknownBackend(backend_name.to_string()))?; - tx.execute(&sql)?; + tx.execute(&sql).await?; let nameval = self.name().as_ref().to_sql(); tx.delete_where( ButaneMigration::TABLE, BoolExpr::Eq(ButaneMigration::PKCOL, Expr::Val(nameval)), - )?; - tx.commit() + ).await?; + tx.commit().await } } diff --git a/butane_core/src/migrations/mod.rs b/butane_core/src/migrations/mod.rs index 531eb75c..21c53020 100644 --- a/butane_core/src/migrations/mod.rs +++ b/butane_core/src/migrations/mod.rs @@ -20,8 +20,10 @@ mod fsmigrations; pub use fsmigrations::{FsMigration, FsMigrations}; mod memmigrations; pub use memmigrations::{MemMigration, MemMigrations}; +use async_trait::async_trait; /// A collection of migrations. +#[async_trait] pub trait Migrations { type M: Migration; @@ -66,8 +68,8 @@ pub trait Migrations { } /// Get migrations which have not yet been applied to the database - fn unapplied_migrations(&self, conn: &impl ConnectionMethods) -> Result> { - match self.last_applied_migration(conn)? { + async fn unapplied_migrations(&self, conn: &impl ConnectionMethods) -> Result> { + match self.last_applied_migration(conn).await? { None => self.all_migrations(), Some(m) => self.migrations_since(&m), } @@ -75,8 +77,8 @@ pub trait Migrations { /// Get the last migration that has been applied to the database or None /// if no migrations have been applied - fn last_applied_migration(&self, conn: &impl ConnectionMethods) -> Result> { - if !conn.has_table(ButaneMigration::TABLE)? { + async fn last_applied_migration(&self, conn: &impl ConnectionMethods) -> Result> { + if !conn.has_table(ButaneMigration::TABLE).await? { return Ok(None); } let migrations: Vec = conn @@ -87,7 +89,7 @@ pub trait Migrations { None, None, None, - )? + ).await? .mapped(ButaneMigration::from_row) .collect()?; @@ -106,6 +108,7 @@ pub trait Migrations { } } +#[async_trait] pub trait MigrationsMut: Migrations where Self::M: MigrationMut, @@ -124,7 +127,7 @@ where /// any storage backing it) and deleting the record of their /// existence/application from the database. The database schema /// is not modified, nor is any other data removed. Use carefully. - fn clear_migrations(&mut self, conn: &impl ConnectionMethods) -> Result<()>; + async fn clear_migrations(&mut self, conn: &impl ConnectionMethods) -> Result<()>; /// Get a pseudo-migration representing the current state as /// determined by the last build of models. This does not @@ -233,6 +236,8 @@ pub fn copy_migration(from: &impl Migration, to: &mut impl MigrationMut) -> Resu struct ButaneMigration { name: String, } + +#[async_trait] impl DataResult for ButaneMigration { type DBO = Self; const COLUMNS: &'static [Column] = &[Column::new("name", SqlType::Text)]; @@ -246,10 +251,13 @@ impl DataResult for ButaneMigration { name: FromSql::from_sql_ref(row.get(0, SqlType::Text).unwrap())?, }) } - fn query() -> query::Query { + + async fn query() -> query::Query { query::Query::new("butane_migrations") } } + +#[async_trait] impl DataObject for ButaneMigration { type PKType = String; type Fields = (); // we don't need Fields as we never filter @@ -259,7 +267,7 @@ impl DataObject for ButaneMigration { fn pk(&self) -> &String { &self.name } - fn save(&mut self, conn: &impl ConnectionMethods) -> Result<()> { + async fn save(&mut self, conn: &impl ConnectionMethods) -> Result<()> { let mut values: Vec> = Vec::with_capacity(2usize); values.push(self.name.to_sql_ref()); conn.insert_or_replace( @@ -267,9 +275,9 @@ impl DataObject for ButaneMigration { ::COLUMNS, &Column::new(Self::PKCOL, SqlType::Text), &values, - ) + ).await } - fn delete(&self, conn: &impl ConnectionMethods) -> Result<()> { - conn.delete(Self::TABLE, Self::PKCOL, self.pk().to_sql()) + async fn delete(&self, conn: &impl ConnectionMethods) -> Result<()> { + conn.delete(Self::TABLE, Self::PKCOL, self.pk().to_sql()).await } } diff --git a/butane_core/src/query/mod.rs b/butane_core/src/query/mod.rs index 705cc5f2..f0157e4c 100644 --- a/butane_core/src/query/mod.rs +++ b/butane_core/src/query/mod.rs @@ -179,26 +179,26 @@ impl Query { } /// Executes the query against `conn` and returns the first result (if any). - pub fn load_first(self, conn: &impl ConnectionMethods) -> Result> { - conn.query(&self.table, T::COLUMNS, self.filter, Some(1), None, None)? + pub async fn load_first(self, conn: &impl ConnectionMethods) -> Result> { + conn.query(&self.table, T::COLUMNS, self.filter, Some(1), None, None).await? .mapped(T::from_row) .nth(0) } /// Executes the query against `conn`. - pub fn load(self, conn: &impl ConnectionMethods) -> Result> { + pub async fn load(self, conn: &impl ConnectionMethods) -> Result> { let sort = if self.sort.is_empty() { None } else { Some(self.sort.as_slice()) }; - conn.query(&self.table, T::COLUMNS, self.filter, self.limit, self.offset, sort)? + conn.query(&self.table, T::COLUMNS, self.filter, self.limit, self.offset, sort).await? .mapped(T::from_row) .collect() } /// Executes the query against `conn` and deletes all matching objects. - pub fn delete(self, conn: &impl ConnectionMethods) -> Result { - conn.delete_where(&self.table, self.filter.unwrap_or(BoolExpr::True)) + pub async fn delete(self, conn: &impl ConnectionMethods) -> Result { + conn.delete_where(&self.table, self.filter.unwrap_or(BoolExpr::True)).await } } diff --git a/butane_core/src/sqlval.rs b/butane_core/src/sqlval.rs index 2efd50e2..990dcbdc 100644 --- a/butane_core/src/sqlval.rs +++ b/butane_core/src/sqlval.rs @@ -269,7 +269,7 @@ pub trait FieldType: ToSql + FromSql { } /// Marker trait for a type suitable for being a primary key -pub trait PrimaryKeyType: FieldType + Clone + PartialEq {} +pub trait PrimaryKeyType: FieldType + Clone + PartialEq + Sync{} /// Trait for referencing the primary key for a given model. Used to /// implement ForeignKey equality tests.