diff --git a/src/datatype/functions_svecf32.rs b/src/datatype/functions_svecf32.rs index 19b5d8a37..798a1a4e6 100644 --- a/src/datatype/functions_svecf32.rs +++ b/src/datatype/functions_svecf32.rs @@ -2,6 +2,7 @@ use super::memory_svecf32::*; use crate::error::*; use base::scalar::*; use base::vector::*; +use num_traits::Zero; #[pgrx::pg_extern(immutable, strict, parallel_safe)] fn _vectors_svecf32_dims(vector: SVecf32Input<'_>) -> i32 { @@ -58,3 +59,26 @@ fn _vectors_to_svector( } SVecf32Output::new(SVecf32Borrowed::new(dims.get(), &indexes, &values)) } + +/// divide a sparse vector by a scalar. +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vectors_svecf32_div(vector: SVecf32Input<'_>, scalar: f32) -> SVecf32Output { + let scalar = F32(scalar); + let vector = vector.for_borrow(); + let indexes = vector.indexes(); + let values = vector.values(); + let mut new_indexes = Vec::::with_capacity(indexes.len()); + let mut new_values = Vec::::with_capacity(values.len()); + for (value, index) in values.iter().zip(indexes.iter()) { + let v = *value / scalar; + if !v.is_zero() { + new_values.push(v); + new_indexes.push(*index); + } + } + SVecf32Output::new(SVecf32Borrowed::new( + vector.dims(), + &new_indexes, + &new_values, + )) +} diff --git a/src/sql/bootstrap.sql b/src/sql/bootstrap.sql index eee68d215..6b023423a 100644 --- a/src/sql/bootstrap.sql +++ b/src/sql/bootstrap.sql @@ -11,5 +11,6 @@ CREATE TYPE veci8; CREATE TYPE vector_index_stat; CREATE TYPE _vectors_vecf32_aggregate_avg_stype; +CREATE TYPE svector_accumulate_state; -- bootstrap end diff --git a/src/sql/finalize.sql b/src/sql/finalize.sql index 9c8bf143b..29b479a4c 100644 --- a/src/sql/finalize.sql +++ b/src/sql/finalize.sql @@ -86,6 +86,11 @@ CREATE TYPE _vectors_vecf32_aggregate_avg_stype ( ALIGNMENT = double ); +CREATE TYPE svector_accumulate_state AS ( + count INT, + sum svector +); + -- List of operators CREATE OPERATOR + ( @@ -659,6 +664,51 @@ IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_binari CREATE FUNCTION to_veci8("len" INT, "alpha" real, "offset" real, "values" INT[]) RETURNS veci8 IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_to_veci8_wrapper'; +CREATE FUNCTION _vectors_svector_accum("state" svector_accumulate_state, "value" svector) RETURNS svector_accumulate_state AS $$ +DECLARE + result svector_accumulate_state; +BEGIN + IF state.count = 0 THEN + result.count := 1; + result.sum := value; + RETURN result; + END IF; + result.count := state.count + 1; + result.sum := state.sum + value; + RETURN result; +END; +$$ LANGUAGE plpgsql STRICT PARALLEL SAFE; + +CREATE FUNCTION _vectors_svector_combine("state1" svector_accumulate_state, "state2" svector_accumulate_state) RETURNS svector_accumulate_state AS $$ +DECLARE + result svector_accumulate_state; +BEGIN + IF state1.count = 0 THEN + RETURN state2; + END IF; + IF state2.count = 0 THEN + RETURN state1; + END IF; + result.count := state1.count + state2.count; + result.sum := state1.sum + state2.sum; + RETURN result; +END; +$$ LANGUAGE plpgsql STRICT PARALLEL SAFE; + +CREATE FUNCTION _vectors_svector_final("state" svector_accumulate_state) RETURNS svector AS $$ +DECLARE + result svector; + count INT; +BEGIN + count := state.count; + IF count = 0 THEN + RETURN NULL; + END IF; + result := _vectors_svecf32_div(state.sum, count::real); + RETURN result; +END; +$$ LANGUAGE plpgsql STRICT PARALLEL SAFE; + -- List of aggregates CREATE AGGREGATE avg(vector) ( @@ -677,6 +727,22 @@ CREATE AGGREGATE sum(vector) ( PARALLEL = SAFE ); +CREATE AGGREGATE avg(svector) ( + SFUNC = _vectors_svector_accum, + STYPE = svector_accumulate_state, + COMBINEFUNC = _vectors_svector_combine, + FINALFUNC = _vectors_svector_final, + INITCOND = '(0, [0])', + PARALLEL = SAFE +); + +CREATE AGGREGATE sum(svector) ( + SFUNC = _vectors_svecf32_operator_add, + STYPE = svector, + COMBINEFUNC = _vectors_svecf32_operator_add, + PARALLEL = SAFE +); + -- List of casts CREATE CAST (real[] AS vector) diff --git a/tests/sqllogictest/svector.slt b/tests/sqllogictest/svector.slt new file mode 100644 index 000000000..3815cd1ae --- /dev/null +++ b/tests/sqllogictest/svector.slt @@ -0,0 +1,165 @@ +statement ok +SET search_path TO pg_temp, vectors; + +statement ok +CREATE TABLE t (id bigserial, val svector); + +statement ok +INSERT INTO t (val) +VALUES ('[1,2,3]'), ('[4,5,6]'); + +query I +SELECT vector_dims(val) FROM t; +---- +3 +3 + +query R +SELECT round(vector_norm(val)::numeric, 5) FROM t; +---- +3.74166 +8.77496 + +query ? +SELECT avg(val) FROM t; +---- +[2.5, 3.5, 4.5] + +query ? +SELECT sum(val) FROM t; +---- +[5, 7, 9] + +statement ok +CREATE TABLE test_vectors (id serial, data vector(1000)); + +statement ok +INSERT INTO test_vectors (data) +SELECT + ARRAY_AGG(CASE WHEN random() < 0.95 THEN 0 ELSE (random() * 99 + 1)::real END)::real[]::vector AS v +FROM generate_series(1, 1000 * 5000) i +GROUP BY i % 5000; + +query ? +SELECT count(*) FROM test_vectors; +---- +5000 + +query R +SELECT vector_norm('[3,4]'::svector); +---- +5 + +query I +SELECT vector_dims(v) FROM unnest(ARRAY['[1,2]'::svector, '[3]']) v; +---- +2 +1 + +query ? +SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[3,5,7]']) v; +---- +[2, 3.5, 5] + +query ? +SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[-1,2,-3]']) v; +---- +[0, 2, 0] + +query ? +SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[3,5,7]', NULL]) v; +---- +[2, 3.5, 5] + +query ? +SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::svector,NULL]) v; +---- +[1, 2, 3] + +query ? +SELECT avg(v) FROM unnest(ARRAY[]::svector[]) v; +---- +NULL + +query ? +SELECT avg(v) FROM unnest(ARRAY[NULL]::svector[]) v; +---- +NULL + +query ? +SELECT avg(v) FROM unnest(ARRAY['[3e38]'::svector, '[3e38]']) v; +---- +[inf] + +statement error differs in dimensions +SELECT avg(v) FROM unnest(ARRAY['[1,2]'::svector, '[3]']) v; + +query ? +SELECT avg(v) FROM unnest(ARRAY[to_svector(5, '{0,1}', '{2,3}'), to_svector(5, '{0,2}', '{1,3}'), to_svector(5, '{3,4}', '{3,3}')]) v; +---- +[1, 1, 1, 1, 1] + +query ? +SELECT avg(v) FROM unnest(ARRAY[to_svector(32, '{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}'), to_svector(32, '{1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}'), to_svector(32, '{2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}') ]) v; +---- +[0.33333334, 0.6666667, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.6666667, 0.33333334, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + +# test avg(svector) get the same result as avg(vector) +query ? +SELECT avg(data) = avg(data::svector)::vector FROM test_vectors; +---- +t + +query ? +SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[3,5,7]']) v; +---- +[4, 7, 10] + +# test zero element +query ? +SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[-1,2,-3]']) v; +---- +[0, 4, 0] + +query ? +SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[3,5,7]', NULL]) v; +---- +[4, 7, 10] + +query ? +SELECT sum(v) FROM unnest(ARRAY[]::svector[]) v; +---- +NULL + +query ? +SELECT sum(v) FROM unnest(ARRAY[NULL]::svector[]) v; +---- +NULL + +statement error differs in dimensions +SELECT sum(v) FROM unnest(ARRAY['[1,2]'::svector, '[3]']) v; + +# should this return an error ? +query ? +SELECT sum(v) FROM unnest(ARRAY['[3e38]'::svector, '[3e38]']) v; +---- +[inf] + +query ? +SELECT sum(v) FROM unnest(ARRAY[to_svector(5, '{0,1}', '{1,2}'), to_svector(5, '{0,2}', '{1,2}'), to_svector(5, '{3,4}', '{3,3}')]) v; +---- +[2, 2, 2, 3, 3] + +query ? +SELECT sum(v) FROM unnest(ARRAY[to_svector(32, '{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}'), to_svector(32, '{1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}'), to_svector(32, '{2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}') ]) v; +---- +[1, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + +# test sum(svector) get the same result as sum(vector) +query ? +SELECT sum(data) = sum(data::svector)::vector FROM test_vectors; +---- +t + +statement ok +DROP TABLE t, test_vectors; \ No newline at end of file