Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions src/bind.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
use {
crate::error::SpannerDbErr,
gcloud_spanner::statement::Statement as SpannerStatement,
sea_orm::{DbErr, Statement},
};

/// Convert a SeaORM Statement into a Spanner Statement with bound parameters.
pub(crate) fn convert_statement(stmt: &Statement) -> Result<SpannerStatement, DbErr> {
let sql = &stmt.sql;
let mut spanner_stmt = SpannerStatement::new(sql);

if let Some(values) = &stmt.values {
for (idx, value) in values.0.iter().enumerate() {
let param_name = format!("p{}", idx + 1);
bind_value(&mut spanner_stmt, &param_name, value)?;
}
}

Ok(spanner_stmt)
}

/// Bind a sea_orm::Value to a Spanner Statement parameter.
///
/// This is the shared implementation used by SpannerExecutor, SpannerReadWriteTransaction,
/// and SpannerReadOnlyTransaction. The proxy path (SpannerProxy) has its own implementation
/// that uses custom wrapper types for Spanner-specific type handling.
pub(crate) fn bind_value(
stmt: &mut SpannerStatement,
param_name: &str,
value: &sea_orm::Value,
) -> Result<(), DbErr> {
use sea_orm::Value;

match value {
Value::Bool(Some(v)) => stmt.add_param(param_name, v),
Value::Bool(None) => stmt.add_param(param_name, &Option::<bool>::None),
Value::TinyInt(Some(v)) => stmt.add_param(param_name, &(*v as i64)),
Value::TinyInt(None) => stmt.add_param(param_name, &Option::<i64>::None),
Value::SmallInt(Some(v)) => stmt.add_param(param_name, &(*v as i64)),
Value::SmallInt(None) => stmt.add_param(param_name, &Option::<i64>::None),
Value::Int(Some(v)) => stmt.add_param(param_name, &(*v as i64)),
Value::Int(None) => stmt.add_param(param_name, &Option::<i64>::None),
Value::BigInt(Some(v)) => stmt.add_param(param_name, v),
Value::BigInt(None) => stmt.add_param(param_name, &Option::<i64>::None),
Value::TinyUnsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)),
Value::TinyUnsigned(None) => stmt.add_param(param_name, &Option::<i64>::None),
Value::SmallUnsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)),
Value::SmallUnsigned(None) => stmt.add_param(param_name, &Option::<i64>::None),
Value::Unsigned(Some(v)) => stmt.add_param(param_name, &(*v as i64)),
Value::Unsigned(None) => stmt.add_param(param_name, &Option::<i64>::None),
Value::BigUnsigned(Some(v)) => {
let i = i64::try_from(*v).map_err(|_| SpannerDbErr::TypeConversion {
column: param_name.to_string(),
expected: "i64".to_string(),
got: format!("u64 value {} overflows i64", v),
})?;
stmt.add_param(param_name, &i);
}
Value::BigUnsigned(None) => stmt.add_param(param_name, &Option::<i64>::None),
Value::Float(Some(v)) => stmt.add_param(param_name, &(*v as f64)),
Value::Float(None) => stmt.add_param(param_name, &Option::<f64>::None),
Value::Double(Some(v)) => stmt.add_param(param_name, v),
Value::Double(None) => stmt.add_param(param_name, &Option::<f64>::None),
Value::String(Some(v)) => {
let s: &str = v.as_ref();
stmt.add_param(param_name, &s)
}
Value::String(None) => stmt.add_param(param_name, &Option::<String>::None),
Value::Char(Some(v)) => stmt.add_param(param_name, &v.to_string()),
Value::Char(None) => stmt.add_param(param_name, &Option::<String>::None),
Value::Bytes(Some(v)) => {
let b: &[u8] = v.as_ref();
stmt.add_param(param_name, &b)
}
Value::Bytes(None) => stmt.add_param(param_name, &Option::<Vec<u8>>::None),

#[cfg(feature = "with-chrono")]
Value::ChronoDate(Some(v)) => stmt.add_param(param_name, &v.format("%Y-%m-%d").to_string()),
#[cfg(feature = "with-chrono")]
Value::ChronoDate(None) => stmt.add_param(param_name, &Option::<String>::None),
#[cfg(feature = "with-chrono")]
Value::ChronoTime(Some(v)) => {
stmt.add_param(param_name, &v.format("%H:%M:%S%.f").to_string())
}
#[cfg(feature = "with-chrono")]
Value::ChronoTime(None) => stmt.add_param(param_name, &Option::<String>::None),
#[cfg(feature = "with-chrono")]
Value::ChronoDateTime(Some(v)) => stmt.add_param(
param_name,
&crate::chrono_support::SpannerTimestamp::new(v.and_utc()),
),
#[cfg(feature = "with-chrono")]
Value::ChronoDateTime(None) => stmt.add_param(
param_name,
&crate::chrono_support::SpannerOptionalTimestamp::none(),
),
#[cfg(feature = "with-chrono")]
Value::ChronoDateTimeUtc(Some(v)) => stmt.add_param(
param_name,
&crate::chrono_support::SpannerTimestamp::new(v.to_utc()),
),
#[cfg(feature = "with-chrono")]
Value::ChronoDateTimeUtc(None) => stmt.add_param(
param_name,
&crate::chrono_support::SpannerOptionalTimestamp::none(),
),
#[cfg(feature = "with-chrono")]
Value::ChronoDateTimeLocal(Some(v)) => stmt.add_param(
param_name,
&crate::chrono_support::SpannerTimestamp::new(v.to_utc()),
),
#[cfg(feature = "with-chrono")]
Value::ChronoDateTimeLocal(None) => stmt.add_param(
param_name,
&crate::chrono_support::SpannerOptionalTimestamp::none(),
),
#[cfg(feature = "with-chrono")]
Value::ChronoDateTimeWithTimeZone(Some(v)) => stmt.add_param(
param_name,
&crate::chrono_support::SpannerTimestamp::new(v.to_utc()),
),
#[cfg(feature = "with-chrono")]
Value::ChronoDateTimeWithTimeZone(None) => stmt.add_param(
param_name,
&crate::chrono_support::SpannerOptionalTimestamp::none(),
),

