From 7a061ac3d2ddcc5ba78d81aebf279f5778ce2985 Mon Sep 17 00:00:00 2001 From: Mohammad Linjawi Date: Sun, 10 May 2026 14:08:15 +0300 Subject: [PATCH] [VL] Add Spark 4.1 TimeType support Map Substrait TIME to Velox TIME_MICRO_UTC and add time literal support. Convert Spark TimeType nanos to Substrait/Velox micros for literals and convert Velox micros back to Spark UnsafeRow nanos in columnar-to-row. --- .../backendsapi/velox/VeloxValidatorApi.scala | 2 + .../execution/VeloxColumnarToRowExec.scala | 2 + .../gluten/execution/VeloxLiteralSuite.scala | 19 +++ .../serializer/VeloxColumnarToRowConverter.cc | 160 +++++++++++++++++- .../serializer/VeloxColumnarToRowConverter.h | 1 + cpp/velox/substrait/SubstraitParser.cc | 6 + cpp/velox/substrait/SubstraitToVeloxExpr.cc | 2 + .../substrait/VeloxSubstraitSignature.cc | 7 + cpp/velox/substrait/VeloxToSubstraitType.cc | 6 + cpp/velox/tests/FunctionTest.cc | 17 ++ cpp/velox/tests/VeloxColumnarToRowTest.cc | 9 + .../tests/VeloxSubstraitSignatureTest.cc | 2 + cpp/velox/tests/VeloxToSubstraitTypeTest.cc | 7 + .../expression/ExpressionBuilder.java | 13 ++ .../expression/StructLiteralNode.java | 3 + .../substrait/expression/TimeLiteralNode.java | 37 ++++ .../gluten/substrait/type/TimeTypeNode.java | 39 +++++ .../gluten/substrait/type/TypeBuilder.java | 4 + .../HashAggregateExecBaseTransformer.scala | 1 + .../gluten/expression/ConverterUtils.scala | 78 +++++++++ 20 files changed, 411 insertions(+), 4 deletions(-) create mode 100644 gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/TimeLiteralNode.java create mode 100644 gluten-substrait/src/main/java/org/apache/gluten/substrait/type/TimeTypeNode.java diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxValidatorApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxValidatorApi.scala index f46b8b03540b..b3e6b8be39b5 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxValidatorApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxValidatorApi.scala @@ -19,6 +19,7 @@ package org.apache.gluten.backendsapi.velox import org.apache.gluten.backendsapi.{BackendsApiManager, ValidatorApi} import org.apache.gluten.config.VeloxConfig import org.apache.gluten.execution.ValidationResult +import org.apache.gluten.expression.ConverterUtils import org.apache.gluten.substrait.`type`.TypeNode import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.expression.ExpressionNode @@ -107,6 +108,7 @@ object VeloxValidatorApi { private def isPrimitiveType(dataType: DataType): Boolean = { val enableTimestampNtzValidation = VeloxConfig.get.enableTimestampNtzValidation dataType match { + case dt if ConverterUtils.isSupportedTimeType(dt) => true case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | StringType | BinaryType | _: DecimalType | DateType | TimestampType | YearMonthIntervalType.DEFAULT | NullType => diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala index 2354ebf39faf..ab10acced065 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala @@ -19,6 +19,7 @@ package org.apache.gluten.execution import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.columnarbatch.{ColumnarBatches, VeloxColumnarBatches} import org.apache.gluten.exception.GlutenNotSupportException +import org.apache.gluten.expression.ConverterUtils import org.apache.gluten.iterator.Iterators import org.apache.gluten.runtime.Runtimes import org.apache.gluten.vectorized.{NativeColumnarToRowInfo, NativeColumnarToRowJniWrapper} @@ -41,6 +42,7 @@ case class VeloxColumnarToRowExec(child: SparkPlan) extends ColumnarToRowExecBas // Depending on the input type, VeloxColumnarToRowConverter. for (field <- schema.fields) { field.dataType match { + case dt if ConverterUtils.isSupportedTimeType(dt) => case _: BooleanType => case _: ByteType => case _: ShortType => diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxLiteralSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxLiteralSuite.scala index cf2e7257f528..c883d19572db 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxLiteralSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxLiteralSuite.scala @@ -34,6 +34,7 @@ class VeloxLiteralSuite extends VeloxWholeStageTransformerSuite { .set("spark.sql.shuffle.partitions", "1") .set("spark.memory.offHeap.size", "2g") .set("spark.unsafe.exceptionOnMemoryLeak", "true") + .set("spark.sql.ansi.enabled", "false") .set("spark.sql.autoBroadcastJoinThreshold", "-1") .set("spark.sql.sources.useV1SourceList", "avro") } @@ -56,6 +57,16 @@ class VeloxLiteralSuite extends VeloxWholeStageTransformerSuite { } } + def validateOffloadPlan(sql: String): Unit = { + val df = spark.sql(sql) + val plan = df.queryExecution.executedPlan + assert(plan.find(_.isInstanceOf[ProjectExecTransformer]).isDefined, sql) + assert(plan.find(_.isInstanceOf[ProjectExec]).isEmpty, sql) + val wholeStageTransformers = plan.collect { case w: WholeStageTransformer => w } + assert(wholeStageTransformers.nonEmpty, sql) + wholeStageTransformers.foreach(_.nativePlanString()) + } + test("Struct Literal") { validateOffloadResult("SELECT struct('Spark', 5)") validateOffloadResult("SELECT struct(7, struct(5, 'test'))") @@ -135,6 +146,14 @@ class VeloxLiteralSuite extends VeloxWholeStageTransformerSuite { validateOffloadResult("SELECT DATE'2020-12-31', DATE'2020-12-30'") } + testWithMinSparkVersion("Time Literal", "4.1") { + // Spark 4.1 cannot collect TIME through Row encoders yet, so this validates planning and + // native plan conversion without comparing collected results. + validateOffloadPlan("SELECT TIME'12:34:56.123456', TIME'00:00:00'") + validateOffloadPlan("SELECT array(TIME'12:34:56.123456', TIME'00:00:00')") + validateOffloadPlan("SELECT struct(TIME'12:34:56.123456')") + } + test("Literal Fallback") { validateFallbackResult("SELECT struct(cast(null as struct))") validateFallbackResult("SELECT array(struct(1, 'a'), null)") diff --git a/cpp/velox/operators/serializer/VeloxColumnarToRowConverter.cc b/cpp/velox/operators/serializer/VeloxColumnarToRowConverter.cc index cbf25ff7ba76..91cebb8afded 100644 --- a/cpp/velox/operators/serializer/VeloxColumnarToRowConverter.cc +++ b/cpp/velox/operators/serializer/VeloxColumnarToRowConverter.cc @@ -22,20 +22,172 @@ #include "memory/VeloxColumnarBatch.h" #include "utils/Exception.h" #include "velox/row/UnsafeRowFast.h" +#include "velox/vector/DecodedVector.h" +#include "velox/vector/FlatVector.h" +#include "velox/vector/LazyVector.h" using namespace facebook; namespace gluten { +namespace { + +constexpr int64_t kMicrosToNanos = 1000; + +bool isTimeMicroUtc(const velox::TypePtr& type) { + return type->equivalent(*velox::TIME_MICRO_UTC()); +} + +bool containsTimeMicroUtc(const velox::TypePtr& type) { + if (isTimeMicroUtc(type)) { + return true; + } + + switch (type->kind()) { + case velox::TypeKind::ARRAY: + return containsTimeMicroUtc(type->asArray().elementType()); + case velox::TypeKind::MAP: + return containsTimeMicroUtc(type->asMap().keyType()) || containsTimeMicroUtc(type->asMap().valueType()); + case velox::TypeKind::ROW: { + const auto& rowType = type->asRow(); + for (const auto& child : rowType.children()) { + if (containsTimeMicroUtc(child)) { + return true; + } + } + return false; + } + default: + return false; + } +} + +velox::VectorPtr normalizeTimeForSparkUnsafeRow(const velox::VectorPtr& vector, velox::memory::MemoryPool* pool); + +velox::VectorPtr normalizeTimeScalarForSparkUnsafeRow(const velox::VectorPtr& vector, velox::memory::MemoryPool* pool) { + velox::DecodedVector decoded(*vector); + auto normalized = velox::BaseVector::create(velox::BIGINT(), vector->size(), pool); + auto* flat = normalized->asFlatVector(); + + for (auto row = 0; row < vector->size(); ++row) { + if (decoded.isNullAt(row)) { + flat->setNull(row, true); + } else { + flat->set(row, decoded.valueAt(row) * kMicrosToNanos); + } + } + return normalized; +} + +velox::VectorPtr loadedFlatVector(const velox::VectorPtr& vector) { + auto loaded = velox::BaseVector::loadedVectorShared(vector); + velox::BaseVector::flattenVector(loaded); + if (loaded->isLazy()) { + loaded = loaded->as()->loadedVectorShared(); + } + return loaded; +} + +velox::VectorPtr normalizeArrayForSparkUnsafeRow(const velox::VectorPtr& vector, velox::memory::MemoryPool* pool) { + auto array = loadedFlatVector(vector)->as(); + auto elements = normalizeTimeForSparkUnsafeRow(array->elements(), pool); + if (elements == array->elements()) { + return vector; + } + return std::make_shared( + pool, + velox::ARRAY(elements->type()), + array->nulls(), + array->size(), + array->offsets(), + array->sizes(), + elements, + array->getNullCount()); +} + +velox::VectorPtr normalizeMapForSparkUnsafeRow(const velox::VectorPtr& vector, velox::memory::MemoryPool* pool) { + auto map = loadedFlatVector(vector)->as(); + auto keys = normalizeTimeForSparkUnsafeRow(map->mapKeys(), pool); + auto values = normalizeTimeForSparkUnsafeRow(map->mapValues(), pool); + if (keys == map->mapKeys() && values == map->mapValues()) { + return vector; + } + return std::make_shared( + pool, + velox::MAP(keys->type(), values->type()), + map->nulls(), + map->size(), + map->offsets(), + map->sizes(), + keys, + values, + map->getNullCount(), + map->hasSortedKeys()); +} + +velox::RowVectorPtr normalizeRowForSparkUnsafeRow( + const velox::RowVectorPtr& rowVector, + velox::memory::MemoryPool* pool) { + std::vector children; + children.reserve(rowVector->childrenSize()); + bool changed = false; + for (const auto& child : rowVector->children()) { + auto normalized = normalizeTimeForSparkUnsafeRow(child, pool); + changed = changed || normalized != child; + children.emplace_back(std::move(normalized)); + } + + if (!changed) { + return rowVector; + } + + std::vector childTypes; + childTypes.reserve(children.size()); + for (const auto& child : children) { + childTypes.emplace_back(child->type()); + } + return std::make_shared( + pool, + velox::ROW(velox::asRowType(rowVector->type())->names(), std::move(childTypes)), + rowVector->nulls(), + rowVector->size(), + std::move(children), + rowVector->getNullCount()); +} + +velox::VectorPtr normalizeTimeForSparkUnsafeRow(const velox::VectorPtr& vector, velox::memory::MemoryPool* pool) { + if (!containsTimeMicroUtc(vector->type())) { + return vector; + } + + if (isTimeMicroUtc(vector->type())) { + return normalizeTimeScalarForSparkUnsafeRow(vector, pool); + } + + switch (vector->typeKind()) { + case velox::TypeKind::ARRAY: + return normalizeArrayForSparkUnsafeRow(vector, pool); + case velox::TypeKind::MAP: + return normalizeMapForSparkUnsafeRow(vector, pool); + case velox::TypeKind::ROW: + return normalizeRowForSparkUnsafeRow(std::dynamic_pointer_cast(loadedFlatVector(vector)), pool); + default: + return vector; + } +} + +} // namespace void VeloxColumnarToRowConverter::refreshStates(facebook::velox::RowVectorPtr rowVector, int64_t startRow) { - auto vectorLength = rowVector->size(); - numCols_ = rowVector->childrenSize(); + rowVectorForUnsafeRow_ = normalizeRowForSparkUnsafeRow(rowVector, veloxPool_.get()); + + auto vectorLength = rowVectorForUnsafeRow_->size(); + numCols_ = rowVectorForUnsafeRow_->childrenSize(); - fast_ = std::make_unique(rowVector); + fast_ = std::make_unique(rowVectorForUnsafeRow_); int64_t totalMemorySize; - if (auto fixedRowSize = velox::row::UnsafeRowFast::fixedRowSize(velox::asRowType(rowVector->type()))) { + if (auto fixedRowSize = velox::row::UnsafeRowFast::fixedRowSize(velox::asRowType(rowVectorForUnsafeRow_->type()))) { auto rowSize = fixedRowSize.value(); // make sure it has at least one row numRows_ = std::max(1, std::min(memThreshold_ / rowSize, vectorLength - startRow)); diff --git a/cpp/velox/operators/serializer/VeloxColumnarToRowConverter.h b/cpp/velox/operators/serializer/VeloxColumnarToRowConverter.h index 540d991a6c65..674a7d40d10a 100644 --- a/cpp/velox/operators/serializer/VeloxColumnarToRowConverter.h +++ b/cpp/velox/operators/serializer/VeloxColumnarToRowConverter.h @@ -42,6 +42,7 @@ class VeloxColumnarToRowConverter final : public ColumnarToRowConverter { std::shared_ptr veloxPool_; std::shared_ptr fast_; facebook::velox::BufferPtr veloxBuffers_; + facebook::velox::RowVectorPtr rowVectorForUnsafeRow_; int64_t memThreshold_; }; diff --git a/cpp/velox/substrait/SubstraitParser.cc b/cpp/velox/substrait/SubstraitParser.cc index c67ad56f0932..afa7c46f6ab2 100644 --- a/cpp/velox/substrait/SubstraitParser.cc +++ b/cpp/velox/substrait/SubstraitParser.cc @@ -76,6 +76,8 @@ TypePtr SubstraitParser::parseType(const ::substrait::Type& substraitType, bool return UNKNOWN(); case ::substrait::Type::KindCase::kDate: return DATE(); + case ::substrait::Type::KindCase::kTime: + return TIME_MICRO_UTC(); case ::substrait::Type::KindCase::kTimestampTz: return TIMESTAMP(); case ::substrait::Type::KindCase::kDecimal: { @@ -356,6 +358,9 @@ int64_t SubstraitParser::getLiteralValue(const ::substrait::Expression::Literal& memcpy(&decimalValue, decimal.c_str(), 16); return static_cast(decimalValue); } + if (literal.has_time()) { + return literal.time(); + } return literal.i64(); } @@ -431,6 +436,7 @@ const std::unordered_map SubstraitParser::typeMap_ = { {"fp32", "REAL"}, {"fp64", "DOUBLE"}, {"date", "DATE"}, + {"time", "TIME MICRO UTC"}, {"ts", "TIMESTAMP"}, {"str", "VARCHAR"}, {"vbin", "VARBINARY"}, diff --git a/cpp/velox/substrait/SubstraitToVeloxExpr.cc b/cpp/velox/substrait/SubstraitToVeloxExpr.cc index 467df25ca881..0ce09b6c3805 100755 --- a/cpp/velox/substrait/SubstraitToVeloxExpr.cc +++ b/cpp/velox/substrait/SubstraitToVeloxExpr.cc @@ -131,6 +131,8 @@ TypePtr getScalarType(const ::substrait::Expression::Literal& literal) { } case ::substrait::Expression_Literal::LiteralTypeCase::kDate: return DATE(); + case ::substrait::Expression_Literal::LiteralTypeCase::kTime: + return TIME_MICRO_UTC(); case ::substrait::Expression_Literal::LiteralTypeCase::kTimestampTz: return TIMESTAMP(); case ::substrait::Expression_Literal::LiteralTypeCase::kString: diff --git a/cpp/velox/substrait/VeloxSubstraitSignature.cc b/cpp/velox/substrait/VeloxSubstraitSignature.cc index 2dd01e8c7218..0b7e508ddadd 100644 --- a/cpp/velox/substrait/VeloxSubstraitSignature.cc +++ b/cpp/velox/substrait/VeloxSubstraitSignature.cc @@ -24,6 +24,9 @@ std::string VeloxSubstraitSignature::toSubstraitSignature(const TypePtr& type) { if (type->isDate()) { return "date"; } + if (type->equivalent(*TIME_MICRO_UTC())) { + return "time"; + } switch (type->kind()) { case TypeKind::BOOLEAN: @@ -159,6 +162,10 @@ TypePtr VeloxSubstraitSignature::fromSubstraitSignature(const std::string& signa return DATE(); } + if (signature == "time") { + return TIME_MICRO_UTC(); + } + if (signature == "nothing") { return UNKNOWN(); } diff --git a/cpp/velox/substrait/VeloxToSubstraitType.cc b/cpp/velox/substrait/VeloxToSubstraitType.cc index b6bcf3bcc978..d6c7f02bb2be 100644 --- a/cpp/velox/substrait/VeloxToSubstraitType.cc +++ b/cpp/velox/substrait/VeloxToSubstraitType.cc @@ -31,6 +31,12 @@ const ::substrait::Type& VeloxToSubstraitTypeConvertor::toSubstraitType( substraitType->set_allocated_date(substraitDate); return *substraitType; } + if (type->equivalent(*velox::TIME_MICRO_UTC())) { + auto substraitTime = google::protobuf::Arena::CreateMessage<::substrait::Type_Time>(&arena); + substraitTime->set_nullability(::substrait::Type_Nullability_NULLABILITY_NULLABLE); + substraitType->set_allocated_time(substraitTime); + return *substraitType; + } switch (type->kind()) { case velox::TypeKind::BOOLEAN: { diff --git a/cpp/velox/tests/FunctionTest.cc b/cpp/velox/tests/FunctionTest.cc index 74be36ee0f54..2a551c8701c4 100644 --- a/cpp/velox/tests/FunctionTest.cc +++ b/cpp/velox/tests/FunctionTest.cc @@ -25,6 +25,7 @@ #include "velox/vector/tests/utils/VectorTestBase.h" #include "substrait/SubstraitParser.h" +#include "substrait/SubstraitToVeloxExpr.h" #include "substrait/SubstraitToVeloxPlan.h" #include "substrait/TypeUtils.h" #include "substrait/VariantToVectorConverter.h" @@ -77,6 +78,22 @@ TEST_F(FunctionTest, getIdxFromNodeName) { ASSERT_EQ(index, 0); } +TEST_F(FunctionTest, substraitTime) { + ::substrait::Type type; + type.mutable_time()->set_nullability(::substrait::Type_Nullability_NULLABILITY_NULLABLE); + auto parsedType = SubstraitParser::parseType(type); + ASSERT_TRUE(parsedType->equivalent(*TIME_MICRO_UTC())); + + ::substrait::Expression_Literal literal; + literal.set_time(45'296'123'456L); + ASSERT_EQ(SubstraitParser::getLiteralValue(literal), 45'296'123'456L); + + SubstraitVeloxExprConverter exprConverter(pool(), {}); + auto constant = exprConverter.toVeloxExpr(literal); + ASSERT_TRUE(constant->type()->equivalent(*TIME_MICRO_UTC())); + ASSERT_EQ(constant->value().value(), 45'296'123'456L); +} + TEST_F(FunctionTest, getNameBeforeDelimiter) { std::string functionSpec = "lte:fp64_fp64"; auto funcName = SubstraitParser::getNameBeforeDelimiter(functionSpec); diff --git a/cpp/velox/tests/VeloxColumnarToRowTest.cc b/cpp/velox/tests/VeloxColumnarToRowTest.cc index 4ed15e134c71..bf3836e65278 100644 --- a/cpp/velox/tests/VeloxColumnarToRowTest.cc +++ b/cpp/velox/tests/VeloxColumnarToRowTest.cc @@ -94,4 +94,13 @@ TEST_F(VeloxColumnarToRowTest, Buffer_int64_int64_with_null) { testRowBufferAddr(vector, expectArr, sizeof(expectArr)); } +TEST_F(VeloxColumnarToRowTest, Buffer_time_micro_utc) { + auto vector = makeRowVector({makeFlatVector({1, 2}, TIME_MICRO_UTC())}); + + uint8_t expectArr[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 232, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 208, 7, 0, 0, 0, 0, 0, 0, + }; + testRowBufferAddr(vector, expectArr, sizeof(expectArr)); +} + } // namespace gluten diff --git a/cpp/velox/tests/VeloxSubstraitSignatureTest.cc b/cpp/velox/tests/VeloxSubstraitSignatureTest.cc index cb62f9764913..3246d96f3bc0 100644 --- a/cpp/velox/tests/VeloxSubstraitSignatureTest.cc +++ b/cpp/velox/tests/VeloxSubstraitSignatureTest.cc @@ -56,6 +56,7 @@ TEST_F(VeloxSubstraitSignatureTest, toSubstraitSignatureWithType) { ASSERT_EQ(toSubstraitSignature(VARBINARY()), "vbin"); ASSERT_EQ(toSubstraitSignature(TIMESTAMP()), "ts"); ASSERT_EQ(toSubstraitSignature(DATE()), "date"); + ASSERT_EQ(toSubstraitSignature(TIME_MICRO_UTC()), "time"); ASSERT_EQ(toSubstraitSignature(ARRAY(BOOLEAN())), "list"); ASSERT_EQ(toSubstraitSignature(ARRAY(INTEGER())), "list"); ASSERT_EQ(toSubstraitSignature(MAP(INTEGER(), BIGINT())), "map"); @@ -107,6 +108,7 @@ TEST_F(VeloxSubstraitSignatureTest, fromSubstraitSignature) { ASSERT_EQ(fromSubstraitSignature("vbin")->kind(), TypeKind::VARBINARY); ASSERT_EQ(fromSubstraitSignature("ts")->kind(), TypeKind::TIMESTAMP); ASSERT_EQ(fromSubstraitSignature("date")->kind(), TypeKind::INTEGER); + ASSERT_TRUE(fromSubstraitSignature("time")->equivalent(*TIME_MICRO_UTC())); ASSERT_EQ(fromSubstraitSignature("dec<18,2>")->kind(), TypeKind::BIGINT); ASSERT_EQ(fromSubstraitSignature("dec<19,2>")->kind(), TypeKind::HUGEINT); diff --git a/cpp/velox/tests/VeloxToSubstraitTypeTest.cc b/cpp/velox/tests/VeloxToSubstraitTypeTest.cc index e7d637ddbb99..441900babed9 100644 --- a/cpp/velox/tests/VeloxToSubstraitTypeTest.cc +++ b/cpp/velox/tests/VeloxToSubstraitTypeTest.cc @@ -62,4 +62,11 @@ TEST_F(VeloxToSubstraitTypeTest, basic) { testTypeConversion(ROW({}, {})); } +TEST_F(VeloxToSubstraitTypeTest, time) { + google::protobuf::Arena arena; + auto substraitType = typeConvertor_->toSubstraitType(arena, TIME_MICRO_UTC()); + ASSERT_EQ(substraitType.kind_case(), ::substrait::Type::KindCase::kTime); + ASSERT_TRUE(SubstraitParser::parseType(substraitType)->equivalent(*TIME_MICRO_UTC())); +} + } // namespace gluten diff --git a/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java b/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java index 4bdef37878c2..f1f7a98c1247 100644 --- a/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java +++ b/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/ExpressionBuilder.java @@ -110,6 +110,14 @@ public static TimestampLiteralNode makeTimestampLiteral(Long vTimestamp, TypeNod return new TimestampLiteralNode(vTimestamp, typeNode); } + public static TimeLiteralNode makeTimeLiteral(Long vTime) { + return new TimeLiteralNode(vTime); + } + + public static TimeLiteralNode makeTimeLiteral(Long vTime, TypeNode typeNode) { + return new TimeLiteralNode(vTime, typeNode); + } + public static StringLiteralNode makeStringLiteral(String vString) { return new StringLiteralNode(vString); } @@ -177,6 +185,11 @@ public static LiteralNode makeLiteral(Object obj, TypeNode typeNode) { if (typeNode instanceof TimestampTypeNode) { return makeTimestampLiteral((Long) obj, typeNode); } + if (typeNode instanceof TimeTypeNode) { + // Spark stores TimeType literals as nanoseconds since midnight. Substrait time literals + // and Velox TIME_MICRO_UTC use microseconds since midnight. + return makeTimeLiteral(((Long) obj) / 1000L, typeNode); + } if (typeNode instanceof StringTypeNode) { return makeStringLiteral(obj.toString(), typeNode); } diff --git a/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/StructLiteralNode.java b/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/StructLiteralNode.java index 981fa36f4c6a..caf2537bdd15 100644 --- a/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/StructLiteralNode.java +++ b/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/StructLiteralNode.java @@ -61,6 +61,9 @@ public LiteralNode getFieldLiteral(int index) { if (type instanceof TimestampTypeNode) { return ExpressionBuilder.makeLiteral(value.getLong(index), type); } + if (type instanceof TimeTypeNode) { + return ExpressionBuilder.makeLiteral(value.getLong(index), type); + } if (type instanceof StringTypeNode) { return ExpressionBuilder.makeLiteral(value.getUTF8String(index), type); } diff --git a/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/TimeLiteralNode.java b/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/TimeLiteralNode.java new file mode 100644 index 000000000000..e3e36059ddaa --- /dev/null +++ b/gluten-substrait/src/main/java/org/apache/gluten/substrait/expression/TimeLiteralNode.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.substrait.expression; + +import org.apache.gluten.substrait.type.TimeTypeNode; +import org.apache.gluten.substrait.type.TypeNode; + +import io.substrait.proto.Expression.Literal.Builder; + +public class TimeLiteralNode extends LiteralNodeWithValue { + public TimeLiteralNode(Long value) { + super(value, new TimeTypeNode(true)); + } + + public TimeLiteralNode(Long value, TypeNode typeNode) { + super(value, typeNode); + } + + @Override + protected void updateLiteralBuilder(Builder literalBuilder, Long value) { + literalBuilder.setTime(value); + } +} diff --git a/gluten-substrait/src/main/java/org/apache/gluten/substrait/type/TimeTypeNode.java b/gluten-substrait/src/main/java/org/apache/gluten/substrait/type/TimeTypeNode.java new file mode 100644 index 000000000000..4f340424734a --- /dev/null +++ b/gluten-substrait/src/main/java/org/apache/gluten/substrait/type/TimeTypeNode.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.substrait.type; + +import io.substrait.proto.Type; + +public class TimeTypeNode extends TypeNode { + + public TimeTypeNode(Boolean nullable) { + super(nullable); + } + + @Override + public Type toProtobuf() { + Type.Time.Builder timeBuilder = Type.Time.newBuilder(); + if (nullable) { + timeBuilder.setNullability(Type.Nullability.NULLABILITY_NULLABLE); + } else { + timeBuilder.setNullability(Type.Nullability.NULLABILITY_REQUIRED); + } + Type.Builder builder = Type.newBuilder(); + builder.setTime(timeBuilder.build()); + return builder.build(); + } +} diff --git a/gluten-substrait/src/main/java/org/apache/gluten/substrait/type/TypeBuilder.java b/gluten-substrait/src/main/java/org/apache/gluten/substrait/type/TypeBuilder.java index 28cb10be27d6..c9af82d57f55 100644 --- a/gluten-substrait/src/main/java/org/apache/gluten/substrait/type/TypeBuilder.java +++ b/gluten-substrait/src/main/java/org/apache/gluten/substrait/type/TypeBuilder.java @@ -69,6 +69,10 @@ public static TypeNode makeDate(Boolean nullable) { return new DateTypeNode(nullable); } + public static TypeNode makeTime(Boolean nullable) { + return new TimeTypeNode(nullable); + } + public static TypeNode makeIntervalYear(Boolean nullable) { return new IntervalYearTypeNode(nullable); } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala index f4e174d9f509..f8ba544b28bb 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/HashAggregateExecBaseTransformer.scala @@ -98,6 +98,7 @@ abstract class HashAggregateExecBaseTransformer( protected def checkType(dataType: DataType): Boolean = { dataType match { + case dt if ConverterUtils.isSupportedTimeType(dt) => true case BooleanType | StringType | TimestampType | DateType | BinaryType => true case _: NumericType => true diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ConverterUtils.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ConverterUtils.scala index 7eb7a7322ad1..3f3a5a017fc4 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ConverterUtils.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ConverterUtils.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.SparkReflectionUtil import com.google.protobuf.CodedInputStream import io.substrait.proto.Type @@ -33,10 +34,79 @@ import io.substrait.proto.Type import java.util.{ArrayList => JArrayList, List => JList, Locale} import scala.collection.JavaConverters._ +import scala.util.control.NonFatal case class ExpressionType(dataType: DataType, nullable: Boolean) {} object ConverterUtils extends Logging { + final private val SparkTimeTypeClassName = "org.apache.spark.sql.types.TimeType" + final private val SparkTimeTypeObjectClassName = SparkTimeTypeClassName + "$" + final private val TimeMicroPrecision = 6 + final private val TimeCatalogPattern = "time(?:\\((\\d+)\\))?".r + + private lazy val sparkTimeTypeClass: Option[Class[_]] = { + try { + Some(SparkReflectionUtil.classForName(SparkTimeTypeClassName)) + } catch { + case _: ClassNotFoundException => None + } + } + + def isTimeType(dataType: DataType): Boolean = { + val typeName = dataType.typeName.toLowerCase(Locale.ROOT) + val catalogString = dataType.catalogString.toLowerCase(Locale.ROOT) + sparkTimeTypeClass.exists(_.isInstance(dataType)) || + typeName == "time" || + catalogString == "time" || + catalogString.startsWith("time(") + } + + def isSupportedTimeType(dataType: DataType): Boolean = { + isTimeType(dataType) && timePrecision(dataType).forall(_ <= TimeMicroPrecision) + } + + private def timePrecision(dataType: DataType): Option[Int] = { + if (!isTimeType(dataType)) { + return None + } + try { + Some(dataType.getClass.getMethod("precision").invoke(dataType).asInstanceOf[Int]) + } catch { + case NonFatal(_) => + dataType.catalogString.toLowerCase(Locale.ROOT) match { + case TimeCatalogPattern(precision) if precision != null => + Some(precision.toInt) + case TimeCatalogPattern(_) => + None + case _ => + None + } + } + } + + private def validateTimeType(dataType: DataType): Unit = { + timePrecision(dataType).foreach { + precision => + if (precision > TimeMicroPrecision) { + throw new GlutenNotSupportException( + s"Type $dataType is not supported. Velox TIME_MICRO_UTC supports up to " + + s"$TimeMicroPrecision fractional digits.") + } + } + } + + private def defaultTimeType(): DataType = { + try { + val module = + SparkReflectionUtil.classForName(SparkTimeTypeObjectClassName).getField("MODULE$").get(null) + module.getClass.getMethod("apply").invoke(module).asInstanceOf[DataType] + } catch { + case NonFatal(e) => + throw new GlutenNotSupportException( + "Substrait TIME is only supported when Spark TimeType is available.", + e) + } + } /** * Get the source Attribute for the input Expression. It will traverse the Expression tree in a @@ -162,6 +232,8 @@ object ConverterUtils extends Logging { (BinaryType, isNullable(substraitType.getBinary.getNullability)) case Type.KindCase.TIMESTAMP_TZ => (TimestampType, isNullable(substraitType.getTimestampTz.getNullability)) + case Type.KindCase.TIME => + (defaultTimeType(), isNullable(substraitType.getTime.getNullability)) case Type.KindCase.DATE => (DateType, isNullable(substraitType.getDate.getNullability)) case Type.KindCase.DECIMAL => @@ -197,6 +269,9 @@ object ConverterUtils extends Logging { def getTypeNode(datatype: DataType, nullable: Boolean): TypeNode = { datatype match { + case dt if isTimeType(dt) => + validateTimeType(dt) + TypeBuilder.makeTime(nullable) case BooleanType => TypeBuilder.makeBoolean(nullable) case FloatType => @@ -271,6 +346,8 @@ object ConverterUtils extends Logging { BinaryType case _: DateTypeNode => DateType + case _: TimeTypeNode => + defaultTimeType() case _: IntervalYearTypeNode => YearMonthIntervalType.DEFAULT case d: DecimalTypeNode => @@ -398,6 +475,7 @@ object ConverterUtils extends Logging { case FloatType => "fp32" case DoubleType => "fp64" case DateType => "date" + case dt if isTimeType(dt) => "time" case TimestampType => "ts" case StringType => "str" case BinaryType => "vbin"