Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion ydb/core/kqp/opt/rbo/kqp_rbo_rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,23 @@ TExprNode::TPtr PruneCast(TExprNode::TPtr node) {
return node;
}

TVector<TInfoUnit> GetHashableKeys(const std::shared_ptr<IOperator> &input) {
if (!input->Type) {
return input->GetOutputIUs();
}

const auto *inputType = input->Type;
TVector<TInfoUnit> hashableKeys;
const auto* structType = inputType->Cast<TListExprType>()->GetItemType()->Cast<TStructExprType>();
for (const auto &item : structType->GetItems()) {
if (item->GetItemType()->IsHashable()) {
hashableKeys.push_back(TInfoUnit(TString(item->GetName())));
}
}

return hashableKeys;
}

bool IsNullRejectingPredicate(const TFilterInfo &filter, TExprContext &ctx) {
Y_UNUSED(ctx);
#ifdef DEBUG_PREDICATE
Expand Down Expand Up @@ -576,8 +593,9 @@ bool TAssignStagesRule::TestAndApply(std::shared_ptr<IOperator> &input, TRBOCont
const auto newStageId = props.StageGraph.AddStage();
aggregate->Props.StageId = newStageId;
const bool isInputSourceStage = props.StageGraph.IsSourceStage(inputStageId);
const auto shuffleKeys = aggregate->KeyColumns.size() ? aggregate->KeyColumns : GetHashableKeys(aggregate->GetInput());

props.StageGraph.Connect(inputStageId, newStageId, std::make_shared<TShuffleConnection>(aggregate->KeyColumns, isInputSourceStage));
props.StageGraph.Connect(inputStageId, newStageId, std::make_shared<TShuffleConnection>(shuffleKeys, isInputSourceStage));
YQL_CLOG(TRACE, CoreDq) << "Assign stage to Aggregation ";
} else {
Y_ENSURE(false, "Unknown operator encountered");
Expand Down
8 changes: 4 additions & 4 deletions ydb/core/kqp/opt/rbo/kqp_rbo_type_ann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,11 @@ TStatus ComputeTypes(std::shared_ptr<TOpUnionAll> unionAll, TRBOContext & ctx) {
TStatus ComputeTypes(std::shared_ptr<TOpAggregate> aggregate, TRBOContext& ctx) {
auto inputType = aggregate->GetInput()->Type;
const auto* structType = inputType->Cast<TListExprType>()->GetItemType()->Cast<TStructExprType>();
THashMap<TStringBuf, std::pair<TStringBuf, TStringBuf>> aggTraitsMap;
THashMap<TString, std::pair<TString, TString>> aggTraitsMap;
for (const auto& aggTraits : aggregate->AggregationTraitsList) {
const auto originalColName = aggTraits.OriginalColName.GetFullName();
const auto resultColName = aggTraits.ResultColName.GetFullName();
const auto funcName = aggTraits.AggFunction;
const auto originalColName = TString(aggTraits.OriginalColName.GetFullName());
const auto resultColName = TString(aggTraits.ResultColName.GetFullName());
const auto funcName = TString(aggTraits.AggFunction);
aggTraitsMap[originalColName] = {resultColName, funcName};
}

Expand Down
10 changes: 9 additions & 1 deletion ydb/core/kqp/ut/rbo/kqp_rbo_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,11 +423,19 @@ Y_UNIT_TEST_SUITE(KqpRbo) {
SET TablePathPrefix = "/Root/";
select t1.b, sum(t1.c) from t1 inner join t2 on t1.a = t2.a group by t1.b order by t1.b;
)",
R"(
--!syntax_pg
SET TablePathPrefix = "/Root/";
select sum(t1.c) from t1 group by t1.b
union all
select sum(t1.b) from t1;
)",
};

std::vector<std::string> results = {
R"([["1";"4"];["2";"6"]])",
R"([["1";"4"];["2";"6"]])"
R"([["1";"4"];["2";"6"]])",
R"([["6"];["4"];["8"]])"
};

for (ui32 i = 0; i < queries.size(); ++i) {
Expand Down
Loading