diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index 9f4c11e4c1c..2fd18f1a6ef 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -806,6 +806,7 @@ private static native long[] repeatColumnCount(long tableHandle, private static native ContigSplitGroupByResult contiguousSplitGroups(long inputTable, int[] keyIndices, + int[] projectionColumnIndices, boolean ignoreNullKeys, boolean keySorted, boolean[] keysDescending, @@ -4474,9 +4475,11 @@ public Table replaceNulls(ReplacePolicyWithColumn... replacements) { * for the memory to be released. */ public ContiguousTable[] contiguousSplitGroups() { + int[] defaultValueIndices= null; try (ContigSplitGroupByResult ret = Table.contiguousSplitGroups( operation.table.nativeHandle, operation.indices, + defaultValueIndices, groupByOptions.getIgnoreNullKeys(), groupByOptions.getKeySorted(), groupByOptions.getKeysDescending(), @@ -4504,15 +4507,51 @@ public ContiguousTable[] contiguousSplitGroups() { * @return The split groups and uniq key table. */ public ContigSplitGroupByResult contiguousSplitGroupsAndGenUniqKeys() { + int[] defaultValueIndices = null; return Table.contiguousSplitGroups( operation.table.nativeHandle, operation.indices, + defaultValueIndices, groupByOptions.getIgnoreNullKeys(), groupByOptions.getKeySorted(), groupByOptions.getKeysDescending(), groupByOptions.getKeysNullSmallest(), true); // generate uniq key table } + + /** + * Similar to the above {@link #contiguousSplitGroupsAndGenUniqKeys}. + * + * The diff with the above method is: + * - Provide an extra input `projectionColumnIndices` which defines the columns to output. + * - The above method outputs keys columns in the split tables, + * but this method does not except `projectionColumnIndices` includes key columns. + * + * The split tables only contain the columns defined in the `projectionColumnIndices` + * + * @param projectionColumnIndices Defines the output columns. + * @return The split groups and uniq key table. + */ + public ContigSplitGroupByResult contiguousSplitGroupsAndGenUniqKeys( + int[] projectionColumnIndices) { + if (operation.indices == null || operation.indices.length == 0) { + throw new IllegalArgumentException("key indices is empty!"); + } + + if (projectionColumnIndices == null || projectionColumnIndices.length == 0) { + throw new IllegalArgumentException("value indices is empty!"); + } + + return Table.contiguousSplitGroups( + operation.table.nativeHandle, + operation.indices, + projectionColumnIndices, + groupByOptions.getIgnoreNullKeys(), + groupByOptions.getKeySorted(), + groupByOptions.getKeysDescending(), + groupByOptions.getKeysNullSmallest(), + true); // generate uniq key table + } } public static final class TableOperation { diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index a55b96db9ac..c90123dc606 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -4606,6 +4606,7 @@ Java_ai_rapids_cudf_Table_contiguousSplitGroups(JNIEnv* env, jclass, jlong jinput_table, jintArray jkey_indices, + jintArray jprojection_column_indices, jboolean jignore_null_keys, jboolean jkey_sorted, jbooleanArray jkeys_sort_desc, @@ -4643,42 +4644,73 @@ Java_ai_rapids_cudf_Table_contiguousSplitGroups(JNIEnv* env, // 1) Gets the groups(keys, offsets, values) from groupby. // - // Uses only the non-key columns as the input values instead of the whole table, - // to avoid duplicated key columns in output of `get_groups`. - // The code looks like a little more complicated, but it can reduce the peak memory. - auto num_value_cols = input_table->num_columns() - key_indices.size(); + // If the `jprojection_column_indices` is null, uses all_columns - key_columns as value columns; + // If the `jprojection_column_indices` is not null, use the `jprojection_column_indices` columns + // as value columns. std::vector value_indices; - value_indices.reserve(num_value_cols); - // column indices start with 0. - cudf::size_type index = 0; - while (value_indices.size() < num_value_cols) { - if (std::find(key_indices.begin(), key_indices.end(), index) == key_indices.end()) { - // not key column, so adds it as value column. - value_indices.emplace_back(index); + auto num_value_cols = [&]() -> size_t { + if (jprojection_column_indices == NULL) { + // if a column is not in key columns, then it's a value column + auto num_v_cols = static_cast(input_table->num_columns()) - key_indices.size(); + value_indices.reserve(num_v_cols); + cudf::size_type index = 0; + while (value_indices.size() < num_v_cols) { + if (std::find(key_indices.begin(), key_indices.end(), index) == key_indices.end()) { + // not key column, so adds it as value column. + value_indices.emplace_back(index); + } + index++; + } + return num_v_cols; + } else { + // use the specified columns as value columns + cudf::jni::native_jintArray n_project_indices(env, jprojection_column_indices); + value_indices.reserve(n_project_indices.size()); + for (auto i = 0; i < n_project_indices.size(); i++) { + value_indices.emplace_back(n_project_indices[i]); + } + return static_cast(n_project_indices.size()); } - index++; - } + }(); + cudf::table_view values_view = input_table->select(value_indices); // execute grouping cudf::groupby::groupby::groups groups = grouper.get_groups(values_view); - // When builds the table view from keys and values of 'groups', restores the - // original order of columns (same order with that in input table). - std::vector grouped_cols(key_indices.size() + num_value_cols); - // key columns - auto key_view = groups.keys->view(); - auto key_view_it = key_view.begin(); - for (auto key_id : key_indices) { - grouped_cols.at(key_id) = std::move(*key_view_it); - key_view_it++; - } - // value columns - auto value_view = groups.values->view(); - auto value_view_it = value_view.begin(); - for (auto value_id : value_indices) { - grouped_cols.at(value_id) = std::move(*value_view_it); - value_view_it++; + // if jprojection_column_indices is null, output both key columns and value columns; + // otherwise, only output value columns. + auto num_grouped_cols = + num_value_cols + ((jprojection_column_indices == NULL) ? key_indices.size() : 0); + + std::vector grouped_cols(num_grouped_cols); + + if (jprojection_column_indices == NULL) { + // When builds the table view from keys and values of 'groups', restores the + // original order of columns (same order with that in input table). + // key columns + auto key_view = groups.keys->view(); + auto key_view_it = key_view.begin(); + for (auto key_id : key_indices) { + grouped_cols[key_id] = *key_view_it; + key_view_it++; + } + // value columns + auto value_view = groups.values->view(); + auto value_view_it = value_view.begin(); + for (auto value_id : value_indices) { + grouped_cols[value_id] = *value_view_it; + value_view_it++; + } + } else { + // specified value_indices, do not output keys columns by default + auto value_view = groups.values->view(); + auto value_view_it = value_view.begin(); + for (size_t i = 0; i < num_value_cols; ++i) { + grouped_cols[i] = *value_view_it; + value_view_it++; + } } + cudf::table_view grouped_table(grouped_cols); // When no key columns, uses the input table instead, because the output // of 'get_groups' is empty. diff --git a/java/src/test/java/ai/rapids/cudf/TableTest.java b/java/src/test/java/ai/rapids/cudf/TableTest.java index 805442571d7..0767f471478 100644 --- a/java/src/test/java/ai/rapids/cudf/TableTest.java +++ b/java/src/test/java/ai/rapids/cudf/TableTest.java @@ -8237,6 +8237,68 @@ void testGroupByContiguousSplitGroups() throws Exception { } } + @Test + void testGroupByContiguousSplitGroupsSpecifyValueIndices() throws Exception { + try (Table table = new Table.TestBuilder() + .column(1, 1, 1, 1, 1, 1) + .column(1, 3, 3, 5, 5, 5) + .column(12, 14, 13, 17, 16, 18) + .column("s1", "s2", "s3", "s4", "s5", "s6") + .build()) { + // Normal case with primitive types. + try (Table expected1 = new Table.TestBuilder() + .column(12) + .column("s1").build(); + Table expected2 = new Table.TestBuilder() + .column(14, 13) + .column("s2", "s3").build(); + Table expected3 = new Table.TestBuilder() + .column(17, 16, 18) + .column("s4", "s5", "s6").build(); + Table expectedUniqKeys = new Table.TestBuilder() + .column(1, 1, 1) + .column(1, 3, 5).build(); + ContigSplitGroupByResult r = + table.groupBy(0, 1).contiguousSplitGroupsAndGenUniqKeys(new int[]{2, 3})) { + ContiguousTable[] splits = r.getGroups(); + Table uniqKeys = r.getUniqKeyTable(); + + for (ContiguousTable ct : splits) { + if (ct.getRowCount() == 1) { + assertTablesAreEqual(expected1, ct.getTable()); + } else if (ct.getRowCount() == 2) { + assertTablesAreEqual(expected2, ct.getTable()); + } else if (ct.getRowCount() == 3) { + assertTablesAreEqual(expected3, ct.getTable()); + } else { + throw new RuntimeException("unexpected behavior: contiguousSplitGroups"); + } + } + + // verify uniq keys table + assertTablesAreEqual(expectedUniqKeys, uniqKeys); + } + + // Row count is 0 + try ( + Table emptyTable = new Table.TestBuilder() + .column(new Integer[0]) + .column(new Integer[0]) + .column(new Integer[0]) + .column(new String[0]).build(); + ContigSplitGroupByResult r = + emptyTable.groupBy(0, 1).contiguousSplitGroupsAndGenUniqKeys(new int[]{2, 3})) { + ContiguousTable[] splits = r.getGroups(); + Table uniqKeys = r.getUniqKeyTable(); + + assertEquals(0, emptyTable.getRowCount()); + assertEquals(1, splits.length); + assertEquals(0, splits[0].getTable().getRowCount()); + assertEquals(0, uniqKeys.getRowCount()); + } + } + } + @Test void testGroupByCollectListIncludeNulls() { try (Table input = new Table.TestBuilder()