Skip to content

Commit

Permalink
function result name should contain collator info (#2808) (#2818)
Browse files Browse the repository at this point in the history
  • Loading branch information
ti-chi-bot authored Sep 3, 2021
1 parent 9d95db0 commit ad24043
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 35 deletions.
138 changes: 103 additions & 35 deletions dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@

namespace DB
{

namespace ErrorCodes
{
extern const int COP_BAD_DAG_REQUEST;
extern const int UNSUPPORTED_METHOD;
} // namespace ErrorCodes

static String genFuncString(const String & func_name, const Names & argument_names)
static String genFuncString(const String & func_name, const Names & argument_names, const TiDB::TiDBCollators & collators)
{
assert(!collators.empty());
std::stringstream ss;
ss << func_name << "(";
bool first = true;
Expand All @@ -45,7 +45,15 @@ static String genFuncString(const String & func_name, const Names & argument_nam
}
ss << argument_name;
}
ss << ") ";
ss << ")_collator";
for (const auto & collator : collators)
{
if (collator == nullptr)
ss << "_0";
else
ss << "_" << collator->getCollatorId();
}
ss << " ";
return ss.str();
}

Expand Down Expand Up @@ -206,10 +214,14 @@ static String buildLeftUTF8Function(DAGExpressionAnalyzer * analyzer, const tipb

static const String tidb_cast_name = "tidb_cast";

static String buildCastFunctionInternal(DAGExpressionAnalyzer * analyzer, const Names & argument_names, bool in_union,
const tipb::FieldType & field_type, ExpressionActionsPtr & actions)
static String buildCastFunctionInternal(
DAGExpressionAnalyzer * analyzer,
const Names & argument_names,
bool in_union,
const tipb::FieldType & field_type,
ExpressionActionsPtr & actions)
{
String result_name = genFuncString(tidb_cast_name, argument_names);
String result_name = genFuncString(tidb_cast_name, argument_names, {nullptr});
if (actions->getSampleBlock().has(result_name))
return result_name;

Expand Down Expand Up @@ -247,16 +259,29 @@ struct DateAdd
static constexpr auto name = "date_add";
static const std::unordered_map<String, String> unit_to_func_name_map;
};
const std::unordered_map<String, String> DateAdd::unit_to_func_name_map = {{"DAY", "addDays"}, {"WEEK", "addWeeks"}, {"MONTH", "addMonths"},
{"YEAR", "addYears"}, {"HOUR", "addHours"}, {"MINUTE", "addMinutes"}, {"SECOND", "addSeconds"}};
const std::unordered_map<String, String> DateAdd::unit_to_func_name_map
= {
{"DAY", "addDays"},
{"WEEK", "addWeeks"},
{"MONTH", "addMonths"},
{"YEAR", "addYears"},
{"HOUR", "addHours"},
{"MINUTE", "addMinutes"},
{"SECOND", "addSeconds"}};
struct DateSub
{
static constexpr auto name = "date_sub";
static const std::unordered_map<String, String> unit_to_func_name_map;
};
const std::unordered_map<String, String> DateSub::unit_to_func_name_map
= {{"DAY", "subtractDays"}, {"WEEK", "subtractWeeks"}, {"MONTH", "subtractMonths"}, {"YEAR", "subtractYears"},
{"HOUR", "subtractHours"}, {"MINUTE", "subtractMinutes"}, {"SECOND", "subtractSeconds"}};
= {
{"DAY", "subtractDays"},
{"WEEK", "subtractWeeks"},
{"MONTH", "subtractMonths"},
{"YEAR", "subtractYears"},
{"HOUR", "subtractHours"},
{"MINUTE", "subtractMinutes"},
{"SECOND", "subtractSeconds"}};

template <typename Impl>
static String buildDateAddOrSubFunction(DAGExpressionAnalyzer * analyzer, const tipb::Expr & expr, ExpressionActionsPtr & actions)
Expand All @@ -270,12 +295,14 @@ static String buildDateAddOrSubFunction(DAGExpressionAnalyzer * analyzer, const
if (expr.children(2).tp() != tipb::ExprType::String)
{
throw TiFlashException(
std::string() + "3rd argument of " + Impl::name + " function must be string literal", Errors::Coprocessor::BadRequest);
std::string() + "3rd argument of " + Impl::name + " function must be string literal",
Errors::Coprocessor::BadRequest);
}
String unit = expr.children(2).val();
if (Impl::unit_to_func_name_map.find(unit) == Impl::unit_to_func_name_map.end())
throw TiFlashException(
std::string() + Impl::name + " function does not support unit " + unit + " yet.", Errors::Coprocessor::Unimplemented);
std::string() + Impl::name + " function does not support unit " + unit + " yet.",
Errors::Coprocessor::Unimplemented);
String func_name = Impl::unit_to_func_name_map.find(unit)->second;
const auto & date_column_type = removeNullable(actions->getSampleBlock().getByName(date_column).type);
if (!date_column_type->isDateOrDateTime())
Expand Down Expand Up @@ -350,21 +377,32 @@ static std::unordered_map<String, std::function<String(DAGExpressionAnalyzer *,
{"date_add", buildDateAddOrSubFunction<DateAdd>}, {"date_sub", buildDateAddOrSubFunction<DateSub>}});

