diff --git a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp index 754e058fe72..a4f46ef813b 100644 --- a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp +++ b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp @@ -21,15 +21,15 @@ namespace DB { - namespace ErrorCodes { extern const int COP_BAD_DAG_REQUEST; extern const int UNSUPPORTED_METHOD; } // namespace ErrorCodes -static String genFuncString(const String & func_name, const Names & argument_names) +static String genFuncString(const String & func_name, const Names & argument_names, const TiDB::TiDBCollators & collators) { + assert(!collators.empty()); std::stringstream ss; ss << func_name << "("; bool first = true; @@ -45,7 +45,15 @@ static String genFuncString(const String & func_name, const Names & argument_nam } ss << argument_name; } - ss << ") "; + ss << ")_collator"; + for (const auto & collator : collators) + { + if (collator == nullptr) + ss << "_0"; + else + ss << "_" << collator->getCollatorId(); + } + ss << " "; return ss.str(); } @@ -206,10 +214,14 @@ static String buildLeftUTF8Function(DAGExpressionAnalyzer * analyzer, const tipb static const String tidb_cast_name = "tidb_cast"; -static String buildCastFunctionInternal(DAGExpressionAnalyzer * analyzer, const Names & argument_names, bool in_union, - const tipb::FieldType & field_type, ExpressionActionsPtr & actions) +static String buildCastFunctionInternal( + DAGExpressionAnalyzer * analyzer, + const Names & argument_names, + bool in_union, + const tipb::FieldType & field_type, + ExpressionActionsPtr & actions) { - String result_name = genFuncString(tidb_cast_name, argument_names); + String result_name = genFuncString(tidb_cast_name, argument_names, {nullptr}); if (actions->getSampleBlock().has(result_name)) return result_name; @@ -247,16 +259,29 @@ struct DateAdd static constexpr auto name = "date_add"; static const std::unordered_map unit_to_func_name_map; }; -const std::unordered_map DateAdd::unit_to_func_name_map = {{"DAY", "addDays"}, {"WEEK", "addWeeks"}, {"MONTH", "addMonths"}, - {"YEAR", "addYears"}, {"HOUR", "addHours"}, {"MINUTE", "addMinutes"}, {"SECOND", "addSeconds"}}; +const std::unordered_map DateAdd::unit_to_func_name_map + = { + {"DAY", "addDays"}, + {"WEEK", "addWeeks"}, + {"MONTH", "addMonths"}, + {"YEAR", "addYears"}, + {"HOUR", "addHours"}, + {"MINUTE", "addMinutes"}, + {"SECOND", "addSeconds"}}; struct DateSub { static constexpr auto name = "date_sub"; static const std::unordered_map unit_to_func_name_map; }; const std::unordered_map DateSub::unit_to_func_name_map - = {{"DAY", "subtractDays"}, {"WEEK", "subtractWeeks"}, {"MONTH", "subtractMonths"}, {"YEAR", "subtractYears"}, - {"HOUR", "subtractHours"}, {"MINUTE", "subtractMinutes"}, {"SECOND", "subtractSeconds"}}; + = { + {"DAY", "subtractDays"}, + {"WEEK", "subtractWeeks"}, + {"MONTH", "subtractMonths"}, + {"YEAR", "subtractYears"}, + {"HOUR", "subtractHours"}, + {"MINUTE", "subtractMinutes"}, + {"SECOND", "subtractSeconds"}}; template static String buildDateAddOrSubFunction(DAGExpressionAnalyzer * analyzer, const tipb::Expr & expr, ExpressionActionsPtr & actions) @@ -270,12 +295,14 @@ static String buildDateAddOrSubFunction(DAGExpressionAnalyzer * analyzer, const if (expr.children(2).tp() != tipb::ExprType::String) { throw TiFlashException( - std::string() + "3rd argument of " + Impl::name + " function must be string literal", Errors::Coprocessor::BadRequest); + std::string() + "3rd argument of " + Impl::name + " function must be string literal", + Errors::Coprocessor::BadRequest); } String unit = expr.children(2).val(); if (Impl::unit_to_func_name_map.find(unit) == Impl::unit_to_func_name_map.end()) throw TiFlashException( - std::string() + Impl::name + " function does not support unit " + unit + " yet.", Errors::Coprocessor::Unimplemented); + std::string() + Impl::name + " function does not support unit " + unit + " yet.", + Errors::Coprocessor::Unimplemented); String func_name = Impl::unit_to_func_name_map.find(unit)->second; const auto & date_column_type = removeNullable(actions->getSampleBlock().getByName(date_column).type); if (!date_column_type->isDateOrDateTime()) @@ -350,21 +377,32 @@ static std::unordered_map}, {"date_sub", buildDateAddOrSubFunction}}); DAGExpressionAnalyzer::DAGExpressionAnalyzer(std::vector && source_columns_, const Context & context_) - : source_columns(std::move(source_columns_)), context(context_), after_agg(false), implicit_cast_count(0) + : source_columns(std::move(source_columns_)) + , context(context_) + , after_agg(false) + , implicit_cast_count(0) { settings = context.getSettings(); } DAGExpressionAnalyzer::DAGExpressionAnalyzer(std::vector & source_columns_, const Context & context_) - : source_columns(source_columns_), context(context_), after_agg(false), implicit_cast_count(0) + : source_columns(source_columns_) + , context(context_) + , after_agg(false) + , implicit_cast_count(0) { settings = context.getSettings(); } extern const String CountSecondStage; -void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, const tipb::Aggregation & agg, Names & aggregation_keys, - TiDB::TiDBCollators & collators, AggregateDescriptions & aggregate_descriptions, bool group_by_collation_sensitive) +void DAGExpressionAnalyzer::appendAggregation( + ExpressionActionsChain & chain, + const tipb::Aggregation & agg, + Names & aggregation_keys, + TiDB::TiDBCollators & collators, + AggregateDescriptions & aggregate_descriptions, + bool group_by_collation_sensitive) { if (agg.group_by_size() == 0 && agg.agg_func_size() == 0) { @@ -401,7 +439,8 @@ void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, co aggregate.argument_names[i] = arg_name; step.required_output.push_back(arg_name); } - String func_string = genFuncString(agg_func_name, aggregate.argument_names); + auto function_collator = getCollatorFromExpr(expr); + String func_string = genFuncString(agg_func_name, aggregate.argument_names, {function_collator}); bool duplicate = false; for (const auto & pre_agg : aggregate_descriptions) { @@ -418,7 +457,7 @@ void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, co aggregate.parameters = Array(); /// if there is group by clause, there is no need to consider the empty input case aggregate.function = AggregateFunctionFactory::instance().get(agg_func_name, types, {}, 0, agg.group_by_size() == 0); - aggregate.function->setCollator(getCollatorFromExpr(expr)); + aggregate.function->setCollator(function_collator); aggregate_descriptions.push_back(aggregate); DataTypePtr result_type = aggregate.function->getReturnType(); // this is a temp result since implicit cast maybe added on these aggregated_columns @@ -463,7 +502,8 @@ void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, co types[0] = type; aggregate.argument_names[0] = name; - String func_string = genFuncString(agg_func_name, aggregate.argument_names); + auto function_collator = getCollatorFromExpr(expr); + String func_string = genFuncString(agg_func_name, aggregate.argument_names, {function_collator}); bool duplicate = false; for (const auto & pre_agg : aggregate_descriptions) { @@ -479,7 +519,7 @@ void DAGExpressionAnalyzer::appendAggregation(ExpressionActionsChain & chain, co aggregate.column_name = func_string; aggregate.parameters = Array(); aggregate.function = AggregateFunctionFactory::instance().get(agg_func_name, types, {}, 0, false); - aggregate.function->setCollator(getCollatorFromExpr(expr)); + aggregate.function->setCollator(function_collator); aggregate_descriptions.push_back(aggregate); DataTypePtr result_type = aggregate.function->getReturnType(); // this is a temp result since implicit cast maybe added on these aggregated_columns @@ -506,9 +546,13 @@ bool isUInt8Type(const DataTypePtr & type) } String DAGExpressionAnalyzer::applyFunction( - const String & func_name, const Names & arg_names, ExpressionActionsPtr & actions, std::shared_ptr collator) + const String & func_name, + const Names & arg_names, + ExpressionActionsPtr & actions, + std::shared_ptr collator) + { - String result_name = genFuncString(func_name, arg_names); + String result_name = genFuncString(func_name, arg_names, {collator}); if (actions->getSampleBlock().has(result_name)) return result_name; const FunctionBuilderPtr & function_builder = FunctionFactory::instance().get(func_name, context); @@ -518,7 +562,9 @@ String DAGExpressionAnalyzer::applyFunction( } void DAGExpressionAnalyzer::appendWhere( - ExpressionActionsChain & chain, const std::vector & conditions, String & filter_column_name) + ExpressionActionsChain & chain, + const std::vector & conditions, + String & filter_column_name) { initChain(chain, getCurrentInputColumns()); ExpressionActionsChain::Step & last_step = chain.steps.back(); @@ -605,7 +651,9 @@ String DAGExpressionAnalyzer::convertToUInt8(ExpressionActionsPtr & actions, con } void DAGExpressionAnalyzer::appendOrderBy( - ExpressionActionsChain & chain, const tipb::TopN & topN, std::vector & order_columns) + ExpressionActionsChain & chain, + const tipb::TopN & topN, + std::vector & order_columns) { if (topN.order_by_size() == 0) { @@ -645,7 +693,10 @@ void constructTZExpr(tipb::Expr & tz_expr, const TimezoneInfo & dag_timezone_inf } String DAGExpressionAnalyzer::appendTimeZoneCast( - const String & tz_col, const String & ts_col, const String & func_name, ExpressionActionsPtr & actions) + const String & tz_col, + const String & ts_col, + const String & func_name, + ExpressionActionsPtr & actions) { String cast_expr_name = applyFunction(func_name, {ts_col, tz_col}, actions, nullptr); return cast_expr_name; @@ -688,16 +739,24 @@ bool DAGExpressionAnalyzer::appendTimeZoneCastsAfterTS(ExpressionActionsChain & } void DAGExpressionAnalyzer::appendJoin( - ExpressionActionsChain & chain, SubqueryForSet & join_query, const NamesAndTypesList & columns_added_by_join) + ExpressionActionsChain & chain, + SubqueryForSet & join_query, + const NamesAndTypesList & columns_added_by_join) { initChain(chain, getCurrentInputColumns()); ExpressionActionsPtr actions = chain.getLastActions(); actions->add(ExpressionAction::ordinaryJoin(join_query.join, columns_added_by_join)); } /// return true if some actions is needed -bool DAGExpressionAnalyzer::appendJoinKeyAndJoinFilters(ExpressionActionsChain & chain, - const google::protobuf::RepeatedPtrField & keys, const DataTypes & key_types, Names & key_names, bool left, - bool is_right_out_join, const google::protobuf::RepeatedPtrField & filters, String & filter_column_name) +bool DAGExpressionAnalyzer::appendJoinKeyAndJoinFilters( + ExpressionActionsChain & chain, + const google::protobuf::RepeatedPtrField & keys, + const DataTypes & key_types, + Names & key_names, + bool left, + bool is_right_out_join, + const google::protobuf::RepeatedPtrField & filters, + String & filter_column_name) { bool ret = false; initChain(chain, getCurrentInputColumns()); @@ -849,8 +908,12 @@ void DAGExpressionAnalyzer::appendAggSelect(ExpressionActionsChain & chain, cons } } -void DAGExpressionAnalyzer::generateFinalProject(ExpressionActionsChain & chain, const std::vector & schema, - const std::vector & output_offsets, const String & column_prefix, bool keep_session_timezone_info, +void DAGExpressionAnalyzer::generateFinalProject( + ExpressionActionsChain & chain, + const std::vector & schema, + const std::vector & output_offsets, + const String & column_prefix, + bool keep_session_timezone_info, NamesWithAliases & final_project) { if (unlikely(!keep_session_timezone_info && output_offsets.empty())) @@ -888,7 +951,8 @@ void DAGExpressionAnalyzer::generateFinalProject(ExpressionActionsChain & chain, for (auto i : output_offsets) { final_project.emplace_back( - current_columns[i].name, unique_name_generator.toUniqueName(column_prefix + current_columns[i].name)); + current_columns[i].name, + unique_name_generator.toUniqueName(column_prefix + current_columns[i].name)); } } else @@ -945,7 +1009,8 @@ void DAGExpressionAnalyzer::generateFinalProject(ExpressionActionsChain & chain, else { final_project.emplace_back( - current_columns[i].name, unique_name_generator.toUniqueName(column_prefix + current_columns[i].name)); + current_columns[i].name, + unique_name_generator.toUniqueName(column_prefix + current_columns[i].name)); } } } @@ -961,7 +1026,10 @@ void DAGExpressionAnalyzer::generateFinalProject(ExpressionActionsChain & chain, * @return */ String DAGExpressionAnalyzer::alignReturnType( - const tipb::Expr & expr, ExpressionActionsPtr & actions, const String & expr_name, bool force_uint8) + const tipb::Expr & expr, + ExpressionActionsPtr & actions, + const String & expr_name, + bool force_uint8) { DataTypePtr orig_type = actions->getSampleBlock().getByName(expr_name).type; if (force_uint8 && isUInt8Type(orig_type)) @@ -985,7 +1053,10 @@ String DAGExpressionAnalyzer::appendCast(const DataTypePtr & target_type, Expres } String DAGExpressionAnalyzer::appendCastIfNeeded( - const tipb::Expr & expr, ExpressionActionsPtr & actions, const String & expr_name, bool explicit_cast) + const tipb::Expr & expr, + ExpressionActionsPtr & actions, + const String & expr_name, + bool explicit_cast) { if (!isFunctionExpr(expr)) return expr_name; @@ -1000,7 +1071,6 @@ String DAGExpressionAnalyzer::appendCastIfNeeded( DataTypePtr actual_type = actions->getSampleBlock().getByName(expr_name).type; if (expected_type->getName() != actual_type->getName()) { - implicit_cast_count += !explicit_cast; return appendCast(expected_type, actions, expr_name); } @@ -1013,7 +1083,10 @@ String DAGExpressionAnalyzer::appendCastIfNeeded( } void DAGExpressionAnalyzer::makeExplicitSet( - const tipb::Expr & expr, const Block & sample_block, bool create_ordered_set, const String & left_arg_name) + const tipb::Expr & expr, + const Block & sample_block, + bool create_ordered_set, + const String & left_arg_name) { if (prepared_sets.count(&expr)) { diff --git a/tests/tidb-ci/new_collation_fullstack/function_collator.test b/tests/tidb-ci/new_collation_fullstack/function_collator.test new file mode 100644 index 00000000000..a4c0935d785 --- /dev/null +++ b/tests/tidb-ci/new_collation_fullstack/function_collator.test @@ -0,0 +1,15 @@ +mysql> drop table if exists test.t1 +mysql> drop table if exists test.t2 +mysql> create table test.t1(col_varchar_20_key_signed varchar(20) COLLATE utf8mb4_general_ci, col_varbinary_20_key_signed varbinary(20), col_varbinary_20_undef_signed varbinary(20)); +mysql> create table test.t2(col_char_20_key_signed char(20) COLLATE utf8mb4_general_ci, col_varchar_20_undef_signed varchar(20) COLLATE utf8mb4_general_ci); +mysql> alter table test.t1 set tiflash replica 1 +mysql> alter table test.t2 set tiflash replica 1 +mysql> insert into test.t1 values('Abc',0x62,0x616263); +mysql> insert into test.t2 values('abc','b'); +func> wait_table test t1 +func> wait_table test t2 + +mysql> set session tidb_enforce_mpp=1; select * from test.t1 where t1.col_varchar_20_key_signed not in (select col_char_20_key_signed from test.t2 where t1.col_varchar_20_key_signed not in ( t1.col_varbinary_20_key_signed, t1.col_varbinary_20_undef_signed,col_varchar_20_undef_signed,col_char_20_key_signed)); + +mysql> drop table if exists test.t1; +mysql> drop table if exists test.t2;