Skip to content

Commit

Permalink
Support other integer types for SubstringUTF8 & RightUTF8 functions (#…
Browse files Browse the repository at this point in the history
…9507) (#9510)

close #9473

Support other integer types for SubstringUTF8 & RightUTF8 functions

Signed-off-by: gengliqi <[email protected]>

Co-authored-by: gengliqi <[email protected]>
  • Loading branch information
ti-chi-bot and gengliqi authored Oct 10, 2024
1 parent 6578d2f commit 8534177
Show file tree
Hide file tree
Showing 6 changed files with 403 additions and 157 deletions.
187 changes: 120 additions & 67 deletions dbms/src/Functions/FunctionsString.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1737,26 +1737,41 @@ class FunctionSubstringUTF8 : public IFunction
bool is_start_type_valid
= getNumberType(block.getByPosition(arguments[1]).type, [&](const auto & start_type, bool) {
using StartType = std::decay_t<decltype(start_type)>;
// Int64 / UInt64
using StartFieldType = typename StartType::FieldType;
const ColumnVector<StartFieldType> * column_vector_start
= getInnerColumnVector<StartFieldType>(column_start);
if unlikely (!column_vector_start)
throw Exception(
fmt::format(
"Illegal type {} of argument 2 of function {}",
block.getByPosition(arguments[1]).type->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

// vector const const
if (!column_string->isColumnConst() && column_start->isColumnConst()
&& (implicit_length || block.getByPosition(arguments[2]).column->isColumnConst()))
{
auto [is_positive, start_abs]
= getValueFromStartField<StartFieldType>((*block.getByPosition(arguments[1]).column)[0]);
auto [is_positive, start_abs] = getValueFromStartColumn<StartFieldType>(*column_vector_start, 0);
UInt64 length = 0;
if (!implicit_length)
{
bool is_length_type_valid = getNumberType(
block.getByPosition(arguments[2]).type,
[&](const auto & length_type, bool) {
using LengthType = std::decay_t<decltype(length_type)>;
// Int64 / UInt64
using LengthFieldType = typename LengthType::FieldType;
length = getValueFromLengthField<LengthFieldType>(
(*block.getByPosition(arguments[2]).column)[0]);
const ColumnVector<LengthFieldType> * column_vector_length
= getInnerColumnVector<LengthFieldType>(block.getByPosition(arguments[2]).column);
if unlikely (!column_vector_length)
throw Exception(
fmt::format(
"Illegal type {} of argument 3 of function {}",
block.getByPosition(arguments[2]).type->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

length = getValueFromLengthColumn<LengthFieldType>(*column_vector_length, 0);
return true;
});

Expand Down Expand Up @@ -1791,15 +1806,15 @@ class FunctionSubstringUTF8 : public IFunction
if (column_start->isColumnConst())
{
// func always return const value
auto start_const = getValueFromStartField<StartFieldType>((*column_start)[0]);
auto start_const = getValueFromStartColumn<StartFieldType>(*column_vector_start, 0);
get_start_func = [start_const](size_t) {
return start_const;
};
}
else
{
get_start_func = [&column_start](size_t i) {
return getValueFromStartField<StartFieldType>((*column_start)[i]);
get_start_func = [column_vector_start](size_t i) {
return getValueFromStartColumn<StartFieldType>(*column_vector_start, i);
};
}

Expand All @@ -1812,26 +1827,36 @@ class FunctionSubstringUTF8 : public IFunction
block.getByPosition(arguments[2]).type,
[&](const auto & length_type, bool) {
using LengthType = std::decay_t<decltype(length_type)>;
// Int64 / UInt64
using LengthFieldType = typename LengthType::FieldType;
const ColumnVector<LengthFieldType> * column_vector_length
= getInnerColumnVector<LengthFieldType>(column_length);
if unlikely (!column_vector_length)
throw Exception(
fmt::format(
"Illegal type {} of argument 3 of function {}",
block.getByPosition(arguments[2]).type->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

if (column_length->isColumnConst())
{
// func always return const value
auto length_const = getValueFromLengthField<LengthFieldType>((*column_length)[0]);
auto length_const
= getValueFromLengthColumn<LengthFieldType>(*column_vector_length, 0);
get_length_func = [length_const](size_t) {
return length_const;
};
}
else
{
get_length_func = [column_length](size_t i) {
return getValueFromLengthField<LengthFieldType>((*column_length)[i]);
get_length_func = [column_vector_length](size_t i) {
return getValueFromLengthColumn<LengthFieldType>(*column_vector_length, i);
};
}
return true;
});

if (!is_length_type_valid)
if unlikely (!is_length_type_valid)
throw Exception(
fmt::format("3nd argument of function {} must have UInt/Int type.", getName()));
}
Expand Down Expand Up @@ -1869,10 +1894,38 @@ class FunctionSubstringUTF8 : public IFunction
return true;
});

if (!is_start_type_valid)
if unlikely (!is_start_type_valid)
throw Exception(fmt::format("2nd argument of function {} must have UInt/Int type.", getName()));
}

template <typename Integer>
static const ColumnVector<Integer> * getInnerColumnVector(const ColumnPtr & column)
{
if (column->isColumnConst())
return checkAndGetColumn<ColumnVector<Integer>>(
checkAndGetColumn<ColumnConst>(column.get())->getDataColumnPtr().get());
return checkAndGetColumn<ColumnVector<Integer>>(column.get());
}

template <typename Integer>
static size_t getValueFromLengthColumn(const ColumnVector<Integer> & column, size_t index)
{
Integer val = column.getElement(index);
if constexpr (
std::is_same_v<Integer, Int8> || std::is_same_v<Integer, Int16> || std::is_same_v<Integer, Int32>
|| std::is_same_v<Integer, Int64>)
{
return val < 0 ? 0 : val;
}
else
{
static_assert(
std::is_same_v<Integer, UInt8> || std::is_same_v<Integer, UInt16> || std::is_same_v<Integer, UInt32>
|| std::is_same_v<Integer, UInt64>);
return val;
}
}

private:
using VectorConstConstFunc = std::function<void(
const ColumnString::Chars_t &,
Expand All @@ -1896,49 +1949,40 @@ class FunctionSubstringUTF8 : public IFunction
}
}

template <typename Integer>
static size_t getValueFromLengthField(const Field & length_field)
{
if constexpr (std::is_same_v<Integer, Int64>)
{
Int64 signed_length = length_field.get<Int64>();
return signed_length < 0 ? 0 : signed_length;
}
else
{
static_assert(std::is_same_v<Integer, UInt64>);
return length_field.get<UInt64>();
}
}

// return {is_positive, abs}
template <typename Integer>
static std::pair<bool, size_t> getValueFromStartField(const Field & start_field)
static std::pair<bool, size_t> getValueFromStartColumn(const ColumnVector<Integer> & column, size_t index)
{
if constexpr (std::is_same_v<Integer, Int64>)
Integer val = column.getElement(index);
if constexpr (
std::is_same_v<Integer, Int8> || std::is_same_v<Integer, Int16> || std::is_same_v<Integer, Int32>
|| std::is_same_v<Integer, Int64>)
{
Int64 signed_length = start_field.get<Int64>();

if (signed_length < 0)
{
return {false, static_cast<size_t>(-signed_length)};
}
else
{
return {true, static_cast<size_t>(signed_length)};
}
if (val < 0)
return {false, static_cast<size_t>(-val)};
return {true, static_cast<size_t>(val)};
}
else
{
static_assert(std::is_same_v<Integer, UInt64>);
return {true, start_field.get<UInt64>()};
static_assert(
std::is_same_v<Integer, UInt8> || std::is_same_v<Integer, UInt16> || std::is_same_v<Integer, UInt32>
|| std::is_same_v<Integer, UInt64>);
return {true, val};
}
}

template <typename F>
static bool getNumberType(DataTypePtr type, F && f)
{
return castTypeToEither<DataTypeInt64, DataTypeUInt64>(type.get(), std::forward<F>(f));
return castTypeToEither<
DataTypeUInt8,
DataTypeUInt16,
DataTypeUInt32,
DataTypeUInt64,
DataTypeInt8,
DataTypeInt16,
DataTypeInt32,
DataTypeInt64>(type.get(), std::forward<F>(f));
}
};

Expand Down Expand Up @@ -1977,16 +2021,28 @@ class FunctionRightUTF8 : public IFunction
bool is_length_type_valid
= getLengthType(block.getByPosition(arguments[1]).type, [&](const auto & length_type, bool) {
using LengthType = std::decay_t<decltype(length_type)>;
// Int64 / UInt64
using LengthFieldType = typename LengthType::FieldType;

const ColumnVector<LengthFieldType> * column_vector_length
= FunctionSubstringUTF8::getInnerColumnVector<LengthFieldType>(column_length);
if unlikely (!column_vector_length)
throw Exception(
fmt::format(
"Illegal type {} of argument 2 of function {}",
block.getByPosition(arguments[1]).type->getName(),
getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);


auto col_res = ColumnString::create();
if (const auto * col_string = checkAndGetColumn<ColumnString>(column_string.get()))
{
if (column_length->isColumnConst())
{
// vector const
size_t length = getValueFromLengthField<LengthFieldType>((*column_length)[0]);
size_t length = FunctionSubstringUTF8::getValueFromLengthColumn<LengthFieldType>(
*column_vector_length,
0);

// for const 0, return const blank string.
if (0 == length)
Expand All @@ -2006,8 +2062,10 @@ class FunctionRightUTF8 : public IFunction
else
{
// vector vector
auto get_length_func = [&column_length](size_t i) {
return getValueFromLengthField<LengthFieldType>((*column_length)[i]);
auto get_length_func = [column_vector_length](size_t i) {
return FunctionSubstringUTF8::getValueFromLengthColumn<LengthFieldType>(
*column_vector_length,
i);
};
RightUTF8Impl::vectorVector(
col_string->getChars(),
Expand All @@ -2026,8 +2084,10 @@ class FunctionRightUTF8 : public IFunction
assert(col_string_from_const);
// When useDefaultImplementationForConstants is true, string and length are not both constants
assert(!column_length->isColumnConst());
auto get_length_func = [&column_length](size_t i) {
return getValueFromLengthField<LengthFieldType>((*column_length)[i]);
auto get_length_func = [column_vector_length](size_t i) {
return FunctionSubstringUTF8::getValueFromLengthColumn<LengthFieldType>(
*column_vector_length,
i);
};
RightUTF8Impl::constVector(
column_length->size(),
Expand All @@ -2054,22 +2114,15 @@ class FunctionRightUTF8 : public IFunction
template <typename F>
static bool getLengthType(DataTypePtr type, F && f)
{
return castTypeToEither<DataTypeInt64, DataTypeUInt64>(type.get(), std::forward<F>(f));
}

template <typename Integer>
static size_t getValueFromLengthField(const Field & length_field)
{
if constexpr (std::is_same_v<Integer, Int64>)
{
Int64 signed_length = length_field.get<Int64>();
return signed_length < 0 ? 0 : signed_length;
}
else
{
static_assert(std::is_same_v<Integer, UInt64>);
return length_field.get<UInt64>();
}
return castTypeToEither<
DataTypeUInt8,
DataTypeUInt16,
DataTypeUInt32,
DataTypeUInt64,
DataTypeInt8,
DataTypeInt16,
DataTypeInt32,
DataTypeInt64>(type.get(), std::forward<F>(f));
}
};

Expand Down
Loading

0 comments on commit 8534177

Please sign in to comment.