Skip to content

Commit f389063

Browse files
committed
Extended aggregates support
1 parent 54f9214 commit f389063

5 files changed

Lines changed: 130 additions & 16 deletions

File tree

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ export(tdb_select)
8383
export(tile)
8484
export(tile_order)
8585
export(tiledb_array)
86+
export(tiledb_array_apply_aggregate)
8687
export(tiledb_array_close)
8788
export(tiledb_array_create)
8889
export(tiledb_array_delete_fragments)

R/Array.R

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,29 @@ tiledb_array_has_enumeration <- function(arr) {
177177
}
178178
return(libtiledb_array_has_enumeration_vector(ctx@ptr, arr@ptr))
179179
}
180+
181+
##' Run an aggregate query on the given array and attribute
182+
##'
183+
##' @param qry A TileDB Query object
184+
##' @param attrname The name of an attribute
185+
##' @param operation The name of aggregation operation
186+
##' @param nullable A boolean toggle whether the attribute is nullable
187+
##' @return The value of the aggregation
188+
##' @export
189+
tiledb_array_apply_aggregate <- function(array, attrname, operation, nullable = TRUE) {
190+
stopifnot("The 'query' argument must be a TileDB Array object" = is(array, "tiledb_array"),
191+
"The 'attrname' argument must be character" = is.character(attrname),
192+
"The 'operation' argument must be character" = is.character(operation),
193+
"The 'nullable' argument must be logical" = is.logical(nullable))
194+
## TODO: match.arg for operation
195+
196+
if (tiledb_array_is_open(array))
197+
array <- tiledb_array_close(array)
198+
199+
query <- tiledb_query(array, "READ")
200+
201+
if (tiledb_query_get_layout(query) != "UNORDERED")
202+
query <- tiledb_query_set_layout(query, "UNORDERED") # TODO: allow GLOBAL_ORDER too?
203+
204+
libtiledb_query_apply_aggregate(query@ptr, attrname, operation, nullable)
205+
}

R/Query.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# MIT License
22
#
3-
# Copyright (c) 2017-2022 TileDB Inc.
3+
# Copyright (c) 2017-2023 TileDB Inc.
44
#
55
# Permission is hereby granted, free of charge, to any person obtaining a copy
66
# of this software and associated documentation files (the "Software"), to deal

man/tiledb_array_apply_aggregate.Rd

Lines changed: 23 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/libtiledb.cpp

Lines changed: 79 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3881,6 +3881,22 @@ XPtr<tiledb::Context> libtiledb_query_get_ctx(XPtr<tiledb::Query> query) {
38813881
return make_xptr<tiledb::Context>(new tiledb::Context(ctx));
38823882
}
38833883

