Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion bolt/functions/sparksql/DecimalVectorFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ roundUpAndDown(R& r, const A& a, const B& b, bool noRoundUp, uint8_t aRescale) {
B unsignedDivisor(b);
bool roundUpSign = ((noRoundUp && a > 0) || (!noRoundUp && a < 0));
R quotient = unsignedDividendRescaled / unsignedDivisor;
R remainder = unsignedDividendRescaled % unsignedDivisor;
// Keep the remainder as wide as the divisor: long decimal inputs can have
// fractional remainders that overflow short decimal result storage.
B remainder = unsignedDividendRescaled % unsignedDivisor;
if (roundUpSign && static_cast<const B>(remainder) > 0) {
++quotient;
}
Expand Down
1 change: 1 addition & 0 deletions bolt/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ add_executable(
DecimalCompareTest.cpp
DecimalRoundTest.cpp
DecimalUtilTest.cpp
DecimalVectorFunctionsTest.cpp
ElementAtTest.cpp
FactorialTest.cpp
FromJsonTest.cpp
Expand Down
33 changes: 33 additions & 0 deletions bolt/functions/sparksql/tests/DecimalCompareTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -983,5 +983,38 @@ TEST_F(DecimalCompareTest, lte) {
140, [](auto row) { return row % 2 == 0 ? true : false; }));
}

TEST_F(DecimalCompareTest, mixedScaleNullable) {
auto dec1 = makeNullableFlatVector<int64_t>(
{124000000L, 341234567L, 391234578L, std::nullopt, 345L}, DECIMAL(18, 7));
auto dec2 = makeNullableFlatVector<int64_t>(
{1240000, 12845678L, 1298765L, 123L, std::nullopt}, DECIMAL(18, 5));

testCompareExpr(
"decimal_equalto(c0, c1)",
{dec1, dec2},
makeNullableFlatVector<bool>(
{true, false, false, std::nullopt, std::nullopt}));
testCompareExpr(
"decimal_lessthan(c0, c1)",
{dec1, dec2},
makeNullableFlatVector<bool>(
{false, true, false, std::nullopt, std::nullopt}));
testCompareExpr(
"decimal_lessthanorequal(c0, c1)",
{dec1, dec2},
makeNullableFlatVector<bool>(
{true, true, false, std::nullopt, std::nullopt}));
testCompareExpr(
"decimal_greaterthan(c0, c1)",
{dec1, dec2},
makeNullableFlatVector<bool>(
{false, false, true, std::nullopt, std::nullopt}));
testCompareExpr(
"decimal_greaterthanorequal(c0, c1)",
{dec1, dec2},
makeNullableFlatVector<bool>(
{true, false, true, std::nullopt, std::nullopt}));
}

} // namespace
} // namespace bytedance::bolt::functions::sparksql::test
7 changes: 7 additions & 0 deletions bolt/functions/sparksql/tests/DecimalRoundTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ TEST_F(DecimalRoundTest, round) {
makeFlatVector<int64_t>({123, 552, -999, 0}, DECIMAL(3, 3)),
3,
makeFlatVector<int64_t>({123, 552, -999, 0}, DECIMAL(4, 3)));
// The result can be obtained by Spark spark-shell CLI.
// scala> spark.sql("select round(cast(0.123 as decimal(3,3)), 30)")
// decimal(4,3)
testDecimalRound(
makeFlatVector<int64_t>({123, 552, -999, 0}, DECIMAL(3, 3)),
30,
makeFlatVector<int64_t>({123, 552, -999, 0}, DECIMAL(4, 3)));

// Round to 'scale - 1'.
testDecimalRound(
Expand Down
162 changes: 0 additions & 162 deletions bolt/functions/sparksql/tests/DecimalVectorFunctionTest.cpp

This file was deleted.

109 changes: 26 additions & 83 deletions bolt/functions/sparksql/tests/DecimalVectorFunctionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,89 +54,6 @@ class DecimalVectorFunctionsTest : public SparkFunctionBaseTest {
}
};

