diff --git a/bolt/functions/sparksql/DecimalVectorFunctions.cpp b/bolt/functions/sparksql/DecimalVectorFunctions.cpp index a2503c0c6..0fa8d5f68 100644 --- a/bolt/functions/sparksql/DecimalVectorFunctions.cpp +++ b/bolt/functions/sparksql/DecimalVectorFunctions.cpp @@ -234,7 +234,9 @@ roundUpAndDown(R& r, const A& a, const B& b, bool noRoundUp, uint8_t aRescale) { B unsignedDivisor(b); bool roundUpSign = ((noRoundUp && a > 0) || (!noRoundUp && a < 0)); R quotient = unsignedDividendRescaled / unsignedDivisor; - R remainder = unsignedDividendRescaled % unsignedDivisor; + // Keep the remainder as wide as the divisor: long decimal inputs can have + // fractional remainders that overflow short decimal result storage. + B remainder = unsignedDividendRescaled % unsignedDivisor; if (roundUpSign && static_cast(remainder) > 0) { ++quotient; } diff --git a/bolt/functions/sparksql/tests/CMakeLists.txt b/bolt/functions/sparksql/tests/CMakeLists.txt index 5e1a07197..0e2931673 100644 --- a/bolt/functions/sparksql/tests/CMakeLists.txt +++ b/bolt/functions/sparksql/tests/CMakeLists.txt @@ -47,6 +47,7 @@ add_executable( DecimalCompareTest.cpp DecimalRoundTest.cpp DecimalUtilTest.cpp + DecimalVectorFunctionsTest.cpp ElementAtTest.cpp FactorialTest.cpp FromJsonTest.cpp diff --git a/bolt/functions/sparksql/tests/DecimalCompareTest.cpp b/bolt/functions/sparksql/tests/DecimalCompareTest.cpp index 558a9b3bb..a024ff858 100644 --- a/bolt/functions/sparksql/tests/DecimalCompareTest.cpp +++ b/bolt/functions/sparksql/tests/DecimalCompareTest.cpp @@ -983,5 +983,38 @@ TEST_F(DecimalCompareTest, lte) { 140, [](auto row) { return row % 2 == 0 ? true : false; })); } +TEST_F(DecimalCompareTest, mixedScaleNullable) { + auto dec1 = makeNullableFlatVector( + {124000000L, 341234567L, 391234578L, std::nullopt, 345L}, DECIMAL(18, 7)); + auto dec2 = makeNullableFlatVector( + {1240000, 12845678L, 1298765L, 123L, std::nullopt}, DECIMAL(18, 5)); + + testCompareExpr( + "decimal_equalto(c0, c1)", + {dec1, dec2}, + makeNullableFlatVector( + {true, false, false, std::nullopt, std::nullopt})); + testCompareExpr( + "decimal_lessthan(c0, c1)", + {dec1, dec2}, + makeNullableFlatVector( + {false, true, false, std::nullopt, std::nullopt})); + testCompareExpr( + "decimal_lessthanorequal(c0, c1)", + {dec1, dec2}, + makeNullableFlatVector( + {true, true, false, std::nullopt, std::nullopt})); + testCompareExpr( + "decimal_greaterthan(c0, c1)", + {dec1, dec2}, + makeNullableFlatVector( + {false, false, true, std::nullopt, std::nullopt})); + testCompareExpr( + "decimal_greaterthanorequal(c0, c1)", + {dec1, dec2}, + makeNullableFlatVector( + {true, false, true, std::nullopt, std::nullopt})); +} + } // namespace } // namespace bytedance::bolt::functions::sparksql::test diff --git a/bolt/functions/sparksql/tests/DecimalRoundTest.cpp b/bolt/functions/sparksql/tests/DecimalRoundTest.cpp index c898eb0af..1c6bffc6e 100644 --- a/bolt/functions/sparksql/tests/DecimalRoundTest.cpp +++ b/bolt/functions/sparksql/tests/DecimalRoundTest.cpp @@ -86,6 +86,13 @@ TEST_F(DecimalRoundTest, round) { makeFlatVector({123, 552, -999, 0}, DECIMAL(3, 3)), 3, makeFlatVector({123, 552, -999, 0}, DECIMAL(4, 3))); + // The result can be obtained by Spark spark-shell CLI. + // scala> spark.sql("select round(cast(0.123 as decimal(3,3)), 30)") + // decimal(4,3) + testDecimalRound( + makeFlatVector({123, 552, -999, 0}, DECIMAL(3, 3)), + 30, + makeFlatVector({123, 552, -999, 0}, DECIMAL(4, 3))); // Round to 'scale - 1'. testDecimalRound( diff --git a/bolt/functions/sparksql/tests/DecimalVectorFunctionTest.cpp b/bolt/functions/sparksql/tests/DecimalVectorFunctionTest.cpp deleted file mode 100644 index 1e3699f32..000000000 --- a/bolt/functions/sparksql/tests/DecimalVectorFunctionTest.cpp +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Copyright (c) ByteDance Ltd. and/or its affiliates - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "bolt/common/base/tests/GTestUtils.h" -#include "bolt/functions/sparksql/tests/SparkFunctionBaseTest.h" -#include "bolt/type/Type.h" -#include "bolt/vector/BaseVector.h" -using namespace bytedance::bolt; -using namespace bytedance::bolt::test; -using namespace bytedance::bolt::functions::test; -namespace bytedance::bolt::functions::sparksql::test { -namespace { -class DecimalVectorFunctionTest : public SparkFunctionBaseTest { - protected: - template - void testDecimalExpr( - const VectorPtr& expected, - const std::string& expression, - const std::vector& input) { - using EvalType = typename bolt::TypeTraits::NativeType; - auto result = - evaluate>(expression, makeRowVector(input)); - assertEqualVectors(expected, result); - testOpDictVectors(expression, expected, input); - } - - template - void testOpDictVectors( - const std::string& operation, - const VectorPtr& expected, - const std::vector& flatVector) { - // Dictionary vectors as arguments. - auto newSize = flatVector[0]->size() * 2; - std::vector dictVectors; - for (auto i = 0; i < flatVector.size(); ++i) { - auto indices = makeIndices(newSize, [&](int row) { return row / 2; }); - dictVectors.push_back( - VectorTestBase::wrapInDictionary(indices, newSize, flatVector[i])); - } - auto resultIndices = makeIndices(newSize, [&](int row) { return row / 2; }); - auto expectedResultDictionary = - VectorTestBase::wrapInDictionary(resultIndices, newSize, expected); - auto actual = - evaluate>(operation, makeRowVector(dictVectors)); - assertEqualVectors(expectedResultDictionary, actual); - } -}; - -TEST_F(DecimalVectorFunctionTest, DISABLED_makeDecimal) { - testDecimalExpr( - {makeFlatVector({1111, -1112, 9999, 0}, DECIMAL(5, 1))}, - "make_decimal_by_unscaled_value(c0, c1, true)", - {makeFlatVector({1111, -1112, 9999, 0}), - makeConstant(0, 4, DECIMAL(5, 1))}); - testDecimalExpr( - {makeFlatVector( - {11111111, -11112112, 99999999, DecimalUtil::kShortDecimalMax + 1}, - DECIMAL(38, 19))}, - "make_decimal_by_unscaled_value(c0, c1, true)", - {makeFlatVector( - {11111111, -11112112, 99999999, DecimalUtil::kShortDecimalMax + 1}), - makeConstant(0, 4, DECIMAL(38, 19))}); - - testDecimalExpr( - {makeNullableFlatVector( - {101, std::nullopt, std::nullopt}, DECIMAL(3, 1))}, - "make_decimal_by_unscaled_value(c0, c1, true)", - {makeNullableFlatVector({101, std::nullopt, 1000}), - makeConstant(0, 3, DECIMAL(3, 1))}); -} - -TEST_F(DecimalVectorFunctionTest, DISABLED_unscale) { - auto expected = - makeNullableFlatVector({1111111, std::nullopt, 10000}); - auto input = makeNullableFlatVector( - {1111111, std::nullopt, 10000}, DECIMAL(15, 10)); - auto result = evaluate>( - "unscaled_value(c0)", makeRowVector({input})); - assertEqualVectors(result, expected); -} - -TEST_F(DecimalVectorFunctionTest, compare) { - auto dec1 = makeNullableFlatVector( - {124000000L, 341234567L, 391234578L, std::nullopt, 345L}, DECIMAL(18, 7)); - auto dec2 = makeNullableFlatVector( - {1240000, 12845678L, 1298765L, 123L, std::nullopt}, DECIMAL(18, 5)); - { - auto expectedResult = makeNullableFlatVector( - {true, false, false, std::nullopt, std::nullopt}); - auto result = evaluate>( - "decimal_eq(c0, c1)", makeRowVector({dec1, dec2})); - assertEqualVectors(expectedResult, result); - } - - { - auto expectedResult = makeNullableFlatVector( - {false, true, false, std::nullopt, std::nullopt}); - auto result = evaluate>( - "decimal_lt(c0, c1)", makeRowVector({dec1, dec2})); - assertEqualVectors(expectedResult, result); - } - - { - auto expectedResult = makeNullableFlatVector( - {true, true, false, std::nullopt, std::nullopt}); - auto result = evaluate>( - "decimal_lte(c0, c1)", makeRowVector({dec1, dec2})); - assertEqualVectors(expectedResult, result); - } - - { - auto expectedResult = makeNullableFlatVector( - {false, false, true, std::nullopt, std::nullopt}); - auto result = evaluate>( - "decimal_gt(c0, c1)", makeRowVector({dec1, dec2})); - assertEqualVectors(expectedResult, result); - } - - { - auto expectedResult = makeNullableFlatVector( - {true, false, true, std::nullopt, std::nullopt}); - auto result = evaluate>( - "decimal_gte(c0, c1)", makeRowVector({dec1, dec2})); - assertEqualVectors(expectedResult, result); - } -} - -TEST_F(DecimalVectorFunctionTest, round) { - auto input = - makeFlatVector({111111, 123456, 123478}, DECIMAL(15, 2)); - - { - auto expected = makeFlatVector({1112, 1235, 1235}, DECIMAL(14, 0)); - auto result = - evaluate>("ceil(c0)", makeRowVector({input})); - assertEqualVectors(result, expected); - } - - { - auto expected = makeFlatVector({1111, 1234, 1234}, DECIMAL(14, 0)); - auto result = - evaluate>("floor(c0)", makeRowVector({input})); - assertEqualVectors(result, expected); - } -} -} // namespace -} // namespace bytedance::bolt::functions::sparksql::test diff --git a/bolt/functions/sparksql/tests/DecimalVectorFunctionsTest.cpp b/bolt/functions/sparksql/tests/DecimalVectorFunctionsTest.cpp index 02be1ab29..c5350b46e 100644 --- a/bolt/functions/sparksql/tests/DecimalVectorFunctionsTest.cpp +++ b/bolt/functions/sparksql/tests/DecimalVectorFunctionsTest.cpp @@ -54,89 +54,6 @@ class DecimalVectorFunctionsTest : public SparkFunctionBaseTest { } }; -// The result can be obtained by Spark spark-shell CLI. -// scala> val df = spark.sql("select round(cast(0.123 as decimal(3,3)), 30);") -// df: org.apache.spark.sql.DataFrame = [round(CAST(0.123 AS DECIMAL(3,3)), 30): -// decimal(4,3)] -TEST_F(DecimalVectorFunctionsTest, round) { - // Round up to 'scale' decimal places. - testDecimalExpr( - {makeFlatVector({123, 552, -999, 0}, DECIMAL(4, 3))}, - "decimal_round(c0, CAST(30 as integer))", - {makeFlatVector({123, 552, -999, 0}, DECIMAL(3, 3))}); - - // Round up to scale-1 decimal places. - testDecimalExpr( - {makeFlatVector({12, 55, -100, 0}, DECIMAL(3, 2))}, - "decimal_round(c0, CAST(2 as integer))", - {makeFlatVector({123, 552, -999, 0}, DECIMAL(3, 3))}); - // Round up to 0 decimal places. - testDecimalExpr( - {makeFlatVector({0, 1, -1, 0}, DECIMAL(1, 0))}, - "decimal_round(c0, CAST(0 as integer))", - {makeFlatVector({123, 552, -999, 0}, DECIMAL(3, 3))}); - testDecimalExpr( - {makeFlatVector({1, 6, -10, 0}, DECIMAL(2, 0))}, - "decimal_round(c0, CAST(0 as integer))", - {makeFlatVector({123, 552, -999, 0}, DECIMAL(3, 2))}); - testDecimalExpr( - {makeFlatVector({1, 6, -10, 0}, DECIMAL(2, 0))}, - "decimal_round(c0)", - {makeFlatVector({123, 552, -999, 0}, DECIMAL(3, 2))}); - // Round up to -1 decimal places. - testDecimalExpr( - {makeFlatVector({10, 60, -100, 0}, DECIMAL(3, 0))}, - "decimal_round(c0, CAST(-1 as integer))", - {makeFlatVector({123, 552, -999, 0}, DECIMAL(3, 1))}); - // Round up to -2 decimal places. Here precision == -scale + 1. - testDecimalExpr( - {makeFlatVector({0, 0, 0, 0}, DECIMAL(4, 0))}, - "decimal_round(c0, CAST(-3 as integer))", - {makeFlatVector({123, 552, -999, 0}, DECIMAL(3, 1))}); - - // Round up long decimals to short decimals. - testDecimalExpr( - {makeNullableFlatVector( - {12345678901235, 50000000000000, -10'000'000'000'000, 0}, - DECIMAL(15, 14))}, - "decimal_round(c0, CAST(14 as integer))", - {makeFlatVector( - {1234567890123456789, 5000000000000000000, -999999999999999999, 0}, - DECIMAL(19, 19))}); - testDecimalExpr( - {makeFlatVector( - {12346000000000, 55556000000000, -10000000000000, 0}, - DECIMAL(15, 0))}, - "decimal_round(c0, CAST(-9 as integer))", - {makeFlatVector( - {1234567890123456789, 5555555555555555555, -999999999999999999, 0}, - DECIMAL(19, 5))}); - // Round up long decimals to long decimals. - testDecimalExpr( - {makeFlatVector( - {1234567890123456789, 5555555555555555555, -999999999999999999, 0}, - DECIMAL(20, 5))}, - "decimal_round(c0, CAST(14 as integer))", - {makeFlatVector( - {1234567890123456789, 5555555555555555555, -999999999999999999, 0}, - DECIMAL(19, 5))}); - testDecimalExpr( - {makeFlatVector( - {12346000000000, 55556000000000, -10000000000000, 0}, - DECIMAL(28, 0))}, - "decimal_round(c0, CAST(-9 as integer))", - {makeFlatVector( - {1234567890123456789, 5555555555555555555, -999999999999999999, 0}, - DECIMAL(32, 5))}); - // Result precision is 38 - testDecimalExpr( - {makeFlatVector({0, 0, 0, 0}, DECIMAL(38, 0))}, - "decimal_round(c0, CAST(-38 as integer))", - {makeFlatVector( - {1234567890123456789, 5555555555555555555, -999999999999999999, 0}, - DECIMAL(32, 0))}); -} - TEST_F(DecimalVectorFunctionsTest, ceil) { testDecimalExpr( {makeFlatVector({3, 6, -9, 0}, DECIMAL(2, 0))}, @@ -148,6 +65,13 @@ TEST_F(DecimalVectorFunctionsTest, ceil) { {makeFlatVector( {1234567890123456789, 5000000000000000000, -999999999999999999, 0}, DECIMAL(19, 19))}); + testDecimalExpr( + {makeFlatVector(std::vector{101}, DECIMAL(4, 0))}, + "ceil(c0)", + {makeFlatVector( + std::vector{ + int128_t(100100100100100100) * 100000000000 + 10010010010}, + DECIMAL(29, 26))}); } TEST_F(DecimalVectorFunctionsTest, floor) { @@ -196,5 +120,24 @@ TEST_F(DecimalVectorFunctionsTest, negative) { {1234567890123456789, 5000000000000000000, -999999999999999999, 0}, DECIMAL(19, 19))}); } + +TEST_F(DecimalVectorFunctionsTest, ceilFloor) { + auto input = + makeFlatVector({111111, 123456, 123478}, DECIMAL(15, 2)); + + { + auto expected = makeFlatVector({1112, 1235, 1235}, DECIMAL(14, 0)); + auto result = + evaluate>("ceil(c0)", makeRowVector({input})); + bolt::test::assertEqualVectors(result, expected); + } + + { + auto expected = makeFlatVector({1111, 1234, 1234}, DECIMAL(14, 0)); + auto result = + evaluate>("floor(c0)", makeRowVector({input})); + bolt::test::assertEqualVectors(result, expected); + } +} } // namespace } // namespace bytedance::bolt::functions::sparksql::test diff --git a/bolt/functions/sparksql/tests/UnscaledValueFunctionTest.cpp b/bolt/functions/sparksql/tests/UnscaledValueFunctionTest.cpp index 0e072a7d8..5aa19cd64 100644 --- a/bolt/functions/sparksql/tests/UnscaledValueFunctionTest.cpp +++ b/bolt/functions/sparksql/tests/UnscaledValueFunctionTest.cpp @@ -51,5 +51,14 @@ TEST_F(UnscaledValueFunctionTest, unscaledValue) { testUnscaledValue({1000, 2000, -3000, -4000}, DECIMAL(20, 3)), "Expect short decimal type, but got: DECIMAL(20, 3)"); } + +TEST_F(UnscaledValueFunctionTest, nullableShortDecimal) { + auto input = makeNullableFlatVector( + {1111111, std::nullopt, 10000}, DECIMAL(15, 10)); + auto expected = + makeNullableFlatVector({1111111, std::nullopt, 10000}); + auto result = evaluate("unscaled_value(c0)", makeRowVector({input})); + assertEqualVectors(expected, result); +} } // namespace } // namespace bytedance::bolt::functions::sparksql::test