diff --git a/integration_tests/pom.xml b/integration_tests/pom.xml
index e3d91be0ce3..f178e84ffba 100644
--- a/integration_tests/pom.xml
+++ b/integration_tests/pom.xml
@@ -1,6 +1,6 @@
+
+ true
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala
new file mode 100644
index 00000000000..cabc8d2905d
--- /dev/null
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala
@@ -0,0 +1,87 @@
+/*
+ * Copyright (c) 2026, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.protobuf
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import com.google.protobuf.DescriptorProtos
+import com.google.protobuf.Descriptors
+
+/**
+ * Minimal descriptor utilities for locating a message descriptor in a FileDescriptorSet.
+ *
+ * This is intentionally lightweight for the "simple types" from_protobuf patch: it supports
+ * descriptor sets produced by `protoc --include_imports --descriptor_set_out=...`.
+ *
+ * NOTE: This utility is currently not used in the initial implementation, which relies on
+ * Spark's ProtobufUtils via reflection (buildMessageDescriptorWithSparkProtobuf). This class
+ * is preserved for potential future use cases where direct descriptor parsing is needed
+ * without depending on Spark's shaded protobuf classes.
+ */
+object ProtobufDescriptorUtils {
+
+ def buildMessageDescriptor(
+ fileDescriptorSetBytes: Array[Byte],
+ messageName: String): Descriptors.Descriptor = {
+ val fds = DescriptorProtos.FileDescriptorSet.parseFrom(fileDescriptorSetBytes)
+ val protos = fds.getFileList.asScala.toSeq
+ val byName = protos.map(p => p.getName -> p).toMap
+ val cache = mutable.HashMap.empty[String, Descriptors.FileDescriptor]
+
+ def buildFileDescriptor(name: String): Descriptors.FileDescriptor = {
+ cache.getOrElseUpdate(name, {
+ val p = byName.getOrElse(name,
+ throw new IllegalArgumentException(s"Missing FileDescriptorProto for '$name'"))
+ val deps = p.getDependencyList.asScala.map(buildFileDescriptor _).toArray
+ Descriptors.FileDescriptor.buildFrom(p, deps)
+ })
+ }
+
+ val fileDescriptors = protos.map(p => buildFileDescriptor(p.getName))
+ val candidates = fileDescriptors.iterator.flatMap(fd => findMessageDescriptors(fd, messageName))
+ .toSeq
+
+ candidates match {
+ case Seq(d) => d
+ case Seq() =>
+ throw new IllegalArgumentException(
+ s"Message '$messageName' not found in FileDescriptorSet")
+ case many =>
+ val names = many.map(_.getFullName).distinct.sorted
+ throw new IllegalArgumentException(
+ s"Message '$messageName' is ambiguous; matches: ${names.mkString(", ")}")
+ }
+ }
+
+ private def findMessageDescriptors(
+ fd: Descriptors.FileDescriptor,
+ messageName: String): Iterator[Descriptors.Descriptor] = {
+ def matches(d: Descriptors.Descriptor): Boolean = {
+ d.getName == messageName ||
+ d.getFullName == messageName ||
+ d.getFullName.endsWith("." + messageName)
+ }
+
+ def walk(d: Descriptors.Descriptor): Iterator[Descriptors.Descriptor] = {
+ val nested = d.getNestedTypes.asScala.iterator.flatMap(walk _)
+ if (matches(d)) Iterator.single(d) ++ nested else nested
+ }
+
+ fd.getMessageTypes.asScala.iterator.flatMap(walk _)
+ }
+}
\ No newline at end of file
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobuf.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobuf.scala
new file mode 100644
index 00000000000..52e10feef7b
--- /dev/null
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobuf.scala
@@ -0,0 +1,204 @@
+/*
+ * Copyright (c) 2025-2026, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.rapids
+
+import ai.rapids.cudf
+import ai.rapids.cudf.{BinaryOp, CudfException, DType}
+import com.nvidia.spark.rapids.{GpuColumnVector, GpuUnaryExpression}
+import com.nvidia.spark.rapids.Arm.withResource
+import com.nvidia.spark.rapids.jni.Protobuf
+import com.nvidia.spark.rapids.shims.NullIntolerantShim
+
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression}
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.types._
+
+/**
+ * GPU implementation for Spark's `from_protobuf` decode path.
+ *
+ * This is designed to replace `org.apache.spark.sql.protobuf.ProtobufDataToCatalyst` when
+ * supported.
+ *
+ * @param fullSchema The complete output schema (must match the original expression's dataType)
+ * @param decodedFieldIndices Indices into fullSchema for fields that will be decoded by GPU.
+ * Fields not in this array will be null columns.
+ * @param fieldNumbers Protobuf field numbers for decoded fields (parallel to decodedFieldIndices)
+ * @param cudfTypeIds cuDF type IDs for decoded fields (parallel to decodedFieldIndices)
+ * @param cudfTypeScales Encodings for decoded fields (parallel to decodedFieldIndices)
+ * @param failOnErrors If true, throw exception on malformed data; if false, return null
+ */
+case class GpuFromProtobuf(
+ fullSchema: StructType,
+ decodedFieldIndices: Array[Int],
+ fieldNumbers: Array[Int],
+ cudfTypeIds: Array[Int],
+ cudfTypeScales: Array[Int],
+ failOnErrors: Boolean,
+ child: Expression)
+ extends GpuUnaryExpression with ExpectsInputTypes with NullIntolerantShim {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType)
+
+ override def dataType: DataType = fullSchema.asNullable
+
+ override def nullable: Boolean = true
+
+ override protected def doColumnar(input: GpuColumnVector): cudf.ColumnVector = {
+ val numRows = input.getRowCount.toInt
+
+ // Decode only the requested fields from protobuf
+ val decoded = try {
+ Protobuf.decodeToStruct(
+ input.getBase,
+ fieldNumbers,
+ cudfTypeIds,
+ cudfTypeScales,
+ failOnErrors)
+ } catch {
+ case e: CudfException if failOnErrors =>
+ // Convert CudfException to Spark's standard protobuf error for consistent error handling.
+ // This allows user code to catch the same exception type regardless of CPU/GPU execution.
+ throw QueryExecutionErrors.malformedProtobufMessageDetectedInMessageParsingError(e)
+ }
+
+ // Build the full struct with all fields from fullSchema
+ // Decoded fields come from the GPU result, others are null columns
+ val result = withResource(decoded) { decodedStruct =>
+ val fullChildren = new Array[cudf.ColumnVector](fullSchema.fields.length)
+ var decodedIdx = 0
+
+ try {
+ for (i <- fullSchema.fields.indices) {
+ if (decodedIdx < decodedFieldIndices.length && decodedFieldIndices(decodedIdx) == i) {
+ // This field was decoded - extract from decoded struct
+ fullChildren(i) = decodedStruct.getChildColumnView(decodedIdx).copyToColumnVector()
+ decodedIdx += 1
+ } else {
+ // This field was not decoded - create null column
+ fullChildren(i) = GpuFromProtobuf.createNullColumn(
+ fullSchema.fields(i).dataType, numRows)
+ }
+ }
+ // cuDF's makeStruct increments the reference count of child columns, so the struct
+ // owns its own references. We must close our original references in the finally block
+ // regardless of whether makeStruct succeeds or fails.
+ cudf.ColumnVector.makeStruct(numRows, fullChildren: _*)
+ } finally {
+ // Safe to close: if loop failed mid-way, only non-null entries are closed.
+ // If makeStruct succeeded, struct has its own refs; if it failed, we clean up.
+ fullChildren.foreach(c => if (c != null) c.close())
+ }
+ }
+
+ // Apply input nulls to output
+ if (input.getBase.hasNulls) {
+ withResource(result) { _ =>
+ result.mergeAndSetValidity(BinaryOp.BITWISE_AND, input.getBase)
+ }
+ } else {
+ result
+ }
+ }
+}
+
+object GpuFromProtobuf {
+ // Encodings from com.nvidia.spark.rapids.jni.Protobuf
+ val ENC_DEFAULT = 0
+ val ENC_FIXED = 1
+ val ENC_ZIGZAG = 2
+
+ /**
+ * Maps a Spark DataType to the corresponding cuDF native type ID.
+ * Note: The encoding (varint/zigzag/fixed) is determined by the protobuf field type,
+ * not the Spark data type, so it must be set separately based on the protobuf schema.
+ */
+ def sparkTypeToCudfId(dt: DataType): Int = dt match {
+ case BooleanType => DType.BOOL8.getTypeId.getNativeId
+ case IntegerType => DType.INT32.getTypeId.getNativeId
+ case LongType => DType.INT64.getTypeId.getNativeId
+ case FloatType => DType.FLOAT32.getTypeId.getNativeId
+ case DoubleType => DType.FLOAT64.getTypeId.getNativeId
+ case StringType => DType.STRING.getTypeId.getNativeId
+ case BinaryType => DType.LIST.getTypeId.getNativeId
+ case other =>
+ throw new IllegalArgumentException(s"Unsupported Spark type for protobuf: $other")
+ }
+
+ /**
+ * Creates a null column of the specified Spark data type with the given number of rows.
+ * Used for fields that are not decoded (schema projection optimization).
+ */
+ def createNullColumn(dataType: DataType, numRows: Int): cudf.ColumnVector = {
+ val cudfType = dataType match {
+ case BooleanType => DType.BOOL8
+ case IntegerType => DType.INT32
+ case LongType => DType.INT64
+ case FloatType => DType.FLOAT32
+ case DoubleType => DType.FLOAT64
+ case StringType => DType.STRING
+ case BinaryType =>
+ // Binary is LIST in cuDF
+ return withResource(cudf.Scalar.listFromNull(
+ new cudf.HostColumnVector.BasicType(false, DType.INT8))) { nullScalar =>
+ withResource(cudf.ColumnVector.fromScalar(nullScalar, numRows)) { col =>
+ col.incRefCount()
+ }
+ }
+ case st: StructType =>
+ // For nested struct, create struct with null children and set all rows to null
+ val nullChildren = st.fields.map(f => createNullColumn(f.dataType, numRows))
+ return withResource(new AutoCloseableArray(nullChildren)) { _ =>
+ withResource(cudf.ColumnVector.makeStruct(numRows, nullChildren: _*)) { struct =>
+ // Create a validity mask of all nulls
+ withResource(cudf.Scalar.fromBool(false)) { falseBool =>
+ withResource(cudf.ColumnVector.fromScalar(falseBool, numRows)) { allFalse =>
+ struct.mergeAndSetValidity(BinaryOp.BITWISE_AND, allFalse)
+ }
+ }
+ }
+ }
+ case ArrayType(elementType, _) =>
+ val elementDType = elementType match {
+ case BooleanType => DType.BOOL8
+ case IntegerType => DType.INT32
+ case LongType => DType.INT64
+ case FloatType => DType.FLOAT32
+ case DoubleType => DType.FLOAT64
+ case StringType => DType.STRING
+ case _ => DType.INT8 // fallback
+ }
+ return withResource(cudf.Scalar.listFromNull(
+ new cudf.HostColumnVector.BasicType(false, elementDType))) { nullScalar =>
+ withResource(cudf.ColumnVector.fromScalar(nullScalar, numRows)) { col =>
+ col.incRefCount()
+ }
+ }
+ case _ =>
+ // Fallback: use INT8 and hope for the best (shouldn't happen for supported types)
+ DType.INT8
+ }
+
+ withResource(cudf.Scalar.fromNull(cudfType)) { nullScalar =>
+ cudf.ColumnVector.fromScalar(nullScalar, numRows)
+ }
+ }
+
+ /** Helper class to auto-close an array of ColumnVectors */
+ private class AutoCloseableArray(cols: Array[cudf.ColumnVector]) extends AutoCloseable {
+ override def close(): Unit = cols.foreach(c => if (c != null) c.close())
+ }
+}
\ No newline at end of file
diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala
new file mode 100644
index 00000000000..9d79de00ccf
--- /dev/null
+++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala
@@ -0,0 +1,530 @@
+/*
+ * Copyright (c) 2026, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "340"}
+{"spark": "341"}
+{"spark": "342"}
+{"spark": "343"}
+{"spark": "344"}
+{"spark": "350"}
+{"spark": "351"}
+{"spark": "352"}
+{"spark": "353"}
+{"spark": "354"}
+{"spark": "355"}
+{"spark": "356"}
+{"spark": "357"}
+{"spark": "400"}
+{"spark": "401"}
+spark-rapids-shim-json-lines ***/
+
+package com.nvidia.spark.rapids.shims
+
+import java.nio.file.{Files, Path}
+
+import scala.collection.mutable
+import scala.util.Try
+
+import com.nvidia.spark.rapids._
+
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, GetStructField, UnaryExpression}
+import org.apache.spark.sql.execution.ProjectExec
+import org.apache.spark.sql.rapids.GpuFromProtobuf
+import org.apache.spark.sql.types._
+
+/**
+ * Information about a protobuf field for schema projection support.
+ */
+private[shims] case class ProtobufFieldInfo(
+ fieldNumber: Int,
+ protoTypeName: String,
+ sparkType: DataType,
+ encoding: Int,
+ isSupported: Boolean,
+ unsupportedReason: Option[String]
+)
+
+/**
+ * Spark 3.4+ optional integration for spark-protobuf expressions.
+ *
+ * spark-protobuf is an external module, so these rules must be registered by reflection.
+ */
+object ProtobufExprShims {
+ private[this] val protobufDataToCatalystClassName =
+ "org.apache.spark.sql.protobuf.ProtobufDataToCatalyst"
+
+ private[this] val sparkProtobufUtilsObjectClassName =
+ "org.apache.spark.sql.protobuf.utils.ProtobufUtils$"
+
+ def exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = {
+ try {
+ val clazz = ShimReflectionUtils.loadClass(protobufDataToCatalystClassName)
+ .asInstanceOf[Class[_ <: UnaryExpression]]
+ Map(clazz.asInstanceOf[Class[_ <: Expression]] -> fromProtobufRule)
+ } catch {
+ case _: ClassNotFoundException => Map.empty
+ }
+ }
+
+ private def fromProtobufRule: ExprRule[_ <: Expression] = {
+ GpuOverrides.expr[UnaryExpression](
+ "Decode a BinaryType column (protobuf) into a Spark SQL struct",
+ ExprChecks.unaryProject(
+ // Use TypeSig.all here because schema projection determines which fields
+ // actually need GPU support. Detailed type checking is done in tagExprForGpu.
+ TypeSig.all,
+ TypeSig.all,
+ TypeSig.BINARY,
+ TypeSig.BINARY),
+ (e, conf, p, r) => new UnaryExprMeta[UnaryExpression](e, conf, p, r) {
+
+ // Full schema from the expression (must match original dataType for compatibility)
+ private var fullSchema: StructType = _
+ // Indices into fullSchema for fields that will be decoded by GPU
+ private var decodedFieldIndices: Array[Int] = _
+ private var fieldNumbers: Array[Int] = _
+ private var cudfTypeIds: Array[Int] = _
+ private var cudfTypeScales: Array[Int] = _
+ private var failOnErrors: Boolean = _
+
+ override def tagExprForGpu(): Unit = {
+ fullSchema = e.dataType match {
+ case st: StructType => st
+ case other =>
+ willNotWorkOnGpu(
+ s"Only StructType output is supported for from_protobuf, got $other")
+ return
+ }
+
+ val options = getOptionsMap(e)
+ val supportedOptions = Set("enums.as.ints", "mode")
+ val unsupportedOptions = options.keys.filterNot(supportedOptions.contains)
+ if (unsupportedOptions.nonEmpty) {
+ val keys = unsupportedOptions.mkString(",")
+ willNotWorkOnGpu(
+ s"from_protobuf options are not supported yet on GPU: $keys")
+ return
+ }
+
+ val enumsAsInts = options.getOrElse("enums.as.ints", "false").toBoolean
+ failOnErrors = options.getOrElse("mode", "PERMISSIVE").equalsIgnoreCase("FAILFAST")
+ val messageName = getMessageName(e)
+ val descFilePathOpt = getDescFilePath(e).orElse {
+ // Newer Spark may embed a descriptor set (binaryDescriptorSet). Write it to a temp file
+ // so we can reuse Spark's ProtobufUtils (and its shaded protobuf classes) to resolve
+ // the descriptor.
+ getDescriptorBytes(e).map(writeTempDescFile)
+ }
+ if (descFilePathOpt.isEmpty) {
+ willNotWorkOnGpu(
+ "from_protobuf requires a descriptor set " +
+ "(descFilePath or binaryDescriptorSet)")
+ return
+ }
+
+ val msgDesc = try {
+ // Spark 3.4.x builds the descriptor as:
+ // ProtobufUtils.buildDescriptor(messageName, descFilePathOpt)
+ buildMessageDescriptorWithSparkProtobuf(messageName, descFilePathOpt)
+ } catch {
+ case t: Throwable =>
+ willNotWorkOnGpu(
+ s"Failed to resolve protobuf descriptor for message '$messageName': " +
+ s"${t.getMessage}")
+ return
+ }
+
+ // Step 1: Analyze all fields and build field info map
+ val allFieldsInfo = analyzeAllFields(fullSchema, msgDesc, enumsAsInts, messageName)
+ if (allFieldsInfo.isEmpty) {
+ // Error was already reported in analyzeAllFields
+ return
+ }
+ val fieldsInfoMap = allFieldsInfo.get
+
+ // Step 2: Determine which fields are actually required by downstream operations
+ val requiredFieldNames = analyzeRequiredFields(fieldsInfoMap.keySet)
+
+ // Step 3: Check if all required fields are supported
+ val unsupportedRequired = requiredFieldNames.filter { name =>
+ fieldsInfoMap.get(name).exists(!_.isSupported)
+ }
+
+ if (unsupportedRequired.nonEmpty) {
+ val reasons = unsupportedRequired.map { name =>
+ val info = fieldsInfoMap(name)
+ s"${name}: ${info.unsupportedReason.getOrElse("unknown reason")}"
+ }
+ willNotWorkOnGpu(
+ s"Required fields not supported for from_protobuf: ${reasons.mkString(", ")}")
+ return
+ }
+
+ // Step 4: Identify which fields in fullSchema need to be decoded
+ // These are fields that are required AND supported
+ val indicesToDecode = fullSchema.fields.zipWithIndex.collect {
+ case (sf, idx) if requiredFieldNames.contains(sf.name) => idx
+ }
+ decodedFieldIndices = indicesToDecode
+
+ // Step 5: Build arrays for the fields to decode (parallel to decodedFieldIndices)
+ val fnums = new Array[Int](indicesToDecode.length)
+ val typeIds = new Array[Int](indicesToDecode.length)
+ val scales = new Array[Int](indicesToDecode.length)
+
+ indicesToDecode.zipWithIndex.foreach { case (schemaIdx, arrIdx) =>
+ val sf = fullSchema.fields(schemaIdx)
+ val info = fieldsInfoMap(sf.name)
+ fnums(arrIdx) = info.fieldNumber
+ typeIds(arrIdx) = GpuFromProtobuf.sparkTypeToCudfId(sf.dataType)
+ scales(arrIdx) = info.encoding
+ }
+
+ fieldNumbers = fnums
+ cudfTypeIds = typeIds
+ cudfTypeScales = scales
+ }
+
+ /**
+ * Analyze all fields in the schema and build a map of field name to ProtobufFieldInfo.
+ * Returns None if there's an error that should abort processing.
+ */
+ private def analyzeAllFields(
+ schema: StructType,
+ msgDesc: AnyRef,
+ enumsAsInts: Boolean,
+ messageName: String): Option[Map[String, ProtobufFieldInfo]] = {
+ val result = mutable.Map[String, ProtobufFieldInfo]()
+
+ for (sf <- schema.fields) {
+ val fd = invoke1[AnyRef](msgDesc, "findFieldByName", classOf[String], sf.name)
+ if (fd == null) {
+ willNotWorkOnGpu(
+ s"Protobuf field '${sf.name}' not found in message '$messageName'")
+ return None
+ }
+
+ val isRepeated = Try {
+ invoke0[java.lang.Boolean](fd, "isRepeated").booleanValue()
+ }.getOrElse(false)
+
+ val protoType = invoke0[AnyRef](fd, "getType")
+ val protoTypeName = typeName(protoType)
+ val fieldNumber = invoke0[java.lang.Integer](fd, "getNumber").intValue()
+
+ // Check field support and determine encoding
+ val (isSupported, unsupportedReason, encoding) =
+ checkFieldSupport(sf.dataType, protoTypeName, isRepeated, enumsAsInts)
+
+ result(sf.name) = ProtobufFieldInfo(
+ fieldNumber = fieldNumber,
+ protoTypeName = protoTypeName,
+ sparkType = sf.dataType,
+ encoding = encoding,
+ isSupported = isSupported,
+ unsupportedReason = unsupportedReason
+ )
+ }
+
+ Some(result.toMap)
+ }
+
+ /**
+ * Check if a field type is supported and return encoding information.
+ * @return (isSupported, unsupportedReason, encoding)
+ */
+ private def checkFieldSupport(
+ sparkType: DataType,
+ protoTypeName: String,
+ isRepeated: Boolean,
+ enumsAsInts: Boolean): (Boolean, Option[String], Int) = {
+
+ if (isRepeated) {
+ return (false, Some("repeated fields are not supported"), GpuFromProtobuf.ENC_DEFAULT)
+ }
+
+ // Check Spark type is one of the supported simple types
+ sparkType match {
+ case BooleanType | IntegerType | LongType | FloatType | DoubleType |
+ StringType | BinaryType =>
+ // Supported Spark type, continue to check encoding
+ case other =>
+ return (false, Some(s"unsupported Spark type: $other"), GpuFromProtobuf.ENC_DEFAULT)
+ }
+
+ // Determine encoding based on Spark type and proto type combination
+ val encoding = (sparkType, protoTypeName) match {
+ case (BooleanType, "BOOL") => Some(GpuFromProtobuf.ENC_DEFAULT)
+ case (IntegerType, "INT32" | "UINT32") => Some(GpuFromProtobuf.ENC_DEFAULT)
+ case (IntegerType, "SINT32") => Some(GpuFromProtobuf.ENC_ZIGZAG)
+ case (IntegerType, "FIXED32" | "SFIXED32") => Some(GpuFromProtobuf.ENC_FIXED)
+ case (LongType, "INT64" | "UINT64") => Some(GpuFromProtobuf.ENC_DEFAULT)
+ case (LongType, "SINT64") => Some(GpuFromProtobuf.ENC_ZIGZAG)
+ case (LongType, "FIXED64" | "SFIXED64") => Some(GpuFromProtobuf.ENC_FIXED)
+ // Spark may upcast smaller integers to LongType
+ case (LongType, "INT32" | "UINT32" | "SINT32" | "FIXED32" | "SFIXED32") =>
+ val enc = protoTypeName match {
+ case "SINT32" => GpuFromProtobuf.ENC_ZIGZAG
+ case "FIXED32" | "SFIXED32" => GpuFromProtobuf.ENC_FIXED
+ case _ => GpuFromProtobuf.ENC_DEFAULT
+ }
+ Some(enc)
+ case (FloatType, "FLOAT") => Some(GpuFromProtobuf.ENC_DEFAULT)
+ case (DoubleType, "DOUBLE") => Some(GpuFromProtobuf.ENC_DEFAULT)
+ case (StringType, "STRING") => Some(GpuFromProtobuf.ENC_DEFAULT)
+ case (BinaryType, "BYTES") => Some(GpuFromProtobuf.ENC_DEFAULT)
+ case (IntegerType, "ENUM") if enumsAsInts => Some(GpuFromProtobuf.ENC_DEFAULT)
+ case _ => None
+ }
+
+ encoding match {
+ case Some(enc) => (true, None, enc)
+ case None =>
+ (false,
+ Some(s"type mismatch: Spark $sparkType vs Protobuf $protoTypeName"),
+ GpuFromProtobuf.ENC_DEFAULT)
+ }
+ }
+
+ /**
+ * Analyze which fields are actually required by downstream operations.
+ * Currently supports analyzing parent Project expressions.
+ *
+ * @param allFieldNames All field names in the full schema
+ * @return Set of field names that are actually required
+ */
+ private def analyzeRequiredFields(allFieldNames: Set[String]): Set[String] = {
+ // Try to find parent SparkPlanMeta and analyze downstream Project
+ val parentPlanOpt = findParentPlanMeta()
+
+ parentPlanOpt match {
+ case Some(planMeta) =>
+ // First, try to analyze the immediate parent
+ analyzeDownstreamProject(planMeta) match {
+ case Some(fields) if fields.nonEmpty =>
+ // Successfully identified required fields via schema projection
+ fields
+ case _ =>
+ // The immediate parent might be a ProjectExec that just aliases the output.
+ // Try to look at its parent (the grandparent) for GetStructField references.
+ planMeta.parent match {
+ case Some(grandParentMeta: SparkPlanMeta[_]) =>
+ analyzeDownstreamProject(grandParentMeta) match {
+ case Some(fields) if fields.nonEmpty => fields
+ case _ => allFieldNames
+ }
+ case _ => allFieldNames
+ }
+ }
+ case None =>
+ // No parent SparkPlanMeta found in the meta tree, assume all fields are needed
+ allFieldNames
+ }
+ }
+
+ /**
+ * Find the parent SparkPlanMeta by traversing up the parent chain.
+ */
+ private def findParentPlanMeta(): Option[SparkPlanMeta[_]] = {
+ def traverse(meta: Option[RapidsMeta[_, _, _]]): Option[SparkPlanMeta[_]] = {
+ meta match {
+ case Some(p: SparkPlanMeta[_]) => Some(p)
+ case Some(p: RapidsMeta[_, _, _]) => traverse(p.parent)
+ case _ => None
+ }
+ }
+ traverse(parent)
+ }
+
+ /**
+ * Analyze a Project plan to find which struct fields are actually used.
+ * This looks for GetStructField expressions that reference our protobuf output.
+ */
+ private def analyzeDownstreamProject(planMeta: SparkPlanMeta[_]): Option[Set[String]] = {
+ planMeta.wrapped match {
+ case p: ProjectExec =>
+ // Collect all GetStructField references from the project list
+ val fieldRefs = mutable.Set[String]()
+ var hasDirectStructRef = false
+
+ p.projectList.foreach { expr =>
+ collectStructFieldReferences(expr, fieldRefs, hasDirectStructRefHolder = () => {
+ hasDirectStructRef = true
+ })
+ }
+
+ if (hasDirectStructRef) {
+ // If the entire struct is referenced directly (not via GetStructField),
+ // we need all fields
+ None
+ } else if (fieldRefs.nonEmpty) {
+ Some(fieldRefs.toSet)
+ } else {
+ // No GetStructField found - this shouldn't happen for valid plans
+ // where from_protobuf is followed by field access
+ None
+ }
+ case _ =>
+ // Not a ProjectExec, cannot analyze schema projection
+ None
+ }
+ }
+
+ /**
+ * Recursively collect field names from GetStructField expressions.
+ * Also tracks if the struct is used directly without field extraction.
+ */
+ private def collectStructFieldReferences(
+ expr: Expression,
+ fieldRefs: mutable.Set[String],
+ hasDirectStructRefHolder: () => Unit): Unit = {
+ expr match {
+ case GetStructField(child, ordinal, nameOpt) =>
+ // Check if this GetStructField extracts from our protobuf struct
+ if (isProtobufStructReference(child)) {
+ // Get field name from the schema using ordinal
+ val fieldName = nameOpt.getOrElse {
+ if (ordinal < fullSchema.fields.length) {
+ fullSchema.fields(ordinal).name
+ } else {
+ s"_$ordinal"
+ }
+ }
+ fieldRefs += fieldName
+ // Don't recurse into child - we've handled this protobuf reference
+ } else {
+ // Child is not a protobuf struct, recurse to check for nested access
+ collectStructFieldReferences(child, fieldRefs, hasDirectStructRefHolder)
+ }
+
+ case _ =>
+ // Check if this expression directly references our protobuf struct
+ // without extracting a field (e.g., passing the whole struct to a function)
+ if (isProtobufStructReference(expr)) {
+ hasDirectStructRefHolder()
+ }
+ // Recursively check children
+ expr.children.foreach { child =>
+ collectStructFieldReferences(child, fieldRefs, hasDirectStructRefHolder)
+ }
+ }
+ }
+
+ /**
+ * Check if an expression references the output of a protobuf decode expression.
+ * This can be either:
+ * 1. The ProtobufDataToCatalyst expression itself
+ * 2. An AttributeReference that references the output of ProtobufDataToCatalyst
+ * (when accessing from a downstream ProjectExec)
+ */
+ private def isProtobufStructReference(expr: Expression): Boolean = {
+ // Check if expr is a ProtobufDataToCatalyst expression
+ if (expr.getClass.getName.contains("ProtobufDataToCatalyst")) {
+ return true
+ }
+
+ // Check if expr is an AttributeReference with the same schema as our protobuf output
+ // This handles the case where GetStructField references a column from a parent Project
+ expr match {
+ case attr: AttributeReference =>
+ // Check if the data type matches our full schema (struct type from protobuf)
+ attr.dataType match {
+ case st: StructType =>
+ // Compare field names and types only. We intentionally do not compare
+ // nullable flags because schema transformations (like projections or
+ // certain optimizations) may change nullability while the underlying
+ // schema structure remains the same. For schema projection detection,
+ // matching names and types is sufficient to identify protobuf output.
+ st.fields.length == fullSchema.fields.length &&
+ st.fields.zip(fullSchema.fields).forall { case (a, b) =>
+ a.name == b.name && a.dataType == b.dataType
+ }
+ case _ => false
+ }
+ case _ => false
+ }
+ }
+
+ override def convertToGpu(child: Expression): GpuExpression = {
+ GpuFromProtobuf(
+ fullSchema, decodedFieldIndices, fieldNumbers, cudfTypeIds, cudfTypeScales,
+ failOnErrors, child)
+ }
+ }
+ )
+ }
+
+ private def getMessageName(e: Expression): String =
+ invoke0[String](e, "messageName")
+
+ /**
+ * Newer Spark versions may carry an in-expression descriptor set payload
+ * (e.g. binaryDescriptorSet).
+ * Spark 3.4.x does not, so callers should fall back to descFilePath().
+ */
+ private def getDescriptorBytes(e: Expression): Option[Array[Byte]] = {
+ // Spark 4.x/3.5+ (depending on the API): may be Array[Byte] or Option[Array[Byte]].
+ val direct = Try(invoke0[Array[Byte]](e, "binaryDescriptorSet")).toOption
+ direct.orElse {
+ Try(invoke0[Option[Array[Byte]]](e, "binaryDescriptorSet")).toOption.flatten
+ }
+ }
+
+ private def getDescFilePath(e: Expression): Option[String] =
+ Try(invoke0[Option[String]](e, "descFilePath")).toOption.flatten
+
+ private def writeTempDescFile(descBytes: Array[Byte]): String = {
+ val tmp: Path = Files.createTempFile("spark-rapids-protobuf-desc-", ".desc")
+ Files.write(tmp, descBytes)
+ // deleteOnExit() is not guaranteed to run on abnormal JVM termination, but these
+ // descriptor files are small (typically < 10KB) and only created when using
+ // binaryDescriptorSet (Spark 4.0+). The risk of temporary file accumulation is
+ // acceptable for this use case.
+ tmp.toFile.deleteOnExit()
+ tmp.toString
+ }
+
+ private def buildMessageDescriptorWithSparkProtobuf(
+ messageName: String,
+ descFilePathOpt: Option[String]): AnyRef = {
+ val cls = ShimReflectionUtils.loadClass(sparkProtobufUtilsObjectClassName)
+ val module = cls.getField("MODULE$").get(null)
+ // buildDescriptor(messageName: String, descFilePath: Option[String])
+ val m = cls.getMethod("buildDescriptor", classOf[String], classOf[scala.Option[_]])
+ m.invoke(module, messageName, descFilePathOpt).asInstanceOf[AnyRef]
+ }
+
+ private def typeName(t: AnyRef): String = {
+ if (t == null) {
+ "null"
+ } else {
+ // Prefer Enum.name() when available; fall back to toString.
+ Try(invoke0[String](t, "name")).getOrElse(t.toString)
+ }
+ }
+
+ private def getOptionsMap(e: Expression): Map[String, String] = {
+ val opt = Try(invoke0[scala.collection.Map[String, String]](e, "options")).toOption
+ opt.map(_.toMap).getOrElse(Map.empty)
+ }
+
+ private def invoke0[T](obj: AnyRef, method: String): T =
+ obj.getClass.getMethod(method).invoke(obj).asInstanceOf[T]
+
+ private def invoke1[T](obj: AnyRef, method: String, arg0Cls: Class[_], arg0: AnyRef): T =
+ obj.getClass.getMethod(method, arg0Cls).invoke(obj, arg0).asInstanceOf[T]
+}
diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala
index 6e28a071a00..56bfa229051 100644
--- a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala
+++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2025, NVIDIA CORPORATION.
+ * Copyright (c) 2022-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -162,7 +162,7 @@ trait Spark340PlusNonDBShims extends Spark331PlusNonDBShims {
),
GpuElementAtMeta.elementAtRule(true)
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
- super.getExprs ++ shimExprs
+ super.getExprs ++ shimExprs ++ ProtobufExprShims.exprs
}
override def getDataWriteCmds: Map[Class[_ <: DataWritingCommand],