From fdd0769485a00347b6234827e32d8961d19fd00a Mon Sep 17 00:00:00 2001 From: Pramod Satya Date: Tue, 29 Jul 2025 14:22:28 -0700 Subject: [PATCH 1/4] feat(native): Add Velox to Presto expression converter --- .../presto_cpp/main/types/CMakeLists.txt | 10 + .../main/types/PrestoToVeloxExpr.cpp | 115 +++--- .../presto_cpp/main/types/PrestoToVeloxExpr.h | 2 + .../main/types/VeloxToPrestoExpr.cpp | 331 ++++++++++++++++++ .../presto_cpp/main/types/VeloxToPrestoExpr.h | 108 ++++++ 5 files changed, 514 insertions(+), 52 deletions(-) create mode 100644 presto-native-execution/presto_cpp/main/types/VeloxToPrestoExpr.cpp create mode 100644 presto-native-execution/presto_cpp/main/types/VeloxToPrestoExpr.h diff --git a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt index 8f11b63083536..739039391e916 100644 --- a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt @@ -33,6 +33,16 @@ set_property(TARGET presto_types PROPERTY JOB_POOL_LINK presto_link_job_pool) add_library(presto_velox_plan_conversion OBJECT VeloxPlanConversion.cpp) target_link_libraries(presto_velox_plan_conversion velox_type) +add_library(presto_velox_to_presto_expr VeloxToPrestoExpr.cpp) + +target_link_libraries( + presto_velox_to_presto_expr + presto_exception + presto_type_converter + presto_types + presto_protocol +) + if(PRESTO_ENABLE_TESTING) add_subdirectory(tests) endif() diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp index b83f12ca0e36b..e58171b7b4bfa 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp @@ -36,59 +36,9 @@ std::string toJsonString(const T& value) { } std::string mapScalarFunction(const std::string& name) { - static const std::string prestoDefaultNamespacePrefix = - SystemConfig::instance()->prestoDefaultNamespacePrefix(); - static const std::unordered_map kFunctionNames = { - // Operator overrides: com.facebook.presto.common.function.OperatorType - {"presto.default.$operator$add", - util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "plus")}, - {"presto.default.$operator$between", - util::addDefaultNamespacePrefix( - prestoDefaultNamespacePrefix, "between")}, - {"presto.default.$operator$divide", - util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "divide")}, - {"presto.default.$operator$equal", - util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "eq")}, - {"presto.default.$operator$greater_than", - util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "gt")}, - {"presto.default.$operator$greater_than_or_equal", - util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "gte")}, - {"presto.default.$operator$is_distinct_from", - util::addDefaultNamespacePrefix( - prestoDefaultNamespacePrefix, "distinct_from")}, - {"presto.default.$operator$less_than", - util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "lt")}, - {"presto.default.$operator$less_than_or_equal", - util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "lte")}, - {"presto.default.$operator$modulus", - util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "mod")}, - {"presto.default.$operator$multiply", - util::addDefaultNamespacePrefix( - prestoDefaultNamespacePrefix, "multiply")}, - {"presto.default.$operator$negation", - util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "negate")}, - {"presto.default.$operator$not_equal", - util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "neq")}, - {"presto.default.$operator$subtract", - util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "minus")}, - {"presto.default.$operator$subscript", - util::addDefaultNamespacePrefix( - prestoDefaultNamespacePrefix, "subscript")}, - {"presto.default.$operator$xx_hash_64", - util::addDefaultNamespacePrefix( - prestoDefaultNamespacePrefix, "xxhash64_internal")}, - {"presto.default.combine_hash", - util::addDefaultNamespacePrefix( - prestoDefaultNamespacePrefix, "combine_hash_internal")}, - // Special form function overrides. - {"presto.default.in", "in"}, - }; - std::string lowerCaseName = boost::to_lower_copy(name); - - auto it = kFunctionNames.find(lowerCaseName); - if (it != kFunctionNames.end()) { - return it->second; + if (prestoOperatorMap().find(lowerCaseName) != prestoOperatorMap().end()) { + return prestoOperatorMap().at(lowerCaseName); } return lowerCaseName; @@ -388,6 +338,67 @@ std::optional tryConvertLiteralArray( } } // namespace +const std::unordered_map prestoOperatorMap() { + static const std::string prestoDefaultNamespacePrefix = + SystemConfig::instance()->prestoDefaultNamespacePrefix(); + static const std::unordered_map kPrestoOperatorMap = + { + // Operator overrides: + // com.facebook.presto.common.function.OperatorType + {"presto.default.$operator$add", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "plus")}, + {"presto.default.$operator$between", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "between")}, + {"presto.default.$operator$divide", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "divide")}, + {"presto.default.$operator$equal", + util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "eq")}, + {"presto.default.$operator$greater_than", + util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "gt")}, + {"presto.default.$operator$greater_than_or_equal", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "gte")}, + {"presto.default.$operator$is_distinct_from", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "distinct_from")}, + {"presto.default.$operator$less_than", + util::addDefaultNamespacePrefix(prestoDefaultNamespacePrefix, "lt")}, + {"presto.default.$operator$less_than_or_equal", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "lte")}, + {"presto.default.$operator$modulus", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "mod")}, + {"presto.default.$operator$multiply", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "multiply")}, + {"presto.default.$operator$negation", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "negate")}, + {"presto.default.$operator$not_equal", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "neq")}, + {"presto.default.$operator$subtract", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "minus")}, + {"presto.default.$operator$subscript", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "subscript")}, + {"presto.default.$operator$xx_hash_64", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "xxhash64_internal")}, + {"presto.default.combine_hash", + util::addDefaultNamespacePrefix( + prestoDefaultNamespacePrefix, "combine_hash_internal")}, + // Special form function overrides. + {"presto.default.in", "in"}, + }; + return kPrestoOperatorMap; +} + std::optional VeloxExprConverter::tryConvertDate( const protocol::CallExpression& pexpr) const { static const std::string prestoDefaultNamespacePrefix = diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h index f63a84ec35ad2..55c6adbcfdeaf 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h @@ -20,6 +20,8 @@ namespace facebook::presto { +const std::unordered_map prestoOperatorMap(); + class VeloxExprConverter { public: VeloxExprConverter(velox::memory::MemoryPool* pool, TypeParser* typeParser) diff --git a/presto-native-execution/presto_cpp/main/types/VeloxToPrestoExpr.cpp b/presto-native-execution/presto_cpp/main/types/VeloxToPrestoExpr.cpp new file mode 100644 index 0000000000000..cb510879f51db --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/VeloxToPrestoExpr.cpp @@ -0,0 +1,331 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/types/VeloxToPrestoExpr.h" +#include +#include "presto_cpp/main/types/PrestoToVeloxExpr.h" +#include "velox/core/ITypedExpr.h" +#include "velox/expression/ExprConstants.h" +#include "velox/vector/ConstantVector.h" + +using namespace facebook::presto; + +namespace facebook::presto::expression { + +using VariableReferenceExpressionPtr = + std::shared_ptr; + +namespace { +const std::string kVariable = "variable"; +const std::string kCall = "call"; +const std::string kStatic = "$static"; +const std::string kSpecial = "special"; +const std::string kDereference = "DEREFERENCE"; +const std::string kRowConstructor = "ROW_CONSTRUCTOR"; +const std::string kSwitch = "SWITCH"; +const std::string kWhen = "WHEN"; + +protocol::TypeSignature getTypeSignature(const velox::TypePtr& type) { + std::string signature = type->toString(); + if (type->isPrimitiveType()) { + boost::algorithm::to_lower(signature); + } else { + // `Type::toString()` API in Velox returns a capitalized string, which could + // contain extraneous characters like `""`, `:` for certain types like ROW. + // This string should be converted to lower case and modified to get a valid + // protocol::TypeSignature. Eg: for type `ROW(a BIGINT, b VARCHAR)`, Velox + // `Type::toString()` API returns the string `ROW`. The + // corresponding protocol::TypeSignature is `row(a bigint, b varchar)`. + boost::algorithm::erase_all(signature, "\"\""); + boost::algorithm::replace_all(signature, ":", " "); + boost::algorithm::replace_all(signature, "<", "("); + boost::algorithm::replace_all(signature, ">", ")"); + boost::algorithm::to_lower(signature); + } + return signature; +} + +VariableReferenceExpressionPtr getVariableReferenceExpression( + const velox::core::FieldAccessTypedExpr* field) { + protocol::VariableReferenceExpression vexpr; + vexpr.name = field->name(); + vexpr._type = kVariable; + vexpr.type = getTypeSignature(field->type()); + return std::make_shared(vexpr); +} + +bool isPrestoSpecialForm(const std::string& name) { + static const std::unordered_set kPrestoSpecialForms = { + "if", + "null_if", + "switch", + "when", + "is_null", + "coalesce", + "in", + "and", + "or", + "dereference", + "row_constructor", + "bind"}; + return kPrestoSpecialForms.contains(name); +} + +json getWhenSpecialForm( + const velox::TypePtr& type, + const json::array_t& whenArgs) { + json when; + when["@type"] = kSpecial; + when["form"] = kWhen; + when["arguments"] = whenArgs; + when["returnType"] = getTypeSignature(type); + return when; +} + +const std::unordered_map& veloxToPrestoOperatorMap() { + static std::unordered_map veloxToPrestoOperatorMap = + {{"cast", "presto.default.$operator$cast"}}; + for (const auto& entry : prestoOperatorMap()) { + veloxToPrestoOperatorMap[entry.second] = entry.first; + } + return veloxToPrestoOperatorMap; +} +} // namespace + +std::string VeloxToPrestoExprConverter::getValueBlock( + const velox::VectorPtr& vector) const { + std::ostringstream output; + serde_->serializeSingleColumn(vector, nullptr, pool_, &output); + const auto serialized = output.str(); + const auto serializedSize = serialized.size(); + return velox::encoding::Base64::encode(serialized.c_str(), serializedSize); +} + +ConstantExpressionPtr VeloxToPrestoExprConverter::getConstantExpression( + const velox::core::ConstantTypedExpr* constantExpr) const { + protocol::ConstantExpression cexpr; + cexpr.type = getTypeSignature(constantExpr->type()); + cexpr.valueBlock.data = getValueBlock(constantExpr->toConstantVector(pool_)); + return std::make_shared(cexpr); +} + +std::vector +VeloxToPrestoExprConverter::getSwitchSpecialFormExpressionArgs( + const velox::core::CallTypedExpr* switchExpr) const { + std::vector result; + const auto& switchInputs = switchExpr->inputs(); + const auto numInputs = switchInputs.size(); + for (auto i = 0; i < numInputs - 1; i += 2) { + const json::array_t resultWhenArgs = { + getRowExpression(switchInputs[i]), + getRowExpression(switchInputs[i + 1])}; + result.emplace_back( + getWhenSpecialForm(switchInputs[i + 1]->type(), resultWhenArgs)); + } + + // Else clause. + if (numInputs % 2 != 0) { + result.emplace_back(getRowExpression(switchInputs[numInputs - 1])); + } + return result; +} + +SpecialFormExpressionPtr VeloxToPrestoExprConverter::getSpecialFormExpression( + const velox::core::CallTypedExpr* expr) const { + VELOX_CHECK( + isPrestoSpecialForm(expr->name()), + "Not a special form expression: {}.", + expr->toString()); + + protocol::SpecialFormExpression result; + result._type = kSpecial; + result.returnType = getTypeSignature(expr->type()); + auto name = expr->name(); + // Presto requires the field form to be in upper case. + std::transform(name.begin(), name.end(), name.begin(), ::toupper); + protocol::Form form; + protocol::from_json(name, form); + result.form = form; + + // Arguments for switch expression include 'WHEN' special form expression(s) + // so they are constructed separately. + if (name == kSwitch) { + result.arguments = getSwitchSpecialFormExpressionArgs(expr); + } else { + // Presto special form expressions that are not of type `SWITCH`, such as + // `IN`, `AND`, `OR` etc,. are handled in this clause. The list of Presto + // special form expressions can be found in `kPrestoSpecialForms` in the + // helper function `isPrestoSpecialForm`. + auto exprInputs = expr->inputs(); + for (const auto& input : exprInputs) { + result.arguments.push_back(getRowExpression(input)); + } + } + + return std::make_shared(result); +} + +SpecialFormExpressionPtr +VeloxToPrestoExprConverter::getRowConstructorExpression( + const velox::core::ConstantTypedExpr* constantExpr) const { + json result; + result["@type"] = kSpecial; + result["form"] = kRowConstructor; + result["returnType"] = getTypeSignature(constantExpr->valueVector()->type()); + + const auto& constVector = constantExpr->toConstantVector(pool_); + const auto* rowVector = constVector->valueVector()->as(); + VELOX_CHECK_NOT_NULL( + rowVector, + "Constant vector not of row type: {}.", + constVector->type()->toString()); + VELOX_CHECK( + constantExpr->type()->isRow(), + "Constant expression not of ROW type: {}.", + constantExpr->type()->toString()); + const auto type = asRowType(constantExpr->type()); + + protocol::ConstantExpression cexpr; + json j; + result["arguments"] = json::array(); + for (const auto& child : rowVector->children()) { + cexpr.type = getTypeSignature(child->type()); + cexpr.valueBlock.data = getValueBlock(child); + protocol::to_json(j, cexpr); + result["arguments"].push_back(j); + } + return result; +} + +SpecialFormExpressionPtr VeloxToPrestoExprConverter::getDereferenceExpression( + const velox::core::DereferenceTypedExpr* dereferenceExpr) const { + json result; + result["@type"] = kSpecial; + result["form"] = kDereference; + result["returnType"] = getTypeSignature(dereferenceExpr->type()); + + json j; + result["arguments"] = json::array(); + const auto dereferenceInputs = std::vector{ + dereferenceExpr->inputs().at(0), + std::make_shared( + velox::BIGINT(), static_cast(dereferenceExpr->index()))}; + for (const auto& input : dereferenceInputs) { + const auto rowExpr = getRowExpression(input); + protocol::to_json(j, rowExpr); + result["arguments"].push_back(j); + } + + return result; +} + +CallExpressionPtr VeloxToPrestoExprConverter::getCallExpression( + const velox::core::CallTypedExpr* expr) const { + json result; + result["@type"] = kCall; + protocol::Signature signature; + std::string exprName = expr->name(); + if (veloxToPrestoOperatorMap().find(exprName) != + veloxToPrestoOperatorMap().end()) { + exprName = veloxToPrestoOperatorMap().at(exprName); + } + signature.name = exprName; + result["displayName"] = exprName; + signature.kind = protocol::FunctionKind::SCALAR; + signature.typeVariableConstraints = {}; + signature.longVariableConstraints = {}; + signature.returnType = getTypeSignature(expr->type()); + + std::vector argumentTypes; + auto exprInputs = expr->inputs(); + argumentTypes.reserve(exprInputs.size()); + for (const auto& input : exprInputs) { + argumentTypes.emplace_back(getTypeSignature(input->type())); + } + signature.argumentTypes = argumentTypes; + signature.variableArity = false; + + protocol::BuiltInFunctionHandle builtInFunctionHandle; + builtInFunctionHandle._type = kStatic; + builtInFunctionHandle.signature = signature; + result["functionHandle"] = builtInFunctionHandle; + result["returnType"] = getTypeSignature(expr->type()); + result["arguments"] = json::array(); + for (const auto& exprInput : exprInputs) { + result["arguments"].push_back(getRowExpression(exprInput)); + } + + return result; +} + +RowExpressionPtr VeloxToPrestoExprConverter::getRowExpression( + const velox::core::TypedExprPtr& expr, + const RowExpressionPtr& inputRowExpr) const { + switch (expr->kind()) { + case velox::core::ExprKind::kConstant: { + const auto* constantExpr = + expr->asUnchecked(); + // ConstantTypedExpr of ROW type maps to SpecialFormExpression of type + // ROW_CONSTRUCTOR in Presto. + if (expr->type()->isRow()) { + return getRowConstructorExpression(constantExpr); + } + return getConstantExpression(constantExpr); + } + case velox::core::ExprKind::kFieldAccess: { + const auto* field = + expr->asUnchecked(); + return getVariableReferenceExpression(field); + } + case velox::core::ExprKind::kDereference: { + const auto* dereferenceTypedExpr = + expr->asUnchecked(); + return getDereferenceExpression(dereferenceTypedExpr); + } + case velox::core::ExprKind::kCast: { + // Velox CastTypedExpr maps to Presto CallExpression. + const auto* castExpr = expr->asUnchecked(); + auto call = std::make_shared( + expr->type(), castExpr->inputs(), velox::expression::kCast); + return getCallExpression(call.get()); + } + case velox::core::ExprKind::kCall: { + const auto* callTypedExpr = + expr->asUnchecked(); + // Check if special form expression or call expression. + auto exprName = callTypedExpr->name(); + boost::algorithm::to_lower(exprName); + if (isPrestoSpecialForm(exprName)) { + return getSpecialFormExpression(callTypedExpr); + } + return getCallExpression(callTypedExpr); + } + case velox::core::ExprKind::kConcat: + [[fallthrough]]; + case velox::core::ExprKind::kInput: + [[fallthrough]]; + case velox::core::ExprKind::kLambda: + [[fallthrough]]; + default: { + // Log Velox to Presto expression conversion error and return the + // unoptimized input RowExpression. + LOG(ERROR) << fmt::format( + "Unable to convert Velox expression: {} of kind: {} to Presto RowExpression.", + expr->toString(), + velox::core::ExprKindName::toName(expr->kind())); + return inputRowExpr; + } + } +} + +} // namespace facebook::presto::expression diff --git a/presto-native-execution/presto_cpp/main/types/VeloxToPrestoExpr.h b/presto-native-execution/presto_cpp/main/types/VeloxToPrestoExpr.h new file mode 100644 index 0000000000000..8a62267a08fd2 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/VeloxToPrestoExpr.h @@ -0,0 +1,108 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "presto_cpp/external/json/nlohmann/json.hpp" +#include "presto_cpp/presto_protocol/presto_protocol.h" +#include "velox/core/Expressions.h" +#include "velox/serializers/PrestoSerializer.h" + +using RowExpressionPtr = + std::shared_ptr; +using ConstantExpressionPtr = + std::shared_ptr; +using CallExpressionPtr = + std::shared_ptr; +using SpecialFormExpressionPtr = + std::shared_ptr; + +namespace facebook::presto::expression { + +/// Helper class to convert a Velox expression of type `core::ITypedExpr` to its +/// equivalent Presto expression of type `protocol::RowExpression`: +/// 1. A constant Velox expression of type `core::ConstantTypedExpr` is +/// converted to a Presto expression of type `protocol::ConstantExpression`. +/// If the Velox constant expression is of `ROW` type, it is converted to a +/// Presto expression of type `protocol::SpecialFormExpression` with Form +/// as `ROW_CONSTRUCTOR`. +/// 2. A Velox expression representing a field in the input, of type +/// `core::FieldAccessTypedExpr`, is converted to a Presto expression of +/// type `protocol::VariableReferenceExpression`. +/// 3. A Velox dereference expression of type `core::DereferenceTypedExpr` is +/// converted to a Presto expression of type +/// `protocol::SpecialFormExpression` with Form as `DEREFERENCE`. +/// 4. A Velox cast expression of type `core::CastTypedExpr` is converted to a +/// Presto expression of type `protocol::CallExpression`. +/// 5. A Velox call expression of type `core::CallTypedExpr` is converted +/// either to a Presto expression of type `protocol::CallExpression` or of +/// type `protocol::SpecialFormExpression`. This is because Velox does +/// not have a separate expression kind for special form expressions and the +/// special forms in Presto and Velox do not have a one to one mapping. If +/// the velox call expression's name belongs to the set of Presto's +/// `SpecialFormExpression` names, it is converted to a Presto +/// `protocol::SpecialFormExpression`; else it is converted to a Presto +/// `protocol::CallExpression`. +class VeloxToPrestoExprConverter { + public: + explicit VeloxToPrestoExprConverter(velox::memory::MemoryPool* pool) + : pool_(pool) {} + + /// Converts a Velox expression `expr` to a Presto protocol RowExpression. + /// The input Presto protocol RowExpression, `inputRowExpr`, is returned in + /// case the Velox to Presto expression conversion fails. + RowExpressionPtr getRowExpression( + const velox::core::TypedExprPtr& expr, + const RowExpressionPtr& inputRowExpr = nullptr) const; + + private: + /// `ValueBlock` in Presto `protocol::ConstantExpression` requires only the + /// column from the serialized PrestoPage without the page header. This + /// function is used to serialize a velox vector to `ValueBlock`. + std::string getValueBlock(const velox::VectorPtr& vector) const; + + /// Helper function to get a Presto `protocol::ConstantExpression` from a + /// Velox constant expression. + ConstantExpressionPtr getConstantExpression( + const velox::core::ConstantTypedExpr* constantExpr) const; + + /// Helper function to get the arguments for Presto `SWITCH` expression from + /// Velox switch expression. + std::vector getSwitchSpecialFormExpressionArgs( + const velox::core::CallTypedExpr* switchExpr) const; + + /// Helper function to construct a Presto `protocol::SpecialFormExpression` + /// from a Velox call expression. + SpecialFormExpressionPtr getSpecialFormExpression( + const velox::core::CallTypedExpr* expr) const; + + /// Helper function to construct a Presto `protocol::SpecialFormExpression` of + /// type `ROW_CONSTRUCTOR` from a Velox constant expression. + SpecialFormExpressionPtr getRowConstructorExpression( + const velox::core::ConstantTypedExpr* constantExpr) const; + + /// Helper function to construct a Presto `protocol::SpecialFormExpression` of + /// type `DEREFERENCE` from a Velox dereference expression. + SpecialFormExpressionPtr getDereferenceExpression( + const velox::core::DereferenceTypedExpr* dereferenceExpr) const; + + /// Helper function to construct a Presto `protocol::CallExpression` from a + /// Velox call expression. + CallExpressionPtr getCallExpression( + const velox::core::CallTypedExpr* expr) const; + + velox::memory::MemoryPool* pool_; + const std::unique_ptr serde_ = + std::make_unique(); +}; +} // namespace facebook::presto::expression From 9fbfd78cd033ae1ceef605eafd78645aa108bd9b Mon Sep 17 00:00:00 2001 From: Pramod Satya Date: Tue, 29 Jul 2025 14:24:03 -0700 Subject: [PATCH 2/4] feat(native): Add expression optimization endpoint in sidecar --- .../presto_cpp/main/CMakeLists.txt | 1 + .../presto_cpp/main/PrestoServer.cpp | 17 ++++ .../presto_cpp/main/types/CMakeLists.txt | 4 + .../main/types/ExpressionOptimizer.cpp | 93 +++++++++++++++++++ .../main/types/ExpressionOptimizer.h | 42 +++++++++ 5 files changed, 157 insertions(+) create mode 100644 presto-native-execution/presto_cpp/main/types/ExpressionOptimizer.cpp create mode 100644 presto-native-execution/presto_cpp/main/types/ExpressionOptimizer.h diff --git a/presto-native-execution/presto_cpp/main/CMakeLists.txt b/presto-native-execution/presto_cpp/main/CMakeLists.txt index 4183f1be0bb3f..61dea5affc4b9 100644 --- a/presto-native-execution/presto_cpp/main/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/CMakeLists.txt @@ -61,6 +61,7 @@ target_link_libraries( $ presto_common presto_exception + presto_expression_optimizer presto_function_metadata presto_connectors presto_http diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.cpp b/presto-native-execution/presto_cpp/main/PrestoServer.cpp index 11811b5923bad..def4b85b7db85 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoServer.cpp @@ -41,6 +41,7 @@ #include "presto_cpp/main/operators/PartitionAndSerialize.h" #include "presto_cpp/main/operators/ShuffleExchangeSource.h" #include "presto_cpp/main/operators/ShuffleRead.h" +#include "presto_cpp/main/types/ExpressionOptimizer.h" #include "presto_cpp/main/types/PrestoToVeloxQueryPlan.h" #include "presto_cpp/main/types/VeloxPlanConversion.h" #include "velox/common/base/Counters.h" @@ -1715,6 +1716,21 @@ void PrestoServer::registerSidecarEndpoints() { http::sendOkResponse(downstream, getFunctionsMetadata(catalog)); }); }); + httpServer_->registerPost( + "/v1/expressions", + [&](proxygen::HTTPMessage* message, + const std::vector>& body, + proxygen::ResponseHandler* downstream) { + const json::array_t inputRowExpressions = + json::parse(util::extractMessageBody(body)); + expression::optimizeExpressions( + message->getHeaders(), + inputRowExpressions, + downstream, + driverExecutor_.get(), + pool_.get()); + }); + httpServer_->registerPost( "/v1/velox/plan", [server = this]( @@ -1832,4 +1848,5 @@ void PrestoServer::registerTraceNodeFactories() { return nullptr; }); } + } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt index 739039391e916..2cf8d85100ecd 100644 --- a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt @@ -43,6 +43,10 @@ target_link_libraries( presto_protocol ) +add_library(presto_expression_optimizer ExpressionOptimizer.cpp) + +target_link_libraries(presto_expression_optimizer presto_types presto_velox_to_presto_expr) + if(PRESTO_ENABLE_TESTING) add_subdirectory(tests) endif() diff --git a/presto-native-execution/presto_cpp/main/types/ExpressionOptimizer.cpp b/presto-native-execution/presto_cpp/main/types/ExpressionOptimizer.cpp new file mode 100644 index 0000000000000..2babba17d04ab --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/ExpressionOptimizer.cpp @@ -0,0 +1,93 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/types/ExpressionOptimizer.h" +#include "presto_cpp/main/common/Configs.h" +#include "presto_cpp/main/http/HttpServer.h" +#include "presto_cpp/main/types/PrestoToVeloxExpr.h" +#include "presto_cpp/main/types/TypeParser.h" +#include "presto_cpp/main/types/VeloxToPrestoExpr.h" +#include "presto_cpp/presto_protocol/core/presto_protocol_core.h" +#include "velox/expression/ExprOptimizer.h" + +namespace facebook::presto::expression { + +constexpr char const* kOptimized = "OPTIMIZED"; +constexpr char const* kTimezoneHeader = "X-Presto-Time-Zone"; +constexpr char const* kOptimizerLevelHeader = + "X-Presto-Expression-Optimizer-Level"; + +void optimizeExpressions( + const proxygen::HTTPHeaders& httpHeaders, + const json::array_t& inputRowExpressions, + proxygen::ResponseHandler* downstream, + folly::Executor* driverExecutor, + velox::memory::MemoryPool* pool) { + static const velox::expression::MakeFailExpr kMakeFailExpr = + [](const std::string& error, + const velox::TypePtr& type) -> velox::core::TypedExprPtr { + return std::make_shared( + type, + std::vector{ + std::make_shared( + velox::UNKNOWN(), + std::vector{ + std::make_shared( + velox::VARCHAR(), error)}, + fmt::format( + "{}fail", + SystemConfig::instance()->prestoDefaultNamespacePrefix()))}, + false); + }; + + const auto& optimizerLevel = + httpHeaders.getSingleOrEmpty(kOptimizerLevelHeader); + VELOX_USER_CHECK_EQ( + optimizerLevel, + kOptimized, + "Optimizer level should be OPTIMIZED, received {}.", + optimizerLevel); + + const auto& timezone = httpHeaders.getSingleOrEmpty(kTimezoneHeader); + std::unordered_map config( + {{velox::core::QueryConfig::kSessionTimezone, timezone}, + {velox::core::QueryConfig::kAdjustTimestampToTimezone, "true"}}); + auto queryConfig = velox::core::QueryConfig{std::move(config)}; + auto queryCtx = + velox::core::QueryCtx::create(driverExecutor, std::move(queryConfig)); + + TypeParser typeParser; + const VeloxExprConverter veloxExprConverter(pool, &typeParser); + const expression::VeloxToPrestoExprConverter veloxToPrestoExprConverter(pool); + const auto numExpr = inputRowExpressions.size(); + json j; + json result = json::array(); + + for (auto i = 0; i < numExpr; i++) { + const RowExpressionPtr inputRowExpr = inputRowExpressions[i]; + protocol::to_json(j, inputRowExpr); + auto expr = veloxExprConverter.toVeloxExpr(inputRowExpr); + auto optimized = + velox::expression::optimize(expr, queryCtx.get(), pool, kMakeFailExpr); + + auto resultExpression = + veloxToPrestoExprConverter.getRowExpression(optimized, inputRowExpr); + protocol::to_json(j, resultExpression); + result.push_back(resultExpression); + } + + http::sendOkResponse(downstream, result); +} + +} // namespace facebook::presto::expression diff --git a/presto-native-execution/presto_cpp/main/types/ExpressionOptimizer.h b/presto-native-execution/presto_cpp/main/types/ExpressionOptimizer.h new file mode 100644 index 0000000000000..b7b7afbec1556 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/ExpressionOptimizer.h @@ -0,0 +1,42 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "presto_cpp/external/json/nlohmann/json.hpp" +#include "velox/common/memory/MemoryPool.h" + +namespace facebook::presto::expression { + +/// Optimizes RowExpressions received in a http request and returns a http +/// response containing the result of expression optimization. +/// @param httpHeaders Headers from the http request, contains the timezone +/// from Presto coordinator and the expression optimizer level. +/// @param inputRowExpressions List of RowExpressions to be optimized. +/// @param downstream Returns the result of expression optimization as a http +/// response. If expression optimizer level is `EVALUATED` and the evaluation of +/// any expression from the input fails, the http response contains the error +/// message encountered during evaluation with a 500 response code. Otherwise, +/// the http response contains the list of optimized RowExpressions, serialized +/// as an array of JSON objects, with 200 response code. +/// @param driverExecutor Driver CPU executor. +/// @param pool Memory pool. +void optimizeExpressions( + const proxygen::HTTPHeaders& httpHeaders, + const nlohmann::json::array_t& inputRowExpressions, + proxygen::ResponseHandler* downstream, + folly::Executor* driverExecutor, + velox::memory::MemoryPool* pool); +} // namespace facebook::presto::expression From 02674d61e207731af321928a3e42a7befa47736e Mon Sep 17 00:00:00 2001 From: Pramod Satya Date: Mon, 21 Apr 2025 15:08:42 -0700 Subject: [PATCH 3/4] test(native): Add e2e tests for expression optimization endpoint --- .../presto/sql/TestExpressionInterpreter.java | 1703 +---------------- .../AbstractTestExpressionInterpreter.java | 1697 ++++++++++++++++ presto-native-execution/pom.xml | 44 + .../TestNativeExpressionInterpreter.java | 620 ++++++ 4 files changed, 2410 insertions(+), 1654 deletions(-) create mode 100644 presto-main-base/src/test/java/com/facebook/presto/sql/expressions/AbstractTestExpressionInterpreter.java create mode 100644 presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestNativeExpressionInterpreter.java diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/TestExpressionInterpreter.java b/presto-main-base/src/test/java/com/facebook/presto/sql/TestExpressionInterpreter.java index c4d9d050dbc3d..221685a9cf85d 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/TestExpressionInterpreter.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/TestExpressionInterpreter.java @@ -13,45 +13,19 @@ */ package com.facebook.presto.sql; -import com.facebook.presto.common.CatalogSchemaName; -import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.block.Block; -import com.facebook.presto.common.block.BlockEncodingManager; -import com.facebook.presto.common.block.BlockEncodingSerde; -import com.facebook.presto.common.block.BlockSerdeUtil; -import com.facebook.presto.common.type.ArrayType; -import com.facebook.presto.common.type.Decimals; -import com.facebook.presto.common.type.SqlTimestampWithTimeZone; -import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.common.type.Type; -import com.facebook.presto.common.type.VarbinaryType; import com.facebook.presto.functionNamespace.json.JsonFileBasedFunctionNamespaceManagerFactory; import com.facebook.presto.metadata.FunctionAndTypeManager; -import com.facebook.presto.metadata.Metadata; -import com.facebook.presto.metadata.MetadataManager; import com.facebook.presto.operator.scalar.FunctionAssertions; -import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.WarningCollector; -import com.facebook.presto.spi.function.AggregationFunctionMetadata; -import com.facebook.presto.spi.function.FunctionKind; -import com.facebook.presto.spi.function.Parameter; -import com.facebook.presto.spi.function.RoutineCharacteristics; -import com.facebook.presto.spi.function.SqlInvokedFunction; -import com.facebook.presto.spi.relation.CallExpression; -import com.facebook.presto.spi.relation.ConstantExpression; -import com.facebook.presto.spi.relation.InputReferenceExpression; -import com.facebook.presto.spi.relation.LambdaDefinitionExpression; +import com.facebook.presto.spi.relation.ExpressionOptimizer; import com.facebook.presto.spi.relation.RowExpression; -import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; -import com.facebook.presto.sql.parser.ParsingOptions; -import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.expressions.AbstractTestExpressionInterpreter; import com.facebook.presto.sql.planner.ExpressionInterpreter; import com.facebook.presto.sql.planner.RowExpressionInterpreter; import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.TypeProvider; -import com.facebook.presto.sql.relational.FunctionResolution; -import com.facebook.presto.sql.tree.EnumLiteral; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; import com.facebook.presto.sql.tree.ExpressionTreeRewriter; @@ -64,122 +38,32 @@ import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.airlift.slice.DynamicSliceOutput; import io.airlift.slice.Slice; -import io.airlift.slice.SliceOutput; import io.airlift.slice.Slices; import org.intellij.lang.annotations.Language; -import org.joda.time.DateTime; -import org.joda.time.DateTimeZone; -import org.joda.time.LocalDate; -import org.joda.time.LocalTime; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import java.math.BigInteger; import java.util.Map; import java.util.Optional; -import java.util.concurrent.TimeUnit; import java.util.stream.IntStream; import static com.facebook.presto.SessionTestUtils.TEST_SESSION; -import static com.facebook.presto.common.type.BigintType.BIGINT; -import static com.facebook.presto.common.type.BooleanType.BOOLEAN; -import static com.facebook.presto.common.type.DateType.DATE; -import static com.facebook.presto.common.type.DecimalType.createDecimalType; -import static com.facebook.presto.common.type.DoubleType.DOUBLE; -import static com.facebook.presto.common.type.IntegerType.INTEGER; -import static com.facebook.presto.common.type.TimeType.TIME; -import static com.facebook.presto.common.type.TimeZoneKey.getTimeZoneKey; -import static com.facebook.presto.common.type.TimestampType.TIMESTAMP; -import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; -import static com.facebook.presto.common.type.VarcharType.VARCHAR; -import static com.facebook.presto.common.type.VarcharType.createVarcharType; import static com.facebook.presto.operator.scalar.ApplyFunction.APPLY_FUNCTION; -import static com.facebook.presto.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; -import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; -import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.DETERMINISTIC; -import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.CPP; -import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.SERIALIZABLE; -import static com.facebook.presto.sql.ExpressionFormatter.formatExpression; import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; import static com.facebook.presto.sql.planner.ExpressionInterpreter.expressionInterpreter; import static com.facebook.presto.sql.planner.ExpressionInterpreter.expressionOptimizer; import static com.facebook.presto.sql.planner.RowExpressionInterpreter.rowExpressionInterpreter; -import static com.facebook.presto.type.IntervalDayTimeType.INTERVAL_DAY_TIME; -import static com.facebook.presto.util.AnalyzerUtil.createParsingOptions; -import static com.facebook.presto.util.DateTimeZoneIndex.getDateTimeZone; -import static io.airlift.slice.Slices.utf8Slice; import static java.lang.String.format; import static java.util.Collections.emptyMap; -import static java.util.Locale.ENGLISH; import static org.testng.Assert.assertEquals; -import static org.testng.Assert.assertThrows; -import static org.testng.Assert.assertTrue; -import static org.testng.Assert.fail; public class TestExpressionInterpreter + extends AbstractTestExpressionInterpreter { - public static final SqlInvokedFunction SQUARE_UDF_CPP = new SqlInvokedFunction( - QualifiedObjectName.valueOf(new CatalogSchemaName("json", "test_schema"), "square"), - ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))), - parseTypeSignature(StandardTypes.BIGINT), - "Integer square", - RoutineCharacteristics.builder().setDeterminism(DETERMINISTIC).setLanguage(CPP).build(), - "", - notVersioned()); - - public static final SqlInvokedFunction AVG_UDAF_CPP = new SqlInvokedFunction( - QualifiedObjectName.valueOf(new CatalogSchemaName("json", "test_schema"), "avg"), - ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.DOUBLE))), - parseTypeSignature(StandardTypes.DOUBLE), - "Returns mean of doubles", - RoutineCharacteristics.builder().setDeterminism(DETERMINISTIC).setLanguage(CPP).build(), - "", - notVersioned(), - FunctionKind.AGGREGATE, - Optional.of(new AggregationFunctionMetadata(parseTypeSignature("ROW(double, int)"), false))); - - private static final int TEST_VARCHAR_TYPE_LENGTH = 17; - private static final TypeProvider SYMBOL_TYPES = TypeProvider.viewOf(ImmutableMap.builder() - .put("bound_integer", INTEGER) - .put("bound_long", BIGINT) - .put("bound_string", createVarcharType(TEST_VARCHAR_TYPE_LENGTH)) - .put("bound_varbinary", VarbinaryType.VARBINARY) - .put("bound_double", DOUBLE) - .put("bound_boolean", BOOLEAN) - .put("bound_date", DATE) - .put("bound_time", TIME) - .put("bound_timestamp", TIMESTAMP) - .put("bound_pattern", VARCHAR) - .put("bound_null_string", VARCHAR) - .put("bound_decimal_short", createDecimalType(5, 2)) - .put("bound_decimal_long", createDecimalType(23, 3)) - .put("time", BIGINT) // for testing reserved identifiers - .put("unbound_integer", INTEGER) - .put("unbound_long", BIGINT) - .put("unbound_long2", BIGINT) - .put("unbound_long3", BIGINT) - .put("unbound_string", VARCHAR) - .put("unbound_double", DOUBLE) - .put("unbound_boolean", BOOLEAN) - .put("unbound_date", DATE) - .put("unbound_time", TIME) - .put("unbound_array", new ArrayType(BIGINT)) - .put("unbound_timestamp", TIMESTAMP) - .put("unbound_interval", INTERVAL_DAY_TIME) - .put("unbound_pattern", VARCHAR) - .put("unbound_null_string", VARCHAR) - .build()); - - private static final SqlParser SQL_PARSER = new SqlParser(); - private static final Metadata METADATA = MetadataManager.createTestMetadataManager(); - private static final TestingRowExpressionTranslator TRANSLATOR = new TestingRowExpressionTranslator(METADATA); - private static final BlockEncodingSerde blockEncodingSerde = new BlockEncodingManager(); - @BeforeClass public void setup() { @@ -187,247 +71,6 @@ public void setup() setupJsonFunctionNamespaceManager(METADATA.getFunctionAndTypeManager()); } - @Test - public void testAnd() - { - assertOptimizedEquals("true and false", "false"); - assertOptimizedEquals("false and true", "false"); - assertOptimizedEquals("false and false", "false"); - - assertOptimizedEquals("true and null", "null"); - assertOptimizedEquals("false and null", "false"); - assertOptimizedEquals("null and true", "null"); - assertOptimizedEquals("null and false", "false"); - assertOptimizedEquals("null and null", "null"); - - assertOptimizedEquals("unbound_string='z' and true", "unbound_string='z'"); - assertOptimizedEquals("unbound_string='z' and false", "false"); - assertOptimizedEquals("true and unbound_string='z'", "unbound_string='z'"); - assertOptimizedEquals("false and unbound_string='z'", "false"); - - assertOptimizedEquals("bound_string='z' and bound_long=1+1", "bound_string='z' and bound_long=2"); - assertOptimizedEquals("random() > 0 and random() > 0", "random() > 0 and random() > 0"); - } - - @Test - public void testOr() - { - assertOptimizedEquals("true or true", "true"); - assertOptimizedEquals("true or false", "true"); - assertOptimizedEquals("false or true", "true"); - assertOptimizedEquals("false or false", "false"); - - assertOptimizedEquals("true or null", "true"); - assertOptimizedEquals("null or true", "true"); - assertOptimizedEquals("null or null", "null"); - - assertOptimizedEquals("false or null", "null"); - assertOptimizedEquals("null or false", "null"); - - assertOptimizedEquals("bound_string='z' or true", "true"); - assertOptimizedEquals("bound_string='z' or false", "bound_string='z'"); - assertOptimizedEquals("true or bound_string='z'", "true"); - assertOptimizedEquals("false or bound_string='z'", "bound_string='z'"); - - assertOptimizedEquals("bound_string='z' or bound_long=1+1", "bound_string='z' or bound_long=2"); - assertOptimizedEquals("random() > 0 or random() > 0", "random() > 0 or random() > 0"); - } - - @Test - public void testComparison() - { - assertOptimizedEquals("null = null", "null"); - - assertOptimizedEquals("'a' = 'b'", "false"); - assertOptimizedEquals("'a' = 'a'", "true"); - assertOptimizedEquals("'a' = null", "null"); - assertOptimizedEquals("null = 'a'", "null"); - assertOptimizedEquals("bound_integer = 1234", "true"); - assertOptimizedEquals("bound_integer = 12340000000", "false"); - assertOptimizedEquals("bound_long = BIGINT '1234'", "true"); - assertOptimizedEquals("bound_long = 1234", "true"); - assertOptimizedEquals("bound_double = 12.34", "true"); - assertOptimizedEquals("bound_string = 'hello'", "true"); - assertOptimizedEquals("bound_long = unbound_long", "1234 = unbound_long"); - - assertOptimizedEquals("10151082135029368 = 10151082135029369", "false"); - - assertOptimizedEquals("bound_varbinary = X'a b'", "true"); - assertOptimizedEquals("bound_varbinary = X'a d'", "false"); - - assertOptimizedEquals("1.1 = 1.1", "true"); - assertOptimizedEquals("9876543210.9874561203 = 9876543210.9874561203", "true"); - assertOptimizedEquals("bound_decimal_short = 123.45", "true"); - assertOptimizedEquals("bound_decimal_long = 12345678901234567890.123", "true"); - } - - @Test - public void testIsDistinctFrom() - { - assertOptimizedEquals("null is distinct from null", "false"); - - assertOptimizedEquals("3 is distinct from 4", "true"); - assertOptimizedEquals("3 is distinct from BIGINT '4'", "true"); - assertOptimizedEquals("3 is distinct from 4000000000", "true"); - assertOptimizedEquals("3 is distinct from 3", "false"); - assertOptimizedEquals("3 is distinct from null", "true"); - assertOptimizedEquals("null is distinct from 3", "true"); - - assertOptimizedEquals("10151082135029368 is distinct from 10151082135029369", "true"); - - assertOptimizedEquals("1.1 is distinct from 1.1", "false"); - assertOptimizedEquals("9876543210.9874561203 is distinct from NULL", "true"); - assertOptimizedEquals("bound_decimal_short is distinct from NULL", "true"); - assertOptimizedEquals("bound_decimal_long is distinct from 12345678901234567890.123", "false"); - } - - @Test - public void testIsNull() - { - assertOptimizedEquals("null is null", "true"); - assertOptimizedEquals("1 is null", "false"); - assertOptimizedEquals("10000000000 is null", "false"); - assertOptimizedEquals("BIGINT '1' is null", "false"); - assertOptimizedEquals("1.0 is null", "false"); - assertOptimizedEquals("'a' is null", "false"); - assertOptimizedEquals("true is null", "false"); - assertOptimizedEquals("null+1 is null", "true"); - assertOptimizedEquals("unbound_string is null", "unbound_string is null"); - assertOptimizedEquals("unbound_long+(1+1) is null", "unbound_long+2 is null"); - assertOptimizedEquals("1.1 is null", "false"); - assertOptimizedEquals("9876543210.9874561203 is null", "false"); - assertOptimizedEquals("bound_decimal_short is null", "false"); - assertOptimizedEquals("bound_decimal_long is null", "false"); - } - - @Test - public void testIsNotNull() - { - assertOptimizedEquals("null is not null", "false"); - assertOptimizedEquals("1 is not null", "true"); - assertOptimizedEquals("10000000000 is not null", "true"); - assertOptimizedEquals("BIGINT '1' is not null", "true"); - assertOptimizedEquals("1.0 is not null", "true"); - assertOptimizedEquals("'a' is not null", "true"); - assertOptimizedEquals("true is not null", "true"); - assertOptimizedEquals("null+1 is not null", "false"); - assertOptimizedEquals("unbound_string is not null", "unbound_string is not null"); - assertOptimizedEquals("unbound_long+(1+1) is not null", "unbound_long+2 is not null"); - assertOptimizedEquals("1.1 is not null", "true"); - assertOptimizedEquals("9876543210.9874561203 is not null", "true"); - assertOptimizedEquals("bound_decimal_short is not null", "true"); - assertOptimizedEquals("bound_decimal_long is not null", "true"); - } - - @Test - public void testNullIf() - { - assertOptimizedEquals("nullif(true, true)", "null"); - assertOptimizedEquals("nullif(true, false)", "true"); - assertOptimizedEquals("nullif(null, false)", "null"); - assertOptimizedEquals("nullif(true, null)", "true"); - - assertOptimizedEquals("nullif('a', 'a')", "null"); - assertOptimizedEquals("nullif('a', 'b')", "'a'"); - assertOptimizedEquals("nullif(null, 'b')", "null"); - assertOptimizedEquals("nullif('a', null)", "'a'"); - - assertOptimizedEquals("nullif(1, 1)", "null"); - assertOptimizedEquals("nullif(1, 2)", "1"); - assertOptimizedEquals("nullif(1, BIGINT '2')", "1"); - assertOptimizedEquals("nullif(1, 20000000000)", "1"); - assertOptimizedEquals("nullif(1.0E0, 1)", "null"); - assertOptimizedEquals("nullif(10000000000.0E0, 10000000000)", "null"); - assertOptimizedEquals("nullif(1.1E0, 1)", "1.1E0"); - assertOptimizedEquals("nullif(1.1E0, 1.1E0)", "null"); - assertOptimizedEquals("nullif(1, 2-1)", "null"); - assertOptimizedEquals("nullif(null, null)", "null"); - assertOptimizedEquals("nullif(1, null)", "1"); - assertOptimizedEquals("nullif(unbound_long, 1)", "nullif(unbound_long, 1)"); - assertOptimizedEquals("nullif(unbound_long, unbound_long2)", "nullif(unbound_long, unbound_long2)"); - assertOptimizedEquals("nullif(unbound_long, unbound_long2+(1+1))", "nullif(unbound_long, unbound_long2+2)"); - - assertOptimizedEquals("nullif(1.1, 1.2)", "1.1"); - assertOptimizedEquals("nullif(9876543210.9874561203, 9876543210.9874561203)", "null"); - assertOptimizedEquals("nullif(bound_decimal_short, 123.45)", "null"); - assertOptimizedEquals("nullif(bound_decimal_long, 12345678901234567890.123)", "null"); - assertOptimizedEquals("nullif(ARRAY[CAST(1 AS BIGINT)], ARRAY[CAST(1 AS BIGINT)]) IS NULL", "true"); - assertOptimizedEquals("nullif(ARRAY[CAST(1 AS BIGINT)], ARRAY[CAST(NULL AS BIGINT)]) IS NULL", "false"); - assertOptimizedEquals("nullif(ARRAY[CAST(NULL AS BIGINT)], ARRAY[CAST(NULL AS BIGINT)]) IS NULL", "false"); - } - - @Test - public void testNegative() - { - assertOptimizedEquals("-(1)", "-1"); - assertOptimizedEquals("-(BIGINT '1')", "BIGINT '-1'"); - assertOptimizedEquals("-(unbound_long+1)", "-(unbound_long+1)"); - assertOptimizedEquals("-(1+1)", "-2"); - assertOptimizedEquals("-(1+ BIGINT '1')", "BIGINT '-2'"); - assertOptimizedEquals("-(CAST(NULL AS BIGINT))", "null"); - assertOptimizedEquals("-(unbound_long+(1+1))", "-(unbound_long+2)"); - assertOptimizedEquals("-(1.1+1.2)", "-2.3"); - assertOptimizedEquals("-(9876543210.9874561203-9876543210.9874561203)", "CAST(0 AS DECIMAL(20,10))"); - assertOptimizedEquals("-(bound_decimal_short+123.45)", "-246.90"); - assertOptimizedEquals("-(bound_decimal_long-12345678901234567890.123)", "CAST(0 AS DECIMAL(20,10))"); - } - - @Test - public void testNot() - { - assertOptimizedEquals("not true", "false"); - assertOptimizedEquals("not false", "true"); - assertOptimizedEquals("not null", "null"); - assertOptimizedEquals("not 1=1", "false"); - assertOptimizedEquals("not 1=BIGINT '1'", "false"); - assertOptimizedEquals("not 1!=1", "true"); - assertOptimizedEquals("not unbound_long=1", "not unbound_long=1"); - assertOptimizedEquals("not unbound_long=(1+1)", "not unbound_long=2"); - } - - @Test - public void testFunctionCall() - { - assertOptimizedEquals("abs(-5)", "5"); - assertOptimizedEquals("abs(-10-5)", "15"); - assertOptimizedEquals("abs(-bound_integer + 1)", "1233"); - assertOptimizedEquals("abs(-bound_long + 1)", "1233"); - assertOptimizedEquals("abs(-bound_long + BIGINT '1')", "1233"); - assertOptimizedEquals("abs(-bound_long)", "1234"); - assertOptimizedEquals("abs(unbound_long)", "abs(unbound_long)"); - assertOptimizedEquals("abs(unbound_long + 1)", "abs(unbound_long + 1)"); - assertOptimizedEquals("cast(json_parse(unbound_string) as map(varchar, varchar))", "cast(json_parse(unbound_string) as map(varchar, varchar))"); - assertOptimizedEquals("cast(json_parse(unbound_string) as array(varchar))", "cast(json_parse(unbound_string) as array(varchar))"); - assertOptimizedEquals("cast(json_parse(unbound_string) as row(bigint, varchar))", "cast(json_parse(unbound_string) as row(bigint, varchar))"); - } - - @Test - public void testNonDeterministicFunctionCall() - { - // optimize should do nothing - assertOptimizedEquals("random()", "random()"); - - // evaluate should execute - Object value = evaluate("random()", false); - assertTrue(value instanceof Double); - double randomValue = (double) value; - assertTrue(0 <= randomValue && randomValue < 1); - } - - @Test - public void testCppFunctionCall() - { - METADATA.getFunctionAndTypeManager().createFunction(SQUARE_UDF_CPP, false); - assertOptimizedEquals("json.test_schema.square(-5)", "json.test_schema.square(-5)"); - } - - @Test - public void testCppAggregateFunctionCall() - { - METADATA.getFunctionAndTypeManager().createFunction(AVG_UDAF_CPP, false); - assertOptimizedEquals("json.test_schema.avg(1.0)", "json.test_schema.avg(1.0)"); - } - // Run this method exactly once. private void setupJsonFunctionNamespaceManager(FunctionAndTypeManager functionAndTypeManager) { @@ -440,997 +83,11 @@ private void setupJsonFunctionNamespaceManager(FunctionAndTypeManager functionAn } @Test - public void testBetween() - { - assertOptimizedEquals("3 between 2 and 4", "true"); - assertOptimizedEquals("2 between 3 and 4", "false"); - assertOptimizedEquals("null between 2 and 4", "null"); - assertOptimizedEquals("3 between null and 4", "null"); - assertOptimizedEquals("3 between 2 and null", "null"); - - assertOptimizedEquals("'cc' between 'b' and 'd'", "true"); - assertOptimizedEquals("'b' between 'cc' and 'd'", "false"); - assertOptimizedEquals("null between 'b' and 'd'", "null"); - assertOptimizedEquals("'cc' between null and 'd'", "null"); - assertOptimizedEquals("'cc' between 'b' and null", "null"); - - assertOptimizedEquals("bound_integer between 1000 and 2000", "true"); - assertOptimizedEquals("bound_integer between 3 and 4", "false"); - assertOptimizedEquals("bound_long between 1000 and 2000", "true"); - assertOptimizedEquals("bound_long between 3 and 4", "false"); - assertOptimizedEquals("bound_long between bound_integer and (bound_long + 1)", "true"); - assertOptimizedEquals("bound_string between 'e' and 'i'", "true"); - assertOptimizedEquals("bound_string between 'a' and 'b'", "false"); - - assertOptimizedEquals("bound_long between unbound_long and 2000 + 1", "1234 between unbound_long and 2001"); - assertOptimizedEquals( - "bound_string between unbound_string and 'bar'", - format("CAST('hello' AS VARCHAR(%s)) between unbound_string and 'bar'", TEST_VARCHAR_TYPE_LENGTH)); - - assertOptimizedEquals("1.15 between 1.1 and 1.2", "true"); - assertOptimizedEquals("9876543210.98745612035 between 9876543210.9874561203 and 9876543210.9874561204", "true"); - assertOptimizedEquals("123.455 between bound_decimal_short and 123.46", "true"); - assertOptimizedEquals("12345678901234567890.1235 between bound_decimal_long and 12345678901234567890.123", "false"); - } - - @Test - public void testExtract() - { - DateTime dateTime = new DateTime(2001, 8, 22, 3, 4, 5, 321, getDateTimeZone(TEST_SESSION.getTimeZoneKey())); - double seconds = dateTime.getMillis() / 1000.0; - - assertOptimizedEquals("extract (YEAR from from_unixtime(" + seconds + "))", "2001"); - assertOptimizedEquals("extract (QUARTER from from_unixtime(" + seconds + "))", "3"); - assertOptimizedEquals("extract (MONTH from from_unixtime(" + seconds + "))", "8"); - assertOptimizedEquals("extract (WEEK from from_unixtime(" + seconds + "))", "34"); - assertOptimizedEquals("extract (DOW from from_unixtime(" + seconds + "))", "3"); - assertOptimizedEquals("extract (DOY from from_unixtime(" + seconds + "))", "234"); - assertOptimizedEquals("extract (DAY from from_unixtime(" + seconds + "))", "22"); - assertOptimizedEquals("extract (HOUR from from_unixtime(" + seconds + "))", "3"); - assertOptimizedEquals("extract (MINUTE from from_unixtime(" + seconds + "))", "4"); - assertOptimizedEquals("extract (SECOND from from_unixtime(" + seconds + "))", "5"); - assertOptimizedEquals("extract (TIMEZONE_HOUR from from_unixtime(" + seconds + ", 7, 9))", "7"); - assertOptimizedEquals("extract (TIMEZONE_MINUTE from from_unixtime(" + seconds + ", 7, 9))", "9"); - - assertOptimizedEquals("extract (YEAR from bound_timestamp)", "2001"); - assertOptimizedEquals("extract (QUARTER from bound_timestamp)", "3"); - assertOptimizedEquals("extract (MONTH from bound_timestamp)", "8"); - assertOptimizedEquals("extract (WEEK from bound_timestamp)", "34"); - assertOptimizedEquals("extract (DOW from bound_timestamp)", "2"); - assertOptimizedEquals("extract (DOY from bound_timestamp)", "233"); - assertOptimizedEquals("extract (DAY from bound_timestamp)", "21"); - assertOptimizedEquals("extract (HOUR from bound_timestamp)", "16"); - assertOptimizedEquals("extract (MINUTE from bound_timestamp)", "4"); - assertOptimizedEquals("extract (SECOND from bound_timestamp)", "5"); - // todo reenable when cast as timestamp with time zone is implemented - // todo add bound timestamp with time zone - //assertOptimizedEquals("extract (TIMEZONE_HOUR from bound_timestamp)", "0"); - //assertOptimizedEquals("extract (TIMEZONE_MINUTE from bound_timestamp)", "0"); - - assertOptimizedEquals("extract (YEAR from unbound_timestamp)", "extract (YEAR from unbound_timestamp)"); - assertOptimizedEquals("extract (SECOND from bound_timestamp + INTERVAL '3' SECOND)", "8"); - } - - @Test - public void testIn() - { - assertOptimizedEquals("3 in (2, 4, 3, 5)", "true"); - assertOptimizedEquals("3 in (2, 4, 9, 5)", "false"); - assertOptimizedEquals("3 in (2, null, 3, 5)", "true"); - - assertOptimizedEquals("'foo' in ('bar', 'baz', 'foo', 'blah')", "true"); - assertOptimizedEquals("'foo' in ('bar', 'baz', 'buz', 'blah')", "false"); - assertOptimizedEquals("'foo' in ('bar', null, 'foo', 'blah')", "true"); - - assertOptimizedEquals("null in (2, null, 3, 5)", "null"); - assertOptimizedEquals("3 in (2, null)", "null"); - - assertOptimizedEquals("bound_integer in (2, 1234, 3, 5)", "true"); - assertOptimizedEquals("bound_integer in (2, 4, 3, 5)", "false"); - assertOptimizedEquals("1234 in (2, bound_integer, 3, 5)", "true"); - assertOptimizedEquals("99 in (2, bound_integer, 3, 5)", "false"); - assertOptimizedEquals("bound_integer in (2, bound_integer, 3, 5)", "true"); - - assertOptimizedEquals("bound_long in (2, 1234, 3, 5)", "true"); - assertOptimizedEquals("bound_long in (2, 4, 3, 5)", "false"); - assertOptimizedEquals("1234 in (2, bound_long, 3, 5)", "true"); - assertOptimizedEquals("99 in (2, bound_long, 3, 5)", "false"); - assertOptimizedEquals("bound_long in (2, bound_long, 3, 5)", "true"); - - assertOptimizedEquals("bound_string in ('bar', 'hello', 'foo', 'blah')", "true"); - assertOptimizedEquals("bound_string in ('bar', 'baz', 'foo', 'blah')", "false"); - assertOptimizedEquals("'hello' in ('bar', bound_string, 'foo', 'blah')", "true"); - assertOptimizedEquals("'baz' in ('bar', bound_string, 'foo', 'blah')", "false"); - - assertOptimizedEquals("bound_long in (2, 1234, unbound_long, 5)", "true"); - assertOptimizedEquals("bound_string in ('bar', 'hello', unbound_string, 'blah')", "true"); - - assertOptimizedEquals("bound_long in (2, 4, unbound_long, unbound_long2, 9)", "1234 in (unbound_long, unbound_long2)"); - assertOptimizedEquals("unbound_long in (2, 4, bound_long, unbound_long2, 5)", "unbound_long in (2, 4, 1234, unbound_long2, 5)"); - - assertOptimizedEquals("1.15 in (1.1, 1.2, 1.3, 1.15)", "true"); - assertOptimizedEquals("9876543210.98745612035 in (9876543210.9874561203, 9876543210.9874561204, 9876543210.98745612035)", "true"); - assertOptimizedEquals("bound_decimal_short in (123.455, 123.46, 123.45)", "true"); - assertOptimizedEquals("bound_decimal_long in (12345678901234567890.123, 9876543210.9874561204, 9876543210.98745612035)", "true"); - assertOptimizedEquals("bound_decimal_long in (9876543210.9874561204, null, 9876543210.98745612035)", "null"); - } - - @Test - public void testInComplexTypes() - { - assertEvaluatedEquals("ARRAY[null] IN (ARRAY[null])", "null"); - assertEvaluatedEquals("ARRAY[1] IN (ARRAY[null])", "null"); - assertEvaluatedEquals("ARRAY[null] IN (ARRAY[1])", "null"); - assertEvaluatedEquals("ARRAY[1, null] IN (ARRAY[1, null])", "null"); - assertEvaluatedEquals("ARRAY[1, null] IN (ARRAY[2, null])", "false"); - assertEvaluatedEquals("ARRAY[1, null] IN (ARRAY[1, null], ARRAY[2, null])", "null"); - assertEvaluatedEquals("ARRAY[1, null] IN (ARRAY[1, null], ARRAY[2, null], ARRAY[1, null])", "null"); - assertEvaluatedEquals("ARRAY[ARRAY[1, 2], ARRAY[3, 4]] in (ARRAY[ARRAY[1, 2], ARRAY[3, NULL]])", "null"); - - assertEvaluatedEquals("ROW(1) IN (ROW(1))", "true"); - assertEvaluatedEquals("ROW(1) IN (ROW(2))", "false"); - assertEvaluatedEquals("ROW(1) IN (ROW(2), ROW(1), ROW(2))", "true"); - assertEvaluatedEquals("ROW(1) IN (null)", "null"); - assertEvaluatedEquals("ROW(1) IN (null, ROW(1))", "true"); - assertEvaluatedEquals("ROW(1, null) IN (ROW(2, null), null)", "null"); - assertEvaluatedEquals("ROW(null) IN (ROW(null))", "null"); - assertEvaluatedEquals("ROW(1) IN (ROW(null))", "null"); - assertEvaluatedEquals("ROW(null) IN (ROW(1))", "null"); - assertEvaluatedEquals("ROW(1, null) IN (ROW(1, null))", "null"); - assertEvaluatedEquals("ROW(1, null) IN (ROW(2, null))", "false"); - assertEvaluatedEquals("ROW(1, null) IN (ROW(1, null), ROW(2, null))", "null"); - assertEvaluatedEquals("ROW(1, null) IN (ROW(1, null), ROW(2, null), ROW(1, null))", "null"); - - assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (MAP(ARRAY[1], ARRAY[1]))", "true"); - assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (null)", "null"); - assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (null, MAP(ARRAY[1], ARRAY[1]))", "true"); - assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]))", "false"); - assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[2, null]), null)", "null"); - assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]))", "null"); - assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 3], ARRAY[1, null]))", "false"); - assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[null]) IN (MAP(ARRAY[1], ARRAY[null]))", "null"); - assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (MAP(ARRAY[1], ARRAY[null]))", "null"); - assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[null]) IN (MAP(ARRAY[1], ARRAY[1]))", "null"); - assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]))", "null"); - assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 3], ARRAY[1, null]))", "false"); - assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[2, null]))", "false"); - assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]), MAP(ARRAY[1, 2], ARRAY[2, null]))", "null"); - assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]), MAP(ARRAY[1, 2], ARRAY[2, null]), MAP(ARRAY[1, 2], ARRAY[1, null]))", "null"); - } - - @Test - public void testCurrentTimestamp() - { - double current = TEST_SESSION.getStartTime() / 1000.0; - assertOptimizedEquals("current_timestamp = from_unixtime(" + current + ")", "true"); - double future = current + TimeUnit.MINUTES.toSeconds(1); - assertOptimizedEquals("current_timestamp > from_unixtime(" + future + ")", "false"); - } - - @Test - public void testCurrentUser() - throws Exception - { - assertOptimizedEquals("current_user", "'" + TEST_SESSION.getUser() + "'"); - } - - @Test - public void testCastToString() - { - // integer - assertOptimizedEquals("cast(123 as VARCHAR(20))", "'123'"); - assertOptimizedEquals("cast(-123 as VARCHAR(20))", "'-123'"); - - // bigint - assertOptimizedEquals("cast(BIGINT '123' as VARCHAR)", "'123'"); - assertOptimizedEquals("cast(12300000000 as VARCHAR)", "'12300000000'"); - assertOptimizedEquals("cast(-12300000000 as VARCHAR)", "'-12300000000'"); - - // double - assertOptimizedEquals("cast(123.0E0 as VARCHAR)", "'123.0'"); - assertOptimizedEquals("cast(-123.0E0 as VARCHAR)", "'-123.0'"); - assertOptimizedEquals("cast(123.456E0 as VARCHAR)", "'123.456'"); - assertOptimizedEquals("cast(-123.456E0 as VARCHAR)", "'-123.456'"); - - // boolean - assertOptimizedEquals("cast(true as VARCHAR)", "'true'"); - assertOptimizedEquals("cast(false as VARCHAR)", "'false'"); - - // string - assertOptimizedEquals("cast('xyz' as VARCHAR)", "'xyz'"); - assertOptimizedEquals("cast(cast('abcxyz' as VARCHAR(3)) as VARCHAR(5))", "'abc'"); - - // null - assertOptimizedEquals("cast(null as VARCHAR)", "null"); - - // decimal - assertOptimizedEquals("cast(1.1 as VARCHAR)", "'1.1'"); - // TODO enabled when DECIMAL is default for literal: assertOptimizedEquals("cast(12345678901234567890.123 as VARCHAR)", "'12345678901234567890.123'"); - } - - @Test - public void testCastBigintToBoundedVarchar() - { - assertEvaluatedEquals("CAST(12300000000 AS varchar(11))", "'12300000000'"); - assertEvaluatedEquals("CAST(12300000000 AS varchar(50))", "'12300000000'"); - - try { - evaluate("CAST(12300000000 AS varchar(3))", true); - fail("Expected to throw an INVALID_CAST_ARGUMENT exception"); - } - catch (PrestoException e) { - try { - assertEquals(e.getErrorCode(), INVALID_CAST_ARGUMENT.toErrorCode()); - assertEquals(e.getMessage(), "Value 12300000000 cannot be represented as varchar(3)"); - } - catch (Throwable failure) { - failure.addSuppressed(e); - throw failure; - } - } - - try { - evaluate("CAST(-12300000000 AS varchar(3))", true); - } - catch (PrestoException e) { - try { - assertEquals(e.getErrorCode(), INVALID_CAST_ARGUMENT.toErrorCode()); - assertEquals(e.getMessage(), "Value -12300000000 cannot be represented as varchar(3)"); - } - catch (Throwable failure) { - failure.addSuppressed(e); - throw failure; - } - } - } - - @Test - public void testCastToBoolean() - { - // integer - assertOptimizedEquals("cast(123 as BOOLEAN)", "true"); - assertOptimizedEquals("cast(-123 as BOOLEAN)", "true"); - assertOptimizedEquals("cast(0 as BOOLEAN)", "false"); - - // bigint - assertOptimizedEquals("cast(12300000000 as BOOLEAN)", "true"); - assertOptimizedEquals("cast(-12300000000 as BOOLEAN)", "true"); - assertOptimizedEquals("cast(BIGINT '0' as BOOLEAN)", "false"); - - // boolean - assertOptimizedEquals("cast(true as BOOLEAN)", "true"); - assertOptimizedEquals("cast(false as BOOLEAN)", "false"); - - // string - assertOptimizedEquals("cast('true' as BOOLEAN)", "true"); - assertOptimizedEquals("cast('false' as BOOLEAN)", "false"); - assertOptimizedEquals("cast('t' as BOOLEAN)", "true"); - assertOptimizedEquals("cast('f' as BOOLEAN)", "false"); - assertOptimizedEquals("cast('1' as BOOLEAN)", "true"); - assertOptimizedEquals("cast('0' as BOOLEAN)", "false"); - - // null - assertOptimizedEquals("cast(null as BOOLEAN)", "null"); - - // double - assertOptimizedEquals("cast(123.45E0 as BOOLEAN)", "true"); - assertOptimizedEquals("cast(-123.45E0 as BOOLEAN)", "true"); - assertOptimizedEquals("cast(0.0E0 as BOOLEAN)", "false"); - - // decimal - assertOptimizedEquals("cast(0.00 as BOOLEAN)", "false"); - assertOptimizedEquals("cast(7.8 as BOOLEAN)", "true"); - assertOptimizedEquals("cast(12345678901234567890.123 as BOOLEAN)", "true"); - assertOptimizedEquals("cast(00000000000000000000.000 as BOOLEAN)", "false"); - } - - @Test - public void testCastToBigint() - { - // integer - assertOptimizedEquals("cast(0 as BIGINT)", "0"); - assertOptimizedEquals("cast(123 as BIGINT)", "123"); - assertOptimizedEquals("cast(-123 as BIGINT)", "-123"); - - // bigint - assertOptimizedEquals("cast(BIGINT '0' as BIGINT)", "0"); - assertOptimizedEquals("cast(BIGINT '123' as BIGINT)", "123"); - assertOptimizedEquals("cast(BIGINT '-123' as BIGINT)", "-123"); - - // double - assertOptimizedEquals("cast(123.0E0 as BIGINT)", "123"); - assertOptimizedEquals("cast(-123.0E0 as BIGINT)", "-123"); - assertOptimizedEquals("cast(123.456E0 as BIGINT)", "123"); - assertOptimizedEquals("cast(-123.456E0 as BIGINT)", "-123"); - - // boolean - assertOptimizedEquals("cast(true as BIGINT)", "1"); - assertOptimizedEquals("cast(false as BIGINT)", "0"); - - // string - assertOptimizedEquals("cast('123' as BIGINT)", "123"); - assertOptimizedEquals("cast('-123' as BIGINT)", "-123"); - - // null - assertOptimizedEquals("cast(null as BIGINT)", "null"); - - // decimal - assertOptimizedEquals("cast(DECIMAL '1.01' as BIGINT)", "1"); - assertOptimizedEquals("cast(DECIMAL '7.8' as BIGINT)", "8"); - assertOptimizedEquals("cast(DECIMAL '1234567890.123' as BIGINT)", "1234567890"); - assertOptimizedEquals("cast(DECIMAL '00000000000000000000.000' as BIGINT)", "0"); - } - - @Test - public void testCastToInteger() - { - // integer - assertOptimizedEquals("cast(0 as INTEGER)", "0"); - assertOptimizedEquals("cast(123 as INTEGER)", "123"); - assertOptimizedEquals("cast(-123 as INTEGER)", "-123"); - - // bigint - assertOptimizedEquals("cast(BIGINT '0' as INTEGER)", "0"); - assertOptimizedEquals("cast(BIGINT '123' as INTEGER)", "123"); - assertOptimizedEquals("cast(BIGINT '-123' as INTEGER)", "-123"); - - // double - assertOptimizedEquals("cast(123.0E0 as INTEGER)", "123"); - assertOptimizedEquals("cast(-123.0E0 as INTEGER)", "-123"); - assertOptimizedEquals("cast(123.456E0 as INTEGER)", "123"); - assertOptimizedEquals("cast(-123.456E0 as INTEGER)", "-123"); - - // boolean - assertOptimizedEquals("cast(true as INTEGER)", "1"); - assertOptimizedEquals("cast(false as INTEGER)", "0"); - - // string - assertOptimizedEquals("cast('123' as INTEGER)", "123"); - assertOptimizedEquals("cast('-123' as INTEGER)", "-123"); - - // null - assertOptimizedEquals("cast(null as INTEGER)", "null"); - } - - @Test - public void testCastToDouble() - { - // integer - assertOptimizedEquals("cast(0 as DOUBLE)", "0.0E0"); - assertOptimizedEquals("cast(123 as DOUBLE)", "123.0E0"); - assertOptimizedEquals("cast(-123 as DOUBLE)", "-123.0E0"); - - // bigint - assertOptimizedEquals("cast(BIGINT '0' as DOUBLE)", "0.0E0"); - assertOptimizedEquals("cast(12300000000 as DOUBLE)", "12300000000.0E0"); - assertOptimizedEquals("cast(-12300000000 as DOUBLE)", "-12300000000.0E0"); - - // double - assertOptimizedEquals("cast(123.0E0 as DOUBLE)", "123.0E0"); - assertOptimizedEquals("cast(-123.0E0 as DOUBLE)", "-123.0E0"); - assertOptimizedEquals("cast(123.456E0 as DOUBLE)", "123.456E0"); - assertOptimizedEquals("cast(-123.456E0 as DOUBLE)", "-123.456E0"); - - // string - assertOptimizedEquals("cast('0' as DOUBLE)", "0.0E0"); - assertOptimizedEquals("cast('123' as DOUBLE)", "123.0E0"); - assertOptimizedEquals("cast('-123' as DOUBLE)", "-123.0E0"); - assertOptimizedEquals("cast('123.0E0' as DOUBLE)", "123.0E0"); - assertOptimizedEquals("cast('-123.0E0' as DOUBLE)", "-123.0E0"); - assertOptimizedEquals("cast('123.456E0' as DOUBLE)", "123.456E0"); - assertOptimizedEquals("cast('-123.456E0' as DOUBLE)", "-123.456E0"); - - // null - assertOptimizedEquals("cast(null as DOUBLE)", "null"); - - // boolean - assertOptimizedEquals("cast(true as DOUBLE)", "1.0E0"); - assertOptimizedEquals("cast(false as DOUBLE)", "0.0E0"); - - // decimal - assertOptimizedEquals("cast(1.01 as DOUBLE)", "DOUBLE '1.01'"); - assertOptimizedEquals("cast(7.8 as DOUBLE)", "DOUBLE '7.8'"); - assertOptimizedEquals("cast(1234567890.123 as DOUBLE)", "DOUBLE '1234567890.123'"); - assertOptimizedEquals("cast(00000000000000000000.000 as DOUBLE)", "DOUBLE '0.0'"); - } - - @Test - public void testCastToDecimal() - { - // long - assertOptimizedEquals("cast(0 as DECIMAL(1,0))", "DECIMAL '0'"); - assertOptimizedEquals("cast(123 as DECIMAL(3,0))", "DECIMAL '123'"); - assertOptimizedEquals("cast(-123 as DECIMAL(3,0))", "DECIMAL '-123'"); - assertOptimizedEquals("cast(-123 as DECIMAL(20,10))", "cast(-123 as DECIMAL(20,10))"); - - // double - assertOptimizedEquals("cast(0E0 as DECIMAL(1,0))", "DECIMAL '0'"); - assertOptimizedEquals("cast(123.2E0 as DECIMAL(4,1))", "DECIMAL '123.2'"); - assertOptimizedEquals("cast(-123.0E0 as DECIMAL(3,0))", "DECIMAL '-123'"); - assertOptimizedEquals("cast(-123.55E0 as DECIMAL(20,10))", "cast(-123.55 as DECIMAL(20,10))"); - - // string - assertOptimizedEquals("cast('0' as DECIMAL(1,0))", "DECIMAL '0'"); - assertOptimizedEquals("cast('123.2' as DECIMAL(4,1))", "DECIMAL '123.2'"); - assertOptimizedEquals("cast('-123.0' as DECIMAL(3,0))", "DECIMAL '-123'"); - assertOptimizedEquals("cast('-123.55' as DECIMAL(20,10))", "cast(-123.55 as DECIMAL(20,10))"); - - // null - assertOptimizedEquals("cast(null as DECIMAL(1,0))", "null"); - assertOptimizedEquals("cast(null as DECIMAL(20,10))", "null"); - - // boolean - assertOptimizedEquals("cast(true as DECIMAL(1,0))", "DECIMAL '1'"); - assertOptimizedEquals("cast(false as DECIMAL(4,1))", "DECIMAL '000.0'"); - assertOptimizedEquals("cast(true as DECIMAL(3,0))", "DECIMAL '001'"); - assertOptimizedEquals("cast(false as DECIMAL(20,10))", "cast(0 as DECIMAL(20,10))"); - - // decimal - assertOptimizedEquals("cast(0.0 as DECIMAL(1,0))", "DECIMAL '0'"); - assertOptimizedEquals("cast(123.2 as DECIMAL(4,1))", "DECIMAL '123.2'"); - assertOptimizedEquals("cast(-123.0 as DECIMAL(3,0))", "DECIMAL '-123'"); - assertOptimizedEquals("cast(-123.55 as DECIMAL(20,10))", "cast(-123.55 as DECIMAL(20,10))"); - } - - @Test - public void testCastOptimization() - { - assertOptimizedEquals("cast(unbound_string as VARCHAR)", "cast(unbound_string as VARCHAR)"); - assertOptimizedMatches("cast(unbound_string as VARCHAR)", "unbound_string"); - assertOptimizedMatches("cast(unbound_integer as INTEGER)", "unbound_integer"); - assertOptimizedMatches("cast(unbound_string as VARCHAR(10))", "cast(unbound_string as VARCHAR(10))"); - } - - @Test - public void testTryCast() - { - assertOptimizedEquals("try_cast(null as BIGINT)", "null"); - assertOptimizedEquals("try_cast(123 as BIGINT)", "123"); - assertOptimizedEquals("try_cast(null as INTEGER)", "null"); - assertOptimizedEquals("try_cast(123 as INTEGER)", "123"); - assertOptimizedEquals("try_cast('foo' as VARCHAR)", "'foo'"); - assertOptimizedEquals("try_cast('foo' as BIGINT)", "null"); - assertOptimizedEquals("try_cast(unbound_string as BIGINT)", "try_cast(unbound_string as BIGINT)"); - assertOptimizedEquals("try_cast('foo' as DECIMAL(2,1))", "null"); - } - - @Test - public void testReservedWithDoubleQuotes() - { - assertOptimizedEquals("\"time\"", "\"time\""); - } - - @Test - public void testEnumLiteralFormattingWithTypeAndValue() - { - java.util.function.BiFunction createEnumLiteral = (type, value) -> new EnumLiteral(Optional.empty(), type, value); - assertEquals(ExpressionFormatter.formatExpression(createEnumLiteral.apply("color", "RED"), Optional.empty()), "color: RED"); - assertEquals(ExpressionFormatter.formatExpression(createEnumLiteral.apply("level", 1), Optional.empty()), "level: 1"); - assertEquals(ExpressionFormatter.formatExpression(createEnumLiteral.apply("StatusType", "Active"), Optional.empty()), "StatusType: Active"); - assertEquals(ExpressionFormatter.formatExpression(createEnumLiteral.apply("priority", "HIGH PRIORITY"), Optional.empty()), "priority: HIGH PRIORITY"); - assertEquals(ExpressionFormatter.formatExpression(createEnumLiteral.apply("lang", "枚举"), Optional.empty()), "lang: 枚举"); - assertEquals(ExpressionFormatter.formatExpression(createEnumLiteral.apply("special", "DOLLAR$"), Optional.empty()), "special: DOLLAR$"); - assertEquals(ExpressionFormatter.formatExpression(createEnumLiteral.apply("enum_type", "VALUE_1"), Optional.empty()), "enum_type: VALUE_1"); - assertEquals(ExpressionFormatter.formatExpression(createEnumLiteral.apply("flag", true), Optional.empty()), "flag: true"); - } - - @Test - public void testSearchCase() - { - assertOptimizedEquals("case " + - "when true then 33 " + - "end", - "33"); - assertOptimizedEquals("case " + - "when false then 1 " + - "else 33 " + - "end", - "33"); - - assertOptimizedEquals("case " + - "when false then 10000000000 " + - "else 33 " + - "end", - "33"); - - assertOptimizedEquals("case " + - "when bound_long = 1234 then 33 " + - "end", - "33"); - assertOptimizedEquals("case " + - "when true then bound_long " + - "end", - "1234"); - assertOptimizedEquals("case " + - "when false then 1 " + - "else bound_long " + - "end", - "1234"); - - assertOptimizedEquals("case " + - "when bound_integer = 1234 then 33 " + - "end", - "33"); - assertOptimizedEquals("case " + - "when true then bound_integer " + - "end", - "1234"); - assertOptimizedEquals("case " + - "when false then 1 " + - "else bound_integer " + - "end", - "1234"); - - assertOptimizedEquals("case " + - "when bound_long = 1234 then 33 " + - "else unbound_long " + - "end", - "33"); - assertOptimizedEquals("case " + - "when true then bound_long " + - "else unbound_long " + - "end", - "1234"); - assertOptimizedEquals("case " + - "when false then unbound_long " + - "else bound_long " + - "end", - "1234"); - - assertOptimizedEquals("case " + - "when bound_integer = 1234 then 33 " + - "else unbound_integer " + - "end", - "33"); - assertOptimizedEquals("case " + - "when true then bound_integer " + - "else unbound_integer " + - "end", - "1234"); - assertOptimizedEquals("case " + - "when false then unbound_integer " + - "else bound_integer " + - "end", - "1234"); - - assertOptimizedEquals("case " + - "when unbound_long = 1234 then 33 " + - "else 1 " + - "end", - "" + - "case " + - "when unbound_long = 1234 then 33 " + - "else 1 " + - "end"); - - assertOptimizedMatches("if(false, 1, 0 / 0)", "cast(fail(8, 'ignored failure message') as integer)"); - - assertOptimizedEquals("case " + - "when false then 2.2 " + - "when true then 2.2 " + - "end", - "2.2"); - - assertOptimizedEquals("case " + - "when false then 1234567890.0987654321 " + - "when true then 3.3 " + - "end", - "CAST(3.3 AS DECIMAL(20,10))"); - - assertOptimizedEquals("case " + - "when false then 1 " + - "when true then 2.2 " + - "end", - "2.2"); - - assertOptimizedEquals("case when ARRAY[CAST(1 AS BIGINT)] = ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'matched'"); - assertOptimizedEquals("case when ARRAY[CAST(2 AS BIGINT)] = ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'not_matched'"); - assertOptimizedEquals("case when ARRAY[CAST(null AS BIGINT)] = ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'not_matched'"); - } - - @Test - public void testSimpleCase() - { - assertOptimizedEquals("case 1 " + - "when 1 then 32 + 1 " + - "when 1 then 34 " + - "end", - "33"); - - assertOptimizedEquals("case null " + - "when true then 33 " + - "end", - "null"); - assertOptimizedEquals("case null " + - "when true then 33 " + - "else 33 " + - "end", - "33"); - assertOptimizedEquals("case 33 " + - "when null then 1 " + - "else 33 " + - "end", - "33"); - - assertOptimizedEquals("case null " + - "when true then 3300000000 " + - "end", - "null"); - assertOptimizedEquals("case null " + - "when true then 3300000000 " + - "else 3300000000 " + - "end", - "3300000000"); - assertOptimizedEquals("case 33 " + - "when null then 3300000000 " + - "else 33 " + - "end", - "33"); - - assertOptimizedEquals("case true " + - "when true then 33 " + - "end", - "33"); - assertOptimizedEquals("case true " + - "when false then 1 " + - "else 33 end", - "33"); - - assertOptimizedEquals("case bound_long " + - "when 1234 then 33 " + - "end", - "33"); - assertOptimizedEquals("case 1234 " + - "when bound_long then 33 " + - "end", - "33"); - assertOptimizedEquals("case true " + - "when true then bound_long " + - "end", - "1234"); - assertOptimizedEquals("case true " + - "when false then 1 " + - "else bound_long " + - "end", - "1234"); - - assertOptimizedEquals("case bound_integer " + - "when 1234 then 33 " + - "end", - "33"); - assertOptimizedEquals("case 1234 " + - "when bound_integer then 33 " + - "end", - "33"); - assertOptimizedEquals("case true " + - "when true then bound_integer " + - "end", - "1234"); - assertOptimizedEquals("case true " + - "when false then 1 " + - "else bound_integer " + - "end", - "1234"); - - assertOptimizedEquals("case bound_long " + - "when 1234 then 33 " + - "else unbound_long " + - "end", - "33"); - assertOptimizedEquals("case true " + - "when true then bound_long " + - "else unbound_long " + - "end", - "1234"); - assertOptimizedEquals("case true " + - "when false then unbound_long " + - "else bound_long " + - "end", - "1234"); - - assertOptimizedEquals("case unbound_long " + - "when 1234 then 33 " + - "else 1 " + - "end", - "" + - "case unbound_long " + - "when 1234 then 33 " + - "else 1 " + - "end"); - - assertOptimizedEquals("case 33 " + - "when 0 then 0 " + - "when 33 then unbound_long " + - "else 1 " + - "end", - "unbound_long"); - assertOptimizedEquals("case 33 " + - "when 0 then 0 " + - "when 33 then 1 " + - "when unbound_long then 2 " + - "else 1 " + - "end", - "1"); - assertOptimizedEquals("case 33 " + - "when unbound_long then 0 " + - "when 1 then 1 " + - "when 33 then 2 " + - "else 0 " + - "end", - "case 33 " + - "when unbound_long then 0 " + - "else 2 " + - "end"); - assertOptimizedEquals("case 33 " + - "when 0 then 0 " + - "when 1 then 1 " + - "else unbound_long " + - "end", - "unbound_long"); - assertOptimizedEquals("case 33 " + - "when unbound_long then 0 " + - "when 1 then 1 " + - "when unbound_long2 then 2 " + - "else 3 " + - "end", - "case 33 " + - "when unbound_long then 0 " + - "when unbound_long2 then 2 " + - "else 3 " + - "end"); - - assertOptimizedEquals("case true " + - "when unbound_long = 1 then 1 " + - "when 0 / 0 = 0 then 2 " + - "else 33 end", - "" + - "case true " + - "when unbound_long = 1 then 1 " + - "when 0 / 0 = 0 then 2 else 33 " + - "end"); - - assertOptimizedEquals("case bound_long " + - "when 123 * 10 + unbound_long then 1 = 1 " + - "else 1 = 2 " + - "end", - "" + - "case bound_long when 1230 + unbound_long then true " + - "else false " + - "end"); - - assertOptimizedEquals("case bound_long " + - "when unbound_long then 2 + 2 " + - "end", - "" + - "case bound_long " + - "when unbound_long then 4 " + - "end"); - - assertOptimizedEquals("case bound_long " + - "when unbound_long then 2 + 2 " + - "when 1 then null " + - "when 2 then null " + - "end", - "" + - "case bound_long " + - "when unbound_long then 4 " + - "end"); - - assertOptimizedMatches("case 1 " + - "when unbound_long then 1 " + - "when 0 / 0 then 2 " + - "else 1 " + - "end", - "" + - "case BIGINT '1' " + - "when unbound_long then 1 " + - "when cast(fail(8, 'ignored failure message') AS integer) then 2 " + - "else 1 " + - "end"); - - assertOptimizedMatches("case 1 " + - "when 0 / 0 then 1 " + - "when 0 / 0 then 2 " + - "else 1 " + - "end", - "" + - "case 1 " + - "when cast(fail(8, 'ignored failure message') as integer) then 1 " + - "when cast(fail(8, 'ignored failure message') as integer) then 2 " + - "else 1 " + - "end"); - - assertOptimizedEquals("case true " + - "when false then 2.2 " + - "when true then 2.2 " + - "end", - "2.2"); - - // TODO enabled when DECIMAL is default for literal: -// assertOptimizedEquals("case true " + -// "when false then 1234567890.0987654321 " + -// "when true then 3.3 " + -// "end", -// "CAST(3.3 AS DECIMAL(20,10))"); - - assertOptimizedEquals("case true " + - "when false then 1 " + - "when true then 2.2 " + - "end", - "2.2"); - - assertOptimizedEquals("case ARRAY[CAST(1 AS BIGINT)] when ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'matched'"); - assertOptimizedEquals("case ARRAY[CAST(2 AS BIGINT)] when ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'not_matched'"); - assertOptimizedEquals("case ARRAY[CAST(null AS BIGINT)] when ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'not_matched'"); - } - - @Test - public void testCoalesce() - { - assertOptimizedEquals("coalesce(null, null)", "coalesce(null, null)"); - assertOptimizedEquals("coalesce(2 * 3 * unbound_long, 1 - 1, null)", "coalesce(6 * unbound_long, 0)"); - assertOptimizedEquals("coalesce(2 * 3 * unbound_long, 1.0E0/2.0E0, null)", "coalesce(6 * unbound_long, 0.5E0)"); - assertOptimizedEquals("coalesce(unbound_long, 2, 1.0E0/2.0E0, 12.34E0, null)", "coalesce(unbound_long, 2.0E0, 0.5E0, 12.34E0)"); - assertOptimizedEquals("coalesce(2 * 3 * unbound_integer, 1 - 1, null)", "coalesce(6 * unbound_integer, 0)"); - assertOptimizedEquals("coalesce(2 * 3 * unbound_integer, 1.0E0/2.0E0, null)", "coalesce(6 * unbound_integer, 0.5E0)"); - assertOptimizedEquals("coalesce(unbound_integer, 2, 1.0E0/2.0E0, 12.34E0, null)", "coalesce(unbound_integer, 2.0E0, 0.5E0, 12.34E0)"); - assertOptimizedMatches("coalesce(0 / 0 > 1, unbound_boolean, 0 / 0 = 0)", - "coalesce(cast(fail(8, 'ignored failure message') as boolean), unbound_boolean)"); - assertOptimizedMatches("coalesce(unbound_long, unbound_long)", "unbound_long"); - assertOptimizedMatches("coalesce(2 * unbound_long, 2 * unbound_long)", "BIGINT '2' * unbound_long"); - assertOptimizedMatches("coalesce(unbound_long, unbound_long2, unbound_long)", "coalesce(unbound_long, unbound_long2)"); - assertOptimizedMatches("coalesce(unbound_long, unbound_long2, unbound_long, unbound_long3)", "coalesce(unbound_long, unbound_long2, unbound_long3)"); - assertOptimizedEquals("coalesce(6, unbound_long2, unbound_long, unbound_long3)", "6"); - assertOptimizedEquals("coalesce(2 * 3, unbound_long2, unbound_long, unbound_long3)", "6"); - assertOptimizedMatches("coalesce(unbound_long, coalesce(unbound_long, 1))", "coalesce(unbound_long, BIGINT '1')"); - assertOptimizedMatches("coalesce(coalesce(unbound_long, coalesce(unbound_long, 1)), unbound_long2)", "coalesce(unbound_long, BIGINT '1')"); - assertOptimizedMatches("coalesce(unbound_long, 2, coalesce(unbound_long, 1))", "coalesce(unbound_long, BIGINT '2')"); - assertOptimizedMatches("coalesce(coalesce(unbound_long, coalesce(unbound_long2, unbound_long3)), 1)", "coalesce(unbound_long, unbound_long2, unbound_long3, BIGINT '1')"); - assertOptimizedMatches("coalesce(unbound_double, coalesce(random(), unbound_double))", "coalesce(unbound_double, random())"); - assertOptimizedMatches("coalesce(random(), random(), 5)", "coalesce(random(), random(), 5E0)"); - assertOptimizedMatches("coalesce(unbound_long, coalesce(unbound_long, 1))", "coalesce(unbound_long, BIGINT '1')"); - assertOptimizedMatches("coalesce(coalesce(unbound_long, coalesce(unbound_long, 1)), unbound_long2)", "coalesce(unbound_long, BIGINT '1')"); - assertOptimizedMatches("coalesce(unbound_long, 2, coalesce(unbound_long, 1))", "coalesce(unbound_long, BIGINT '2')"); - assertOptimizedMatches("coalesce(coalesce(unbound_long, coalesce(unbound_long2, unbound_long3)), 1)", "coalesce(unbound_long, unbound_long2, unbound_long3, BIGINT '1')"); - assertOptimizedMatches("coalesce(unbound_double, coalesce(random(), unbound_double))", "coalesce(unbound_double, random())"); - } - - @Test - public void testIf() - { - assertOptimizedEquals("IF(2 = 2, 3, 4)", "3"); - assertOptimizedEquals("IF(1 = 2, 3, 4)", "4"); - assertOptimizedEquals("IF(1 = 2, BIGINT '3', 4)", "4"); - assertOptimizedEquals("IF(1 = 2, 3000000000, 4)", "4"); - - assertOptimizedEquals("IF(true, 3, 4)", "3"); - assertOptimizedEquals("IF(false, 3, 4)", "4"); - assertOptimizedEquals("IF(null, 3, 4)", "4"); - - assertOptimizedEquals("IF(true, 3, null)", "3"); - assertOptimizedEquals("IF(false, 3, null)", "null"); - assertOptimizedEquals("IF(true, null, 4)", "null"); - assertOptimizedEquals("IF(false, null, 4)", "4"); - assertOptimizedEquals("IF(true, null, null)", "null"); - assertOptimizedEquals("IF(false, null, null)", "null"); - - assertOptimizedEquals("IF(true, 3.5E0, 4.2E0)", "3.5E0"); - assertOptimizedEquals("IF(false, 3.5E0, 4.2E0)", "4.2E0"); - - assertOptimizedEquals("IF(true, 'foo', 'bar')", "'foo'"); - assertOptimizedEquals("IF(false, 'foo', 'bar')", "'bar'"); - - assertOptimizedEquals("IF(true, 1.01, 1.02)", "1.01"); - assertOptimizedEquals("IF(false, 1.01, 1.02)", "1.02"); - assertOptimizedEquals("IF(true, 1234567890.123, 1.02)", "1234567890.123"); - assertOptimizedEquals("IF(false, 1.01, 1234567890.123)", "1234567890.123"); - } - - @Test - public void testLike() - { - assertOptimizedEquals("'a' LIKE 'a'", "true"); - assertOptimizedEquals("'' LIKE 'a'", "false"); - assertOptimizedEquals("'abc' LIKE 'a'", "false"); - - assertOptimizedEquals("'a' LIKE '_'", "true"); - assertOptimizedEquals("'' LIKE '_'", "false"); - assertOptimizedEquals("'abc' LIKE '_'", "false"); - - assertOptimizedEquals("'a' LIKE '%'", "true"); - assertOptimizedEquals("'' LIKE '%'", "true"); - assertOptimizedEquals("'abc' LIKE '%'", "true"); - - assertOptimizedEquals("'abc' LIKE '___'", "true"); - assertOptimizedEquals("'ab' LIKE '___'", "false"); - assertOptimizedEquals("'abcd' LIKE '___'", "false"); - - assertOptimizedEquals("'abc' LIKE 'abc'", "true"); - assertOptimizedEquals("'xyz' LIKE 'abc'", "false"); - assertOptimizedEquals("'abc0' LIKE 'abc'", "false"); - assertOptimizedEquals("'0abc' LIKE 'abc'", "false"); - - assertOptimizedEquals("'abc' LIKE 'abc%'", "true"); - assertOptimizedEquals("'abc0' LIKE 'abc%'", "true"); - assertOptimizedEquals("'0abc' LIKE 'abc%'", "false"); - - assertOptimizedEquals("'abc' LIKE '%abc'", "true"); - assertOptimizedEquals("'0abc' LIKE '%abc'", "true"); - assertOptimizedEquals("'abc0' LIKE '%abc'", "false"); - - assertOptimizedEquals("'abc' LIKE '%abc%'", "true"); - assertOptimizedEquals("'0abc' LIKE '%abc%'", "true"); - assertOptimizedEquals("'abc0' LIKE '%abc%'", "true"); - assertOptimizedEquals("'0abc0' LIKE '%abc%'", "true"); - assertOptimizedEquals("'xyzw' LIKE '%abc%'", "false"); - - assertOptimizedEquals("'abc' LIKE '%ab%c%'", "true"); - assertOptimizedEquals("'0abc' LIKE '%ab%c%'", "true"); - assertOptimizedEquals("'abc0' LIKE '%ab%c%'", "true"); - assertOptimizedEquals("'0abc0' LIKE '%ab%c%'", "true"); - assertOptimizedEquals("'ab01c' LIKE '%ab%c%'", "true"); - assertOptimizedEquals("'0ab01c' LIKE '%ab%c%'", "true"); - assertOptimizedEquals("'ab01c0' LIKE '%ab%c%'", "true"); - assertOptimizedEquals("'0ab01c0' LIKE '%ab%c%'", "true"); - - assertOptimizedEquals("'xyzw' LIKE '%ab%c%'", "false"); - - // ensure regex chars are escaped - assertOptimizedEquals("'\' LIKE '\'", "true"); - assertOptimizedEquals("'.*' LIKE '.*'", "true"); - assertOptimizedEquals("'[' LIKE '['", "true"); - assertOptimizedEquals("']' LIKE ']'", "true"); - assertOptimizedEquals("'{' LIKE '{'", "true"); - assertOptimizedEquals("'}' LIKE '}'", "true"); - assertOptimizedEquals("'?' LIKE '?'", "true"); - assertOptimizedEquals("'+' LIKE '+'", "true"); - assertOptimizedEquals("'(' LIKE '('", "true"); - assertOptimizedEquals("')' LIKE ')'", "true"); - assertOptimizedEquals("'|' LIKE '|'", "true"); - assertOptimizedEquals("'^' LIKE '^'", "true"); - assertOptimizedEquals("'$' LIKE '$'", "true"); - - assertOptimizedEquals("null LIKE '%'", "null"); - assertOptimizedEquals("'a' LIKE null", "null"); - assertOptimizedEquals("'a' LIKE '%' ESCAPE null", "null"); - assertOptimizedEquals("'a' LIKE unbound_string ESCAPE null", "null"); - - assertOptimizedEquals("'%' LIKE 'z%' ESCAPE 'z'", "true"); - - assertRowExpressionEquals(SERIALIZABLE, "'%' LIKE 'z%' ESCAPE 'z'", "true"); - assertRowExpressionEquals(SERIALIZABLE, "'%' LIKE 'z%'", "false"); - } - - @Test - public void testLikeOptimization() - { - assertOptimizedEquals("unbound_string LIKE 'abc'", "unbound_string = CAST('abc' AS VARCHAR)"); - - assertOptimizedEquals("unbound_string LIKE '' ESCAPE '#'", "unbound_string LIKE '' ESCAPE '#'"); - assertOptimizedEquals("unbound_string LIKE 'abc' ESCAPE '#'", "unbound_string = CAST('abc' AS VARCHAR)"); - assertOptimizedEquals("unbound_string LIKE 'a#_b' ESCAPE '#'", "unbound_string = CAST('a_b' AS VARCHAR)"); - assertOptimizedEquals("unbound_string LIKE 'a#%b' ESCAPE '#'", "unbound_string = CAST('a%b' AS VARCHAR)"); - assertOptimizedEquals("unbound_string LIKE 'a#_##b' ESCAPE '#'", "unbound_string = CAST('a_#b' AS VARCHAR)"); - assertOptimizedEquals("unbound_string LIKE 'a#__b' ESCAPE '#'", "unbound_string LIKE 'a#__b' ESCAPE '#'"); - assertOptimizedEquals("unbound_string LIKE 'a##%b' ESCAPE '#'", "unbound_string LIKE 'a##%b' ESCAPE '#'"); - - assertOptimizedEquals("bound_string LIKE bound_pattern", "true"); - assertOptimizedEquals("'abc' LIKE bound_pattern", "false"); - - assertOptimizedEquals("unbound_string LIKE bound_pattern", "unbound_string LIKE bound_pattern"); - assertDoNotOptimize("unbound_string LIKE 'abc%'", SERIALIZABLE); - - assertOptimizedEquals("unbound_string LIKE unbound_pattern ESCAPE unbound_string", "unbound_string LIKE unbound_pattern ESCAPE unbound_string"); - } - - @Test - public void testInvalidLike() + public void testBind() { - assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE 'abc' ESCAPE ''")); - assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE 'abc' ESCAPE 'bc'")); - assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE '#' ESCAPE '#'")); - assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE '#abc' ESCAPE '#'")); - assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE 'ab#' ESCAPE '#'")); + assertOptimizedEquals("apply(90, \"$internal$bind\"(9, (x, y) -> x + y))", "apply(90, \"$internal$bind\"(9, (x, y) -> x + y))"); + evaluate("apply(90, \"$internal$bind\"(9, (x, y) -> x + y))", true); + evaluate("apply(900, \"$internal$bind\"(90, 9, (x, y, z) -> x + y + z))", true); } @Test @@ -1450,42 +107,6 @@ public void testLambda() assertEquals(evaluate("reduce(ARRAY[1, 5], 0, (x, y) -> x + y, x -> x)", true), 6L); } - @Test - public void testBind() - { - assertOptimizedEquals("apply(90, \"$internal$bind\"(9, (x, y) -> x + y))", "apply(90, \"$internal$bind\"(9, (x, y) -> x + y))"); - evaluate("apply(90, \"$internal$bind\"(9, (x, y) -> x + y))", true); - evaluate("apply(900, \"$internal$bind\"(90, 9, (x, y, z) -> x + y + z))", true); - } - - @Test - public void testFailedExpressionOptimization() - { - assertOptimizedMatches("CASE unbound_long WHEN 1 THEN 1 WHEN 0 / 0 THEN 2 END", - "CASE unbound_long WHEN BIGINT '1' THEN 1 WHEN cast(fail(8, 'ignored failure message') as bigint) THEN 2 END"); - - assertOptimizedMatches("CASE unbound_boolean WHEN true THEN 1 ELSE 0 / 0 END", - "CASE unbound_boolean WHEN true THEN 1 ELSE cast(fail(8, 'ignored failure message') as integer) END"); - - assertOptimizedMatches("CASE bound_long WHEN unbound_long THEN 1 WHEN 0 / 0 THEN 2 ELSE 1 END", - "CASE BIGINT '1234' WHEN unbound_long THEN 1 WHEN cast(fail(8, 'ignored failure message') as bigint) THEN 2 ELSE 1 END"); - - assertOptimizedMatches("case when unbound_boolean then 1 when 0 / 0 = 0 then 2 end", - "case when unbound_boolean then 1 when cast(fail(8, 'ignored failure message') as boolean) then 2 end"); - - assertOptimizedMatches("case when unbound_boolean then 1 else 0 / 0 end", - "case when unbound_boolean then 1 else cast(fail(8, 'ignored failure message') as integer) end"); - - assertOptimizedMatches("case when unbound_boolean then 0 / 0 else 1 end", - "case when unbound_boolean then cast(fail(8, 'ignored failure message') as integer) else 1 end"); - } - - @Test(expectedExceptions = PrestoException.class) - public void testOptimizeDivideByZero() - { - optimize("0 / 0"); - } - @Test public void testMassiveArray() { @@ -1500,90 +121,6 @@ public void testMassiveArray() optimize(format("ARRAY [%s]", Joiner.on(", ").join(IntStream.range(0, 10_000).mapToObj(i -> "ARRAY['" + i + "']").iterator()))); } - @Test - public void testArrayConstructor() - { - optimize("ARRAY []"); - assertOptimizedEquals("ARRAY [(unbound_long + 0), (unbound_long + 1), (unbound_long + 2)]", - "array_constructor((unbound_long + 0), (unbound_long + 1), (unbound_long + 2))"); - assertOptimizedEquals("ARRAY [(bound_long + 0), (unbound_long + 1), (bound_long + 2)]", - "array_constructor((bound_long + 0), (unbound_long + 1), (bound_long + 2))"); - assertOptimizedEquals("ARRAY [(bound_long + 0), (unbound_long + 1), NULL]", - "array_constructor((bound_long + 0), (unbound_long + 1), NULL)"); - } - - @Test - public void testRowConstructor() - { - optimize("ROW(NULL)"); - optimize("ROW(1)"); - optimize("ROW(unbound_long + 0)"); - optimize("ROW(unbound_long + unbound_long2, unbound_string, unbound_double)"); - optimize("ROW(unbound_boolean, FALSE, ARRAY[unbound_long, unbound_long2], unbound_null_string, unbound_interval)"); - optimize("ARRAY [ROW(unbound_string, unbound_double), ROW(unbound_string, 0.0E0)]"); - optimize("ARRAY [ROW('string', unbound_double), ROW('string', bound_double)]"); - optimize("ROW(ROW(NULL), ROW(ROW(ROW(ROW('rowception')))))"); - optimize("ROW(unbound_string, bound_string)"); - - optimize("ARRAY [ROW(unbound_string, unbound_double), ROW(CAST(bound_string AS VARCHAR), 0.0E0)]"); - optimize("ARRAY [ROW(CAST(bound_string AS VARCHAR), 0.0E0), ROW(unbound_string, unbound_double)]"); - - optimize("ARRAY [ROW(unbound_string, unbound_double), CAST(NULL AS ROW(VARCHAR, DOUBLE))]"); - optimize("ARRAY [CAST(NULL AS ROW(VARCHAR, DOUBLE)), ROW(unbound_string, unbound_double)]"); - } - - @Test - public void testDereference() - { - optimize("ARRAY []"); - assertOptimizedEquals("ARRAY [(unbound_long + 0), (unbound_long + 1), (unbound_long + 2)]", - "array_constructor((unbound_long + 0), (unbound_long + 1), (unbound_long + 2))"); - assertOptimizedEquals("ARRAY [(bound_long + 0), (unbound_long + 1), (bound_long + 2)]", - "array_constructor((bound_long + 0), (unbound_long + 1), (bound_long + 2))"); - assertOptimizedEquals("ARRAY [(bound_long + 0), (unbound_long + 1), NULL]", - "array_constructor((bound_long + 0), (unbound_long + 1), NULL)"); - } - - @Test - public void testRowDereference() - { - optimize("CAST(null AS ROW(a VARCHAR, b BIGINT)).a"); - } - - @Test - public void testRowSubscript() - { - assertOptimizedEquals("ROW (1, 'a', true)[3]", "true"); - assertOptimizedEquals("ROW (1, 'a', ROW (2, 'b', ROW (3, 'c')))[3][3][2]", "'c'"); - } - - @Test(expectedExceptions = PrestoException.class) - public void testArraySubscriptConstantNegativeIndex() - { - optimize("ARRAY [1, 2, 3][-1]"); - } - - @Test(expectedExceptions = PrestoException.class) - public void testArraySubscriptConstantZeroIndex() - { - optimize("ARRAY [1, 2, 3][0]"); - } - - @Test(expectedExceptions = PrestoException.class) - public void testMapSubscriptMissingKey() - { - optimize("MAP(ARRAY [1, 2], ARRAY [3, 4])[-1]"); - } - - @Test - public void testMapSubscriptConstantIndexes() - { - optimize("MAP(ARRAY [1, 2], ARRAY [3, 4])[1]"); - optimize("MAP(ARRAY [BIGINT '1', 2], ARRAY [3, 4])[1]"); - optimize("MAP(ARRAY [1, 2], ARRAY [3, 4])[2]"); - optimize("MAP(ARRAY [ARRAY[1,1]], ARRAY['a'])[ARRAY[1,1]]"); - } - @Test(timeOut = 60000) public void testLikeInvalidUtf8() { @@ -1592,21 +129,13 @@ public void testLikeInvalidUtf8() } @Test - public void testLiterals() + public void testLikeSerializable() { - optimize("date '2013-04-03' + unbound_interval"); - optimize("time '03:04:05.321' + unbound_interval"); - optimize("time '03:04:05.321 UTC' + unbound_interval"); - optimize("timestamp '2013-04-03 03:04:05.321' + unbound_interval"); - optimize("timestamp '2013-04-03 03:04:05.321 UTC' + unbound_interval"); - - optimize("interval '3' day * unbound_long"); - optimize("interval '3' year * unbound_long"); - - assertEquals(optimize("X'1234'"), Slices.wrappedBuffer((byte) 0x12, (byte) 0x34)); + assertRowExpressionEquals(SERIALIZABLE, "'%' LIKE 'z%' ESCAPE 'z'", "true"); + assertRowExpressionEquals(SERIALIZABLE, "'%' LIKE 'z%'", "false"); } - private static void assertLike(byte[] value, String pattern, boolean expected) + private void assertLike(byte[] value, String pattern, boolean expected) { Expression predicate = new LikePredicate( rawStringLiteral(Slices.wrappedBuffer(value)), @@ -1627,12 +156,23 @@ public Slice getSlice() }; } - private static void assertOptimizedEquals(@Language("SQL") String actual, @Language("SQL") String expected) + @Override + public void assertDoNotOptimize(@Language("SQL") String expression, ExpressionOptimizer.Level optimizationLevel) { - assertEquals(optimize(actual), optimize(expected)); + assertRoundTrip(expression); + Expression translatedExpression = expression(expression); + RowExpression rowExpression = toRowExpression(translatedExpression); + + Object expressionResult = optimize(translatedExpression); + if (expressionResult instanceof Expression) { + expressionResult = toRowExpression((Expression) expressionResult); + } + Object rowExpressionResult = optimize(rowExpression, optimizationLevel); + assertRowExpressionEvaluationEquals(expressionResult, rowExpressionResult); + assertRowExpressionEvaluationEquals(rowExpressionResult, rowExpression); } - private static void assertRowExpressionEquals(Level level, @Language("SQL") String actual, @Language("SQL") String expected) + private void assertRowExpressionEquals(ExpressionOptimizer.Level level, @Language("SQL") String actual, @Language("SQL") String expected) { Object actualResult = optimize(toRowExpression(expression(actual)), level); Object expectedResult = optimize(toRowExpression(expression(expected)), level); @@ -1643,7 +183,14 @@ private static void assertRowExpressionEquals(Level level, @Language("SQL") Stri assertEquals(actualResult, expectedResult); } - private static void assertOptimizedMatches(@Language("SQL") String actual, @Language("SQL") String expected) + @Override + public void assertOptimizedEquals(@Language("SQL") String actual, @Language("SQL") String expected) + { + assertEquals(optimize(actual), optimize(expected)); + } + + @Override + public void assertOptimizedMatches(@Language("SQL") String actual, @Language("SQL") String expected) { // replaces FunctionCalls to FailureFunction by fail() Object actualOptimized = optimize(actual); @@ -1655,7 +202,8 @@ private static void assertOptimizedMatches(@Language("SQL") String actual, @Lang rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(expected))); } - private static Object optimize(@Language("SQL") String expression) + @Override + public Object optimize(@Language("SQL") String expression) { assertRoundTrip(expression); @@ -1668,186 +216,40 @@ private static Object optimize(@Language("SQL") String expression) return expressionResult; } - private static Expression expression(String expression) + @Override + public void assertEvaluatedEquals(@Language("SQL") String actual, @Language("SQL") String expected) { - return FunctionAssertions.createExpression(expression, METADATA, SYMBOL_TYPES); - } - - private static RowExpression toRowExpression(Expression expression) - { - return TRANSLATOR.translate(expression, SYMBOL_TYPES); + assertEquals(evaluate(actual, true), evaluate(expected, true)); } - private static Object optimize(Expression expression) + private Object optimize(RowExpression expression, ExpressionOptimizer.Level level) { - Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, SYMBOL_TYPES, expression, emptyMap(), WarningCollector.NOOP); - ExpressionInterpreter interpreter = expressionOptimizer(expression, METADATA, TEST_SESSION, expressionTypes); - return interpreter.optimize(variable -> { + return new RowExpressionInterpreter(expression, METADATA, TEST_SESSION.toConnectorSession(), level).optimize(variable -> { Symbol symbol = new Symbol(variable.getName()); Object value = symbolConstant(symbol); if (value == null) { - return symbol.toSymbolReference(); + return new VariableReferenceExpression(Optional.empty(), symbol.getName(), SYMBOL_TYPES.get(symbol.toSymbolReference())); } return value; }); } - private static Object optimize(RowExpression expression, Level level) + private Object optimize(Expression expression) { - return new RowExpressionInterpreter(expression, METADATA, TEST_SESSION.toConnectorSession(), level).optimize(variable -> { + Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, SYMBOL_TYPES, expression, emptyMap(), WarningCollector.NOOP); + ExpressionInterpreter interpreter = expressionOptimizer(expression, METADATA, TEST_SESSION, expressionTypes); + return interpreter.optimize(variable -> { Symbol symbol = new Symbol(variable.getName()); Object value = symbolConstant(symbol); if (value == null) { - return new VariableReferenceExpression(Optional.empty(), symbol.getName(), SYMBOL_TYPES.get(symbol.toSymbolReference())); + return symbol.toSymbolReference(); } return value; }); } - private static void assertDoNotOptimize(@Language("SQL") String expression, Level optimizationLevel) - { - assertRoundTrip(expression); - Expression translatedExpression = expression(expression); - RowExpression rowExpression = toRowExpression(translatedExpression); - - Object expressionResult = optimize(translatedExpression); - if (expressionResult instanceof Expression) { - expressionResult = toRowExpression((Expression) expressionResult); - } - Object rowExpressionResult = optimize(rowExpression, optimizationLevel); - assertRowExpressionEvaluationEquals(expressionResult, rowExpressionResult); - assertRowExpressionEvaluationEquals(rowExpressionResult, rowExpression); - } - - private static Object symbolConstant(Symbol symbol) - { - switch (symbol.getName().toLowerCase(ENGLISH)) { - case "bound_integer": - return 1234L; - case "bound_long": - return 1234L; - case "bound_string": - return utf8Slice("hello"); - case "bound_double": - return 12.34; - case "bound_date": - return new LocalDate(2001, 8, 22).toDateMidnight(DateTimeZone.UTC).getMillis(); - case "bound_time": - return new LocalTime(3, 4, 5, 321).toDateTime(new DateTime(0, DateTimeZone.UTC)).getMillis(); - case "bound_timestamp": - return new DateTime(2001, 8, 22, 3, 4, 5, 321, DateTimeZone.UTC).getMillis(); - case "bound_pattern": - return utf8Slice("%el%"); - case "bound_timestamp_with_timezone": - return new SqlTimestampWithTimeZone(new DateTime(1970, 1, 1, 1, 0, 0, 999, DateTimeZone.UTC).getMillis(), getTimeZoneKey("Z")); - case "bound_varbinary": - return Slices.wrappedBuffer((byte) 0xab); - case "bound_decimal_short": - return 12345L; - case "bound_decimal_long": - return Decimals.encodeUnscaledValue(new BigInteger("12345678901234567890123")); - } - return null; - } - - private static void assertExpressionAndRowExpressionEquals(Object expressionResult, Object rowExpressionResult) - { - if (rowExpressionResult instanceof RowExpression) { - // Cannot be completely evaluated into a constant; compare expressions - assertTrue(expressionResult instanceof Expression); - - // It is tricky to check the equivalence of an expression and a row expression. - // We rely on the optimized translator to fill the gap. - RowExpression translated = TRANSLATOR.translateAndOptimize((Expression) expressionResult, SYMBOL_TYPES); - assertRowExpressionEvaluationEquals(translated, rowExpressionResult); - } - else { - // We have constants; directly compare - assertRowExpressionEvaluationEquals(expressionResult, rowExpressionResult); - } - } - - /** - * Assert the evaluation result of two row expressions equivalent - * no matter they are constants or remaining row expressions. - */ - private static void assertRowExpressionEvaluationEquals(Object left, Object right) - { - if (right instanceof RowExpression) { - assertTrue(left instanceof RowExpression); - // assertEquals(((RowExpression) left).getType(), ((RowExpression) right).getType()); - if (left instanceof ConstantExpression) { - if (isRemovableCast(right)) { - assertRowExpressionEvaluationEquals(left, ((CallExpression) right).getArguments().get(0)); - return; - } - assertTrue(right instanceof ConstantExpression); - assertRowExpressionEvaluationEquals(((ConstantExpression) left).getValue(), ((ConstantExpression) left).getValue()); - } - else if (left instanceof InputReferenceExpression || left instanceof VariableReferenceExpression) { - assertEquals(left, right); - } - else if (left instanceof CallExpression) { - assertTrue(right instanceof CallExpression); - assertEquals(((CallExpression) left).getFunctionHandle(), ((CallExpression) right).getFunctionHandle()); - assertEquals(((CallExpression) left).getArguments().size(), ((CallExpression) right).getArguments().size()); - for (int i = 0; i < ((CallExpression) left).getArguments().size(); i++) { - assertRowExpressionEvaluationEquals(((CallExpression) left).getArguments().get(i), ((CallExpression) right).getArguments().get(i)); - } - } - else if (left instanceof SpecialFormExpression) { - assertTrue(right instanceof SpecialFormExpression); - assertEquals(((SpecialFormExpression) left).getForm(), ((SpecialFormExpression) right).getForm()); - assertEquals(((SpecialFormExpression) left).getArguments().size(), ((SpecialFormExpression) right).getArguments().size()); - for (int i = 0; i < ((SpecialFormExpression) left).getArguments().size(); i++) { - assertRowExpressionEvaluationEquals(((SpecialFormExpression) left).getArguments().get(i), ((SpecialFormExpression) right).getArguments().get(i)); - } - } - else { - assertTrue(left instanceof LambdaDefinitionExpression); - assertTrue(right instanceof LambdaDefinitionExpression); - assertEquals(((LambdaDefinitionExpression) left).getArguments(), ((LambdaDefinitionExpression) right).getArguments()); - assertEquals(((LambdaDefinitionExpression) left).getArgumentTypes(), ((LambdaDefinitionExpression) right).getArgumentTypes()); - assertRowExpressionEvaluationEquals(((LambdaDefinitionExpression) left).getBody(), ((LambdaDefinitionExpression) right).getBody()); - } - } - else { - // We have constants; directly compare - if (left instanceof Block) { - assertTrue(right instanceof Block); - assertEquals(blockToSlice((Block) left), blockToSlice((Block) right)); - } - else { - assertEquals(left, right); - } - } - } - - private static boolean isRemovableCast(Object value) - { - if (value instanceof CallExpression && - new FunctionResolution(METADATA.getFunctionAndTypeManager().getFunctionAndTypeResolver()).isCastFunction(((CallExpression) value).getFunctionHandle())) { - Type targetType = ((CallExpression) value).getType(); - Type sourceType = ((CallExpression) value).getArguments().get(0).getType(); - return METADATA.getFunctionAndTypeManager().canCoerce(sourceType, targetType); - } - return false; - } - - private static Slice blockToSlice(Block block) - { - // This function is strictly for testing use only - SliceOutput sliceOutput = new DynamicSliceOutput(1000); - BlockSerdeUtil.writeBlock(blockEncodingSerde, sliceOutput, block); - return sliceOutput.slice(); - } - - private static void assertEvaluatedEquals(@Language("SQL") String actual, @Language("SQL") String expected) - { - assertEquals(evaluate(actual, true), evaluate(expected, true)); - } - - private static Object evaluate(String expression, boolean deterministic) + @Override + public Object evaluate(@Language("SQL") String expression, boolean deterministic) { assertRoundTrip(expression); @@ -1856,14 +258,7 @@ private static Object evaluate(String expression, boolean deterministic) return evaluate(parsedExpression, deterministic); } - private static void assertRoundTrip(String expression) - { - ParsingOptions parsingOptions = createParsingOptions(TEST_SESSION); - assertEquals(SQL_PARSER.createExpression(expression, parsingOptions), - SQL_PARSER.createExpression(formatExpression(SQL_PARSER.createExpression(expression, parsingOptions), Optional.empty()), parsingOptions)); - } - - private static Object evaluate(Expression expression, boolean deterministic) + private Object evaluate(Expression expression, boolean deterministic) { Map, Type> expressionTypes = getExpressionTypes(TEST_SESSION, METADATA, SQL_PARSER, SYMBOL_TYPES, expression, emptyMap(), WarningCollector.NOOP); Object expressionResult = expressionInterpreter(expression, METADATA, TEST_SESSION, expressionTypes).evaluate(); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/expressions/AbstractTestExpressionInterpreter.java b/presto-main-base/src/test/java/com/facebook/presto/sql/expressions/AbstractTestExpressionInterpreter.java new file mode 100644 index 0000000000000..ad2433e4a9ef6 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/expressions/AbstractTestExpressionInterpreter.java @@ -0,0 +1,1697 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.expressions; + +import com.facebook.presto.common.CatalogSchemaName; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockEncodingManager; +import com.facebook.presto.common.block.BlockEncodingSerde; +import com.facebook.presto.common.block.BlockSerdeUtil; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.Decimals; +import com.facebook.presto.common.type.SqlTimestampWithTimeZone; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarbinaryType; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.operator.scalar.FunctionAssertions; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.AggregationFunctionMetadata; +import com.facebook.presto.spi.function.FunctionKind; +import com.facebook.presto.spi.function.Parameter; +import com.facebook.presto.spi.function.RoutineCharacteristics; +import com.facebook.presto.spi.function.SqlInvokedFunction; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.InputReferenceExpression; +import com.facebook.presto.spi.relation.LambdaDefinitionExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.ExpressionFormatter; +import com.facebook.presto.sql.TestingRowExpressionTranslator; +import com.facebook.presto.sql.parser.ParsingOptions; +import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.facebook.presto.sql.tree.EnumLiteral; +import com.facebook.presto.sql.tree.Expression; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.airlift.slice.DynamicSliceOutput; +import io.airlift.slice.Slice; +import io.airlift.slice.SliceOutput; +import io.airlift.slice.Slices; +import org.intellij.lang.annotations.Language; +import org.joda.time.DateTime; +import org.joda.time.DateTimeZone; +import org.joda.time.LocalDate; +import org.joda.time.LocalTime; +import org.testng.annotations.Test; + +import java.math.BigInteger; +import java.util.Optional; +import java.util.concurrent.TimeUnit; + +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.DateType.DATE; +import static com.facebook.presto.common.type.DecimalType.createDecimalType; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.TimeType.TIME; +import static com.facebook.presto.common.type.TimeZoneKey.getTimeZoneKey; +import static com.facebook.presto.common.type.TimestampType.TIMESTAMP; +import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_CAST_ARGUMENT; +import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.DETERMINISTIC; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.CPP; +import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.SERIALIZABLE; +import static com.facebook.presto.sql.ExpressionFormatter.formatExpression; +import static com.facebook.presto.type.IntervalDayTimeType.INTERVAL_DAY_TIME; +import static com.facebook.presto.util.AnalyzerUtil.createParsingOptions; +import static com.facebook.presto.util.DateTimeZoneIndex.getDateTimeZone; +import static io.airlift.slice.Slices.utf8Slice; +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertThrows; +import static org.testng.Assert.assertTrue; +import static org.testng.Assert.fail; + +public abstract class AbstractTestExpressionInterpreter +{ + public static final SqlInvokedFunction SQUARE_UDF_CPP = new SqlInvokedFunction( + QualifiedObjectName.valueOf(new CatalogSchemaName("json", "test_schema"), "square"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))), + parseTypeSignature(StandardTypes.BIGINT), + "Integer square", + RoutineCharacteristics.builder().setDeterminism(DETERMINISTIC).setLanguage(CPP).build(), + "", + notVersioned()); + + public static final SqlInvokedFunction AVG_UDAF_CPP = new SqlInvokedFunction( + QualifiedObjectName.valueOf(new CatalogSchemaName("json", "test_schema"), "avg"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.DOUBLE))), + parseTypeSignature(StandardTypes.DOUBLE), + "Returns mean of doubles", + RoutineCharacteristics.builder().setDeterminism(DETERMINISTIC).setLanguage(CPP).build(), + "", + notVersioned(), + FunctionKind.AGGREGATE, + Optional.of(new AggregationFunctionMetadata(parseTypeSignature("ROW(double, int)"), false))); + + public static final int TEST_VARCHAR_TYPE_LENGTH = 17; + public static final TypeProvider SYMBOL_TYPES = TypeProvider.viewOf(ImmutableMap.builder() + .put("bound_integer", INTEGER) + .put("bound_long", BIGINT) + .put("bound_string", createVarcharType(TEST_VARCHAR_TYPE_LENGTH)) + .put("bound_varbinary", VarbinaryType.VARBINARY) + .put("bound_double", DOUBLE) + .put("bound_boolean", BOOLEAN) + .put("bound_date", DATE) + .put("bound_time", TIME) + .put("bound_timestamp", TIMESTAMP) + .put("bound_pattern", VARCHAR) + .put("bound_null_string", VARCHAR) + .put("bound_decimal_short", createDecimalType(5, 2)) + .put("bound_decimal_long", createDecimalType(23, 3)) + .put("time", BIGINT) // for testing reserved identifiers + .put("unbound_integer", INTEGER) + .put("unbound_long", BIGINT) + .put("unbound_long2", BIGINT) + .put("unbound_long3", BIGINT) + .put("unbound_string", VARCHAR) + .put("unbound_double", DOUBLE) + .put("unbound_boolean", BOOLEAN) + .put("unbound_date", DATE) + .put("unbound_time", TIME) + .put("unbound_array", new ArrayType(BIGINT)) + .put("unbound_timestamp", TIMESTAMP) + .put("unbound_interval", INTERVAL_DAY_TIME) + .put("unbound_pattern", VARCHAR) + .put("unbound_null_string", VARCHAR) + .build()); + + public static final SqlParser SQL_PARSER = new SqlParser(); + public static final Metadata METADATA = MetadataManager.createTestMetadataManager(); + public static final TestingRowExpressionTranslator TRANSLATOR = new TestingRowExpressionTranslator(METADATA); + public static final BlockEncodingSerde blockEncodingSerde = new BlockEncodingManager(); + + @Test + public void testAnd() + { + assertOptimizedEquals("true and false", "false"); + assertOptimizedEquals("false and true", "false"); + assertOptimizedEquals("false and false", "false"); + + assertOptimizedEquals("true and null", "null"); + assertOptimizedEquals("false and null", "false"); + assertOptimizedEquals("null and true", "null"); + assertOptimizedEquals("null and false", "false"); + assertOptimizedEquals("null and null", "null"); + + assertOptimizedEquals("unbound_string='z' and true", "unbound_string='z'"); + assertOptimizedEquals("unbound_string='z' and false", "false"); + assertOptimizedEquals("true and unbound_string='z'", "unbound_string='z'"); + assertOptimizedEquals("false and unbound_string='z'", "false"); + + assertOptimizedEquals("bound_string='z' and bound_long=1+1", "bound_string='z' and bound_long=2"); + assertOptimizedEquals("random() > 0 and random() > 0", "random() > 0 and random() > 0"); + } + + @Test + public void testOr() + { + assertOptimizedEquals("true or true", "true"); + assertOptimizedEquals("true or false", "true"); + assertOptimizedEquals("false or true", "true"); + assertOptimizedEquals("false or false", "false"); + + assertOptimizedEquals("true or null", "true"); + assertOptimizedEquals("null or true", "true"); + assertOptimizedEquals("null or null", "null"); + + assertOptimizedEquals("false or null", "null"); + assertOptimizedEquals("null or false", "null"); + + assertOptimizedEquals("bound_string='z' or true", "true"); + assertOptimizedEquals("bound_string='z' or false", "bound_string='z'"); + assertOptimizedEquals("true or bound_string='z'", "true"); + assertOptimizedEquals("false or bound_string='z'", "bound_string='z'"); + + assertOptimizedEquals("bound_string='z' or bound_long=1+1", "bound_string='z' or bound_long=2"); + assertOptimizedEquals("random() > 0 or random() > 0", "random() > 0 or random() > 0"); + } + + @Test + public void testComparison() + { + assertOptimizedEquals("null = null", "null"); + + assertOptimizedEquals("'a' = 'b'", "false"); + assertOptimizedEquals("'a' = 'a'", "true"); + assertOptimizedEquals("'a' = null", "null"); + assertOptimizedEquals("null = 'a'", "null"); + assertOptimizedEquals("bound_integer = 1234", "true"); + assertOptimizedEquals("bound_integer = 12340000000", "false"); + assertOptimizedEquals("bound_long = BIGINT '1234'", "true"); + assertOptimizedEquals("bound_long = 1234", "true"); + assertOptimizedEquals("bound_double = 12.34", "true"); + assertOptimizedEquals("bound_string = 'hello'", "true"); + assertOptimizedEquals("bound_long = unbound_long", "1234 = unbound_long"); + + assertOptimizedEquals("10151082135029368 = 10151082135029369", "false"); + + assertOptimizedEquals("bound_varbinary = X'a b'", "true"); + assertOptimizedEquals("bound_varbinary = X'a d'", "false"); + + assertOptimizedEquals("1.1 = 1.1", "true"); + assertOptimizedEquals("9876543210.9874561203 = 9876543210.9874561203", "true"); + assertOptimizedEquals("bound_decimal_short = 123.45", "true"); + assertOptimizedEquals("bound_decimal_long = 12345678901234567890.123", "true"); + } + + @Test + public void testIsDistinctFrom() + { + assertOptimizedEquals("null is distinct from null", "false"); + + assertOptimizedEquals("3 is distinct from 4", "true"); + assertOptimizedEquals("3 is distinct from BIGINT '4'", "true"); + assertOptimizedEquals("3 is distinct from 4000000000", "true"); + assertOptimizedEquals("3 is distinct from 3", "false"); + assertOptimizedEquals("3 is distinct from null", "true"); + assertOptimizedEquals("null is distinct from 3", "true"); + + assertOptimizedEquals("10151082135029368 is distinct from 10151082135029369", "true"); + + assertOptimizedEquals("1.1 is distinct from 1.1", "false"); + assertOptimizedEquals("9876543210.9874561203 is distinct from NULL", "true"); + assertOptimizedEquals("bound_decimal_short is distinct from NULL", "true"); + assertOptimizedEquals("bound_decimal_long is distinct from 12345678901234567890.123", "false"); + } + + @Test + public void testIsNull() + { + assertOptimizedEquals("null is null", "true"); + assertOptimizedEquals("1 is null", "false"); + assertOptimizedEquals("10000000000 is null", "false"); + assertOptimizedEquals("BIGINT '1' is null", "false"); + assertOptimizedEquals("1.0 is null", "false"); + assertOptimizedEquals("'a' is null", "false"); + assertOptimizedEquals("true is null", "false"); + assertOptimizedEquals("null+1 is null", "true"); + assertOptimizedEquals("unbound_string is null", "unbound_string is null"); + assertOptimizedEquals("unbound_long+(1+1) is null", "unbound_long+2 is null"); + assertOptimizedEquals("1.1 is null", "false"); + assertOptimizedEquals("9876543210.9874561203 is null", "false"); + assertOptimizedEquals("bound_decimal_short is null", "false"); + assertOptimizedEquals("bound_decimal_long is null", "false"); + } + + @Test + public void testIsNotNull() + { + assertOptimizedEquals("null is not null", "false"); + assertOptimizedEquals("1 is not null", "true"); + assertOptimizedEquals("10000000000 is not null", "true"); + assertOptimizedEquals("BIGINT '1' is not null", "true"); + assertOptimizedEquals("1.0 is not null", "true"); + assertOptimizedEquals("'a' is not null", "true"); + assertOptimizedEquals("true is not null", "true"); + assertOptimizedEquals("null+1 is not null", "false"); + assertOptimizedEquals("unbound_string is not null", "unbound_string is not null"); + assertOptimizedEquals("unbound_long+(1+1) is not null", "unbound_long+2 is not null"); + assertOptimizedEquals("1.1 is not null", "true"); + assertOptimizedEquals("9876543210.9874561203 is not null", "true"); + assertOptimizedEquals("bound_decimal_short is not null", "true"); + assertOptimizedEquals("bound_decimal_long is not null", "true"); + } + + @Test + public void testNullIf() + { + assertOptimizedEquals("nullif(true, true)", "null"); + assertOptimizedEquals("nullif(true, false)", "true"); + assertOptimizedEquals("nullif(null, false)", "null"); + assertOptimizedEquals("nullif(true, null)", "true"); + + assertOptimizedEquals("nullif('a', 'a')", "null"); + assertOptimizedEquals("nullif('a', 'b')", "'a'"); + assertOptimizedEquals("nullif(null, 'b')", "null"); + assertOptimizedEquals("nullif('a', null)", "'a'"); + + assertOptimizedEquals("nullif(1, 1)", "null"); + assertOptimizedEquals("nullif(1, 2)", "1"); + assertOptimizedEquals("nullif(1, BIGINT '2')", "1"); + assertOptimizedEquals("nullif(1, 20000000000)", "1"); + assertOptimizedEquals("nullif(1.0E0, 1)", "null"); + assertOptimizedEquals("nullif(10000000000.0E0, 10000000000)", "null"); + assertOptimizedEquals("nullif(1.1E0, 1)", "1.1E0"); + assertOptimizedEquals("nullif(1.1E0, 1.1E0)", "null"); + assertOptimizedEquals("nullif(1, 2-1)", "null"); + assertOptimizedEquals("nullif(null, null)", "null"); + assertOptimizedEquals("nullif(1, null)", "1"); + assertOptimizedEquals("nullif(unbound_long, 1)", "nullif(unbound_long, 1)"); + assertOptimizedEquals("nullif(unbound_long, unbound_long2)", "nullif(unbound_long, unbound_long2)"); + assertOptimizedEquals("nullif(unbound_long, unbound_long2+(1+1))", "nullif(unbound_long, unbound_long2+2)"); + + assertOptimizedEquals("nullif(1.1, 1.2)", "1.1"); + assertOptimizedEquals("nullif(9876543210.9874561203, 9876543210.9874561203)", "null"); + assertOptimizedEquals("nullif(bound_decimal_short, 123.45)", "null"); + assertOptimizedEquals("nullif(bound_decimal_long, 12345678901234567890.123)", "null"); + assertOptimizedEquals("nullif(ARRAY[CAST(1 AS BIGINT)], ARRAY[CAST(1 AS BIGINT)]) IS NULL", "true"); + assertOptimizedEquals("nullif(ARRAY[CAST(1 AS BIGINT)], ARRAY[CAST(NULL AS BIGINT)]) IS NULL", "false"); + assertOptimizedEquals("nullif(ARRAY[CAST(NULL AS BIGINT)], ARRAY[CAST(NULL AS BIGINT)]) IS NULL", "false"); + } + + @Test + public void testNegative() + { + assertOptimizedEquals("-(1)", "-1"); + assertOptimizedEquals("-(BIGINT '1')", "BIGINT '-1'"); + assertOptimizedEquals("-(unbound_long+1)", "-(unbound_long+1)"); + assertOptimizedEquals("-(1+1)", "-2"); + assertOptimizedEquals("-(1+ BIGINT '1')", "BIGINT '-2'"); + assertOptimizedEquals("-(CAST(NULL AS BIGINT))", "null"); + assertOptimizedEquals("-(unbound_long+(1+1))", "-(unbound_long+2)"); + assertOptimizedEquals("-(1.1+1.2)", "-2.3"); + assertOptimizedEquals("-(9876543210.9874561203-9876543210.9874561203)", "CAST(0 AS DECIMAL(20,10))"); + assertOptimizedEquals("-(bound_decimal_short+123.45)", "-246.90"); + assertOptimizedEquals("-(bound_decimal_long-12345678901234567890.123)", "CAST(0 AS DECIMAL(20,10))"); + } + + @Test + public void testNot() + { + assertOptimizedEquals("not true", "false"); + assertOptimizedEquals("not false", "true"); + assertOptimizedEquals("not null", "null"); + assertOptimizedEquals("not 1=1", "false"); + assertOptimizedEquals("not 1=BIGINT '1'", "false"); + assertOptimizedEquals("not 1!=1", "true"); + assertOptimizedEquals("not unbound_long=1", "not unbound_long=1"); + assertOptimizedEquals("not unbound_long=(1+1)", "not unbound_long=2"); + } + + @Test + public void testFunctionCall() + { + assertOptimizedEquals("abs(-5)", "5"); + assertOptimizedEquals("abs(-10-5)", "15"); + assertOptimizedEquals("abs(-bound_integer + 1)", "1233"); + assertOptimizedEquals("abs(-bound_long + 1)", "1233"); + assertOptimizedEquals("abs(-bound_long + BIGINT '1')", "1233"); + assertOptimizedEquals("abs(-bound_long)", "1234"); + assertOptimizedEquals("abs(unbound_long)", "abs(unbound_long)"); + assertOptimizedEquals("abs(unbound_long + 1)", "abs(unbound_long + 1)"); + assertOptimizedEquals("cast(json_parse(unbound_string) as map(varchar, varchar))", "cast(json_parse(unbound_string) as map(varchar, varchar))"); + assertOptimizedEquals("cast(json_parse(unbound_string) as array(varchar))", "cast(json_parse(unbound_string) as array(varchar))"); + assertOptimizedEquals("cast(json_parse(unbound_string) as row(bigint, varchar))", "cast(json_parse(unbound_string) as row(bigint, varchar))"); + } + + @Test + public void testNonDeterministicFunctionCall() + { + // optimize should do nothing + assertOptimizedEquals("random()", "random()"); + + // evaluate should execute + Object value = evaluate("random()", false); + assertTrue(value instanceof Double); + double randomValue = (double) value; + assertTrue(0 <= randomValue && randomValue < 1); + } + + @Test + public void testCppFunctionCall() + { + METADATA.getFunctionAndTypeManager().createFunction(SQUARE_UDF_CPP, false); + assertOptimizedEquals("json.test_schema.square(-5)", "json.test_schema.square(-5)"); + } + + @Test + public void testCppAggregateFunctionCall() + { + METADATA.getFunctionAndTypeManager().createFunction(AVG_UDAF_CPP, false); + assertOptimizedEquals("json.test_schema.avg(1.0)", "json.test_schema.avg(1.0)"); + } + + @Test + public void testBetween() + { + assertOptimizedEquals("3 between 2 and 4", "true"); + assertOptimizedEquals("2 between 3 and 4", "false"); + assertOptimizedEquals("null between 2 and 4", "null"); + assertOptimizedEquals("3 between null and 4", "null"); + assertOptimizedEquals("3 between 2 and null", "null"); + + assertOptimizedEquals("'cc' between 'b' and 'd'", "true"); + assertOptimizedEquals("'b' between 'cc' and 'd'", "false"); + assertOptimizedEquals("null between 'b' and 'd'", "null"); + assertOptimizedEquals("'cc' between null and 'd'", "null"); + assertOptimizedEquals("'cc' between 'b' and null", "null"); + + assertOptimizedEquals("bound_integer between 1000 and 2000", "true"); + assertOptimizedEquals("bound_integer between 3 and 4", "false"); + assertOptimizedEquals("bound_long between 1000 and 2000", "true"); + assertOptimizedEquals("bound_long between 3 and 4", "false"); + assertOptimizedEquals("bound_long between bound_integer and (bound_long + 1)", "true"); + assertOptimizedEquals("bound_string between 'e' and 'i'", "true"); + assertOptimizedEquals("bound_string between 'a' and 'b'", "false"); + + assertOptimizedEquals("bound_long between unbound_long and 2000 + 1", "1234 between unbound_long and 2001"); + assertOptimizedEquals( + "bound_string between unbound_string and 'bar'", + format("CAST('hello' AS VARCHAR(%s)) between unbound_string and 'bar'", TEST_VARCHAR_TYPE_LENGTH)); + + assertOptimizedEquals("1.15 between 1.1 and 1.2", "true"); + assertOptimizedEquals("9876543210.98745612035 between 9876543210.9874561203 and 9876543210.9874561204", "true"); + assertOptimizedEquals("123.455 between bound_decimal_short and 123.46", "true"); + assertOptimizedEquals("12345678901234567890.1235 between bound_decimal_long and 12345678901234567890.123", "false"); + } + + @Test + public void testExtract() + { + DateTime dateTime = new DateTime(2001, 8, 22, 3, 4, 5, 321, getDateTimeZone(TEST_SESSION.getTimeZoneKey())); + double seconds = dateTime.getMillis() / 1000.0; + + assertOptimizedEquals("extract (YEAR from from_unixtime(" + seconds + "))", "2001"); + assertOptimizedEquals("extract (QUARTER from from_unixtime(" + seconds + "))", "3"); + assertOptimizedEquals("extract (MONTH from from_unixtime(" + seconds + "))", "8"); + assertOptimizedEquals("extract (WEEK from from_unixtime(" + seconds + "))", "34"); + assertOptimizedEquals("extract (DOW from from_unixtime(" + seconds + "))", "3"); + assertOptimizedEquals("extract (DOY from from_unixtime(" + seconds + "))", "234"); + assertOptimizedEquals("extract (DAY from from_unixtime(" + seconds + "))", "22"); + assertOptimizedEquals("extract (HOUR from from_unixtime(" + seconds + "))", "3"); + assertOptimizedEquals("extract (MINUTE from from_unixtime(" + seconds + "))", "4"); + assertOptimizedEquals("extract (SECOND from from_unixtime(" + seconds + "))", "5"); + assertOptimizedEquals("extract (TIMEZONE_HOUR from from_unixtime(" + seconds + ", 7, 9))", "7"); + assertOptimizedEquals("extract (TIMEZONE_MINUTE from from_unixtime(" + seconds + ", 7, 9))", "9"); + + assertOptimizedEquals("extract (YEAR from bound_timestamp)", "2001"); + assertOptimizedEquals("extract (QUARTER from bound_timestamp)", "3"); + assertOptimizedEquals("extract (MONTH from bound_timestamp)", "8"); + assertOptimizedEquals("extract (WEEK from bound_timestamp)", "34"); + assertOptimizedEquals("extract (DOW from bound_timestamp)", "2"); + assertOptimizedEquals("extract (DOY from bound_timestamp)", "233"); + assertOptimizedEquals("extract (DAY from bound_timestamp)", "21"); + assertOptimizedEquals("extract (HOUR from bound_timestamp)", "16"); + assertOptimizedEquals("extract (MINUTE from bound_timestamp)", "4"); + assertOptimizedEquals("extract (SECOND from bound_timestamp)", "5"); + // todo reenable when cast as timestamp with time zone is implemented + // todo add bound timestamp with time zone + //assertOptimizedEquals("extract (TIMEZONE_HOUR from bound_timestamp)", "0"); + //assertOptimizedEquals("extract (TIMEZONE_MINUTE from bound_timestamp)", "0"); + + assertOptimizedEquals("extract (YEAR from unbound_timestamp)", "extract (YEAR from unbound_timestamp)"); + assertOptimizedEquals("extract (SECOND from bound_timestamp + INTERVAL '3' SECOND)", "8"); + } + + @Test + public void testIn() + { + assertOptimizedEquals("3 in (2, 4, 3, 5)", "true"); + assertOptimizedEquals("3 in (2, 4, 9, 5)", "false"); + assertOptimizedEquals("3 in (2, null, 3, 5)", "true"); + + assertOptimizedEquals("'foo' in ('bar', 'baz', 'foo', 'blah')", "true"); + assertOptimizedEquals("'foo' in ('bar', 'baz', 'buz', 'blah')", "false"); + assertOptimizedEquals("'foo' in ('bar', null, 'foo', 'blah')", "true"); + + assertOptimizedEquals("null in (2, null, 3, 5)", "null"); + assertOptimizedEquals("3 in (2, null)", "null"); + + assertOptimizedEquals("bound_integer in (2, 1234, 3, 5)", "true"); + assertOptimizedEquals("bound_integer in (2, 4, 3, 5)", "false"); + assertOptimizedEquals("1234 in (2, bound_integer, 3, 5)", "true"); + assertOptimizedEquals("99 in (2, bound_integer, 3, 5)", "false"); + assertOptimizedEquals("bound_integer in (2, bound_integer, 3, 5)", "true"); + + assertOptimizedEquals("bound_long in (2, 1234, 3, 5)", "true"); + assertOptimizedEquals("bound_long in (2, 4, 3, 5)", "false"); + assertOptimizedEquals("1234 in (2, bound_long, 3, 5)", "true"); + assertOptimizedEquals("99 in (2, bound_long, 3, 5)", "false"); + assertOptimizedEquals("bound_long in (2, bound_long, 3, 5)", "true"); + + assertOptimizedEquals("bound_string in ('bar', 'hello', 'foo', 'blah')", "true"); + assertOptimizedEquals("bound_string in ('bar', 'baz', 'foo', 'blah')", "false"); + assertOptimizedEquals("'hello' in ('bar', bound_string, 'foo', 'blah')", "true"); + assertOptimizedEquals("'baz' in ('bar', bound_string, 'foo', 'blah')", "false"); + + assertOptimizedEquals("bound_long in (2, 1234, unbound_long, 5)", "true"); + assertOptimizedEquals("bound_string in ('bar', 'hello', unbound_string, 'blah')", "true"); + + assertOptimizedEquals("bound_long in (2, 4, unbound_long, unbound_long2, 9)", "1234 in (unbound_long, unbound_long2)"); + assertOptimizedEquals("unbound_long in (2, 4, bound_long, unbound_long2, 5)", "unbound_long in (2, 4, 1234, unbound_long2, 5)"); + + assertOptimizedEquals("1.15 in (1.1, 1.2, 1.3, 1.15)", "true"); + assertOptimizedEquals("9876543210.98745612035 in (9876543210.9874561203, 9876543210.9874561204, 9876543210.98745612035)", "true"); + assertOptimizedEquals("bound_decimal_short in (123.455, 123.46, 123.45)", "true"); + assertOptimizedEquals("bound_decimal_long in (12345678901234567890.123, 9876543210.9874561204, 9876543210.98745612035)", "true"); + assertOptimizedEquals("bound_decimal_long in (9876543210.9874561204, null, 9876543210.98745612035)", "null"); + } + + @Test + public void testInComplexTypes() + { + assertEvaluatedEquals("ARRAY[null] IN (ARRAY[null])", "null"); + assertEvaluatedEquals("ARRAY[1] IN (ARRAY[null])", "null"); + assertEvaluatedEquals("ARRAY[null] IN (ARRAY[1])", "null"); + assertEvaluatedEquals("ARRAY[1, null] IN (ARRAY[1, null])", "null"); + assertEvaluatedEquals("ARRAY[1, null] IN (ARRAY[2, null])", "false"); + assertEvaluatedEquals("ARRAY[1, null] IN (ARRAY[1, null], ARRAY[2, null])", "null"); + assertEvaluatedEquals("ARRAY[1, null] IN (ARRAY[1, null], ARRAY[2, null], ARRAY[1, null])", "null"); + assertEvaluatedEquals("ARRAY[ARRAY[1, 2], ARRAY[3, 4]] in (ARRAY[ARRAY[1, 2], ARRAY[3, NULL]])", "null"); + + assertEvaluatedEquals("ROW(1) IN (ROW(1))", "true"); + assertEvaluatedEquals("ROW(1) IN (ROW(2))", "false"); + assertEvaluatedEquals("ROW(1) IN (ROW(2), ROW(1), ROW(2))", "true"); + assertEvaluatedEquals("ROW(1) IN (null)", "null"); + assertEvaluatedEquals("ROW(1) IN (null, ROW(1))", "true"); + assertEvaluatedEquals("ROW(1, null) IN (ROW(2, null), null)", "null"); + assertEvaluatedEquals("ROW(null) IN (ROW(null))", "null"); + assertEvaluatedEquals("ROW(1) IN (ROW(null))", "null"); + assertEvaluatedEquals("ROW(null) IN (ROW(1))", "null"); + assertEvaluatedEquals("ROW(1, null) IN (ROW(1, null))", "null"); + assertEvaluatedEquals("ROW(1, null) IN (ROW(2, null))", "false"); + assertEvaluatedEquals("ROW(1, null) IN (ROW(1, null), ROW(2, null))", "null"); + assertEvaluatedEquals("ROW(1, null) IN (ROW(1, null), ROW(2, null), ROW(1, null))", "null"); + + assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (MAP(ARRAY[1], ARRAY[1]))", "true"); + assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (null)", "null"); + assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (null, MAP(ARRAY[1], ARRAY[1]))", "true"); + assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]))", "false"); + assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[2, null]), null)", "null"); + assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]))", "null"); + assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 3], ARRAY[1, null]))", "false"); + assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[null]) IN (MAP(ARRAY[1], ARRAY[null]))", "null"); + assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[1]) IN (MAP(ARRAY[1], ARRAY[null]))", "null"); + assertEvaluatedEquals("MAP(ARRAY[1], ARRAY[null]) IN (MAP(ARRAY[1], ARRAY[1]))", "null"); + assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]))", "null"); + assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 3], ARRAY[1, null]))", "false"); + assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[2, null]))", "false"); + assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]), MAP(ARRAY[1, 2], ARRAY[2, null]))", "null"); + assertEvaluatedEquals("MAP(ARRAY[1, 2], ARRAY[1, null]) IN (MAP(ARRAY[1, 2], ARRAY[1, null]), MAP(ARRAY[1, 2], ARRAY[2, null]), MAP(ARRAY[1, 2], ARRAY[1, null]))", "null"); + } + + @Test + public void testCurrentTimestamp() + { + double current = TEST_SESSION.getStartTime() / 1000.0; + assertOptimizedEquals("current_timestamp = from_unixtime(" + current + ")", "true"); + double future = current + TimeUnit.MINUTES.toSeconds(1); + assertOptimizedEquals("current_timestamp > from_unixtime(" + future + ")", "false"); + } + + @Test + public void testCurrentUser() + throws Exception + { + assertOptimizedEquals("current_user", "'" + TEST_SESSION.getUser() + "'"); + } + + @Test + public void testCastToString() + { + // integer + assertOptimizedEquals("cast(123 as VARCHAR(20))", "'123'"); + assertOptimizedEquals("cast(-123 as VARCHAR(20))", "'-123'"); + + // bigint + assertOptimizedEquals("cast(BIGINT '123' as VARCHAR)", "'123'"); + assertOptimizedEquals("cast(12300000000 as VARCHAR)", "'12300000000'"); + assertOptimizedEquals("cast(-12300000000 as VARCHAR)", "'-12300000000'"); + + // double + assertOptimizedEquals("cast(123.0E0 as VARCHAR)", "'123.0'"); + assertOptimizedEquals("cast(-123.0E0 as VARCHAR)", "'-123.0'"); + assertOptimizedEquals("cast(123.456E0 as VARCHAR)", "'123.456'"); + assertOptimizedEquals("cast(-123.456E0 as VARCHAR)", "'-123.456'"); + + // boolean + assertOptimizedEquals("cast(true as VARCHAR)", "'true'"); + assertOptimizedEquals("cast(false as VARCHAR)", "'false'"); + + // string + assertOptimizedEquals("cast('xyz' as VARCHAR)", "'xyz'"); + assertOptimizedEquals("cast(cast('abcxyz' as VARCHAR(3)) as VARCHAR(5))", "'abc'"); + + // null + assertOptimizedEquals("cast(null as VARCHAR)", "null"); + + // decimal + assertOptimizedEquals("cast(1.1 as VARCHAR)", "'1.1'"); + // TODO enabled when DECIMAL is default for literal: assertOptimizedEquals("cast(12345678901234567890.123 as VARCHAR)", "'12345678901234567890.123'"); + } + + @Test + public void testCastBigintToBoundedVarchar() + { + assertEvaluatedEquals("CAST(12300000000 AS varchar(11))", "'12300000000'"); + assertEvaluatedEquals("CAST(12300000000 AS varchar(50))", "'12300000000'"); + + try { + evaluate("CAST(12300000000 AS varchar(3))", true); + fail("Expected to throw an INVALID_CAST_ARGUMENT exception"); + } + catch (PrestoException e) { + try { + assertEquals(e.getErrorCode(), INVALID_CAST_ARGUMENT.toErrorCode()); + assertEquals(e.getMessage(), "Value 12300000000 cannot be represented as varchar(3)"); + } + catch (Throwable failure) { + failure.addSuppressed(e); + throw failure; + } + } + + try { + evaluate("CAST(-12300000000 AS varchar(3))", true); + } + catch (PrestoException e) { + try { + assertEquals(e.getErrorCode(), INVALID_CAST_ARGUMENT.toErrorCode()); + assertEquals(e.getMessage(), "Value -12300000000 cannot be represented as varchar(3)"); + } + catch (Throwable failure) { + failure.addSuppressed(e); + throw failure; + } + } + } + + @Test + public void testCastToBoolean() + { + // integer + assertOptimizedEquals("cast(123 as BOOLEAN)", "true"); + assertOptimizedEquals("cast(-123 as BOOLEAN)", "true"); + assertOptimizedEquals("cast(0 as BOOLEAN)", "false"); + + // bigint + assertOptimizedEquals("cast(12300000000 as BOOLEAN)", "true"); + assertOptimizedEquals("cast(-12300000000 as BOOLEAN)", "true"); + assertOptimizedEquals("cast(BIGINT '0' as BOOLEAN)", "false"); + + // boolean + assertOptimizedEquals("cast(true as BOOLEAN)", "true"); + assertOptimizedEquals("cast(false as BOOLEAN)", "false"); + + // string + assertOptimizedEquals("cast('true' as BOOLEAN)", "true"); + assertOptimizedEquals("cast('false' as BOOLEAN)", "false"); + assertOptimizedEquals("cast('t' as BOOLEAN)", "true"); + assertOptimizedEquals("cast('f' as BOOLEAN)", "false"); + assertOptimizedEquals("cast('1' as BOOLEAN)", "true"); + assertOptimizedEquals("cast('0' as BOOLEAN)", "false"); + + // null + assertOptimizedEquals("cast(null as BOOLEAN)", "null"); + + // double + assertOptimizedEquals("cast(123.45E0 as BOOLEAN)", "true"); + assertOptimizedEquals("cast(-123.45E0 as BOOLEAN)", "true"); + assertOptimizedEquals("cast(0.0E0 as BOOLEAN)", "false"); + + // decimal + assertOptimizedEquals("cast(0.00 as BOOLEAN)", "false"); + assertOptimizedEquals("cast(7.8 as BOOLEAN)", "true"); + assertOptimizedEquals("cast(12345678901234567890.123 as BOOLEAN)", "true"); + assertOptimizedEquals("cast(00000000000000000000.000 as BOOLEAN)", "false"); + } + + @Test + public void testCastToBigint() + { + // integer + assertOptimizedEquals("cast(0 as BIGINT)", "0"); + assertOptimizedEquals("cast(123 as BIGINT)", "123"); + assertOptimizedEquals("cast(-123 as BIGINT)", "-123"); + + // bigint + assertOptimizedEquals("cast(BIGINT '0' as BIGINT)", "0"); + assertOptimizedEquals("cast(BIGINT '123' as BIGINT)", "123"); + assertOptimizedEquals("cast(BIGINT '-123' as BIGINT)", "-123"); + + // double + assertOptimizedEquals("cast(123.0E0 as BIGINT)", "123"); + assertOptimizedEquals("cast(-123.0E0 as BIGINT)", "-123"); + assertOptimizedEquals("cast(123.456E0 as BIGINT)", "123"); + assertOptimizedEquals("cast(-123.456E0 as BIGINT)", "-123"); + + // boolean + assertOptimizedEquals("cast(true as BIGINT)", "1"); + assertOptimizedEquals("cast(false as BIGINT)", "0"); + + // string + assertOptimizedEquals("cast('123' as BIGINT)", "123"); + assertOptimizedEquals("cast('-123' as BIGINT)", "-123"); + + // null + assertOptimizedEquals("cast(null as BIGINT)", "null"); + + // decimal + assertOptimizedEquals("cast(DECIMAL '1.01' as BIGINT)", "1"); + assertOptimizedEquals("cast(DECIMAL '7.8' as BIGINT)", "8"); + assertOptimizedEquals("cast(DECIMAL '1234567890.123' as BIGINT)", "1234567890"); + assertOptimizedEquals("cast(DECIMAL '00000000000000000000.000' as BIGINT)", "0"); + } + + @Test + public void testCastToInteger() + { + // integer + assertOptimizedEquals("cast(0 as INTEGER)", "0"); + assertOptimizedEquals("cast(123 as INTEGER)", "123"); + assertOptimizedEquals("cast(-123 as INTEGER)", "-123"); + + // bigint + assertOptimizedEquals("cast(BIGINT '0' as INTEGER)", "0"); + assertOptimizedEquals("cast(BIGINT '123' as INTEGER)", "123"); + assertOptimizedEquals("cast(BIGINT '-123' as INTEGER)", "-123"); + + // double + assertOptimizedEquals("cast(123.0E0 as INTEGER)", "123"); + assertOptimizedEquals("cast(-123.0E0 as INTEGER)", "-123"); + assertOptimizedEquals("cast(123.456E0 as INTEGER)", "123"); + assertOptimizedEquals("cast(-123.456E0 as INTEGER)", "-123"); + + // boolean + assertOptimizedEquals("cast(true as INTEGER)", "1"); + assertOptimizedEquals("cast(false as INTEGER)", "0"); + + // string + assertOptimizedEquals("cast('123' as INTEGER)", "123"); + assertOptimizedEquals("cast('-123' as INTEGER)", "-123"); + + // null + assertOptimizedEquals("cast(null as INTEGER)", "null"); + } + + @Test + public void testCastToDouble() + { + // integer + assertOptimizedEquals("cast(0 as DOUBLE)", "0.0E0"); + assertOptimizedEquals("cast(123 as DOUBLE)", "123.0E0"); + assertOptimizedEquals("cast(-123 as DOUBLE)", "-123.0E0"); + + // bigint + assertOptimizedEquals("cast(BIGINT '0' as DOUBLE)", "0.0E0"); + assertOptimizedEquals("cast(12300000000 as DOUBLE)", "12300000000.0E0"); + assertOptimizedEquals("cast(-12300000000 as DOUBLE)", "-12300000000.0E0"); + + // double + assertOptimizedEquals("cast(123.0E0 as DOUBLE)", "123.0E0"); + assertOptimizedEquals("cast(-123.0E0 as DOUBLE)", "-123.0E0"); + assertOptimizedEquals("cast(123.456E0 as DOUBLE)", "123.456E0"); + assertOptimizedEquals("cast(-123.456E0 as DOUBLE)", "-123.456E0"); + + // string + assertOptimizedEquals("cast('0' as DOUBLE)", "0.0E0"); + assertOptimizedEquals("cast('123' as DOUBLE)", "123.0E0"); + assertOptimizedEquals("cast('-123' as DOUBLE)", "-123.0E0"); + assertOptimizedEquals("cast('123.0E0' as DOUBLE)", "123.0E0"); + assertOptimizedEquals("cast('-123.0E0' as DOUBLE)", "-123.0E0"); + assertOptimizedEquals("cast('123.456E0' as DOUBLE)", "123.456E0"); + assertOptimizedEquals("cast('-123.456E0' as DOUBLE)", "-123.456E0"); + + // null + assertOptimizedEquals("cast(null as DOUBLE)", "null"); + + // boolean + assertOptimizedEquals("cast(true as DOUBLE)", "1.0E0"); + assertOptimizedEquals("cast(false as DOUBLE)", "0.0E0"); + + // decimal + assertOptimizedEquals("cast(1.01 as DOUBLE)", "DOUBLE '1.01'"); + assertOptimizedEquals("cast(7.8 as DOUBLE)", "DOUBLE '7.8'"); + assertOptimizedEquals("cast(1234567890.123 as DOUBLE)", "DOUBLE '1234567890.123'"); + assertOptimizedEquals("cast(00000000000000000000.000 as DOUBLE)", "DOUBLE '0.0'"); + } + + @Test + public void testCastToDecimal() + { + // long + assertOptimizedEquals("cast(0 as DECIMAL(1,0))", "DECIMAL '0'"); + assertOptimizedEquals("cast(123 as DECIMAL(3,0))", "DECIMAL '123'"); + assertOptimizedEquals("cast(-123 as DECIMAL(3,0))", "DECIMAL '-123'"); + assertOptimizedEquals("cast(-123 as DECIMAL(20,10))", "cast(-123 as DECIMAL(20,10))"); + + // double + assertOptimizedEquals("cast(0E0 as DECIMAL(1,0))", "DECIMAL '0'"); + assertOptimizedEquals("cast(123.2E0 as DECIMAL(4,1))", "DECIMAL '123.2'"); + assertOptimizedEquals("cast(-123.0E0 as DECIMAL(3,0))", "DECIMAL '-123'"); + assertOptimizedEquals("cast(-123.55E0 as DECIMAL(20,10))", "cast(-123.55 as DECIMAL(20,10))"); + + // string + assertOptimizedEquals("cast('0' as DECIMAL(1,0))", "DECIMAL '0'"); + assertOptimizedEquals("cast('123.2' as DECIMAL(4,1))", "DECIMAL '123.2'"); + assertOptimizedEquals("cast('-123.0' as DECIMAL(3,0))", "DECIMAL '-123'"); + assertOptimizedEquals("cast('-123.55' as DECIMAL(20,10))", "cast(-123.55 as DECIMAL(20,10))"); + + // null + assertOptimizedEquals("cast(null as DECIMAL(1,0))", "null"); + assertOptimizedEquals("cast(null as DECIMAL(20,10))", "null"); + + // boolean + assertOptimizedEquals("cast(true as DECIMAL(1,0))", "DECIMAL '1'"); + assertOptimizedEquals("cast(false as DECIMAL(4,1))", "DECIMAL '000.0'"); + assertOptimizedEquals("cast(true as DECIMAL(3,0))", "DECIMAL '001'"); + assertOptimizedEquals("cast(false as DECIMAL(20,10))", "cast(0 as DECIMAL(20,10))"); + + // decimal + assertOptimizedEquals("cast(0.0 as DECIMAL(1,0))", "DECIMAL '0'"); + assertOptimizedEquals("cast(123.2 as DECIMAL(4,1))", "DECIMAL '123.2'"); + assertOptimizedEquals("cast(-123.0 as DECIMAL(3,0))", "DECIMAL '-123'"); + assertOptimizedEquals("cast(-123.55 as DECIMAL(20,10))", "cast(-123.55 as DECIMAL(20,10))"); + } + + @Test + public void testCastOptimization() + { + assertOptimizedEquals("cast(unbound_string as VARCHAR)", "cast(unbound_string as VARCHAR)"); + assertOptimizedMatches("cast(unbound_string as VARCHAR)", "unbound_string"); + assertOptimizedMatches("cast(unbound_integer as INTEGER)", "unbound_integer"); + assertOptimizedMatches("cast(unbound_string as VARCHAR(10))", "cast(unbound_string as VARCHAR(10))"); + } + + @Test + public void testTryCast() + { + assertOptimizedEquals("try_cast(null as BIGINT)", "null"); + assertOptimizedEquals("try_cast(123 as BIGINT)", "123"); + assertOptimizedEquals("try_cast(null as INTEGER)", "null"); + assertOptimizedEquals("try_cast(123 as INTEGER)", "123"); + assertOptimizedEquals("try_cast('foo' as VARCHAR)", "'foo'"); + assertOptimizedEquals("try_cast('foo' as BIGINT)", "null"); + assertOptimizedEquals("try_cast(unbound_string as BIGINT)", "try_cast(unbound_string as BIGINT)"); + assertOptimizedEquals("try_cast('foo' as DECIMAL(2,1))", "null"); + } + + @Test + public void testReservedWithDoubleQuotes() + { + assertOptimizedEquals("\"time\"", "\"time\""); + } + + @Test + public void testEnumLiteralFormattingWithTypeAndValue() + { + java.util.function.BiFunction createEnumLiteral = (type, value) -> new EnumLiteral(Optional.empty(), type, value); + assertEquals(ExpressionFormatter.formatExpression(createEnumLiteral.apply("color", "RED"), Optional.empty()), "color: RED"); + assertEquals(ExpressionFormatter.formatExpression(createEnumLiteral.apply("level", 1), Optional.empty()), "level: 1"); + assertEquals(ExpressionFormatter.formatExpression(createEnumLiteral.apply("StatusType", "Active"), Optional.empty()), "StatusType: Active"); + assertEquals(ExpressionFormatter.formatExpression(createEnumLiteral.apply("priority", "HIGH PRIORITY"), Optional.empty()), "priority: HIGH PRIORITY"); + assertEquals(ExpressionFormatter.formatExpression(createEnumLiteral.apply("lang", "枚举"), Optional.empty()), "lang: 枚举"); + assertEquals(ExpressionFormatter.formatExpression(createEnumLiteral.apply("special", "DOLLAR$"), Optional.empty()), "special: DOLLAR$"); + assertEquals(ExpressionFormatter.formatExpression(createEnumLiteral.apply("enum_type", "VALUE_1"), Optional.empty()), "enum_type: VALUE_1"); + assertEquals(ExpressionFormatter.formatExpression(createEnumLiteral.apply("flag", true), Optional.empty()), "flag: true"); + } + + @Test + public void testSearchCase() + { + assertOptimizedEquals("case " + + "when true then 33 " + + "end", + "33"); + assertOptimizedEquals("case " + + "when false then 1 " + + "else 33 " + + "end", + "33"); + + assertOptimizedEquals("case " + + "when false then 10000000000 " + + "else 33 " + + "end", + "33"); + + assertOptimizedEquals("case " + + "when bound_long = 1234 then 33 " + + "end", + "33"); + assertOptimizedEquals("case " + + "when true then bound_long " + + "end", + "1234"); + assertOptimizedEquals("case " + + "when false then 1 " + + "else bound_long " + + "end", + "1234"); + + assertOptimizedEquals("case " + + "when bound_integer = 1234 then 33 " + + "end", + "33"); + assertOptimizedEquals("case " + + "when true then bound_integer " + + "end", + "1234"); + assertOptimizedEquals("case " + + "when false then 1 " + + "else bound_integer " + + "end", + "1234"); + + assertOptimizedEquals("case " + + "when bound_long = 1234 then 33 " + + "else unbound_long " + + "end", + "33"); + assertOptimizedEquals("case " + + "when true then bound_long " + + "else unbound_long " + + "end", + "1234"); + assertOptimizedEquals("case " + + "when false then unbound_long " + + "else bound_long " + + "end", + "1234"); + + assertOptimizedEquals("case " + + "when bound_integer = 1234 then 33 " + + "else unbound_integer " + + "end", + "33"); + assertOptimizedEquals("case " + + "when true then bound_integer " + + "else unbound_integer " + + "end", + "1234"); + assertOptimizedEquals("case " + + "when false then unbound_integer " + + "else bound_integer " + + "end", + "1234"); + + assertOptimizedEquals("case " + + "when unbound_long = 1234 then 33 " + + "else 1 " + + "end", + "" + + "case " + + "when unbound_long = 1234 then 33 " + + "else 1 " + + "end"); + + assertOptimizedEquals("case " + + "when false then 2.2 " + + "when true then 2.2 " + + "end", + "2.2"); + + assertOptimizedEquals("case " + + "when false then 1234567890.0987654321 " + + "when true then 3.3 " + + "end", + "CAST(3.3 AS DECIMAL(20,10))"); + + assertOptimizedEquals("case " + + "when false then 1 " + + "when true then 2.2 " + + "end", + "2.2"); + + assertOptimizedEquals("case when ARRAY[CAST(1 AS BIGINT)] = ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'matched'"); + assertOptimizedEquals("case when ARRAY[CAST(2 AS BIGINT)] = ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'not_matched'"); + assertOptimizedEquals("case when ARRAY[CAST(null AS BIGINT)] = ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'not_matched'"); + } + + @Test + public void testSimpleCase() + { + assertOptimizedEquals("case 1 " + + "when 1 then 32 + 1 " + + "when 1 then 34 " + + "end", + "33"); + + assertOptimizedEquals("case null " + + "when true then 33 " + + "end", + "null"); + assertOptimizedEquals("case null " + + "when true then 33 " + + "else 33 " + + "end", + "33"); + assertOptimizedEquals("case 33 " + + "when null then 1 " + + "else 33 " + + "end", + "33"); + + assertOptimizedEquals("case null " + + "when true then 3300000000 " + + "end", + "null"); + assertOptimizedEquals("case null " + + "when true then 3300000000 " + + "else 3300000000 " + + "end", + "3300000000"); + assertOptimizedEquals("case 33 " + + "when null then 3300000000 " + + "else 33 " + + "end", + "33"); + + assertOptimizedEquals("case true " + + "when true then 33 " + + "end", + "33"); + assertOptimizedEquals("case true " + + "when false then 1 " + + "else 33 end", + "33"); + + assertOptimizedEquals("case bound_long " + + "when 1234 then 33 " + + "end", + "33"); + assertOptimizedEquals("case 1234 " + + "when bound_long then 33 " + + "end", + "33"); + assertOptimizedEquals("case true " + + "when true then bound_long " + + "end", + "1234"); + assertOptimizedEquals("case true " + + "when false then 1 " + + "else bound_long " + + "end", + "1234"); + + assertOptimizedEquals("case bound_integer " + + "when 1234 then 33 " + + "end", + "33"); + assertOptimizedEquals("case 1234 " + + "when bound_integer then 33 " + + "end", + "33"); + assertOptimizedEquals("case true " + + "when true then bound_integer " + + "end", + "1234"); + assertOptimizedEquals("case true " + + "when false then 1 " + + "else bound_integer " + + "end", + "1234"); + + assertOptimizedEquals("case bound_long " + + "when 1234 then 33 " + + "else unbound_long " + + "end", + "33"); + assertOptimizedEquals("case true " + + "when true then bound_long " + + "else unbound_long " + + "end", + "1234"); + assertOptimizedEquals("case true " + + "when false then unbound_long " + + "else bound_long " + + "end", + "1234"); + + assertOptimizedEquals("case unbound_long " + + "when 1234 then 33 " + + "else 1 " + + "end", + "" + + "case unbound_long " + + "when 1234 then 33 " + + "else 1 " + + "end"); + + assertOptimizedEquals("case 33 " + + "when 0 then 0 " + + "when 33 then unbound_long " + + "else 1 " + + "end", + "unbound_long"); + assertOptimizedEquals("case 33 " + + "when 0 then 0 " + + "when 33 then 1 " + + "when unbound_long then 2 " + + "else 1 " + + "end", + "1"); + assertOptimizedEquals("case 33 " + + "when unbound_long then 0 " + + "when 1 then 1 " + + "when 33 then 2 " + + "else 0 " + + "end", + "case 33 " + + "when unbound_long then 0 " + + "else 2 " + + "end"); + assertOptimizedEquals("case 33 " + + "when 0 then 0 " + + "when 1 then 1 " + + "else unbound_long " + + "end", + "unbound_long"); + assertOptimizedEquals("case 33 " + + "when unbound_long then 0 " + + "when 1 then 1 " + + "when unbound_long2 then 2 " + + "else 3 " + + "end", + "case 33 " + + "when unbound_long then 0 " + + "when unbound_long2 then 2 " + + "else 3 " + + "end"); + + assertOptimizedEquals("case true " + + "when unbound_long = 1 then 1 " + + "when 0 / 0 = 0 then 2 " + + "else 33 end", + "" + + "case true " + + "when unbound_long = 1 then 1 " + + "when 0 / 0 = 0 then 2 else 33 " + + "end"); + + assertOptimizedEquals("case bound_long " + + "when 123 * 10 + unbound_long then 1 = 1 " + + "else 1 = 2 " + + "end", + "" + + "case bound_long when 1230 + unbound_long then true " + + "else false " + + "end"); + + assertOptimizedEquals("case bound_long " + + "when unbound_long then 2 + 2 " + + "end", + "" + + "case bound_long " + + "when unbound_long then 4 " + + "end"); + + assertOptimizedEquals("case bound_long " + + "when unbound_long then 2 + 2 " + + "when 1 then null " + + "when 2 then null " + + "end", + "" + + "case bound_long " + + "when unbound_long then 4 " + + "end"); + + assertOptimizedEquals("case true " + + "when false then 2.2 " + + "when true then 2.2 " + + "end", + "2.2"); + + // TODO enabled when DECIMAL is default for literal: +// assertOptimizedEquals("case true " + +// "when false then 1234567890.0987654321 " + +// "when true then 3.3 " + +// "end", +// "CAST(3.3 AS DECIMAL(20,10))"); + + assertOptimizedEquals("case true " + + "when false then 1 " + + "when true then 2.2 " + + "end", + "2.2"); + + assertOptimizedEquals("case ARRAY[CAST(1 AS BIGINT)] when ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'matched'"); + assertOptimizedEquals("case ARRAY[CAST(2 AS BIGINT)] when ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'not_matched'"); + assertOptimizedEquals("case ARRAY[CAST(null AS BIGINT)] when ARRAY[CAST(1 AS BIGINT)] then 'matched' else 'not_matched' end", "'not_matched'"); + } + + @Test + public void testCoalesce() + { + assertOptimizedEquals("coalesce(null, null)", "coalesce(null, null)"); + assertOptimizedEquals("coalesce(2 * 3 * unbound_long, 1 - 1, null)", "coalesce(6 * unbound_long, 0)"); + assertOptimizedEquals("coalesce(2 * 3 * unbound_long, 1.0E0/2.0E0, null)", "coalesce(6 * unbound_long, 0.5E0)"); + assertOptimizedEquals("coalesce(unbound_long, 2, 1.0E0/2.0E0, 12.34E0, null)", "coalesce(unbound_long, 2.0E0, 0.5E0, 12.34E0)"); + assertOptimizedEquals("coalesce(2 * 3 * unbound_integer, 1 - 1, null)", "coalesce(6 * unbound_integer, 0)"); + assertOptimizedEquals("coalesce(2 * 3 * unbound_integer, 1.0E0/2.0E0, null)", "coalesce(6 * unbound_integer, 0.5E0)"); + assertOptimizedEquals("coalesce(unbound_integer, 2, 1.0E0/2.0E0, 12.34E0, null)", "coalesce(unbound_integer, 2.0E0, 0.5E0, 12.34E0)"); + assertOptimizedMatches("coalesce(unbound_long, unbound_long)", "unbound_long"); + assertOptimizedMatches("coalesce(2 * unbound_long, 2 * unbound_long)", "BIGINT '2' * unbound_long"); + assertOptimizedMatches("coalesce(unbound_long, unbound_long2, unbound_long)", "coalesce(unbound_long, unbound_long2)"); + assertOptimizedMatches("coalesce(unbound_long, unbound_long2, unbound_long, unbound_long3)", "coalesce(unbound_long, unbound_long2, unbound_long3)"); + assertOptimizedEquals("coalesce(6, unbound_long2, unbound_long, unbound_long3)", "6"); + assertOptimizedEquals("coalesce(2 * 3, unbound_long2, unbound_long, unbound_long3)", "6"); + assertOptimizedMatches("coalesce(unbound_long, coalesce(unbound_long, 1))", "coalesce(unbound_long, BIGINT '1')"); + assertOptimizedMatches("coalesce(coalesce(unbound_long, coalesce(unbound_long, 1)), unbound_long2)", "coalesce(unbound_long, BIGINT '1')"); + assertOptimizedMatches("coalesce(unbound_long, 2, coalesce(unbound_long, 1))", "coalesce(unbound_long, BIGINT '2')"); + assertOptimizedMatches("coalesce(coalesce(unbound_long, coalesce(unbound_long2, unbound_long3)), 1)", "coalesce(unbound_long, unbound_long2, unbound_long3, BIGINT '1')"); + assertOptimizedMatches("coalesce(unbound_double, coalesce(random(), unbound_double))", "coalesce(unbound_double, random())"); + assertOptimizedMatches("coalesce(random(), random(), 5)", "coalesce(random(), random(), 5E0)"); + assertOptimizedMatches("coalesce(unbound_long, coalesce(unbound_long, 1))", "coalesce(unbound_long, BIGINT '1')"); + assertOptimizedMatches("coalesce(coalesce(unbound_long, coalesce(unbound_long, 1)), unbound_long2)", "coalesce(unbound_long, BIGINT '1')"); + assertOptimizedMatches("coalesce(unbound_long, 2, coalesce(unbound_long, 1))", "coalesce(unbound_long, BIGINT '2')"); + assertOptimizedMatches("coalesce(coalesce(unbound_long, coalesce(unbound_long2, unbound_long3)), 1)", "coalesce(unbound_long, unbound_long2, unbound_long3, BIGINT '1')"); + assertOptimizedMatches("coalesce(unbound_double, coalesce(random(), unbound_double))", "coalesce(unbound_double, random())"); + } + + @Test + public void testIf() + { + assertOptimizedEquals("IF(2 = 2, 3, 4)", "3"); + assertOptimizedEquals("IF(1 = 2, 3, 4)", "4"); + assertOptimizedEquals("IF(1 = 2, BIGINT '3', 4)", "4"); + assertOptimizedEquals("IF(1 = 2, 3000000000, 4)", "4"); + + assertOptimizedEquals("IF(true, 3, 4)", "3"); + assertOptimizedEquals("IF(false, 3, 4)", "4"); + assertOptimizedEquals("IF(null, 3, 4)", "4"); + + assertOptimizedEquals("IF(true, 3, null)", "3"); + assertOptimizedEquals("IF(false, 3, null)", "null"); + assertOptimizedEquals("IF(true, null, 4)", "null"); + assertOptimizedEquals("IF(false, null, 4)", "4"); + assertOptimizedEquals("IF(true, null, null)", "null"); + assertOptimizedEquals("IF(false, null, null)", "null"); + + assertOptimizedEquals("IF(true, 3.5E0, 4.2E0)", "3.5E0"); + assertOptimizedEquals("IF(false, 3.5E0, 4.2E0)", "4.2E0"); + + assertOptimizedEquals("IF(true, 'foo', 'bar')", "'foo'"); + assertOptimizedEquals("IF(false, 'foo', 'bar')", "'bar'"); + + assertOptimizedEquals("IF(true, 1.01, 1.02)", "1.01"); + assertOptimizedEquals("IF(false, 1.01, 1.02)", "1.02"); + assertOptimizedEquals("IF(true, 1234567890.123, 1.02)", "1234567890.123"); + assertOptimizedEquals("IF(false, 1.01, 1234567890.123)", "1234567890.123"); + } + + @Test + public void testLike() + { + assertOptimizedEquals("'a' LIKE 'a'", "true"); + assertOptimizedEquals("'' LIKE 'a'", "false"); + assertOptimizedEquals("'abc' LIKE 'a'", "false"); + + assertOptimizedEquals("'a' LIKE '_'", "true"); + assertOptimizedEquals("'' LIKE '_'", "false"); + assertOptimizedEquals("'abc' LIKE '_'", "false"); + + assertOptimizedEquals("'a' LIKE '%'", "true"); + assertOptimizedEquals("'' LIKE '%'", "true"); + assertOptimizedEquals("'abc' LIKE '%'", "true"); + + assertOptimizedEquals("'abc' LIKE '___'", "true"); + assertOptimizedEquals("'ab' LIKE '___'", "false"); + assertOptimizedEquals("'abcd' LIKE '___'", "false"); + + assertOptimizedEquals("'abc' LIKE 'abc'", "true"); + assertOptimizedEquals("'xyz' LIKE 'abc'", "false"); + assertOptimizedEquals("'abc0' LIKE 'abc'", "false"); + assertOptimizedEquals("'0abc' LIKE 'abc'", "false"); + + assertOptimizedEquals("'abc' LIKE 'abc%'", "true"); + assertOptimizedEquals("'abc0' LIKE 'abc%'", "true"); + assertOptimizedEquals("'0abc' LIKE 'abc%'", "false"); + + assertOptimizedEquals("'abc' LIKE '%abc'", "true"); + assertOptimizedEquals("'0abc' LIKE '%abc'", "true"); + assertOptimizedEquals("'abc0' LIKE '%abc'", "false"); + + assertOptimizedEquals("'abc' LIKE '%abc%'", "true"); + assertOptimizedEquals("'0abc' LIKE '%abc%'", "true"); + assertOptimizedEquals("'abc0' LIKE '%abc%'", "true"); + assertOptimizedEquals("'0abc0' LIKE '%abc%'", "true"); + assertOptimizedEquals("'xyzw' LIKE '%abc%'", "false"); + + assertOptimizedEquals("'abc' LIKE '%ab%c%'", "true"); + assertOptimizedEquals("'0abc' LIKE '%ab%c%'", "true"); + assertOptimizedEquals("'abc0' LIKE '%ab%c%'", "true"); + assertOptimizedEquals("'0abc0' LIKE '%ab%c%'", "true"); + assertOptimizedEquals("'ab01c' LIKE '%ab%c%'", "true"); + assertOptimizedEquals("'0ab01c' LIKE '%ab%c%'", "true"); + assertOptimizedEquals("'ab01c0' LIKE '%ab%c%'", "true"); + assertOptimizedEquals("'0ab01c0' LIKE '%ab%c%'", "true"); + + assertOptimizedEquals("'xyzw' LIKE '%ab%c%'", "false"); + + // ensure regex chars are escaped + assertOptimizedEquals("'\' LIKE '\'", "true"); + assertOptimizedEquals("'.*' LIKE '.*'", "true"); + assertOptimizedEquals("'[' LIKE '['", "true"); + assertOptimizedEquals("']' LIKE ']'", "true"); + assertOptimizedEquals("'{' LIKE '{'", "true"); + assertOptimizedEquals("'}' LIKE '}'", "true"); + assertOptimizedEquals("'?' LIKE '?'", "true"); + assertOptimizedEquals("'+' LIKE '+'", "true"); + assertOptimizedEquals("'(' LIKE '('", "true"); + assertOptimizedEquals("')' LIKE ')'", "true"); + assertOptimizedEquals("'|' LIKE '|'", "true"); + assertOptimizedEquals("'^' LIKE '^'", "true"); + assertOptimizedEquals("'$' LIKE '$'", "true"); + + assertOptimizedEquals("'%' LIKE 'z%' ESCAPE 'z'", "true"); + } + + @Test + public void testLikeNullOptimization() + { + assertOptimizedEquals("null LIKE '%'", "null"); + assertOptimizedEquals("'a' LIKE null", "null"); + assertOptimizedEquals("'a' LIKE '%' ESCAPE null", "null"); + assertOptimizedEquals("'a' LIKE unbound_string ESCAPE null", "null"); + } + + @Test + public void testLikeOptimization() + { + assertOptimizedEquals("unbound_string LIKE 'abc'", "unbound_string = CAST('abc' AS VARCHAR)"); + + assertOptimizedEquals("unbound_string LIKE '' ESCAPE '#'", "unbound_string LIKE '' ESCAPE '#'"); + assertOptimizedEquals("unbound_string LIKE 'abc' ESCAPE '#'", "unbound_string = CAST('abc' AS VARCHAR)"); + assertOptimizedEquals("unbound_string LIKE 'a#_b' ESCAPE '#'", "unbound_string = CAST('a_b' AS VARCHAR)"); + assertOptimizedEquals("unbound_string LIKE 'a#%b' ESCAPE '#'", "unbound_string = CAST('a%b' AS VARCHAR)"); + assertOptimizedEquals("unbound_string LIKE 'a#_##b' ESCAPE '#'", "unbound_string = CAST('a_#b' AS VARCHAR)"); + assertOptimizedEquals("unbound_string LIKE 'a#__b' ESCAPE '#'", "unbound_string LIKE 'a#__b' ESCAPE '#'"); + assertOptimizedEquals("unbound_string LIKE 'a##%b' ESCAPE '#'", "unbound_string LIKE 'a##%b' ESCAPE '#'"); + + assertOptimizedEquals("bound_string LIKE bound_pattern", "true"); + assertOptimizedEquals("'abc' LIKE bound_pattern", "false"); + + assertOptimizedEquals("unbound_string LIKE bound_pattern", "unbound_string LIKE bound_pattern"); + assertDoNotOptimize("unbound_string LIKE 'abc%'", SERIALIZABLE); + + assertOptimizedEquals("unbound_string LIKE unbound_pattern ESCAPE unbound_string", "unbound_string LIKE unbound_pattern ESCAPE unbound_string"); + } + + @Test + public void testInvalidLike() + { + assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE 'abc' ESCAPE ''")); + assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE 'abc' ESCAPE 'bc'")); + assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE '#' ESCAPE '#'")); + assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE '#abc' ESCAPE '#'")); + assertThrows(PrestoException.class, () -> optimize("unbound_string LIKE 'ab#' ESCAPE '#'")); + } + + @Test + public void testFailedExpressionOptimization() + { + assertFailedMatches("coalesce(0 / 0 > 1, unbound_boolean, 0 / 0 = 0)", + "coalesce(cast(fail(8, 'ignored failure message') as boolean), unbound_boolean)"); + + assertFailedMatches("if(false, 1, 0 / 0)", "cast(fail(8, 'ignored failure message') as integer)"); + + assertFailedMatches("CASE unbound_long WHEN 1 THEN 1 WHEN 0 / 0 THEN 2 END", + "CASE unbound_long WHEN BIGINT '1' THEN 1 WHEN cast(fail(8, 'ignored failure message') as bigint) THEN 2 END"); + + assertFailedMatches("CASE unbound_boolean WHEN true THEN 1 ELSE 0 / 0 END", + "CASE unbound_boolean WHEN true THEN 1 ELSE cast(fail(8, 'ignored failure message') as integer) END"); + + assertFailedMatches("CASE bound_long WHEN unbound_long THEN 1 WHEN 0 / 0 THEN 2 ELSE 1 END", + "CASE BIGINT '1234' WHEN unbound_long THEN 1 WHEN cast(fail(8, 'ignored failure message') as bigint) THEN 2 ELSE 1 END"); + + assertFailedMatches("case when unbound_boolean then 1 when 0 / 0 = 0 then 2 end", + "case when unbound_boolean then 1 when cast(fail(8, 'ignored failure message') as boolean) then 2 end"); + + assertFailedMatches("case when unbound_boolean then 1 else 0 / 0 end", + "case when unbound_boolean then 1 else cast(fail(8, 'ignored failure message') as integer) end"); + + assertFailedMatches("case when unbound_boolean then 0 / 0 else 1 end", + "case when unbound_boolean then cast(fail(8, 'ignored failure message') as integer) else 1 end"); + + assertFailedMatches("case true " + + "when unbound_long = 1 then 1 " + + "when 0 / 0 = 0 then 2 " + + "else 33 end", + "case true " + + "when unbound_long = BIGINT '1' then 1 " + + "when CAST(fail(8, 'ignored failure message') AS boolean) then 2 else 33 " + + "end"); + + assertFailedMatches("case 1 " + + "when 0 / 0 then 1 " + + "when 0 / 0 then 2 " + + "else 1 " + + "end", + "case 1 " + + "when cast(fail(8, 'ignored failure message') as integer) then 1 " + + "when cast(fail(8, 'ignored failure message') as integer) then 2 " + + "else 1 " + + "end"); + + assertFailedMatches("case 1 " + + "when unbound_long then 1 " + + "when 0 / 0 then 2 " + + "else 1 " + + "end", + "" + + "case BIGINT '1' " + + "when unbound_long then 1 " + + "when cast(fail(8, 'ignored failure message') AS integer) then 2 " + + "else 1 " + + "end"); + } + + @Test + public void testArrayConstructor() + { + optimize("ARRAY []"); + assertOptimizedEquals("ARRAY [(unbound_long + 0), (unbound_long + 1), (unbound_long + 2)]", + "array_constructor((unbound_long + 0), (unbound_long + 1), (unbound_long + 2))"); + assertOptimizedEquals("ARRAY [(bound_long + 0), (unbound_long + 1), (bound_long + 2)]", + "array_constructor((bound_long + 0), (unbound_long + 1), (bound_long + 2))"); + assertOptimizedEquals("ARRAY [(bound_long + 0), (unbound_long + 1), NULL]", + "array_constructor((bound_long + 0), (unbound_long + 1), NULL)"); + } + + @Test + public void testRowConstructor() + { + optimize("ROW(NULL)"); + optimize("ROW(1)"); + optimize("ROW(unbound_long + 0)"); + optimize("ROW(unbound_long + unbound_long2, unbound_string, unbound_double)"); + optimize("ROW(unbound_boolean, FALSE, ARRAY[unbound_long, unbound_long2], unbound_null_string, unbound_interval)"); + optimize("ARRAY [ROW(unbound_string, unbound_double), ROW(unbound_string, 0.0E0)]"); + optimize("ARRAY [ROW('string', unbound_double), ROW('string', bound_double)]"); + optimize("ROW(ROW(NULL), ROW(ROW(ROW(ROW('rowception')))))"); + optimize("ROW(unbound_string, bound_string)"); + + optimize("ARRAY [ROW(unbound_string, unbound_double), ROW(CAST(bound_string AS VARCHAR), 0.0E0)]"); + optimize("ARRAY [ROW(CAST(bound_string AS VARCHAR), 0.0E0), ROW(unbound_string, unbound_double)]"); + + optimize("ARRAY [ROW(unbound_string, unbound_double), CAST(NULL AS ROW(VARCHAR, DOUBLE))]"); + optimize("ARRAY [CAST(NULL AS ROW(VARCHAR, DOUBLE)), ROW(unbound_string, unbound_double)]"); + } + + @Test + public void testDereference() + { + optimize("ARRAY []"); + assertOptimizedEquals("ARRAY [(unbound_long + 0), (unbound_long + 1), (unbound_long + 2)]", + "array_constructor((unbound_long + 0), (unbound_long + 1), (unbound_long + 2))"); + assertOptimizedEquals("ARRAY [(bound_long + 0), (unbound_long + 1), (bound_long + 2)]", + "array_constructor((bound_long + 0), (unbound_long + 1), (bound_long + 2))"); + assertOptimizedEquals("ARRAY [(bound_long + 0), (unbound_long + 1), NULL]", + "array_constructor((bound_long + 0), (unbound_long + 1), NULL)"); + } + + @Test + public void testRowDereference() + { + optimize("CAST(null AS ROW(a VARCHAR, b BIGINT)).a"); + } + + @Test + public void testRowSubscript() + { + assertOptimizedEquals("ROW (1, 'a', true)[3]", "true"); + assertOptimizedEquals("ROW (1, 'a', ROW (2, 'b', ROW (3, 'c')))[3][3][2]", "'c'"); + } + + @Test + public void testOptimizeDivideByZero() + { + assertThrows(PrestoException.class, () -> optimize("0 / 0")); + } + + @Test + public void testArraySubscriptConstantNegativeIndex() + { + assertThrows(PrestoException.class, () -> optimize("ARRAY [1, 2, 3][-1]")); + } + + @Test + public void testArraySubscriptConstantZeroIndex() + { + assertThrows(PrestoException.class, () -> optimize("ARRAY [1, 2, 3][0]")); + } + + @Test + public void testMapSubscriptMissingKey() + { + assertThrows(PrestoException.class, () -> optimize("MAP(ARRAY [1, 2], ARRAY [3, 4])[-1]")); + } + + @Test + public void testMapSubscriptConstantIndexes() + { + optimize("MAP(ARRAY [1, 2], ARRAY [3, 4])[1]"); + optimize("MAP(ARRAY [BIGINT '1', 2], ARRAY [3, 4])[1]"); + optimize("MAP(ARRAY [1, 2], ARRAY [3, 4])[2]"); + optimize("MAP(ARRAY [ARRAY[1,1]], ARRAY['a'])[ARRAY[1,1]]"); + } + + @Test + public void testLiterals() + { + optimize("date '2013-04-03' + unbound_interval"); + optimize("time '03:04:05.321' + unbound_interval"); + optimize("time '03:04:05.321 UTC' + unbound_interval"); + optimize("timestamp '2013-04-03 03:04:05.321' + unbound_interval"); + optimize("timestamp '2013-04-03 03:04:05.321 UTC' + unbound_interval"); + + optimize("interval '3' day * unbound_long"); + optimize("interval '3' year * unbound_long"); + } + + @Test + public void testVarbinaryLiteral() + { + assertEquals(optimize("X'1234'"), Slices.wrappedBuffer((byte) 0x12, (byte) 0x34)); + } + + public void assertExpressionAndRowExpressionEquals(Object expressionResult, Object rowExpressionResult) + { + if (rowExpressionResult instanceof RowExpression) { + // Cannot be completely evaluated into a constant; compare expressions + assertTrue(expressionResult instanceof Expression); + + // It is tricky to check the equivalence of an expression and a row expression. + // We rely on the optimized translator to fill the gap. + RowExpression translated = TRANSLATOR.translateAndOptimize((Expression) expressionResult, SYMBOL_TYPES); + assertRowExpressionEvaluationEquals(translated, rowExpressionResult); + } + else { + // We have constants; directly compare + assertRowExpressionEvaluationEquals(expressionResult, rowExpressionResult); + } + } + + public static void assertRowExpressionEvaluationEquals(RowExpression left, RowExpression right) + { + assertTrue(left instanceof RowExpression); + // assertEquals(((RowExpression) left).getType(), ((RowExpression) right).getType()); + if (left instanceof ConstantExpression) { + if (isRemovableCast(right)) { + assertRowExpressionEvaluationEquals(left, ((CallExpression) right).getArguments().get(0)); + return; + } + assertTrue(right instanceof ConstantExpression); + assertRowExpressionEvaluationEquals(((ConstantExpression) left).getValue(), ((ConstantExpression) left).getValue()); + } + else if (left instanceof InputReferenceExpression || left instanceof VariableReferenceExpression) { + assertEquals(left, right); + } + else if (left instanceof CallExpression) { + assertTrue(right instanceof CallExpression); + assertEquals(((CallExpression) left).getFunctionHandle(), ((CallExpression) right).getFunctionHandle()); + assertEquals(((CallExpression) left).getArguments().size(), ((CallExpression) right).getArguments().size()); + for (int i = 0; i < ((CallExpression) left).getArguments().size(); i++) { + assertRowExpressionEvaluationEquals(((CallExpression) left).getArguments().get(i), ((CallExpression) right).getArguments().get(i)); + } + } + else if (left instanceof SpecialFormExpression) { + assertTrue(right instanceof SpecialFormExpression); + assertEquals(((SpecialFormExpression) left).getForm(), ((SpecialFormExpression) right).getForm()); + assertEquals(((SpecialFormExpression) left).getArguments().size(), ((SpecialFormExpression) right).getArguments().size()); + for (int i = 0; i < ((SpecialFormExpression) left).getArguments().size(); i++) { + assertRowExpressionEvaluationEquals(((SpecialFormExpression) left).getArguments().get(i), ((SpecialFormExpression) right).getArguments().get(i)); + } + } + else { + assertTrue(left instanceof LambdaDefinitionExpression); + assertTrue(right instanceof LambdaDefinitionExpression); + assertEquals(((LambdaDefinitionExpression) left).getArguments(), ((LambdaDefinitionExpression) right).getArguments()); + assertEquals(((LambdaDefinitionExpression) left).getArgumentTypes(), ((LambdaDefinitionExpression) right).getArgumentTypes()); + assertRowExpressionEvaluationEquals(((LambdaDefinitionExpression) left).getBody(), ((LambdaDefinitionExpression) right).getBody()); + } + } + + /** + * Assert the evaluation result of two row expressions equivalent + * no matter they are constants or remaining row expressions. + */ + public static void assertRowExpressionEvaluationEquals(Object left, Object right) + { + if (right instanceof RowExpression) { + assertRowExpressionEvaluationEquals((RowExpression) left, (RowExpression) right); + } + else { + // We have constants; directly compare + if (left instanceof Block) { + assertTrue(right instanceof Block); + assertEquals(blockToSlice((Block) left), blockToSlice((Block) right)); + } + else { + assertEquals(left, right); + } + } + } + + private static boolean isRemovableCast(Object value) + { + if (value instanceof CallExpression && + new FunctionResolution(METADATA.getFunctionAndTypeManager().getFunctionAndTypeResolver()).isCastFunction(((CallExpression) value).getFunctionHandle())) { + Type targetType = ((CallExpression) value).getType(); + Type sourceType = ((CallExpression) value).getArguments().get(0).getType(); + return METADATA.getFunctionAndTypeManager().canCoerce(sourceType, targetType); + } + return false; + } + + public abstract void assertOptimizedEquals(@Language("SQL") String actual, @Language("SQL") String expected); + + public void assertRoundTrip(String expression) + { + ParsingOptions parsingOptions = createParsingOptions(TEST_SESSION); + assertEquals(SQL_PARSER.createExpression(expression, parsingOptions), + SQL_PARSER.createExpression(formatExpression(SQL_PARSER.createExpression(expression, parsingOptions), Optional.empty()), parsingOptions)); + } + + public static RowExpression toRowExpression(Expression expression) + { + return TRANSLATOR.translate(expression, SYMBOL_TYPES); + } + + public abstract void assertOptimizedMatches(@Language("SQL") String actual, @Language("SQL") String expected); + + public void assertFailedMatches(@Language("SQL") String actual, @Language("SQL") String expected) + { + assertOptimizedMatches(actual, expected); + } + + public abstract Object optimize(@Language("SQL") String expression); + + public static Expression expression(String expression) + { + return FunctionAssertions.createExpression(expression, METADATA, SYMBOL_TYPES); + } + + public abstract void assertDoNotOptimize(@Language("SQL") String expression, ExpressionOptimizer.Level optimizationLevel); + + public static Object symbolConstant(Symbol symbol) + { + switch (symbol.getName().toLowerCase(ENGLISH)) { + case "bound_integer": + return 1234L; + case "bound_long": + return 1234L; + case "bound_string": + return utf8Slice("hello"); + case "bound_double": + return 12.34; + case "bound_date": + return new LocalDate(2001, 8, 22).toDateMidnight(DateTimeZone.UTC).getMillis(); + case "bound_time": + return new LocalTime(3, 4, 5, 321).toDateTime(new DateTime(0, DateTimeZone.UTC)).getMillis(); + case "bound_timestamp": + return new DateTime(2001, 8, 22, 3, 4, 5, 321, DateTimeZone.UTC).getMillis(); + case "bound_pattern": + return utf8Slice("%el%"); + case "bound_timestamp_with_timezone": + return new SqlTimestampWithTimeZone(new DateTime(1970, 1, 1, 1, 0, 0, 999, DateTimeZone.UTC).getMillis(), getTimeZoneKey("Z")); + case "bound_varbinary": + return Slices.wrappedBuffer((byte) 0xab); + case "bound_decimal_short": + return 12345L; + case "bound_decimal_long": + return Decimals.encodeUnscaledValue(new BigInteger("12345678901234567890123")); + } + return null; + } + + public static Slice blockToSlice(Block block) + { + // This function is strictly for testing use only + SliceOutput sliceOutput = new DynamicSliceOutput(1000); + BlockSerdeUtil.writeBlock(blockEncodingSerde, sliceOutput, block); + return sliceOutput.slice(); + } + + public abstract void assertEvaluatedEquals(@Language("SQL") String actual, @Language("SQL") String expected); + + public abstract Object evaluate(@Language("SQL") String expression, boolean deterministic); +} diff --git a/presto-native-execution/pom.xml b/presto-native-execution/pom.xml index 9eeb4b10250e6..fb847c6f9a999 100644 --- a/presto-native-execution/pom.xml +++ b/presto-native-execution/pom.xml @@ -326,6 +326,31 @@ + + + com.facebook.airlift + bootstrap + test + + + + com.facebook.airlift + configuration + test + + + + com.facebook.airlift + http-client + test + + + + com.facebook.airlift + json + test + + com.facebook.airlift log @@ -343,6 +368,13 @@ testing test + + + com.facebook.airlift.drift + drift-codec + test + + org.testcontainers testcontainers @@ -372,6 +404,18 @@ ${project.version} test + + + com.fasterxml.jackson.core + jackson-core + test + + + + com.fasterxml.jackson.core + jackson-databind + test + diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestNativeExpressionInterpreter.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestNativeExpressionInterpreter.java new file mode 100644 index 0000000000000..c02996d61c70c --- /dev/null +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestNativeExpressionInterpreter.java @@ -0,0 +1,620 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.nativeworker; + +import com.facebook.airlift.bootstrap.Bootstrap; +import com.facebook.airlift.http.client.HttpUriBuilder; +import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.json.JsonModule; +import com.facebook.airlift.log.Logger; +import com.facebook.drift.codec.guice.ThriftCodecModule; +import com.facebook.presto.block.BlockJsonSerde; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockEncoding; +import com.facebook.presto.common.block.BlockEncodingManager; +import com.facebook.presto.common.block.BlockEncodingSerde; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.connector.ConnectorManager; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.HandleJsonModule; +import com.facebook.presto.operator.scalar.FunctionAssertions; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.InputReferenceExpression; +import com.facebook.presto.spi.relation.LambdaDefinitionExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.RowExpressionVisitor; +import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.expressions.AbstractTestExpressionInterpreter; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.type.TypeDeserializer; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import com.google.inject.Injector; +import com.google.inject.Key; +import com.google.inject.Module; +import com.google.inject.Scopes; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.net.ServerSocket; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.UUID; + +import static com.facebook.airlift.configuration.ConfigBinder.configBinder; +import static com.facebook.airlift.json.JsonBinder.jsonBinder; +import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager; +import static com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils.getNativeQueryRunnerParameters; +import static com.facebook.presto.operator.scalar.ApplyFunction.APPLY_FUNCTION; +import static com.google.common.net.HttpHeaders.ACCEPT; +import static com.google.common.net.HttpHeaders.CONTENT_TYPE; +import static com.google.common.net.MediaType.JSON_UTF_8; +import static com.google.inject.multibindings.Multibinder.newSetBinder; +import static java.lang.String.format; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class TestNativeExpressionInterpreter + extends AbstractTestExpressionInterpreter +{ + private static final Logger log = Logger.get(TestNativeExpressionInterpreter.class); + + private JsonCodec codec; + private TestVisitor visitor; + private Process sidecar; + private URI expressionUri; + + public TestNativeExpressionInterpreter() + { + METADATA.getFunctionAndTypeManager().registerBuiltInFunctions(ImmutableList.of(APPLY_FUNCTION)); + } + + @BeforeClass + public void setup() + throws Exception + { + codec = getJsonCodec(); + visitor = new TestVisitor(); + int port = findRandomPort(); + HttpUriBuilder sidecarUri = HttpUriBuilder.uriBuilder() + .scheme("http") + .host("127.0.0.1") + .port(port); + expressionUri = sidecarUri.appendPath("/v1/expressions").build(); + sidecar = getSidecarProcess(sidecarUri.build(), port); + + try { + HttpClient client = HttpClient.newHttpClient(); + URI infoUri = sidecarUri.appendPath("/v1/info").build(); + HttpRequest request = HttpRequest.newBuilder() + .uri(infoUri) + .header(ACCEPT, JSON_UTF_8.toString()) + .GET() + .build(); + + long timeoutMs = 15000; + long pollIntervalMs = 1000; + long deadline = System.currentTimeMillis() + timeoutMs; + boolean sidecarProcessStarted = false; + + while (System.currentTimeMillis() < deadline) { + try { + HttpResponse response = client.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + if (response.statusCode() != 500) { + sidecarProcessStarted = true; + break; + } + } + catch (IOException e) { + // ignore and retry until deadline + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + + try { + Thread.sleep(pollIntervalMs); + } + catch (InterruptedException e) { + // ignore and retry until deadline + } + } + + assertTrue(sidecarProcessStarted, format("Sidecar did not start properly within %d ms", timeoutMs)); + } + catch (Exception e) { + log.error(e, "Failed while waiting for sidecar startup"); + throw new Exception(e); + } + } + + @AfterClass + public void tearDown() + { + sidecar.destroyForcibly(); + } + + /// TODO: Optimization of `'a' LIKE unbound_string ESCAPE null` to `null` is not supported in Velox. + @Override + @Test + public void testLikeNullOptimization() + { + assertOptimizedEquals("null LIKE '%'", "null"); + assertOptimizedEquals("'a' LIKE null", "null"); + assertOptimizedEquals("'a' LIKE '%' ESCAPE null", "null"); + // assertOptimizedEquals("'a' LIKE unbound_string ESCAPE null", "null"); + } + + /// TODO: Certain tests are pending on IN-rewrite in Velox. + @Override + @Test + public void testIn() + { + assertOptimizedEquals("3 in (2, 4, 3, 5)", "true"); + assertOptimizedEquals("3 in (2, 4, 9, 5)", "false"); + assertOptimizedEquals("3 in (2, null, 3, 5)", "true"); + + assertOptimizedEquals("'foo' in ('bar', 'baz', 'foo', 'blah')", "true"); + assertOptimizedEquals("'foo' in ('bar', 'baz', 'buz', 'blah')", "false"); + assertOptimizedEquals("'foo' in ('bar', null, 'foo', 'blah')", "true"); + + assertOptimizedEquals("null in (2, null, 3, 5)", "null"); + assertOptimizedEquals("3 in (2, null)", "null"); + + assertOptimizedEquals("bound_integer in (2, 1234, 3, 5)", "true"); + assertOptimizedEquals("bound_integer in (2, 4, 3, 5)", "false"); + assertOptimizedEquals("1234 in (2, bound_integer, 3, 5)", "true"); + assertOptimizedEquals("99 in (2, bound_integer, 3, 5)", "false"); + assertOptimizedEquals("bound_integer in (2, bound_integer, 3, 5)", "true"); + + assertOptimizedEquals("bound_long in (2, 1234, 3, 5)", "true"); + assertOptimizedEquals("bound_long in (2, 4, 3, 5)", "false"); + assertOptimizedEquals("1234 in (2, bound_long, 3, 5)", "true"); + assertOptimizedEquals("99 in (2, bound_long, 3, 5)", "false"); + assertOptimizedEquals("bound_long in (2, bound_long, 3, 5)", "true"); + + assertOptimizedEquals("bound_string in ('bar', 'hello', 'foo', 'blah')", "true"); + assertOptimizedEquals("bound_string in ('bar', 'baz', 'foo', 'blah')", "false"); + assertOptimizedEquals("'hello' in ('bar', bound_string, 'foo', 'blah')", "true"); + assertOptimizedEquals("'baz' in ('bar', bound_string, 'foo', 'blah')", "false"); + + // TODO: Pending on IN rewrite in Velox. +// assertOptimizedEquals("bound_long in (2, 1234, unbound_long, 5)", "true"); +// assertOptimizedEquals("bound_string in ('bar', 'hello', unbound_string, 'blah')", "true"); +// assertOptimizedEquals("bound_long in (2, 4, unbound_long, unbound_long2, 9)", "1234 in (unbound_long, unbound_long2)"); +// assertOptimizedEquals("unbound_long in (2, 4, bound_long, unbound_long2, 5)", "unbound_long in (2, 4, 1234, unbound_long2, 5)"); + + assertOptimizedEquals("1.15 in (1.1, 1.2, 1.3, 1.15)", "true"); + assertOptimizedEquals("9876543210.98745612035 in (9876543210.9874561203, 9876543210.9874561204, 9876543210.98745612035)", "true"); + assertOptimizedEquals("bound_decimal_short in (123.455, 123.46, 123.45)", "true"); + assertOptimizedEquals("bound_decimal_long in (12345678901234567890.123, 9876543210.9874561204, 9876543210.98745612035)", "true"); + assertOptimizedEquals("bound_decimal_long in (9876543210.9874561204, null, 9876543210.98745612035)", "null"); + } + + /// Velox adds an implicit cast to the expression type for failed expression optimizations, so expected expression + /// containing one or more `fail` subexpressions should not be optimized in Velox. The string representation of + /// RowExpression is used to validate `fail` expression replaces all failing subexpressions. + @Override + @Test + public void testFailedExpressionOptimization() + { + // TODO: Velox COALESCE rewrite should be enhanced to deduplicate fail expressions. + assertFailedMatches("coalesce(0 / 0 > 1, unbound_boolean, 0 / 0 = 0)", + "COALESCE\\(presto.default.\\$operator\\$greater_than\\(presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\), 1\\), unbound_boolean, presto.default.\\$operator\\$equal\\(presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\), 0\\)\\)"); + + assertFailedMatches("if(false, 1, 0 / 0)", "presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\)"); + + assertFailedMatches("CASE unbound_long WHEN 1 THEN 1 WHEN 0 / 0 THEN 2 END", + "SWITCH\\(WHEN\\(presto.default.\\$operator\\$equal\\(1, unbound_long\\), 1\\), WHEN\\(presto.default.\\$operator\\$equal\\(presto.default.\\$operator\\$cast\\(presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\)\\), unbound_long\\), 2\\), null\\)"); + + assertFailedMatches("CASE unbound_boolean WHEN true THEN 1 ELSE 0 / 0 END", + "SWITCH\\(WHEN\\(presto.default.\\$operator\\$equal\\(true, unbound_boolean\\), 1\\), presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\)\\)"); + + assertFailedMatches("CASE bound_long WHEN unbound_long THEN 1 WHEN 0 / 0 THEN 2 ELSE 1 END", + "SWITCH\\(WHEN\\(presto.default.\\$operator\\$equal\\(unbound_long, 1234\\), 1\\), WHEN\\(presto.default.\\$operator\\$equal\\(presto.default.\\$operator\\$cast\\(presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\)\\), 1234\\), 2\\), 1\\)"); + + assertFailedMatches("case when unbound_boolean then 1 when 0 / 0 = 0 then 2 end", + "SWITCH\\(WHEN\\(unbound_boolean, 1\\), WHEN\\(presto.default.\\$operator\\$equal\\(presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\), 0\\), 2\\), null\\)"); + + assertFailedMatches("case when unbound_boolean then 1 else 0 / 0 end", + "SWITCH\\(WHEN\\(unbound_boolean, 1\\), presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\)\\)"); + + assertFailedMatches("case when unbound_boolean then 0 / 0 else 1 end", + "SWITCH\\(WHEN\\(unbound_boolean, presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\)\\), 1\\)"); + + assertFailedMatches("case true " + + "when unbound_long = 1 then 1 " + + "when 0 / 0 = 0 then 2 " + + "else 33 end", + "SWITCH\\(WHEN\\(presto.default.\\$operator\\$equal\\(unbound_long, 1\\), 1\\), WHEN\\(presto.default.\\$operator\\$equal\\(presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\), 0\\), 2\\), 33\\)"); + + assertFailedMatches("case 1 " + + "when 0 / 0 then 1 " + + "when 0 / 0 then 2 " + + "else 1 " + + "end", + "SWITCH\\(WHEN\\(presto.default.\\$operator\\$equal\\(presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\), 1\\), 1\\), WHEN\\(presto.default.\\$operator\\$equal\\(presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\), 1\\), 2\\), 1\\)"); + + assertFailedMatches("case 1 " + + "when unbound_long then 1 " + + "when 0 / 0 then 2 " + + "else 1 " + + "end", + "SWITCH\\(WHEN\\(presto.default.\\$operator\\$equal\\(unbound_long, 1\\), 1\\), WHEN\\(presto.default.\\$operator\\$equal\\(presto.default.\\$operator\\$cast\\(presto.default.\\$operator\\$cast\\(presto.default.fail\\(.*\\)\\)\\), 1\\), 2\\), 1\\)"); + } + + /// TODO: Optimizer-Level EVALUATED is not supported by the sidecar. + /// Velox permits Bigint to Varchar cast but Presto does not. + @Override + @Test(enabled = false) + public void testCastBigintToBoundedVarchar() {} + + /// TODO: Optimizer-Level EVALUATED is not supported by the sidecar. + @Override + @Test(enabled = false) + public void testInComplexTypes() {} + + /// TODO: Optimizer-Level EVALUATED is not supported by the sidecar. + @Override + @Test(enabled = false) + public void testOptimizeDivideByZero() {} + + /// TODO: Optimizer-Level EVALUATED is not supported by the sidecar. + @Override + @Test(enabled = false) + public void testArraySubscriptConstantNegativeIndex() {} + + /// TODO: Optimizer-Level EVALUATED is not supported by the sidecar. + @Override + @Test + public void testArraySubscriptConstantZeroIndex() {} + + /// TODO: Fix result mismatch between Presto and Velox for VARBINARY literals. + /// java.lang.AssertionError: + /// Expected :Slice{base=[B@771ede0d, address=16, length=2} + /// Actual :Slice{base=[B@1fd73dcb, address=47, length=2} + @Override + @Test(enabled = false) + public void testVarbinaryLiteral() {} + + // TODO: current timestamp returns the session timestamp and this should be evaluated on the sidecar plugin. + @Override + @Test(enabled = false) + public void testCurrentTimestamp() {} + + /// TODO: current_user should be evaluated in the sidecar plugin and not in the sidecar. + @Override + @Test(enabled = false) + public void testCurrentUser() {} + + /// This test is disabled because these optimizations for LIKE function are not yet supported in Velox. + /// TODO: LIKE function with field references can be optimized in Velox. + @Override + @Test(enabled = false) + public void testLikeOptimization() {} + + /// LIKE function with unbound_string will not be optimized/constant-folded in Velox. + /// TODO: LIKE function with field references can be optimized in Velox. + @Override + @Test(enabled = false) + public void testInvalidLike() {} + + /// TODO: NULL_IF special form is unsupported in Presto native. + @Override + @Test(enabled = false) + public void testNullIf() {} + + /// TODO: Json based UDFs are not supported by the native expression optimizer. + @Override + @Test(enabled = false) + public void testCppFunctionCall() {} + + /// TODO: Json based UDFs are not supported by the native expression optimizer. + @Override + @Test(enabled = false) + public void testCppAggregateFunctionCall() {} + + /// TODO: Non-deterministic functions should be evaluated on sidecar when optimizer level is EVALUATED. + @Override + @Test(enabled = false) + public void testNonDeterministicFunctionCall() {} + + /// Velox only supports legacy map subscript and returns NULL for missing keys instead of an exception. + @Override + @Test(enabled = false) + public void testMapSubscriptMissingKey() {} + + /// TODO: Nested row subscript is not supported by Presto to Velox expression conversion. + @Override + @Test(enabled = false) + public void testRowSubscript() {} + + @Override + public Object evaluate(@Language("SQL") String expression, boolean deterministic) + { + return evaluate(expression); + } + + private RowExpression evaluate(@Language("SQL") String expression) + { + return optimize(expression, ExpressionOptimizer.Level.EVALUATED); + } + + @Override + public RowExpression optimize(String expression) + { + return optimize(expression, ExpressionOptimizer.Level.OPTIMIZED); + } + + private RowExpression optimize(@Language("SQL") String expression, ExpressionOptimizer.Level level) + { + assertRoundTrip(expression); + RowExpression parsedExpression = sqlToRowExpression(expression); + return optimizeRowExpression(parsedExpression, level); + } + + @Override + public void assertOptimizedEquals(@Language("SQL") String actual, @Language("SQL") String expected) + { + RowExpression optimizedActual = optimize(actual, ExpressionOptimizer.Level.OPTIMIZED); + RowExpression optimizedExpected = optimize(expected, ExpressionOptimizer.Level.OPTIMIZED); + assertRowExpressionEvaluationEquals(optimizedActual, optimizedExpected); + } + + @Override + public void assertOptimizedMatches(@Language("SQL") String actual, @Language("SQL") String expected) + { + assertOptimizedEquals(actual, expected); + } + + /// Checks that the string representation of the failed optimized expression matches expected. + @Override + public void assertFailedMatches(@Language("SQL") String actual, @Language("RegExp") String expected) + { + RowExpression optimized = optimize(actual); + assertTrue(optimized.toString().matches(expected)); + } + + @Override + public void assertDoNotOptimize(@Language("SQL") String expression, ExpressionOptimizer.Level optimizationLevel) + { + assertRoundTrip(expression); + RowExpression rowExpression = sqlToRowExpression(expression); + RowExpression rowExpressionResult = optimizeRowExpression(rowExpression, ExpressionOptimizer.Level.OPTIMIZED); + assertRowExpressionEvaluationEquals(rowExpressionResult, rowExpression); + } + + private RowExpression sqlToRowExpression(String expression) + { + Expression parsedExpression = FunctionAssertions.createExpression(expression, METADATA, SYMBOL_TYPES); + return TRANSLATOR.translate(parsedExpression, SYMBOL_TYPES); + } + + @Override + public void assertEvaluatedEquals(@Language("SQL") String actual, @Language("SQL") String expected) + { + assertRowExpressionEvaluationEquals(evaluate(actual), evaluate(expected)); + } + + private RowExpression optimizeRowExpression(RowExpression expression, ExpressionOptimizer.Level level) + { + expression = expression.accept(visitor, null); + HttpResponse response = null; + try { + response = getSidecarResponse(expression, level); + } + catch (Exception e) { + log.error(e, "Failed to get sidecar response: %s.", e.getMessage()); + throw new RuntimeException(e); + } + + assertEquals(response.statusCode(), 200, "Sidecar returned error."); + String responseBody = response.body(); + ObjectMapper mapper = new ObjectMapper(); + RowExpression result = expression; + try { + // Response should be a JSON array consisting of a single RowExpression. + JsonNode optimizedExpressionList = mapper.readTree(responseBody); + assertTrue(optimizedExpressionList.isArray()); + JsonNode optimizedExpression = optimizedExpressionList.get(0); + result = codec.fromJson(optimizedExpression.toString()); + } + catch (JsonProcessingException e) { + log.error(e, "Failed to decode RowExpression from sidecar response: %s.", e.getMessage()); + throw new RuntimeException(e); + } + + return result; + } + + private HttpResponse getSidecarResponse(RowExpression expression, ExpressionOptimizer.Level level) + throws IOException, InterruptedException + { + String json = String.format("[%s]", codec.toJson(expression)); + HttpClient client = HttpClient.newHttpClient(); + HttpRequest request = HttpRequest.newBuilder() + .uri(expressionUri) + .header(CONTENT_TYPE, JSON_UTF_8.toString()) + .header(ACCEPT, JSON_UTF_8.toString()) + .header("X-Presto-Time-Zone", TEST_SESSION.getSqlFunctionProperties().getTimeZoneKey().getId()) + .header("X-Presto-Expression-Optimizer-Level", level.name()) + .POST(HttpRequest.BodyPublishers.ofString(json, StandardCharsets.UTF_8)) + .build(); + + return client.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)); + } + + private JsonCodec getJsonCodec() + { + Module module = binder -> { + binder.install(new JsonModule()); + binder.install(new HandleJsonModule()); + binder.bind(ConnectorManager.class).toProvider(() -> null).in(Scopes.SINGLETON); + binder.install(new ThriftCodecModule()); + configBinder(binder).bindConfig(FeaturesConfig.class); + + FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager(); + binder.bind(TypeManager.class).toInstance(functionAndTypeManager); + jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class); + newSetBinder(binder, Type.class); + + binder.bind(BlockEncodingSerde.class).to(BlockEncodingManager.class).in(Scopes.SINGLETON); + newSetBinder(binder, BlockEncoding.class); + jsonBinder(binder).addSerializerBinding(Block.class).to(BlockJsonSerde.Serializer.class); + jsonBinder(binder).addDeserializerBinding(Block.class).to(BlockJsonSerde.Deserializer.class); + jsonCodecBinder(binder).bindJsonCodec(RowExpression.class); + }; + Bootstrap app = new Bootstrap(ImmutableList.of(module)); + Injector injector = app + .doNotInitializeLogging() + .quiet() + .initialize(); + return injector.getInstance(new Key>() {}); + } + + private static Process getSidecarProcess(URI discoveryUri, int port) + throws IOException + { + Path tempDirectoryPath = Files.createTempDirectory(PrestoNativeQueryRunnerUtils.class.getSimpleName()); + log.info("Temp directory for Sidecar: %s", tempDirectoryPath.toString()); + + String configProperties = format("discovery.uri=%s%n" + + "presto.version=testversion%n" + + "system-memory-gb=4%n" + + "native-sidecar=true%n" + + "http-server.http.port=%d", discoveryUri, port); + + Files.write(tempDirectoryPath.resolve("config.properties"), configProperties.getBytes()); + Files.write(tempDirectoryPath.resolve("node.properties"), + format("node.id=%s%n" + + "node.internal-address=127.0.0.1%n" + + "node.environment=testing%n" + + "node.location=test-location", UUID.randomUUID()).getBytes()); + + Path catalogDirectoryPath = tempDirectoryPath.resolve("catalog"); + Files.createDirectory(catalogDirectoryPath); + PrestoNativeQueryRunnerUtils.NativeQueryRunnerParameters nativeQueryRunnerParameters = getNativeQueryRunnerParameters(); + String prestoServerPath = nativeQueryRunnerParameters.serverBinary.toString(); + + return new ProcessBuilder(prestoServerPath, "--logtostderr=1", "--v=1") + .directory(tempDirectoryPath.toFile()) + .redirectErrorStream(true) + .redirectOutput(ProcessBuilder.Redirect.to(tempDirectoryPath.resolve("sidecar.out").toFile())) + .redirectError(ProcessBuilder.Redirect.to(tempDirectoryPath.resolve("sidecar.out").toFile())) + .start(); + } + + public static int findRandomPort() + throws IOException + { + try (ServerSocket socket = new ServerSocket(0)) { + return socket.getLocalPort(); + } + } + + private static class TestVisitor + implements RowExpressionVisitor + { + @Override + public RowExpression visitInputReference(InputReferenceExpression node, Object context) + { + return node; + } + + @Override + public RowExpression visitConstant(ConstantExpression node, Object context) + { + return node; + } + + /** + * Convert a variable reference to a RowExpression. + * If the symbol has a constant value, return a ConstantExpression of the appropriate type. + * Otherwise return a VariableReferenceExpression as before. + */ + @Override + public RowExpression visitVariableReference(VariableReferenceExpression node, Object context) + { + Symbol symbol = new Symbol(node.getName()); + Object value = symbolConstant(symbol); + if (value == null) { + return new VariableReferenceExpression(Optional.empty(), symbol.getName(), SYMBOL_TYPES.get(symbol.toSymbolReference())); + } + Type type = SYMBOL_TYPES.get(symbol.toSymbolReference()); + return new ConstantExpression(value, type); + } + + @Override + public RowExpression visitCall(CallExpression call, Object context) + { + CallExpression callExpression; + List newArguments = new ArrayList<>(); + for (RowExpression argument : call.getArguments()) { + RowExpression newArgument = argument.accept(this, context); + newArguments.add(newArgument); + } + callExpression = new CallExpression( + call.getSourceLocation(), + call.getDisplayName(), + call.getFunctionHandle(), + call.getType(), + newArguments); + return callExpression; + } + + @Override + public RowExpression visitLambda(LambdaDefinitionExpression lambda, Object context) + { + return lambda; + } + + @Override + public RowExpression visitSpecialForm(SpecialFormExpression specialForm, Object context) + { + SpecialFormExpression result; + List newArguments = new ArrayList<>(); + for (RowExpression argument : specialForm.getArguments()) { + RowExpression newArgument = argument.accept(this, context); + newArguments.add(newArgument); + } + result = new SpecialFormExpression( + specialForm.getSourceLocation(), + specialForm.getForm(), + specialForm.getType(), + newArguments); + return result; + } + } +} From 4b77a02f2be407c2c984d8e3b004b8aa3d09e3bc Mon Sep 17 00:00:00 2001 From: Timothy Meehan Date: Wed, 12 Nov 2025 10:18:16 -0800 Subject: [PATCH 4/4] docs(native): Document v1/expressions endpoint Co-authored-by: Pramod Satya --- .../src/main/sphinx/presto_cpp/sidecar.rst | 7 + .../src/main/resources/expressions.yaml | 173 ++++++++++++++++++ 2 files changed, 180 insertions(+) create mode 100644 presto-openapi/src/main/resources/expressions.yaml diff --git a/presto-docs/src/main/sphinx/presto_cpp/sidecar.rst b/presto-docs/src/main/sphinx/presto_cpp/sidecar.rst index 2c00e86c7cac9..1ff704bef9bb5 100644 --- a/presto-docs/src/main/sphinx/presto_cpp/sidecar.rst +++ b/presto-docs/src/main/sphinx/presto_cpp/sidecar.rst @@ -38,6 +38,13 @@ The following HTTP endpoints are implemented by the Presto C++ sidecar. validates the Velox plan. Returns any errors encountered during plan conversion. +.. function:: POST /v1/expressions + + Optimizes a list of ``RowExpression``\s from the http request using + a combination of logical rewrites and constant folding, by leveraging + the ``ExprOptimizer`` from Velox, and returns a list of optimized + ``RowExpression``\s. + Configuration Properties ------------------------ diff --git a/presto-openapi/src/main/resources/expressions.yaml b/presto-openapi/src/main/resources/expressions.yaml new file mode 100644 index 0000000000000..31e1ecde4c466 --- /dev/null +++ b/presto-openapi/src/main/resources/expressions.yaml @@ -0,0 +1,173 @@ +openapi: 3.0.0 +info: + title: Presto Expression API + description: API for evaluating and simplifying row expressions in Presto + version: "1" +servers: + - url: http://localhost:8080 + description: Presto endpoint when running locally +paths: + /v1/expressions: + post: + summary: Simplify the list of row expressions + description: This endpoint takes in a list of row expressions and attempts to simplify them to their simplest logical equivalent expression. + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RowExpressions' + required: true + responses: + '200': + description: Results + content: + application/json: + schema: + $ref: '#/components/schemas/RowExpressions' +components: + schemas: + RowExpressions: + type: array + maxItems: 100 + items: + $ref: "#/components/schemas/RowExpression" + RowExpression: + oneOf: + - $ref: "#/components/schemas/ConstantExpression" + - $ref: "#/components/schemas/VariableReferenceExpression" + - $ref: "#/components/schemas/InputReferenceExpression" + - $ref: "#/components/schemas/LambdaDefinitionExpression" + - $ref: "#/components/schemas/SpecialFormExpression" + - $ref: "#/components/schemas/CallExpression" + RowExpressionParent: + type: object + properties: + sourceLocation: + $ref: "#/components/schemas/SourceLocation" + SourceLocation: + description: The source location of the row expression in the original query, referencing the line and the column of the query. + type: object + properties: + line: + type: integer + column: + type: integer + ConstantExpression: + description: A constant expression is a row expression that represents a constant value. The value attribute is the constant value. + allOf: + - $ref: "#/components/schemas/RowExpressionParent" + - type: object + properties: + "@type": + type: string + enum : ["constant"] + typeSignature: + type: string + valueBlock: + type: string + VariableReferenceExpression: + description: A variable reference expression is a row expression that represents a reference to a variable. The name attribute indicates the name of the variable. + allOf: + - $ref: "#/components/schemas/RowExpressionParent" + - type: object + properties: + "@type": + type: string + enum : ["variable"] + typeSignature: + type: string + name: + type: string + InputReferenceExpression: + description: > + An input reference expression is a row expression that represents a reference to a column in the input schema. The field attribute indicates the index of the column in the + input schema. + allOf: + - $ref: "#/components/schemas/RowExpressionParent" + - type: object + properties: + "@type": + type: string + enum : ["input"] + typeSignature: + type: string + field: + type: integer + LambdaDefinitionExpression: + description: > + A lambda definition expression is a row expression that represents a lambda function. The lambda function is defined by a list of argument types, a list of argument names, + and a body expression. + allOf: + - $ref: "#/components/schemas/RowExpressionParent" + - type: object + properties: + "@type": + type: string + enum : ["lambda"] + argumentTypeSignatures: + type: array + items: + type: string + arguments: + type: array + items: + type: string + body: + $ref: "#/components/schemas/RowExpression" + SpecialFormExpression: + description: > + A special form expression is a row expression that represents a special language construct. The form attribute indicates the specific form of the special form, + which is a well known list, and with each having special semantics. The arguments attribute is a list of row expressions that are the arguments to the special form, with + each form taking in a specific number of arguments. + allOf: + - $ref: "#/components/schemas/RowExpressionParent" + - type: object + properties: + "@type": + type: string + enum : ["special"] + form: + type: string + enum: ["IF","NULL_IF","SWITCH","WHEN","IS_NULL","COALESCE","IN","AND","OR","DEREFERENCE","ROW_CONSTRUCTOR","BIND"] + returnTypeSignature: + type: string + arguments: + type: array + items: + $ref: "#/components/schemas/RowExpression" + CallExpression: + description: > + A call expression is a row expression that represents a call to a function. The functionHandle attribute is an opaque handle to the function that is being called. + The arguments attribute is a list of row expressions that are the arguments to the function. + allOf: + - $ref: "#/components/schemas/RowExpressionParent" + - type: object + properties: + "@type": + type: string + enum : ["call"] + displayName: + type: string + functionHandle: + $ref: "#/components/schemas/FunctionHandle" + returnTypeSignature: + type: string + arguments: + type: array + items: + $ref: "#/components/schemas/RowExpression" + FunctionHandle: + description: An opaque handle to a function that may be invoked. This is interpreted by the registered function namespace manager. + anyOf: + - $ref: "#/components/schemas/OpaqueFunctionHandle" + - $ref: "#/components/schemas/SqlFunctionHandle" + OpaqueFunctionHandle: + type: object + properties: {} # any opaque object may be passed and interpreted by a function namespace manager + SqlFunctionHandle: + type: object + properties: + functionId: + type: string + version: + type: string