diff --git a/src/bind.rs b/src/bind.rs new file mode 100644 index 0000000..fed6697 --- /dev/null +++ b/src/bind.rs @@ -0,0 +1,161 @@ +use { + crate::error::SpannerDbErr, + gcloud_spanner::statement::Statement as SpannerStatement, + sea_orm::{DbErr, Statement}, +}; + +/// Convert a SeaORM Statement into a Spanner Statement with bound parameters. +pub(crate) fn convert_statement(stmt: &Statement) -> Result { + let sql = &stmt.sql; + let mut spanner_stmt = SpannerStatement::new(sql); + + if let Some(values) = &stmt.values { + for (idx, value) in values.0.iter().enumerate() { + let param_name = format!("p{}", idx + 1); + bind_value(&mut spanner_stmt, ¶m_name, value)?; + } + } + + Ok(spanner_stmt) +} + +/// Bind a sea_orm::Value to a Spanner Statement parameter. +/// +/// This is the shared implementation used by SpannerExecutor, SpannerReadWriteTransaction, +/// and SpannerReadOnlyTransaction. The proxy path (SpannerProxy) has its own implementation +/// that uses custom wrapper types for Spanner-specific type handling. +pub(crate) fn bind_value( + stmt: &mut SpannerStatement, + param_name: &str, + value: &sea_orm::Value, +) -> Result<(), DbErr> { + use sea_orm::Value; + + match value { + Value::Bool(Some(v)) => stmt.add_param(param_name, v), + Value::Bool(None) => stmt.add_param(param_name, &Option::::None), + Value::TinyInt(Some(v)) => stmt.add_param(param_name, &(*v as i64)), + Value::TinyInt(None) => stmt.add_param(param_name, &Option::::None), + Value::SmallInt(Some(v)) => stmt.add_param(param_name, &(*v as i64)), + Value::SmallInt(None) => stmt.add_param(param_name, &Option::::None), + Value::Int(Some(v)) => stmt.add_param(param_name, &(*v as i64)), + Value::Int(None) => stmt.add_param(param_name, &Option::::None), + Value::BigInt(Some(v)) => stmt.add_param(param_name, v), + Value::BigInt(None) => stmt.add_param(param_name, &Option::::None), + Value::TinyUnsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)), + Value::TinyUnsigned(None) => stmt.add_param(param_name, &Option::::None), + Value::SmallUnsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)), + Value::SmallUnsigned(None) => stmt.add_param(param_name, &Option::::None), + Value::Unsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)), + Value::Unsigned(None) => stmt.add_param(param_name, &Option::::None), + Value::BigUnsigned(Some(v)) => { + let i = i64::try_from(*v).map_err(|_| SpannerDbErr::TypeConversion { + column: param_name.to_string(), + expected: "i64".to_string(), + got: format!("u64 value {} overflows i64", v), + })?; + stmt.add_param(param_name, &i); + } + Value::BigUnsigned(None) => stmt.add_param(param_name, &Option::::None), + Value::Float(Some(v)) => stmt.add_param(param_name, &(*v as f64)), + Value::Float(None) => stmt.add_param(param_name, &Option::::None), + Value::Double(Some(v)) => stmt.add_param(param_name, v), + Value::Double(None) => stmt.add_param(param_name, &Option::::None), + Value::String(Some(v)) => { + let s: &str = v.as_ref(); + stmt.add_param(param_name, &s) + } + Value::String(None) => stmt.add_param(param_name, &Option::::None), + Value::Char(Some(v)) => stmt.add_param(param_name, &v.to_string()), + Value::Char(None) => stmt.add_param(param_name, &Option::::None), + Value::Bytes(Some(v)) => { + let b: &[u8] = v.as_ref(); + stmt.add_param(param_name, &b) + } + Value::Bytes(None) => stmt.add_param(param_name, &Option::>::None), + + #[cfg(feature = "with-chrono")] + Value::ChronoDate(Some(v)) => stmt.add_param(param_name, &v.format("%Y-%m-%d").to_string()), + #[cfg(feature = "with-chrono")] + Value::ChronoDate(None) => stmt.add_param(param_name, &Option::::None), + #[cfg(feature = "with-chrono")] + Value::ChronoTime(Some(v)) => { + stmt.add_param(param_name, &v.format("%H:%M:%S%.f").to_string()) + } + #[cfg(feature = "with-chrono")] + Value::ChronoTime(None) => stmt.add_param(param_name, &Option::::None), + #[cfg(feature = "with-chrono")] + Value::ChronoDateTime(Some(v)) => stmt.add_param( + param_name, + &crate::chrono_support::SpannerTimestamp::new(v.and_utc()), + ), + #[cfg(feature = "with-chrono")] + Value::ChronoDateTime(None) => stmt.add_param( + param_name, + &crate::chrono_support::SpannerOptionalTimestamp::none(), + ), + #[cfg(feature = "with-chrono")] + Value::ChronoDateTimeUtc(Some(v)) => stmt.add_param( + param_name, + &crate::chrono_support::SpannerTimestamp::new(v.to_utc()), + ), + #[cfg(feature = "with-chrono")] + Value::ChronoDateTimeUtc(None) => stmt.add_param( + param_name, + &crate::chrono_support::SpannerOptionalTimestamp::none(), + ), + #[cfg(feature = "with-chrono")] + Value::ChronoDateTimeLocal(Some(v)) => stmt.add_param( + param_name, + &crate::chrono_support::SpannerTimestamp::new(v.to_utc()), + ), + #[cfg(feature = "with-chrono")] + Value::ChronoDateTimeLocal(None) => stmt.add_param( + param_name, + &crate::chrono_support::SpannerOptionalTimestamp::none(), + ), + #[cfg(feature = "with-chrono")] + Value::ChronoDateTimeWithTimeZone(Some(v)) => stmt.add_param( + param_name, + &crate::chrono_support::SpannerTimestamp::new(v.to_utc()), + ), + #[cfg(feature = "with-chrono")] + Value::ChronoDateTimeWithTimeZone(None) => stmt.add_param( + param_name, + &crate::chrono_support::SpannerOptionalTimestamp::none(), + ), + + #[cfg(feature = "with-uuid")] + Value::Uuid(Some(v)) => stmt.add_param(param_name, &v.to_string()), + #[cfg(feature = "with-uuid")] + Value::Uuid(None) => stmt.add_param(param_name, &Option::::None), + + #[cfg(feature = "with-json")] + Value::Json(Some(v)) => stmt.add_param( + param_name, + &crate::json_support::SpannerOptionalJson::some(v.as_ref().clone()), + ), + #[cfg(feature = "with-json")] + Value::Json(None) => stmt.add_param( + param_name, + &crate::json_support::SpannerOptionalJson::none(), + ), + + #[cfg(feature = "with-rust_decimal")] + Value::Decimal(Some(v)) => stmt.add_param(param_name, &v.to_string()), + #[cfg(feature = "with-rust_decimal")] + Value::Decimal(None) => stmt.add_param(param_name, &Option::::None), + + #[allow(unreachable_patterns)] + _ => { + return Err(SpannerDbErr::TypeConversion { + column: param_name.to_string(), + expected: "supported type".to_string(), + got: format!("{:?}", value), + } + .into()); + } + } + + Ok(()) +} diff --git a/src/connection.rs b/src/connection.rs index 506a159..0f27e22 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,17 +1,13 @@ -use crate::error::SpannerDbErr; -use crate::executor::SpannerExecutor; -use crate::query_result::SpannerQueryResult; -use gcloud_gax::cancel::CancellationToken; -use gcloud_gax::grpc::Status; -use gcloud_gax::retry::TryAs; -use gcloud_spanner::client::Client; -use gcloud_spanner::session::SessionError; -use gcloud_spanner::transaction_rw::ReadWriteTransaction; -use gcloud_spanner::transaction_ro::ReadOnlyTransaction; -use sea_orm::{DbBackend, DbErr, Statement}; -use std::future::Future; -use std::pin::Pin; -use std::sync::Arc; +use { + crate::{error::SpannerDbErr, executor::SpannerExecutor, query_result::SpannerQueryResult}, + gcloud_gax::{grpc::Status, retry::TryAs}, + gcloud_spanner::{ + client::Client, session::SessionError, transaction_ro::ReadOnlyTransaction, + transaction_rw::ReadWriteTransaction, + }, + sea_orm::{DbBackend, DbErr, Statement}, + std::{future::Future, pin::Pin, sync::Arc}, +}; #[derive(Clone)] pub struct SpannerConnection { @@ -29,8 +25,12 @@ impl SpannerConnection { &self.client } - pub async fn close(&self) { - self.client.close().await; + pub async fn close(self) { + Arc::try_unwrap(self.client) + .ok() + .expect("Cannot close: other references to Client exist") + .close() + .await; } pub fn get_database_backend(&self) -> DbBackend { @@ -63,15 +63,14 @@ impl SpannerConnection { E: TryAs + From + From + ToString, F: for<'tx> Fn( &'tx mut ReadWriteTransaction, - Option, ) -> Pin> + Send + 'tx>>, T: Send, { let result = self .client - .read_write_transaction(|tx, cancel| callback(tx, cancel)) + .read_write_transaction(|tx| callback(tx)) .await - .map_err(|e| DbErr::Custom(e.to_string()))?; + .map_err(|e: E| DbErr::Custom(e.to_string()))?; Ok(result.1) } @@ -79,8 +78,10 @@ impl SpannerConnection { pub async fn read_only_transaction(&self, callback: F) -> Result where F: for<'tx> FnOnce( - &'tx mut ReadOnlyTransaction, - ) -> Pin> + Send + 'tx>> + Send, + &'tx mut ReadOnlyTransaction, + ) + -> Pin> + Send + 'tx>> + + Send, T: Send, { let mut tx = self diff --git a/src/executor.rs b/src/executor.rs index 87a5e67..b32c3ee 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -1,10 +1,12 @@ -use crate::error::{SpannerDbErr, SpannerTxError}; -use crate::query_result::SpannerQueryResult; -use gcloud_gax::grpc::Status; -use gcloud_spanner::client::Client; -use gcloud_spanner::statement::Statement as SpannerStatement; -use sea_orm::{DbErr, Statement}; -use std::sync::Arc; +use { + crate::{ + error::{SpannerDbErr, SpannerTxError}, + query_result::SpannerQueryResult, + }, + gcloud_spanner::client::Client, + sea_orm::{DbErr, Statement}, + std::sync::Arc, +}; pub struct SpannerExecutor { client: Arc, @@ -16,19 +18,16 @@ impl SpannerExecutor { } pub async fn execute(&self, stmt: Statement) -> Result { - let spanner_stmt = self.convert_statement(&stmt)?; - - let result = self.client - .read_write_transaction(|tx, _cancel| { + let spanner_stmt = crate::bind::convert_statement(&stmt)?; + + let result = self + .client + .read_write_transaction(|tx| { let stmt = spanner_stmt.clone(); - Box::pin(async move { - tx.update(stmt) - .await - .map_err(|e: Status| SpannerTxError::from(e)) - }) + Box::pin(async move { tx.update(stmt).await.map_err(SpannerTxError::from) }) }) .await - .map_err(|e| SpannerDbErr::Execution(e.to_string()))?; + .map_err(|e: crate::error::SpannerTxError| SpannerDbErr::Execution(e.to_string()))?; Ok(result.1) } @@ -39,9 +38,10 @@ impl SpannerExecutor { } pub async fn query_all(&self, stmt: Statement) -> Result, DbErr> { - let spanner_stmt = self.convert_statement(&stmt)?; - - let mut tx = self.client + let spanner_stmt = crate::bind::convert_statement(&stmt)?; + + let mut tx = self + .client .single() .await .map_err(|e| SpannerDbErr::Query(e.to_string()))?; @@ -62,137 +62,4 @@ impl SpannerExecutor { Ok(results) } - - fn convert_statement(&self, stmt: &Statement) -> Result { - let sql = &stmt.sql; - let mut spanner_stmt = SpannerStatement::new(sql); - - if let Some(values) = &stmt.values { - for (idx, value) in values.0.iter().enumerate() { - let param_name = format!("p{}", idx + 1); - self.bind_value(&mut spanner_stmt, ¶m_name, value)?; - } - } - - Ok(spanner_stmt) - } - - fn bind_value( - &self, - stmt: &mut SpannerStatement, - param_name: &str, - value: &sea_orm::Value, - ) -> Result<(), DbErr> { - use sea_orm::Value; - - match value { - Value::Bool(Some(v)) => stmt.add_param(param_name, v), - Value::Bool(None) => stmt.add_param(param_name, &Option::::None), - - Value::TinyInt(Some(v)) => stmt.add_param(param_name, &(*v as i64)), - Value::TinyInt(None) => stmt.add_param(param_name, &Option::::None), - - Value::SmallInt(Some(v)) => stmt.add_param(param_name, &(*v as i64)), - Value::SmallInt(None) => stmt.add_param(param_name, &Option::::None), - - Value::Int(Some(v)) => stmt.add_param(param_name, &(*v as i64)), - Value::Int(None) => stmt.add_param(param_name, &Option::::None), - - Value::BigInt(Some(v)) => stmt.add_param(param_name, v), - Value::BigInt(None) => stmt.add_param(param_name, &Option::::None), - - Value::TinyUnsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)), - Value::TinyUnsigned(None) => stmt.add_param(param_name, &Option::::None), - - Value::SmallUnsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)), - Value::SmallUnsigned(None) => stmt.add_param(param_name, &Option::::None), - - Value::Unsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)), - Value::Unsigned(None) => stmt.add_param(param_name, &Option::::None), - - Value::BigUnsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)), - Value::BigUnsigned(None) => stmt.add_param(param_name, &Option::::None), - - Value::Float(Some(v)) => stmt.add_param(param_name, &(*v as f64)), - Value::Float(None) => stmt.add_param(param_name, &Option::::None), - - Value::Double(Some(v)) => stmt.add_param(param_name, v), - Value::Double(None) => stmt.add_param(param_name, &Option::::None), - - Value::String(Some(v)) => stmt.add_param(param_name, v.as_ref()), - Value::String(None) => stmt.add_param(param_name, &Option::::None), - - Value::Char(Some(v)) => stmt.add_param(param_name, &v.to_string()), - Value::Char(None) => stmt.add_param(param_name, &Option::::None), - - Value::Bytes(Some(v)) => stmt.add_param(param_name, v.as_ref()), - Value::Bytes(None) => stmt.add_param(param_name, &Option::>::None), - - #[cfg(feature = "with-chrono")] - Value::ChronoDate(Some(v)) => { - stmt.add_param(param_name, &v.format("%Y-%m-%d").to_string()) - } - #[cfg(feature = "with-chrono")] - Value::ChronoDate(None) => stmt.add_param(param_name, &Option::::None), - #[cfg(feature = "with-chrono")] - Value::ChronoTime(Some(v)) => { - stmt.add_param(param_name, &v.format("%H:%M:%S%.f").to_string()) - } - #[cfg(feature = "with-chrono")] - Value::ChronoTime(None) => stmt.add_param(param_name, &Option::::None), - #[cfg(feature = "with-chrono")] - Value::ChronoDateTime(Some(v)) => { - stmt.add_param(param_name, &v.and_utc()) - } - #[cfg(feature = "with-chrono")] - Value::ChronoDateTime(None) => stmt.add_param(param_name, &Option::>::None), - #[cfg(feature = "with-chrono")] - Value::ChronoDateTimeUtc(Some(v)) => stmt.add_param(param_name, v.as_ref()), - #[cfg(feature = "with-chrono")] - Value::ChronoDateTimeUtc(None) => stmt.add_param(param_name, &Option::>::None), - #[cfg(feature = "with-chrono")] - Value::ChronoDateTimeLocal(Some(v)) => { - stmt.add_param(param_name, &v.to_utc()) - } - #[cfg(feature = "with-chrono")] - Value::ChronoDateTimeLocal(None) => stmt.add_param(param_name, &Option::>::None), - #[cfg(feature = "with-chrono")] - Value::ChronoDateTimeWithTimeZone(Some(v)) => { - stmt.add_param(param_name, &v.to_utc()) - } - #[cfg(feature = "with-chrono")] - Value::ChronoDateTimeWithTimeZone(None) => stmt.add_param(param_name, &Option::>::None), - - #[cfg(feature = "with-uuid")] - Value::Uuid(Some(v)) => stmt.add_param(param_name, &v.to_string()), - #[cfg(feature = "with-uuid")] - Value::Uuid(None) => stmt.add_param(param_name, &Option::::None), - - #[cfg(feature = "with-json")] - Value::Json(Some(v)) => { - stmt.add_param(param_name, &crate::json_support::SpannerOptionalJson::some(v.as_ref().clone())) - } - #[cfg(feature = "with-json")] - Value::Json(None) => stmt.add_param(param_name, &crate::json_support::SpannerOptionalJson::none()), - - #[cfg(feature = "with-rust_decimal")] - Value::Decimal(Some(v)) => { - use rust_decimal::prelude::ToPrimitive; - stmt.add_param(param_name, &v.to_string()) - } - #[cfg(feature = "with-rust_decimal")] - Value::Decimal(None) => stmt.add_param(param_name, &Option::::None), - - #[allow(unreachable_patterns)] - _ => { - return Err(SpannerDbErr::TypeConversion { - column: param_name.to_string(), - expected: "supported type".to_string(), - got: format!("{:?}", value), - }.into()); - } - } - - Ok(()) - } } diff --git a/src/lib.rs b/src/lib.rs index 0217fc5..58f7038 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,16 @@ pub mod array_support; +mod bind; #[cfg(feature = "with-chrono")] pub mod chrono_support; +pub mod connection; mod database; mod error; +pub mod executor; #[cfg(feature = "with-json")] pub mod json_support; mod proxy; +pub mod query_result; +pub mod transaction; #[cfg(feature = "with-uuid")] pub mod uuid_support; diff --git a/src/proxy.rs b/src/proxy.rs index 99b7d75..acf792d 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -73,7 +73,33 @@ impl SpannerProxy { } fn rewrite_mysql_quotes(sql: &str) -> String { - sql.replace('`', "") + let mut result = String::with_capacity(sql.len()); + let mut chars = sql.chars().peekable(); + let mut in_string = false; + let mut string_char = ' '; + + while let Some(c) = chars.next() { + if !in_string && (c == '\'' || c == '"') { + in_string = true; + string_char = c; + result.push(c); + } else if in_string && c == string_char { + if chars.peek() == Some(&string_char) { + // Escaped quote inside string literal + result.push(c); + result.push(chars.next().unwrap()); + } else { + in_string = false; + result.push(c); + } + } else if !in_string && c == '`' { + // Skip backticks outside of string literals + } else { + result.push(c); + } + } + + result } fn bind_value( @@ -107,7 +133,14 @@ impl SpannerProxy { Value::Unsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)), Value::Unsigned(None) => stmt.add_param(param_name, &Option::::None), - Value::BigUnsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)), + Value::BigUnsigned(Some(v)) => { + let i = i64::try_from(*v).map_err(|_| SpannerDbErr::TypeConversion { + column: param_name.to_string(), + expected: "i64".to_string(), + got: format!("u64 value {} overflows i64", v), + })?; + stmt.add_param(param_name, &i); + } Value::BigUnsigned(None) => stmt.add_param(param_name, &Option::::None), Value::Float(Some(v)) => stmt.add_param(param_name, &(*v as f64)), @@ -250,24 +283,34 @@ impl SpannerProxy { ArrayType::Bool => { let arr: Vec = values .iter() - .filter_map(|v| match v { - Value::Bool(Some(b)) => Some(*b), - _ => None, + .enumerate() + .map(|(i, v)| match v { + Value::Bool(Some(b)) => Ok(*b), + _ => Err(SpannerDbErr::TypeConversion { + column: param_name.to_string(), + expected: "Bool".to_string(), + got: format!("element [{}]: {:?}", i, v), + }), }) - .collect(); + .collect::, _>>()?; stmt.add_param(param_name, &arr); } ArrayType::TinyInt | ArrayType::SmallInt | ArrayType::Int | ArrayType::BigInt => { let arr: Vec = values .iter() - .filter_map(|v| match v { - Value::TinyInt(Some(i)) => Some(*i as i64), - Value::SmallInt(Some(i)) => Some(*i as i64), - Value::Int(Some(i)) => Some(*i as i64), - Value::BigInt(Some(i)) => Some(*i), - _ => None, + .enumerate() + .map(|(i, v)| match v { + Value::TinyInt(Some(i)) => Ok(*i as i64), + Value::SmallInt(Some(i)) => Ok(*i as i64), + Value::Int(Some(i)) => Ok(*i as i64), + Value::BigInt(Some(i)) => Ok(*i), + _ => Err(SpannerDbErr::TypeConversion { + column: param_name.to_string(), + expected: "Int".to_string(), + got: format!("element [{}]: {:?}", i, v), + }), }) - .collect(); + .collect::, _>>()?; stmt.add_param(param_name, &arr); } ArrayType::TinyUnsigned @@ -276,46 +319,72 @@ impl SpannerProxy { | ArrayType::BigUnsigned => { let arr: Vec = values .iter() - .filter_map(|v| match v { - Value::TinyUnsigned(Some(i)) => Some(*i as i64), - Value::SmallUnsigned(Some(i)) => Some(*i as i64), - Value::Unsigned(Some(i)) => Some(*i as i64), - Value::BigUnsigned(Some(i)) => Some(*i as i64), - _ => None, + .enumerate() + .map(|(i, v)| match v { + Value::TinyUnsigned(Some(val)) => Ok(*val as i64), + Value::SmallUnsigned(Some(val)) => Ok(*val as i64), + Value::Unsigned(Some(val)) => Ok(*val as i64), + Value::BigUnsigned(Some(val)) => { + i64::try_from(*val).map_err(|_| SpannerDbErr::TypeConversion { + column: param_name.to_string(), + expected: "i64".to_string(), + got: format!("element [{}]: u64 value {} overflows i64", i, val), + }) + } + _ => Err(SpannerDbErr::TypeConversion { + column: param_name.to_string(), + expected: "Unsigned Int".to_string(), + got: format!("element [{}]: {:?}", i, v), + }), }) - .collect(); + .collect::, _>>()?; stmt.add_param(param_name, &arr); } ArrayType::Float | ArrayType::Double => { let arr: Vec = values .iter() - .filter_map(|v| match v { - Value::Float(Some(f)) => Some(*f as f64), - Value::Double(Some(d)) => Some(*d), - _ => None, + .enumerate() + .map(|(i, v)| match v { + Value::Float(Some(f)) => Ok(*f as f64), + Value::Double(Some(d)) => Ok(*d), + _ => Err(SpannerDbErr::TypeConversion { + column: param_name.to_string(), + expected: "Float".to_string(), + got: format!("element [{}]: {:?}", i, v), + }), }) - .collect(); + .collect::, _>>()?; stmt.add_param(param_name, &arr); } ArrayType::String | ArrayType::Char => { let arr: Vec = values .iter() - .filter_map(|v| match v { - Value::String(Some(s)) => Some(s.clone()), - Value::Char(Some(c)) => Some(c.to_string()), - _ => None, + .enumerate() + .map(|(i, v)| match v { + Value::String(Some(s)) => Ok(s.to_string()), + Value::Char(Some(c)) => Ok(c.to_string()), + _ => Err(SpannerDbErr::TypeConversion { + column: param_name.to_string(), + expected: "String".to_string(), + got: format!("element [{}]: {:?}", i, v), + }), }) - .collect(); + .collect::, _>>()?; stmt.add_param(param_name, &arr); } ArrayType::Bytes => { let arr: Vec> = values .iter() - .filter_map(|v| match v { - Value::Bytes(Some(b)) => Some(b.clone()), - _ => None, + .enumerate() + .map(|(i, v)| match v { + Value::Bytes(Some(b)) => Ok(b.to_vec()), + _ => Err(SpannerDbErr::TypeConversion { + column: param_name.to_string(), + expected: "Bytes".to_string(), + got: format!("element [{}]: {:?}", i, v), + }), }) - .collect(); + .collect::, _>>()?; stmt.add_param(param_name, &SpannerBytesArray(arr)); } #[cfg(feature = "with-chrono")] @@ -327,57 +396,77 @@ impl SpannerProxy { | ArrayType::ChronoDateTimeWithTimeZone => { let arr: Vec = values .iter() - .filter_map(|v| match v { - Value::ChronoDate(Some(d)) => Some(d.format("%Y-%m-%d").to_string()), - Value::ChronoTime(Some(t)) => Some(t.format("%H:%M:%S%.f").to_string()), + .enumerate() + .map(|(i, v)| match v { + Value::ChronoDate(Some(d)) => Ok(d.format("%Y-%m-%d").to_string()), + Value::ChronoTime(Some(t)) => Ok(t.format("%H:%M:%S%.f").to_string()), Value::ChronoDateTime(Some(dt)) => { - Some(dt.format("%Y-%m-%dT%H:%M:%S%.fZ").to_string()) + Ok(dt.format("%Y-%m-%dT%H:%M:%S%.fZ").to_string()) } Value::ChronoDateTimeUtc(Some(dt)) => { - Some(dt.format("%Y-%m-%dT%H:%M:%S%.fZ").to_string()) + Ok(dt.format("%Y-%m-%dT%H:%M:%S%.fZ").to_string()) } Value::ChronoDateTimeLocal(Some(dt)) => { - Some(dt.format("%Y-%m-%dT%H:%M:%S%.fZ").to_string()) + Ok(dt.format("%Y-%m-%dT%H:%M:%S%.fZ").to_string()) } Value::ChronoDateTimeWithTimeZone(Some(dt)) => { - Some(dt.format("%Y-%m-%dT%H:%M:%S%.fZ").to_string()) + Ok(dt.format("%Y-%m-%dT%H:%M:%S%.fZ").to_string()) } - _ => None, + _ => Err(SpannerDbErr::TypeConversion { + column: param_name.to_string(), + expected: "Chrono DateTime".to_string(), + got: format!("element [{}]: {:?}", i, v), + }), }) - .collect(); + .collect::, _>>()?; stmt.add_param(param_name, &arr); } #[cfg(feature = "with-uuid")] ArrayType::Uuid => { let arr: Vec = values .iter() - .filter_map(|v| match v { - Value::Uuid(Some(u)) => Some(u.to_string()), - _ => None, + .enumerate() + .map(|(i, v)| match v { + Value::Uuid(Some(u)) => Ok(u.to_string()), + _ => Err(SpannerDbErr::TypeConversion { + column: param_name.to_string(), + expected: "Uuid".to_string(), + got: format!("element [{}]: {:?}", i, v), + }), }) - .collect(); + .collect::, _>>()?; stmt.add_param(param_name, &arr); } #[cfg(feature = "with-json")] ArrayType::Json => { let arr: Vec = values .iter() - .filter_map(|v| match v { - Value::Json(Some(j)) => Some(j.to_string()), - _ => None, + .enumerate() + .map(|(i, v)| match v { + Value::Json(Some(j)) => Ok(j.to_string()), + _ => Err(SpannerDbErr::TypeConversion { + column: param_name.to_string(), + expected: "Json".to_string(), + got: format!("element [{}]: {:?}", i, v), + }), }) - .collect(); + .collect::, _>>()?; stmt.add_param(param_name, &arr); } #[cfg(feature = "with-rust_decimal")] ArrayType::Decimal => { let arr: Vec = values .iter() - .filter_map(|v| match v { - Value::Decimal(Some(d)) => Some(d.to_string()), - _ => None, + .enumerate() + .map(|(i, v)| match v { + Value::Decimal(Some(d)) => Ok(d.to_string()), + _ => Err(SpannerDbErr::TypeConversion { + column: param_name.to_string(), + expected: "Decimal".to_string(), + got: format!("element [{}]: {:?}", i, v), + }), }) - .collect(); + .collect::, _>>()?; stmt.add_param(param_name, &arr); } #[allow(unreachable_patterns)] @@ -800,8 +889,7 @@ impl SpannerProxy { if sql[i..].starts_with("FROM") { let next_idx = i + 4; if next_idx >= bytes.len() - || !bytes[next_idx].is_ascii_alphanumeric() - || bytes[next_idx] == b'_' + || (!bytes[next_idx].is_ascii_alphanumeric() && bytes[next_idx] != b'_') { return Some(i); } @@ -891,15 +979,24 @@ impl ProxyDatabaseTrait for SpannerProxy { } async fn begin(&self) { - // Spanner uses callback-based transactions, handled differently + tracing::warn!( + "SpannerProxy::begin() is a no-op. Spanner uses callback-based transactions via \ + SpannerConnection::read_write_transaction() instead of begin/commit/rollback." + ); } async fn commit(&self) { - // Handled by transaction callback + tracing::warn!( + "SpannerProxy::commit() is a no-op. Spanner transactions are committed \ + automatically when the callback passed to read_write_transaction() succeeds." + ); } async fn rollback(&self) { - // Handled by transaction callback + tracing::warn!( + "SpannerProxy::rollback() is a no-op. Spanner transactions are rolled back \ + automatically when the callback passed to read_write_transaction() returns an error." + ); } async fn ping(&self) -> Result<(), DbErr> { diff --git a/src/query_result.rs b/src/query_result.rs index ad9f90f..d58df50 100644 --- a/src/query_result.rs +++ b/src/query_result.rs @@ -1,5 +1,7 @@ -use gcloud_spanner::row::Row as SpannerRow; -use sea_orm::{DbErr, TryGetError}; +use { + gcloud_spanner::row::Row as SpannerRow, + sea_orm::{DbErr, TryGetError}, +}; pub struct SpannerQueryResult { row: SpannerRow, diff --git a/src/transaction.rs b/src/transaction.rs index 276dfa6..f32703e 100644 --- a/src/transaction.rs +++ b/src/transaction.rs @@ -1,9 +1,8 @@ -use crate::error::SpannerDbErr; -use crate::query_result::SpannerQueryResult; -use gcloud_spanner::statement::Statement as SpannerStatement; -use gcloud_spanner::transaction_ro::ReadOnlyTransaction; -use gcloud_spanner::transaction_rw::ReadWriteTransaction; -use sea_orm::{DbErr, Statement}; +use { + crate::{error::SpannerDbErr, query_result::SpannerQueryResult}, + gcloud_spanner::{transaction_ro::ReadOnlyTransaction, transaction_rw::ReadWriteTransaction}, + sea_orm::{DbErr, Statement}, +}; pub struct SpannerReadWriteTransaction<'a> { tx: &'a mut ReadWriteTransaction, @@ -15,22 +14,29 @@ impl<'a> SpannerReadWriteTransaction<'a> { } pub async fn execute(&mut self, stmt: Statement) -> Result { - let spanner_stmt = self.convert_statement(&stmt)?; - let rows_affected = self.tx.update(spanner_stmt).await + let spanner_stmt = crate::bind::convert_statement(&stmt)?; + let rows_affected = self + .tx + .update(spanner_stmt) + .await .map_err(|e| SpannerDbErr::Execution(e.to_string()))?; - + Ok(rows_affected) } - pub async fn query_one(&mut self, stmt: Statement) -> Result, DbErr> { + pub async fn query_one( + &mut self, + stmt: Statement, + ) -> Result, DbErr> { let results = self.query_all(stmt).await?; Ok(results.into_iter().next()) } pub async fn query_all(&mut self, stmt: Statement) -> Result, DbErr> { - let spanner_stmt = self.convert_statement(&stmt)?; - - let mut iter = self.tx + let spanner_stmt = crate::bind::convert_statement(&stmt)?; + + let mut iter = self + .tx .query(spanner_stmt) .await .map_err(|e| SpannerDbErr::Query(e.to_string()))?; @@ -46,124 +52,6 @@ impl<'a> SpannerReadWriteTransaction<'a> { Ok(results) } - - fn convert_statement(&self, stmt: &Statement) -> Result { - let sql = &stmt.sql; - let mut spanner_stmt = SpannerStatement::new(sql); - - if let Some(values) = &stmt.values { - for (idx, value) in values.0.iter().enumerate() { - let param_name = format!("p{}", idx + 1); - self.bind_value(&mut spanner_stmt, ¶m_name, value)?; - } - } - - Ok(spanner_stmt) - } - - fn bind_value( - &self, - stmt: &mut SpannerStatement, - param_name: &str, - value: &sea_orm::Value, - ) -> Result<(), DbErr> { - use sea_orm::Value; - - match value { - Value::Bool(Some(v)) => stmt.add_param(param_name, v), - Value::Bool(None) => stmt.add_param(param_name, &Option::::None), - Value::TinyInt(Some(v)) => stmt.add_param(param_name, &(*v as i64)), - Value::TinyInt(None) => stmt.add_param(param_name, &Option::::None), - Value::SmallInt(Some(v)) => stmt.add_param(param_name, &(*v as i64)), - Value::SmallInt(None) => stmt.add_param(param_name, &Option::::None), - Value::Int(Some(v)) => stmt.add_param(param_name, &(*v as i64)), - Value::Int(None) => stmt.add_param(param_name, &Option::::None), - Value::BigInt(Some(v)) => stmt.add_param(param_name, v), - Value::BigInt(None) => stmt.add_param(param_name, &Option::::None), - Value::TinyUnsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)), - Value::TinyUnsigned(None) => stmt.add_param(param_name, &Option::::None), - Value::SmallUnsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)), - Value::SmallUnsigned(None) => stmt.add_param(param_name, &Option::::None), - Value::Unsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)), - Value::Unsigned(None) => stmt.add_param(param_name, &Option::::None), - Value::BigUnsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)), - Value::BigUnsigned(None) => stmt.add_param(param_name, &Option::::None), - Value::Float(Some(v)) => stmt.add_param(param_name, &(*v as f64)), - Value::Float(None) => stmt.add_param(param_name, &Option::::None), - Value::Double(Some(v)) => stmt.add_param(param_name, v), - Value::Double(None) => stmt.add_param(param_name, &Option::::None), - Value::String(Some(v)) => stmt.add_param(param_name, v.as_ref()), - Value::String(None) => stmt.add_param(param_name, &Option::::None), - Value::Char(Some(v)) => stmt.add_param(param_name, &v.to_string()), - Value::Char(None) => stmt.add_param(param_name, &Option::::None), - Value::Bytes(Some(v)) => stmt.add_param(param_name, v.as_ref()), - Value::Bytes(None) => stmt.add_param(param_name, &Option::>::None), - - #[cfg(feature = "with-chrono")] - Value::ChronoDate(Some(v)) => { - stmt.add_param(param_name, &v.format("%Y-%m-%d").to_string()) - } - #[cfg(feature = "with-chrono")] - Value::ChronoDate(None) => stmt.add_param(param_name, &Option::::None), - #[cfg(feature = "with-chrono")] - Value::ChronoTime(Some(v)) => { - stmt.add_param(param_name, &v.format("%H:%M:%S%.f").to_string()) - } - #[cfg(feature = "with-chrono")] - Value::ChronoTime(None) => stmt.add_param(param_name, &Option::::None), - #[cfg(feature = "with-chrono")] - Value::ChronoDateTime(Some(v)) => { - stmt.add_param(param_name, &v.and_utc()) - } - #[cfg(feature = "with-chrono")] - Value::ChronoDateTime(None) => stmt.add_param(param_name, &Option::>::None), - #[cfg(feature = "with-chrono")] - Value::ChronoDateTimeUtc(Some(v)) => stmt.add_param(param_name, v.as_ref()), - #[cfg(feature = "with-chrono")] - Value::ChronoDateTimeUtc(None) => stmt.add_param(param_name, &Option::>::None), - #[cfg(feature = "with-chrono")] - Value::ChronoDateTimeLocal(Some(v)) => { - stmt.add_param(param_name, &v.to_utc()) - } - #[cfg(feature = "with-chrono")] - Value::ChronoDateTimeLocal(None) => stmt.add_param(param_name, &Option::>::None), - #[cfg(feature = "with-chrono")] - Value::ChronoDateTimeWithTimeZone(Some(v)) => { - stmt.add_param(param_name, &v.to_utc()) - } - #[cfg(feature = "with-chrono")] - Value::ChronoDateTimeWithTimeZone(None) => stmt.add_param(param_name, &Option::>::None), - - #[cfg(feature = "with-uuid")] - Value::Uuid(Some(v)) => stmt.add_param(param_name, &v.to_string()), - #[cfg(feature = "with-uuid")] - Value::Uuid(None) => stmt.add_param(param_name, &Option::::None), - - #[cfg(feature = "with-json")] - Value::Json(Some(v)) => { - stmt.add_param(param_name, &crate::json_support::SpannerOptionalJson::some(v.as_ref().clone())) - } - #[cfg(feature = "with-json")] - Value::Json(None) => stmt.add_param(param_name, &crate::json_support::SpannerOptionalJson::none()), - - #[cfg(feature = "with-rust_decimal")] - Value::Decimal(Some(v)) => { - stmt.add_param(param_name, &v.to_string()) - } - #[cfg(feature = "with-rust_decimal")] - Value::Decimal(None) => stmt.add_param(param_name, &Option::::None), - - #[allow(unreachable_patterns)] - _ => { - return Err(SpannerDbErr::TypeConversion { - column: param_name.to_string(), - expected: "supported type".to_string(), - got: format!("{:?}", value), - }.into()); - } - } - Ok(()) - } } pub struct SpannerReadOnlyTransaction<'a> { @@ -175,15 +63,19 @@ impl<'a> SpannerReadOnlyTransaction<'a> { Self { tx } } - pub async fn query_one(&mut self, stmt: Statement) -> Result, DbErr> { + pub async fn query_one( + &mut self, + stmt: Statement, + ) -> Result, DbErr> { let results = self.query_all(stmt).await?; Ok(results.into_iter().next()) } pub async fn query_all(&mut self, stmt: Statement) -> Result, DbErr> { - let spanner_stmt = self.convert_statement(&stmt)?; - - let mut iter = self.tx + let spanner_stmt = crate::bind::convert_statement(&stmt)?; + + let mut iter = self + .tx .query(spanner_stmt) .await .map_err(|e| SpannerDbErr::Query(e.to_string()))?; @@ -199,124 +91,6 @@ impl<'a> SpannerReadOnlyTransaction<'a> { Ok(results) } - - fn convert_statement(&self, stmt: &Statement) -> Result { - let sql = &stmt.sql; - let mut spanner_stmt = SpannerStatement::new(sql); - - if let Some(values) = &stmt.values { - for (idx, value) in values.0.iter().enumerate() { - let param_name = format!("p{}", idx + 1); - self.bind_value(&mut spanner_stmt, ¶m_name, value)?; - } - } - - Ok(spanner_stmt) - } - - fn bind_value( - &self, - stmt: &mut SpannerStatement, - param_name: &str, - value: &sea_orm::Value, - ) -> Result<(), DbErr> { - use sea_orm::Value; - - match value { - Value::Bool(Some(v)) => stmt.add_param(param_name, v), - Value::Bool(None) => stmt.add_param(param_name, &Option::::None), - Value::TinyInt(Some(v)) => stmt.add_param(param_name, &(*v as i64)), - Value::TinyInt(None) => stmt.add_param(param_name, &Option::::None), - Value::SmallInt(Some(v)) => stmt.add_param(param_name, &(*v as i64)), - Value::SmallInt(None) => stmt.add_param(param_name, &Option::::None), - Value::Int(Some(v)) => stmt.add_param(param_name, &(*v as i64)), - Value::Int(None) => stmt.add_param(param_name, &Option::::None), - Value::BigInt(Some(v)) => stmt.add_param(param_name, v), - Value::BigInt(None) => stmt.add_param(param_name, &Option::::None), - Value::TinyUnsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)), - Value::TinyUnsigned(None) => stmt.add_param(param_name, &Option::::None), - Value::SmallUnsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)), - Value::SmallUnsigned(None) => stmt.add_param(param_name, &Option::::None), - Value::Unsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)), - Value::Unsigned(None) => stmt.add_param(param_name, &Option::::None), - Value::BigUnsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)), - Value::BigUnsigned(None) => stmt.add_param(param_name, &Option::::None), - Value::Float(Some(v)) => stmt.add_param(param_name, &(*v as f64)), - Value::Float(None) => stmt.add_param(param_name, &Option::::None), - Value::Double(Some(v)) => stmt.add_param(param_name, v), - Value::Double(None) => stmt.add_param(param_name, &Option::::None), - Value::String(Some(v)) => stmt.add_param(param_name, v.as_ref()), - Value::String(None) => stmt.add_param(param_name, &Option::::None), - Value::Char(Some(v)) => stmt.add_param(param_name, &v.to_string()), - Value::Char(None) => stmt.add_param(param_name, &Option::::None), - Value::Bytes(Some(v)) => stmt.add_param(param_name, v.as_ref()), - Value::Bytes(None) => stmt.add_param(param_name, &Option::>::None), - - #[cfg(feature = "with-chrono")] - Value::ChronoDate(Some(v)) => { - stmt.add_param(param_name, &v.format("%Y-%m-%d").to_string()) - } - #[cfg(feature = "with-chrono")] - Value::ChronoDate(None) => stmt.add_param(param_name, &Option::::None), - #[cfg(feature = "with-chrono")] - Value::ChronoTime(Some(v)) => { - stmt.add_param(param_name, &v.format("%H:%M:%S%.f").to_string()) - } - #[cfg(feature = "with-chrono")] - Value::ChronoTime(None) => stmt.add_param(param_name, &Option::::None), - #[cfg(feature = "with-chrono")] - Value::ChronoDateTime(Some(v)) => { - stmt.add_param(param_name, &v.and_utc()) - } - #[cfg(feature = "with-chrono")] - Value::ChronoDateTime(None) => stmt.add_param(param_name, &Option::>::None), - #[cfg(feature = "with-chrono")] - Value::ChronoDateTimeUtc(Some(v)) => stmt.add_param(param_name, v.as_ref()), - #[cfg(feature = "with-chrono")] - Value::ChronoDateTimeUtc(None) => stmt.add_param(param_name, &Option::>::None), - #[cfg(feature = "with-chrono")] - Value::ChronoDateTimeLocal(Some(v)) => { - stmt.add_param(param_name, &v.to_utc()) - } - #[cfg(feature = "with-chrono")] - Value::ChronoDateTimeLocal(None) => stmt.add_param(param_name, &Option::>::None), - #[cfg(feature = "with-chrono")] - Value::ChronoDateTimeWithTimeZone(Some(v)) => { - stmt.add_param(param_name, &v.to_utc()) - } - #[cfg(feature = "with-chrono")] - Value::ChronoDateTimeWithTimeZone(None) => stmt.add_param(param_name, &Option::>::None), - - #[cfg(feature = "with-uuid")] - Value::Uuid(Some(v)) => stmt.add_param(param_name, &v.to_string()), - #[cfg(feature = "with-uuid")] - Value::Uuid(None) => stmt.add_param(param_name, &Option::::None), - - #[cfg(feature = "with-json")] - Value::Json(Some(v)) => { - stmt.add_param(param_name, &crate::json_support::SpannerOptionalJson::some(v.as_ref().clone())) - } - #[cfg(feature = "with-json")] - Value::Json(None) => stmt.add_param(param_name, &crate::json_support::SpannerOptionalJson::none()), - - #[cfg(feature = "with-rust_decimal")] - Value::Decimal(Some(v)) => { - stmt.add_param(param_name, &v.to_string()) - } - #[cfg(feature = "with-rust_decimal")] - Value::Decimal(None) => stmt.add_param(param_name, &Option::::None), - - #[allow(unreachable_patterns)] - _ => { - return Err(SpannerDbErr::TypeConversion { - column: param_name.to_string(), - expected: "supported type".to_string(), - got: format!("{:?}", value), - }.into()); - } - } - Ok(()) - } } impl std::fmt::Debug for SpannerReadWriteTransaction<'_> {