// The result can be obtained by Spark spark-shell CLI.
// scala> val df = spark.sql("select round(cast(0.123 as decimal(3,3)), 30);")
// df: org.apache.spark.sql.DataFrame = [round(CAST(0.123 AS DECIMAL(3,3)), 30):
// decimal(4,3)]
TEST_F(DecimalVectorFunctionsTest, round) {
// Round up to 'scale' decimal places.
testDecimalExpr<TypeKind::BIGINT>(
{makeFlatVector<int64_t>({123, 552, -999, 0}, DECIMAL(4, 3))},
"decimal_round(c0, CAST(30 as integer))",
{makeFlatVector<int64_t>({123, 552, -999, 0}, DECIMAL(3, 3))});

// Round up to scale-1 decimal places.
testDecimalExpr<TypeKind::BIGINT>(
{makeFlatVector<int64_t>({12, 55, -100, 0}, DECIMAL(3, 2))},
"decimal_round(c0, CAST(2 as integer))",
{makeFlatVector<int64_t>({123, 552, -999, 0}, DECIMAL(3, 3))});
// Round up to 0 decimal places.
testDecimalExpr<TypeKind::BIGINT>(
{makeFlatVector<int64_t>({0, 1, -1, 0}, DECIMAL(1, 0))},
"decimal_round(c0, CAST(0 as integer))",
{makeFlatVector<int64_t>({123, 552, -999, 0}, DECIMAL(3, 3))});
testDecimalExpr<TypeKind::BIGINT>(
{makeFlatVector<int64_t>({1, 6, -10, 0}, DECIMAL(2, 0))},
"decimal_round(c0, CAST(0 as integer))",
{makeFlatVector<int64_t>({123, 552, -999, 0}, DECIMAL(3, 2))});
testDecimalExpr<TypeKind::BIGINT>(
{makeFlatVector<int64_t>({1, 6, -10, 0}, DECIMAL(2, 0))},
"decimal_round(c0)",
{makeFlatVector<int64_t>({123, 552, -999, 0}, DECIMAL(3, 2))});
// Round up to -1 decimal places.
testDecimalExpr<TypeKind::BIGINT>(
{makeFlatVector<int64_t>({10, 60, -100, 0}, DECIMAL(3, 0))},
"decimal_round(c0, CAST(-1 as integer))",
{makeFlatVector<int64_t>({123, 552, -999, 0}, DECIMAL(3, 1))});
// Round up to -2 decimal places. Here precision == -scale + 1.
testDecimalExpr<TypeKind::BIGINT>(
{makeFlatVector<int64_t>({0, 0, 0, 0}, DECIMAL(4, 0))},
"decimal_round(c0, CAST(-3 as integer))",
{makeFlatVector<int64_t>({123, 552, -999, 0}, DECIMAL(3, 1))});

// Round up long decimals to short decimals.
testDecimalExpr<TypeKind::BIGINT>(
{makeNullableFlatVector<int64_t>(
{12345678901235, 50000000000000, -10'000'000'000'000, 0},
DECIMAL(15, 14))},
"decimal_round(c0, CAST(14 as integer))",
{makeFlatVector<int128_t>(
{1234567890123456789, 5000000000000000000, -999999999999999999, 0},
DECIMAL(19, 19))});
testDecimalExpr<TypeKind::BIGINT>(
{makeFlatVector<int64_t>(
{12346000000000, 55556000000000, -10000000000000, 0},
DECIMAL(15, 0))},
"decimal_round(c0, CAST(-9 as integer))",
{makeFlatVector<int128_t>(
{1234567890123456789, 5555555555555555555, -999999999999999999, 0},
DECIMAL(19, 5))});
// Round up long decimals to long decimals.
testDecimalExpr<TypeKind::HUGEINT>(
{makeFlatVector<int128_t>(
{1234567890123456789, 5555555555555555555, -999999999999999999, 0},
DECIMAL(20, 5))},
"decimal_round(c0, CAST(14 as integer))",
{makeFlatVector<int128_t>(
{1234567890123456789, 5555555555555555555, -999999999999999999, 0},
DECIMAL(19, 5))});
testDecimalExpr<TypeKind::HUGEINT>(
{makeFlatVector<int128_t>(
{12346000000000, 55556000000000, -10000000000000, 0},
DECIMAL(28, 0))},
"decimal_round(c0, CAST(-9 as integer))",
{makeFlatVector<int128_t>(
{1234567890123456789, 5555555555555555555, -999999999999999999, 0},
DECIMAL(32, 5))});
// Result precision is 38
testDecimalExpr<TypeKind::HUGEINT>(
{makeFlatVector<int128_t>({0, 0, 0, 0}, DECIMAL(38, 0))},
"decimal_round(c0, CAST(-38 as integer))",
{makeFlatVector<int128_t>(
{1234567890123456789, 5555555555555555555, -999999999999999999, 0},
DECIMAL(32, 0))});
}

