Skip to content

Commit

Permalink
Support new agg function APPROX_COUNT_DISTINCT (#798)
Browse files Browse the repository at this point in the history
Signed-off-by: Tong Zhigao <[email protected]>

Co-authored-by: Tong Zhigao <[email protected]>
  • Loading branch information
ti-srebot and solotzg authored Jun 19, 2020
1 parent dc04625 commit ffc9204
Show file tree
Hide file tree
Showing 9 changed files with 233 additions and 26 deletions.
7 changes: 5 additions & 2 deletions dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR;
}

extern const String UniqRawResName = "uniqRawRes";

void AggregateFunctionFactory::registerFunction(const String & name, Creator creator, CaseSensitiveness case_sensitiveness)
{
Expand Down Expand Up @@ -64,10 +65,12 @@ AggregateFunctionPtr AggregateFunctionFactory::get(

AggregateFunctionPtr nested_function;

const static std::unordered_set<String> check_names = {"count", UniqRawResName};

/// A little hack - if we have NULL arguments, don't even create nested function.
/// Combinator will check if nested_function was created.
if (name == "count" || std::none_of(argument_types.begin(), argument_types.end(),
[](const auto & type) { return type->onlyNull(); }))
if (check_names.count(name)
|| std::none_of(argument_types.begin(), argument_types.end(), [](const auto & type) { return type->onlyNull(); }))
nested_function = getImpl(name, nested_types, parameters, recursion_level);

return combinator->transformAggregateFunction(nested_function, argument_types, parameters);
Expand Down
26 changes: 17 additions & 9 deletions dbms/src/AggregateFunctions/AggregateFunctionNull.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
#include <DataTypes/DataTypeNullable.h>
#include <AggregateFunctions/AggregateFunctionNull.h>
#include <AggregateFunctions/AggregateFunctionNothing.h>
#include <AggregateFunctions/AggregateFunctionCount.h>
#include <AggregateFunctions/AggregateFunctionCombinatorFactory.h>
#include <AggregateFunctions/AggregateFunctionCount.h>
#include <AggregateFunctions/AggregateFunctionNothing.h>
#include <AggregateFunctions/AggregateFunctionNull.h>
#include <DataTypes/DataTypeNullable.h>


namespace DB
{

namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}

extern const String UniqRawResName;

class AggregateFunctionCombinatorNull final : public IAggregateFunctionCombinator
{
public:
Expand Down Expand Up @@ -50,7 +52,7 @@ class AggregateFunctionCombinatorNull final : public IAggregateFunctionCombinato
/// - that means - count number of calls, when all arguments are not NULL.
if (nested_function && nested_function->getName() == "count")
{
if(has_nullable_types)
if (has_nullable_types)
{
if (arguments.size() == 1)
return std::make_shared<AggregateFunctionCountNotNullUnary>(arguments[0]);
Expand All @@ -63,10 +65,16 @@ class AggregateFunctionCombinatorNull final : public IAggregateFunctionCombinato
}
}

if (has_null_types)
bool can_output_be_null = true;
if (nested_function && nested_function->getName() == UniqRawResName)
{
can_output_be_null = false;
}

if (has_null_types && can_output_be_null)
return std::make_shared<AggregateFunctionNothing>();

bool return_type_is_nullable = nested_function->getReturnType()->canBeInsideNullable();
bool return_type_is_nullable = can_output_be_null && nested_function->getReturnType()->canBeInsideNullable();

if (arguments.size() == 1)
{
Expand Down Expand Up @@ -100,4 +108,4 @@ void registerAggregateFunctionCombinatorNull(AggregateFunctionCombinatorFactory
factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorNull>());
}

}
} // namespace DB
19 changes: 19 additions & 0 deletions dbms/src/AggregateFunctions/AggregateFunctionUniq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,23 @@ AggregateFunctionPtr createAggregateFunctionUniq(const std::string & name, const
return std::make_shared<AggregateFunctionUniqVariadic<DataForVariadic, false>>(argument_types);
}

AggregateFunctionPtr createAggregateFunctionUniqRawRes(const std::string & name, const DataTypes & argument_types, const Array & params)
{
assertNoParameters(name, params);

if (argument_types.empty())
throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);

/// If there are several arguments, then no tuples allowed among them.
for (const auto & type : argument_types)
if (typeid_cast<const DataTypeTuple *>(type.get()))
throw Exception("Tuple argument of function " + name + " must be the only argument", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

/// "Variadic" method also works as a fallback generic case for single argument.
return std::make_shared<AggregateFunctionUniqVariadic<AggregateFunctionUniqUniquesHashSetDataForVariadicRawRes, false, true>>(
argument_types);
}

}

