Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions src/bind.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
use crate::error::SpannerDbErr;
use gcloud_spanner::statement::Statement as SpannerStatement;
use sea_orm::{DbErr, Statement};

/// Convert a SeaORM Statement into a Spanner Statement with bound parameters.
pub(crate) fn convert_statement(stmt: &Statement) -> Result<SpannerStatement, DbErr> {
let sql = &stmt.sql;
let mut spanner_stmt = SpannerStatement::new(sql);

if let Some(values) = &stmt.values {
for (idx, value) in values.0.iter().enumerate() {
let param_name = format!("p{}", idx + 1);
bind_value(&mut spanner_stmt, &param_name, value)?;
}
}

Ok(spanner_stmt)
}

/// Bind a sea_orm::Value to a Spanner Statement parameter.
///
/// This is the shared implementation used by SpannerExecutor, SpannerReadWriteTransaction,
/// and SpannerReadOnlyTransaction. The proxy path (SpannerProxy) has its own implementation
/// that uses custom wrapper types for Spanner-specific type handling.
pub(crate) fn bind_value(
stmt: &mut SpannerStatement,
param_name: &str,
value: &sea_orm::Value,
) -> Result<(), DbErr> {
use sea_orm::Value;

match value {
Value::Bool(Some(v)) => stmt.add_param(param_name, v),
Value::Bool(None) => stmt.add_param(param_name, &Option::<bool>::None),
Value::TinyInt(Some(v)) => stmt.add_param(param_name, &(*v as i64)),
Value::TinyInt(None) => stmt.add_param(param_name, &Option::<i64>::None),
Value::SmallInt(Some(v)) => stmt.add_param(param_name, &(*v as i64)),
Value::SmallInt(None) => stmt.add_param(param_name, &Option::<i64>::None),
Value::Int(Some(v)) => stmt.add_param(param_name, &(*v as i64)),
Value::Int(None) => stmt.add_param(param_name, &Option::<i64>::None),
Value::BigInt(Some(v)) => stmt.add_param(param_name, v),
Value::BigInt(None) => stmt.add_param(param_name, &Option::<i64>::None),
Value::TinyUnsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)),
Value::TinyUnsigned(None) => stmt.add_param(param_name, &Option::<i64>::None),
Value::SmallUnsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)),
Value::SmallUnsigned(None) => stmt.add_param(param_name, &Option::<i64>::None),
Value::Unsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)),
Value::Unsigned(None) => stmt.add_param(param_name, &Option::<i64>::None),
Value::BigUnsigned(Some(v)) => {
let i = i64::try_from(*v).map_err(|_| SpannerDbErr::TypeConversion {
column: param_name.to_string(),
expected: "i64".to_string(),
got: format!("u64 value {} overflows i64", v),
})?;
stmt.add_param(param_name, &i);
}
Value::BigUnsigned(None) => stmt.add_param(param_name, &Option::<i64>::None),
Value::Float(Some(v)) => stmt.add_param(param_name, &(*v as f64)),
Value::Float(None) => stmt.add_param(param_name, &Option::<f64>::None),
Value::Double(Some(v)) => stmt.add_param(param_name, v),
Value::Double(None) => stmt.add_param(param_name, &Option::<f64>::None),
Value::String(Some(v)) => {
let s: &str = v.as_ref();
stmt.add_param(param_name, &s)
}
Value::String(None) => stmt.add_param(param_name, &Option::<String>::None),
Value::Char(Some(v)) => stmt.add_param(param_name, &v.to_string()),
Value::Char(None) => stmt.add_param(param_name, &Option::<String>::None),
Value::Bytes(Some(v)) => {
let b: &[u8] = v.as_ref();
stmt.add_param(param_name, &b)
}
Value::Bytes(None) => stmt.add_param(param_name, &Option::<Vec<u8>>::None),

#[cfg(feature = "with-chrono")]
Value::ChronoDate(Some(v)) => stmt.add_param(param_name, &v.format("%Y-%m-%d").to_string()),
#[cfg(feature = "with-chrono")]
Value::ChronoDate(None) => stmt.add_param(param_name, &Option::<String>::None),
#[cfg(feature = "with-chrono")]
Value::ChronoTime(Some(v)) => {
stmt.add_param(param_name, &v.format("%H:%M:%S%.f").to_string())
}
#[cfg(feature = "with-chrono")]
Value::ChronoTime(None) => stmt.add_param(param_name, &Option::<String>::None),
#[cfg(feature = "with-chrono")]
Value::ChronoDateTime(Some(v)) => stmt.add_param(
param_name,
&crate::chrono_support::SpannerTimestamp::new(v.and_utc()),
),
#[cfg(feature = "with-chrono")]
Value::ChronoDateTime(None) => stmt.add_param(
param_name,
&crate::chrono_support::SpannerOptionalTimestamp::none(),
),
#[cfg(feature = "with-chrono")]
Value::ChronoDateTimeUtc(Some(v)) => stmt.add_param(
param_name,
&crate::chrono_support::SpannerTimestamp::new(v.to_utc()),
),
#[cfg(feature = "with-chrono")]
Value::ChronoDateTimeUtc(None) => stmt.add_param(
param_name,
&crate::chrono_support::SpannerOptionalTimestamp::none(),
),
#[cfg(feature = "with-chrono")]
Value::ChronoDateTimeLocal(Some(v)) => stmt.add_param(
param_name,
&crate::chrono_support::SpannerTimestamp::new(v.to_utc()),
),
#[cfg(feature = "with-chrono")]
Value::ChronoDateTimeLocal(None) => stmt.add_param(
param_name,
&crate::chrono_support::SpannerOptionalTimestamp::none(),
),
#[cfg(feature = "with-chrono")]
Value::ChronoDateTimeWithTimeZone(Some(v)) => stmt.add_param(
param_name,
&crate::chrono_support::SpannerTimestamp::new(v.to_utc()),
),
#[cfg(feature = "with-chrono")]
Value::ChronoDateTimeWithTimeZone(None) => stmt.add_param(
param_name,
&crate::chrono_support::SpannerOptionalTimestamp::none(),
),

