@@ -3881,6 +3881,22 @@ XPtr<tiledb::Context> libtiledb_query_get_ctx(XPtr<tiledb::Query> query) {
38813881 return make_xptr<tiledb::Context>(new tiledb::Context (ctx));
38823882}
38833883
3884+ template <typename T>
3885+ SEXP apply_unary_aggregate (XPtr<tiledb::Query> query, std::string operator_name, bool nullable = false ) {
3886+ #if TILEDB_VERSION >= TileDB_Version(2,18,0)
3887+ T result = 0 ;
3888+ std::vector<uint8_t > nulls = { 0 };
3889+ uint64_t size = 1 ;
3890+ query->set_data_buffer (operator_name, &result, size);
3891+ if (nullable) query->set_validity_buffer (operator_name, nulls);
3892+ query->submit ();
3893+ SEXP res = Rcpp::wrap (result);
3894+ return res;
3895+ #else
3896+ return Rcpp::wrap (R_NaReal);
3897+ #endif
3898+ }
3899+
38843900// [[Rcpp::export]]
38853901SEXP libtiledb_query_apply_aggregate (XPtr<tiledb::Query> query,
38863902 std::string attribute_name,
@@ -3889,35 +3905,83 @@ SEXP libtiledb_query_apply_aggregate(XPtr<tiledb::Query> query,
38893905#if TILEDB_VERSION >= TileDB_Version(2,18,0)
38903906 check_xptr_tag<tiledb::Query>(query);
38913907 tiledb::QueryChannel channel = tiledb::QueryExperimental::get_default_channel (*query.get ());
3892- tiledb::ChannelOperation operation;
38933908 if (operator_name == " Sum" ) {
3894- operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::SumOperator>(*query.get (), attribute_name);
3909+ tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::SumOperator>(*query.get (), attribute_name);
3910+ channel.apply_aggregate (operator_name, operation);
38953911 } else if (operator_name == " Min" ) {
3896- operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::MinOperator>(*query.get (), attribute_name);
3912+ tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::MinOperator>(*query.get (), attribute_name);
3913+ channel.apply_aggregate (operator_name, operation);
38973914 } else if (operator_name == " Max" ) {
3898- operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::MaxOperator>(*query.get (), attribute_name);
3915+ tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::MaxOperator>(*query.get (), attribute_name);
3916+ channel.apply_aggregate (operator_name, operation);
38993917 } else if (operator_name == " Mean" ) {
3900- operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::MeanOperator>(*query.get (), attribute_name);
3918+ tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::MeanOperator>(*query.get (), attribute_name);
3919+ channel.apply_aggregate (operator_name, operation);
39013920 } else if (operator_name == " NullCount" ) {
3902- operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::NullCountOperator>(*query.get (), attribute_name);
3921+ tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::NullCountOperator>(*query.get (), attribute_name);
3922+ channel.apply_aggregate (operator_name, operation);
3923+ } else if (operator_name == " Count" ) {
3924+ channel.apply_aggregate (operator_name, tiledb::CountOperation ());
39033925 } else {
39043926 Rcpp::stop (" Invalid aggregation operator '%s' specified." , operator_name.c_str ());
39053927 }
3906- channel.apply_aggregate (operator_name, operation);
39073928 std::vector<uint8_t > nulls = { 0 };
39083929 uint64_t size = 1 ;
3909- if (operator_name != " NullCount" ) {
3910- double result = 0 ;
3911- query->set_data_buffer (operator_name, &result, size);
3912- if (nullable) query->set_validity_buffer (operator_name, nulls);
3913- query->submit ();
3914- return Rcpp::wrap (result);
3915- } else {
3930+ if (operator_name == " NullCount" || operator_name == " Count" ) {
3931+ // Count and null count take uint64_t.
39163932 uint64_t result = 0 ;
39173933 query->set_data_buffer (operator_name, &result, size);
3918- // no validity buffer for NullCount
3934+ if (nullable && operator_name != " NullCount" ) { // no validity buffer for NullCount
3935+ query->set_validity_buffer (operator_name, nulls);
3936+ }
39193937 query->submit ();
39203938 return Rcpp::wrap (result);
3939+ } else if (operator_name == " Mean" ) {
3940+ // Mean always takes in a double.
3941+ return apply_unary_aggregate<double >(query, operator_name, nullable);
3942+ } else if (operator_name == " Sum" ) {
3943+ // Sum will take int64_t for signed integers, uint64_t for unsigned integers
3944+ // and double for floating point values.
3945+ tiledb::Context ctx = query->ctx ();
3946+ auto arr = query->array ();
3947+ auto sch = tiledb::ArraySchema (ctx, arr.uri ());
3948+ auto attr = tiledb::Attribute (sch.attribute (attribute_name));
3949+ std::string type_name = _tiledb_datatype_to_string (attr.type ());
3950+ if (type_name == " INT8" || type_name == " INT16" ||
3951+ type_name == " INT32" || type_name == " INT64" ) {
3952+ return apply_unary_aggregate<int64_t >(query, operator_name, nullable);
3953+ } else if (type_name == " UINT8" || type_name == " UINT16" ||
3954+ type_name == " UINT32" || type_name == " UINT64" ) {
3955+ return apply_unary_aggregate<uint64_t >(query, operator_name, nullable);
3956+ } else if (type_name == " FLOAT32" || type_name == " FLOAT64" ) {
3957+ return apply_unary_aggregate<double >(query, operator_name, nullable);
3958+ } else {
3959+ Rcpp::stop (" 'Sum' operator not valid for attribute '%s' of type '%s'" ,
3960+ attribute_name, type_name);
3961+ }
3962+ } else if (operator_name == " Min" || operator_name == " Max" ) {
3963+ // Min/max will take whatever the datatype of the column is.
3964+ tiledb::Context ctx = query->ctx ();
3965+ auto arr = query->array ();
3966+ auto sch = tiledb::ArraySchema (ctx, arr.uri ());
3967+ auto attr = tiledb::Attribute (sch.attribute (attribute_name));
3968+ std::string type_name = _tiledb_datatype_to_string (attr.type ());
3969+ switch (attr.type ()) {
3970+ case TILEDB_INT8: return apply_unary_aggregate<int16_t >(query, operator_name, nullable); // int8_t bites char
3971+ case TILEDB_INT16: return apply_unary_aggregate<int16_t >(query, operator_name, nullable);
3972+ case TILEDB_INT32: return apply_unary_aggregate<int32_t >(query, operator_name, nullable);
3973+ case TILEDB_INT64: return apply_unary_aggregate<int64_t >(query, operator_name, nullable);
3974+ case TILEDB_UINT8: return apply_unary_aggregate<uint16_t >(query, operator_name, nullable); // uint8_t bites char
3975+ case TILEDB_UINT16: return apply_unary_aggregate<uint16_t >(query, operator_name, nullable);
3976+ case TILEDB_UINT32: return apply_unary_aggregate<uint32_t >(query, operator_name, nullable);
3977+ case TILEDB_UINT64: return apply_unary_aggregate<uint64_t >(query, operator_name, nullable);
3978+ case TILEDB_FLOAT32: return apply_unary_aggregate<float >(query, operator_name, nullable);
3979+ case TILEDB_FLOAT64: return apply_unary_aggregate<double >(query, operator_name, nullable);
3980+ default : Rcpp::stop (" '%s' is not defined for attribute '%s' of type '%s'" ,
3981+ operator_name, attribute_name, type_name);
3982+ }
3983+ } else {
3984+ Rcpp::stop (" '%s' is not implemented for '%s'" , operator_name, attribute_name);
39213985 }
39223986#else
39233987 return Rcpp::wrap (R_NaReal);
0 commit comments