Skip to content

Commit

Permalink
remove include files, add functions supported float types and more fu…
Browse files Browse the repository at this point in the history
…nction tests, the unit tests are remaining to be added.
  • Loading branch information
kche0169 committed Jan 9, 2025
1 parent b14d2a6 commit 655ca81
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 24 deletions.
13 changes: 8 additions & 5 deletions python/test_pysdk/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,15 +1028,18 @@ def test_select_truncate(self, suffix):
db_obj = self.infinity_obj.get_database("default_db")
db_obj.drop_table("test_select_truncate" + suffix, ConflictType.Ignore)
db_obj.create_table("test_select_truncate" + suffix,
{"c": {"type": "double"}}, ConflictType.Error)
{"c1": {"type": "double"},
"c2": {"type": "float"}}, ConflictType.Error)
table_obj = db_obj.get_table("test_select_truncate" + suffix)
table_obj.insert(
[{'c': 2.123}, {'c': -2.123}, {'c': 2}, {'c': 2.1}, {'c': float("nan")}, {'c': float("inf")}, {'c':float("-inf")}])
[{"c1": 2.123, "c2":float(2.123)}, {"c1": -2.123, "c2":float(-2.123)}, {"c1": 2, "c2": float(2)}, {"c1": 2.1, "c2": float(2.1)},
{"c1": float("nan"), "c2": float("nan")}, {"c1":float("-inf"), "c2": float("-inf")}, {"c1":float("-inf"), "c2": float("-inf")}])

res, extra_res = table_obj.output(["trunc(c, 2)"]).to_df()
res, extra_res = table_obj.output(["trunc(c1, 2)", "trunc(c2, 2)"]).to_df()
print(res)
pd.testing.assert_frame_equal(res, pd.DataFrame({'(c trunc 2)': ("2.12", "-2.12", "2.00", "2.10", "NaN", "Inf", "Inf")})
.astype({'(c trunc 2)': dtype('str_')}))
pd.testing.assert_frame_equal(res, pd.DataFrame({'(c1 trunc 2)': ("2.12", "-2.12", "2.00", "2.10", "NaN", "Inf", "Inf"),
'(c2 trunc 2)': ("2.12", "-2.12", "2.00", "2.10", "NaN", "Inf", "Inf")})
.astype({'(c1 trunc 2)': dtype('str_'), '(c2 trunc 2)': dtype('str_')}))


res = db_obj.drop_table("test_select_truncate" + suffix)
Expand Down
9 changes: 8 additions & 1 deletion src/common/stl.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@ module;
#include <cmath>
#include <cstdint>
#include <cstdio>
#include <cstddef>
#include <cstdlib>
#include <cstring>
#include <exception>
#include <filesystem>
#include <forward_list>
#include <functional>
#include <iomanip>
#include <ios>
#include <iostream>
#include <iterator>
#include <list>
Expand All @@ -40,6 +42,7 @@ module;
#include <optional>
#include <random>
#include <ranges>
#include <sched.h>
#include <semaphore>
#include <set>
#include <shared_mutex>
Expand Down Expand Up @@ -160,6 +163,11 @@ using std::stable_sort;
using std::tie;
using std::transform;
using std::unique;
using std::setprecision;
using std::fixed;

using std::string;
using std::stringstream;