#[cfg(feature = "with-uuid")]
Value::Uuid(Some(v)) => stmt.add_param(param_name, &v.to_string()),
#[cfg(feature = "with-uuid")]
Value::Uuid(None) => stmt.add_param(param_name, &Option::<String>::None),

#[cfg(feature = "with-json")]
Value::Json(Some(v)) => stmt.add_param(
param_name,
&crate::json_support::SpannerOptionalJson::some(v.as_ref().clone()),
),
#[cfg(feature = "with-json")]
Value::Json(None) => stmt.add_param(
param_name,
&crate::json_support::SpannerOptionalJson::none(),
),

#[cfg(feature = "with-rust_decimal")]
Value::Decimal(Some(v)) => stmt.add_param(param_name, &v.to_string()),
#[cfg(feature = "with-rust_decimal")]
Value::Decimal(None) => stmt.add_param(param_name, &Option::<String>::None),

#[allow(unreachable_patterns)]
_ => {
return Err(SpannerDbErr::TypeConversion {
column: param_name.to_string(),
expected: "supported type".to_string(),
got: format!("{:?}", value),
}
.into());
}
}

Ok(())
}
148 changes: 8 additions & 140 deletions src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::error::{SpannerDbErr, SpannerTxError};
use crate::query_result::SpannerQueryResult;
use gcloud_gax::grpc::Status;
use gcloud_spanner::client::Client;
use gcloud_spanner::statement::Statement as SpannerStatement;
use sea_orm::{DbErr, Statement};
use std::sync::Arc;

Expand All @@ -16,9 +15,10 @@ impl SpannerExecutor {
}

pub async fn execute(&self, stmt: Statement) -> Result<i64, DbErr> {
let spanner_stmt = self.convert_statement(&stmt)?;

let result = self.client
let spanner_stmt = crate::bind::convert_statement(&stmt)?;

let result = self
.client
.read_write_transaction(|tx, _cancel| {
let stmt = spanner_stmt.clone();
Box::pin(async move {
Expand All @@ -39,9 +39,10 @@ impl SpannerExecutor {
}

pub async fn query_all(&self, stmt: Statement) -> Result<Vec<SpannerQueryResult>, DbErr> {
let spanner_stmt = self.convert_statement(&stmt)?;

let mut tx = self.client
let spanner_stmt = crate::bind::convert_statement(&stmt)?;

let mut tx = self
.client
.single()
.await
.map_err(|e| SpannerDbErr::Query(e.to_string()))?;
Expand All @@ -62,137 +63,4 @@ impl SpannerExecutor {

Ok(results)
}

fn convert_statement(&self, stmt: &Statement) -> Result<SpannerStatement, DbErr> {
let sql = &stmt.sql;
let mut spanner_stmt = SpannerStatement::new(sql);

if let Some(values) = &stmt.values {
for (idx, value) in values.0.iter().enumerate() {
let param_name = format!("p{}", idx + 1);
self.bind_value(&mut spanner_stmt, &param_name, value)?;
}
}

Ok(spanner_stmt)
}

fn bind_value(
&self,
stmt: &mut SpannerStatement,
param_name: &str,
value: &sea_orm::Value,
) -> Result<(), DbErr> {
use sea_orm::Value;

match value {
Value::Bool(Some(v)) => stmt.add_param(param_name, v),
Value::Bool(None) => stmt.add_param(param_name, &Option::<bool>::None),

Value::TinyInt(Some(v)) => stmt.add_param(param_name, &(*v as i64)),
Value::TinyInt(None) => stmt.add_param(param_name, &Option::<i64>::None),

Value::SmallInt(Some(v)) => stmt.add_param(param_name, &(*v as i64)),
Value::SmallInt(None) => stmt.add_param(param_name, &Option::<i64>::None),

Value::Int(Some(v)) => stmt.add_param(param_name, &(*v as i64)),
Value::Int(None) => stmt.add_param(param_name, &Option::<i64>::None),

Value::BigInt(Some(v)) => stmt.add_param(param_name, v),
Value::BigInt(None) => stmt.add_param(param_name, &Option::<i64>::None),

Value::TinyUnsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)),
Value::TinyUnsigned(None) => stmt.add_param(param_name, &Option::<i64>::None),

Value::SmallUnsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)),
Value::SmallUnsigned(None) => stmt.add_param(param_name, &Option::<i64>::None),

Value::Unsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)),
Value::Unsigned(None) => stmt.add_param(param_name, &Option::<i64>::None),

Value::BigUnsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)),
Value::BigUnsigned(None) => stmt.add_param(param_name, &Option::<i64>::None),