void registerAggregateFunctionsUniq(AggregateFunctionFactory & factory)
Expand All @@ -125,6 +142,8 @@ void registerAggregateFunctionsUniq(AggregateFunctionFactory & factory)

factory.registerFunction("uniqCombined",
createAggregateFunctionUniq<AggregateFunctionUniqCombinedData, AggregateFunctionUniqCombinedData<UInt64>>);

factory.registerFunction(UniqRawResName, createAggregateFunctionUniqRawRes);
}

}
31 changes: 27 additions & 4 deletions dbms/src/AggregateFunctions/AggregateFunctionUniq.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
#include <IO/WriteHelpers.h>
#include <IO/ReadHelpers.h>

#include <Columns/ColumnString.h>

#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypeTuple.h>

#include <Interpreters/AggregationCommon.h>
Expand All @@ -29,6 +32,8 @@ namespace DB

/// uniq

extern const String UniqRawResName;

struct AggregateFunctionUniqUniquesHashSetData
{
using Set = UniquesHashSet<DefaultHash<UInt64>>;
Expand All @@ -46,6 +51,13 @@ struct AggregateFunctionUniqUniquesHashSetDataForVariadic
static String getName() { return "uniq"; }
};

struct AggregateFunctionUniqUniquesHashSetDataForVariadicRawRes
{
using Set = UniquesHashSet<TrivialHash, false>;
Set set;

static String getName() { return UniqRawResName; }
};

/// uniqHLL12

Expand Down Expand Up @@ -341,8 +353,9 @@ class AggregateFunctionUniq final : public IAggregateFunctionDataHelper<Data, Ag
* You can pass multiple arguments as is; You can also pass one argument - a tuple.
* But (for the possibility of efficient implementation), you can not pass several arguments, among which there are tuples.
*/
template <typename Data, bool argument_is_tuple>
class AggregateFunctionUniqVariadic final : public IAggregateFunctionDataHelper<Data, AggregateFunctionUniqVariadic<Data, argument_is_tuple>>
template <typename Data, bool argument_is_tuple, bool raw_result = false>
class AggregateFunctionUniqVariadic final
: public IAggregateFunctionDataHelper<Data, AggregateFunctionUniqVariadic<Data, argument_is_tuple, raw_result>>
{
private:
static constexpr bool is_exact = std::is_same_v<Data, AggregateFunctionUniqExactData<String>>;
Expand All @@ -362,7 +375,10 @@ class AggregateFunctionUniqVariadic final : public IAggregateFunctionDataHelper<

DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeUInt64>();
if constexpr (raw_result)
return std::make_shared<DataTypeString>();
else
return std::make_shared<DataTypeUInt64>();
}

void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
Expand All @@ -387,7 +403,14 @@ class AggregateFunctionUniqVariadic final : public IAggregateFunctionDataHelper<

void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
static_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).set.size());
if constexpr (raw_result)
{
WriteBufferFromOwnString buf;
serialize(place, buf);
static_cast<ColumnString &>(to).insertData(buf.str().data(), buf.count());
}
else
static_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).set.size());
}

const char * getHeaderFilePath() const override { return __FILE__; }
Expand Down
5 changes: 3 additions & 2 deletions dbms/src/AggregateFunctions/UniquesHashSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ struct UniquesHashSetDefaultHash
};


