Skip to content

Commit

Permalink
feat: support aggregate function for vector. (#463)
Browse files Browse the repository at this point in the history
* implement vector avg.

Signed-off-by: my-vegetable-has-exploded <[email protected]>

* Implement sum、 vector_dims()、 vector_norm()

Signed-off-by: my-vegetable-has-exploded <[email protected]>

* update AccumulateState

Signed-off-by: my-vegetable-has-exploded <[email protected]>

* fix norm.

Signed-off-by: my-vegetable-has-exploded <[email protected]>

* add l2_norm for VectorBorrowed.

Signed-off-by: my-vegetable-has-exploded <[email protected]>

---------

Signed-off-by: my-vegetable-has-exploded <[email protected]>
  • Loading branch information
my-vegetable-has-exploded committed May 6, 2024
1 parent eaa3093 commit d81473c
Show file tree
Hide file tree
Showing 7 changed files with 434 additions and 0 deletions.
3 changes: 3 additions & 0 deletions crates/base/src/vector/vecf32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ impl<'a> Vecf32Borrowed<'a> {
pub fn slice(&self) -> &[F32] {
self.0
}
pub fn l2_norm(&self) -> F32 {
dot(self.slice(), self.slice()).sqrt()
}
}

impl<'a> VectorBorrowed for Vecf32Borrowed<'a> {
Expand Down
292 changes: 292 additions & 0 deletions src/datatype/functions_vecf32.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
#![allow(unused_lifetimes)]
#![allow(clippy::extra_unused_lifetimes)]
use crate::datatype::memory_vecf32::{Vecf32Input, Vecf32Output};
use crate::error::*;
use base::scalar::*;
use base::vector::*;
use pgrx::pg_sys::Datum;
use pgrx::pg_sys::Oid;
use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError;
use pgrx::pgrx_sql_entity_graph::metadata::Returns;
use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError;
use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping;
use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable;
use pgrx::{FromDatum, IntoDatum};
use std::alloc::Layout;
use std::ffi::{CStr, CString};
use std::ops::{Deref, DerefMut};
use std::ptr::NonNull;

#[repr(C, align(8))]
pub struct AccumulateStateHeader {
varlena: u32,
dims: u16,
count: u64,
phantom: [F32; 0],
}

impl AccumulateStateHeader {
fn varlena(size: usize) -> u32 {
(size << 2) as u32
}
fn layout(len: usize) -> Layout {
u16::try_from(len).expect("Vector is too large.");
let layout_alpha = Layout::new::<AccumulateStateHeader>();
let layout_beta = Layout::array::<F32>(len).unwrap();
let layout = layout_alpha.extend(layout_beta).unwrap().0;
layout.pad_to_align()
}
pub fn dims(&self) -> usize {
self.dims as usize
}
pub fn count(&self) -> u64 {
self.count
}
pub fn slice(&self) -> &[F32] {
unsafe { std::slice::from_raw_parts(self.phantom.as_ptr(), self.dims as usize) }
}
pub fn slice_mut(&mut self) -> &mut [F32] {
unsafe { std::slice::from_raw_parts_mut(self.phantom.as_mut_ptr(), self.dims as usize) }
}
}

pub enum AccumulateState<'a> {
Owned(NonNull<AccumulateStateHeader>),
Borrowed(&'a mut AccumulateStateHeader),
}

impl<'a> AccumulateState<'a> {
unsafe fn new(p: NonNull<AccumulateStateHeader>) -> Self {
// datum maybe toasted, try to detoast it
let q = unsafe {
NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.as_ptr().cast()).cast()).unwrap()
};
if p != q {
AccumulateState::Owned(q)
} else {
unsafe { AccumulateState::Borrowed(&mut *p.as_ptr()) }
}
}

pub fn new_with_slice(count: u64, slice: &[F32]) -> Self {
let dims = slice.len();
let layout = AccumulateStateHeader::layout(dims);
unsafe {
let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut AccumulateStateHeader;
std::ptr::addr_of_mut!((*ptr).varlena)
.write(AccumulateStateHeader::varlena(layout.size()));
std::ptr::addr_of_mut!((*ptr).dims).write(dims as u16);
std::ptr::addr_of_mut!((*ptr).count).write(count);
if dims > 0 {
std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), dims);
}
AccumulateState::Owned(NonNull::new(ptr).unwrap())
}
}