namespace ranges {

Expand Down Expand Up @@ -203,7 +211,6 @@ using std::chrono::days;
using std::chrono::day;
using std::chrono::month;
using std::chrono::year;

using std::chrono::sys_days;
using std::chrono::system_clock;
using std::chrono::year_month_day;
Expand Down
57 changes: 45 additions & 12 deletions src/function/scalar/truncate.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@

// Copyright(C) 2025 InfiniFlow, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

module;
#include <type_traits>
#include "type/internal_types.h"
#include <ostream>
#include "type/logical_type.h"
#include <cstddef>
#include <cmath>
#include <iostream>
#include <sstream>
#include <string>
#include <iomanip>
module trunc;
import stl;
import catalog;
Expand Down Expand Up @@ -41,7 +43,7 @@ inline void TruncFunction::Run(DoubleT left, BigIntT right, VarcharT &result, Co
ss << left;
std::string str = ss.str();
std::string truncated_str;
size_t i = str.find_first_of('.');
int i = str.find_first_of('.');
if (right < static_cast<BigIntT>(0) || std::isnan(right) || std::isinf(right)) {
Status status = Status::InvalidDataType();
RecoverableError(status);
Expand All @@ -58,6 +60,30 @@ inline void TruncFunction::Run(DoubleT left, BigIntT right, VarcharT &result, Co
result_ptr->AppendVarcharInner(truncated_str, result);
}

template <>
inline void TruncFunction::Run(FloatT left, BigIntT right, VarcharT &result, ColumnVector *result_ptr) {
std::stringstream ss;
ss << std::fixed << std::setprecision(6);
ss << left;
std::string str = ss.str();
std::string truncated_str;
int i = str.find_first_of('.');
if (right < static_cast<BigIntT>(0) || std::isnan(right) || std::isinf(right)) {
Status status = Status::InvalidDataType();
RecoverableError(status);
return;
} else if (std::isnan(left)) {
truncated_str = "NaN";
} else if (std::isinf(left)) {
truncated_str = "Inf";
} else if (right > static_cast<BigIntT>(7) || static_cast<BigIntT>(str.size() - i) < right || right == static_cast<BigIntT>(0)) {
truncated_str = str.substr(0, i);
} else {
truncated_str = str.substr(0, i + right + 1);
}
result_ptr->AppendVarcharInner(truncated_str, result);
}

void RegisterTruncFunction(const UniquePtr<Catalog> &catalog_ptr) {
String func_name = "trunc";

Expand All @@ -69,6 +95,13 @@ void RegisterTruncFunction(const UniquePtr<Catalog> &catalog_ptr) {
&ScalarFunction::BinaryFunctionToVarlen<DoubleT, BigIntT, VarcharT, TruncFunction>);
function_set_ptr->AddFunction(truncate_double_bigint);

ScalarFunction truncate_float_bigint(func_name,
{DataType(LogicalType::kFloat), DataType(LogicalType::kBigInt)},
DataType(LogicalType::kVarchar),
&ScalarFunction::BinaryFunctionToVarlen<FloatT, BigIntT, VarcharT, TruncFunction>);
function_set_ptr->AddFunction(truncate_float_bigint);


Catalog::AddFunctionSet(catalog_ptr.get(), function_set_ptr);
}

Expand Down
21 changes: 17 additions & 4 deletions src/unit_test/function/scalar/truncate_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ import column_vector;
import logical_type;
import internal_types;
import data_type;
#if 0
using namespace infinity;
class TruncateFunctionsTest : public BaseTestParamStr {};

Expand All @@ -52,12 +51,26 @@ TEST_P(TruncateFunctionsTest, truncate_func) {

RegisterAbsFunction(catalog_ptr);

String op = "truncate";
String op = "trunc";

SharedPtr<FunctionSet> function_set = Catalog::GetFunctionSetByName(catalog_ptr.get(), op);
EXPECT_EQ(function_set->type_, FunctionType::kScalar);
SharedPtr<ScalarFunctionSet> scalar_function_set = std::static_pointer_cast<ScalarFunctionSet>(function_set);

{}
{
Vector<SharedPtr<BaseExpression>> inputs;

DataType data_type1(LogicalType::kFloat);
DataType data_type2(LogicalType::kBigInt);
DataType result_type(LogicalType::kVarchar);
SharedPtr<ColumnExpression> col1_expr_ptr = MakeShared<ColumnExpression>(data_type1, "t1", 1, "c1", 0, 0);
SharedPtr<ColumnExpression> col2_expr_ptr = MakeShared<ColumnExpression>(data_type2, "t1", 1, "c2", 1, 0);

inputs.emplace_back(col1_expr_ptr);
inputs.emplace_back(col2_expr_ptr);

ScalarFunction func = scalar_function_set->GetMostMatchFunction(inputs);
EXPECT_STREQ("POW(Heterogeneous, Double)->Double", func.ToString().c_str());
}
}
#endif

4 changes: 2 additions & 2 deletions tools/run_pysdk_remote_infinity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
python_executable = sys.executable


def python_sdk_test(python_test_dir: str, pytest_mark: str):
def python_sdk_test(python_test_dir: str, pytest_mark: str, single_file_name="/test_select.py"):
print("python test path is {}".format(python_test_dir))
# run test
print(f"start pysdk test with {pytest_mark}")
Expand All @@ -22,7 +22,7 @@ def python_sdk_test(python_test_dir: str, pytest_mark: str):
"-x",
"-m",
pytest_mark,
f"{python_test_dir}/test_pysdk",
f"{python_test_dir}/test_pysdk{single_file_name}",
]
quoted_args = ['"' + arg + '"' if " " in arg else arg for arg in args]
print(" ".join(quoted_args))
Expand Down

0 comments on commit 655ca81

Please sign in to comment.