Skip to content

Commit

Permalink
Add JNI backend for Spark SQL function conv for (hexa)decimals (#1314)
Browse files Browse the repository at this point in the history
Contributes to NVIDIA/spark-rapids#8511

POC supporting form/to  radices 10 and 16 leveraging existing libcudf API 

Signed-off-by: Gera Shegalov <[email protected]>
  • Loading branch information
gerashegalov authored Aug 17, 2023
1 parent becb973 commit 72cb837
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 13 deletions.
108 changes: 108 additions & 0 deletions src/main/cpp/src/CastStringJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@
*/

#include "cast_string.hpp"
#include <cudf/binaryop.hpp>
#include <cudf/column/column_factories.hpp>
#include <cudf/copying.hpp>
#include <cudf/replace.hpp>
#include <cudf/scalar/scalar_factories.hpp>
#include <cudf/strings/contains.hpp>
#include <cudf/strings/convert/convert_integers.hpp>
#include <cudf/strings/extract.hpp>
#include <cudf/strings/find.hpp>
#include <cudf/strings/regex/regex_program.hpp>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/transform.hpp>
#include <cudf/unary.hpp>

#include "cudf_jni_apis.hpp"
#include "dtype_utils.hpp"
Expand Down Expand Up @@ -111,4 +124,99 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_CastStrings_fromDecimal
}
CATCH_CAST_EXCEPTION(env, 0);
}

JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_CastStrings_toIntegersWithBase(
JNIEnv* env, jclass, jlong input_column, jint base, jboolean ansi_enabled, jint j_dtype)
{
JNI_NULL_CHECK(env, input_column, "input column is null", 0);
using namespace cudf;
try {
if (base != 10 && base != 16) {
auto const error_msg = "Bases supported 10, 16; Actual: " + std::to_string(base);
throw spark_rapids_jni::cast_error(0, error_msg);
}

jni::auto_set_device(env);
auto const zero_scalar = numeric_scalar<uint64_t>(0);
auto const res_data_type = jni::make_data_type(j_dtype, 0);
auto const input_view{*reinterpret_cast<column_view const*>(input_column)};
auto const validity_regex_str = [&] {
switch (base) {
case 10: return R"(^\s*(-?[0-9]+).*)"; break;
case 16: return R"(^\s*(-?[0-9a-fA-F]+).*)"; break;
default: throw spark_rapids_jni::cast_error(0, "INFEASIBLE"); break;
}
}();

auto const validity_regex = strings::regex_program::create(validity_regex_str);
auto const valid_rows = strings::matches_re(input_view, *validity_regex);
auto const prepped_table = strings::extract(input_view, *validity_regex);
const strings_column_view prepped_view{prepped_table->get_column(0)};
auto int_col = [&] {
switch (base) {
case 10: {
return strings::to_integers(prepped_view, res_data_type);
} break;
case 16: {
auto const is_negative = strings::starts_with(prepped_view, string_scalar("-"));
auto const pos_vals = strings::hex_to_integers(prepped_view, res_data_type);
auto neg_vals =
binary_operation(zero_scalar, *pos_vals, binary_operator::SUB, res_data_type);
return copy_if_else(*neg_vals, *pos_vals, *is_negative);
}
default: {
throw spark_rapids_jni::cast_error(0, "INFEASIBLE");
break;
}
}
}();

auto unmatched_implies_zero = copy_if_else(*int_col, zero_scalar, *valid_rows);

// output nulls: original + all rows matching \s*

auto const space_only_regex = strings::regex_program::create(R"(^\s*$)");
auto const extra_null_rows = strings::matches_re(input_view, *space_only_regex);
auto const extra_mask = unary_operation(*extra_null_rows, unary_operator::NOT);

auto const original_mask = mask_to_bools(input_view.null_mask(), 0, input_view.size());
auto const new_mask = binary_operation(
*original_mask, *extra_mask, binary_operator::BITWISE_AND, data_type(type_id::BOOL8));

auto const [null_mask, null_count] = bools_to_mask(*new_mask);
unmatched_implies_zero->set_null_mask(*null_mask, null_count);
return jni::release_as_jlong(unmatched_implies_zero);
}
CATCH_CAST_EXCEPTION(env, 0);
}

JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_CastStrings_fromIntegersWithBase(
JNIEnv* env, jclass, jlong input_column, jint base)
{
JNI_NULL_CHECK(env, input_column, "input column is null", 0);
using namespace cudf;
try {
jni::auto_set_device(env);
auto input_view{*reinterpret_cast<column_view const*>(input_column)};
auto result = [&] {
switch (base) {
case 10: {
return strings::from_integers(input_view);
} break;
case 16: {
auto pre_res = strings::integers_to_hex(input_view);
auto const regex = strings::regex_program::create("^0?([0-9a-fA-F]+)$");
auto const wo_leading_zeros = strings::extract(strings_column_view(*pre_res), *regex);
return std::move(wo_leading_zeros->release()[0]);
}
default: {
auto const error_msg = "Bases supported 10, 16; Actual: " + std::to_string(base);
throw spark_rapids_jni::cast_error(0, error_msg);
}
}
}();
return jni::release_as_jlong(result);
}
CATCH_CAST_EXCEPTION(env, 0);
}
}
4 changes: 2 additions & 2 deletions src/main/cpp/src/row_conversion.cu
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ struct tile_info {
*
*/
struct row_batch {
size_type num_bytes; // number of bytes in this batch
size_type row_count; // number of rows in the batch
size_type num_bytes; // number of bytes in this batch
size_type row_count; // number of rows in the batch
device_uvector<size_type> row_offsets; // offsets column of output cudf column
};

Expand Down
16 changes: 15 additions & 1 deletion src/main/java/com/nvidia/spark/rapids/jni/CastStrings.java
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public static ColumnVector toDecimal(ColumnView cv, boolean ansiMode, boolean st

/**
* Convert a decimal column to a string column.
*
*
* @param cv the column data to process
* @return the converted column
*/
Expand All @@ -102,10 +102,24 @@ public static ColumnVector toFloat(ColumnView cv, boolean ansiMode, DType type)
return new ColumnVector(toFloat(cv.getNativeView(), ansiMode, type.getTypeId().getNativeId()));
}


public static ColumnVector toIntegersWithBase(ColumnView cv, int base,
boolean ansiEnabled, DType type) {
return new ColumnVector(toIntegersWithBase(cv.getNativeView(), base, ansiEnabled,
type.getTypeId().getNativeId()));
}

public static ColumnVector fromIntegersWithBase(ColumnView cv, int base) {
return new ColumnVector(fromIntegersWithBase(cv.getNativeView(), base));
}

private static native long toInteger(long nativeColumnView, boolean ansi_enabled, boolean strip,
int dtype);
private static native long toDecimal(long nativeColumnView, boolean ansi_enabled, boolean strip,
int precision, int scale);
private static native long toFloat(long nativeColumnView, boolean ansi_enabled, int dtype);
private static native long fromDecimal(long nativeColumnView);
private static native long toIntegersWithBase(long nativeColumnView, int base,
boolean ansiEnabled, int dtype);
private static native long fromIntegersWithBase(long nativeColumnView, int base);
}
126 changes: 116 additions & 10 deletions src/test/java/com/nvidia/spark/rapids/jni/CastStringsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,18 @@

package com.nvidia.spark.rapids.jni;

import ai.rapids.cudf.AssertUtils;
import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.DType;
import ai.rapids.cudf.Table;
import com.nvidia.spark.rapids.jni.CastException;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.stream.IntStream;
import java.util.ArrayList;
import java.util.List;

import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.Test;

import ai.rapids.cudf.AssertUtils;
import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.DType;
import ai.rapids.cudf.Table;

public class CastStringsTest {
@Test
Expand Down Expand Up @@ -194,4 +192,112 @@ void castToDecimalNoStripTest() {
}
}
}


@Test
void baseDec2HexTest() {
try(
Table input = new Table.TestBuilder().column(
null,
" ",
"junk-510junk510",
"--510",
" -510junk510",
" 510junk510",
"510",
"00510",
"00-510"
).build();

Table expected = new Table.TestBuilder().column(
null,
null,
"0",
"0",
"18446744073709551106",
"510",
"510",
"510",
"0"
).column(
null,
null,
"0",
"0",
"FFFFFFFFFFFFFE02",
"1FE",
"1FE",
"1FE",
"0"
).build();

ColumnVector intCol = CastStrings.toIntegersWithBase(input.getColumn(0), 10, false,
DType.UINT64);
ColumnVector decStrCol = CastStrings.fromIntegersWithBase(intCol, 10);
ColumnVector hexStrCol = CastStrings.fromIntegersWithBase(intCol, 16);
) {
ai.rapids.cudf.TableDebug.get().debug("intCol", intCol);
AssertUtils.assertColumnsAreEqual(expected.getColumn(0), decStrCol, "decStrCol");
AssertUtils.assertColumnsAreEqual(expected.getColumn(1), hexStrCol, "hexStrCol");
}
}

@Test
void baseHex2DecTest() {
try(
Table input = new Table.TestBuilder().column(
null,
"junk",
"0",
"f",
"junk-5Ajunk5A",
"--5A",
" -5Ajunk5A",
" 5Ajunk5A",
"5a",
"05a",
"005a",
"00-5a",
"NzGGImWNRh"
).build();

Table expected = new Table.TestBuilder().column(
null,
"0",
"0",
"15",
"0",
"0",
"18446744073709551526",
"90",
"90",
"90",
"90",
"0",
"0"
).column(
null,
"0",
"0",
"F",
"0",
"0",
"FFFFFFFFFFFFFFA6",
"5A",
"5A",
"5A",
"5A",
"0",
"0"
).build();

ColumnVector intCol = CastStrings.toIntegersWithBase(input.getColumn(0), 16, false, DType.UINT64);
ColumnVector decStrCol = CastStrings.fromIntegersWithBase(intCol, 10);
ColumnVector hexStrCol = CastStrings.fromIntegersWithBase(intCol, 16);
) {
ai.rapids.cudf.TableDebug.get().debug("intCol", intCol);
AssertUtils.assertColumnsAreEqual(expected.getColumn(0), decStrCol, "decStrCol");
AssertUtils.assertColumnsAreEqual(expected.getColumn(1), hexStrCol, "hexStrCol");
}
}
}

0 comments on commit 72cb837

Please sign in to comment.