pub fn into_raw(self) -> *mut AccumulateStateHeader {
let result = match self {
AccumulateState::Owned(p) => p.as_ptr(),
AccumulateState::Borrowed(ref p) => {
*p as *const AccumulateStateHeader as *mut AccumulateStateHeader
}
};
std::mem::forget(self);
result
}
}

impl Deref for AccumulateState<'_> {
type Target = AccumulateStateHeader;

fn deref(&self) -> &Self::Target {
match self {
AccumulateState::Owned(p) => unsafe { p.as_ref() },
AccumulateState::Borrowed(p) => p,
}
}
}

impl DerefMut for AccumulateState<'_> {
fn deref_mut(&mut self) -> &mut Self::Target {
match self {
AccumulateState::Owned(p) => unsafe { p.as_mut() },
AccumulateState::Borrowed(p) => p,
}
}
}

impl Drop for AccumulateState<'_> {
fn drop(&mut self) {
match self {
AccumulateState::Owned(p) => unsafe {
pgrx::pg_sys::pfree(p.as_ptr().cast());
},
AccumulateState::Borrowed(_) => {}
}
}
}

impl FromDatum for AccumulateState<'_> {
unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typmod: Oid) -> Option<Self> {
if is_null {
None
} else {
let ptr = NonNull::new(datum.cast_mut_ptr::<AccumulateStateHeader>()).unwrap();
unsafe { Some(AccumulateState::new(ptr)) }
}
}
}

impl IntoDatum for AccumulateState<'_> {
fn into_datum(self) -> Option<Datum> {
Some(Datum::from(self.into_raw() as *mut ()))
}

fn type_oid() -> Oid {
let namespace = pgrx::pg_catalog::PgNamespace::search_namespacename(c"vectors").unwrap();
let namespace = namespace.get().expect("pgvecto.rs is not installed.");
let t = pgrx::pg_catalog::PgType::search_typenamensp(
c"vector_accumulate_state ",
namespace.oid(),
)
.unwrap();
let t = t.get().expect("pg_catalog is broken.");
t.oid()
}
}

unsafe impl SqlTranslatable for AccumulateState<'_> {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
Ok(SqlMapping::As(String::from("vector_accumulate_state ")))
}
fn return_sql() -> Result<Returns, ReturnsError> {
Ok(Returns::One(SqlMapping::As(String::from(
"vector_accumulate_state ",
))))
}
}

fn parse_accumulate_state(input: &[u8]) -> Result<(u64, Vec<F32>), String> {
use crate::utils::parse::parse_vector;
let hint = "Invalid input format for accumulatestate, using \'bigint, array \' like \'1, [1]\'";
let (count, slice) = input.split_once(|&c| c == b',').ok_or(hint)?;
let count = std::str::from_utf8(count)
.map_err(|e| e.to_string() + "\n" + hint)?
.parse::<u64>()
.map_err(|e| e.to_string() + "\n" + hint)?;
let v = parse_vector(slice, 0, |s| s.parse().ok());
match v {
Err(e) => Err(e.to_string() + "\n" + hint),
Ok(vector) => Ok((count, vector)),
}
}

#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_accumulate_state_in(input: &CStr, _oid: Oid, _typmod: i32) -> AccumulateState<'_> {
// parse one bigint and a vector of f32, split with a comma
let res = parse_accumulate_state(input.to_bytes());
match res {
Err(e) => {
bad_literal(&e.to_string());
}
Ok((count, vector)) => AccumulateState::new_with_slice(count, &vector),
}
}

#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_accumulate_state_out(state: AccumulateState<'_>) -> CString {
let mut buffer = String::new();
buffer.push_str(format!("{}, ", state.count()).as_str());
buffer.push('[');
if let Some(&x) = state.slice().first() {
buffer.push_str(format!("{}", x).as_str());
}
for &x in state.slice().iter().skip(1) {
buffer.push_str(format!(", {}", x).as_str());
}
buffer.push(']');
CString::new(buffer).unwrap()
}

