Skip to content

Commit

Permalink
feat: temporary support for sparse aggregate. (#489)
Browse files Browse the repository at this point in the history
* Implement aggregate functions for sparse vector.

Signed-off-by: my-vegetable-has-exploded <[email protected]>

* fix null.

Signed-off-by: my-vegetable-has-exploded <[email protected]>

* tests: add more e2e test.

Signed-off-by: my-vegetable-has-exploded <[email protected]>

---------

Signed-off-by: my-vegetable-has-exploded <[email protected]>
  • Loading branch information
my-vegetable-has-exploded committed May 28, 2024
1 parent 8f99933 commit 13e4245
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 0 deletions.
24 changes: 24 additions & 0 deletions src/datatype/functions_svecf32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::memory_svecf32::*;
use crate::error::*;
use base::scalar::*;
use base::vector::*;
use num_traits::Zero;

#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_svecf32_dims(vector: SVecf32Input<'_>) -> i32 {
Expand Down Expand Up @@ -58,3 +59,26 @@ fn _vectors_to_svector(
}
SVecf32Output::new(SVecf32Borrowed::new(dims.get(), &indexes, &values))
}

/// divide a sparse vector by a scalar.
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_svecf32_div(vector: SVecf32Input<'_>, scalar: f32) -> SVecf32Output {
let scalar = F32(scalar);
let vector = vector.for_borrow();
let indexes = vector.indexes();
let values = vector.values();
let mut new_indexes = Vec::<u32>::with_capacity(indexes.len());
let mut new_values = Vec::<F32>::with_capacity(values.len());
for (value, index) in values.iter().zip(indexes.iter()) {
let v = *value / scalar;
if !v.is_zero() {
new_values.push(v);
new_indexes.push(*index);
}
}
SVecf32Output::new(SVecf32Borrowed::new(
vector.dims(),
&new_indexes,
&new_values,
))
}
1 change: 1 addition & 0 deletions src/sql/bootstrap.sql
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ CREATE TYPE veci8;
CREATE TYPE vector_index_stat;

CREATE TYPE _vectors_vecf32_aggregate_avg_stype;
CREATE TYPE svector_accumulate_state;

-- bootstrap end
66 changes: 66 additions & 0 deletions src/sql/finalize.sql
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ CREATE TYPE _vectors_vecf32_aggregate_avg_stype (
ALIGNMENT = double
);

CREATE TYPE svector_accumulate_state AS (
count INT,
sum svector
);

-- List of operators

CREATE OPERATOR + (
Expand Down Expand Up @@ -659,6 +664,51 @@ IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_binari
CREATE FUNCTION to_veci8("len" INT, "alpha" real, "offset" real, "values" INT[]) RETURNS veci8
IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_to_veci8_wrapper';

CREATE FUNCTION _vectors_svector_accum("state" svector_accumulate_state, "value" svector) RETURNS svector_accumulate_state AS $$
DECLARE
result svector_accumulate_state;
BEGIN
IF state.count = 0 THEN
result.count := 1;
result.sum := value;
RETURN result;
END IF;
result.count := state.count + 1;
result.sum := state.sum + value;
RETURN result;
END;
$$ LANGUAGE plpgsql STRICT PARALLEL SAFE;

CREATE FUNCTION _vectors_svector_combine("state1" svector_accumulate_state, "state2" svector_accumulate_state) RETURNS svector_accumulate_state AS $$
DECLARE
result svector_accumulate_state;
BEGIN
IF state1.count = 0 THEN
RETURN state2;
END IF;
IF state2.count = 0 THEN
RETURN state1;
END IF;
result.count := state1.count + state2.count;
result.sum := state1.sum + state2.sum;
RETURN result;
END;
$$ LANGUAGE plpgsql STRICT PARALLEL SAFE;

CREATE FUNCTION _vectors_svector_final("state" svector_accumulate_state) RETURNS svector AS $$
DECLARE
result svector;
count INT;
BEGIN
count := state.count;
IF count = 0 THEN
RETURN NULL;
END IF;
result := _vectors_svecf32_div(state.sum, count::real);
RETURN result;
END;
$$ LANGUAGE plpgsql STRICT PARALLEL SAFE;

-- List of aggregates

CREATE AGGREGATE avg(vector) (
Expand All @@ -677,6 +727,22 @@ CREATE AGGREGATE sum(vector) (
PARALLEL = SAFE
);

CREATE AGGREGATE avg(svector) (
SFUNC = _vectors_svector_accum,
STYPE = svector_accumulate_state,
COMBINEFUNC = _vectors_svector_combine,
FINALFUNC = _vectors_svector_final,
INITCOND = '(0, [0])',
PARALLEL = SAFE
);

CREATE AGGREGATE sum(svector) (
SFUNC = _vectors_svecf32_operator_add,
STYPE = svector,
COMBINEFUNC = _vectors_svecf32_operator_add,
PARALLEL = SAFE
);

-- List of casts

CREATE CAST (real[] AS vector)
Expand Down
165 changes: 165 additions & 0 deletions tests/sqllogictest/svector.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
statement ok
SET search_path TO pg_temp, vectors;

statement ok
CREATE TABLE t (id bigserial, val svector);

statement ok
INSERT INTO t (val)
VALUES ('[1,2,3]'), ('[4,5,6]');

query I
SELECT vector_dims(val) FROM t;
----
3
3

query R
SELECT round(vector_norm(val)::numeric, 5) FROM t;
----
3.74166
8.77496

query ?
SELECT avg(val) FROM t;
----
[2.5, 3.5, 4.5]

query ?
SELECT sum(val) FROM t;
----
[5, 7, 9]

statement ok
CREATE TABLE test_vectors (id serial, data vector(1000));

statement ok
INSERT INTO test_vectors (data)
SELECT
ARRAY_AGG(CASE WHEN random() < 0.95 THEN 0 ELSE (random() * 99 + 1)::real END)::real[]::vector AS v
FROM generate_series(1, 1000 * 5000) i
GROUP BY i % 5000;

query ?
SELECT count(*) FROM test_vectors;
----
5000

query R
SELECT vector_norm('[3,4]'::svector);
----
5

query I
SELECT vector_dims(v) FROM unnest(ARRAY['[1,2]'::svector, '[3]']) v;
----
2
1

query ?
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[3,5,7]']) v;
----
[2, 3.5, 5]

query ?
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[-1,2,-3]']) v;
----
[0, 2, 0]

query ?
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[3,5,7]', NULL]) v;
----
[2, 3.5, 5]

query ?
SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::svector,NULL]) v;
----
[1, 2, 3]

