-
Notifications
You must be signed in to change notification settings - Fork 78
Add a protocol buffer decode kernel with limited features #4107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Haoyang Li <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds a GPU-accelerated protocol buffer decoder with intentionally limited features, focusing on simple scalar field types. The implementation provides a JNI interface for decoding binary protobuf messages into cuDF STRUCT columns.
Key changes:
- Implements GPU kernels for decoding protobuf varint, fixed32/64, and length-delimited (string) fields
- Adds JNI bindings between Java and CUDA implementation
- Provides basic test coverage for INT64 and STRING field types
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 14 comments.
Show a summary per file
| File | Description |
|---|---|
| src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java | Java API providing decodeToStruct() method with parameter validation |
| src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java | Basic test case covering varint (INT64) and string decoding with missing fields and null messages |
| src/main/cpp/src/protobuf_simple.hpp | C++ API declaration with documentation of supported types |
| src/main/cpp/src/protobuf_simple.cu | CUDA implementation with three specialized kernels for varint, fixed-width, and string extraction |
| src/main/cpp/src/ProtobufSimpleJni.cpp | JNI bridge translating Java arrays to C++ vectors and invoking decode logic |
| src/main/cpp/CMakeLists.txt | Build configuration adding new source files to compilation targets |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| break; | ||
| } | ||
| cur += n; | ||
| value = static_cast<OutT>(v); |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing zigzag decoding for signed integers: The test comment says "int64 id = 1" which suggests a signed int64 field type. Protocol buffer's signed integer types (sint32, sint64) use zigzag encoding where the value must be decoded as (n >>> 1) ^ -(n & 1) after varint decoding. The current implementation only performs raw varint decoding and casts to the output type, which will produce incorrect results for negative values when using sint32 or sint64 field types. Note that if the protobuf definition uses int32/int64 (not sint32/sint64), then no zigzag decoding is needed, but the API documentation should clarify which protobuf types are supported.
src/main/cpp/src/protobuf_simple.hpp
Outdated
| * Supported output child types: | ||
| * - BOOL8, INT32, INT64, FLOAT32, FLOAT64, STRING | ||
| * |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing documentation for supported protobuf types: The API documentation lists supported output types (BOOL8, INT32, INT64, FLOAT32, FLOAT64, STRING) but doesn't clarify which protobuf wire types or field types they map to. For example, does INT32/INT64 support sint32/sint64 (zigzag encoded) or only int32/int64 (standard varint)? Does it support fixed32/fixed64 or only varint encoding? This information is critical for correct usage.
| * Supported output child types: | |
| * - BOOL8, INT32, INT64, FLOAT32, FLOAT64, STRING | |
| * | |
| * Supported output child types (cudf dtypes) and corresponding protobuf field types: | |
| * - BOOL8 : protobuf `bool` (varint wire type) | |
| * - INT32 : protobuf `int32` (standard varint encoding) | |
| * - INT64 : protobuf `int64` (standard varint encoding) | |
| * - FLOAT32 : protobuf `float` (fixed32 wire type) | |
| * - FLOAT64 : protobuf `double` (fixed64 wire type) | |
| * - STRING : protobuf `string` (length-delimited wire type, UTF-8 text) | |
| * | |
| * Integer handling: | |
| * - For INT32/INT64, only the standard varint-encoded `int32` / `int64` field types are supported. | |
| * - Zigzag-encoded `sint32` / `sint64`, and fixed-width `fixed32` / `fixed64` / `sfixed32` / `sfixed64` | |
| * field types are not interpreted and must not be used for INT32/INT64 outputs. | |
| * |
| @Test | ||
| void decodeVarintAndStringToStruct() { | ||
| // message Msg { int64 id = 1; string name = 2; } | ||
| // Row0: id=100, name="alice" | ||
| Byte[] row0 = concat( | ||
| new Byte[]{(byte) 0x08}, // field 1, varint | ||
| box(encodeVarint(100)), | ||
| new Byte[]{(byte) 0x12}, // field 2, len-delimited | ||
| box(encodeVarint(5)), | ||
| box("alice".getBytes())); | ||
|
|
||
| // Row1: id=200, name missing | ||
| Byte[] row1 = concat( | ||
| new Byte[]{(byte) 0x08}, | ||
| box(encodeVarint(200))); | ||
|
|
||
| // Row2: null input message | ||
| Byte[] row2 = null; | ||
|
|
||
| try (Table input = new Table.TestBuilder().column(row0, row1, row2).build(); | ||
| ColumnVector expectedId = ColumnVector.fromBoxedLongs(100L, 200L, null); | ||
| ColumnVector expectedName = ColumnVector.fromStrings("alice", null, null); | ||
| ColumnVector expectedStruct = ColumnVector.makeStruct(expectedId, expectedName); | ||
| ColumnVector actualStruct = ProtobufSimple.decodeToStruct( | ||
| input.getColumn(0), | ||
| new int[]{1, 2}, | ||
| new int[]{DType.INT64.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, | ||
| new int[]{0, 0})) { | ||
| AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); | ||
| } | ||
| } |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incomplete test coverage for error conditions: The test only covers the happy path with valid protobuf data, missing fields, and null messages. It should also test error cases like malformed varint encoding, truncated messages, incorrect wire types for fields, and overflow conditions to verify the error handling behaves correctly.
| @Test | ||
| void decodeVarintAndStringToStruct() { | ||
| // message Msg { int64 id = 1; string name = 2; } | ||
| // Row0: id=100, name="alice" | ||
| Byte[] row0 = concat( | ||
| new Byte[]{(byte) 0x08}, // field 1, varint | ||
| box(encodeVarint(100)), | ||
| new Byte[]{(byte) 0x12}, // field 2, len-delimited | ||
| box(encodeVarint(5)), | ||
| box("alice".getBytes())); | ||
|
|
||
| // Row1: id=200, name missing | ||
| Byte[] row1 = concat( | ||
| new Byte[]{(byte) 0x08}, | ||
| box(encodeVarint(200))); | ||
|
|
||
| // Row2: null input message | ||
| Byte[] row2 = null; | ||
|
|
||
| try (Table input = new Table.TestBuilder().column(row0, row1, row2).build(); | ||
| ColumnVector expectedId = ColumnVector.fromBoxedLongs(100L, 200L, null); | ||
| ColumnVector expectedName = ColumnVector.fromStrings("alice", null, null); | ||
| ColumnVector expectedStruct = ColumnVector.makeStruct(expectedId, expectedName); | ||
| ColumnVector actualStruct = ProtobufSimple.decodeToStruct( | ||
| input.getColumn(0), | ||
| new int[]{1, 2}, | ||
| new int[]{DType.INT64.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, | ||
| new int[]{0, 0})) { | ||
| AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); | ||
| } | ||
| } |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incomplete test coverage for data types: The test only covers INT64 and STRING types. It should also include test cases for BOOL8, INT32, FLOAT32, and FLOAT64 to ensure all supported types work correctly, especially the fixed-width types that use different wire types and extraction kernels.
src/main/cpp/src/protobuf_simple.cu
Outdated
| *error_flag = 1; | ||
| out_pairs[row] = cudf::strings::detail::string_index_pair{nullptr, 0}; | ||
| return; | ||
| } | ||
| uint8_t const* cur = bytes + start; | ||
| uint8_t const* stop = bytes + end; | ||
|
|
||
| cudf::strings::detail::string_index_pair pair{nullptr, 0}; | ||
| while (cur < stop) { | ||
| uint64_t key; | ||
| int key_bytes; | ||
| if (!read_varint(cur, stop, key, key_bytes)) { | ||
| *error_flag = 1; | ||
| break; | ||
| } | ||
| cur += key_bytes; | ||
| int fn = static_cast<int>(key >> 3); | ||
| int wt = static_cast<int>(key & 0x7); | ||
| if (fn == field_number) { | ||
| if (wt != WT_LEN) { | ||
| *error_flag = 1; | ||
| break; | ||
| } | ||
| uint64_t len64; | ||
| int n; | ||
| if (!read_varint(cur, stop, len64, n)) { | ||
| *error_flag = 1; | ||
| break; | ||
| } | ||
| cur += n; | ||
| if (len64 > static_cast<uint64_t>(stop - cur)) { | ||
| *error_flag = 1; | ||
| break; | ||
| } | ||
| pair.first = reinterpret_cast<char const*>(cur); | ||
| pair.second = static_cast<cudf::size_type>(len64); | ||
| cur += static_cast<int>(len64); | ||
| // Continue scanning to allow "last one wins". | ||
| } else { | ||
| uint8_t const* next; | ||
| if (!skip_field(cur, stop, wt, next)) { | ||
| *error_flag = 1; |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Race condition: Multiple threads may write to the same error_flag memory location simultaneously without synchronization. This should use atomicOr or atomicExch to ensure thread-safe updates when multiple threads detect errors concurrently.
| for (std::size_t i = 0; i < out_types.size(); ++i) { | ||
| auto const fn = field_numbers[i]; | ||
| auto const dt = out_types[i]; | ||
| switch (dt.id()) { | ||
| case cudf::type_id::BOOL8: { | ||
| rmm::device_uvector<uint8_t> out(rows, stream, mr); | ||
| rmm::device_uvector<bool> valid(rows, stream, mr); | ||
| extract_varint_kernel<uint8_t><<<blocks, threads, 0, stream.value()>>>( | ||
| *d_in, fn, out.data(), valid.data(), d_error.data()); | ||
| auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); | ||
| children.push_back( | ||
| std::make_unique<cudf::column>(dt, rows, out.release(), std::move(mask), null_count)); | ||
| break; | ||
| } | ||
| case cudf::type_id::INT32: { | ||
| rmm::device_uvector<int32_t> out(rows, stream, mr); | ||
| rmm::device_uvector<bool> valid(rows, stream, mr); | ||
| extract_varint_kernel<int32_t><<<blocks, threads, 0, stream.value()>>>( | ||
| *d_in, fn, out.data(), valid.data(), d_error.data()); | ||
| auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); | ||
| children.push_back( | ||
| std::make_unique<cudf::column>(dt, rows, out.release(), std::move(mask), null_count)); | ||
| break; | ||
| } | ||
| case cudf::type_id::INT64: { | ||
| rmm::device_uvector<int64_t> out(rows, stream, mr); | ||
| rmm::device_uvector<bool> valid(rows, stream, mr); | ||
| extract_varint_kernel<int64_t><<<blocks, threads, 0, stream.value()>>>( | ||
| *d_in, fn, out.data(), valid.data(), d_error.data()); | ||
| auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); | ||
| children.push_back( | ||
| std::make_unique<cudf::column>(dt, rows, out.release(), std::move(mask), null_count)); | ||
| break; | ||
| } | ||
| case cudf::type_id::FLOAT32: { | ||
| rmm::device_uvector<float> out(rows, stream, mr); | ||
| rmm::device_uvector<bool> valid(rows, stream, mr); | ||
| extract_fixed_kernel<float, WT_32BIT><<<blocks, threads, 0, stream.value()>>>( | ||
| *d_in, fn, out.data(), valid.data(), d_error.data()); | ||
| auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); | ||
| children.push_back( | ||
| std::make_unique<cudf::column>(dt, rows, out.release(), std::move(mask), null_count)); | ||
| break; | ||
| } | ||
| case cudf::type_id::FLOAT64: { | ||
| rmm::device_uvector<double> out(rows, stream, mr); | ||
| rmm::device_uvector<bool> valid(rows, stream, mr); | ||
| extract_fixed_kernel<double, WT_64BIT><<<blocks, threads, 0, stream.value()>>>( | ||
| *d_in, fn, out.data(), valid.data(), d_error.data()); | ||
| auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); | ||
| children.push_back( | ||
| std::make_unique<cudf::column>(dt, rows, out.release(), std::move(mask), null_count)); | ||
| break; | ||
| } | ||
| case cudf::type_id::STRING: { | ||
| rmm::device_uvector<cudf::strings::detail::string_index_pair> pairs(rows, stream, mr); | ||
| extract_string_kernel<<<blocks, threads, 0, stream.value()>>>(*d_in, fn, pairs.data(), d_error.data()); | ||
| children.push_back(cudf::strings::detail::make_strings_column( | ||
| pairs.begin(), pairs.end(), stream, mr)); | ||
| break; | ||
| } | ||
| default: CUDF_FAIL("Unsupported output type for protobuf_simple"); | ||
| } | ||
| } |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing kernel launch error check: After launching CUDA kernels, the code should check for kernel launch errors using CUDF_CUDA_TRY(cudaPeekAtLastError()) to catch configuration errors or launch failures immediately. Currently, errors would only be detected later during the memcpy synchronization, making debugging more difficult.
src/main/cpp/src/protobuf_simple.cu
Outdated
| auto end = in.offset_at(row + 1) - base; | ||
| // Defensive bounds checks: if offsets are inconsistent, avoid illegal memory access. | ||
| if (start < 0 || end < start || end > child.size()) { | ||
| *error_flag = 1; |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Race condition: Multiple threads may write to the same error_flag memory location simultaneously without synchronization. This should use atomicOr or atomicExch to ensure thread-safe updates when multiple threads detect errors concurrently.
| *error_flag = 1; | |
| atomicExch(error_flag, 1); |
src/main/cpp/src/protobuf_simple.cu
Outdated
| *error_flag = 1; | ||
| break; | ||
| } | ||
| cur += key_bytes; | ||
| int fn = static_cast<int>(key >> 3); | ||
| int wt = static_cast<int>(key & 0x7); | ||
| if (fn == field_number) { | ||
| if (wt != WT_VARINT) { | ||
| *error_flag = 1; |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Race condition: Multiple threads may write to the same error_flag memory location simultaneously without synchronization. This should use atomicOr or atomicExch to ensure thread-safe updates when multiple threads detect errors concurrently.
src/main/cpp/src/protobuf_simple.cu
Outdated
| *error_flag = 1; | ||
| break; | ||
| } | ||
| cur += n; | ||
| value = static_cast<OutT>(v); | ||
| found = true; | ||
| // Continue scanning to allow "last one wins" semantics. | ||
| } else { | ||
| uint8_t const* next; | ||
| if (!skip_field(cur, stop, wt, next)) { | ||
| *error_flag = 1; |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Race condition: Multiple threads may write to the same error_flag memory location simultaneously without synchronization. This should use atomicOr or atomicExch to ensure thread-safe updates when multiple threads detect errors concurrently.
src/main/cpp/src/protobuf_simple.cu
Outdated
| *error_flag = 1; | ||
| valid[row] = false; | ||
| return; | ||
| } | ||
| uint8_t const* cur = bytes + start; | ||
| uint8_t const* stop = bytes + end; | ||
|
|
||
| bool found = false; | ||
| OutT value{}; | ||
| while (cur < stop) { | ||
| uint64_t key; | ||
| int key_bytes; | ||
| if (!read_varint(cur, stop, key, key_bytes)) { | ||
| *error_flag = 1; | ||
| break; | ||
| } | ||
| cur += key_bytes; | ||
| int fn = static_cast<int>(key >> 3); | ||
| int wt = static_cast<int>(key & 0x7); | ||
| if (fn == field_number) { | ||
| if (wt != WT) { | ||
| *error_flag = 1; | ||
| break; | ||
| } | ||
| if constexpr (WT == WT_32BIT) { | ||
| if (stop - cur < 4) { *error_flag = 1; break; } | ||
| uint32_t raw = load_le<uint32_t>(cur); | ||
| cur += 4; | ||
| value = *reinterpret_cast<OutT*>(&raw); | ||
| } else { | ||
| if (stop - cur < 8) { *error_flag = 1; break; } | ||
| uint64_t raw = load_le<uint64_t>(cur); | ||
| cur += 8; | ||
| value = *reinterpret_cast<OutT*>(&raw); | ||
| } | ||
| found = true; | ||
| } else { | ||
| uint8_t const* next; | ||
| if (!skip_field(cur, stop, wt, next)) { | ||
| *error_flag = 1; |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Race condition: Multiple threads may write to the same error_flag memory location simultaneously without synchronization. This should use atomicOr or atomicExch to ensure thread-safe updates when multiple threads detect errors concurrently.
Signed-off-by: Haoyang Li <[email protected]>
|
@greptile full review |
Greptile SummaryAdded GPU-accelerated protobuf decoder that extracts scalar fields from binary protobuf messages into CUDF STRUCT columns. Key Changes:
Issues Found:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Java as Java Client
participant JNI as ProtobufSimpleJni
participant Native as decode_protobuf_simple_to_struct
participant Kernel as CUDA Kernels
Java->>JNI: decodeToStruct(binaryInput, fieldNumbers, typeIds, typeScales)
JNI->>JNI: Validate parameters (null checks, array length)
JNI->>JNI: Convert Java arrays to native vectors
JNI->>JNI: Create data_type objects from typeIds/scales
JNI->>Native: decode_protobuf_simple_to_struct()
Native->>Native: Validate input column type (LIST<INT8/UINT8>)
Native->>Native: Create device column view
Native->>Native: Allocate error flag device vector
loop For each output field
Native->>Native: Determine field type and encoding
alt Varint types (BOOL8, INT32/64, UINT32/64)
Native->>Kernel: Launch extract_varint_kernel
Kernel->>Kernel: Parse protobuf key-value pairs
Kernel->>Kernel: Decode varint with optional zigzag
Kernel->>Kernel: Write output and validity
else Fixed types (FLOAT32/64, fixed32/64)
Native->>Kernel: Launch extract_fixed_kernel
Kernel->>Kernel: Parse protobuf key-value pairs
Kernel->>Kernel: Read fixed-width little-endian value
Kernel->>Kernel: Write output and validity
else String/Bytes types
Native->>Kernel: Launch extract_string_kernel
Kernel->>Kernel: Parse protobuf key-value pairs
Kernel->>Kernel: Extract length-delimited data
Kernel->>Kernel: Write string_index_pairs
end
Native->>Native: Create null mask from validity buffer
Native->>Native: Add child column to results
end
Native->>Native: Check error flag (fail if errors & failOnErrors)
Native->>Native: Create STRUCT column from children
Native->>JNI: Return column pointer
JNI->>Java: Return ColumnVector handle
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (6)
-
src/main/cpp/src/protobuf_simple.cu, line 89-91 (link)logic: potential overflow:
len64can be up to 2^64-1, but casting tointon line 90 can overflow iflen64 > INT_MAX -
src/main/cpp/src/protobuf_simple.cu, line 323-324 (link)logic: potential overflow:
len64can be larger thanINT_MAX, but casting tointon line 324 will overflow -
src/main/cpp/src/protobuf_simple.cu, line 375-376 (link)logic: race condition: multiple threads write to
*error_flagwithout atomics, causing undefined behavior when multiple threads encounter errors simultaneouslyThen in kernels, use
atomicOr(error_flag, 1)instead of*error_flag = 1 -
src/main/cpp/src/protobuf_simple.cu, line 398-407 (link)logic: protobuf uses zigzag encoding for signed integers (sint32/sint64), but varint decoding here treats them as unsigned - decoding negative values will produce incorrect results. Are you only supporting unsigned int32/int64, or should zigzag decoding be implemented for signed types?
-
src/main/cpp/src/protobuf_simple.cu, line 240 (link)syntax: type punning through
reinterpret_castof incompatible pointer types is undefined behavior in C++ -
src/main/cpp/src/protobuf_simple.cu, line 248 (link)syntax: type punning through
reinterpret_castof incompatible pointer types is undefined behavior in C++
6 files reviewed, 6 comments
Signed-off-by: Haoyang Li <[email protected]>
Signed-off-by: Haoyang Li <[email protected]>
|
@greptile full review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (2)
-
src/main/cpp/src/protobuf_simple.cu, line 186 (link)logic: zigzag decode uses signed right shift which is correct for signed types, but applied to unsigned
v -
src/main/cpp/src/ProtobufSimpleJni.cpp, line 55 (link)logic:
encodingsis constructed fromn_type_scalesbut thenout_typesis also constructed usingn_type_scales[i]as the scale parameter, which would be wrong for non-decimal types where this represents encoding
6 files reviewed, 2 comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 13 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/main/cpp/src/protobuf_simple.cu
Outdated
| auto end = in.offset_at(row + 1) - base; | ||
| // Defensive bounds checks: if offsets are inconsistent, avoid illegal memory access. | ||
| if (start < 0 || end < start || end > child.size()) { | ||
| *error_flag = 1; |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error_flag is accessed from multiple threads without atomic operations. When multiple threads encounter errors simultaneously, the non-atomic write (*error_flag = 1) can lead to a race condition. While the final outcome (error_flag being set to 1) may be correct, this violates best practices for concurrent access. Consider using atomicOr or atomicExch to safely update the shared error flag.
src/main/cpp/src/protobuf_simple.cu
Outdated
| *error_flag = 1; | ||
| break; | ||
| } | ||
| cur += key_bytes; | ||
| int fn = static_cast<int>(key >> 3); | ||
| int wt = static_cast<int>(key & 0x7); | ||
| if (fn == 0) { | ||
| *error_flag = 1; | ||
| break; | ||
| } | ||
| if (fn == field_number) { | ||
| if (wt != WT_VARINT) { | ||
| *error_flag = 1; | ||
| break; | ||
| } | ||
| uint64_t v; | ||
| int n; | ||
| if (!read_varint(cur, stop, v, n)) { | ||
| *error_flag = 1; | ||
| break; | ||
| } | ||
| cur += n; | ||
| if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } | ||
| value = static_cast<OutT>(v); | ||
| found = true; | ||
| // Continue scanning to allow "last one wins" semantics. | ||
| } else { | ||
| uint8_t const* next; | ||
| if (!skip_field(cur, stop, wt, next)) { | ||
| *error_flag = 1; |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error_flag is accessed from multiple threads without atomic operations. Multiple concurrent writes could create a race condition. Use atomicOr or atomicExch to ensure thread-safe updates to the shared error flag.
src/main/cpp/src/protobuf_simple.cu
Outdated
| *error_flag = 1; | ||
| valid[row] = false; | ||
| return; | ||
| } | ||
| uint8_t const* cur = bytes + start; | ||
| uint8_t const* stop = bytes + end; | ||
|
|
||
| bool found = false; | ||
| OutT value{}; | ||
| while (cur < stop) { | ||
| uint64_t key; | ||
| int key_bytes; | ||
| if (!read_varint(cur, stop, key, key_bytes)) { | ||
| *error_flag = 1; | ||
| break; | ||
| } | ||
| cur += key_bytes; | ||
| int fn = static_cast<int>(key >> 3); | ||
| int wt = static_cast<int>(key & 0x7); | ||
| if (fn == 0) { | ||
| *error_flag = 1; | ||
| break; | ||
| } | ||
| if (fn == field_number) { | ||
| if (wt != WT) { | ||
| *error_flag = 1; | ||
| break; | ||
| } | ||
| if constexpr (WT == WT_32BIT) { | ||
| if (stop - cur < 4) { | ||
| *error_flag = 1; | ||
| break; | ||
| } | ||
| uint32_t raw = load_le<uint32_t>(cur); | ||
| cur += 4; | ||
| value = *reinterpret_cast<OutT*>(&raw); | ||
| } else { | ||
| if (stop - cur < 8) { | ||
| *error_flag = 1; | ||
| break; | ||
| } | ||
| uint64_t raw = load_le<uint64_t>(cur); | ||
| cur += 8; | ||
| value = *reinterpret_cast<OutT*>(&raw); | ||
| } | ||
| found = true; | ||
| } else { | ||
| uint8_t const* next; | ||
| if (!skip_field(cur, stop, wt, next)) { | ||
| *error_flag = 1; |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error_flag is accessed from multiple threads without atomic operations. Multiple concurrent writes could create a race condition. Use atomicOr or atomicExch to ensure thread-safe updates to the shared error flag.
src/main/cpp/src/protobuf_simple.cu
Outdated
| *error_flag = 1; | ||
| out_pairs[row] = cudf::strings::detail::string_index_pair{nullptr, 0}; | ||
| return; | ||
| } | ||
| uint8_t const* cur = bytes + start; | ||
| uint8_t const* stop = bytes + end; | ||
|
|
||
| cudf::strings::detail::string_index_pair pair{nullptr, 0}; | ||
| while (cur < stop) { | ||
| uint64_t key; | ||
| int key_bytes; | ||
| if (!read_varint(cur, stop, key, key_bytes)) { | ||
| *error_flag = 1; | ||
| break; | ||
| } | ||
| cur += key_bytes; | ||
| int fn = static_cast<int>(key >> 3); | ||
| int wt = static_cast<int>(key & 0x7); | ||
| if (fn == 0) { | ||
| *error_flag = 1; | ||
| break; | ||
| } | ||
| if (fn == field_number) { | ||
| if (wt != WT_LEN) { | ||
| *error_flag = 1; | ||
| break; | ||
| } | ||
| uint64_t len64; | ||
| int n; | ||
| if (!read_varint(cur, stop, len64, n)) { | ||
| *error_flag = 1; | ||
| break; | ||
| } | ||
| cur += n; | ||
| if (len64 > static_cast<uint64_t>(stop - cur)) { | ||
| *error_flag = 1; | ||
| break; | ||
| } | ||
| pair.first = reinterpret_cast<char const*>(cur); | ||
| pair.second = static_cast<cudf::size_type>(len64); | ||
| cur += static_cast<int>(len64); | ||
| // Continue scanning to allow "last one wins". | ||
| } else { | ||
| uint8_t const* next; | ||
| if (!skip_field(cur, stop, wt, next)) { | ||
| *error_flag = 1; |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error_flag is accessed from multiple threads without atomic operations. Multiple concurrent writes could create a race condition. Use atomicOr or atomicExch to ensure thread-safe updates to the shared error flag.
| CUDF_EXPECTS(field_numbers.size() == out_types.size(), | ||
| "field_numbers and out_types must have the same length"); | ||
| CUDF_EXPECTS(encodings.size() == out_types.size(), | ||
| "encodings and out_types must have the same length"); |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no validation that field_numbers contain positive values. Protobuf field numbers must be positive integers (1 to 536,870,911, excluding 19000-19999). Passing a non-positive field number would cause the kernel to never match any field (since fn == 0 is treated as an error in the kernel), resulting in silently missing data. Consider adding validation in the host code to reject invalid field numbers.
|
|
||
| try (Table input = new Table.TestBuilder().column(row0).build(); | ||
| ColumnVector expectedU32 = ColumnVector.fromBoxedLongs(4000000000L); // cuDF doesn't have boxed UInt32 easily, use Longs for test if needed, but we want native id | ||
| // Wait, I'll use direct values to avoid Boxing issues with UInt32 |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment on lines 133-134 appears to be leftover draft text ("Wait, I'll use direct values...") that should either be removed or clarified. This comment doesn't add value and suggests the code is still in draft state.
| // Wait, I'll use direct values to avoid Boxing issues with UInt32 |
| children.reserve(out_types.size()); | ||
|
|
||
| auto const threads = 256; | ||
| auto const blocks = static_cast<int>((rows + threads - 1) / threads); |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When rows equals 0, blocks will be 0, which means no kernel will launch. While this is technically correct behavior (no work to do), it may cause issues if downstream code expects valid output columns. Consider adding an explicit check for the zero-row case to ensure proper column initialization or early return if that's the intended behavior.
| if (enc == ENC_ZIGZAG) { | ||
| extract_varint_kernel<int64_t, true><<<blocks, threads, 0, stream.value()>>>( | ||
| *d_in, fn, out.data(), valid.data(), d_error.data()); | ||
| } else if (enc == ENC_FIXED) { | ||
| extract_fixed_kernel<int64_t, WT_64BIT><<<blocks, threads, 0, stream.value()>>>( | ||
| *d_in, fn, out.data(), valid.data(), d_error.data()); | ||
| } else if (enc == ENC_DEFAULT) { | ||
| extract_varint_kernel<int64_t, false><<<blocks, threads, 0, stream.value()>>>( | ||
| *d_in, fn, out.data(), valid.data(), d_error.data()); | ||
| } else { | ||
| CUDF_FAIL("Unsupported encoding for INT64 protobuf field"); | ||
| } | ||
| auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); | ||
| children.push_back( | ||
| std::make_unique<cudf::column>(dt, rows, out.release(), std::move(mask), null_count)); | ||
| break; | ||
| } | ||
| case cudf::type_id::UINT64: { | ||
| rmm::device_uvector<uint64_t> out(rows, stream, mr); | ||
| rmm::device_uvector<bool> valid(rows, stream, mr); | ||
| if (enc == ENC_FIXED) { | ||
| extract_fixed_kernel<uint64_t, WT_64BIT><<<blocks, threads, 0, stream.value()>>>( | ||
| *d_in, fn, out.data(), valid.data(), d_error.data()); | ||
| } else if (enc == ENC_DEFAULT) { | ||
| extract_varint_kernel<uint64_t><<<blocks, threads, 0, stream.value()>>>( | ||
| *d_in, fn, out.data(), valid.data(), d_error.data()); | ||
| } else { | ||
| CUDF_FAIL("Unsupported encoding for UINT64 protobuf field"); | ||
| } | ||
| auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); | ||
| children.push_back( | ||
| std::make_unique<cudf::column>(dt, rows, out.release(), std::move(mask), null_count)); | ||
| break; | ||
| } | ||
| case cudf::type_id::FLOAT32: { | ||
| rmm::device_uvector<float> out(rows, stream, mr); | ||
| rmm::device_uvector<bool> valid(rows, stream, mr); | ||
| if (enc == ENC_DEFAULT) { | ||
| extract_fixed_kernel<float, WT_32BIT><<<blocks, threads, 0, stream.value()>>>( | ||
| *d_in, fn, out.data(), valid.data(), d_error.data()); | ||
| } else { | ||
| CUDF_FAIL("Unsupported encoding for FLOAT32 protobuf field"); | ||
| } | ||
| auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); | ||
| children.push_back( | ||
| std::make_unique<cudf::column>(dt, rows, out.release(), std::move(mask), null_count)); | ||
| break; | ||
| } | ||
| case cudf::type_id::FLOAT64: { | ||
| rmm::device_uvector<double> out(rows, stream, mr); | ||
| rmm::device_uvector<bool> valid(rows, stream, mr); | ||
| if (enc == ENC_DEFAULT) { | ||
| extract_fixed_kernel<double, WT_64BIT><<<blocks, threads, 0, stream.value()>>>( | ||
| *d_in, fn, out.data(), valid.data(), d_error.data()); | ||
| } else { | ||
| CUDF_FAIL("Unsupported encoding for FLOAT64 protobuf field"); | ||
| } | ||
| auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); | ||
| children.push_back( | ||
| std::make_unique<cudf::column>(dt, rows, out.release(), std::move(mask), null_count)); | ||
| break; | ||
| } | ||
| case cudf::type_id::STRING: { | ||
| rmm::device_uvector<cudf::strings::detail::string_index_pair> pairs(rows, stream, mr); | ||
| if (enc == ENC_DEFAULT) { | ||
| extract_string_kernel<<<blocks, threads, 0, stream.value()>>>( | ||
| *d_in, fn, pairs.data(), d_error.data()); | ||
| } else { | ||
| CUDF_FAIL("Unsupported encoding for STRING protobuf field"); | ||
| } | ||
| children.push_back( | ||
| cudf::strings::detail::make_strings_column(pairs.begin(), pairs.end(), stream, mr)); | ||
| break; | ||
| } | ||
| case cudf::type_id::LIST: { | ||
| rmm::device_uvector<cudf::strings::detail::string_index_pair> pairs(rows, stream, mr); | ||
| if (enc == ENC_DEFAULT) { | ||
| extract_string_kernel<<<blocks, threads, 0, stream.value()>>>( | ||
| *d_in, fn, pairs.data(), d_error.data()); | ||
| } else { | ||
| CUDF_FAIL("Unsupported encoding for LIST protobuf field"); | ||
| } | ||
| auto strings = | ||
| cudf::strings::detail::make_strings_column(pairs.begin(), pairs.end(), stream, mr); | ||
| auto const null_count = strings->null_count(); | ||
| auto contents = strings->release(); | ||
| auto null_mask = | ||
| contents.null_mask ? std::move(*contents.null_mask) : rmm::device_buffer{0, stream, mr}; | ||
| children.push_back(cudf::make_lists_column(rows, | ||
| std::move(contents.children[0]), | ||
| std::move(contents.children[1]), | ||
| null_count, | ||
| std::move(null_mask), | ||
| stream, | ||
| mr)); | ||
| break; | ||
| } | ||
| default: CUDF_FAIL("Unsupported output type for protobuf_simple"); | ||
| } | ||
| } |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation launches a separate kernel for each output field (one per iteration of the loop), which means the entire input is scanned multiple times. For messages with many fields, this could be inefficient. While this may be acceptable for a "simple" first implementation, consider noting this performance limitation in the documentation or adding a TODO comment about potential optimization (e.g., single-pass kernel that extracts all fields).
| std::vector<cudf::data_type> out_types; | ||
| out_types.reserve(n_type_ids.size()); | ||
| for (int i = 0; i < n_type_ids.size(); ++i) { | ||
| out_types.emplace_back(cudf::jni::make_data_type(n_type_ids[i], n_type_scales[i])); |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The typeScales parameter serves dual purposes: encoding information for protobuf fields and scale for data type construction. When creating data_type via make_data_type with encoding values like ENC_FIXED=1 or ENC_ZIGZAG=2 as the scale parameter, verify that make_data_type properly ignores or handles these values for non-decimal types. If make_data_type interprets these as actual scale values for types like INT32/INT64, it could create incorrect type metadata. Consider either: (1) always passing 0 as scale when creating data types for non-decimal types, or (2) documenting that make_data_type safely ignores scale for such types.
| out_types.emplace_back(cudf::jni::make_data_type(n_type_ids[i], n_type_scales[i])); | |
| auto const type_id = static_cast<cudf::type_id>(n_type_ids[i]); | |
| auto const scale = | |
| (type_id == cudf::type_id::DECIMAL32 || type_id == cudf::type_id::DECIMAL64 || | |
| type_id == cudf::type_id::DECIMAL128) | |
| ? n_type_scales[i] | |
| : 0; | |
| out_types.emplace_back(cudf::jni::make_data_type(n_type_ids[i], scale)); |
src/main/cpp/src/protobuf_simple.cu
Outdated
| if (!read_varint(out_cur, end, len64, n)) return false; | ||
| out_cur += n; | ||
| if (len64 > static_cast<uint64_t>(end - out_cur)) return false; | ||
| out_cur += static_cast<int>(len64); |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The cast from uint64_t to int can cause integer overflow for large length-delimited fields. If len64 is greater than INT_MAX (approximately 2GB), the cast will produce a negative or incorrect value, potentially causing memory corruption or out-of-bounds access. Consider checking that len64 fits within a safe range before casting, or use a larger integral type for the offset calculation.
Signed-off-by: Haoyang Li <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 10 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/main/cpp/src/protobuf_simple.hpp
Outdated
| * @param out_types output cudf data types (one per output child) | ||
| * @param encodings encoding type for each field (0=default, 1=fixed, 2=zigzag) | ||
| * @param fail_on_errors whether to throw on malformed messages | ||
| * @return STRUCT column with the given children types, with nullability propagated from input rows |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation in the header file states that the STRUCT column will have "nullability propagated from input rows", but the implementation explicitly does NOT propagate input nulls to the output STRUCT validity (see lines 590-597 in the .cu file). The STRUCT row is always valid regardless of whether the input message is null. This discrepancy between the documentation and implementation should be corrected. Based on the implementation and test expectations, the header documentation should be updated to clarify that the STRUCT itself is always non-null, and only individual child fields can be null.
| * @return STRUCT column with the given children types, with nullability propagated from input rows | |
| * @return STRUCT column with the given children types; the STRUCT itself is always non-null, | |
| * and only individual child fields may be null (including due to null input rows) |
| case cudf::type_id::LIST: { | ||
| rmm::device_uvector<cudf::strings::detail::string_index_pair> pairs(rows, stream, mr); | ||
| if (enc == ENC_DEFAULT) { | ||
| extract_string_kernel<<<blocks, threads, 0, stream.value()>>>( | ||
| *d_in, fn, pairs.data(), d_error.data()); | ||
| } else { | ||
| CUDF_FAIL("Unsupported encoding for LIST protobuf field"); | ||
| } | ||
| auto strings = | ||
| cudf::strings::detail::make_strings_column(pairs.begin(), pairs.end(), stream, mr); | ||
| auto const null_count = strings->null_count(); | ||
| auto contents = strings->release(); | ||
| auto null_mask = | ||
| contents.null_mask ? std::move(*contents.null_mask) : rmm::device_buffer{0, stream, mr}; | ||
| children.push_back(cudf::make_lists_column(rows, | ||
| std::move(contents.children[0]), | ||
| std::move(contents.children[1]), | ||
| null_count, | ||
| std::move(null_mask), | ||
| stream, | ||
| mr)); | ||
| break; |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The LIST type handling reuses the string extraction kernel and then converts string column internals to a list column (lines 551-572). This is clever but undocumented and non-obvious. The documentation should explain this implementation detail, and the code comment should clarify that bytes fields are internally treated as strings and then re-interpreted as LIST<INT8> for better maintainability.
| children.reserve(out_types.size()); | ||
|
|
||
| auto const threads = 256; | ||
| auto const blocks = static_cast<int>((rows + threads - 1) / threads); |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The blocks calculation casts the result to int, which could overflow if the input has more than approximately 2^31 * 256 rows. While this is an extremely large number, cudf::size_type is typically int32_t, so this may not be a practical issue. However, for robustness and consistency with CUDA best practices, consider using a safer calculation or at least adding a check that rows doesn't exceed INT_MAX before computing blocks.
| ColumnVector expectedU32 = ColumnVector.fromBoxedLongs(4000000000L); // cuDF doesn't have boxed UInt32 easily, use Longs for test if needed, but we want native id | ||
| // Wait, I'll use direct values to avoid Boxing issues with UInt32 | ||
| ColumnVector expectedS64 = ColumnVector.fromBoxedLongs(-1234567890123L); | ||
| ColumnVector expectedF32 = ColumnVector.fromBoxedInts(12345); | ||
| ColumnVector expectedB = ColumnVector.fromLists( | ||
| new ListType(true, new BasicType(true, DType.INT8)), | ||
| Arrays.asList((byte) 1, (byte) 2, (byte) 3)); | ||
| ColumnVector actualStruct = ProtobufSimple.decodeToStruct( | ||
| input.getColumn(0), | ||
| new int[]{1, 2, 3, 4}, | ||
| new int[]{ | ||
| DType.UINT32.getTypeId().getNativeId(), | ||
| DType.INT64.getTypeId().getNativeId(), | ||
| DType.INT32.getTypeId().getNativeId(), | ||
| DType.LIST.getTypeId().getNativeId()}, | ||
| new int[]{ | ||
| ProtobufSimple.ENC_DEFAULT, | ||
| ProtobufSimple.ENC_ZIGZAG, | ||
| ProtobufSimple.ENC_FIXED, | ||
| ProtobufSimple.ENC_DEFAULT})) { | ||
| // For UINT32, expectedU32 from fromBoxedLongs will be INT64. | ||
| // I should use makeColumn to get exactly the right types for comparison. |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comments at lines 133-134 and 153-154 are confusing and appear to be leftover thoughts from development. The first comment mentions "cuDF doesn't have boxed UInt32 easily" followed by "Wait, I'll use direct values...", and the second comment says "I should use makeColumn to get exactly the right types..." These development notes should be removed or rewritten as clear explanations of why the casting is necessary.
| int h_error = 0; | ||
| CUDF_CUDA_TRY( | ||
| cudaMemcpyAsync(&h_error, d_error.data(), sizeof(int), cudaMemcpyDeviceToHost, stream.value())); | ||
| stream.synchronize(); | ||
| if (fail_on_errors) { | ||
| CUDF_EXPECTS(h_error == 0, "Malformed protobuf message or unsupported wire type"); |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error flag is checked only once after all kernels have been launched (line 582-587). If fail_on_errors is true and an early kernel detects an error, subsequent kernels will still launch and execute. Consider checking cudaPeekAtLastError after each kernel launch to fail fast, or at minimum, check the error flag between kernel launches to avoid unnecessary computation when an error has already occurred.
| if (fieldNumbers == null || typeIds == null || typeScales == null) { | ||
| throw new IllegalArgumentException("fieldNumbers/typeIds/typeScales must be non-null"); | ||
| } | ||
| if (fieldNumbers.length != typeIds.length || fieldNumbers.length != typeScales.length) { | ||
| throw new IllegalArgumentException("fieldNumbers/typeIds/typeScales must be the same length"); | ||
| } |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no validation that the field_numbers, typeIds, and typeScales arrays are non-empty. If all three arrays have length 0, the code would create a STRUCT with no children, which may be valid but is likely an error. Consider adding a check to ensure at least one field is being decoded, or document that empty field arrays are allowed.
src/main/cpp/src/protobuf_simple.cu
Outdated
| auto start = in.offset_at(row) - base; | ||
| auto end = in.offset_at(row + 1) - base; | ||
| if (start < 0 || end < start || end > child.size()) { | ||
| *error_flag = 1; |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent use of atomicExch vs direct assignment for error_flag. At line 313, a non-atomic assignment is used, while all other error flag updates in the file use atomicExch. This is a race condition because multiple threads could write to error_flag concurrently. Change this to atomicExch for thread safety consistency.
| *error_flag = 1; | |
| atomicExch(error_flag, 1); |
| @Test | ||
| void decodeVarintAndStringToStruct() { | ||
| // message Msg { int64 id = 1; string name = 2; } | ||
| // Row0: id=100, name="alice" | ||
| Byte[] row0 = concat( | ||
| new Byte[]{(byte) 0x08}, // field 1, varint | ||
| box(encodeVarint(100)), | ||
| new Byte[]{(byte) 0x12}, // field 2, len-delimited | ||
| box(encodeVarint(5)), | ||
| box("alice".getBytes())); | ||
|
|
||
| // Row1: id=200, name missing | ||
| Byte[] row1 = concat( | ||
| new Byte[]{(byte) 0x08}, | ||
| box(encodeVarint(200))); | ||
|
|
||
| // Row2: null input message | ||
| Byte[] row2 = null; | ||
|
|
||
| try (Table input = new Table.TestBuilder().column(row0, row1, row2).build(); | ||
| ColumnVector expectedId = ColumnVector.fromBoxedLongs(100L, 200L, null); | ||
| ColumnVector expectedName = ColumnVector.fromStrings("alice", null, null); | ||
| ColumnVector expectedStruct = ColumnVector.makeStruct(expectedId, expectedName); | ||
| ColumnVector actualStruct = ProtobufSimple.decodeToStruct( | ||
| input.getColumn(0), | ||
| new int[]{1, 2}, | ||
| new int[]{DType.INT64.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, | ||
| new int[]{0, 0})) { | ||
| AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); | ||
| } | ||
| } | ||
|
|
||
| @Test | ||
| void decodeMoreTypes() { | ||
| // message Msg { | ||
| // uint32 u32 = 1; | ||
| // sint64 s64 = 2; | ||
| // fixed32 f32 = 3; | ||
| // bytes b = 4; | ||
| // } | ||
| Byte[] row0 = concat( | ||
| new Byte[]{(byte) 0x08}, // field 1, varint | ||
| box(encodeVarint(4000000000L)), | ||
| new Byte[]{(byte) 0x10}, // field 2, varint | ||
| box(encodeVarint(zigzagEncode(-1234567890123L))), | ||
| new Byte[]{(byte) 0x1d}, // field 3, fixed32 | ||
| box(encodeFixed32(12345)), | ||
| new Byte[]{(byte) 0x22}, // field 4, len-delimited | ||
| box(encodeVarint(3)), | ||
| box(new byte[]{1, 2, 3})); | ||
|
|
||
| try (Table input = new Table.TestBuilder().column(row0).build(); | ||
| ColumnVector expectedU32 = ColumnVector.fromBoxedLongs(4000000000L); // cuDF doesn't have boxed UInt32 easily, use Longs for test if needed, but we want native id | ||
| // Wait, I'll use direct values to avoid Boxing issues with UInt32 | ||
| ColumnVector expectedS64 = ColumnVector.fromBoxedLongs(-1234567890123L); | ||
| ColumnVector expectedF32 = ColumnVector.fromBoxedInts(12345); | ||
| ColumnVector expectedB = ColumnVector.fromLists( | ||
| new ListType(true, new BasicType(true, DType.INT8)), | ||
| Arrays.asList((byte) 1, (byte) 2, (byte) 3)); | ||
| ColumnVector actualStruct = ProtobufSimple.decodeToStruct( | ||
| input.getColumn(0), | ||
| new int[]{1, 2, 3, 4}, | ||
| new int[]{ | ||
| DType.UINT32.getTypeId().getNativeId(), | ||
| DType.INT64.getTypeId().getNativeId(), | ||
| DType.INT32.getTypeId().getNativeId(), | ||
| DType.LIST.getTypeId().getNativeId()}, | ||
| new int[]{ | ||
| ProtobufSimple.ENC_DEFAULT, | ||
| ProtobufSimple.ENC_ZIGZAG, | ||
| ProtobufSimple.ENC_FIXED, | ||
| ProtobufSimple.ENC_DEFAULT})) { | ||
| // For UINT32, expectedU32 from fromBoxedLongs will be INT64. | ||
| // I should use makeColumn to get exactly the right types for comparison. | ||
| try (ColumnVector expectedU32Correct = expectedU32.castTo(DType.UINT32); | ||
| ColumnVector expectedStructCorrect = ColumnVector.makeStruct(expectedU32Correct, expectedS64, expectedF32, expectedB)) { | ||
| AssertUtils.assertStructColumnsAreEqual(expectedStructCorrect, actualStruct); | ||
| } | ||
| } | ||
| } | ||
| } |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test coverage is limited to only two basic test cases. Missing test coverage for important scenarios including: (1) malformed protobuf messages with invalid varints or field tags, (2) testing the failOnErrors parameter behavior, (3) duplicate field numbers in a message (testing "last one wins" semantics), (4) empty protobuf messages, (5) messages with unknown field numbers, (6) boundary cases like very large field numbers or lengths, (7) float and double types, (8) BOOL type, and (9) all encoding types for each supported type.
| if (fieldNumbers == null || typeIds == null || typeScales == null) { | ||
| throw new IllegalArgumentException("fieldNumbers/typeIds/typeScales must be non-null"); | ||
| } | ||
| if (fieldNumbers.length != typeIds.length || fieldNumbers.length != typeScales.length) { | ||
| throw new IllegalArgumentException("fieldNumbers/typeIds/typeScales must be the same length"); | ||
| } |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no validation that field numbers are positive (greater than 0). Protobuf field numbers must be in the range 1-536870911, with 19000-19999 reserved. While the code checks for field number 0 in the parsed message (lines 172-175), it doesn't validate the input field_numbers array to ensure callers pass valid field numbers. Consider adding validation in the Java layer or the C++ entry point to reject invalid field numbers before processing.
| } | ||
| if (fieldNumbers.length != typeIds.length || fieldNumbers.length != typeScales.length) { | ||
| throw new IllegalArgumentException("fieldNumbers/typeIds/typeScales must be the same length"); | ||
| } |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no validation of encoding values in the Java layer. Invalid encoding values (not 0, 1, or 2) will be passed to the C++ layer where they may cause silent failures or unexpected behavior. Consider adding validation in the Java layer to check that all encoding values are within the valid range (ENC_DEFAULT, ENC_FIXED, or ENC_ZIGZAG) before passing to native code.
| } | |
| } | |
| // Validate encodings for non-decimal types to avoid passing invalid values to native code. | |
| for (int i = 0; i < typeScales.length; i++) { | |
| int enc = typeScales[i]; | |
| if (enc < ENC_DEFAULT || enc > ENC_ZIGZAG) { | |
| throw new IllegalArgumentException( | |
| "Invalid encoding value at index " + i + ": " + enc | |
| + " (expected " + ENC_DEFAULT + ", " + ENC_FIXED + ", or " + ENC_ZIGZAG + ")"); | |
| } | |
| } |
Signed-off-by: Haoyang Li <[email protected]>
Signed-off-by: Haoyang Li <[email protected]>
Signed-off-by: Haoyang Li <[email protected]>
…ven/spark-rapids-jni into protocol_buffer_jni_dev
Signed-off-by: Haoyang Li <[email protected]>
WIP
now a AI draft seeking AI review