TEST_F(DecimalVectorFunctionsTest, ceil) {
testDecimalExpr<TypeKind::BIGINT>(
{makeFlatVector<int64_t>({3, 6, -9, 0}, DECIMAL(2, 0))},
Expand All @@ -148,6 +65,13 @@ TEST_F(DecimalVectorFunctionsTest, ceil) {
{makeFlatVector<int128_t>(
{1234567890123456789, 5000000000000000000, -999999999999999999, 0},
DECIMAL(19, 19))});
testDecimalExpr<TypeKind::BIGINT>(
{makeFlatVector<int64_t>(std::vector<int64_t>{101}, DECIMAL(4, 0))},
"ceil(c0)",
{makeFlatVector<int128_t>(
std::vector<int128_t>{
int128_t(100100100100100100) * 100000000000 + 10010010010},
DECIMAL(29, 26))});
}

TEST_F(DecimalVectorFunctionsTest, floor) {
Expand Down Expand Up @@ -196,5 +120,24 @@ TEST_F(DecimalVectorFunctionsTest, negative) {
{1234567890123456789, 5000000000000000000, -999999999999999999, 0},
DECIMAL(19, 19))});
}

TEST_F(DecimalVectorFunctionsTest, ceilFloor) {
auto input =
makeFlatVector<int64_t>({111111, 123456, 123478}, DECIMAL(15, 2));

{
auto expected = makeFlatVector<int64_t>({1112, 1235, 1235}, DECIMAL(14, 0));
auto result =
evaluate<SimpleVector<int64_t>>("ceil(c0)", makeRowVector({input}));
bolt::test::assertEqualVectors(result, expected);
}

{
auto expected = makeFlatVector<int64_t>({1111, 1234, 1234}, DECIMAL(14, 0));
auto result =
evaluate<SimpleVector<int64_t>>("floor(c0)", makeRowVector({input}));
bolt::test::assertEqualVectors(result, expected);
}
}
} // namespace
} // namespace bytedance::bolt::functions::sparksql::test
9 changes: 9 additions & 0 deletions bolt/functions/sparksql/tests/UnscaledValueFunctionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,14 @@ TEST_F(UnscaledValueFunctionTest, unscaledValue) {
testUnscaledValue({1000, 2000, -3000, -4000}, DECIMAL(20, 3)),
"Expect short decimal type, but got: DECIMAL(20, 3)");
}

TEST_F(UnscaledValueFunctionTest, nullableShortDecimal) {
auto input = makeNullableFlatVector<int64_t>(
{1111111, std::nullopt, 10000}, DECIMAL(15, 10));
auto expected =
makeNullableFlatVector<int64_t>({1111111, std::nullopt, 10000});
auto result = evaluate("unscaled_value(c0)", makeRowVector({input}));
assertEqualVectors(expected, result);
}
} // namespace
} // namespace bytedance::bolt::functions::sparksql::test