DAGExpressionAnalyzer::DAGExpressionAnalyzer(std::vector<NameAndTypePair> && source_columns_, const Context & context_)
: source_columns(std::move(source_columns_)), context(context_), after_agg(false), implicit_cast_count(0)
: source_columns(std::move(source_columns_))
, context(context_)
, after_agg(false)
, implicit_cast_count(0)
{
settings = context.getSettings();
}

DAGExpressionAnalyzer::DAGExpressionAnalyzer(std::vector<NameAndTypePair> & source_columns_, const Context & context_)
: source_columns(source_columns_), context(context_), after_agg(false), implicit_cast_count(0)
: source_columns(source_columns_)
, context(context_)
, after_agg(false)
, implicit_cast_count(0)
{
settings = context.getSettings();
}

extern const String CountSecondStage;

void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, const tipb::Aggregation & agg, Names & aggregation_keys,
TiDB::TiDBCollators & collators, AggregateDescriptions & aggregate_descriptions, bool group_by_collation_sensitive)
void DAGExpressionAnalyzer::appendAggregation(
ExpressionActionsChain & chain,
const tipb::Aggregation & agg,
Names & aggregation_keys,
TiDB::TiDBCollators & collators,
AggregateDescriptions & aggregate_descriptions,
bool group_by_collation_sensitive)
{
if (agg.group_by_size() == 0 && agg.agg_func_size() == 0)
{
Expand Down Expand Up @@ -401,7 +439,8 @@ void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, co
aggregate.argument_names[i] = arg_name;
step.required_output.push_back(arg_name);
}
String func_string = genFuncString(agg_func_name, aggregate.argument_names);
auto function_collator = getCollatorFromExpr(expr);
String func_string = genFuncString(agg_func_name, aggregate.argument_names, {function_collator});
bool duplicate = false;
for (const auto & pre_agg : aggregate_descriptions)
{
Expand All @@ -418,7 +457,7 @@ void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, co
aggregate.parameters = Array();
/// if there is group by clause, there is no need to consider the empty input case
aggregate.function = AggregateFunctionFactory::instance().get(agg_func_name, types, {}, 0, agg.group_by_size() == 0);
aggregate.function->setCollator(getCollatorFromExpr(expr));
aggregate.function->setCollator(function_collator);
aggregate_descriptions.push_back(aggregate);
DataTypePtr result_type = aggregate.function->getReturnType();
// this is a temp result since implicit cast maybe added on these aggregated_columns
Expand Down Expand Up @@ -463,7 +502,7 @@ void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, co
types[0] = type;
aggregate.argument_names[0] = name;

String func_string = genFuncString(agg_func_name, aggregate.argument_names);
String func_string = genFuncString(agg_func_name, aggregate.argument_names, {collator});
bool duplicate = false;
for (const auto & pre_agg : aggregate_descriptions)
{
Expand Down Expand Up @@ -508,7 +547,7 @@ bool isUInt8Type(const DataTypePtr & type)
String DAGExpressionAnalyzer::applyFunction(
const String & func_name, const Names & arg_names, ExpressionActionsPtr & actions, std::shared_ptr<TiDB::ITiDBCollator> collator)
{
String result_name = genFuncString(func_name, arg_names);
String result_name = genFuncString(func_name, arg_names, {collator});
if (actions->getSampleBlock().has(result_name))
return result_name;
const FunctionBuilderPtr & function_builder = FunctionFactory::instance().get(func_name, context);
Expand All @@ -518,7 +557,9 @@ String DAGExpressionAnalyzer::applyFunction(
}

void DAGExpressionAnalyzer::appendWhere(
ExpressionActionsChain & chain, const std::vector<const tipb::Expr *> & conditions, String & filter_column_name)
ExpressionActionsChain & chain,
const std::vector<const tipb::Expr *> & conditions,
String & filter_column_name)
{
initChain(chain, getCurrentInputColumns());
ExpressionActionsChain::Step & last_step = chain.steps.back();
Expand Down Expand Up @@ -605,7 +646,9 @@ String DAGExpressionAnalyzer::convertToUInt8(ExpressionActionsPtr & actions, con
}

void DAGExpressionAnalyzer::appendOrderBy(
ExpressionActionsChain & chain, const tipb::TopN & topN, std::vector<NameAndTypePair> & order_columns)
ExpressionActionsChain & chain,
const tipb::TopN & topN,
std::vector<NameAndTypePair> & order_columns)
{
if (topN.order_by_size() == 0)
{
Expand Down Expand Up @@ -645,7 +688,10 @@ void constructTZExpr(tipb::Expr & tz_expr, const TimezoneInfo & dag_timezone_inf
}

String DAGExpressionAnalyzer::appendTimeZoneCast(
const String & tz_col, const String & ts_col, const String & func_name, ExpressionActionsPtr & actions)
const String & tz_col,
const String & ts_col,
const String & func_name,
ExpressionActionsPtr & actions)
{
String cast_expr_name = applyFunction(func_name, {ts_col, tz_col}, actions, nullptr);
return cast_expr_name;
Expand Down Expand Up @@ -688,16 +734,24 @@ bool DAGExpressionAnalyzer::appendTimeZoneCastsAfterTS(ExpressionActionsChain &
}

void DAGExpressionAnalyzer::appendJoin(
ExpressionActionsChain & chain, SubqueryForSet & join_query, const NamesAndTypesList & columns_added_by_join)
ExpressionActionsChain & chain,
SubqueryForSet & join_query,
const NamesAndTypesList & columns_added_by_join)
{
initChain(chain, getCurrentInputColumns());
ExpressionActionsPtr actions = chain.getLastActions();
actions->add(ExpressionAction::ordinaryJoin(join_query.join, columns_added_by_join));
}
/// return true if some actions is needed
bool DAGExpressionAnalyzer::appendJoinKeyAndJoinFilters(ExpressionActionsChain & chain,
const google::protobuf::RepeatedPtrField<tipb::Expr> & keys, const DataTypes & key_types, Names & key_names, bool left,
bool is_right_out_join, const google::protobuf::RepeatedPtrField<tipb::Expr> & filters, String & filter_column_name)
bool DAGExpressionAnalyzer::appendJoinKeyAndJoinFilters(
ExpressionActionsChain & chain,
const google::protobuf::RepeatedPtrField<tipb::Expr> & keys,
const DataTypes & key_types,
Names & key_names,
bool left,
bool is_right_out_join,
const google::protobuf::RepeatedPtrField<tipb::Expr> & filters,
String & filter_column_name)
{
bool ret = false;
initChain(chain, getCurrentInputColumns());
Expand Down Expand Up @@ -849,8 +903,12 @@ void DAGExpressionAnalyzer::appendAggSelect(ExpressionActionsChain & chain, cons
}
}

void DAGExpressionAnalyzer::generateFinalProject(ExpressionActionsChain & chain, const std::vector<tipb::FieldType> & schema,
const std::vector<Int32> & output_offsets, const String & column_prefix, bool keep_session_timezone_info,
void DAGExpressionAnalyzer::generateFinalProject(
ExpressionActionsChain & chain,
const std::vector<tipb::FieldType> & schema,
const std::vector<Int32> & output_offsets,
const String & column_prefix,
bool keep_session_timezone_info,
NamesWithAliases & final_project)
{
if (unlikely(!keep_session_timezone_info && output_offsets.empty()))
Expand Down Expand Up @@ -888,7 +946,8 @@ void DAGExpressionAnalyzer::generateFinalProject(ExpressionActionsChain & chain,
for (auto i : output_offsets)
{
final_project.emplace_back(
current_columns[i].name, unique_name_generator.toUniqueName(column_prefix + current_columns[i].name));
current_columns[i].name,
unique_name_generator.toUniqueName(column_prefix + current_columns[i].name));
}
}
else
Expand Down Expand Up @@ -945,7 +1004,8 @@ void DAGExpressionAnalyzer::generateFinalProject(ExpressionActionsChain & chain,
else
{
final_project.emplace_back(
current_columns[i].name, unique_name_generator.toUniqueName(column_prefix + current_columns[i].name));
current_columns[i].name,
unique_name_generator.toUniqueName(column_prefix + current_columns[i].name));
}
}
}
Expand All @@ -961,7 +1021,10 @@ void DAGExpressionAnalyzer::generateFinalProject(ExpressionActionsChain & chain,
* @return
*/
String DAGExpressionAnalyzer::alignReturnType(
const tipb::Expr & expr, ExpressionActionsPtr & actions, const String & expr_name, bool force_uint8)
const tipb::Expr & expr,
ExpressionActionsPtr & actions,
const String & expr_name,
bool force_uint8)
{
DataTypePtr orig_type = actions->getSampleBlock().getByName(expr_name).type;
if (force_uint8 && isUInt8Type(orig_type))
Expand All @@ -985,7 +1048,10 @@ String DAGExpressionAnalyzer::appendCast(const DataTypePtr & target_type, Expres
}

String DAGExpressionAnalyzer::appendCastIfNeeded(
const tipb::Expr & expr, ExpressionActionsPtr & actions, const String & expr_name, bool explicit_cast)
const tipb::Expr & expr,
ExpressionActionsPtr & actions,
const String & expr_name,
bool explicit_cast)
{
if (!isFunctionExpr(expr))
return expr_name;
Expand All @@ -1000,7 +1066,6 @@ String DAGExpressionAnalyzer::appendCastIfNeeded(
DataTypePtr actual_type = actions->getSampleBlock().getByName(expr_name).type;
if (expected_type->getName() != actual_type->getName())
{

implicit_cast_count += !explicit_cast;
return appendCast(expected_type, actions, expr_name);
}
Expand All @@ -1013,7 +1078,10 @@ String DAGExpressionAnalyzer::appendCastIfNeeded(
}

void DAGExpressionAnalyzer::makeExplicitSet(
const tipb::Expr & expr, const Block & sample_block, bool create_ordered_set, const String & left_arg_name)
const tipb::Expr & expr,
const Block & sample_block,
bool create_ordered_set,
const String & left_arg_name)
{
if (prepared_sets.count(&expr))
{
Expand Down
15 changes: 15 additions & 0 deletions tests/tidb-ci/new_collation_fullstack/function_collator.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
mysql> drop table if exists test.t1
mysql> drop table if exists test.t2
mysql> create table test.t1(col_varchar_20_key_signed varchar(20) COLLATE utf8mb4_general_ci, col_varbinary_20_key_signed varbinary(20), col_varbinary_20_undef_signed varbinary(20));
mysql> create table test.t2(col_char_20_key_signed char(20) COLLATE utf8mb4_general_ci, col_varchar_20_undef_signed varchar(20) COLLATE utf8mb4_general_ci);
mysql> alter table test.t1 set tiflash replica 1
mysql> alter table test.t2 set tiflash replica 1
mysql> insert into test.t1 values('Abc',0x62,0x616263);
mysql> insert into test.t2 values('abc','b');
func> wait_table test t1
func> wait_table test t2

mysql> set session tidb_enforce_mpp=1; select * from test.t1 where t1.col_varchar_20_key_signed not in (select col_char_20_key_signed from test.t2 where t1.col_varchar_20_key_signed not in ( t1.col_varbinary_20_key_signed, t1.col_varbinary_20_undef_signed,col_varchar_20_undef_signed,col_char_20_key_signed));

mysql> drop table if exists test.t1;
mysql> drop table if exists test.t2;

0 comments on commit ad24043

Please sign in to comment.