Skip to content
2 changes: 1 addition & 1 deletion pgx-examples/schemas/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ mod tests {

#[pg_test]
fn test_my_some_schema_type() {
Spi::connect(|c| {
Spi::connect(|mut c| {
// "MySomeSchemaType" is in 'some_schema', so it needs to be discoverable
c.update("SET search_path TO some_schema,public", None, None);
assert_eq!(
Expand Down
4 changes: 2 additions & 2 deletions pgx-tests/src/tests/bgworker_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub extern "C" fn bgworker(arg: pg_sys::Datum) {
if arg > 0 {
BackgroundWorker::transaction(|| {
Spi::run("CREATE TABLE tests.bgworker_test (v INTEGER);");
Spi::connect(|client| {
Spi::connect(|mut client| {
client.update(
"INSERT INTO tests.bgworker_test VALUES ($1);",
None,
Expand Down Expand Up @@ -71,7 +71,7 @@ pub extern "C" fn bgworker_return_value(arg: pg_sys::Datum) {
};
while BackgroundWorker::wait_latch(Some(Duration::from_millis(100))) {}
BackgroundWorker::transaction(|| {
Spi::connect(|c| {
Spi::connect(|mut c| {
c.update(
"INSERT INTO tests.bgworker_test_return VALUES ($1)",
None,
Expand Down
12 changes: 6 additions & 6 deletions pgx-tests/src/tests/spi_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ mod tests {

#[pg_test]
fn test_inserting_null() -> Result<(), pgx::spi::Error> {
Spi::connect(|client| {
Spi::connect(|mut client| {
client.update("CREATE TABLE tests.null_test (id uuid)", None, None);
});
assert_eq!(
Expand All @@ -202,7 +202,7 @@ mod tests {

#[pg_test]
fn test_cursor() {
Spi::connect(|client| {
Spi::connect(|mut client| {
client.update("CREATE TABLE tests.cursor_table (id int)", None, None);
client.update(
"INSERT INTO tests.cursor_table (id) \
Expand All @@ -224,7 +224,7 @@ mod tests {

#[pg_test]
fn test_cursor_prepared_statement() -> Result<(), pgx::spi::Error> {
Spi::connect(|client| {
Spi::connect(|mut client| {
client.update("CREATE TABLE tests.cursor_table (id int)", None, None);
client.update(
"INSERT INTO tests.cursor_table (id) \
Expand All @@ -248,7 +248,7 @@ mod tests {

#[pg_test]
fn test_cursor_by_name() -> Result<(), pgx::spi::Error> {
let cursor_name = Spi::connect(|client| {
let cursor_name = Spi::connect(|mut client| {
client.update("CREATE TABLE tests.cursor_table (id int)", None, None);
client.update(
"INSERT INTO tests.cursor_table (id) \
Expand Down Expand Up @@ -308,7 +308,7 @@ mod tests {
assert_eq!(res.column_name(2).unwrap(), "b");
});

Spi::connect(|client| {
Spi::connect(|mut client| {
let res = client.update("SET TIME ZONE 'PST8PDT'", None, None);

assert_eq!(0, res.columns());
Expand All @@ -324,7 +324,7 @@ mod tests {
#[pg_test]
fn test_spi_non_mut() -> Result<(), pgx::spi::Error> {
// Ensures update and cursor APIs do not need mutable reference to SpiClient
Spi::connect(|client| {
Spi::connect(|mut client| {
client.update("SELECT 1", None, None);
let cursor = client.open_cursor("SELECT 1", None)?.detach_into_name();
client.find_cursor(&cursor).map(|_| ())
Expand Down
4 changes: 2 additions & 2 deletions pgx-tests/src/tests/srf_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ mod tests {

#[pg_test]
fn test_srf_setof_datum_detoasting_with_borrow() {
let cnt = Spi::connect(|client| {
let cnt = Spi::connect(|mut client| {
// build up a table with one large column that Postgres will be forced to TOAST
client.update("CREATE TABLE test_srf_datum_detoasting AS SELECT array_to_string(array_agg(g),' ') s FROM (SELECT 'a' g FROM generate_series(1, 1000000)) x;", None, None);

Expand All @@ -195,7 +195,7 @@ mod tests {

#[pg_test]
fn test_srf_table_datum_detoasting_with_borrow() {
let cnt = Spi::connect(|client| {
let cnt = Spi::connect(|mut client| {
// build up a table with one large column that Postgres will be forced to TOAST
client.update("CREATE TABLE test_srf_datum_detoasting AS SELECT array_to_string(array_agg(g),' ') s FROM (SELECT 'a' g FROM generate_series(1, 1000000)) x;", None, None);

Expand Down
2 changes: 1 addition & 1 deletion pgx-tests/src/tests/struct_type_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ mod tests {

#[pg_test]
fn test_complex_storage_and_retrieval() -> Result<(), pgx::spi::Error> {
let complex = Spi::connect(|client| {
let complex = Spi::connect(|mut client| {
client.update(
"CREATE TABLE complex_test AS SELECT s as id, (s || '.0, 2.0' || s)::complex as value FROM generate_series(1, 1000) s;\
SELECT value FROM complex_test ORDER BY id;", None, None).first().get_one::<PgBox<Complex>>()
Expand Down
88 changes: 56 additions & 32 deletions pgx/src/spi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,22 @@ pub enum Error {
pub struct Spi;

// TODO: should `'conn` be invariant?
pub struct SpiClient<'conn>(PhantomData<&'conn SpiConnection>);
pub struct SpiClient<'conn> {
phantom: PhantomData<&'conn SpiConnection>,
// This field indicates whether queries be readonly. Unless any `update` has been used
// `readonly` will be `true`.
// Postgres docs say:
//
// It is generally unwise to mix read-only and read-write commands within a single function
// using SPI; that could result in very confusing behavior, since the read-only queries
// would not see the results of any database updates done by the read-write queries.
//
// TODO: Alternatively, we can detect if the command counter (or something?) has incremented and if yes
// then we set read_only=false, else we can set it to true.
// However, we would still need to remember the previous value, which will be larger than the boolean.
// So, unless somebody will send commands to Postgres bypassing this SPI API, this flag seems sufficient.
readonly: bool,
}

/// a struct to manage our SPI connection lifetime
struct SpiConnection(PhantomData<*mut ()>);
Expand All @@ -156,7 +171,7 @@ impl Drop for SpiConnection {
impl SpiConnection {
/// Return a client that with a lifetime scoped to this connection.
fn client(&self) -> SpiClient<'_> {
SpiClient(PhantomData)
SpiClient { phantom: PhantomData, readonly: true }
}
}

Expand Down Expand Up @@ -263,7 +278,7 @@ impl<'a> Query for &'a str {

fn open_cursor<'c: 'cc, 'cc>(
self,
_client: &'cc SpiClient<'c>,
client: &'cc SpiClient<'c>,
args: Self::Arguments,
) -> Result<SpiCursor<'c>, Error> {
let src = std::ffi::CString::new(self).expect("query contained a null byte");
Expand All @@ -283,7 +298,7 @@ impl<'a> Query for &'a str {
argtypes.as_mut_ptr(),
datums.as_mut_ptr(),
nulls.as_ptr(),
false,
client.readonly,
0,
)
})
Expand Down Expand Up @@ -316,13 +331,13 @@ pub struct SpiHeapTupleData {

impl Spi {
pub fn get_one<A: FromDatum + IntoDatum>(query: &str) -> Result<Option<A>, Error> {
Spi::connect(|client| client.select(query, Some(1), None).first().get_one())
Spi::connect(|mut client| client.update(query, Some(1), None).first().get_one())
}

pub fn get_two<A: FromDatum + IntoDatum, B: FromDatum + IntoDatum>(
query: &str,
) -> Result<(Option<A>, Option<B>), Error> {
Spi::connect(|client| client.select(query, Some(1), None).first().get_two::<A, B>())
Spi::connect(|mut client| client.update(query, Some(1), None).first().get_two::<A, B>())
}

pub fn get_three<
Expand All @@ -332,21 +347,25 @@ impl Spi {
>(
query: &str,
) -> Result<(Option<A>, Option<B>, Option<C>), Error> {
Spi::connect(|client| client.select(query, Some(1), None).first().get_three::<A, B, C>())
Spi::connect(|mut client| {
client.update(query, Some(1), None).first().get_three::<A, B, C>()
})
}

pub fn get_one_with_args<A: FromDatum + IntoDatum>(
query: &str,
args: Vec<(PgOid, Option<pg_sys::Datum>)>,
) -> Result<Option<A>, Error> {
Spi::connect(|client| client.select(query, Some(1), Some(args)).first().get_one())
Spi::connect(|mut client| client.update(query, Some(1), Some(args)).first().get_one())
}

pub fn get_two_with_args<A: FromDatum + IntoDatum, B: FromDatum + IntoDatum>(
query: &str,
args: Vec<(PgOid, Option<pg_sys::Datum>)>,
) -> Result<(Option<A>, Option<B>), Error> {
Spi::connect(|client| client.select(query, Some(1), Some(args)).first().get_two::<A, B>())
Spi::connect(|mut client| {
client.update(query, Some(1), Some(args)).first().get_two::<A, B>()
})
}

pub fn get_three_with_args<
Expand All @@ -357,8 +376,8 @@ impl Spi {
query: &str,
args: Vec<(PgOid, Option<pg_sys::Datum>)>,
) -> Result<(Option<A>, Option<B>, Option<C>), Error> {
Spi::connect(|client| {
client.select(query, Some(1), Some(args)).first().get_three::<A, B, C>()
Spi::connect(|mut client| {
client.update(query, Some(1), Some(args)).first().get_three::<A, B, C>()
})
}

Expand All @@ -377,7 +396,7 @@ impl Spi {
///
/// The statement runs in read/write mode
pub fn run_with_args(query: &str, args: Option<Vec<(PgOid, Option<pg_sys::Datum>)>>) {
Spi::connect(|client| {
Spi::connect(|mut client| {
client.update(query, None, args);
})
}
Expand All @@ -392,7 +411,7 @@ impl Spi {
query: &str,
args: Option<Vec<(PgOid, Option<pg_sys::Datum>)>>,
) -> Result<Json, Error> {
Spi::connect(|client| {
Spi::connect(|mut client| {
let table =
client.update(&format!("EXPLAIN (format json) {}", query), None, args).first();
Ok(table.get_one::<Json>()?.unwrap())
Expand Down Expand Up @@ -452,23 +471,18 @@ impl Spi {
impl<'a> SpiClient<'a> {
/// perform a SELECT statement
pub fn select<Q: Query>(&self, query: Q, limit: Option<i64>, args: Q::Arguments) -> Q::Result {
// Postgres docs say:
//
// It is generally unwise to mix read-only and read-write commands within a single function
// using SPI; that could result in very confusing behavior, since the read-only queries
// would not see the results of any database updates done by the read-write queries.
//
// As such, we don't actually set read-only to true here

// TODO: can we detect if the command counter (or something?) has incremented and if yes
// then we set read_only=false, else we can set it to true?
// Is this even a good idea?
self.execute(query, false, limit, args)
self.execute(query, self.readonly, limit, args)
}

/// perform any query (including utility statements) that modify the database in some way
pub fn update<Q: Query>(&self, query: Q, limit: Option<i64>, args: Q::Arguments) -> Q::Result {
self.execute(query, false, limit, args)
pub fn update<Q: Query>(
&mut self,
query: Q,
limit: Option<i64>,
args: Q::Arguments,
) -> Q::Result {
self.readonly = false;
self.execute(query, self.readonly, limit, args)
}

fn execute<Q: Query>(
Expand Down Expand Up @@ -502,11 +516,21 @@ impl<'a> SpiClient<'a> {
/// Rows may be then fetched using [`SpiCursor::fetch`].
///
/// See [`SpiCursor`] docs for usage details.
pub fn open_cursor<Q: Query>(
&self,
pub fn open_cursor<Q: Query>(&self, query: Q, args: Q::Arguments) -> Result<SpiCursor, Error> {
query.open_cursor(&self, args)
}

/// Set up a cursor that will execute the specified update (mutating) query
///
/// Rows may be then fetched using [`SpiCursor::fetch`].
///
/// See [`SpiCursor`] docs for usage details.
pub fn open_cursor_mut<Q: Query>(
&mut self,
query: Q,
args: Q::Arguments,
) -> Result<SpiCursor<'a>, Error> {
) -> Result<SpiCursor, Error> {
self.readonly = false;
query.open_cursor(&self, args)
}

Expand Down Expand Up @@ -761,7 +785,7 @@ impl<'a: 'b, 'b> Query for &'b PreparedStatement<'a> {

fn open_cursor<'c: 'cc, 'cc>(
self,
_client: &'cc SpiClient<'c>,
client: &'cc SpiClient<'c>,
args: Self::Arguments,
) -> Result<SpiCursor<'c>, Error> {
let args = args.unwrap_or_default();
Expand All @@ -775,7 +799,7 @@ impl<'a: 'b, 'b> Query for &'b PreparedStatement<'a> {
self.plan,
datums.as_mut_ptr(),
nulls.as_ptr(),
false,
client.readonly,
)
})
.ok_or(Error::PortalIsNull)?;
Expand Down