diff --git a/integration_tests/pom.xml b/integration_tests/pom.xml index e3d91be0ce3..f178e84ffba 100644 --- a/integration_tests/pom.xml +++ b/integration_tests/pom.xml @@ -1,6 +1,6 @@ + + true diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala new file mode 100644 index 00000000000..cabc8d2905d --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.protobuf + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import com.google.protobuf.DescriptorProtos +import com.google.protobuf.Descriptors + +/** + * Minimal descriptor utilities for locating a message descriptor in a FileDescriptorSet. + * + * This is intentionally lightweight for the "simple types" from_protobuf patch: it supports + * descriptor sets produced by `protoc --include_imports --descriptor_set_out=...`. + * + * NOTE: This utility is currently not used in the initial implementation, which relies on + * Spark's ProtobufUtils via reflection (buildMessageDescriptorWithSparkProtobuf). This class + * is preserved for potential future use cases where direct descriptor parsing is needed + * without depending on Spark's shaded protobuf classes. + */ +object ProtobufDescriptorUtils { + + def buildMessageDescriptor( + fileDescriptorSetBytes: Array[Byte], + messageName: String): Descriptors.Descriptor = { + val fds = DescriptorProtos.FileDescriptorSet.parseFrom(fileDescriptorSetBytes) + val protos = fds.getFileList.asScala.toSeq + val byName = protos.map(p => p.getName -> p).toMap + val cache = mutable.HashMap.empty[String, Descriptors.FileDescriptor] + + def buildFileDescriptor(name: String): Descriptors.FileDescriptor = { + cache.getOrElseUpdate(name, { + val p = byName.getOrElse(name, + throw new IllegalArgumentException(s"Missing FileDescriptorProto for '$name'")) + val deps = p.getDependencyList.asScala.map(buildFileDescriptor _).toArray + Descriptors.FileDescriptor.buildFrom(p, deps) + }) + } + + val fileDescriptors = protos.map(p => buildFileDescriptor(p.getName)) + val candidates = fileDescriptors.iterator.flatMap(fd => findMessageDescriptors(fd, messageName)) + .toSeq + + candidates match { + case Seq(d) => d + case Seq() => + throw new IllegalArgumentException( + s"Message '$messageName' not found in FileDescriptorSet") + case many => + val names = many.map(_.getFullName).distinct.sorted + throw new IllegalArgumentException( + s"Message '$messageName' is ambiguous; matches: ${names.mkString(", ")}") + } + } + + private def findMessageDescriptors( + fd: Descriptors.FileDescriptor, + messageName: String): Iterator[Descriptors.Descriptor] = { + def matches(d: Descriptors.Descriptor): Boolean = { + d.getName == messageName || + d.getFullName == messageName || + d.getFullName.endsWith("." + messageName) + } + + def walk(d: Descriptors.Descriptor): Iterator[Descriptors.Descriptor] = { + val nested = d.getNestedTypes.asScala.iterator.flatMap(walk _) + if (matches(d)) Iterator.single(d) ++ nested else nested + } + + fd.getMessageTypes.asScala.iterator.flatMap(walk _) + } +} \ No newline at end of file diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobuf.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobuf.scala new file mode 100644 index 00000000000..52e10feef7b --- /dev/null +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobuf.scala @@ -0,0 +1,204 @@ +/* + * Copyright (c) 2025-2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids + +import ai.rapids.cudf +import ai.rapids.cudf.{BinaryOp, CudfException, DType} +import com.nvidia.spark.rapids.{GpuColumnVector, GpuUnaryExpression} +import com.nvidia.spark.rapids.Arm.withResource +import com.nvidia.spark.rapids.jni.Protobuf +import com.nvidia.spark.rapids.shims.NullIntolerantShim + +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression} +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.types._ + +/** + * GPU implementation for Spark's `from_protobuf` decode path. + * + * This is designed to replace `org.apache.spark.sql.protobuf.ProtobufDataToCatalyst` when + * supported. + * + * @param fullSchema The complete output schema (must match the original expression's dataType) + * @param decodedFieldIndices Indices into fullSchema for fields that will be decoded by GPU. + * Fields not in this array will be null columns. + * @param fieldNumbers Protobuf field numbers for decoded fields (parallel to decodedFieldIndices) + * @param cudfTypeIds cuDF type IDs for decoded fields (parallel to decodedFieldIndices) + * @param cudfTypeScales Encodings for decoded fields (parallel to decodedFieldIndices) + * @param failOnErrors If true, throw exception on malformed data; if false, return null + */ +case class GpuFromProtobuf( + fullSchema: StructType, + decodedFieldIndices: Array[Int], + fieldNumbers: Array[Int], + cudfTypeIds: Array[Int], + cudfTypeScales: Array[Int], + failOnErrors: Boolean, + child: Expression) + extends GpuUnaryExpression with ExpectsInputTypes with NullIntolerantShim { + + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType) + + override def dataType: DataType = fullSchema.asNullable + + override def nullable: Boolean = true + + override protected def doColumnar(input: GpuColumnVector): cudf.ColumnVector = { + val numRows = input.getRowCount.toInt + + // Decode only the requested fields from protobuf + val decoded = try { + Protobuf.decodeToStruct( + input.getBase, + fieldNumbers, + cudfTypeIds, + cudfTypeScales, + failOnErrors) + } catch { + case e: CudfException if failOnErrors => + // Convert CudfException to Spark's standard protobuf error for consistent error handling. + // This allows user code to catch the same exception type regardless of CPU/GPU execution. + throw QueryExecutionErrors.malformedProtobufMessageDetectedInMessageParsingError(e) + } + + // Build the full struct with all fields from fullSchema + // Decoded fields come from the GPU result, others are null columns + val result = withResource(decoded) { decodedStruct => + val fullChildren = new Array[cudf.ColumnVector](fullSchema.fields.length) + var decodedIdx = 0 + + try { + for (i <- fullSchema.fields.indices) { + if (decodedIdx < decodedFieldIndices.length && decodedFieldIndices(decodedIdx) == i) { + // This field was decoded - extract from decoded struct + fullChildren(i) = decodedStruct.getChildColumnView(decodedIdx).copyToColumnVector() + decodedIdx += 1 + } else { + // This field was not decoded - create null column + fullChildren(i) = GpuFromProtobuf.createNullColumn( + fullSchema.fields(i).dataType, numRows) + } + } + // cuDF's makeStruct increments the reference count of child columns, so the struct + // owns its own references. We must close our original references in the finally block + // regardless of whether makeStruct succeeds or fails. + cudf.ColumnVector.makeStruct(numRows, fullChildren: _*) + } finally { + // Safe to close: if loop failed mid-way, only non-null entries are closed. + // If makeStruct succeeded, struct has its own refs; if it failed, we clean up. + fullChildren.foreach(c => if (c != null) c.close()) + } + } + + // Apply input nulls to output + if (input.getBase.hasNulls) { + withResource(result) { _ => + result.mergeAndSetValidity(BinaryOp.BITWISE_AND, input.getBase) + } + } else { + result + } + } +} + +object GpuFromProtobuf { + // Encodings from com.nvidia.spark.rapids.jni.Protobuf + val ENC_DEFAULT = 0 + val ENC_FIXED = 1 + val ENC_ZIGZAG = 2 + + /** + * Maps a Spark DataType to the corresponding cuDF native type ID. + * Note: The encoding (varint/zigzag/fixed) is determined by the protobuf field type, + * not the Spark data type, so it must be set separately based on the protobuf schema. + */ + def sparkTypeToCudfId(dt: DataType): Int = dt match { + case BooleanType => DType.BOOL8.getTypeId.getNativeId + case IntegerType => DType.INT32.getTypeId.getNativeId + case LongType => DType.INT64.getTypeId.getNativeId + case FloatType => DType.FLOAT32.getTypeId.getNativeId + case DoubleType => DType.FLOAT64.getTypeId.getNativeId + case StringType => DType.STRING.getTypeId.getNativeId + case BinaryType => DType.LIST.getTypeId.getNativeId + case other => + throw new IllegalArgumentException(s"Unsupported Spark type for protobuf: $other") + } + + /** + * Creates a null column of the specified Spark data type with the given number of rows. + * Used for fields that are not decoded (schema projection optimization). + */ + def createNullColumn(dataType: DataType, numRows: Int): cudf.ColumnVector = { + val cudfType = dataType match { + case BooleanType => DType.BOOL8 + case IntegerType => DType.INT32 + case LongType => DType.INT64 + case FloatType => DType.FLOAT32 + case DoubleType => DType.FLOAT64 + case StringType => DType.STRING + case BinaryType => + // Binary is LIST in cuDF + return withResource(cudf.Scalar.listFromNull( + new cudf.HostColumnVector.BasicType(false, DType.INT8))) { nullScalar => + withResource(cudf.ColumnVector.fromScalar(nullScalar, numRows)) { col => + col.incRefCount() + } + } + case st: StructType => + // For nested struct, create struct with null children and set all rows to null + val nullChildren = st.fields.map(f => createNullColumn(f.dataType, numRows)) + return withResource(new AutoCloseableArray(nullChildren)) { _ => + withResource(cudf.ColumnVector.makeStruct(numRows, nullChildren: _*)) { struct => + // Create a validity mask of all nulls + withResource(cudf.Scalar.fromBool(false)) { falseBool => + withResource(cudf.ColumnVector.fromScalar(falseBool, numRows)) { allFalse => + struct.mergeAndSetValidity(BinaryOp.BITWISE_AND, allFalse) + } + } + } + } + case ArrayType(elementType, _) => + val elementDType = elementType match { + case BooleanType => DType.BOOL8 + case IntegerType => DType.INT32 + case LongType => DType.INT64 + case FloatType => DType.FLOAT32 + case DoubleType => DType.FLOAT64 + case StringType => DType.STRING + case _ => DType.INT8 // fallback + } + return withResource(cudf.Scalar.listFromNull( + new cudf.HostColumnVector.BasicType(false, elementDType))) { nullScalar => + withResource(cudf.ColumnVector.fromScalar(nullScalar, numRows)) { col => + col.incRefCount() + } + } + case _ => + // Fallback: use INT8 and hope for the best (shouldn't happen for supported types) + DType.INT8 + } + + withResource(cudf.Scalar.fromNull(cudfType)) { nullScalar => + cudf.ColumnVector.fromScalar(nullScalar, numRows) + } + } + + /** Helper class to auto-close an array of ColumnVectors */ + private class AutoCloseableArray(cols: Array[cudf.ColumnVector]) extends AutoCloseable { + override def close(): Unit = cols.foreach(c => if (c != null) c.close()) + } +} \ No newline at end of file diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala new file mode 100644 index 00000000000..9d79de00ccf --- /dev/null +++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala @@ -0,0 +1,530 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "340"} +{"spark": "341"} +{"spark": "342"} +{"spark": "343"} +{"spark": "344"} +{"spark": "350"} +{"spark": "351"} +{"spark": "352"} +{"spark": "353"} +{"spark": "354"} +{"spark": "355"} +{"spark": "356"} +{"spark": "357"} +{"spark": "400"} +{"spark": "401"} +spark-rapids-shim-json-lines ***/ + +package com.nvidia.spark.rapids.shims + +import java.nio.file.{Files, Path} + +import scala.collection.mutable +import scala.util.Try + +import com.nvidia.spark.rapids._ + +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, GetStructField, UnaryExpression} +import org.apache.spark.sql.execution.ProjectExec +import org.apache.spark.sql.rapids.GpuFromProtobuf +import org.apache.spark.sql.types._ + +/** + * Information about a protobuf field for schema projection support. + */ +private[shims] case class ProtobufFieldInfo( + fieldNumber: Int, + protoTypeName: String, + sparkType: DataType, + encoding: Int, + isSupported: Boolean, + unsupportedReason: Option[String] +) + +/** + * Spark 3.4+ optional integration for spark-protobuf expressions. + * + * spark-protobuf is an external module, so these rules must be registered by reflection. + */ +object ProtobufExprShims { + private[this] val protobufDataToCatalystClassName = + "org.apache.spark.sql.protobuf.ProtobufDataToCatalyst" + + private[this] val sparkProtobufUtilsObjectClassName = + "org.apache.spark.sql.protobuf.utils.ProtobufUtils$" + + def exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = { + try { + val clazz = ShimReflectionUtils.loadClass(protobufDataToCatalystClassName) + .asInstanceOf[Class[_ <: UnaryExpression]] + Map(clazz.asInstanceOf[Class[_ <: Expression]] -> fromProtobufRule) + } catch { + case _: ClassNotFoundException => Map.empty + } + } + + private def fromProtobufRule: ExprRule[_ <: Expression] = { + GpuOverrides.expr[UnaryExpression]( + "Decode a BinaryType column (protobuf) into a Spark SQL struct", + ExprChecks.unaryProject( + // Use TypeSig.all here because schema projection determines which fields + // actually need GPU support. Detailed type checking is done in tagExprForGpu. + TypeSig.all, + TypeSig.all, + TypeSig.BINARY, + TypeSig.BINARY), + (e, conf, p, r) => new UnaryExprMeta[UnaryExpression](e, conf, p, r) { + + // Full schema from the expression (must match original dataType for compatibility) + private var fullSchema: StructType = _ + // Indices into fullSchema for fields that will be decoded by GPU + private var decodedFieldIndices: Array[Int] = _ + private var fieldNumbers: Array[Int] = _ + private var cudfTypeIds: Array[Int] = _ + private var cudfTypeScales: Array[Int] = _ + private var failOnErrors: Boolean = _ + + override def tagExprForGpu(): Unit = { + fullSchema = e.dataType match { + case st: StructType => st + case other => + willNotWorkOnGpu( + s"Only StructType output is supported for from_protobuf, got $other") + return + } + + val options = getOptionsMap(e) + val supportedOptions = Set("enums.as.ints", "mode") + val unsupportedOptions = options.keys.filterNot(supportedOptions.contains) + if (unsupportedOptions.nonEmpty) { + val keys = unsupportedOptions.mkString(",") + willNotWorkOnGpu( + s"from_protobuf options are not supported yet on GPU: $keys") + return + } + + val enumsAsInts = options.getOrElse("enums.as.ints", "false").toBoolean + failOnErrors = options.getOrElse("mode", "PERMISSIVE").equalsIgnoreCase("FAILFAST") + val messageName = getMessageName(e) + val descFilePathOpt = getDescFilePath(e).orElse { + // Newer Spark may embed a descriptor set (binaryDescriptorSet). Write it to a temp file + // so we can reuse Spark's ProtobufUtils (and its shaded protobuf classes) to resolve + // the descriptor. + getDescriptorBytes(e).map(writeTempDescFile) + } + if (descFilePathOpt.isEmpty) { + willNotWorkOnGpu( + "from_protobuf requires a descriptor set " + + "(descFilePath or binaryDescriptorSet)") + return + } + + val msgDesc = try { + // Spark 3.4.x builds the descriptor as: + // ProtobufUtils.buildDescriptor(messageName, descFilePathOpt) + buildMessageDescriptorWithSparkProtobuf(messageName, descFilePathOpt) + } catch { + case t: Throwable => + willNotWorkOnGpu( + s"Failed to resolve protobuf descriptor for message '$messageName': " + + s"${t.getMessage}") + return + } + + // Step 1: Analyze all fields and build field info map + val allFieldsInfo = analyzeAllFields(fullSchema, msgDesc, enumsAsInts, messageName) + if (allFieldsInfo.isEmpty) { + // Error was already reported in analyzeAllFields + return + } + val fieldsInfoMap = allFieldsInfo.get + + // Step 2: Determine which fields are actually required by downstream operations + val requiredFieldNames = analyzeRequiredFields(fieldsInfoMap.keySet) + + // Step 3: Check if all required fields are supported + val unsupportedRequired = requiredFieldNames.filter { name => + fieldsInfoMap.get(name).exists(!_.isSupported) + } + + if (unsupportedRequired.nonEmpty) { + val reasons = unsupportedRequired.map { name => + val info = fieldsInfoMap(name) + s"${name}: ${info.unsupportedReason.getOrElse("unknown reason")}" + } + willNotWorkOnGpu( + s"Required fields not supported for from_protobuf: ${reasons.mkString(", ")}") + return + } + + // Step 4: Identify which fields in fullSchema need to be decoded + // These are fields that are required AND supported + val indicesToDecode = fullSchema.fields.zipWithIndex.collect { + case (sf, idx) if requiredFieldNames.contains(sf.name) => idx + } + decodedFieldIndices = indicesToDecode + + // Step 5: Build arrays for the fields to decode (parallel to decodedFieldIndices) + val fnums = new Array[Int](indicesToDecode.length) + val typeIds = new Array[Int](indicesToDecode.length) + val scales = new Array[Int](indicesToDecode.length) + + indicesToDecode.zipWithIndex.foreach { case (schemaIdx, arrIdx) => + val sf = fullSchema.fields(schemaIdx) + val info = fieldsInfoMap(sf.name) + fnums(arrIdx) = info.fieldNumber + typeIds(arrIdx) = GpuFromProtobuf.sparkTypeToCudfId(sf.dataType) + scales(arrIdx) = info.encoding + } + + fieldNumbers = fnums + cudfTypeIds = typeIds + cudfTypeScales = scales + } + + /** + * Analyze all fields in the schema and build a map of field name to ProtobufFieldInfo. + * Returns None if there's an error that should abort processing. + */ + private def analyzeAllFields( + schema: StructType, + msgDesc: AnyRef, + enumsAsInts: Boolean, + messageName: String): Option[Map[String, ProtobufFieldInfo]] = { + val result = mutable.Map[String, ProtobufFieldInfo]() + + for (sf <- schema.fields) { + val fd = invoke1[AnyRef](msgDesc, "findFieldByName", classOf[String], sf.name) + if (fd == null) { + willNotWorkOnGpu( + s"Protobuf field '${sf.name}' not found in message '$messageName'") + return None + } + + val isRepeated = Try { + invoke0[java.lang.Boolean](fd, "isRepeated").booleanValue() + }.getOrElse(false) + + val protoType = invoke0[AnyRef](fd, "getType") + val protoTypeName = typeName(protoType) + val fieldNumber = invoke0[java.lang.Integer](fd, "getNumber").intValue() + + // Check field support and determine encoding + val (isSupported, unsupportedReason, encoding) = + checkFieldSupport(sf.dataType, protoTypeName, isRepeated, enumsAsInts) + + result(sf.name) = ProtobufFieldInfo( + fieldNumber = fieldNumber, + protoTypeName = protoTypeName, + sparkType = sf.dataType, + encoding = encoding, + isSupported = isSupported, + unsupportedReason = unsupportedReason + ) + } + + Some(result.toMap) + } + + /** + * Check if a field type is supported and return encoding information. + * @return (isSupported, unsupportedReason, encoding) + */ + private def checkFieldSupport( + sparkType: DataType, + protoTypeName: String, + isRepeated: Boolean, + enumsAsInts: Boolean): (Boolean, Option[String], Int) = { + + if (isRepeated) { + return (false, Some("repeated fields are not supported"), GpuFromProtobuf.ENC_DEFAULT) + } + + // Check Spark type is one of the supported simple types + sparkType match { + case BooleanType | IntegerType | LongType | FloatType | DoubleType | + StringType | BinaryType => + // Supported Spark type, continue to check encoding + case other => + return (false, Some(s"unsupported Spark type: $other"), GpuFromProtobuf.ENC_DEFAULT) + } + + // Determine encoding based on Spark type and proto type combination + val encoding = (sparkType, protoTypeName) match { + case (BooleanType, "BOOL") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (IntegerType, "INT32" | "UINT32") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (IntegerType, "SINT32") => Some(GpuFromProtobuf.ENC_ZIGZAG) + case (IntegerType, "FIXED32" | "SFIXED32") => Some(GpuFromProtobuf.ENC_FIXED) + case (LongType, "INT64" | "UINT64") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (LongType, "SINT64") => Some(GpuFromProtobuf.ENC_ZIGZAG) + case (LongType, "FIXED64" | "SFIXED64") => Some(GpuFromProtobuf.ENC_FIXED) + // Spark may upcast smaller integers to LongType + case (LongType, "INT32" | "UINT32" | "SINT32" | "FIXED32" | "SFIXED32") => + val enc = protoTypeName match { + case "SINT32" => GpuFromProtobuf.ENC_ZIGZAG + case "FIXED32" | "SFIXED32" => GpuFromProtobuf.ENC_FIXED + case _ => GpuFromProtobuf.ENC_DEFAULT + } + Some(enc) + case (FloatType, "FLOAT") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (DoubleType, "DOUBLE") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (StringType, "STRING") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (BinaryType, "BYTES") => Some(GpuFromProtobuf.ENC_DEFAULT) + case (IntegerType, "ENUM") if enumsAsInts => Some(GpuFromProtobuf.ENC_DEFAULT) + case _ => None + } + + encoding match { + case Some(enc) => (true, None, enc) + case None => + (false, + Some(s"type mismatch: Spark $sparkType vs Protobuf $protoTypeName"), + GpuFromProtobuf.ENC_DEFAULT) + } + } + + /** + * Analyze which fields are actually required by downstream operations. + * Currently supports analyzing parent Project expressions. + * + * @param allFieldNames All field names in the full schema + * @return Set of field names that are actually required + */ + private def analyzeRequiredFields(allFieldNames: Set[String]): Set[String] = { + // Try to find parent SparkPlanMeta and analyze downstream Project + val parentPlanOpt = findParentPlanMeta() + + parentPlanOpt match { + case Some(planMeta) => + // First, try to analyze the immediate parent + analyzeDownstreamProject(planMeta) match { + case Some(fields) if fields.nonEmpty => + // Successfully identified required fields via schema projection + fields + case _ => + // The immediate parent might be a ProjectExec that just aliases the output. + // Try to look at its parent (the grandparent) for GetStructField references. + planMeta.parent match { + case Some(grandParentMeta: SparkPlanMeta[_]) => + analyzeDownstreamProject(grandParentMeta) match { + case Some(fields) if fields.nonEmpty => fields + case _ => allFieldNames + } + case _ => allFieldNames + } + } + case None => + // No parent SparkPlanMeta found in the meta tree, assume all fields are needed + allFieldNames + } + } + + /** + * Find the parent SparkPlanMeta by traversing up the parent chain. + */ + private def findParentPlanMeta(): Option[SparkPlanMeta[_]] = { + def traverse(meta: Option[RapidsMeta[_, _, _]]): Option[SparkPlanMeta[_]] = { + meta match { + case Some(p: SparkPlanMeta[_]) => Some(p) + case Some(p: RapidsMeta[_, _, _]) => traverse(p.parent) + case _ => None + } + } + traverse(parent) + } + + /** + * Analyze a Project plan to find which struct fields are actually used. + * This looks for GetStructField expressions that reference our protobuf output. + */ + private def analyzeDownstreamProject(planMeta: SparkPlanMeta[_]): Option[Set[String]] = { + planMeta.wrapped match { + case p: ProjectExec => + // Collect all GetStructField references from the project list + val fieldRefs = mutable.Set[String]() + var hasDirectStructRef = false + + p.projectList.foreach { expr => + collectStructFieldReferences(expr, fieldRefs, hasDirectStructRefHolder = () => { + hasDirectStructRef = true + }) + } + + if (hasDirectStructRef) { + // If the entire struct is referenced directly (not via GetStructField), + // we need all fields + None + } else if (fieldRefs.nonEmpty) { + Some(fieldRefs.toSet) + } else { + // No GetStructField found - this shouldn't happen for valid plans + // where from_protobuf is followed by field access + None + } + case _ => + // Not a ProjectExec, cannot analyze schema projection + None + } + } + + /** + * Recursively collect field names from GetStructField expressions. + * Also tracks if the struct is used directly without field extraction. + */ + private def collectStructFieldReferences( + expr: Expression, + fieldRefs: mutable.Set[String], + hasDirectStructRefHolder: () => Unit): Unit = { + expr match { + case GetStructField(child, ordinal, nameOpt) => + // Check if this GetStructField extracts from our protobuf struct + if (isProtobufStructReference(child)) { + // Get field name from the schema using ordinal + val fieldName = nameOpt.getOrElse { + if (ordinal < fullSchema.fields.length) { + fullSchema.fields(ordinal).name + } else { + s"_$ordinal" + } + } + fieldRefs += fieldName + // Don't recurse into child - we've handled this protobuf reference + } else { + // Child is not a protobuf struct, recurse to check for nested access + collectStructFieldReferences(child, fieldRefs, hasDirectStructRefHolder) + } + + case _ => + // Check if this expression directly references our protobuf struct + // without extracting a field (e.g., passing the whole struct to a function) + if (isProtobufStructReference(expr)) { + hasDirectStructRefHolder() + } + // Recursively check children + expr.children.foreach { child => + collectStructFieldReferences(child, fieldRefs, hasDirectStructRefHolder) + } + } + } + + /** + * Check if an expression references the output of a protobuf decode expression. + * This can be either: + * 1. The ProtobufDataToCatalyst expression itself + * 2. An AttributeReference that references the output of ProtobufDataToCatalyst + * (when accessing from a downstream ProjectExec) + */ + private def isProtobufStructReference(expr: Expression): Boolean = { + // Check if expr is a ProtobufDataToCatalyst expression + if (expr.getClass.getName.contains("ProtobufDataToCatalyst")) { + return true + } + + // Check if expr is an AttributeReference with the same schema as our protobuf output + // This handles the case where GetStructField references a column from a parent Project + expr match { + case attr: AttributeReference => + // Check if the data type matches our full schema (struct type from protobuf) + attr.dataType match { + case st: StructType => + // Compare field names and types only. We intentionally do not compare + // nullable flags because schema transformations (like projections or + // certain optimizations) may change nullability while the underlying + // schema structure remains the same. For schema projection detection, + // matching names and types is sufficient to identify protobuf output. + st.fields.length == fullSchema.fields.length && + st.fields.zip(fullSchema.fields).forall { case (a, b) => + a.name == b.name && a.dataType == b.dataType + } + case _ => false + } + case _ => false + } + } + + override def convertToGpu(child: Expression): GpuExpression = { + GpuFromProtobuf( + fullSchema, decodedFieldIndices, fieldNumbers, cudfTypeIds, cudfTypeScales, + failOnErrors, child) + } + } + ) + } + + private def getMessageName(e: Expression): String = + invoke0[String](e, "messageName") + + /** + * Newer Spark versions may carry an in-expression descriptor set payload + * (e.g. binaryDescriptorSet). + * Spark 3.4.x does not, so callers should fall back to descFilePath(). + */ + private def getDescriptorBytes(e: Expression): Option[Array[Byte]] = { + // Spark 4.x/3.5+ (depending on the API): may be Array[Byte] or Option[Array[Byte]]. + val direct = Try(invoke0[Array[Byte]](e, "binaryDescriptorSet")).toOption + direct.orElse { + Try(invoke0[Option[Array[Byte]]](e, "binaryDescriptorSet")).toOption.flatten + } + } + + private def getDescFilePath(e: Expression): Option[String] = + Try(invoke0[Option[String]](e, "descFilePath")).toOption.flatten + + private def writeTempDescFile(descBytes: Array[Byte]): String = { + val tmp: Path = Files.createTempFile("spark-rapids-protobuf-desc-", ".desc") + Files.write(tmp, descBytes) + // deleteOnExit() is not guaranteed to run on abnormal JVM termination, but these + // descriptor files are small (typically < 10KB) and only created when using + // binaryDescriptorSet (Spark 4.0+). The risk of temporary file accumulation is + // acceptable for this use case. + tmp.toFile.deleteOnExit() + tmp.toString + } + + private def buildMessageDescriptorWithSparkProtobuf( + messageName: String, + descFilePathOpt: Option[String]): AnyRef = { + val cls = ShimReflectionUtils.loadClass(sparkProtobufUtilsObjectClassName) + val module = cls.getField("MODULE$").get(null) + // buildDescriptor(messageName: String, descFilePath: Option[String]) + val m = cls.getMethod("buildDescriptor", classOf[String], classOf[scala.Option[_]]) + m.invoke(module, messageName, descFilePathOpt).asInstanceOf[AnyRef] + } + + private def typeName(t: AnyRef): String = { + if (t == null) { + "null" + } else { + // Prefer Enum.name() when available; fall back to toString. + Try(invoke0[String](t, "name")).getOrElse(t.toString) + } + } + + private def getOptionsMap(e: Expression): Map[String, String] = { + val opt = Try(invoke0[scala.collection.Map[String, String]](e, "options")).toOption + opt.map(_.toMap).getOrElse(Map.empty) + } + + private def invoke0[T](obj: AnyRef, method: String): T = + obj.getClass.getMethod(method).invoke(obj).asInstanceOf[T] + + private def invoke1[T](obj: AnyRef, method: String, arg0Cls: Class[_], arg0: AnyRef): T = + obj.getClass.getMethod(method, arg0Cls).invoke(obj, arg0).asInstanceOf[T] +} diff --git a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala index 6e28a071a00..56bfa229051 100644 --- a/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala +++ b/sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. + * Copyright (c) 2022-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -162,7 +162,7 @@ trait Spark340PlusNonDBShims extends Spark331PlusNonDBShims { ), GpuElementAtMeta.elementAtRule(true) ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap - super.getExprs ++ shimExprs + super.getExprs ++ shimExprs ++ ProtobufExprShims.exprs } override def getDataWriteCmds: Map[Class[_ <: DataWritingCommand],