Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,7 @@ private static native long[] repeatColumnCount(long tableHandle,

private static native ContigSplitGroupByResult contiguousSplitGroups(long inputTable,
int[] keyIndices,
int[] projectionColumnIndices,
boolean ignoreNullKeys,
boolean keySorted,
boolean[] keysDescending,
Expand Down Expand Up @@ -4474,9 +4475,11 @@ public Table replaceNulls(ReplacePolicyWithColumn... replacements) {
* for the memory to be released.
*/
public ContiguousTable[] contiguousSplitGroups() {
int[] defaultValueIndices= null;
try (ContigSplitGroupByResult ret = Table.contiguousSplitGroups(
operation.table.nativeHandle,
operation.indices,
defaultValueIndices,
groupByOptions.getIgnoreNullKeys(),
groupByOptions.getKeySorted(),
groupByOptions.getKeysDescending(),
Expand Down Expand Up @@ -4504,15 +4507,51 @@ public ContiguousTable[] contiguousSplitGroups() {
* @return The split groups and uniq key table.
*/
public ContigSplitGroupByResult contiguousSplitGroupsAndGenUniqKeys() {
int[] defaultValueIndices = null;
return Table.contiguousSplitGroups(
operation.table.nativeHandle,
operation.indices,
defaultValueIndices,
groupByOptions.getIgnoreNullKeys(),
groupByOptions.getKeySorted(),
groupByOptions.getKeysDescending(),
groupByOptions.getKeysNullSmallest(),
true); // generate uniq key table
}

/**
* Similar to the above {@link #contiguousSplitGroupsAndGenUniqKeys}.
*
* The diff with the above method is:
* - Provide an extra input `projectionColumnIndices` which defines the columns to output.
* - The above method outputs keys columns in the split tables,
* but this method does not except `projectionColumnIndices` includes key columns.
*
* The split tables only contain the columns defined in the `projectionColumnIndices`
*
* @param projectionColumnIndices Defines the output columns.
* @return The split groups and uniq key table.
*/
public ContigSplitGroupByResult contiguousSplitGroupsAndGenUniqKeys(
int[] projectionColumnIndices) {
if (operation.indices == null || operation.indices.length == 0) {
throw new IllegalArgumentException("key indices is empty!");
}

if (projectionColumnIndices == null || projectionColumnIndices.length == 0) {
throw new IllegalArgumentException("value indices is empty!");
}

return Table.contiguousSplitGroups(
operation.table.nativeHandle,
operation.indices,
projectionColumnIndices,
groupByOptions.getIgnoreNullKeys(),
groupByOptions.getKeySorted(),
groupByOptions.getKeysDescending(),
groupByOptions.getKeysNullSmallest(),
true); // generate uniq key table
}
}

public static final class TableOperation {
Expand Down
90 changes: 61 additions & 29 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4606,6 +4606,7 @@ Java_ai_rapids_cudf_Table_contiguousSplitGroups(JNIEnv* env,
jclass,
jlong jinput_table,
jintArray jkey_indices,
jintArray jprojection_column_indices,
jboolean jignore_null_keys,
jboolean jkey_sorted,
jbooleanArray jkeys_sort_desc,
Expand Down Expand Up @@ -4643,42 +4644,73 @@ Java_ai_rapids_cudf_Table_contiguousSplitGroups(JNIEnv* env,

// 1) Gets the groups(keys, offsets, values) from groupby.
//
// Uses only the non-key columns as the input values instead of the whole table,
// to avoid duplicated key columns in output of `get_groups`.
// The code looks like a little more complicated, but it can reduce the peak memory.
auto num_value_cols = input_table->num_columns() - key_indices.size();
// If the `jprojection_column_indices` is null, uses all_columns - key_columns as value columns;
// If the `jprojection_column_indices` is not null, use the `jprojection_column_indices` columns
// as value columns.
std::vector<cudf::size_type> value_indices;
value_indices.reserve(num_value_cols);
// column indices start with 0.
cudf::size_type index = 0;
while (value_indices.size() < num_value_cols) {
if (std::find(key_indices.begin(), key_indices.end(), index) == key_indices.end()) {
// not key column, so adds it as value column.
value_indices.emplace_back(index);
auto num_value_cols = [&]() -> size_t {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: We might be able to remove the parentheses:

Suggested change
auto num_value_cols = [&]() -> size_t {
auto num_value_cols = [&] -> size_t {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding () causes a warning:

TableJni.cpp:4650:31: error: parameter declaration before lambda trailing return type only optional with ‘-std=c++2b’ or ‘-std=gnu++2b’ [-Werror=c++23-extensions]
[INFO]      [exec]  4650 |     auto num_value_cols = [&] -> size_t {
[INFO]      [exec]       |                               ^~
[INFO]      [exec] cc1plus: all warnings being treated as errors

if (jprojection_column_indices == NULL) {
// if a column is not in key columns, then it's a value column
auto num_v_cols = static_cast<size_t>(input_table->num_columns()) - key_indices.size();
value_indices.reserve(num_v_cols);
cudf::size_type index = 0;
while (value_indices.size() < num_v_cols) {
if (std::find(key_indices.begin(), key_indices.end(), index) == key_indices.end()) {
// not key column, so adds it as value column.
value_indices.emplace_back(index);
}
index++;
}
return num_v_cols;
} else {
// use the specified columns as value columns
cudf::jni::native_jintArray n_project_indices(env, jprojection_column_indices);
value_indices.reserve(n_project_indices.size());
for (auto i = 0; i < n_project_indices.size(); i++) {
value_indices.emplace_back(n_project_indices[i]);
}
return static_cast<size_t>(n_project_indices.size());
}
index++;
}
}();

cudf::table_view values_view = input_table->select(value_indices);
// execute grouping
cudf::groupby::groupby::groups groups = grouper.get_groups(values_view);

// When builds the table view from keys and values of 'groups', restores the
// original order of columns (same order with that in input table).
std::vector<cudf::column_view> grouped_cols(key_indices.size() + num_value_cols);
// key columns
auto key_view = groups.keys->view();
auto key_view_it = key_view.begin();
for (auto key_id : key_indices) {
grouped_cols.at(key_id) = std::move(*key_view_it);
key_view_it++;
}
// value columns
auto value_view = groups.values->view();
auto value_view_it = value_view.begin();
for (auto value_id : value_indices) {
grouped_cols.at(value_id) = std::move(*value_view_it);
value_view_it++;
// if jprojection_column_indices is null, output both key columns and value columns;
// otherwise, only output value columns.
auto num_grouped_cols =
num_value_cols + ((jprojection_column_indices == NULL) ? key_indices.size() : 0);

std::vector<cudf::column_view> grouped_cols(num_grouped_cols);

if (jprojection_column_indices == NULL) {
// When builds the table view from keys and values of 'groups', restores the
// original order of columns (same order with that in input table).
// key columns
auto key_view = groups.keys->view();
auto key_view_it = key_view.begin();
for (auto key_id : key_indices) {
grouped_cols[key_id] = *key_view_it;
key_view_it++;
}
// value columns
auto value_view = groups.values->view();
auto value_view_it = value_view.begin();
for (auto value_id : value_indices) {
grouped_cols[value_id] = *value_view_it;
value_view_it++;
}
} else {
// specified value_indices, do not output keys columns by default
auto value_view = groups.values->view();
auto value_view_it = value_view.begin();
for (size_t i = 0; i < num_value_cols; ++i) {
grouped_cols[i] = *value_view_it;
value_view_it++;
}
}

cudf::table_view grouped_table(grouped_cols);
// When no key columns, uses the input table instead, because the output
// of 'get_groups' is empty.
Expand Down
62 changes: 62 additions & 0 deletions java/src/test/java/ai/rapids/cudf/TableTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -8237,6 +8237,68 @@ void testGroupByContiguousSplitGroups() throws Exception {
}
}

@Test
void testGroupByContiguousSplitGroupsSpecifyValueIndices() throws Exception {
try (Table table = new Table.TestBuilder()
.column(1, 1, 1, 1, 1, 1)
.column(1, 3, 3, 5, 5, 5)
.column(12, 14, 13, 17, 16, 18)
.column("s1", "s2", "s3", "s4", "s5", "s6")
.build()) {
// Normal case with primitive types.
try (Table expected1 = new Table.TestBuilder()
.column(12)
.column("s1").build();
Table expected2 = new Table.TestBuilder()
.column(14, 13)
.column("s2", "s3").build();
Table expected3 = new Table.TestBuilder()
.column(17, 16, 18)
.column("s4", "s5", "s6").build();
Table expectedUniqKeys = new Table.TestBuilder()
.column(1, 1, 1)
.column(1, 3, 5).build();
ContigSplitGroupByResult r =
table.groupBy(0, 1).contiguousSplitGroupsAndGenUniqKeys(new int[]{2, 3})) {
ContiguousTable[] splits = r.getGroups();
Table uniqKeys = r.getUniqKeyTable();

for (ContiguousTable ct : splits) {
if (ct.getRowCount() == 1) {
assertTablesAreEqual(expected1, ct.getTable());
} else if (ct.getRowCount() == 2) {
assertTablesAreEqual(expected2, ct.getTable());
} else if (ct.getRowCount() == 3) {
assertTablesAreEqual(expected3, ct.getTable());
} else {
throw new RuntimeException("unexpected behavior: contiguousSplitGroups");
}
}

// verify uniq keys table
assertTablesAreEqual(expectedUniqKeys, uniqKeys);
}

// Row count is 0
try (
Table emptyTable = new Table.TestBuilder()
.column(new Integer[0])
.column(new Integer[0])
.column(new Integer[0])
.column(new String[0]).build();
ContigSplitGroupByResult r =
emptyTable.groupBy(0, 1).contiguousSplitGroupsAndGenUniqKeys(new int[]{2, 3})) {
ContiguousTable[] splits = r.getGroups();
Table uniqKeys = r.getUniqKeyTable();

assertEquals(0, emptyTable.getRowCount());
assertEquals(1, splits.length);
assertEquals(0, splits[0].getTable().getRowCount());
assertEquals(0, uniqKeys.getRowCount());
}
}
}

@Test
void testGroupByCollectListIncludeNulls() {
try (Table input = new Table.TestBuilder()
Expand Down