3884+
template <typename T>
3885+
SEXP apply_unary_aggregate(XPtr<tiledb::Query> query, std::string operator_name, bool nullable = false) {
3886+
#if TILEDB_VERSION >= TileDB_Version(2,18,0)
3887+
T result = 0;
3888+
std::vector<uint8_t> nulls = { 0 };
3889+
uint64_t size = 1;
3890+
query->set_data_buffer(operator_name, &result, size);
3891+
if (nullable) query->set_validity_buffer(operator_name, nulls);
3892+
query->submit();
3893+
SEXP res = Rcpp::wrap(result);
3894+
return res;
3895+
#else
3896+
return Rcpp::wrap(R_NaReal);
3897+
#endif
3898+
}
3899+
38843900
// [[Rcpp::export]]
38853901
SEXP libtiledb_query_apply_aggregate(XPtr<tiledb::Query> query,
38863902
std::string attribute_name,
@@ -3889,35 +3905,83 @@ SEXP libtiledb_query_apply_aggregate(XPtr<tiledb::Query> query,
38893905
#if TILEDB_VERSION >= TileDB_Version(2,18,0)
38903906
check_xptr_tag<tiledb::Query>(query);
38913907
tiledb::QueryChannel channel = tiledb::QueryExperimental::get_default_channel(*query.get());
3892-
tiledb::ChannelOperation operation;
38933908
if (operator_name == "Sum") {
3894-
operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::SumOperator>(*query.get(), attribute_name);
3909+
tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::SumOperator>(*query.get(), attribute_name);
3910+
channel.apply_aggregate(operator_name, operation);
38953911
} else if (operator_name == "Min") {
3896-
operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::MinOperator>(*query.get(), attribute_name);
3912+
tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::MinOperator>(*query.get(), attribute_name);
3913+
channel.apply_aggregate(operator_name, operation);
38973914
} else if (operator_name == "Max") {
3898-
operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::MaxOperator>(*query.get(), attribute_name);
3915+
tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::MaxOperator>(*query.get(), attribute_name);
3916+
channel.apply_aggregate(operator_name, operation);
38993917
} else if (operator_name == "Mean") {
3900-
operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::MeanOperator>(*query.get(), attribute_name);
3918+
tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::MeanOperator>(*query.get(), attribute_name);
3919+
channel.apply_aggregate(operator_name, operation);
39013920
} else if (operator_name == "NullCount") {
3902-
operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::NullCountOperator>(*query.get(), attribute_name);
3921+
tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::NullCountOperator>(*query.get(), attribute_name);
3922+
channel.apply_aggregate(operator_name, operation);
3923+
} else if (operator_name == "Count") {
3924+
channel.apply_aggregate(operator_name, tiledb::CountOperation());
39033925
} else {
39043926
Rcpp::stop("Invalid aggregation operator '%s' specified.", operator_name.c_str());
39053927
}
3906-
channel.apply_aggregate(operator_name, operation);
39073928
std::vector<uint8_t> nulls = { 0 };
39083929
uint64_t size = 1;
3909-
if (operator_name != "NullCount") {
3910-
double result = 0;
3911-
query->set_data_buffer(operator_name, &result, size);
3912-
if (nullable) query->set_validity_buffer(operator_name, nulls);
3913-
query->submit();
3914-
return Rcpp::wrap(result);
3915-
} else {
3930+
if (operator_name == "NullCount" || operator_name == "Count") {
3931+
// Count and null count take uint64_t.
39163932
uint64_t result = 0;
39173933
query->set_data_buffer(operator_name, &result, size);
3918-
// no validity buffer for NullCount
3934+
if (nullable && operator_name != "NullCount") { // no validity buffer for NullCount
3935+
query->set_validity_buffer(operator_name, nulls);
3936+
}
39193937
query->submit();
39203938
return Rcpp::wrap(result);
3939+
} else if (operator_name == "Mean") {
3940+
// Mean always takes in a double.
3941+
return apply_unary_aggregate<double>(query, operator_name, nullable);
3942+
} else if (operator_name == "Sum") {
3943+
// Sum will take int64_t for signed integers, uint64_t for unsigned integers
3944+
// and double for floating point values.
3945+
tiledb::Context ctx = query->ctx();
3946+
auto arr = query->array();
3947+
auto sch = tiledb::ArraySchema(ctx, arr.uri());
3948+
auto attr = tiledb::Attribute(sch.attribute(attribute_name));
3949+
std::string type_name = _tiledb_datatype_to_string(attr.type());
3950+
if (type_name == "INT8" || type_name == "INT16" ||
3951+
type_name == "INT32" || type_name == "INT64") {
3952+
return apply_unary_aggregate<int64_t>(query, operator_name, nullable);
3953+
} else if (type_name == "UINT8" || type_name == "UINT16" ||
3954+
type_name == "UINT32" || type_name == "UINT64") {
3955+
return apply_unary_aggregate<uint64_t>(query, operator_name, nullable);
3956+
} else if (type_name == "FLOAT32" || type_name == "FLOAT64") {
3957+
return apply_unary_aggregate<double>(query, operator_name, nullable);
3958+
} else {
3959+
Rcpp::stop("'Sum' operator not valid for attribute '%s' of type '%s'",
3960+
attribute_name, type_name);
3961+
}
3962+
} else if (operator_name == "Min" || operator_name == "Max") {
3963+
// Min/max will take whatever the datatype of the column is.
3964+
tiledb::Context ctx = query->ctx();
3965+
auto arr = query->array();
3966+
auto sch = tiledb::ArraySchema(ctx, arr.uri());
3967+
auto attr = tiledb::Attribute(sch.attribute(attribute_name));
3968+
std::string type_name = _tiledb_datatype_to_string(attr.type());
3969+
switch (attr.type()) {
3970+
case TILEDB_INT8: return apply_unary_aggregate<int16_t>(query, operator_name, nullable); // int8_t bites char
3971+
case TILEDB_INT16: return apply_unary_aggregate<int16_t>(query, operator_name, nullable);
3972+
case TILEDB_INT32: return apply_unary_aggregate<int32_t>(query, operator_name, nullable);
3973+
case TILEDB_INT64: return apply_unary_aggregate<int64_t>(query, operator_name, nullable);
3974+
case TILEDB_UINT8: return apply_unary_aggregate<uint16_t>(query, operator_name, nullable); // uint8_t bites char
3975+
case TILEDB_UINT16: return apply_unary_aggregate<uint16_t>(query, operator_name, nullable);
3976+
case TILEDB_UINT32: return apply_unary_aggregate<uint32_t>(query, operator_name, nullable);
3977+
case TILEDB_UINT64: return apply_unary_aggregate<uint64_t>(query, operator_name, nullable);
3978+
case TILEDB_FLOAT32: return apply_unary_aggregate<float>(query, operator_name, nullable);
3979+
case TILEDB_FLOAT64: return apply_unary_aggregate<double>(query, operator_name, nullable);
3980+
default: Rcpp::stop("'%s' is not defined for attribute '%s' of type '%s'",
3981+
operator_name, attribute_name, type_name);
3982+
}
3983+
} else {
3984+
Rcpp::stop("'%s' is not implemented for '%s'", operator_name, attribute_name);
39213985
}
39223986
#else
39233987
return Rcpp::wrap(R_NaReal);

0 commit comments

Comments
 (0)