Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion integration_tests/pom.xml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
Copyright (c) 2020-2025, NVIDIA CORPORATION.
Copyright (c) 2020-2026, NVIDIA CORPORATION.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -142,6 +142,7 @@
<includes>
<include>parquet-hadoop*.jar</include>
<include>spark-avro*.jar</include>
<include>spark-protobuf*.jar</include>
</includes>
</filesets>
</filesets>
Expand Down Expand Up @@ -176,6 +177,24 @@
</artifactItems>
</configuration>
</execution>
<execution>
<id>copy-spark-protobuf</id>
<phase>package</phase>
<goals>
<goal>copy</goal>
</goals>
<configuration>
<skip>${spark.protobuf.copy.skip}</skip>
<useBaseVersion>true</useBaseVersion>
<artifactItems>
<artifactItem>
<groupId>org.apache.spark</groupId>
<artifactId>spark-protobuf_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
</artifactItem>
</artifactItems>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
Expand Down
25 changes: 22 additions & 3 deletions integration_tests/run_pyspark_from_build.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
# Copyright (c) 2020-2026, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,6 +29,7 @@
# - SPARK_HOME: Path to your Apache Spark installation.
# - SKIP_TESTS: If set to true, skips running the Python integration tests.
# - INCLUDE_SPARK_AVRO_JAR: If set to true, includes Avro tests.
# - INCLUDE_SPARK_PROTOBUF_JAR: If set to true, includes spark-protobuf (Spark 3.4.0+) on the JVM classpath.
# - TEST: Specifies a specific test to run.
# - TEST_TAGS: Allows filtering tests based on tags.
# - TEST_TYPE: Specifies the type of tests to run.
Expand Down Expand Up @@ -100,6 +101,7 @@ else
# support alternate local jars NOT building from the source code
if [ -d "$LOCAL_JAR_PATH" ]; then
AVRO_JARS=$(echo "$LOCAL_JAR_PATH"/spark-avro*.jar)
PROTOBUF_JARS=$(echo "$LOCAL_JAR_PATH"/spark-protobuf*.jar)
PLUGIN_JAR=$(echo "$LOCAL_JAR_PATH"/rapids-4-spark_*.jar)
if [ -f $(echo $LOCAL_JAR_PATH/parquet-hadoop*.jar) ]; then
export INCLUDE_PARQUET_HADOOP_TEST_JAR=true
Expand All @@ -116,6 +118,7 @@ else
else
[[ "$SCALA_VERSION" != "2.12" ]] && TARGET_DIR=${TARGET_DIR/integration_tests/scala$SCALA_VERSION\/integration_tests}
AVRO_JARS=$(echo "$TARGET_DIR"/dependency/spark-avro*.jar)
PROTOBUF_JARS=$(echo "$TARGET_DIR"/dependency/spark-protobuf*.jar)
PARQUET_HADOOP_TESTS=$(echo "$TARGET_DIR"/dependency/parquet-hadoop*.jar)
# remove the log4j.properties file so it doesn't conflict with ours, ignore errors
# if it isn't present or already removed
Expand All @@ -141,9 +144,25 @@ else
AVRO_JARS=""
fi

# ALL_JARS includes dist.jar integration-test.jar avro.jar parquet.jar if they exist
# spark-protobuf is an optional Spark module that exists in Spark 3.4.0+. If we have the jar staged
# under target/dependency, include it so from_protobuf() is callable from PySpark.
if [[ $( echo ${INCLUDE_SPARK_PROTOBUF_JAR:-true} | tr '[:upper:]' '[:lower:]' ) == "true" ]];
then
# VERSION_STRING >= 3.4.0 ?
if printf '%s\n' "3.4.0" "$VERSION_STRING" | sort -V | head -1 | grep -qx "3.4.0"; then
export INCLUDE_SPARK_PROTOBUF_JAR=true
else
export INCLUDE_SPARK_PROTOBUF_JAR=false
PROTOBUF_JARS=""
fi
else
export INCLUDE_SPARK_PROTOBUF_JAR=false
PROTOBUF_JARS=""
fi

# ALL_JARS includes dist.jar integration-test.jar avro.jar protobuf.jar parquet.jar if they exist
# Remove non-existing paths and canonicalize the paths including get rid of links and `..`
ALL_JARS=$(readlink -e $PLUGIN_JAR $TEST_JARS $AVRO_JARS $PARQUET_HADOOP_TESTS || true)
ALL_JARS=$(readlink -e $PLUGIN_JAR $TEST_JARS $AVRO_JARS $PROTOBUF_JARS $PARQUET_HADOOP_TESTS || true)
# `:` separated jars
ALL_JARS="${ALL_JARS//$'\n'/:}"

Expand Down
127 changes: 126 additions & 1 deletion integration_tests/src/main/python/data_gen.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2025, NVIDIA CORPORATION.
# Copyright (c) 2020-2026, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -857,6 +857,131 @@ def gen_bytes():
return bytes([ rand.randint(0, 255) for _ in range(length) ])
self._start(rand, gen_bytes)


# -----------------------------------------------------------------------------
# Protobuf (simple types) generators/utilities (for from_protobuf/to_protobuf tests)
# -----------------------------------------------------------------------------