query ?
SELECT avg(v) FROM unnest(ARRAY[]::svector[]) v;
----
NULL

query ?
SELECT avg(v) FROM unnest(ARRAY[NULL]::svector[]) v;
----
NULL

query ?
SELECT avg(v) FROM unnest(ARRAY['[3e38]'::svector, '[3e38]']) v;
----
[inf]

statement error differs in dimensions
SELECT avg(v) FROM unnest(ARRAY['[1,2]'::svector, '[3]']) v;

query ?
SELECT avg(v) FROM unnest(ARRAY[to_svector(5, '{0,1}', '{2,3}'), to_svector(5, '{0,2}', '{1,3}'), to_svector(5, '{3,4}', '{3,3}')]) v;
----
[1, 1, 1, 1, 1]

query ?
SELECT avg(v) FROM unnest(ARRAY[to_svector(32, '{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}'), to_svector(32, '{1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}'), to_svector(32, '{2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}') ]) v;
----
[0.33333334, 0.6666667, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.6666667, 0.33333334, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

# test avg(svector) get the same result as avg(vector)
query ?
SELECT avg(data) = avg(data::svector)::vector FROM test_vectors;
----
t

query ?
SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[3,5,7]']) v;
----
[4, 7, 10]

# test zero element
query ?
SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[-1,2,-3]']) v;
----
[0, 4, 0]

query ?
SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::svector, '[3,5,7]', NULL]) v;
----
[4, 7, 10]

query ?
SELECT sum(v) FROM unnest(ARRAY[]::svector[]) v;
----
NULL

query ?
SELECT sum(v) FROM unnest(ARRAY[NULL]::svector[]) v;
----
NULL

statement error differs in dimensions
SELECT sum(v) FROM unnest(ARRAY['[1,2]'::svector, '[3]']) v;

# should this return an error ?
query ?
SELECT sum(v) FROM unnest(ARRAY['[3e38]'::svector, '[3e38]']) v;
----
[inf]

query ?
SELECT sum(v) FROM unnest(ARRAY[to_svector(5, '{0,1}', '{1,2}'), to_svector(5, '{0,2}', '{1,2}'), to_svector(5, '{3,4}', '{3,3}')]) v;
----
[2, 2, 2, 3, 3]

query ?
SELECT sum(v) FROM unnest(ARRAY[to_svector(32, '{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}'), to_svector(32, '{1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}'), to_svector(32, '{2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17}', '{1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}') ]) v;
----
[1, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

# test sum(svector) get the same result as sum(vector)
query ?
SELECT sum(data) = sum(data::svector)::vector FROM test_vectors;
----
t

statement ok
DROP TABLE t, test_vectors;

0 comments on commit 13e4245

Please sign in to comment.