Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
gerashegalov committed Aug 5, 2023
1 parent 7420b3f commit 3752093
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 19 deletions.
60 changes: 44 additions & 16 deletions src/main/cpp/src/CastStringJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
*/

#include "cast_string.hpp"
#include <cudf/replace.hpp>
#include <cudf/scalar/scalar.hpp>
#include <cudf/strings/convert/convert_integers.hpp>
#include <cudf/strings/strip.hpp>
#include <cudf/strings/strings_column_view.hpp>

#include "cudf_jni_apis.hpp"
#include "dtype_utils.hpp"
Expand Down Expand Up @@ -113,41 +117,65 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_CastStrings_fromDecimal
CATCH_CAST_EXCEPTION(env, 0);
}

JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_CastStrings_changeRadix(
JNIEnv* env, jclass, jlong input_column, jint fromRadix, jint toRadix)
JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_CastStrings_toIntegerUsingBase(
JNIEnv* env, jclass, jlong input_column, jint base)
{
JNI_NULL_CHECK(env, input_column, "input column is null", 0);

try {
cudf::jni::auto_set_device(env);

cudf::column_view input_view{*reinterpret_cast<cudf::column_view const*>(input_column)};
auto integer_view = [&] {
switch (fromRadix) {
auto input_view{*reinterpret_cast<cudf::column_view const*>(input_column)};
auto integer_view_with_nulls = [&] {
switch (base) {
case 10: {
return cudf::strings::from_integers(input_view);
return cudf::strings::to_integers(input_view, cudf::data_type(cudf::type_id::UINT64));
} break;
case 16: {
return cudf::strings::hex_to_integers(input_view, cudf::data_type(cudf::type_id::UINT64));
}
default: {
return std::unique_ptr<cudf::column>(nullptr); // TODO all zeros
}
}
return std::unique_ptr<cudf::column>(nullptr);
}();

auto result_col = [&] {
switch (toRadix) {
case 16: {
return cudf::strings::integers_to_hex(*integer_view);
} break;
cudf::numeric_scalar<uint64_t> zero(0);
auto integer_view = cudf::replace_nulls(*integer_view_with_nulls, zero);
return cudf::jni::release_as_jlong(integer_view);
}
CATCH_CAST_EXCEPTION(env, 0);
}


JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_CastStrings_fromIntegerUsingBase(
JNIEnv* env, jclass, jlong input_column, jint base)
{
JNI_NULL_CHECK(env, input_column, "input column is null", 0);

try {
cudf::jni::auto_set_device(env);

auto input_view{*reinterpret_cast<cudf::column_view const*>(input_column)};
auto result = [&] {
switch (base) {
case 10: {
return cudf::strings::from_integers(*integer_view);
return cudf::strings::from_integers(input_view);
} break;
case 16: {
auto hex_with_leading_zeros = cudf::strings::integers_to_hex(input_view);
return cudf::strings::strip(
cudf::strings_column_view(*hex_with_leading_zeros),
cudf::strings::side_type::LEFT, cudf::string_scalar("0"));
}
default: {
return std::unique_ptr<cudf::column>(nullptr); // TODO all zeros
}
}
return std::unique_ptr<cudf::column>(nullptr);
}();

return cudf::jni::release_as_jlong(result_col);
return cudf::jni::release_as_jlong(result);
}
CATCH_CAST_EXCEPTION(env, 0);
}

}
11 changes: 8 additions & 3 deletions src/main/java/com/nvidia/spark/rapids/jni/CastStrings.java
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,12 @@ public static ColumnVector toFloat(ColumnView cv, boolean ansiMode, DType type)
}


public static ColumnVector changeRadix(ColumnView cv, int fromRadix, int toRadix) {
return new ColumnVector(changeRadix(cv.getNativeView(), fromRadix, toRadix));
public static ColumnVector toIntegerUsingBase(ColumnView cv, int base) {
return new ColumnVector(toIntegerUsingBase(cv.getNativeView(), base));
}

public static ColumnVector fromIntegerUsingBase(ColumnView cv, int base) {
return new ColumnVector(fromIntegerUsingBase(cv.getNativeView(), base));
}

private static native long toInteger(long nativeColumnView, boolean ansi_enabled, boolean strip,
Expand All @@ -113,5 +117,6 @@ private static native long toDecimal(long nativeColumnView, boolean ansi_enabled
int precision, int scale);
private static native long toFloat(long nativeColumnView, boolean ansi_enabled, int dtype);
private static native long fromDecimal(long nativeColumnView);
private static native long changeRadix(long nativeColumnView, int fromRadix, int toRadix);
private static native long toIntegerUsingBase(long nativeColumnView, int base);
private static native long fromIntegerUsingBase(long nativeColumnView, int base);
}

0 comments on commit 3752093

Please sign in to comment.