/// accumulate intermediate state for vector average
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_vector_accum<'a>(
mut state: AccumulateState<'a>,
value: Vecf32Input<'_>,
) -> AccumulateState<'a> {
let count = state.count();
match count {
// if the state is empty, copy the input vector
0 => AccumulateState::new_with_slice(1, value.iter().as_slice()),
_ => {
let dims = state.dims();
let value_dims = value.dims();
check_matched_dims(dims, value_dims);
let sum = state.slice_mut();
// accumulate the input vector
for (x, y) in sum.iter_mut().zip(value.iter()) {
*x += *y;
}
// increase the count
state.count += 1;
state
}
}
}

/// combine two intermediate states for vector average
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_vector_combine<'a>(
mut state1: AccumulateState<'a>,
state2: AccumulateState<'a>,
) -> AccumulateState<'a> {
let count1 = state1.count();
let count2 = state2.count();
if count1 == 0 {
state2
} else if count2 == 0 {
state1
} else {
let dims1 = state1.dims();
let dims2 = state2.dims();
check_matched_dims(dims1, dims2);
state1.count += count2;
let sum1 = state1.slice_mut();
let sum2 = state2.slice();
for (x, y) in sum1.iter_mut().zip(sum2.iter()) {
*x += *y;
}
state1
}
}

/// finalize the intermediate state for vector average
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_vector_final(state: AccumulateState<'_>) -> Option<Vecf32Output> {
let count = state.count();
if count == 0 {
// return NULL if all inputs are NULL
return None;
}
let sum = state
.slice()
.iter()
.map(|x| *x / F32(count as f32))
.collect::<Vec<_>>();
Some(Vecf32Output::new(
Vecf32Borrowed::new_checked(&sum).unwrap(),
))
}

/// Get the dimensions of a vector.
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_vector_dims(vector: Vecf32Input<'_>) -> i32 {
vector.dims() as i32
}

/// Calculate the l2 norm of a vector.
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_vector_norm(vector: Vecf32Input<'_>) -> f32 {
vector.for_borrow().l2_norm().to_f32()
}
1 change: 1 addition & 0 deletions src/datatype/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub mod binary_veci8;
pub mod casts;
pub mod functions_bvecf32;
pub mod functions_svecf32;
pub mod functions_vecf32;
pub mod functions_veci8;
pub mod memory_bvecf32;
pub mod memory_svecf32;
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//!
//! Provides an easy-to-use extension for vector similarity search.
#![feature(alloc_error_hook)]
#![feature(slice_split_once)]
#![allow(clippy::needless_range_loop)]
#![allow(clippy::single_match)]
#![allow(clippy::too_many_arguments)]
Expand Down
1 change: 1 addition & 0 deletions src/sql/bootstrap.sql
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ CREATE TYPE svector;
CREATE TYPE bvector;
CREATE TYPE veci8;
CREATE TYPE vector_index_stat;
CREATE TYPE vector_accumulate_state;

-- bootstrap end
32 changes: 32 additions & 0 deletions src/sql/finalize.sql
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ CREATE TYPE vector_index_stat AS (
idx_options TEXT
);

CREATE TYPE vector_accumulate_state (
INPUT = _vectors_accumulate_state_in,
OUTPUT = _vectors_accumulate_state_out,
STORAGE = EXTERNAL,
INTERNALLENGTH = VARIABLE,
ALIGNMENT = double
);

-- List of operators

CREATE OPERATOR + (
Expand Down Expand Up @@ -593,6 +601,30 @@ $$;
CREATE FUNCTION alter_vector_index("index" OID, "key" TEXT, "value" TEXT) RETURNS void
STRICT LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_alter_vector_index_wrapper';

CREATE FUNCTION vector_dims("v" vector) RETURNS INT
STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_vector_dims_wrapper';

CREATE FUNCTION vector_norm("v" vector) RETURNS real
STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_vector_norm_wrapper';

-- List of aggregates

CREATE AGGREGATE avg(vector) (
SFUNC = _vectors_vector_accum,
STYPE = vector_accumulate_state,
COMBINEFUNC = _vectors_vector_combine,
FINALFUNC = _vectors_vector_final,
INITCOND = '0, []',
PARALLEL = SAFE
);

CREATE AGGREGATE sum(vector) (
SFUNC = _vectors_vecf32_operator_add,
STYPE = vector,
COMBINEFUNC = _vectors_vecf32_operator_add,
PARALLEL = SAFE
);

-- List of casts

CREATE CAST (real[] AS vector)
Expand Down
Loading

0 comments on commit d81473c

Please sign in to comment.