Value::Float(Some(v)) => stmt.add_param(param_name, &(*v as f64)),
Value::Float(None) => stmt.add_param(param_name, &Option::<f64>::None),

Value::Double(Some(v)) => stmt.add_param(param_name, v),
Value::Double(None) => stmt.add_param(param_name, &Option::<f64>::None),

Value::String(Some(v)) => stmt.add_param(param_name, v.as_ref()),
Value::String(None) => stmt.add_param(param_name, &Option::<String>::None),

Value::Char(Some(v)) => stmt.add_param(param_name, &v.to_string()),
Value::Char(None) => stmt.add_param(param_name, &Option::<String>::None),

Value::Bytes(Some(v)) => stmt.add_param(param_name, v.as_ref()),
Value::Bytes(None) => stmt.add_param(param_name, &Option::<Vec<u8>>::None),

#[cfg(feature = "with-chrono")]
Value::ChronoDate(Some(v)) => {
stmt.add_param(param_name, &v.format("%Y-%m-%d").to_string())
}
#[cfg(feature = "with-chrono")]
Value::ChronoDate(None) => stmt.add_param(param_name, &Option::<String>::None),
#[cfg(feature = "with-chrono")]
Value::ChronoTime(Some(v)) => {
stmt.add_param(param_name, &v.format("%H:%M:%S%.f").to_string())
}
#[cfg(feature = "with-chrono")]
Value::ChronoTime(None) => stmt.add_param(param_name, &Option::<String>::None),
#[cfg(feature = "with-chrono")]
Value::ChronoDateTime(Some(v)) => {
stmt.add_param(param_name, &v.and_utc())
}
#[cfg(feature = "with-chrono")]
Value::ChronoDateTime(None) => stmt.add_param(param_name, &Option::<chrono::DateTime<chrono::Utc>>::None),
#[cfg(feature = "with-chrono")]
Value::ChronoDateTimeUtc(Some(v)) => stmt.add_param(param_name, v.as_ref()),
#[cfg(feature = "with-chrono")]
Value::ChronoDateTimeUtc(None) => stmt.add_param(param_name, &Option::<chrono::DateTime<chrono::Utc>>::None),
#[cfg(feature = "with-chrono")]
Value::ChronoDateTimeLocal(Some(v)) => {
stmt.add_param(param_name, &v.to_utc())
}
#[cfg(feature = "with-chrono")]
Value::ChronoDateTimeLocal(None) => stmt.add_param(param_name, &Option::<chrono::DateTime<chrono::Utc>>::None),
#[cfg(feature = "with-chrono")]
Value::ChronoDateTimeWithTimeZone(Some(v)) => {
stmt.add_param(param_name, &v.to_utc())
}
#[cfg(feature = "with-chrono")]
Value::ChronoDateTimeWithTimeZone(None) => stmt.add_param(param_name, &Option::<chrono::DateTime<chrono::Utc>>::None),

#[cfg(feature = "with-uuid")]
Value::Uuid(Some(v)) => stmt.add_param(param_name, &v.to_string()),
#[cfg(feature = "with-uuid")]
Value::Uuid(None) => stmt.add_param(param_name, &Option::<String>::None),

#[cfg(feature = "with-json")]
Value::Json(Some(v)) => {
stmt.add_param(param_name, &crate::json_support::SpannerOptionalJson::some(v.as_ref().clone()))
}
#[cfg(feature = "with-json")]
Value::Json(None) => stmt.add_param(param_name, &crate::json_support::SpannerOptionalJson::none()),

#[cfg(feature = "with-rust_decimal")]
Value::Decimal(Some(v)) => {
use rust_decimal::prelude::ToPrimitive;
stmt.add_param(param_name, &v.to_string())
}
#[cfg(feature = "with-rust_decimal")]
Value::Decimal(None) => stmt.add_param(param_name, &Option::<String>::None),

#[allow(unreachable_patterns)]
_ => {
return Err(SpannerDbErr::TypeConversion {
column: param_name.to_string(),
expected: "supported type".to_string(),
got: format!("{:?}", value),
}.into());
}
}

Ok(())
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod array_support;
mod bind;
#[cfg(feature = "with-chrono")]
pub mod chrono_support;
mod database;
Expand Down
Loading
Loading