_PROTOBUF_WIRE_VARINT = 0
_PROTOBUF_WIRE_64BIT = 1
_PROTOBUF_WIRE_LEN_DELIM = 2
_PROTOBUF_WIRE_32BIT = 5

def _encode_protobuf_uvarint(value):
"""Encode a non-negative integer as protobuf varint."""
if value is None:
raise ValueError("value must not be None")
if value < 0:
raise ValueError("uvarint only supports non-negative integers")
out = bytearray()
v = int(value)
while True:
b = v & 0x7F
v >>= 7
if v:
out.append(b | 0x80)
else:
out.append(b)
break
return bytes(out)

def _encode_protobuf_key(field_number, wire_type):
return _encode_protobuf_uvarint((int(field_number) << 3) | int(wire_type))

def _encode_protobuf_field(field_number, spark_type, value):
"""
Encode a single protobuf field for a subset of scalar types.
Notes on signed ints:
- Protobuf `int32`/`int64` use *varint* encoding of the two's-complement integer.
- Negative `int32` values are encoded as a 10-byte varint (because they are sign-extended to 64 bits).
"""
if value is None:
return b""

if isinstance(spark_type, BooleanType):
return _encode_protobuf_key(field_number, _PROTOBUF_WIRE_VARINT) + _encode_protobuf_uvarint(1 if value else 0)
elif isinstance(spark_type, IntegerType):
# Match protobuf-java behavior for writeInt32NoTag: negative values are sign-extended and written as uint64.
u64 = int(value) & 0xFFFFFFFFFFFFFFFF
return _encode_protobuf_key(field_number, _PROTOBUF_WIRE_VARINT) + _encode_protobuf_uvarint(u64)
elif isinstance(spark_type, LongType):
u64 = int(value) & 0xFFFFFFFFFFFFFFFF
return _encode_protobuf_key(field_number, _PROTOBUF_WIRE_VARINT) + _encode_protobuf_uvarint(u64)
elif isinstance(spark_type, FloatType):
return _encode_protobuf_key(field_number, _PROTOBUF_WIRE_32BIT) + struct.pack("<f", float(value))
elif isinstance(spark_type, DoubleType):
return _encode_protobuf_key(field_number, _PROTOBUF_WIRE_64BIT) + struct.pack("<d", float(value))
elif isinstance(spark_type, StringType):
b = value.encode("utf-8")
return (_encode_protobuf_key(field_number, _PROTOBUF_WIRE_LEN_DELIM) +
_encode_protobuf_uvarint(len(b)) + b)
else:
raise ValueError("Unsupported type for protobuf simple generator: {}".format(spark_type))


class ProtobufSimpleMessageRowGen(DataGen):
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class 'ProtobufSimpleMessageRowGen' does not override 'eq', but adds the new attribute _fields.
The class 'ProtobufSimpleMessageRowGen' does not override 'eq', but adds the new attribute _binary_col_name.

Copilot uses AI. Check for mistakes.
"""
Generates rows that include:
- one column per message field (Spark scalar types)
- a binary column containing a serialized protobuf message containing those fields

This is intentionally limited to the simple scalar types currently supported:
boolean/int32/int64/float/double/string.

Fields are omitted from the encoded message if the corresponding value is None.
"""
def __init__(self, fields, binary_col_name="bin", nullable=False):
"""
fields: list of (field_name, field_number, DataGen)
"""
self._fields = fields
self._binary_col_name = binary_col_name

struct_fields = []
for (name, _num, gen) in fields:
struct_fields.append(StructField(name, gen.data_type, nullable=gen.nullable))
struct_fields.append(StructField(binary_col_name, BinaryType(), nullable=True))
super().__init__(StructType(struct_fields), nullable=nullable)

def __repr__(self):
return "ProtobufSimpleMessageRowGen({})".format(
",".join(["{}#{}".format(n, num) for (n, num, _g) in self._fields]))

def _cache_repr(self):
kids = ",".join(["{}:{}#{}".format(n, str(g.data_type), num) for (n, num, g) in self._fields])
return super()._cache_repr() + "(" + kids + "," + self._binary_col_name + ")"

def __eq__(self, other):
if not isinstance(other, ProtobufSimpleMessageRowGen):
return False
if len(self._fields) != len(other._fields):
return False
for (n1, num1, g1), (n2, num2, g2) in zip(self._fields, other._fields):
if n1 != n2 or num1 != num2 or g1.data_type != g2.data_type:
return False
return (self._binary_col_name == other._binary_col_name and
self.nullable == other.nullable)

def __hash__(self):
field_tuple = tuple((n, num, str(g.data_type)) for (n, num, g) in self._fields)
return hash((field_tuple, self._binary_col_name, self.nullable))

def start(self, rand):
for (_name, _num, gen) in self._fields:
gen.start(rand)

def make_row():
values = []
encoded_parts = []
for (name, num, gen) in self._fields:
v = gen.gen()
values.append(v)
encoded_parts.append(_encode_protobuf_field(num, gen.data_type, v))
msg = b"".join(encoded_parts)
return tuple(values + [msg])

self._start(rand, make_row)

# Note: Current(2023/06/06) maxmium IT data size is 7282688 bytes, so LRU cache with maxsize 128
# will lead to 7282688 * 128 = 932 MB additional memory usage in edge case, which is acceptable.
@lru_cache(maxsize=128, typed=True)
Expand Down
Loading
Loading