Skip to content

Commit

Permalink
function result name should contain collator info (#2808) (#3018)
Browse files Browse the repository at this point in the history
  • Loading branch information
ti-chi-bot authored Sep 13, 2021
1 parent 3e0e679 commit ed1fecd
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 21 deletions.
76 changes: 55 additions & 21 deletions dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@

namespace DB
{

namespace ErrorCodes
{
extern const int COP_BAD_DAG_REQUEST;
extern const int UNSUPPORTED_METHOD;
} // namespace ErrorCodes

static String genFuncString(const String & func_name, const Names & argument_names)
static String genFuncString(const String & func_name, const Names & argument_names, const TiDB::TiDBCollators & collators)
{
assert(!collators.empty());
std::stringstream ss;
ss << func_name << "(";
bool first = true;
Expand All @@ -46,7 +46,15 @@ static String genFuncString(const String & func_name, const Names & argument_nam
}
ss << argument_name;
}
ss << ") ";
ss << ")_collator";
for (const auto & collator : collators)
{
if (collator == nullptr)
ss << "_0";
else
ss << "_" << collator->getCollatorId();
}
ss << " ";
return ss.str();
}

Expand Down Expand Up @@ -162,10 +170,14 @@ static String buildLogicalFunction(DAGExpressionAnalyzer * analyzer, const tipb:

static const String tidb_cast_name = "tidb_cast";

static String buildCastFunctionInternal(DAGExpressionAnalyzer * analyzer, const Names & argument_names, bool in_union,
const tipb::FieldType & field_type, ExpressionActionsPtr & actions)
static String buildCastFunctionInternal(
DAGExpressionAnalyzer * analyzer,
const Names & argument_names,
bool in_union,
const tipb::FieldType & field_type,
ExpressionActionsPtr & actions)
{
String result_name = genFuncString(tidb_cast_name, argument_names);
String result_name = genFuncString(tidb_cast_name, argument_names, {nullptr});
if (actions->getSampleBlock().has(result_name))
return result_name;

Expand Down Expand Up @@ -200,7 +212,6 @@ static String buildCastFunction(DAGExpressionAnalyzer * analyzer, const tipb::Ex

static String buildDateAddFunction(DAGExpressionAnalyzer * analyzer, const tipb::Expr & expr, ExpressionActionsPtr & actions)
{

static const std::unordered_map<String, String> unit_to_func_name_map({{"DAY", "addDays"}, {"WEEK", "addWeeks"}, {"MONTH", "addMonths"},
{"YEAR", "addYears"}, {"HOUR", "addHours"}, {"MINUTE", "addMinutes"}, {"SECOND", "addSeconds"}});
if (expr.children_size() != 3)
Expand Down Expand Up @@ -303,7 +314,10 @@ static std::unordered_map<String, std::function<String(DAGExpressionAnalyzer *,
});

DAGExpressionAnalyzer::DAGExpressionAnalyzer(std::vector<NameAndTypePair> && source_columns_, const Context & context_)
: source_columns(std::move(source_columns_)), context(context_), after_agg(false), implicit_cast_count(0)
: source_columns(std::move(source_columns_))
, context(context_)
, after_agg(false)
, implicit_cast_count(0)
{
settings = context.getSettings();
}
Expand Down Expand Up @@ -332,7 +346,8 @@ void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, co
aggregate.argument_names[i] = arg_name;
step.required_output.push_back(arg_name);
}
String func_string = genFuncString(agg_func_name, aggregate.argument_names);
auto function_collator = getCollatorFromExpr(expr);
String func_string = genFuncString(agg_func_name, aggregate.argument_names, {function_collator});
bool duplicate = false;
for (const auto & pre_agg : aggregate_descriptions)
{
Expand All @@ -349,7 +364,7 @@ void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, co
aggregate.parameters = Array();
/// if there is group by clause, there is no need to consider the empty input case
aggregate.function = AggregateFunctionFactory::instance().get(agg_func_name, types, {}, 0, agg.group_by_size() == 0);
aggregate.function->setCollator(getCollatorFromExpr(expr));
aggregate.function->setCollator(function_collator);
aggregate_descriptions.push_back(aggregate);
DataTypePtr result_type = aggregate.function->getReturnType();
// this is a temp result since implicit cast maybe added on these aggregated_columns
Expand Down Expand Up @@ -386,7 +401,8 @@ void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, co
types[0] = type;
aggregate.argument_names[0] = name;

String func_string = genFuncString(agg_func_name, aggregate.argument_names);
auto function_collator = getCollatorFromExpr(expr);
String func_string = genFuncString(agg_func_name, aggregate.argument_names, {function_collator});
bool duplicate = false;
for (const auto & pre_agg : aggregate_descriptions)
{
Expand All @@ -402,7 +418,7 @@ void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, co
aggregate.column_name = func_string;
aggregate.parameters = Array();
aggregate.function = AggregateFunctionFactory::instance().get(agg_func_name, types, {}, 0, false);
aggregate.function->setCollator(getCollatorFromExpr(expr));
aggregate.function->setCollator(function_collator);
aggregate_descriptions.push_back(aggregate);
DataTypePtr result_type = aggregate.function->getReturnType();
// this is a temp result since implicit cast maybe added on these aggregated_columns
Expand Down Expand Up @@ -431,7 +447,7 @@ bool isUInt8Type(const DataTypePtr & type)
String DAGExpressionAnalyzer::applyFunction(
const String & func_name, const Names & arg_names, ExpressionActionsPtr & actions, std::shared_ptr<TiDB::ITiDBCollator> collator)
{
String result_name = genFuncString(func_name, arg_names);
String result_name = genFuncString(func_name, arg_names, {collator});
if (actions->getSampleBlock().has(result_name))
return result_name;
const FunctionBuilderPtr & function_builder = FunctionFactory::instance().get(func_name, context);
Expand All @@ -441,7 +457,9 @@ String DAGExpressionAnalyzer::applyFunction(
}

void DAGExpressionAnalyzer::appendWhere(
ExpressionActionsChain & chain, const std::vector<const tipb::Expr *> & conditions, String & filter_column_name)
ExpressionActionsChain & chain,
const std::vector<const tipb::Expr *> & conditions,
String & filter_column_name)
{
initChain(chain, getCurrentInputColumns());
ExpressionActionsChain::Step & last_step = chain.steps.back();
Expand Down Expand Up @@ -528,7 +546,9 @@ String DAGExpressionAnalyzer::convertToUInt8(ExpressionActionsPtr & actions, con
}

void DAGExpressionAnalyzer::appendOrderBy(
ExpressionActionsChain & chain, const tipb::TopN & topN, std::vector<NameAndTypePair> & order_columns)
ExpressionActionsChain & chain,
const tipb::TopN & topN,
std::vector<NameAndTypePair> & order_columns)
{
if (topN.order_by_size() == 0)
{
Expand Down Expand Up @@ -568,7 +588,10 @@ void constructTZExpr(tipb::Expr & tz_expr, const TimezoneInfo & dag_timezone_inf
}

String DAGExpressionAnalyzer::appendTimeZoneCast(
const String & tz_col, const String & ts_col, const String & func_name, ExpressionActionsPtr & actions)
const String & tz_col,
const String & ts_col,
const String & func_name,
ExpressionActionsPtr & actions)
{
String cast_expr_name = applyFunction(func_name, {ts_col, tz_col}, actions, nullptr);
return cast_expr_name;
Expand Down Expand Up @@ -614,12 +637,15 @@ bool DAGExpressionAnalyzer::appendTimeZoneCastsAfterTS(ExpressionActionsChain &
}

void DAGExpressionAnalyzer::appendJoin(
ExpressionActionsChain & chain, SubqueryForSet & join_query, const NamesAndTypesList & columns_added_by_join)
ExpressionActionsChain & chain,
SubqueryForSet & join_query,
const NamesAndTypesList & columns_added_by_join)
{
initChain(chain, getCurrentInputColumns());
ExpressionActionsPtr actions = chain.getLastActions();
actions->add(ExpressionAction::ordinaryJoin(join_query.join, columns_added_by_join));
}

/// return true if some actions is needed
bool DAGExpressionAnalyzer::appendJoinKey(ExpressionActionsChain & chain, const google::protobuf::RepeatedPtrField<tipb::Expr> & keys,
const DataTypes & key_types, Names & key_names, bool left, bool is_right_out_join)
Expand Down Expand Up @@ -774,7 +800,10 @@ void DAGExpressionAnalyzer::appendAggSelect(
* @return
*/
String DAGExpressionAnalyzer::alignReturnType(
const tipb::Expr & expr, ExpressionActionsPtr & actions, const String & expr_name, bool force_uint8)
const tipb::Expr & expr,
ExpressionActionsPtr & actions,
const String & expr_name,
bool force_uint8)
{
DataTypePtr orig_type = actions->getSampleBlock().getByName(expr_name).type;
if (force_uint8 && isUInt8Type(orig_type))
Expand All @@ -798,7 +827,10 @@ String DAGExpressionAnalyzer::appendCast(const DataTypePtr & target_type, Expres
}

String DAGExpressionAnalyzer::appendCastIfNeeded(
const tipb::Expr & expr, ExpressionActionsPtr & actions, const String & expr_name, bool explicit_cast)
const tipb::Expr & expr,
ExpressionActionsPtr & actions,
const String & expr_name,
bool explicit_cast)
{
if (!isFunctionExpr(expr))
return expr_name;
Expand All @@ -813,7 +845,6 @@ String DAGExpressionAnalyzer::appendCastIfNeeded(
DataTypePtr actual_type = actions->getSampleBlock().getByName(expr_name).type;
if (expected_type->getName() != actual_type->getName())
{

implicit_cast_count += !explicit_cast;
return appendCast(expected_type, actions, expr_name);
}
Expand All @@ -826,7 +857,10 @@ String DAGExpressionAnalyzer::appendCastIfNeeded(
}

void DAGExpressionAnalyzer::makeExplicitSet(
const tipb::Expr & expr, const Block & sample_block, bool create_ordered_set, const String & left_arg_name)
const tipb::Expr & expr,
const Block & sample_block,
bool create_ordered_set,
const String & left_arg_name)
{
if (prepared_sets.count(&expr))
{
Expand Down
15 changes: 15 additions & 0 deletions tests/tidb-ci/new_collation_fullstack/function_collator.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
mysql> drop table if exists test.t1
mysql> drop table if exists test.t2
mysql> create table test.t1(col_varchar_20_key_signed varchar(20) COLLATE utf8mb4_general_ci, col_varbinary_20_key_signed varbinary(20), col_varbinary_20_undef_signed varbinary(20));
mysql> create table test.t2(col_char_20_key_signed char(20) COLLATE utf8mb4_general_ci, col_varchar_20_undef_signed varchar(20) COLLATE utf8mb4_general_ci);
mysql> alter table test.t1 set tiflash replica 1
mysql> alter table test.t2 set tiflash replica 1
mysql> insert into test.t1 values('Abc',0x62,0x616263);
mysql> insert into test.t2 values('abc','b');
func> wait_table test t1
func> wait_table test t2

mysql> set @@tidb_isolation_read_engines='tiflash'; select * from test.t1 where t1.col_varchar_20_key_signed not in (select col_char_20_key_signed from test.t2 where t1.col_varchar_20_key_signed not in ( t1.col_varbinary_20_key_signed, t1.col_varbinary_20_undef_signed,col_varchar_20_undef_signed,col_char_20_key_signed));

mysql> drop table if exists test.t1;
mysql> drop table if exists test.t2;

0 comments on commit ed1fecd

Please sign in to comment.