diff --git a/crates/base/src/vector/vecf32.rs b/crates/base/src/vector/vecf32.rs index 68f77dff2..70dc55e97 100644 --- a/crates/base/src/vector/vecf32.rs +++ b/crates/base/src/vector/vecf32.rs @@ -82,6 +82,9 @@ impl<'a> Vecf32Borrowed<'a> { pub fn slice(&self) -> &[F32] { self.0 } + pub fn l2_norm(&self) -> F32 { + dot(self.slice(), self.slice()).sqrt() + } } impl<'a> VectorBorrowed for Vecf32Borrowed<'a> { diff --git a/src/datatype/functions_vecf32.rs b/src/datatype/functions_vecf32.rs new file mode 100644 index 000000000..3f4c84e5f --- /dev/null +++ b/src/datatype/functions_vecf32.rs @@ -0,0 +1,292 @@ +#![allow(unused_lifetimes)] +#![allow(clippy::extra_unused_lifetimes)] +use crate::datatype::memory_vecf32::{Vecf32Input, Vecf32Output}; +use crate::error::*; +use base::scalar::*; +use base::vector::*; +use pgrx::pg_sys::Datum; +use pgrx::pg_sys::Oid; +use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError; +use pgrx::pgrx_sql_entity_graph::metadata::Returns; +use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; +use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; +use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; +use pgrx::{FromDatum, IntoDatum}; +use std::alloc::Layout; +use std::ffi::{CStr, CString}; +use std::ops::{Deref, DerefMut}; +use std::ptr::NonNull; + +#[repr(C, align(8))] +pub struct AccumulateStateHeader { + varlena: u32, + dims: u16, + count: u64, + phantom: [F32; 0], +} + +impl AccumulateStateHeader { + fn varlena(size: usize) -> u32 { + (size << 2) as u32 + } + fn layout(len: usize) -> Layout { + u16::try_from(len).expect("Vector is too large."); + let layout_alpha = Layout::new::(); + let layout_beta = Layout::array::(len).unwrap(); + let layout = layout_alpha.extend(layout_beta).unwrap().0; + layout.pad_to_align() + } + pub fn dims(&self) -> usize { + self.dims as usize + } + pub fn count(&self) -> u64 { + self.count + } + pub fn slice(&self) -> &[F32] { + unsafe { std::slice::from_raw_parts(self.phantom.as_ptr(), self.dims as usize) } + } + pub fn slice_mut(&mut self) -> &mut [F32] { + unsafe { std::slice::from_raw_parts_mut(self.phantom.as_mut_ptr(), self.dims as usize) } + } +} + +pub enum AccumulateState<'a> { + Owned(NonNull), + Borrowed(&'a mut AccumulateStateHeader), +} + +impl<'a> AccumulateState<'a> { + unsafe fn new(p: NonNull) -> Self { + // datum maybe toasted, try to detoast it + let q = unsafe { + NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.as_ptr().cast()).cast()).unwrap() + }; + if p != q { + AccumulateState::Owned(q) + } else { + unsafe { AccumulateState::Borrowed(&mut *p.as_ptr()) } + } + } + + pub fn new_with_slice(count: u64, slice: &[F32]) -> Self { + let dims = slice.len(); + let layout = AccumulateStateHeader::layout(dims); + unsafe { + let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut AccumulateStateHeader; + std::ptr::addr_of_mut!((*ptr).varlena) + .write(AccumulateStateHeader::varlena(layout.size())); + std::ptr::addr_of_mut!((*ptr).dims).write(dims as u16); + std::ptr::addr_of_mut!((*ptr).count).write(count); + if dims > 0 { + std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), dims); + } + AccumulateState::Owned(NonNull::new(ptr).unwrap()) + } + } + + pub fn into_raw(self) -> *mut AccumulateStateHeader { + let result = match self { + AccumulateState::Owned(p) => p.as_ptr(), + AccumulateState::Borrowed(ref p) => { + *p as *const AccumulateStateHeader as *mut AccumulateStateHeader + } + }; + std::mem::forget(self); + result + } +} + +impl Deref for AccumulateState<'_> { + type Target = AccumulateStateHeader; + + fn deref(&self) -> &Self::Target { + match self { + AccumulateState::Owned(p) => unsafe { p.as_ref() }, + AccumulateState::Borrowed(p) => p, + } + } +} + +impl DerefMut for AccumulateState<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + AccumulateState::Owned(p) => unsafe { p.as_mut() }, + AccumulateState::Borrowed(p) => p, + } + } +} + +impl Drop for AccumulateState<'_> { + fn drop(&mut self) { + match self { + AccumulateState::Owned(p) => unsafe { + pgrx::pg_sys::pfree(p.as_ptr().cast()); + }, + AccumulateState::Borrowed(_) => {} + } + } +} + +impl FromDatum for AccumulateState<'_> { + unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typmod: Oid) -> Option { + if is_null { + None + } else { + let ptr = NonNull::new(datum.cast_mut_ptr::()).unwrap(); + unsafe { Some(AccumulateState::new(ptr)) } + } + } +} + +impl IntoDatum for AccumulateState<'_> { + fn into_datum(self) -> Option { + Some(Datum::from(self.into_raw() as *mut ())) + } + + fn type_oid() -> Oid { + let namespace = pgrx::pg_catalog::PgNamespace::search_namespacename(c"vectors").unwrap(); + let namespace = namespace.get().expect("pgvecto.rs is not installed."); + let t = pgrx::pg_catalog::PgType::search_typenamensp( + c"vector_accumulate_state ", + namespace.oid(), + ) + .unwrap(); + let t = t.get().expect("pg_catalog is broken."); + t.oid() + } +} + +unsafe impl SqlTranslatable for AccumulateState<'_> { + fn argument_sql() -> Result { + Ok(SqlMapping::As(String::from("vector_accumulate_state "))) + } + fn return_sql() -> Result { + Ok(Returns::One(SqlMapping::As(String::from( + "vector_accumulate_state ", + )))) + } +} + +fn parse_accumulate_state(input: &[u8]) -> Result<(u64, Vec), String> { + use crate::utils::parse::parse_vector; + let hint = "Invalid input format for accumulatestate, using \'bigint, array \' like \'1, [1]\'"; + let (count, slice) = input.split_once(|&c| c == b',').ok_or(hint)?; + let count = std::str::from_utf8(count) + .map_err(|e| e.to_string() + "\n" + hint)? + .parse::() + .map_err(|e| e.to_string() + "\n" + hint)?; + let v = parse_vector(slice, 0, |s| s.parse().ok()); + match v { + Err(e) => Err(e.to_string() + "\n" + hint), + Ok(vector) => Ok((count, vector)), + } +} + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vectors_accumulate_state_in(input: &CStr, _oid: Oid, _typmod: i32) -> AccumulateState<'_> { + // parse one bigint and a vector of f32, split with a comma + let res = parse_accumulate_state(input.to_bytes()); + match res { + Err(e) => { + bad_literal(&e.to_string()); + } + Ok((count, vector)) => AccumulateState::new_with_slice(count, &vector), + } +} + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vectors_accumulate_state_out(state: AccumulateState<'_>) -> CString { + let mut buffer = String::new(); + buffer.push_str(format!("{}, ", state.count()).as_str()); + buffer.push('['); + if let Some(&x) = state.slice().first() { + buffer.push_str(format!("{}", x).as_str()); + } + for &x in state.slice().iter().skip(1) { + buffer.push_str(format!(", {}", x).as_str()); + } + buffer.push(']'); + CString::new(buffer).unwrap() +} + +/// accumulate intermediate state for vector average +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vectors_vector_accum<'a>( + mut state: AccumulateState<'a>, + value: Vecf32Input<'_>, +) -> AccumulateState<'a> { + let count = state.count(); + match count { + // if the state is empty, copy the input vector + 0 => AccumulateState::new_with_slice(1, value.iter().as_slice()), + _ => { + let dims = state.dims(); + let value_dims = value.dims(); + check_matched_dims(dims, value_dims); + let sum = state.slice_mut(); + // accumulate the input vector + for (x, y) in sum.iter_mut().zip(value.iter()) { + *x += *y; + } + // increase the count + state.count += 1; + state + } + } +} + +/// combine two intermediate states for vector average +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vectors_vector_combine<'a>( + mut state1: AccumulateState<'a>, + state2: AccumulateState<'a>, +) -> AccumulateState<'a> { + let count1 = state1.count(); + let count2 = state2.count(); + if count1 == 0 { + state2 + } else if count2 == 0 { + state1 + } else { + let dims1 = state1.dims(); + let dims2 = state2.dims(); + check_matched_dims(dims1, dims2); + state1.count += count2; + let sum1 = state1.slice_mut(); + let sum2 = state2.slice(); + for (x, y) in sum1.iter_mut().zip(sum2.iter()) { + *x += *y; + } + state1 + } +} + +/// finalize the intermediate state for vector average +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vectors_vector_final(state: AccumulateState<'_>) -> Option { + let count = state.count(); + if count == 0 { + // return NULL if all inputs are NULL + return None; + } + let sum = state + .slice() + .iter() + .map(|x| *x / F32(count as f32)) + .collect::>(); + Some(Vecf32Output::new( + Vecf32Borrowed::new_checked(&sum).unwrap(), + )) +} + +/// Get the dimensions of a vector. +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vectors_vector_dims(vector: Vecf32Input<'_>) -> i32 { + vector.dims() as i32 +} + +/// Calculate the l2 norm of a vector. +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vectors_vector_norm(vector: Vecf32Input<'_>) -> f32 { + vector.for_borrow().l2_norm().to_f32() +} diff --git a/src/datatype/mod.rs b/src/datatype/mod.rs index abfa165ff..9ffcbadaa 100644 --- a/src/datatype/mod.rs +++ b/src/datatype/mod.rs @@ -7,6 +7,7 @@ pub mod binary_veci8; pub mod casts; pub mod functions_bvecf32; pub mod functions_svecf32; +pub mod functions_vecf32; pub mod functions_veci8; pub mod memory_bvecf32; pub mod memory_svecf32; diff --git a/src/lib.rs b/src/lib.rs index a38ab08da..0bf11b7a8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ //! //! Provides an easy-to-use extension for vector similarity search. #![feature(alloc_error_hook)] +#![feature(slice_split_once)] #![allow(clippy::needless_range_loop)] #![allow(clippy::single_match)] #![allow(clippy::too_many_arguments)] diff --git a/src/sql/bootstrap.sql b/src/sql/bootstrap.sql index 7f9620bec..9a861ac9b 100644 --- a/src/sql/bootstrap.sql +++ b/src/sql/bootstrap.sql @@ -8,5 +8,6 @@ CREATE TYPE svector; CREATE TYPE bvector; CREATE TYPE veci8; CREATE TYPE vector_index_stat; +CREATE TYPE vector_accumulate_state; -- bootstrap end diff --git a/src/sql/finalize.sql b/src/sql/finalize.sql index 82f65ac94..5d94582a9 100644 --- a/src/sql/finalize.sql +++ b/src/sql/finalize.sql @@ -78,6 +78,14 @@ CREATE TYPE vector_index_stat AS ( idx_options TEXT ); +CREATE TYPE vector_accumulate_state ( + INPUT = _vectors_accumulate_state_in, + OUTPUT = _vectors_accumulate_state_out, + STORAGE = EXTERNAL, + INTERNALLENGTH = VARIABLE, + ALIGNMENT = double +); + -- List of operators CREATE OPERATOR + ( @@ -593,6 +601,30 @@ $$; CREATE FUNCTION alter_vector_index("index" OID, "key" TEXT, "value" TEXT) RETURNS void STRICT LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_alter_vector_index_wrapper'; +CREATE FUNCTION vector_dims("v" vector) RETURNS INT +STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_vector_dims_wrapper'; + +CREATE FUNCTION vector_norm("v" vector) RETURNS real +STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_vector_norm_wrapper'; + +-- List of aggregates + +CREATE AGGREGATE avg(vector) ( + SFUNC = _vectors_vector_accum, + STYPE = vector_accumulate_state, + COMBINEFUNC = _vectors_vector_combine, + FINALFUNC = _vectors_vector_final, + INITCOND = '0, []', + PARALLEL = SAFE +); + +CREATE AGGREGATE sum(vector) ( + SFUNC = _vectors_vecf32_operator_add, + STYPE = vector, + COMBINEFUNC = _vectors_vecf32_operator_add, + PARALLEL = SAFE +); + -- List of casts CREATE CAST (real[] AS vector) diff --git a/tests/sqllogictest/vector.slt b/tests/sqllogictest/vector.slt new file mode 100644 index 000000000..1b5d77f8d --- /dev/null +++ b/tests/sqllogictest/vector.slt @@ -0,0 +1,104 @@ +statement ok +SET search_path TO pg_temp, vectors; + +statement ok +CREATE TABLE t (id bigserial, val vector); + +statement ok +INSERT INTO t (val) +VALUES ('[1,2,3]'), ('[4,5,6]'); + +query I +SELECT vector_dims(val) FROM t; +---- +3 +3 + +query R +SELECT round(vector_norm(val)::numeric, 5) FROM t; +---- +3.74166 +8.77496 + +query ? +SELECT avg(val) FROM t; +---- +[2.5, 3.5, 4.5] + +query ? +SELECT sum(val) FROM t; +---- +[5, 7, 9] + +query R +SELECT vector_norm('[3,4]'); +---- +5 + +query I +SELECT vector_dims(v) FROM unnest(ARRAY['[1,2]'::vector, '[3]']) v; +---- +2 +1 + +query ? +SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; +---- +[2, 3.5, 5] + +query ? +SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v; +---- +[2, 3.5, 5] + +query ? +SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector,NULL]) v; +---- +[1, 2, 3] + +query ? +SELECT avg(v) FROM unnest(ARRAY[]::vector[]) v; +---- +NULL + +query ? +SELECT avg(v) FROM unnest(ARRAY[NULL]::vector[]) v; +---- +NULL + +query ? +SELECT avg(v) FROM unnest(ARRAY['[3e38]'::vector, '[3e38]']) v; +---- +[inf] + +statement error differs in dimensions +SELECT avg(v) FROM unnest(ARRAY['[1,2]'::vector, '[3]']) v; + +query ? +SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; +---- +[4, 7, 10] + +query ? +SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v; +---- +[4, 7, 10] + +query ? +SELECT sum(v) FROM unnest(ARRAY[]::vector[]) v; +---- +NULL + +query ? +SELECT sum(v) FROM unnest(ARRAY[NULL]::vector[]) v; +---- +NULL + +statement error differs in dimensions +SELECT sum(v) FROM unnest(ARRAY['[1,2]'::vector, '[3]']) v; + +# should this return an error ? +query ? +SELECT sum(v) FROM unnest(ARRAY['[3e38]'::vector, '[3e38]']) v; +---- +[inf]