template <typename Hash = UniquesHashSetDefaultHash>
template <typename Hash = UniquesHashSetDefaultHash, bool use_crc32 = true>
class UniquesHashSet : private HashTableAllocatorWithStackMemory<(1ULL << UNIQUES_HASH_SET_INITIAL_SIZE_DEGREE) * sizeof(UInt32)>
{
private:
Expand Down Expand Up @@ -332,7 +332,8 @@ class UniquesHashSet : private HashTableAllocatorWithStackMemory<(1ULL << UNIQUE
/** Pseudo-random remainder - in order to be not visible,
* that the number is divided by the power of two.
*/
res += (intHashCRC32(m_size) & ((1ULL << skip_degree) - 1));

res += (use_crc32 ? intHashCRC32(m_size) : intHash64(m_size)) & ((1ULL << skip_degree) - 1);

/** Correction of a systematic error due to collisions during hashing in UInt32.
* `fixed_res(res)` formula
Expand Down
80 changes: 73 additions & 7 deletions dbms/src/Debug/dbgFuncCoprocessor.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionUniq.h>
#include <Common/typeid_cast.h>
#include <DataStreams/BlocksListBlockInputStream.h>
#include <Debug/MockTiDB.h>
Expand Down Expand Up @@ -55,8 +56,63 @@ struct DAGProperties
Int32 collator = 0;
};

std::tuple<TableID, DAGSchema, tipb::DAGRequest> compileQuery(
using MakeResOutputStream = std::function<BlockInputStreamPtr(BlockInputStreamPtr)>;

std::tuple<TableID, DAGSchema, tipb::DAGRequest, MakeResOutputStream> compileQuery(
Context & context, const String & query, SchemaFetcher schema_fetcher, const DAGProperties & properties);

class UniqRawResReformatBlockOutputStream : public IProfilingBlockInputStream
{
public:
UniqRawResReformatBlockOutputStream(const BlockInputStreamPtr & in_) : in(in_) {}

String getName() const override { return "UniqRawResReformat"; }

Block getHeader() const override { return in->getHeader(); }

protected:
Block readImpl() override
{
while (true)
{
Block block = in->read();
if (!block)
return block;

size_t num_columns = block.columns();
MutableColumns columns(num_columns);
for (size_t i = 0; i < num_columns; ++i)
{
ColumnWithTypeAndName & ori_column = block.getByPosition(i);

if (std::string::npos != ori_column.name.find_first_of(UniqRawResName))
{
MutableColumnPtr mutable_holder = ori_column.column->cloneEmpty();

for (size_t j = 0; j < ori_column.column->size(); ++j)
{
Field field;
ori_column.column->get(j, field);

auto & str_ref = field.safeGet<String>();

ReadBufferFromString in(str_ref);
AggregateFunctionUniqUniquesHashSetDataForVariadicRawRes set;
set.set.read(in);

mutable_holder->insert(std::to_string(set.set.size()));
}
ori_column.column = std::move(mutable_holder);
}
}
return block;
}
}

private:
BlockInputStreamPtr in;
};

tipb::SelectResponse executeDAGRequest(Context & context, const tipb::DAGRequest & dag_request, RegionID region_id, UInt64 region_version,
UInt64 region_conf_version, Timestamp start_ts, std::vector<std::pair<DecodedTiKVKey, DecodedTiKVKey>> & key_ranges);
BlockInputStreamPtr outputDAGResponse(Context & context, const DAGSchema & schema, const tipb::SelectResponse & dag_response);
Expand Down Expand Up @@ -105,7 +161,7 @@ BlockInputStreamPtr dbgFuncDAG(Context & context, const ASTs & args)
DAGProperties properties = getDAGProperties(prop_string);
Timestamp start_ts = context.getTMTContext().getPDClient()->getTS();

auto [table_id, schema, dag_request] = compileQuery(
auto [table_id, schema, dag_request, func_wrap_output_stream] = compileQuery(
context, query,
[&](const String & database_name, const String & table_name) {
auto storage = context.getTable(database_name, table_name);
Expand Down Expand Up @@ -141,7 +197,7 @@ BlockInputStreamPtr dbgFuncDAG(Context & context, const ASTs & args)
tipb::SelectResponse dag_response
= executeDAGRequest(context, dag_request, region->id(), region->version(), region->confVer(), start_ts, key_ranges);

return outputDAGResponse(context, schema, dag_response);
return func_wrap_output_stream(outputDAGResponse(context, schema, dag_response));
}

BlockInputStreamPtr dbgFuncMockDAG(Context & context, const ASTs & args)
Expand All @@ -162,7 +218,7 @@ BlockInputStreamPtr dbgFuncMockDAG(Context & context, const ASTs & args)
prop_string = safeGet<String>(typeid_cast<const ASTLiteral &>(*args[3]).value);
DAGProperties properties = getDAGProperties(prop_string);

auto [table_id, schema, dag_request] = compileQuery(
auto [table_id, schema, dag_request, func_wrap_output_stream] = compileQuery(
context, query,
[&](const String & database_name, const String & table_name) {
return MockTiDB::instance().getTableByName(database_name, table_name)->table_info;
Expand All @@ -179,7 +235,7 @@ BlockInputStreamPtr dbgFuncMockDAG(Context & context, const ASTs & args)
tipb::SelectResponse dag_response
= executeDAGRequest(context, dag_request, region_id, region->version(), region->confVer(), start_ts, key_ranges);

return outputDAGResponse(context, schema, dag_response);
return func_wrap_output_stream(outputDAGResponse(context, schema, dag_response));
}

struct ExecutorCtx
Expand Down Expand Up @@ -412,9 +468,10 @@ void compileFilter(const DAGSchema & input, ASTPtr ast, tipb::Selection * filter
compileExpr(input, ast, cond, referred_columns, col_ref_map, collator_id);
}

std::tuple<TableID, DAGSchema, tipb::DAGRequest> compileQuery(
std::tuple<TableID, DAGSchema, tipb::DAGRequest, MakeResOutputStream> compileQuery(
Context & context, const String & query, SchemaFetcher schema_fetcher, const DAGProperties & properties)
{
MakeResOutputStream func_wrap_output_stream = [](BlockInputStreamPtr in) { return in; };
DAGSchema schema;
tipb::DAGRequest dag_request;
dag_request.set_time_zone_name(properties.tz_name);
Expand Down Expand Up @@ -653,6 +710,15 @@ std::tuple<TableID, DAGSchema, tipb::DAGRequest> compileQuery(
ft->set_tp(agg_func->children(0).field_type().tp());
ft->set_collate(properties.collator);
}
else if (func->name == UniqRawResName)
{
agg_func->set_tp(tipb::ApproxCountDistinct);
auto ft = agg_func->mutable_field_type();
ft->set_tp(TiDB::TypeString);
ft->set_flag(1);
func_wrap_output_stream
= [](BlockInputStreamPtr in) { return std::make_shared<UniqRawResReformatBlockOutputStream>(in); };
}
// TODO: Other agg func.
else
{
Expand Down Expand Up @@ -720,7 +786,7 @@ std::tuple<TableID, DAGSchema, tipb::DAGRequest> compileQuery(
}
}

return std::make_tuple(table_info.id, std::move(schema), std::move(dag_request));
return std::make_tuple(table_info.id, std::move(schema), std::move(dag_request), func_wrap_output_stream);
}

tipb::SelectResponse executeDAGRequest(Context & context, const tipb::DAGRequest & dag_request, RegionID region_id, UInt64 region_version,
Expand Down
6 changes: 5 additions & 1 deletion dbms/src/Flash/Coprocessor/DAGUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ String exprToString(const tipb::Expr & expr, const std::vector<NameAndTypePair>
case tipb::ExprType::Min:
case tipb::ExprType::Max:
case tipb::ExprType::First:
case tipb::ExprType::ApproxCountDistinct:
if (agg_func_map.find(expr.tp()) == agg_func_map.end())
{
throw Exception(tipb::ExprType_Name(expr.tp()) + " not supported", ErrorCodes::UNSUPPORTED_METHOD);
Expand Down Expand Up @@ -178,6 +179,7 @@ bool isAggFunctionExpr(const tipb::Expr & expr)
case tipb::ExprType::Variance:
case tipb::ExprType::JsonArrayAgg:
case tipb::ExprType::JsonObjectAgg:
case tipb::ExprType::ApproxCountDistinct:
return true;
default:
return false;
Expand Down Expand Up @@ -374,9 +376,11 @@ std::shared_ptr<TiDB::ITiDBCollator> getCollatorFromExpr(const tipb::Expr & expr
return ret;
}

extern const String UniqRawResName;

std::unordered_map<tipb::ExprType, String> agg_func_map({
{tipb::ExprType::Count, "count"}, {tipb::ExprType::Sum, "sum"}, {tipb::ExprType::Min, "min"}, {tipb::ExprType::Max, "max"},
{tipb::ExprType::First, "any"},
{tipb::ExprType::First, "any"}, {tipb::ExprType::ApproxCountDistinct, UniqRawResName},
//{tipb::ExprType::Avg, ""},
//{tipb::ExprType::GroupConcat, ""},
//{tipb::ExprType::Agg_BitAnd, ""},
Expand Down
Loading

0 comments on commit ffc9204

Please sign in to comment.