#[cfg(feature = "with-uuid")]
Value::Uuid(Some(v)) => stmt.add_param(param_name, &v.to_string()),
#[cfg(feature = "with-uuid")]
Value::Uuid(None) => stmt.add_param(param_name, &Option::<String>::None),

#[cfg(feature = "with-json")]
Value::Json(Some(v)) => stmt.add_param(
param_name,
&crate::json_support::SpannerOptionalJson::some(v.as_ref().clone()),
),
#[cfg(feature = "with-json")]
Value::Json(None) => stmt.add_param(
param_name,
&crate::json_support::SpannerOptionalJson::none(),
),

#[cfg(feature = "with-rust_decimal")]
Value::Decimal(Some(v)) => stmt.add_param(param_name, &v.to_string()),
#[cfg(feature = "with-rust_decimal")]
Value::Decimal(None) => stmt.add_param(param_name, &Option::<String>::None),

#[allow(unreachable_patterns)]
_ => {
return Err(SpannerDbErr::TypeConversion {
column: param_name.to_string(),
expected: "supported type".to_string(),
got: format!("{:?}", value),
}
.into());
}
}

Ok(())
}
43 changes: 22 additions & 21 deletions src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
use crate::error::SpannerDbErr;
use crate::executor::SpannerExecutor;
use crate::query_result::SpannerQueryResult;
use gcloud_gax::cancel::CancellationToken;
use gcloud_gax::grpc::Status;
use gcloud_gax::retry::TryAs;
use gcloud_spanner::client::Client;
use gcloud_spanner::session::SessionError;
use gcloud_spanner::transaction_rw::ReadWriteTransaction;
use gcloud_spanner::transaction_ro::ReadOnlyTransaction;
use sea_orm::{DbBackend, DbErr, Statement};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use {
crate::{error::SpannerDbErr, executor::SpannerExecutor, query_result::SpannerQueryResult},
gcloud_gax::{grpc::Status, retry::TryAs},
gcloud_spanner::{
client::Client, session::SessionError, transaction_ro::ReadOnlyTransaction,
transaction_rw::ReadWriteTransaction,
},
sea_orm::{DbBackend, DbErr, Statement},
std::{future::Future, pin::Pin, sync::Arc},
};

#[derive(Clone)]
pub struct SpannerConnection {
Expand All @@ -29,8 +25,12 @@ impl SpannerConnection {
&self.client
}

pub async fn close(&self) {
self.client.close().await;
pub async fn close(self) {
Arc::try_unwrap(self.client)
.ok()
.expect("Cannot close: other references to Client exist")
.close()
.await;
Comment on lines +28 to +33
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SpannerConnection::close(self) will panic at runtime if the Arc<Client> has been cloned anywhere else (Arc::try_unwrap(...).expect(...)). Since SpannerConnection is Clone and now publicly exported, this is easy to trigger unintentionally and would crash consumer applications. Consider returning a Result<(), DbErr> (or logging and returning) when try_unwrap fails instead of panicking, or provide a separate infallible close_if_unique(self) helper if you want to preserve strict behavior.

Suggested change
pub async fn close(self) {
Arc::try_unwrap(self.client)
.ok()
.expect("Cannot close: other references to Client exist")
.close()
.await;
pub async fn close(self) -> Result<(), DbErr> {
match Arc::try_unwrap(self.client) {
Ok(client) => {
client.close().await;
Ok(())
}
Err(_) => Err(DbErr::Custom(
"Cannot close: other references to Client exist".to_owned(),
)),
}

Copilot uses AI. Check for mistakes.
}

pub fn get_database_backend(&self) -> DbBackend {
Expand Down Expand Up @@ -63,24 +63,25 @@ impl SpannerConnection {
E: TryAs<Status> + From<SessionError> + From<Status> + ToString,
F: for<'tx> Fn(
&'tx mut ReadWriteTransaction,
Option<CancellationToken>,
) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'tx>>,
T: Send,
{
let result = self
.client
.read_write_transaction(|tx, cancel| callback(tx, cancel))
.read_write_transaction(|tx| callback(tx))
.await
.map_err(|e| DbErr::Custom(e.to_string()))?;
.map_err(|e: E| DbErr::Custom(e.to_string()))?;

Ok(result.1)
}

pub async fn read_only_transaction<F, T>(&self, callback: F) -> Result<T, DbErr>
where
F: for<'tx> FnOnce(
&'tx mut ReadOnlyTransaction,
) -> Pin<Box<dyn Future<Output = Result<T, DbErr>> + Send + 'tx>> + Send,
&'tx mut ReadOnlyTransaction,
)
-> Pin<Box<dyn Future<Output = Result<T, DbErr>> + Send + 'tx>>
+ Send,
T: Send,
{
let mut tx = self
Expand Down
Loading
Loading