From c9626dfd2e3d962119c2218631e029280f6f940d Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Thu, 5 Sep 2024 17:00:56 -0400 Subject: [PATCH 01/25] CountMinSketch Additions --- src/commands/cmd_cms.cc | 189 ++++++++++++++++++++++++++++++++++++++++ src/types/cms.cc | 125 ++++++++++++++++++++++++++ src/types/cms.h | 74 ++++++++++++++++ src/types/redis_cms.cc | 163 ++++++++++++++++++++++++++++++++++ src/types/redis_cms.h | 48 ++++++++++ 5 files changed, 599 insertions(+) create mode 100644 src/commands/cmd_cms.cc create mode 100644 src/types/cms.cc create mode 100644 src/types/cms.h create mode 100644 src/types/redis_cms.cc create mode 100644 src/types/redis_cms.h diff --git a/src/commands/cmd_cms.cc b/src/commands/cmd_cms.cc new file mode 100644 index 00000000000..9399d0cace2 --- /dev/null +++ b/src/commands/cmd_cms.cc @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include + +#include + +#include "commander.h" +#include "commands/command_parser.h" +#include "server/redis_reply.h" +#include "server/server.h" + +namespace redis { + +// CMS.INCRBY key item increment [item increment ...] +class CommandIncrBy final : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + if ((args_.size() - 2) % 2 != 0) { + return Status::RedisTryAgain; + } + + redis::CMS cms(srv->storage, conn->GetNamespace()); + rocksdb::Status s; + std::unordered_map elements; + for (int i = 2; i < args_.size(); i += 2) { + std::string key = args_[i]; + uint64_t value = 0; + try { + value = std::stoull(args_[i + 1]); + } catch (const std::exception &e) { + return Status::InvalidArgument; + } + elements[key] = value; + } + + s = cms.IncrBy(args_[1], elements); + + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + return Status::OK(); + } +}; + +// CMS.INFO key +class CommandInfo final : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::CMS cms(srv->storage, conn->GetNamespace()); + rocksdb::Status s; + std::unordered_map elements; + std::vector ret{}; + + s = cms.Info(args_[1], &ret); + + if (!s.ok() && !s.IsNotFound()) { + return {Status::RedisExecErr, s.ToString()}; + } + + *output = redis::Array({ + redis::BulkString("width"), + redis::Integer(ret[0]), + redis::BulkString("depth"), + redis::Integer(ret[1]), + redis::BulkString("count"), + redis::Integer(ret[2]) + }); + + return Status::OK(); + } +}; + +// CMS.INITBYDIM key width depth +class CommandInitByDim final : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::CMS cms(srv->storage, conn->GetNamespace()); + rocksdb::Status s; + + try { + uint64_t width = std::stoull(args_[2]); + uint64_t depth = std::stoull(args_[3]); + + s = cms.InitByDim(args_[1], width, depth); + + } catch (const std::exception &e) { + return {Status::RedisExecErr, "Invalid dimensions: " + std::string(e.what())}; + } + + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + + return Status::OK(); + } +}; + +// CMS.INITBYPROB key error probability +class CommandInitByProb final : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::CMS cms(srv->storage, conn->GetNamespace()); + rocksdb::Status s; + + try { + uint64_t error = std::stoull(args_[2]); + uint64_t delta = std::stoull(args_[3]); + + s = cms.InitByDim(args_[1], error, delta); + + } catch (const std::exception &e) { + return {Status::RedisExecErr, "Invalid dimensions: " + std::string(e.what())}; + } + + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + + return Status::OK(); + } +}; + +// CMS.MERGE destination numKeys source [source ...] [WEIGHTS weight [weight ...]] +// class CommandMerge final : public Commander { +// public: +// Status Execute(Server *srv, Connection *conn, std::string *output) override { + +// } +// }; + +// CMS.QUERY key item [item ...] +class CommandQuery final : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::CMS cms(srv->storage, conn->GetNamespace()); + rocksdb::Status s; + + std::vector counters{}; + std::vector elements; + + for (int i = 2; i < args_.size(); ++i) { + elements.emplace_back(args_[i]); + } + + s = cms.Query(args_[1], elements, counters); + + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + + std::vector output_values; + output_values.reserve(counters.size()); + for (const auto &counter : counters) { + output_values.push_back(std::to_string(counter)); + } + + *output = redis::ArrayOfBulkStrings(output_values); + + return Status::OK(); + } +}; + + +REDIS_REGISTER_COMMANDS(CMS, MakeCmdAttr("incrby", -3, "write", 0, 0, 0), + MakeCmdAttr("info", 1, "read-only", 0, 0, 0), + MakeCmdAttr("initbydim", 3, "write", 0, 0, 0), + MakeCmdAttr("initbyprob", 3, "write", 0, 0, 0), + // MakeCmdAttr("merge", -3, "write", 0, 0, 0), + MakeCmdAttr("query", -2, "read-only", 0, 0, 0), ); + +} // namespace redis \ No newline at end of file diff --git a/src/types/cms.cc b/src/types/cms.cc new file mode 100644 index 00000000000..25c168d29d7 --- /dev/null +++ b/src/types/cms.cc @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "cms.h" +#include +#include +#include +#include + +void CMSketch::CMSDimFromProb(double error, double delta, size_t &width, size_t &depth) { + width = std::ceil(2 / error); + depth = std::ceil(std::log10(delta) / std::log10(0.5)); +} + + +size_t CMSketch::IncrBy(const char* item, size_t item_len, size_t value) { + size_t min_count = std::numeric_limits::max(); + + for (size_t i = 0; i < depth_; ++i) { + // check over the static casting the parameter to an int + uint32_t hash = HllMurMurHash64A(item, static_cast(item_len), i); + size_t loc = (hash % width_) + (i * width_); + array_[loc] += value; + if (array_[loc] < value) { + array_[loc] = UINT32_MAX; + } + min_count = std::min(min_count, static_cast(array_[loc])); + } + counter_ += value; + return min_count; +} + +size_t CMSketch::Query(const char* item, size_t item_len) const { + size_t min_count = std::numeric_limits::max(); + + for (size_t i = 0; i < depth_; ++i) { + uint32_t hash = HllMurMurHash64A(item, static_cast(item_len), i); + min_count = std::min(min_count, static_cast(array_[(hash % width_) + (i * width_)])); + } + return min_count; +} + +int CMSketch::Merge(CMSketch* dest, size_t quantity, const std::vector& src, const std::vector& weights) { + + if (checkOverflow(dest, quantity, src, weights) != 0) { + return -1; + } + + for (size_t i = 0; i < dest->GetDepth(); ++i) { + for (size_t j = 0; j < dest->GetWidth(); ++j) { + int64_t item_count = 0; + for (size_t k = 0; k < quantity; ++k) { + item_count += static_cast(src[k]->array_[(i * dest->GetWidth()) + j]) * weights[k]; + } + dest->GetArray()[(i * dest->GetWidth()) + j] = item_count; + } + } + + for (size_t i = 0; i < quantity; ++i) { + dest->GetCounter() += src[i]->GetCounter() * weights[i]; + } + + return 0; +} + +int CMSMergeParams(const CMSketch::MergeParams& params) { + return CMSketch::Merge(params.dest, params.num_keys, params.cms_array, params.weights); +} + +int CMSketch::checkOverflow(CMSketch* dest, size_t quantity, const std::vector& src, const std::vector& weights) { + int64_t item_count = 0; + int64_t cms_count = 0; + size_t width = dest->GetWidth(); + size_t depth = dest->GetDepth(); + + for (size_t i = 0; i < depth; ++i) { + for (size_t j = 0; j < width; ++j) { + item_count = 0; + for (size_t k = 0; k < quantity; ++k) { + int64_t mul = 0; + + if (__builtin_mul_overflow(src[k]->GetArray()[(i * width) + j], weights[k], &mul) || + (__builtin_add_overflow(item_count, mul, &item_count))) { + return -1; + } + } + + if (item_count < 0 || item_count > UINT32_MAX) { + return -1; + } + } + } + + for (size_t i = 0; i < quantity; ++i) { + int64_t mul = 0; + + if (__builtin_mul_overflow(src[i]->GetCounter(), weights[i], &mul) || + (__builtin_add_overflow(cms_count, mul, &cms_count))) { + return -1; + } + } + + if (cms_count < 0) { + return -1; + } + + return 0; +} \ No newline at end of file diff --git a/src/types/cms.h b/src/types/cms.h new file mode 100644 index 00000000000..08040692c39 --- /dev/null +++ b/src/types/cms.h @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include +#include +#include +#include "vendor/murmurhash2.h" +#include + + +class CMSketch { + public: + explicit CMSketch(uint32_t width = 0, uint32_t depth = 0, uint64_t counter = 0, std::vector array = {}) + : width_(width), depth_(depth), counter_(counter), array_(array.empty() ? std::vector(width * depth, 0) : std::move(array)) {} + + ~CMSketch() = default; + + static CMSketch* NewCMSketch(size_t width, size_t depth) { return new CMSketch(width, depth); } + + static void CMSDimFromProb(double error, double delta, size_t &width, size_t &depth); + + size_t IncrBy(const char* item, size_t item_len, size_t value); + + size_t Query(const char* item, size_t item_len) const; + + static int Merge(CMSketch* dest, size_t quantity, const std::vector& src, const std::vector& weights); + + struct MergeParams { + CMSketch* dest; + size_t num_keys; + std::vector cms_array; + std::vector weights; + }; + + int CMSMergeParams(const MergeParams& params); + + uint64_t& GetCounter() { return counter_; } + std::vector& GetArray() { return array_; } + + const uint64_t& GetCounter() const { return counter_; } + const std::vector& GetArray() const { return array_; } + + size_t GetWidth() const { return width_; } + size_t GetDepth() const { return depth_; } + + private: + size_t width_; + size_t depth_; + uint64_t counter_; + std::vector array_; + + static uint32_t hllMurMurHash64A(const char* item, size_t item_len, size_t i) { return HllMurMurHash64A(item, static_cast(item_len), i); } + + static int checkOverflow(CMSketch* dest, size_t quantity, const std::vector& src, const std::vector& weights); +}; \ No newline at end of file diff --git a/src/types/redis_cms.cc b/src/types/redis_cms.cc new file mode 100644 index 00000000000..9990df78f28 --- /dev/null +++ b/src/types/redis_cms.cc @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + + #include "redis_cms.h" + + #include "cms.h" + #include + + #include "cms.h" + #include "vendor/murmurhash2.h" + +namespace redis { + +rocksdb::Status CMS::GetMetadata(Database::GetOptions get_options, const Slice &ns_key, + CountMinSketchMetadata *metadata) { + return Database::GetMetadata(get_options, {kRedisCountMinSketch}, ns_key, metadata); +} + +rocksdb::Status CMS::IncrBy(const Slice &user_key, const std::unordered_map &elements) { + std::string ns_key = AppendNamespacePrefix(user_key); + + LockGuard guard(storage_->GetLockManager(), ns_key); + CountMinSketchMetadata metadata{}; + rocksdb::Status s = GetMetadata(GetOptions(), ns_key, &metadata); + if (!s.ok() && !s.IsNotFound()) { return s; } + + auto batch = storage_->GetWriteBatchBase(); + WriteBatchLogData log_data(kRedisCountMinSketch); + batch->PutLogData(log_data.Encode()); + + CMSketch cms(metadata.width, metadata.depth, metadata.counter,metadata.array); + + if (elements.empty()) { + return rocksdb::Status::OK(); + } + + for (auto &element : elements) { + cms.IncrBy(element.first.data(), element.first.size(), element.second); + metadata.counter += element.second; + } + + metadata.array = std::move(cms.GetArray()); + + std::string bytes; + metadata.Encode(&bytes); + batch->Put(metadata_cf_handle_, ns_key, bytes); + + return storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); +} + +rocksdb::Status CMS::Info(const Slice &user_key, std::vector *ret) { + std::string ns_key = AppendNamespacePrefix(user_key); + + LockGuard guard(storage_->GetLockManager(), ns_key); + CountMinSketchMetadata metadata{}; + rocksdb::Status s = GetMetadata(GetOptions(), ns_key, &metadata); + if (!s.ok() && !s.IsNotFound()) { return s; } + + ret->emplace_back(metadata.width); + ret->emplace_back(metadata.depth); + ret->emplace_back(metadata.counter); + + return rocksdb::Status::OK(); +}; + +rocksdb::Status CMS::InitByDim(const Slice &user_key, uint32_t width, uint32_t depth) { + std::string ns_key = AppendNamespacePrefix(user_key); + + LockGuard guard(storage_->GetLockManager(), ns_key); + CountMinSketchMetadata metadata{}; + + rocksdb::Status s = GetMetadata(GetOptions(), ns_key, &metadata); + if (!s.IsNotFound()) { + return s; + } + auto batch = storage_->GetWriteBatchBase(); + WriteBatchLogData log_data(kRedisCountMinSketch); + batch->PutLogData(log_data.Encode()); + + metadata.counter = 0; + metadata.width = width; + metadata.depth = depth; + metadata.array = std::vector(width*depth, 0); + + std::string bytes; + metadata.Encode(&bytes); + batch->Put(metadata_cf_handle_, ns_key, bytes); + + return storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); +}; + +rocksdb::Status CMS::InitByProb(const Slice &user_key, double error, double delta) { + std::string ns_key = AppendNamespacePrefix(user_key); + + LockGuard guard(storage_->GetLockManager(), ns_key); + CountMinSketchMetadata metadata{}; + + rocksdb::Status s = GetMetadata(GetOptions(), ns_key, &metadata); + if (!s.IsNotFound()) { + return s; + } + auto batch = storage_->GetWriteBatchBase(); + WriteBatchLogData log_data(kRedisCountMinSketch); + batch->PutLogData(log_data.Encode()); + + CMSketch cms; + size_t width = 0; + size_t depth = 0; + cms.CMSDimFromProb(error, delta, width, depth); + + metadata.width = width; + metadata.depth = depth; + metadata.counter = cms.GetCounter(); + metadata.array = std::move(cms.GetArray()); + + std::string bytes; + metadata.Encode(&bytes); + batch->Put(metadata_cf_handle_, ns_key, bytes); + + return storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); +}; + + +rocksdb::Status CMS::Query(const Slice &user_key, const std::vector &elements, std::vector &counters) { + std::string ns_key = AppendNamespacePrefix(user_key); + + LockGuard guard(storage_->GetLockManager(), ns_key); + CountMinSketchMetadata metadata{}; + + rocksdb::Status s = GetMetadata(GetOptions(), ns_key, &metadata); + if (!s.ok() && !s.IsNotFound()) { return s; } + if (s.IsNotFound()) { + counters.assign(elements.size(), 0); + return rocksdb::Status::NotFound(); + } + + CMSketch cms(metadata.width, metadata.depth, metadata.counter, metadata.array); + + for (auto &element : elements) { + counters.push_back(cms.Query(element.data(), element.size())); + } + return rocksdb::Status::OK(); +}; + +} // namespace redis + diff --git a/src/types/redis_cms.h b/src/types/redis_cms.h new file mode 100644 index 00000000000..5b4d777f3ee --- /dev/null +++ b/src/types/redis_cms.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include "cms.h" +#include "storage/redis_db.h" +#include "storage/redis_metadata.h" + +namespace redis { + +class CMS : public Database { + public: + explicit CMS(engine::Storage *storage, const std::string &ns): Database(storage, ns) {} + + rocksdb::Status IncrBy(const Slice &user_key, const std::unordered_map &elements); + rocksdb::Status Info(const Slice &user_key, std::vector *ret); + rocksdb::Status InitByDim(const Slice &user_key, uint32_t width, uint32_t depth); + rocksdb::Status InitByProb(const Slice &user_key, double error, double delta); + // rocksdb::Status Merge(const std::vector &user_keys, const std::vector &source_user_keys); + rocksdb::Status Query(const Slice &user_key, const std::vector &elements, std::vector &counters); + + private: + [[nodiscard]] rocksdb::Status GetMetadata(Database::GetOptions get_options, const Slice &ns_key, + CountMinSketchMetadata *metadata); + + // [[nodiscard]] rocksdb::Status mergeUserKeys(Database::GetOptions get_options, const std::vector &user_keys, + // std::vector *register_segments); +}; + +} // namespace redis From bb48c5b18d2f77af52243096afd7d57ba0a73ea0 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Fri, 6 Sep 2024 22:50:48 -0400 Subject: [PATCH 02/25] adding redis_cms --- src/commands/cmd_cms.cc | 236 ++++++++++++++------------------ src/commands/command_parser.h | 2 +- src/commands/commander.h | 1 + src/storage/redis_metadata.cc | 41 +++++- src/storage/redis_metadata.h | 19 ++- src/types/cms.cc | 155 +++++++++++---------- src/types/cms.h | 56 ++++---- src/types/redis_cms.cc | 65 +++++---- src/types/redis_cms.h | 13 +- tests/cppunit/types/cms_test.cc | 132 ++++++++++++++++++ 10 files changed, 449 insertions(+), 271 deletions(-) create mode 100644 tests/cppunit/types/cms_test.cc diff --git a/src/commands/cmd_cms.cc b/src/commands/cmd_cms.cc index 9399d0cace2..3a79577ce18 100644 --- a/src/commands/cmd_cms.cc +++ b/src/commands/cmd_cms.cc @@ -19,7 +19,6 @@ */ #include - #include #include "commander.h" @@ -29,161 +28,136 @@ namespace redis { -// CMS.INCRBY key item increment [item increment ...] -class CommandIncrBy final : public Commander { - public: - Status Execute(Server *srv, Connection *conn, std::string *output) override { - if ((args_.size() - 2) % 2 != 0) { - return Status::RedisTryAgain; - } - - redis::CMS cms(srv->storage, conn->GetNamespace()); - rocksdb::Status s; - std::unordered_map elements; - for (int i = 2; i < args_.size(); i += 2) { - std::string key = args_[i]; - uint64_t value = 0; - try { - value = std::stoull(args_[i + 1]); - } catch (const std::exception &e) { - return Status::InvalidArgument; - } - elements[key] = value; +/// CMS.INCRBY key item increment [item increment ...] +class CommandCMSIncrBy final : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + if ((args_.size() - 2) % 2 != 0) { + return Status::RedisTryAgain; + } + redis::CMS cms(srv->storage, conn->GetNamespace()); + rocksdb::Status s; + std::unordered_map elements; + for (size_t i = 2; i < args_.size(); i += 2) { + std::string key = args_[i]; + uint64_t value = 0; + try { + value = std::stoull(args_[i + 1]); + } catch (const std::exception &e) { + return Status::InvalidArgument; } + elements[key] = value; + } - s = cms.IncrBy(args_[1], elements); - - if (!s.ok()) { - return {Status::RedisExecErr, s.ToString()}; - } - return Status::OK(); + s = cms.IncrBy(args_[1], elements); + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; } -}; -// CMS.INFO key -class CommandInfo final : public Commander { - public: - Status Execute(Server *srv, Connection *conn, std::string *output) override { - redis::CMS cms(srv->storage, conn->GetNamespace()); - rocksdb::Status s; - std::unordered_map elements; - std::vector ret{}; + *output = redis::SimpleString("OK"); + return Status::OK(); + } +}; - s = cms.Info(args_[1], &ret); +/// CMS.INFO key +class CommandCMSInfo final : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::CMS cms(srv->storage, conn->GetNamespace()); + rocksdb::Status s; + std::unordered_map elements; + std::vector ret{}; - if (!s.ok() && !s.IsNotFound()) { - return {Status::RedisExecErr, s.ToString()}; - } + s = cms.Info(args_[1], &ret); - *output = redis::Array({ - redis::BulkString("width"), - redis::Integer(ret[0]), - redis::BulkString("depth"), - redis::Integer(ret[1]), - redis::BulkString("count"), - redis::Integer(ret[2]) - }); - - return Status::OK(); + if (s.IsNotFound()) { + return {Status::RedisExecErr, s.ToString()}; } -}; -// CMS.INITBYDIM key width depth -class CommandInitByDim final : public Commander { - public: - Status Execute(Server *srv, Connection *conn, std::string *output) override { - redis::CMS cms(srv->storage, conn->GetNamespace()); - rocksdb::Status s; - - try { - uint64_t width = std::stoull(args_[2]); - uint64_t depth = std::stoull(args_[3]); - - s = cms.InitByDim(args_[1], width, depth); + if (!s.ok() && !s.IsNotFound()) { + return {Status::RedisExecErr, s.ToString()}; + } - } catch (const std::exception &e) { - return {Status::RedisExecErr, "Invalid dimensions: " + std::string(e.what())}; - } + *output = redis::Array({redis::BulkString("width"), redis::Integer(ret[0]), redis::BulkString("depth"), + redis::Integer(ret[1]), redis::BulkString("count"), redis::Integer(ret[2])}); - if (!s.ok()) { - return {Status::RedisExecErr, s.ToString()}; - } + return Status::OK(); + } +}; - return Status::OK(); +/// CMS.INITBYDIM key width depth +class CommandCMSInitByDim final : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::CMS cms(srv->storage, conn->GetNamespace()); + rocksdb::Status s; + uint64_t width = std::stoull(args_[2]); + uint64_t depth = std::stoull(args_[3]); + + s = cms.InitByDim(args_[1], width, depth); + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; } -}; -// CMS.INITBYPROB key error probability -class CommandInitByProb final : public Commander { - public: - Status Execute(Server *srv, Connection *conn, std::string *output) override { - redis::CMS cms(srv->storage, conn->GetNamespace()); - rocksdb::Status s; + *output = redis::SimpleString("OK"); + return Status::OK(); + } +}; - try { - uint64_t error = std::stoull(args_[2]); - uint64_t delta = std::stoull(args_[3]); +/// CMS.INITBYPROB key error probability +class CommandCMSInitByProb final : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::CMS cms(srv->storage, conn->GetNamespace()); + rocksdb::Status s; + double error = std::stod(args_[2]); + double delta = std::stod(args_[3]); + + s = cms.InitByProb(args_[1], error, delta); + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } - s = cms.InitByDim(args_[1], error, delta); + *output = redis::SimpleString("OK"); + return Status::OK(); + } +}; - } catch (const std::exception &e) { - return {Status::RedisExecErr, "Invalid dimensions: " + std::string(e.what())}; - } +/// CMS.QUERY key item [item ...] +class CommandCMSQuery final : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::CMS cms(srv->storage, conn->GetNamespace()); + rocksdb::Status s; - if (!s.ok()) { - return {Status::RedisExecErr, s.ToString()}; - } + std::vector counters{}; + std::vector elements; - return Status::OK(); + for (size_t i = 2; i < args_.size(); ++i) { + elements.emplace_back(args_[i]); } -}; - -// CMS.MERGE destination numKeys source [source ...] [WEIGHTS weight [weight ...]] -// class CommandMerge final : public Commander { -// public: -// Status Execute(Server *srv, Connection *conn, std::string *output) override { - -// } -// }; - -// CMS.QUERY key item [item ...] -class CommandQuery final : public Commander { - public: - Status Execute(Server *srv, Connection *conn, std::string *output) override { - redis::CMS cms(srv->storage, conn->GetNamespace()); - rocksdb::Status s; - - std::vector counters{}; - std::vector elements; - - for (int i = 2; i < args_.size(); ++i) { - elements.emplace_back(args_[i]); - } - s = cms.Query(args_[1], elements, counters); - - if (!s.ok()) { - return {Status::RedisExecErr, s.ToString()}; - } + s = cms.Query(args_[1], elements, counters); - std::vector output_values; - output_values.reserve(counters.size()); - for (const auto &counter : counters) { - output_values.push_back(std::to_string(counter)); - } - - *output = redis::ArrayOfBulkStrings(output_values); + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } - return Status::OK(); + std::vector output_values; + output_values.reserve(counters.size()); + for (const auto &counter : counters) { + output_values.push_back(std::to_string(counter)); } -}; + *output = redis::ArrayOfBulkStrings(output_values); -REDIS_REGISTER_COMMANDS(CMS, MakeCmdAttr("incrby", -3, "write", 0, 0, 0), - MakeCmdAttr("info", 1, "read-only", 0, 0, 0), - MakeCmdAttr("initbydim", 3, "write", 0, 0, 0), - MakeCmdAttr("initbyprob", 3, "write", 0, 0, 0), - // MakeCmdAttr("merge", -3, "write", 0, 0, 0), - MakeCmdAttr("query", -2, "read-only", 0, 0, 0), ); + return Status::OK(); + } +}; +REDIS_REGISTER_COMMANDS(CMS, MakeCmdAttr("cms.incrby", -4, "write", 0, 0, 0), + MakeCmdAttr("cms.info", 2, "read-only", 0, 0, 0), + MakeCmdAttr("cms.initbydim", 4, "write", 0, 0, 0), + MakeCmdAttr("cms.initbyprob", 4, "write", 0, 0, 0), + MakeCmdAttr("cms.query", -3, "read-only", 0, 0, 0), ); } // namespace redis \ No newline at end of file diff --git a/src/commands/command_parser.h b/src/commands/command_parser.h index a4e06e1c563..673a733e362 100644 --- a/src/commands/command_parser.h +++ b/src/commands/command_parser.h @@ -33,7 +33,7 @@ template struct MoveIterator : Iter { - explicit MoveIterator(Iter iter) : Iter(iter){}; + explicit MoveIterator(Iter iter) : Iter(iter) {}; typename Iter::value_type&& operator*() const { return std::move(this->Iter::operator*()); } }; diff --git a/src/commands/commander.h b/src/commands/commander.h index 39d55bfc6f2..d4aa18f1c36 100644 --- a/src/commands/commander.h +++ b/src/commands/commander.h @@ -73,6 +73,7 @@ enum class CommandCategory : uint8_t { Bit, BloomFilter, Cluster, + CMS, Function, Geo, Hash, diff --git a/src/storage/redis_metadata.cc b/src/storage/redis_metadata.cc index 76403faaef3..1dcf8730ac3 100644 --- a/src/storage/redis_metadata.cc +++ b/src/storage/redis_metadata.cc @@ -329,7 +329,8 @@ bool Metadata::ExpireAt(uint64_t expired_ts) const { bool Metadata::IsSingleKVType() const { return Type() == kRedisString || Type() == kRedisJson; } bool Metadata::IsEmptyableType() const { - return IsSingleKVType() || Type() == kRedisStream || Type() == kRedisBloomFilter || Type() == kRedisHyperLogLog; + return IsSingleKVType() || Type() == kRedisStream || Type() == kRedisBloomFilter || Type() == kRedisHyperLogLog || + Type() == kRedisCountMinSketch; } bool Metadata::Expired() const { return ExpireAt(util::GetTimeStampMS()); } @@ -495,3 +496,41 @@ rocksdb::Status HyperLogLogMetadata::Decode(Slice *input) { return rocksdb::Status::OK(); } + +void CountMinSketchMetadata::Encode(std::string *dst) const { + Metadata::Encode(dst); + PutFixed32(dst, width); + PutFixed32(dst, depth); + PutFixed64(dst, counter); + for (const auto &count : array) { + PutFixed32(dst, count); + } +} + +rocksdb::Status CountMinSketchMetadata::Decode(Slice *input) { + if (auto s = Metadata::Decode(input); !s.ok()) { + return s; + } + if (!GetFixed32(input, &width)) { + return rocksdb::Status::InvalidArgument(kErrMetadataTooShort); + } + if (!GetFixed32(input, &depth)) { + return rocksdb::Status::InvalidArgument(kErrMetadataTooShort); + } + if (!GetFixed64(input, &counter)) { + return rocksdb::Status::InvalidArgument(kErrMetadataTooShort); + } + + size_t array_size = width * depth; + array.resize(array_size); + + for (size_t i = 0; i < array_size; ++i) { + uint32_t count = 0; + if (!GetFixed32(input, &count)) { + return rocksdb::Status::InvalidArgument(kErrMetadataTooShort); + } + array[i] = count; + } + + return rocksdb::Status::OK(); +} diff --git a/src/storage/redis_metadata.h b/src/storage/redis_metadata.h index 5590609be37..f7aa8e091a0 100644 --- a/src/storage/redis_metadata.h +++ b/src/storage/redis_metadata.h @@ -50,6 +50,7 @@ enum RedisType : uint8_t { kRedisBloomFilter = 9, kRedisJson = 10, kRedisHyperLogLog = 11, + kRedisCountMinSketch = 12, }; struct RedisTypes { @@ -91,9 +92,9 @@ enum RedisCommand { kRedisCmdLMove, }; -const std::vector RedisTypeNames = {"none", "string", "hash", "list", - "set", "zset", "bitmap", "sortedint", - "stream", "MBbloom--", "ReJSON-RL", "hyperloglog"}; +const std::vector RedisTypeNames = {"none", "string", "hash", "list", "set", + "zset", "bitmap", "sortedint", "stream", "MBbloom--", + "ReJSON-RL", "hyperloglog", "countminsketch"}; constexpr const char *kErrMsgWrongType = "WRONGTYPE Operation against a key holding the wrong kind of value"; constexpr const char *kErrMsgKeyExpired = "the key was expired"; @@ -335,3 +336,15 @@ class HyperLogLogMetadata : public Metadata { EncodeType encode_type = EncodeType::DENSE; }; + +class CountMinSketchMetadata : public Metadata { + public: + uint32_t width; + uint32_t depth; + uint64_t counter; + std::vector array; + + explicit CountMinSketchMetadata(bool generate_version = true) : Metadata(kRedisCountMinSketch, generate_version) {} + void Encode(std::string *dst) const override; + rocksdb::Status Decode(Slice *input) override; +}; diff --git a/src/types/cms.cc b/src/types/cms.cc index 25c168d29d7..66809ea060e 100644 --- a/src/types/cms.cc +++ b/src/types/cms.cc @@ -19,107 +19,110 @@ */ #include "cms.h" -#include + +#include #include #include -#include +#include -void CMSketch::CMSDimFromProb(double error, double delta, size_t &width, size_t &depth) { - width = std::ceil(2 / error); - depth = std::ceil(std::log10(delta) / std::log10(0.5)); -} +#include "glog/logging.h" +void CMSketch::CMSDimFromProb(double error, double delta, uint32_t& width, uint32_t& depth) { + LOG(INFO) << error << delta; + width = std::ceil(2 / error); + depth = std::ceil(std::log10(delta) / std::log10(0.5)); +} size_t CMSketch::IncrBy(const char* item, size_t item_len, size_t value) { - size_t min_count = std::numeric_limits::max(); - - for (size_t i = 0; i < depth_; ++i) { - // check over the static casting the parameter to an int - uint32_t hash = HllMurMurHash64A(item, static_cast(item_len), i); - size_t loc = (hash % width_) + (i * width_); - array_[loc] += value; - if (array_[loc] < value) { - array_[loc] = UINT32_MAX; - } - min_count = std::min(min_count, static_cast(array_[loc])); + size_t min_count = std::numeric_limits::max(); + + for (size_t i = 0; i < depth_; ++i) { + uint32_t hash = HllMurMurHash64A(item, static_cast(item_len), i); + size_t loc = (hash % width_) + (i * width_); + array_[loc] += value; + if (array_[loc] < value) { + array_[loc] = UINT32_MAX; } - counter_ += value; - return min_count; + min_count = std::min(min_count, static_cast(array_[loc])); + } + counter_ += value; + return min_count; } size_t CMSketch::Query(const char* item, size_t item_len) const { - size_t min_count = std::numeric_limits::max(); + size_t min_count = std::numeric_limits::max(); - for (size_t i = 0; i < depth_; ++i) { - uint32_t hash = HllMurMurHash64A(item, static_cast(item_len), i); - min_count = std::min(min_count, static_cast(array_[(hash % width_) + (i * width_)])); - } - return min_count; + for (size_t i = 0; i < depth_; ++i) { + uint32_t hash = HllMurMurHash64A(item, static_cast(item_len), i); + min_count = std::min(min_count, static_cast(array_[(hash % width_) + (i * width_)])); + } + return min_count; } -int CMSketch::Merge(CMSketch* dest, size_t quantity, const std::vector& src, const std::vector& weights) { - - if (checkOverflow(dest, quantity, src, weights) != 0) { - return -1; +int CMSketch::Merge(CMSketch* dest, size_t quantity, const std::vector& src, + const std::vector& weights) { + if (checkOverflow(dest, quantity, src, weights) != 0) { + return -1; + } + + for (size_t i = 0; i < dest->GetDepth(); ++i) { + for (size_t j = 0; j < dest->GetWidth(); ++j) { + int64_t item_count = 0; + for (size_t k = 0; k < quantity; ++k) { + item_count += static_cast(src[k]->array_[(i * dest->GetWidth()) + j]) * weights[k]; + } + dest->GetArray()[(i * dest->GetWidth()) + j] = item_count; } + } - for (size_t i = 0; i < dest->GetDepth(); ++i) { - for (size_t j = 0; j < dest->GetWidth(); ++j) { - int64_t item_count = 0; - for (size_t k = 0; k < quantity; ++k) { - item_count += static_cast(src[k]->array_[(i * dest->GetWidth()) + j]) * weights[k]; - } - dest->GetArray()[(i * dest->GetWidth()) + j] = item_count; - } - } - - for (size_t i = 0; i < quantity; ++i) { - dest->GetCounter() += src[i]->GetCounter() * weights[i]; - } + for (size_t i = 0; i < quantity; ++i) { + dest->GetCounter() += src[i]->GetCounter() * weights[i]; + } - return 0; + return 0; } -int CMSMergeParams(const CMSketch::MergeParams& params) { - return CMSketch::Merge(params.dest, params.num_keys, params.cms_array, params.weights); +int CMSMergeParams(const CMSketch::MergeParams& params) { + return CMSketch::Merge(params.dest, params.num_keys, params.cms_array, params.weights); } -int CMSketch::checkOverflow(CMSketch* dest, size_t quantity, const std::vector& src, const std::vector& weights) { - int64_t item_count = 0; - int64_t cms_count = 0; - size_t width = dest->GetWidth(); - size_t depth = dest->GetDepth(); - - for (size_t i = 0; i < depth; ++i) { - for (size_t j = 0; j < width; ++j) { - item_count = 0; - for (size_t k = 0; k < quantity; ++k) { - int64_t mul = 0; - - if (__builtin_mul_overflow(src[k]->GetArray()[(i * width) + j], weights[k], &mul) || - (__builtin_add_overflow(item_count, mul, &item_count))) { - return -1; - } - } - - if (item_count < 0 || item_count > UINT32_MAX) { - return -1; - } - } - } - - for (size_t i = 0; i < quantity; ++i) { +int CMSketch::checkOverflow(CMSketch* dest, size_t quantity, const std::vector& src, + const std::vector& weights) { + int64_t item_count = 0; + int64_t cms_count = 0; + size_t width = dest->GetWidth(); + size_t depth = dest->GetDepth(); + + for (size_t i = 0; i < depth; ++i) { + for (size_t j = 0; j < width; ++j) { + item_count = 0; + for (size_t k = 0; k < quantity; ++k) { int64_t mul = 0; - if (__builtin_mul_overflow(src[i]->GetCounter(), weights[i], &mul) || - (__builtin_add_overflow(cms_count, mul, &cms_count))) { - return -1; + if (__builtin_mul_overflow(src[k]->GetArray()[(i * width) + j], weights[k], &mul) || + (__builtin_add_overflow(item_count, mul, &item_count))) { + return -1; } - } + } - if (cms_count < 0) { + if (item_count < 0 || item_count > UINT32_MAX) { return -1; + } } + } + + for (size_t i = 0; i < quantity; ++i) { + int64_t mul = 0; + + if (__builtin_mul_overflow(src[i]->GetCounter(), weights[i], &mul) || + (__builtin_add_overflow(cms_count, mul, &cms_count))) { + return -1; + } + } + + if (cms_count < 0) { + return -1; + } - return 0; + return 0; } \ No newline at end of file diff --git a/src/types/cms.h b/src/types/cms.h index 08040692c39..a95295d0854 100644 --- a/src/types/cms.h +++ b/src/types/cms.h @@ -21,38 +21,43 @@ #pragma once #include + #include -#include -#include "vendor/murmurhash2.h" #include +#include +#include "vendor/murmurhash2.h" class CMSketch { - public: - explicit CMSketch(uint32_t width = 0, uint32_t depth = 0, uint64_t counter = 0, std::vector array = {}) - : width_(width), depth_(depth), counter_(counter), array_(array.empty() ? std::vector(width * depth, 0) : std::move(array)) {} + public: + explicit CMSketch(uint32_t width = 0, uint32_t depth = 0, uint64_t counter = 0, std::vector array = {}) + : width_(width), + depth_(depth), + counter_(counter), + array_(array.empty() ? std::vector(width * depth, 0) : std::move(array)) {} - ~CMSketch() = default; + ~CMSketch() = default; - static CMSketch* NewCMSketch(size_t width, size_t depth) { return new CMSketch(width, depth); } + static CMSketch* NewCMSketch(size_t width, size_t depth) { return new CMSketch(width, depth); } - static void CMSDimFromProb(double error, double delta, size_t &width, size_t &depth); + static void CMSDimFromProb(double error, double delta, uint32_t& width, uint32_t& depth); - size_t IncrBy(const char* item, size_t item_len, size_t value); + size_t IncrBy(const char* item, size_t item_len, size_t value); - size_t Query(const char* item, size_t item_len) const; + size_t Query(const char* item, size_t item_len) const; - static int Merge(CMSketch* dest, size_t quantity, const std::vector& src, const std::vector& weights); + static int Merge(CMSketch* dest, size_t quantity, const std::vector& src, + const std::vector& weights); - struct MergeParams { - CMSketch* dest; - size_t num_keys; - std::vector cms_array; - std::vector weights; + struct MergeParams { + CMSketch* dest; + size_t num_keys; + std::vector cms_array; + std::vector weights; }; int CMSMergeParams(const MergeParams& params); - + uint64_t& GetCounter() { return counter_; } std::vector& GetArray() { return array_; } @@ -62,13 +67,16 @@ class CMSketch { size_t GetWidth() const { return width_; } size_t GetDepth() const { return depth_; } - private: - size_t width_; - size_t depth_; - uint64_t counter_; - std::vector array_; + private: + size_t width_; + size_t depth_; + uint64_t counter_; + std::vector array_; - static uint32_t hllMurMurHash64A(const char* item, size_t item_len, size_t i) { return HllMurMurHash64A(item, static_cast(item_len), i); } + static uint32_t hllMurMurHash64A(const char* item, size_t item_len, size_t i) { + return HllMurMurHash64A(item, static_cast(item_len), i); + } - static int checkOverflow(CMSketch* dest, size_t quantity, const std::vector& src, const std::vector& weights); + static int checkOverflow(CMSketch* dest, size_t quantity, const std::vector& src, + const std::vector& weights); }; \ No newline at end of file diff --git a/src/types/redis_cms.cc b/src/types/redis_cms.cc index 9990df78f28..c38a1a4fbf0 100644 --- a/src/types/redis_cms.cc +++ b/src/types/redis_cms.cc @@ -18,35 +18,35 @@ * */ - #include "redis_cms.h" - - #include "cms.h" - #include +#include "redis_cms.h" - #include "cms.h" - #include "vendor/murmurhash2.h" +#include + +#include "cms.h" namespace redis { rocksdb::Status CMS::GetMetadata(Database::GetOptions get_options, const Slice &ns_key, - CountMinSketchMetadata *metadata) { + CountMinSketchMetadata *metadata) { return Database::GetMetadata(get_options, {kRedisCountMinSketch}, ns_key, metadata); } rocksdb::Status CMS::IncrBy(const Slice &user_key, const std::unordered_map &elements) { std::string ns_key = AppendNamespacePrefix(user_key); - + LockGuard guard(storage_->GetLockManager(), ns_key); CountMinSketchMetadata metadata{}; rocksdb::Status s = GetMetadata(GetOptions(), ns_key, &metadata); - if (!s.ok() && !s.IsNotFound()) { return s; } + if (!s.ok() && !s.IsNotFound()) { + return s; + } auto batch = storage_->GetWriteBatchBase(); WriteBatchLogData log_data(kRedisCountMinSketch); batch->PutLogData(log_data.Encode()); - CMSketch cms(metadata.width, metadata.depth, metadata.counter,metadata.array); - + CMSketch cms(metadata.width, metadata.depth, metadata.counter, metadata.array); + if (elements.empty()) { return rocksdb::Status::OK(); } @@ -67,11 +67,14 @@ rocksdb::Status CMS::IncrBy(const Slice &user_key, const std::unordered_map *ret) { std::string ns_key = AppendNamespacePrefix(user_key); - + LockGuard guard(storage_->GetLockManager(), ns_key); CountMinSketchMetadata metadata{}; rocksdb::Status s = GetMetadata(GetOptions(), ns_key, &metadata); - if (!s.ok() && !s.IsNotFound()) { return s; } + + if (!s.ok() || s.IsNotFound()) { + return rocksdb::Status::NotFound(); + } ret->emplace_back(metadata.width); ret->emplace_back(metadata.depth); @@ -82,22 +85,24 @@ rocksdb::Status CMS::Info(const Slice &user_key, std::vector *ret) { rocksdb::Status CMS::InitByDim(const Slice &user_key, uint32_t width, uint32_t depth) { std::string ns_key = AppendNamespacePrefix(user_key); - + LockGuard guard(storage_->GetLockManager(), ns_key); CountMinSketchMetadata metadata{}; rocksdb::Status s = GetMetadata(GetOptions(), ns_key, &metadata); + if (!s.IsNotFound()) { return s; } + auto batch = storage_->GetWriteBatchBase(); WriteBatchLogData log_data(kRedisCountMinSketch); batch->PutLogData(log_data.Encode()); - metadata.counter = 0; metadata.width = width; metadata.depth = depth; - metadata.array = std::vector(width*depth, 0); + metadata.counter = 0; + metadata.array = std::vector(width * depth, 0); std::string bytes; metadata.Encode(&bytes); @@ -108,7 +113,7 @@ rocksdb::Status CMS::InitByDim(const Slice &user_key, uint32_t width, uint32_t d rocksdb::Status CMS::InitByProb(const Slice &user_key, double error, double delta) { std::string ns_key = AppendNamespacePrefix(user_key); - + LockGuard guard(storage_->GetLockManager(), ns_key); CountMinSketchMetadata metadata{}; @@ -121,14 +126,15 @@ rocksdb::Status CMS::InitByProb(const Slice &user_key, double error, double delt batch->PutLogData(log_data.Encode()); CMSketch cms; - size_t width = 0; - size_t depth = 0; + uint32_t width = 0; + uint32_t depth = 0; cms.CMSDimFromProb(error, delta, width, depth); metadata.width = width; metadata.depth = depth; - metadata.counter = cms.GetCounter(); - metadata.array = std::move(cms.GetArray()); + metadata.counter = 0; + metadata.array = std::vector(width * depth, 0); + ; std::string bytes; metadata.Encode(&bytes); @@ -137,18 +143,20 @@ rocksdb::Status CMS::InitByProb(const Slice &user_key, double error, double delt return storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); }; - -rocksdb::Status CMS::Query(const Slice &user_key, const std::vector &elements, std::vector &counters) { +rocksdb::Status CMS::Query(const Slice &user_key, const std::vector &elements, + std::vector &counters) { std::string ns_key = AppendNamespacePrefix(user_key); - + LockGuard guard(storage_->GetLockManager(), ns_key); CountMinSketchMetadata metadata{}; rocksdb::Status s = GetMetadata(GetOptions(), ns_key, &metadata); - if (!s.ok() && !s.IsNotFound()) { return s; } + if (s.IsNotFound()) { - counters.assign(elements.size(), 0); - return rocksdb::Status::NotFound(); + counters.assign(elements.size(), 0); + return rocksdb::Status::OK(); + } else if (!s.ok()) { + return s; } CMSketch cms(metadata.width, metadata.depth, metadata.counter, metadata.array); @@ -159,5 +167,4 @@ rocksdb::Status CMS::Query(const Slice &user_key, const std::vector return rocksdb::Status::OK(); }; -} // namespace redis - +} // namespace redis diff --git a/src/types/redis_cms.h b/src/types/redis_cms.h index 5b4d777f3ee..4139de1a29e 100644 --- a/src/types/redis_cms.h +++ b/src/types/redis_cms.h @@ -20,7 +20,7 @@ #pragma once -#include "cms.h" +#include "cms.h" #include "storage/redis_db.h" #include "storage/redis_metadata.h" @@ -28,21 +28,22 @@ namespace redis { class CMS : public Database { public: - explicit CMS(engine::Storage *storage, const std::string &ns): Database(storage, ns) {} + explicit CMS(engine::Storage *storage, const std::string &ns) : Database(storage, ns) {} rocksdb::Status IncrBy(const Slice &user_key, const std::unordered_map &elements); rocksdb::Status Info(const Slice &user_key, std::vector *ret); rocksdb::Status InitByDim(const Slice &user_key, uint32_t width, uint32_t depth); rocksdb::Status InitByProb(const Slice &user_key, double error, double delta); - // rocksdb::Status Merge(const std::vector &user_keys, const std::vector &source_user_keys); - rocksdb::Status Query(const Slice &user_key, const std::vector &elements, std::vector &counters); + rocksdb::Status Query(const Slice &user_key, const std::vector &elements, + std::vector &counters); private: [[nodiscard]] rocksdb::Status GetMetadata(Database::GetOptions get_options, const Slice &ns_key, CountMinSketchMetadata *metadata); - // [[nodiscard]] rocksdb::Status mergeUserKeys(Database::GetOptions get_options, const std::vector &user_keys, - // std::vector *register_segments); + // TODO (jonathanc-n) + [[nodiscard]] rocksdb::Status mergeUserKeys(Database::GetOptions get_options, const std::vector &user_keys, + std::vector *register_segments); }; } // namespace redis diff --git a/tests/cppunit/types/cms_test.cc b/tests/cppunit/types/cms_test.cc new file mode 100644 index 00000000000..274c3c72166 --- /dev/null +++ b/tests/cppunit/types/cms_test.cc @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include + +#include + +#include "test_base.h" +#include "types/redis_cms.h" + +class RedisCMSketchTest : public TestBase { + protected: + explicit RedisCMSketchTest() : TestBase() { cms_ = std::make_unique(storage_.get(), "cms_ns"); } + ~RedisCMSketchTest() override = default; + + void SetUp() override { + TestBase::SetUp(); + [[maybe_unused]] auto s = cms_->Del("cms"); + for (int x = 1; x <= 3; x++) { + s = cms_->Del("cms" + std::to_string(x)); + } + } + + void TearDown() override { + TestBase::TearDown(); + [[maybe_unused]] auto s = cms_->Del("cms"); + for (int x = 1; x <= 3; x++) { + s = cms_->Del("cms" + std::to_string(x)); + } + } + + std::unique_ptr cms_; +}; + +TEST_F(RedisCMSketchTest, CMSInitByDim) { + ASSERT_TRUE(cms_->InitByDim("cms", 100, 5).ok()); + std::vector info; + ASSERT_TRUE(cms_->Info("cms", &info).ok()); + ASSERT_EQ(info[0], 100); + ASSERT_EQ(info[1], 5); + ASSERT_EQ(info[2], 0); +} + +TEST_F(RedisCMSketchTest, CMSIncrBy) { + std::unordered_map elements = {{"apple", 2}, {"banana", 3}, {"cherry", 1}}; + ASSERT_TRUE(cms_->InitByDim("cms", 100, 5).ok()); + ASSERT_TRUE(cms_->IncrBy("cms", elements).ok()); + + std::vector counts; + ASSERT_TRUE(cms_->Query("cms", {"apple", "banana", "cherry"}, counts).ok()); + + ASSERT_EQ(counts[0], 2); + ASSERT_EQ(counts[1], 3); + ASSERT_EQ(counts[2], 1); + + std::vector info; + ASSERT_TRUE(cms_->Info("cms", &info).ok()); + ASSERT_EQ(info[2], 6); +} + +TEST_F(RedisCMSketchTest, CMSQuery) { + std::unordered_map elements = {{"orange", 5}, {"grape", 3}, {"melon", 2}}; + ASSERT_TRUE(cms_->InitByDim("cms", 100, 5).ok()); + ASSERT_TRUE(cms_->IncrBy("cms", elements).ok()); + + std::vector counts; + ASSERT_TRUE(cms_->Query("cms", {"orange", "grape", "melon", "nonexistent"}, counts).ok()); + + ASSERT_EQ(counts[0], 5); + ASSERT_EQ(counts[1], 3); + ASSERT_EQ(counts[2], 2); + ASSERT_EQ(counts[3], 0); +} + +TEST_F(RedisCMSketchTest, CMSInfo) { + ASSERT_TRUE(cms_->InitByDim("cms", 200, 10).ok()); + + std::vector info; + ASSERT_TRUE(cms_->Info("cms", &info).ok()); + + ASSERT_EQ(info[0], 200); + ASSERT_EQ(info[1], 10); + ASSERT_EQ(info[2], 0); +} + +TEST_F(RedisCMSketchTest, CMSInitByProb) { + ASSERT_TRUE(cms_->InitByProb("cms", 0.001, 0.1).ok()); + + std::vector info; + ASSERT_TRUE(cms_->Info("cms", &info).ok()); + + ASSERT_EQ(info[0], 2000); + ASSERT_EQ(info[1], 4); + ASSERT_EQ(info[2], 0); +} + +TEST_F(RedisCMSketchTest, CMSMultipleKeys) { + std::unordered_map elements1 = {{"apple", 2}, {"banana", 3}}; + std::unordered_map elements2 = {{"cherry", 1}, {"date", 4}}; + + ASSERT_TRUE(cms_->InitByDim("cms1", 100, 5).ok()); + ASSERT_TRUE(cms_->InitByDim("cms2", 100, 5).ok()); + + ASSERT_TRUE(cms_->IncrBy("cms1", elements1).ok()); + ASSERT_TRUE(cms_->IncrBy("cms2", elements2).ok()); + + std::vector counts1, counts2; + ASSERT_TRUE(cms_->Query("cms1", {"apple", "banana"}, counts1).ok()); + ASSERT_TRUE(cms_->Query("cms2", {"cherry", "date"}, counts2).ok()); + + ASSERT_EQ(counts1[0], 2); + ASSERT_EQ(counts1[1], 3); + ASSERT_EQ(counts2[0], 1); + ASSERT_EQ(counts2[1], 4); +} From a1b904a635aa75e73e5d2a94fa0b897040edf3fd Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Fri, 6 Sep 2024 23:30:12 -0400 Subject: [PATCH 03/25] small changes --- src/types/cms.cc | 3 --- src/types/cms.h | 5 ----- 2 files changed, 8 deletions(-) diff --git a/src/types/cms.cc b/src/types/cms.cc index 66809ea060e..834f6eff956 100644 --- a/src/types/cms.cc +++ b/src/types/cms.cc @@ -25,10 +25,7 @@ #include #include -#include "glog/logging.h" - void CMSketch::CMSDimFromProb(double error, double delta, uint32_t& width, uint32_t& depth) { - LOG(INFO) << error << delta; width = std::ceil(2 / error); depth = std::ceil(std::log10(delta) / std::log10(0.5)); } diff --git a/src/types/cms.h b/src/types/cms.h index a95295d0854..ce57a2ad55b 100644 --- a/src/types/cms.h +++ b/src/types/cms.h @@ -20,12 +20,7 @@ #pragma once -#include - -#include -#include #include - #include "vendor/murmurhash2.h" class CMSketch { From ecbaa440bd2f5b0ab776d575a77772ab8433c693 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Sun, 8 Sep 2024 18:26:59 -0400 Subject: [PATCH 04/25] Resolved changes --- src/commands/cmd_cms.cc | 15 +++++++---- src/types/cms.cc | 17 +++++++----- src/types/cms.h | 27 ++++++++++--------- src/types/redis_cms.cc | 47 +++++++++++++++------------------ src/types/redis_cms.h | 14 +++++----- tests/cppunit/types/cms_test.cc | 46 ++++++++++++++++---------------- 6 files changed, 87 insertions(+), 79 deletions(-) diff --git a/src/commands/cmd_cms.cc b/src/commands/cmd_cms.cc index 3a79577ce18..53664e74ea1 100644 --- a/src/commands/cmd_cms.cc +++ b/src/commands/cmd_cms.cc @@ -36,6 +36,7 @@ class CommandCMSIncrBy final : public Commander { return Status::RedisTryAgain; } redis::CMS cms(srv->storage, conn->GetNamespace()); + engine::Context ctx(srv->storage); rocksdb::Status s; std::unordered_map elements; for (size_t i = 2; i < args_.size(); i += 2) { @@ -49,7 +50,7 @@ class CommandCMSIncrBy final : public Commander { elements[key] = value; } - s = cms.IncrBy(args_[1], elements); + s = cms.IncrBy(ctx, args_[1], elements); if (!s.ok()) { return {Status::RedisExecErr, s.ToString()}; } @@ -64,11 +65,12 @@ class CommandCMSInfo final : public Commander { public: Status Execute(Server *srv, Connection *conn, std::string *output) override { redis::CMS cms(srv->storage, conn->GetNamespace()); + engine::Context ctx(srv->storage); rocksdb::Status s; std::unordered_map elements; std::vector ret{}; - s = cms.Info(args_[1], &ret); + s = cms.Info(ctx, args_[1], &ret); if (s.IsNotFound()) { return {Status::RedisExecErr, s.ToString()}; @@ -90,11 +92,12 @@ class CommandCMSInitByDim final : public Commander { public: Status Execute(Server *srv, Connection *conn, std::string *output) override { redis::CMS cms(srv->storage, conn->GetNamespace()); + engine::Context ctx(srv->storage); rocksdb::Status s; uint64_t width = std::stoull(args_[2]); uint64_t depth = std::stoull(args_[3]); - s = cms.InitByDim(args_[1], width, depth); + s = cms.InitByDim(ctx, args_[1], width, depth); if (!s.ok()) { return {Status::RedisExecErr, s.ToString()}; } @@ -109,11 +112,12 @@ class CommandCMSInitByProb final : public Commander { public: Status Execute(Server *srv, Connection *conn, std::string *output) override { redis::CMS cms(srv->storage, conn->GetNamespace()); + engine::Context ctx(srv->storage); rocksdb::Status s; double error = std::stod(args_[2]); double delta = std::stod(args_[3]); - s = cms.InitByProb(args_[1], error, delta); + s = cms.InitByProb(ctx, args_[1], error, delta); if (!s.ok()) { return {Status::RedisExecErr, s.ToString()}; } @@ -128,6 +132,7 @@ class CommandCMSQuery final : public Commander { public: Status Execute(Server *srv, Connection *conn, std::string *output) override { redis::CMS cms(srv->storage, conn->GetNamespace()); + engine::Context ctx(srv->storage); rocksdb::Status s; std::vector counters{}; @@ -137,7 +142,7 @@ class CommandCMSQuery final : public Commander { elements.emplace_back(args_[i]); } - s = cms.Query(args_[1], elements, counters); + s = cms.Query(ctx, args_[1], elements, counters); if (!s.ok()) { return {Status::RedisExecErr, s.ToString()}; diff --git a/src/types/cms.cc b/src/types/cms.cc index 834f6eff956..4699b139f2c 100644 --- a/src/types/cms.cc +++ b/src/types/cms.cc @@ -25,16 +25,19 @@ #include #include -void CMSketch::CMSDimFromProb(double error, double delta, uint32_t& width, uint32_t& depth) { - width = std::ceil(2 / error); - depth = std::ceil(std::log10(delta) / std::log10(0.5)); + +CMSketch::CMSketchDimensions CMSketch::CMSDimFromProb(double error, double delta) { + CMSketchDimensions dims; + dims.width = std::ceil(2 / error); + dims.depth = std::ceil(std::log10(delta) / std::log10(0.5)); + return dims; } -size_t CMSketch::IncrBy(const char* item, size_t item_len, size_t value) { +size_t CMSketch::IncrBy(std::string_view item, size_t value) { size_t min_count = std::numeric_limits::max(); for (size_t i = 0; i < depth_; ++i) { - uint32_t hash = HllMurMurHash64A(item, static_cast(item_len), i); + uint64_t hash = HllMurMurHash64A(item.data(), static_cast(item.size()), i); size_t loc = (hash % width_) + (i * width_); array_[loc] += value; if (array_[loc] < value) { @@ -46,11 +49,11 @@ size_t CMSketch::IncrBy(const char* item, size_t item_len, size_t value) { return min_count; } -size_t CMSketch::Query(const char* item, size_t item_len) const { +size_t CMSketch::Query(std::string_view item) const { size_t min_count = std::numeric_limits::max(); for (size_t i = 0; i < depth_; ++i) { - uint32_t hash = HllMurMurHash64A(item, static_cast(item_len), i); + uint64_t hash = HllMurMurHash64A(item.data(), static_cast(item.size()), i); min_count = std::min(min_count, static_cast(array_[(hash % width_) + (i * width_)])); } return min_count; diff --git a/src/types/cms.h b/src/types/cms.h index ce57a2ad55b..02283b3f085 100644 --- a/src/types/cms.h +++ b/src/types/cms.h @@ -21,25 +21,32 @@ #pragma once #include +#include + #include "vendor/murmurhash2.h" class CMSketch { public: - explicit CMSketch(uint32_t width = 0, uint32_t depth = 0, uint64_t counter = 0, std::vector array = {}) + explicit CMSketch(uint32_t width, uint32_t depth, uint64_t counter, std::vector array = {}) : width_(width), depth_(depth), counter_(counter), array_(array.empty() ? std::vector(width * depth, 0) : std::move(array)) {} - ~CMSketch() = default; + static std::unique_ptr NewCMSketch(uint32_t width, int32_t depth) { + return std::make_unique(width, depth, 0); + } - static CMSketch* NewCMSketch(size_t width, size_t depth) { return new CMSketch(width, depth); } + struct CMSketchDimensions { + uint32_t width; + uint32_t depth; + }; - static void CMSDimFromProb(double error, double delta, uint32_t& width, uint32_t& depth); + static CMSketchDimensions CMSDimFromProb(double error, double delta); - size_t IncrBy(const char* item, size_t item_len, size_t value); + size_t IncrBy(std::string_view item, size_t value); - size_t Query(const char* item, size_t item_len) const; + size_t Query(std::string_view item) const; static int Merge(CMSketch* dest, size_t quantity, const std::vector& src, const std::vector& weights); @@ -59,8 +66,8 @@ class CMSketch { const uint64_t& GetCounter() const { return counter_; } const std::vector& GetArray() const { return array_; } - size_t GetWidth() const { return width_; } - size_t GetDepth() const { return depth_; } + uint32_t GetWidth() const { return width_; } + uint32_t GetDepth() const { return depth_; } private: size_t width_; @@ -68,10 +75,6 @@ class CMSketch { uint64_t counter_; std::vector array_; - static uint32_t hllMurMurHash64A(const char* item, size_t item_len, size_t i) { - return HllMurMurHash64A(item, static_cast(item_len), i); - } - static int checkOverflow(CMSketch* dest, size_t quantity, const std::vector& src, const std::vector& weights); }; \ No newline at end of file diff --git a/src/types/redis_cms.cc b/src/types/redis_cms.cc index c38a1a4fbf0..b75ce2d11f1 100644 --- a/src/types/redis_cms.cc +++ b/src/types/redis_cms.cc @@ -26,17 +26,17 @@ namespace redis { -rocksdb::Status CMS::GetMetadata(Database::GetOptions get_options, const Slice &ns_key, +rocksdb::Status CMS::GetMetadata(engine::Context &ctx, const Slice &ns_key, CountMinSketchMetadata *metadata) { - return Database::GetMetadata(get_options, {kRedisCountMinSketch}, ns_key, metadata); + return Database::GetMetadata(ctx, {kRedisCountMinSketch}, ns_key, metadata); } -rocksdb::Status CMS::IncrBy(const Slice &user_key, const std::unordered_map &elements) { +rocksdb::Status CMS::IncrBy(engine::Context &ctx, const Slice &user_key, const std::unordered_map &elements) { std::string ns_key = AppendNamespacePrefix(user_key); LockGuard guard(storage_->GetLockManager(), ns_key); CountMinSketchMetadata metadata{}; - rocksdb::Status s = GetMetadata(GetOptions(), ns_key, &metadata); + rocksdb::Status s = GetMetadata(ctx, ns_key, &metadata); if (!s.ok() && !s.IsNotFound()) { return s; } @@ -52,7 +52,7 @@ rocksdb::Status CMS::IncrBy(const Slice &user_key, const std::unordered_mapPut(metadata_cf_handle_, ns_key, bytes); - return storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); + return storage_->Write(ctx, storage_->DefaultWriteOptions(), batch->GetWriteBatch()); } -rocksdb::Status CMS::Info(const Slice &user_key, std::vector *ret) { +rocksdb::Status CMS::Info(engine::Context &ctx, const Slice &user_key, std::vector *ret) { std::string ns_key = AppendNamespacePrefix(user_key); LockGuard guard(storage_->GetLockManager(), ns_key); CountMinSketchMetadata metadata{}; - rocksdb::Status s = GetMetadata(GetOptions(), ns_key, &metadata); + rocksdb::Status s = GetMetadata(ctx, ns_key, &metadata); if (!s.ok() || s.IsNotFound()) { return rocksdb::Status::NotFound(); @@ -83,13 +83,13 @@ rocksdb::Status CMS::Info(const Slice &user_key, std::vector *ret) { return rocksdb::Status::OK(); }; -rocksdb::Status CMS::InitByDim(const Slice &user_key, uint32_t width, uint32_t depth) { +rocksdb::Status CMS::InitByDim(engine::Context &ctx, const Slice &user_key, uint32_t width, uint32_t depth) { std::string ns_key = AppendNamespacePrefix(user_key); LockGuard guard(storage_->GetLockManager(), ns_key); CountMinSketchMetadata metadata{}; - rocksdb::Status s = GetMetadata(GetOptions(), ns_key, &metadata); + rocksdb::Status s = GetMetadata(ctx, ns_key, &metadata); if (!s.IsNotFound()) { return s; @@ -108,16 +108,16 @@ rocksdb::Status CMS::InitByDim(const Slice &user_key, uint32_t width, uint32_t d metadata.Encode(&bytes); batch->Put(metadata_cf_handle_, ns_key, bytes); - return storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); + return storage_->Write(ctx, storage_->DefaultWriteOptions(), batch->GetWriteBatch()); }; -rocksdb::Status CMS::InitByProb(const Slice &user_key, double error, double delta) { +rocksdb::Status CMS::InitByProb(engine::Context &ctx, const Slice &user_key, double error, double delta) { std::string ns_key = AppendNamespacePrefix(user_key); LockGuard guard(storage_->GetLockManager(), ns_key); CountMinSketchMetadata metadata{}; - rocksdb::Status s = GetMetadata(GetOptions(), ns_key, &metadata); + rocksdb::Status s = GetMetadata(ctx, ns_key, &metadata); if (!s.IsNotFound()) { return s; } @@ -125,32 +125,29 @@ rocksdb::Status CMS::InitByProb(const Slice &user_key, double error, double delt WriteBatchLogData log_data(kRedisCountMinSketch); batch->PutLogData(log_data.Encode()); - CMSketch cms; - uint32_t width = 0; - uint32_t depth = 0; - cms.CMSDimFromProb(error, delta, width, depth); + CMSketch cms { 0, 0, 0 }; + CMSketch::CMSketchDimensions dim = cms.CMSDimFromProb(error, delta); - metadata.width = width; - metadata.depth = depth; + metadata.width = dim.width; + metadata.depth = dim.depth; metadata.counter = 0; - metadata.array = std::vector(width * depth, 0); - ; + metadata.array = std::vector(dim.width * dim.depth, 0); std::string bytes; metadata.Encode(&bytes); batch->Put(metadata_cf_handle_, ns_key, bytes); - return storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); + return storage_->Write(ctx, storage_->DefaultWriteOptions(), batch->GetWriteBatch()); }; -rocksdb::Status CMS::Query(const Slice &user_key, const std::vector &elements, +rocksdb::Status CMS::Query(engine::Context &ctx, const Slice &user_key, const std::vector &elements, std::vector &counters) { std::string ns_key = AppendNamespacePrefix(user_key); LockGuard guard(storage_->GetLockManager(), ns_key); CountMinSketchMetadata metadata{}; - rocksdb::Status s = GetMetadata(GetOptions(), ns_key, &metadata); + rocksdb::Status s = GetMetadata(ctx, ns_key, &metadata); if (s.IsNotFound()) { counters.assign(elements.size(), 0); @@ -162,7 +159,7 @@ rocksdb::Status CMS::Query(const Slice &user_key, const std::vector CMSketch cms(metadata.width, metadata.depth, metadata.counter, metadata.array); for (auto &element : elements) { - counters.push_back(cms.Query(element.data(), element.size())); + counters.push_back(cms.Query(element.data())); } return rocksdb::Status::OK(); }; diff --git a/src/types/redis_cms.h b/src/types/redis_cms.h index 4139de1a29e..6428a040189 100644 --- a/src/types/redis_cms.h +++ b/src/types/redis_cms.h @@ -30,19 +30,19 @@ class CMS : public Database { public: explicit CMS(engine::Storage *storage, const std::string &ns) : Database(storage, ns) {} - rocksdb::Status IncrBy(const Slice &user_key, const std::unordered_map &elements); - rocksdb::Status Info(const Slice &user_key, std::vector *ret); - rocksdb::Status InitByDim(const Slice &user_key, uint32_t width, uint32_t depth); - rocksdb::Status InitByProb(const Slice &user_key, double error, double delta); - rocksdb::Status Query(const Slice &user_key, const std::vector &elements, + rocksdb::Status IncrBy(engine::Context &ctx, const Slice &user_key, const std::unordered_map &elements); + rocksdb::Status Info(engine::Context &ctx, const Slice &user_key, std::vector *ret); + rocksdb::Status InitByDim(engine::Context &ctx, const Slice &user_key, uint32_t width, uint32_t depth); + rocksdb::Status InitByProb(engine::Context &ctx, const Slice &user_key, double error, double delta); + rocksdb::Status Query(engine::Context &ctx, const Slice &user_key, const std::vector &elements, std::vector &counters); private: - [[nodiscard]] rocksdb::Status GetMetadata(Database::GetOptions get_options, const Slice &ns_key, + [[nodiscard]] rocksdb::Status GetMetadata(engine::Context &ctx, const Slice &ns_key, CountMinSketchMetadata *metadata); // TODO (jonathanc-n) - [[nodiscard]] rocksdb::Status mergeUserKeys(Database::GetOptions get_options, const std::vector &user_keys, + [[nodiscard]] rocksdb::Status mergeUserKeys(engine::Context &ctx, const std::vector &user_keys, std::vector *register_segments); }; diff --git a/tests/cppunit/types/cms_test.cc b/tests/cppunit/types/cms_test.cc index 274c3c72166..dfae20f1ab3 100644 --- a/tests/cppunit/types/cms_test.cc +++ b/tests/cppunit/types/cms_test.cc @@ -32,17 +32,17 @@ class RedisCMSketchTest : public TestBase { void SetUp() override { TestBase::SetUp(); - [[maybe_unused]] auto s = cms_->Del("cms"); + [[maybe_unused]] auto s = cms_->Del(*ctx_, "cms"); for (int x = 1; x <= 3; x++) { - s = cms_->Del("cms" + std::to_string(x)); + s = cms_->Del(*ctx_, "cms" + std::to_string(x)); } } void TearDown() override { TestBase::TearDown(); - [[maybe_unused]] auto s = cms_->Del("cms"); + [[maybe_unused]] auto s = cms_->Del(*ctx_, "cms"); for (int x = 1; x <= 3; x++) { - s = cms_->Del("cms" + std::to_string(x)); + s = cms_->Del(*ctx_, "cms" + std::to_string(x)); } } @@ -50,9 +50,9 @@ class RedisCMSketchTest : public TestBase { }; TEST_F(RedisCMSketchTest, CMSInitByDim) { - ASSERT_TRUE(cms_->InitByDim("cms", 100, 5).ok()); + ASSERT_TRUE(cms_->InitByDim(*ctx_, "cms", 100, 5).ok()); std::vector info; - ASSERT_TRUE(cms_->Info("cms", &info).ok()); + ASSERT_TRUE(cms_->Info(*ctx_, "cms", &info).ok()); ASSERT_EQ(info[0], 100); ASSERT_EQ(info[1], 5); ASSERT_EQ(info[2], 0); @@ -60,28 +60,28 @@ TEST_F(RedisCMSketchTest, CMSInitByDim) { TEST_F(RedisCMSketchTest, CMSIncrBy) { std::unordered_map elements = {{"apple", 2}, {"banana", 3}, {"cherry", 1}}; - ASSERT_TRUE(cms_->InitByDim("cms", 100, 5).ok()); - ASSERT_TRUE(cms_->IncrBy("cms", elements).ok()); + ASSERT_TRUE(cms_->InitByDim(*ctx_, "cms", 100, 5).ok()); + ASSERT_TRUE(cms_->IncrBy(*ctx_, "cms", elements).ok()); std::vector counts; - ASSERT_TRUE(cms_->Query("cms", {"apple", "banana", "cherry"}, counts).ok()); + ASSERT_TRUE(cms_->Query(*ctx_, "cms", {"apple", "banana", "cherry"}, counts).ok()); ASSERT_EQ(counts[0], 2); ASSERT_EQ(counts[1], 3); ASSERT_EQ(counts[2], 1); std::vector info; - ASSERT_TRUE(cms_->Info("cms", &info).ok()); + ASSERT_TRUE(cms_->Info(*ctx_, "cms", &info).ok()); ASSERT_EQ(info[2], 6); } TEST_F(RedisCMSketchTest, CMSQuery) { std::unordered_map elements = {{"orange", 5}, {"grape", 3}, {"melon", 2}}; - ASSERT_TRUE(cms_->InitByDim("cms", 100, 5).ok()); - ASSERT_TRUE(cms_->IncrBy("cms", elements).ok()); + ASSERT_TRUE(cms_->InitByDim(*ctx_, "cms", 100, 5).ok()); + ASSERT_TRUE(cms_->IncrBy(*ctx_, "cms", elements).ok()); std::vector counts; - ASSERT_TRUE(cms_->Query("cms", {"orange", "grape", "melon", "nonexistent"}, counts).ok()); + ASSERT_TRUE(cms_->Query(*ctx_, "cms", {"orange", "grape", "melon", "nonexistent"}, counts).ok()); ASSERT_EQ(counts[0], 5); ASSERT_EQ(counts[1], 3); @@ -90,10 +90,10 @@ TEST_F(RedisCMSketchTest, CMSQuery) { } TEST_F(RedisCMSketchTest, CMSInfo) { - ASSERT_TRUE(cms_->InitByDim("cms", 200, 10).ok()); + ASSERT_TRUE(cms_->InitByDim(*ctx_, "cms", 200, 10).ok()); std::vector info; - ASSERT_TRUE(cms_->Info("cms", &info).ok()); + ASSERT_TRUE(cms_->Info(*ctx_, "cms", &info).ok()); ASSERT_EQ(info[0], 200); ASSERT_EQ(info[1], 10); @@ -101,10 +101,10 @@ TEST_F(RedisCMSketchTest, CMSInfo) { } TEST_F(RedisCMSketchTest, CMSInitByProb) { - ASSERT_TRUE(cms_->InitByProb("cms", 0.001, 0.1).ok()); + ASSERT_TRUE(cms_->InitByProb(*ctx_, "cms", 0.001, 0.1).ok()); std::vector info; - ASSERT_TRUE(cms_->Info("cms", &info).ok()); + ASSERT_TRUE(cms_->Info(*ctx_, "cms", &info).ok()); ASSERT_EQ(info[0], 2000); ASSERT_EQ(info[1], 4); @@ -115,15 +115,15 @@ TEST_F(RedisCMSketchTest, CMSMultipleKeys) { std::unordered_map elements1 = {{"apple", 2}, {"banana", 3}}; std::unordered_map elements2 = {{"cherry", 1}, {"date", 4}}; - ASSERT_TRUE(cms_->InitByDim("cms1", 100, 5).ok()); - ASSERT_TRUE(cms_->InitByDim("cms2", 100, 5).ok()); + ASSERT_TRUE(cms_->InitByDim(*ctx_, "cms1", 100, 5).ok()); + ASSERT_TRUE(cms_->InitByDim(*ctx_, "cms2", 100, 5).ok()); - ASSERT_TRUE(cms_->IncrBy("cms1", elements1).ok()); - ASSERT_TRUE(cms_->IncrBy("cms2", elements2).ok()); + ASSERT_TRUE(cms_->IncrBy(*ctx_, "cms1", elements1).ok()); + ASSERT_TRUE(cms_->IncrBy(*ctx_, "cms2", elements2).ok()); std::vector counts1, counts2; - ASSERT_TRUE(cms_->Query("cms1", {"apple", "banana"}, counts1).ok()); - ASSERT_TRUE(cms_->Query("cms2", {"cherry", "date"}, counts2).ok()); + ASSERT_TRUE(cms_->Query(*ctx_, "cms1", {"apple", "banana"}, counts1).ok()); + ASSERT_TRUE(cms_->Query(*ctx_, "cms2", {"cherry", "date"}, counts2).ok()); ASSERT_EQ(counts1[0], 2); ASSERT_EQ(counts1[1], 3); From 3fb9f1f52c703343559a1615f067373d278a273d Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Mon, 9 Sep 2024 20:38:25 -0400 Subject: [PATCH 05/25] small tweaks --- src/commands/cmd_cms.cc | 47 ++++++++++++++++++++++++--------- src/storage/redis_metadata.h | 2 +- src/types/cms.cc | 9 ++++--- src/types/cms.h | 13 ++++++++- src/types/redis_cms.cc | 17 ++++++++---- src/types/redis_cms.h | 2 +- tests/cppunit/types/cms_test.cc | 29 ++++++++++---------- 7 files changed, 80 insertions(+), 39 deletions(-) diff --git a/src/commands/cmd_cms.cc b/src/commands/cmd_cms.cc index 53664e74ea1..1d81477be2e 100644 --- a/src/commands/cmd_cms.cc +++ b/src/commands/cmd_cms.cc @@ -21,6 +21,7 @@ #include #include +#include "parse_util.h" #include "commander.h" #include "commands/command_parser.h" #include "server/redis_reply.h" @@ -41,12 +42,11 @@ class CommandCMSIncrBy final : public Commander { std::unordered_map elements; for (size_t i = 2; i < args_.size(); i += 2) { std::string key = args_[i]; - uint64_t value = 0; - try { - value = std::stoull(args_[i + 1]); - } catch (const std::exception &e) { - return Status::InvalidArgument; + auto parse_result = ParseInt(args_[i + 1]); + if (!parse_result) { + return {Status::RedisParseErr, errValueNotInteger}; } + uint64_t value = *parse_result; elements[key] = value; } @@ -67,8 +67,7 @@ class CommandCMSInfo final : public Commander { redis::CMS cms(srv->storage, conn->GetNamespace()); engine::Context ctx(srv->storage); rocksdb::Status s; - std::unordered_map elements; - std::vector ret{}; + CMSketch::CMSInfo ret{}; s = cms.Info(ctx, args_[1], &ret); @@ -80,8 +79,11 @@ class CommandCMSInfo final : public Commander { return {Status::RedisExecErr, s.ToString()}; } - *output = redis::Array({redis::BulkString("width"), redis::Integer(ret[0]), redis::BulkString("depth"), - redis::Integer(ret[1]), redis::BulkString("count"), redis::Integer(ret[2])}); + *output = redis::Array({ + redis::BulkString("width"), redis::Integer(ret.width), + redis::BulkString("depth"), redis::Integer(ret.depth), + redis::BulkString("count"), redis::Integer(ret.count) + }); return Status::OK(); } @@ -94,8 +96,17 @@ class CommandCMSInitByDim final : public Commander { redis::CMS cms(srv->storage, conn->GetNamespace()); engine::Context ctx(srv->storage); rocksdb::Status s; - uint64_t width = std::stoull(args_[2]); - uint64_t depth = std::stoull(args_[3]); + auto width_result = ParseInt(this->args_[2]); + if (!width_result) { + return {Status::RedisParseErr, errValueNotInteger}; + } + uint32_t width = *width_result; + + auto depth_result = ParseInt(this->args_[3]); + if (!depth_result) { + return {Status::RedisParseErr, errValueNotInteger}; + } + uint32_t depth = *depth_result; s = cms.InitByDim(ctx, args_[1], width, depth); if (!s.ok()) { @@ -114,8 +125,18 @@ class CommandCMSInitByProb final : public Commander { redis::CMS cms(srv->storage, conn->GetNamespace()); engine::Context ctx(srv->storage); rocksdb::Status s; - double error = std::stod(args_[2]); - double delta = std::stod(args_[3]); + + auto error_result = ParseFloat(args_[2]); + if (!error_result) { + return {Status::RedisParseErr, errValueNotInteger}; + } + double error = *error_result; + + auto delta_result = ParseFloat(args_[3]); + if (!delta_result) { + return {Status::RedisParseErr, errValueNotInteger}; + } + double delta = *delta_result; s = cms.InitByProb(ctx, args_[1], error, delta); if (!s.ok()) { diff --git a/src/storage/redis_metadata.h b/src/storage/redis_metadata.h index f7aa8e091a0..224b9bcff04 100644 --- a/src/storage/redis_metadata.h +++ b/src/storage/redis_metadata.h @@ -341,7 +341,7 @@ class CountMinSketchMetadata : public Metadata { public: uint32_t width; uint32_t depth; - uint64_t counter; + uint64_t counter = 0; std::vector array; explicit CountMinSketchMetadata(bool generate_version = true) : Metadata(kRedisCountMinSketch, generate_version) {} diff --git a/src/types/cms.cc b/src/types/cms.cc index 4699b139f2c..86974c20a45 100644 --- a/src/types/cms.cc +++ b/src/types/cms.cc @@ -33,15 +33,16 @@ CMSketch::CMSketchDimensions CMSketch::CMSDimFromProb(double error, double delta return dims; } -size_t CMSketch::IncrBy(std::string_view item, size_t value) { +size_t CMSketch::IncrBy(std::string_view item, uint32_t value) { size_t min_count = std::numeric_limits::max(); for (size_t i = 0; i < depth_; ++i) { uint64_t hash = HllMurMurHash64A(item.data(), static_cast(item.size()), i); - size_t loc = (hash % width_) + (i * width_); - array_[loc] += value; - if (array_[loc] < value) { + size_t loc = GetLocation(hash, i); + if (array_[loc] > UINT32_MAX - value) { array_[loc] = UINT32_MAX; + } else { + array_[loc] += value; } min_count = std::min(min_count, static_cast(array_[loc])); } diff --git a/src/types/cms.h b/src/types/cms.h index 02283b3f085..d1b8d1b4b7b 100644 --- a/src/types/cms.h +++ b/src/types/cms.h @@ -23,6 +23,7 @@ #include #include +#include "server/redis_reply.h" #include "vendor/murmurhash2.h" class CMSketch { @@ -37,6 +38,12 @@ class CMSketch { return std::make_unique(width, depth, 0); } + struct CMSInfo { + uint64_t width; + uint64_t depth; + uint64_t count; + }; + struct CMSketchDimensions { uint32_t width; uint32_t depth; @@ -44,7 +51,7 @@ class CMSketch { static CMSketchDimensions CMSDimFromProb(double error, double delta); - size_t IncrBy(std::string_view item, size_t value); + size_t IncrBy(std::string_view item, uint32_t value); size_t Query(std::string_view item) const; @@ -60,6 +67,10 @@ class CMSketch { int CMSMergeParams(const MergeParams& params); + size_t GetLocation(uint64_t hash, size_t i) const { + return (hash % width_) + (i * width_); + } + uint64_t& GetCounter() { return counter_; } std::vector& GetArray() { return array_; } diff --git a/src/types/redis_cms.cc b/src/types/redis_cms.cc index b75ce2d11f1..e269f7d9c54 100644 --- a/src/types/redis_cms.cc +++ b/src/types/redis_cms.cc @@ -18,7 +18,7 @@ * */ -#include "redis_cms.h" + #include "redis_cms.h" #include @@ -65,7 +65,7 @@ rocksdb::Status CMS::IncrBy(engine::Context &ctx, const Slice &user_key, const s return storage_->Write(ctx, storage_->DefaultWriteOptions(), batch->GetWriteBatch()); } -rocksdb::Status CMS::Info(engine::Context &ctx, const Slice &user_key, std::vector *ret) { +rocksdb::Status CMS::Info(engine::Context &ctx, const Slice &user_key, CMSketch::CMSInfo *ret) { std::string ns_key = AppendNamespacePrefix(user_key); LockGuard guard(storage_->GetLockManager(), ns_key); @@ -76,9 +76,9 @@ rocksdb::Status CMS::Info(engine::Context &ctx, const Slice &user_key, std::vect return rocksdb::Status::NotFound(); } - ret->emplace_back(metadata.width); - ret->emplace_back(metadata.depth); - ret->emplace_back(metadata.counter); + ret->width = metadata.width; + ret->depth = metadata.depth; + ret->count = metadata.counter; return rocksdb::Status::OK(); }; @@ -112,6 +112,13 @@ rocksdb::Status CMS::InitByDim(engine::Context &ctx, const Slice &user_key, uint }; rocksdb::Status CMS::InitByProb(engine::Context &ctx, const Slice &user_key, double error, double delta) { + if (error <= 0 || error >= 1) { + return rocksdb::Status::InvalidArgument("Error must be between 0 and 1 (exclusive)."); + } + if (delta <= 0 || delta >= 1) { + return rocksdb::Status::InvalidArgument("Delta must be between 0 and 1 (exclusive)."); + } + std::string ns_key = AppendNamespacePrefix(user_key); LockGuard guard(storage_->GetLockManager(), ns_key); diff --git a/src/types/redis_cms.h b/src/types/redis_cms.h index 6428a040189..cc9d54241b1 100644 --- a/src/types/redis_cms.h +++ b/src/types/redis_cms.h @@ -31,7 +31,7 @@ class CMS : public Database { explicit CMS(engine::Storage *storage, const std::string &ns) : Database(storage, ns) {} rocksdb::Status IncrBy(engine::Context &ctx, const Slice &user_key, const std::unordered_map &elements); - rocksdb::Status Info(engine::Context &ctx, const Slice &user_key, std::vector *ret); + rocksdb::Status Info(engine::Context &ctx, const Slice &user_key, CMSketch::CMSInfo *ret); rocksdb::Status InitByDim(engine::Context &ctx, const Slice &user_key, uint32_t width, uint32_t depth); rocksdb::Status InitByProb(engine::Context &ctx, const Slice &user_key, double error, double delta); rocksdb::Status Query(engine::Context &ctx, const Slice &user_key, const std::vector &elements, diff --git a/tests/cppunit/types/cms_test.cc b/tests/cppunit/types/cms_test.cc index dfae20f1ab3..e97dbdaa320 100644 --- a/tests/cppunit/types/cms_test.cc +++ b/tests/cppunit/types/cms_test.cc @@ -23,6 +23,7 @@ #include #include "test_base.h" +#include "types/cms.h" #include "types/redis_cms.h" class RedisCMSketchTest : public TestBase { @@ -51,11 +52,11 @@ class RedisCMSketchTest : public TestBase { TEST_F(RedisCMSketchTest, CMSInitByDim) { ASSERT_TRUE(cms_->InitByDim(*ctx_, "cms", 100, 5).ok()); - std::vector info; + CMSketch::CMSInfo info; ASSERT_TRUE(cms_->Info(*ctx_, "cms", &info).ok()); - ASSERT_EQ(info[0], 100); - ASSERT_EQ(info[1], 5); - ASSERT_EQ(info[2], 0); + ASSERT_EQ(info.width, 100); + ASSERT_EQ(info.depth, 5); + ASSERT_EQ(info.count, 0); } TEST_F(RedisCMSketchTest, CMSIncrBy) { @@ -70,9 +71,9 @@ TEST_F(RedisCMSketchTest, CMSIncrBy) { ASSERT_EQ(counts[1], 3); ASSERT_EQ(counts[2], 1); - std::vector info; + CMSketch::CMSInfo info; ASSERT_TRUE(cms_->Info(*ctx_, "cms", &info).ok()); - ASSERT_EQ(info[2], 6); + ASSERT_EQ(info.count, 6); } TEST_F(RedisCMSketchTest, CMSQuery) { @@ -92,23 +93,23 @@ TEST_F(RedisCMSketchTest, CMSQuery) { TEST_F(RedisCMSketchTest, CMSInfo) { ASSERT_TRUE(cms_->InitByDim(*ctx_, "cms", 200, 10).ok()); - std::vector info; + CMSketch::CMSInfo info; ASSERT_TRUE(cms_->Info(*ctx_, "cms", &info).ok()); - ASSERT_EQ(info[0], 200); - ASSERT_EQ(info[1], 10); - ASSERT_EQ(info[2], 0); + ASSERT_EQ(info.width, 200); + ASSERT_EQ(info.depth, 10); + ASSERT_EQ(info.count, 0); } TEST_F(RedisCMSketchTest, CMSInitByProb) { ASSERT_TRUE(cms_->InitByProb(*ctx_, "cms", 0.001, 0.1).ok()); - std::vector info; + CMSketch::CMSInfo info; ASSERT_TRUE(cms_->Info(*ctx_, "cms", &info).ok()); - ASSERT_EQ(info[0], 2000); - ASSERT_EQ(info[1], 4); - ASSERT_EQ(info[2], 0); + ASSERT_EQ(info.width, 2000); + ASSERT_EQ(info.depth, 4); + ASSERT_EQ(info.count, 0); } TEST_F(RedisCMSketchTest, CMSMultipleKeys) { From a92dc08f72ec27f76f59bf2340c8e9a62e3ae960 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Tue, 10 Sep 2024 00:55:15 -0400 Subject: [PATCH 06/25] parse change --- src/commands/cmd_cms.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/commands/cmd_cms.cc b/src/commands/cmd_cms.cc index 1d81477be2e..050d7842fc8 100644 --- a/src/commands/cmd_cms.cc +++ b/src/commands/cmd_cms.cc @@ -34,7 +34,7 @@ class CommandCMSIncrBy final : public Commander { public: Status Execute(Server *srv, Connection *conn, std::string *output) override { if ((args_.size() - 2) % 2 != 0) { - return Status::RedisTryAgain; + return {Status::RedisParseErr, errWrongNumOfArguments}; } redis::CMS cms(srv->storage, conn->GetNamespace()); engine::Context ctx(srv->storage); From b8b0bf295887fb2fa96731c7af990f793fbba2b7 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Wed, 11 Sep 2024 21:34:31 -0400 Subject: [PATCH 07/25] format changes --- src/commands/cmd_cms.cc | 13 +++++-------- src/types/cms.cc | 1 - src/types/cms.h | 8 +++----- src/types/redis_cms.cc | 10 +++++----- src/types/redis_cms.h | 3 ++- tests/cppunit/types/cms_test.cc | 3 ++- 6 files changed, 17 insertions(+), 21 deletions(-) diff --git a/src/commands/cmd_cms.cc b/src/commands/cmd_cms.cc index 050d7842fc8..0e3132bbad2 100644 --- a/src/commands/cmd_cms.cc +++ b/src/commands/cmd_cms.cc @@ -21,9 +21,9 @@ #include #include -#include "parse_util.h" #include "commander.h" #include "commands/command_parser.h" +#include "parse_util.h" #include "server/redis_reply.h" #include "server/server.h" @@ -79,11 +79,8 @@ class CommandCMSInfo final : public Commander { return {Status::RedisExecErr, s.ToString()}; } - *output = redis::Array({ - redis::BulkString("width"), redis::Integer(ret.width), - redis::BulkString("depth"), redis::Integer(ret.depth), - redis::BulkString("count"), redis::Integer(ret.count) - }); + *output = redis::Array({redis::BulkString("width"), redis::Integer(ret.width), redis::BulkString("depth"), + redis::Integer(ret.depth), redis::BulkString("count"), redis::Integer(ret.count)}); return Status::OK(); } @@ -106,7 +103,7 @@ class CommandCMSInitByDim final : public Commander { if (!depth_result) { return {Status::RedisParseErr, errValueNotInteger}; } - uint32_t depth = *depth_result; + uint32_t depth = *depth_result; s = cms.InitByDim(ctx, args_[1], width, depth); if (!s.ok()) { @@ -130,7 +127,7 @@ class CommandCMSInitByProb final : public Commander { if (!error_result) { return {Status::RedisParseErr, errValueNotInteger}; } - double error = *error_result; + double error = *error_result; auto delta_result = ParseFloat(args_[3]); if (!delta_result) { diff --git a/src/types/cms.cc b/src/types/cms.cc index 86974c20a45..a7c26d07fad 100644 --- a/src/types/cms.cc +++ b/src/types/cms.cc @@ -25,7 +25,6 @@ #include #include - CMSketch::CMSketchDimensions CMSketch::CMSDimFromProb(double error, double delta) { CMSketchDimensions dims; dims.width = std::ceil(2 / error); diff --git a/src/types/cms.h b/src/types/cms.h index d1b8d1b4b7b..a87f89fecb2 100644 --- a/src/types/cms.h +++ b/src/types/cms.h @@ -20,8 +20,8 @@ #pragma once -#include #include +#include #include "server/redis_reply.h" #include "vendor/murmurhash2.h" @@ -35,7 +35,7 @@ class CMSketch { array_(array.empty() ? std::vector(width * depth, 0) : std::move(array)) {} static std::unique_ptr NewCMSketch(uint32_t width, int32_t depth) { - return std::make_unique(width, depth, 0); + return std::make_unique(width, depth, 0); } struct CMSInfo { @@ -67,9 +67,7 @@ class CMSketch { int CMSMergeParams(const MergeParams& params); - size_t GetLocation(uint64_t hash, size_t i) const { - return (hash % width_) + (i * width_); - } + size_t GetLocation(uint64_t hash, size_t i) const { return (hash % width_) + (i * width_); } uint64_t& GetCounter() { return counter_; } std::vector& GetArray() { return array_; } diff --git a/src/types/redis_cms.cc b/src/types/redis_cms.cc index e269f7d9c54..b245b9113d5 100644 --- a/src/types/redis_cms.cc +++ b/src/types/redis_cms.cc @@ -18,7 +18,7 @@ * */ - #include "redis_cms.h" +#include "redis_cms.h" #include @@ -26,12 +26,12 @@ namespace redis { -rocksdb::Status CMS::GetMetadata(engine::Context &ctx, const Slice &ns_key, - CountMinSketchMetadata *metadata) { +rocksdb::Status CMS::GetMetadata(engine::Context &ctx, const Slice &ns_key, CountMinSketchMetadata *metadata) { return Database::GetMetadata(ctx, {kRedisCountMinSketch}, ns_key, metadata); } -rocksdb::Status CMS::IncrBy(engine::Context &ctx, const Slice &user_key, const std::unordered_map &elements) { +rocksdb::Status CMS::IncrBy(engine::Context &ctx, const Slice &user_key, + const std::unordered_map &elements) { std::string ns_key = AppendNamespacePrefix(user_key); LockGuard guard(storage_->GetLockManager(), ns_key); @@ -132,7 +132,7 @@ rocksdb::Status CMS::InitByProb(engine::Context &ctx, const Slice &user_key, dou WriteBatchLogData log_data(kRedisCountMinSketch); batch->PutLogData(log_data.Encode()); - CMSketch cms { 0, 0, 0 }; + CMSketch cms{0, 0, 0}; CMSketch::CMSketchDimensions dim = cms.CMSDimFromProb(error, delta); metadata.width = dim.width; diff --git a/src/types/redis_cms.h b/src/types/redis_cms.h index cc9d54241b1..54c51111197 100644 --- a/src/types/redis_cms.h +++ b/src/types/redis_cms.h @@ -30,7 +30,8 @@ class CMS : public Database { public: explicit CMS(engine::Storage *storage, const std::string &ns) : Database(storage, ns) {} - rocksdb::Status IncrBy(engine::Context &ctx, const Slice &user_key, const std::unordered_map &elements); + rocksdb::Status IncrBy(engine::Context &ctx, const Slice &user_key, + const std::unordered_map &elements); rocksdb::Status Info(engine::Context &ctx, const Slice &user_key, CMSketch::CMSInfo *ret); rocksdb::Status InitByDim(engine::Context &ctx, const Slice &user_key, uint32_t width, uint32_t depth); rocksdb::Status InitByProb(engine::Context &ctx, const Slice &user_key, double error, double delta); diff --git a/tests/cppunit/types/cms_test.cc b/tests/cppunit/types/cms_test.cc index e97dbdaa320..b60c1c196d7 100644 --- a/tests/cppunit/types/cms_test.cc +++ b/tests/cppunit/types/cms_test.cc @@ -18,12 +18,13 @@ * */ +#include "types/cms.h" + #include #include #include "test_base.h" -#include "types/cms.h" #include "types/redis_cms.h" class RedisCMSketchTest : public TestBase { From dd1c4f547a3a65d4f4e583b8303483b233f09010 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Fri, 13 Sep 2024 20:53:15 -0400 Subject: [PATCH 08/25] tweaks + lint --- src/storage/redis_metadata.cc | 7 ++++--- src/types/cms.cc | 8 ++++---- src/types/cms.h | 6 +----- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/storage/redis_metadata.cc b/src/storage/redis_metadata.cc index 1dcf8730ac3..07ae5703e14 100644 --- a/src/storage/redis_metadata.cc +++ b/src/storage/redis_metadata.cc @@ -326,11 +326,12 @@ bool Metadata::ExpireAt(uint64_t expired_ts) const { return expire < expired_ts; } -bool Metadata::IsSingleKVType() const { return Type() == kRedisString || Type() == kRedisJson; } +bool Metadata::IsSingleKVType() const { + return Type() == kRedisString || Type() == kRedisJson || Type() == kRedisCountMinSketch; +} bool Metadata::IsEmptyableType() const { - return IsSingleKVType() || Type() == kRedisStream || Type() == kRedisBloomFilter || Type() == kRedisHyperLogLog || - Type() == kRedisCountMinSketch; + return IsSingleKVType() || Type() == kRedisStream || Type() == kRedisBloomFilter || Type() == kRedisHyperLogLog; } bool Metadata::Expired() const { return ExpireAt(util::GetTimeStampMS()); } diff --git a/src/types/cms.cc b/src/types/cms.cc index a7c26d07fad..f9e8e8691cd 100644 --- a/src/types/cms.cc +++ b/src/types/cms.cc @@ -36,7 +36,7 @@ size_t CMSketch::IncrBy(std::string_view item, uint32_t value) { size_t min_count = std::numeric_limits::max(); for (size_t i = 0; i < depth_; ++i) { - uint64_t hash = HllMurMurHash64A(item.data(), static_cast(item.size()), i); + uint64_t hash = XXH32(item.data(), static_cast(item.size()), i); size_t loc = GetLocation(hash, i); if (array_[loc] > UINT32_MAX - value) { array_[loc] = UINT32_MAX; @@ -53,8 +53,8 @@ size_t CMSketch::Query(std::string_view item) const { size_t min_count = std::numeric_limits::max(); for (size_t i = 0; i < depth_; ++i) { - uint64_t hash = HllMurMurHash64A(item.data(), static_cast(item.size()), i); - min_count = std::min(min_count, static_cast(array_[(hash % width_) + (i * width_)])); + uint64_t hash = XXH32(item.data(), static_cast(item.size()), i); + min_count = std::min(min_count, static_cast(array_[GetLocation(hash, i)])); } return min_count; } @@ -125,4 +125,4 @@ int CMSketch::checkOverflow(CMSketch* dest, size_t quantity, const std::vector #include "server/redis_reply.h" -#include "vendor/murmurhash2.h" +#include "xxhash.h" class CMSketch { public: @@ -34,10 +34,6 @@ class CMSketch { counter_(counter), array_(array.empty() ? std::vector(width * depth, 0) : std::move(array)) {} - static std::unique_ptr NewCMSketch(uint32_t width, int32_t depth) { - return std::make_unique(width, depth, 0); - } - struct CMSInfo { uint64_t width; uint64_t depth; From 509d5ab2e442688588f9a6124f40a5163b1c6e18 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Sat, 14 Sep 2024 00:07:20 -0400 Subject: [PATCH 09/25] max memory checks for initbyprob command --- src/types/redis_cms.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/types/redis_cms.cc b/src/types/redis_cms.cc index b245b9113d5..e371f427e98 100644 --- a/src/types/redis_cms.cc +++ b/src/types/redis_cms.cc @@ -135,6 +135,16 @@ rocksdb::Status CMS::InitByProb(engine::Context &ctx, const Slice &user_key, dou CMSketch cms{0, 0, 0}; CMSketch::CMSketchDimensions dim = cms.CMSDimFromProb(error, delta); + size_t memory_used = dim.width * dim.depth * sizeof(uint32_t); + const size_t max_memory = 50 * 1024 * 1024; + + if (memory_used == 0) { + return rocksdb::Status::InvalidArgument("Memory usage must be greater than 0."); + } + if (memory_used > max_memory) { + return rocksdb::Status::InvalidArgument("Memory usage exceeds 50MB."); + } + metadata.width = dim.width; metadata.depth = dim.depth; metadata.counter = 0; From 1e4e747e566dc9805c34b82996658b8a8b2f8cfe Mon Sep 17 00:00:00 2001 From: Jonathan Chen <86070045+jonathanc-n@users.noreply.github.com> Date: Wed, 18 Sep 2024 12:23:30 -0400 Subject: [PATCH 10/25] Update src/types/redis_cms.h Co-authored-by: mwish --- src/types/redis_cms.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/types/redis_cms.h b/src/types/redis_cms.h index 54c51111197..c5e6dcecf76 100644 --- a/src/types/redis_cms.h +++ b/src/types/redis_cms.h @@ -43,6 +43,8 @@ class CMS : public Database { CountMinSketchMetadata *metadata); // TODO (jonathanc-n) + // [[nodiscard]] rocksdb::Status mergeUserKeys(engine::Context &ctx, const std::vector &user_keys, + // std::vector *register_segments); [[nodiscard]] rocksdb::Status mergeUserKeys(engine::Context &ctx, const std::vector &user_keys, std::vector *register_segments); }; From a71d7f73c075fc09ab8f86408b427977412adcac Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Thu, 19 Sep 2024 14:37:20 -0400 Subject: [PATCH 11/25] all fixes + go test case --- src/commands/cmd_cms.cc | 80 ++++++++++ src/types/cms.cc | 6 +- src/types/cms.h | 11 +- src/types/redis_cms.cc | 82 +++++++++- src/types/redis_cms.h | 8 +- tests/gocase/unit/cms/cms_test.go | 245 ++++++++++++++++++++++++++++++ 6 files changed, 418 insertions(+), 14 deletions(-) create mode 100644 tests/gocase/unit/cms/cms_test.go diff --git a/src/commands/cmd_cms.cc b/src/commands/cmd_cms.cc index 0e3132bbad2..1782d91be49 100644 --- a/src/commands/cmd_cms.cc +++ b/src/commands/cmd_cms.cc @@ -51,6 +51,9 @@ class CommandCMSIncrBy final : public Commander { } s = cms.IncrBy(ctx, args_[1], elements); + if (s.IsNotFound()) { + return {Status::RedisExecErr, "Key not found"}; + } if (!s.ok()) { return {Status::RedisExecErr, s.ToString()}; } @@ -145,6 +148,82 @@ class CommandCMSInitByProb final : public Commander { } }; +/// CMS.MERGE destination numKeys source [source ...] [WEIGHTS weight [weight ...]] +class CommandCMSMerge final : public Commander { + public: + Status Parse(const std::vector &args) override { + CommandParser parser(args, 2); + destination_ = args[1]; // Change to std::string + + StatusOr num_key_result = parser.TakeInt(); + if (!num_key_result || *num_key_result <= 0) { + return {Status::RedisParseErr, "invalid number of source keys"}; + } + num_keys_ = *num_key_result; + + src_keys_.reserve(num_keys_); + for (int i = 0; i < num_keys_; i++) { + auto result = parser.TakeStr(); + if (!result) { + return {Status::RedisParseErr, "Error parsing source key"}; + } + src_keys_.emplace_back(std::move(*result)); + } + + bool weights_found = false; + while (parser.Good()) { + if (parser.EatEqICase("WEIGHTS")) { + if (weights_found) { + return {Status::RedisParseErr, "WEIGHTS option cannot be specified multiple times"}; + } + src_weights_.reserve(num_keys_); + for (int i = 0; i < num_keys_; i++) { + StatusOr weight_result = parser.TakeInt(); + if (!weight_result || *weight_result == 0) { // Adjust condition if needed + return {Status::RedisParseErr, "invalid weight value"}; + } + src_weights_.emplace_back(*weight_result); + } + weights_found = true; + } else { + return {Status::RedisParseErr, "Syntax error: unexpected token"}; + } + } + + if (!weights_found) { + src_weights_.resize(num_keys_, 1); + } + + return Status::OK(); + } + + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::CMS cms(srv->storage, conn->GetNamespace()); + engine::Context ctx(srv->storage); + + // Convert std::string to Slice + std::vector src_keys_slices; + src_keys_slices.reserve(src_keys_.size()); + for (const auto &key : src_keys_) { + src_keys_slices.emplace_back(key); + } + + rocksdb::Status s = cms.MergeUserKeys(ctx, Slice(destination_), src_keys_slices, src_weights_); + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + + *output = redis::SimpleString("OK"); + return Status::OK(); + } + + private: + std::string destination_; // Changed to std::string + int num_keys_; + std::vector src_keys_; // Changed to std::vector + std::vector src_weights_; +}; + /// CMS.QUERY key item [item ...] class CommandCMSQuery final : public Commander { public: @@ -182,5 +261,6 @@ REDIS_REGISTER_COMMANDS(CMS, MakeCmdAttr("cms.incrby", -4, "wr MakeCmdAttr("cms.info", 2, "read-only", 0, 0, 0), MakeCmdAttr("cms.initbydim", 4, "write", 0, 0, 0), MakeCmdAttr("cms.initbyprob", 4, "write", 0, 0, 0), + MakeCmdAttr("cms.merge", -4, "write", 0, 0, 0), MakeCmdAttr("cms.query", -3, "read-only", 0, 0, 0), ); } // namespace redis \ No newline at end of file diff --git a/src/types/cms.cc b/src/types/cms.cc index f9e8e8691cd..4bbea0df4ed 100644 --- a/src/types/cms.cc +++ b/src/types/cms.cc @@ -20,6 +20,8 @@ #include "cms.h" +#include + #include #include #include @@ -68,10 +70,12 @@ int CMSketch::Merge(CMSketch* dest, size_t quantity, const std::vectorGetDepth(); ++i) { for (size_t j = 0; j < dest->GetWidth(); ++j) { int64_t item_count = 0; + // Sum the weighted counts from all source CMSes for (size_t k = 0; k < quantity; ++k) { item_count += static_cast(src[k]->array_[(i * dest->GetWidth()) + j]) * weights[k]; } - dest->GetArray()[(i * dest->GetWidth()) + j] = item_count; + // accumulates the weighted sum into the destination CMS's array + dest->GetArray()[(i * dest->GetWidth()) + j] += static_cast(item_count); } } diff --git a/src/types/cms.h b/src/types/cms.h index c02b22e09df..f1bdb70ee9f 100644 --- a/src/types/cms.h +++ b/src/types/cms.h @@ -24,19 +24,18 @@ #include #include "server/redis_reply.h" -#include "xxhash.h" class CMSketch { public: - explicit CMSketch(uint32_t width, uint32_t depth, uint64_t counter, std::vector array = {}) + explicit CMSketch(uint32_t width, uint32_t depth, uint64_t counter, std::vector array) : width_(width), depth_(depth), counter_(counter), array_(array.empty() ? std::vector(width * depth, 0) : std::move(array)) {} struct CMSInfo { - uint64_t width; - uint64_t depth; + uint32_t width; + uint32_t depth; uint64_t count; }; @@ -75,8 +74,8 @@ class CMSketch { uint32_t GetDepth() const { return depth_; } private: - size_t width_; - size_t depth_; + uint32_t width_; + uint32_t depth_; uint64_t counter_; std::vector array_; diff --git a/src/types/redis_cms.cc b/src/types/redis_cms.cc index e371f427e98..f386d288be2 100644 --- a/src/types/redis_cms.cc +++ b/src/types/redis_cms.cc @@ -23,6 +23,7 @@ #include #include "cms.h" +#include "rocksdb/status.h" namespace redis { @@ -37,6 +38,10 @@ rocksdb::Status CMS::IncrBy(engine::Context &ctx, const Slice &user_key, LockGuard guard(storage_->GetLockManager(), ns_key); CountMinSketchMetadata metadata{}; rocksdb::Status s = GetMetadata(ctx, ns_key, &metadata); + + if (s.IsNotFound()) { + return rocksdb::Status::NotFound(); + } if (!s.ok() && !s.IsNotFound()) { return s; } @@ -132,7 +137,7 @@ rocksdb::Status CMS::InitByProb(engine::Context &ctx, const Slice &user_key, dou WriteBatchLogData log_data(kRedisCountMinSketch); batch->PutLogData(log_data.Encode()); - CMSketch cms{0, 0, 0}; + CMSketch cms{0, 0, 0, {}}; CMSketch::CMSketchDimensions dim = cms.CMSDimFromProb(error, delta); size_t memory_used = dim.width * dim.depth * sizeof(uint32_t); @@ -157,6 +162,81 @@ rocksdb::Status CMS::InitByProb(engine::Context &ctx, const Slice &user_key, dou return storage_->Write(ctx, storage_->DefaultWriteOptions(), batch->GetWriteBatch()); }; +rocksdb::Status CMS::MergeUserKeys(engine::Context &ctx, const Slice &user_key, const std::vector &src_keys, + const std::vector &src_weights) { + size_t num_sources = src_keys.size(); + if (num_sources == 0) { + return rocksdb::Status::InvalidArgument("No source keys provided for merge."); + } + if (src_weights.size() != num_sources) { + return rocksdb::Status::InvalidArgument("Number of weights must match number of source keys."); + } + + std::string dest_ns_key = AppendNamespacePrefix(user_key); + LockGuard guard(storage_->GetLockManager(), dest_ns_key); + + CountMinSketchMetadata dest_metadata{}; + rocksdb::Status dest_status = GetMetadata(ctx, dest_ns_key, &dest_metadata); + if (!dest_status.ok()) { + if (dest_status.IsNotFound()) { + return rocksdb::Status::InvalidArgument("Destination CMS does not exist."); + } + return dest_status; + } + + CMSketch dest_cms(dest_metadata.width, dest_metadata.depth, dest_metadata.counter, dest_metadata.array); + + std::vector src_cms_objects; + src_cms_objects.reserve(num_sources); + std::vector src_cms_pointers; + src_cms_pointers.reserve(num_sources); + std::vector weights_long; + weights_long.reserve(num_sources); + + for (size_t i = 0; i < num_sources; ++i) { + std::string src_ns_key = AppendNamespacePrefix(src_keys[i]); + LOG(INFO) << "Dest Key: " << dest_ns_key << " | Source Key: " << src_ns_key; + LockGuard guard(storage_->GetLockManager(), src_ns_key); + + CountMinSketchMetadata src_metadata{}; + rocksdb::Status src_status = GetMetadata(ctx, src_ns_key, &src_metadata); + if (!src_status.ok()) { + if (src_status.IsNotFound()) { + return rocksdb::Status::InvalidArgument("Source CMS key not found."); + } + return src_status; + } + + if (src_metadata.width != dest_metadata.width || src_metadata.depth != dest_metadata.depth) { + return rocksdb::Status::InvalidArgument("Source CMS dimensions do not match destination CMS."); + } + + CMSketch src_cms(src_metadata.width, src_metadata.depth, src_metadata.counter, src_metadata.array); + src_cms_objects.emplace_back(std::move(src_cms)); + src_cms_pointers.push_back(&src_cms_objects.back()); + + weights_long.push_back(static_cast(src_weights[i])); + } + + int merge_result = CMSketch::Merge(&dest_cms, num_sources, src_cms_pointers, weights_long); + if (merge_result != 0) { + return rocksdb::Status::InvalidArgument("Merge operation failed due to overflow or invalid dimensions."); + } + + auto batch = storage_->GetWriteBatchBase(); + WriteBatchLogData log_data(kRedisCountMinSketch); + batch->PutLogData(log_data.Encode()); + + dest_metadata.counter = dest_cms.GetCounter(); + dest_metadata.array = dest_cms.GetArray(); + + std::string encoded_metadata; + dest_metadata.Encode(&encoded_metadata); + batch->Put(metadata_cf_handle_, dest_ns_key, encoded_metadata); + + return storage_->Write(ctx, storage_->DefaultWriteOptions(), batch->GetWriteBatch()); +} + rocksdb::Status CMS::Query(engine::Context &ctx, const Slice &user_key, const std::vector &elements, std::vector &counters) { std::string ns_key = AppendNamespacePrefix(user_key); diff --git a/src/types/redis_cms.h b/src/types/redis_cms.h index c5e6dcecf76..2071256bf6b 100644 --- a/src/types/redis_cms.h +++ b/src/types/redis_cms.h @@ -37,16 +37,12 @@ class CMS : public Database { rocksdb::Status InitByProb(engine::Context &ctx, const Slice &user_key, double error, double delta); rocksdb::Status Query(engine::Context &ctx, const Slice &user_key, const std::vector &elements, std::vector &counters); + rocksdb::Status MergeUserKeys(engine::Context &ctx, const Slice &user_key, const std::vector &src_keys, + const std::vector &src_weights); private: [[nodiscard]] rocksdb::Status GetMetadata(engine::Context &ctx, const Slice &ns_key, CountMinSketchMetadata *metadata); - - // TODO (jonathanc-n) - // [[nodiscard]] rocksdb::Status mergeUserKeys(engine::Context &ctx, const std::vector &user_keys, - // std::vector *register_segments); - [[nodiscard]] rocksdb::Status mergeUserKeys(engine::Context &ctx, const std::vector &user_keys, - std::vector *register_segments); }; } // namespace redis diff --git a/tests/gocase/unit/cms/cms_test.go b/tests/gocase/unit/cms/cms_test.go new file mode 100644 index 00000000000..d300b9c19f7 --- /dev/null +++ b/tests/gocase/unit/cms/cms_test.go @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package cms + +import ( + "context" + "fmt" + "strconv" + "testing" + + "github.com/apache/kvrocks/tests/gocase/util" + "github.com/stretchr/testify/require" +) + +func TestCountMinSketch(t *testing.T) { + // Define configuration options if needed. + // Adjust or add more configurations as per your CMS requirements. + configOptions := []util.ConfigOptions{ + { + Name: "txn-context-enabled", + Options: []string{"yes", "no"}, + ConfigType: util.YesNo, + }, + // Add more configuration options here if necessary + } + + // Generate all combinations of configurations + configsMatrix, err := util.GenerateConfigsMatrix(configOptions) + require.NoError(t, err) + + // Iterate over each configuration and run CMS tests + for _, configs := range configsMatrix { + testCMS(t, configs) + } +} + +// testCMS sets up the server with the given configurations and runs CMS tests +func testCMS(t *testing.T, configs util.KvrocksServerConfigs) { + srv := util.StartServer(t, configs) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + // Run individual CMS test cases + t.Run("basic add", func(t *testing.T) { + require.NoError(t, rdb.Do(ctx, "DEL", "cmsA").Err()) + + res := rdb.Do(ctx, "cms.initbydim", "cmsA", 100, 10) + require.NoError(t, res.Err()) + require.Equal(t, "OK", res.Val()) + + require.Equal(t, []interface{}{"width", int64(100), "depth", int64(10), "count", int64(0)}, rdb.Do(ctx, "cms.info", "cmsA").Val()) + + res = rdb.Do(ctx, "cms.incrby", "cmsA", "foo", 1) + require.NoError(t, res.Err()) + addCnt, err := res.Result() + + require.NoError(t, err) + require.Equal(t, string("OK"), addCnt) + + card, err := rdb.Do(ctx, "cms.query", "cmsA", "foo").Result() + require.NoError(t, err) + require.Equal(t, []interface{}([]interface{}{"1"}), card, "The queried count for 'foo' should be 1") + }) + + t.Run("cms.initbyprob - Initialization with Probability Parameters", func(t *testing.T) { + require.NoError(t, rdb.Do(ctx, "DEL", "cmsB").Err()) + + res := rdb.Do(ctx, "cms.initbyprob", "cmsB", "0.001", "0.1") + require.NoError(t, res.Err()) + require.Equal(t, "OK", res.Val()) + + infoRes := rdb.Do(ctx, "cms.info", "cmsB") + require.NoError(t, infoRes.Err()) + infoSlice, ok := infoRes.Val().([]interface{}) + require.True(t, ok, "Expected cms.info to return a slice") + + infoMap := make(map[string]interface{}) + for i := 0; i < len(infoSlice); i += 2 { + key, ok1 := infoSlice[i].(string) + value, ok2 := infoSlice[i+1].(int64) + require.True(t, ok1 && ok2, "Expected cms.info keys to be strings and values to be int64") + infoMap[key] = value + } + + require.Equal(t, int64(2000), infoMap["width"]) + require.Equal(t, int64(4), infoMap["depth"]) + require.Equal(t, int64(0), infoMap["count"]) + }) + + t.Run("cms.incrby - Basic Increment Operations", func(t *testing.T) { + require.NoError(t, rdb.Do(ctx, "DEL", "cmsA").Err()) + res := rdb.Do(ctx, "cms.initbydim", "cmsA", "100", "10") + require.NoError(t, res.Err()) + require.Equal(t, "OK", res.Val()) + + elements := map[string]string{"apple": "7", "orange": "15", "mango": "3"} + for key, count := range elements { + res = rdb.Do(ctx, "cms.incrby", "cmsA", key, count) + require.NoError(t, res.Err()) + require.Equal(t, "OK", res.Val()) + } + + for key, expected := range elements { + res = rdb.Do(ctx, "cms.query", "cmsA", key) + require.NoError(t, res.Err()) + countSlice, ok := res.Val().([]interface{}) + require.True(t, ok, "Expected cms.query to return a slice") + require.Len(t, countSlice, 1, "Expected cms.query to return a single count") + count, ok := countSlice[0].(string) + require.True(t, ok, "Expected count to be a string") + require.Equal(t, expected, count, fmt.Sprintf("Count for key '%s' mismatch", key)) + } + + // Verify total count + infoRes := rdb.Do(ctx, "cms.info", "cmsA") + require.NoError(t, infoRes.Err()) + infoSlice, ok := infoRes.Val().([]interface{}) + require.True(t, ok, "Expected cms.info to return a slice") + + // Convert the slice to a map for easier access + infoMap := make(map[string]interface{}) + for i := 0; i < len(infoSlice); i += 2 { + key, ok1 := infoSlice[i].(string) + value, ok2 := infoSlice[i+1].(int64) + require.True(t, ok1 && ok2, "Expected cms.info keys to be strings and values to be int64") + infoMap[key] = value + } + + total := int64(0) + for _, cntStr := range elements { + cnt, err := strconv.ParseInt(cntStr, 10, 64) + require.NoError(t, err, "Failed to parse count string to int64") + total += cnt + } + require.Equal(t, total, infoMap["count"], "Total count mismatch") + }) + + // Increment operation on a non-existent CMS + t.Run("cms.incrby - Increment Non-Existent CMS", func(t *testing.T) { + res := rdb.Do(ctx, "cms.incrby", "nonexistent_cms", "apple", "5") + require.Error(t, res.Err()) + }) + + // Query for non-existent element + t.Run("cms.query - Query Non-Existent Element", func(t *testing.T) { + require.NoError(t, rdb.Do(ctx, "DEL", "cmsA").Err()) + res := rdb.Do(ctx, "cms.initbydim", "cmsA", "100", "10") + require.NoError(t, res.Err()) + require.Equal(t, "OK", res.Val()) + + // Query a non-existent element + res = rdb.Do(ctx, "cms.query", "cmsA", "nonexistent") + require.NoError(t, res.Err()) + countSlice, ok := res.Val().([]interface{}) + require.True(t, ok, "Expected cms.query to return a slice") + require.Len(t, countSlice, 1, "Expected cms.query to return a single count") + count, ok := countSlice[0].(string) + require.True(t, ok, "Expected count to be a string") + require.Equal(t, "0", count, "Non-existent element should return count '0'") + }) + + // Merging CMS structures + t.Run("cms.merge - Basic Merge Operation", func(t *testing.T) { + require.NoError(t, rdb.Do(ctx, "DEL", "cmsA").Err()) + require.NoError(t, rdb.Do(ctx, "DEL", "cmsB").Err()) + + res := rdb.Do(ctx, "cms.initbydim", "cmsA", "100", "10") + require.NoError(t, res.Err()) + require.Equal(t, "OK", res.Val()) + + res = rdb.Do(ctx, "cms.initbydim", "cmsB", "100", "10") + require.NoError(t, res.Err()) + require.Equal(t, "OK", res.Val()) + + // Increment elements in cmsA + elementsA := map[string]string{"apple": "7", "orange": "15", "mango": "3"} + for key, count := range elementsA { + res = rdb.Do(ctx, "cms.incrby", "cmsA", key, count) + require.NoError(t, res.Err()) + require.Equal(t, "OK", res.Val()) + } + + // Increment elements in cmsB + elementsB := map[string]string{"banana": "5", "apple": "4", "grape": "6"} + for key, count := range elementsB { + res = rdb.Do(ctx, "cms.incrby", "cmsB", key, count) + require.NoError(t, res.Err()) + require.Equal(t, "OK", res.Val()) + } + + // Merge cmsB into cmsA with weights + res = rdb.Do(ctx, "cms.merge", "cmsA", "1", "cmsB", "WEIGHTS", "1") + require.NoError(t, res.Err()) + require.Equal(t, "OK", res.Val()) + + // Query counts after merge + expectedCounts := map[string]string{"apple": "11", "orange": "15", "mango": "3", "banana": "5", "grape": "6"} + for key, expected := range expectedCounts { + res = rdb.Do(ctx, "cms.query", "cmsA", key) + require.NoError(t, res.Err()) + countSlice, ok := res.Val().([]interface{}) + require.True(t, ok, "Expected cms.query to return a slice") + require.Len(t, countSlice, 1, "Expected cms.query to return a single count") + count, ok := countSlice[0].(string) + require.True(t, ok, "Expected count to be a string") + require.Equal(t, expected, count, fmt.Sprintf("Count for key '%s' mismatch after merge", key)) + } + + infoRes := rdb.Do(ctx, "cms.info", "cmsA") + require.NoError(t, infoRes.Err()) + infoSlice, ok := infoRes.Val().([]interface{}) + require.True(t, ok, "Expected cms.info to return a slice") + + infoMap := make(map[string]interface{}) + for i := 0; i < len(infoSlice); i += 2 { + key, ok1 := infoSlice[i].(string) + value, ok2 := infoSlice[i+1].(int64) + require.True(t, ok1 && ok2, "Expected cms.info keys to be strings and values to be int64") + infoMap[key] = value + } + + expectedTotal := int64(40) + require.Equal(t, expectedTotal, infoMap["count"], "Total count mismatch after merge") + }) +} From 551f3c848387329c027806d0f9d569950b9acfb5 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Sun, 22 Sep 2024 22:31:44 -0400 Subject: [PATCH 12/25] Added additional test cases --- src/types/redis_cms.cc | 2 +- tests/gocase/unit/cms/cms_test.go | 48 +++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/types/redis_cms.cc b/src/types/redis_cms.cc index f386d288be2..2e2ce08051b 100644 --- a/src/types/redis_cms.cc +++ b/src/types/redis_cms.cc @@ -248,7 +248,7 @@ rocksdb::Status CMS::Query(engine::Context &ctx, const Slice &user_key, const st if (s.IsNotFound()) { counters.assign(elements.size(), 0); - return rocksdb::Status::OK(); + return rocksdb::Status::NotFound(); } else if (!s.ok()) { return s; } diff --git a/tests/gocase/unit/cms/cms_test.go b/tests/gocase/unit/cms/cms_test.go index d300b9c19f7..aabf5692d81 100644 --- a/tests/gocase/unit/cms/cms_test.go +++ b/tests/gocase/unit/cms/cms_test.go @@ -161,6 +161,13 @@ func testCMS(t *testing.T, configs util.KvrocksServerConfigs) { require.Error(t, res.Err()) }) + t.Run("cms.query - Query Non-Existent CMS", func(t *testing.T) { + // Attempt to query a CMS that doesn't exist + res := rdb.Do(ctx, "cms.query", "nonexistent_cms", "foo") + require.Error(t, res.Err()) + require.Contains(t, res.Err().Error(), "ERR NotFound:") + }) + // Query for non-existent element t.Run("cms.query - Query Non-Existent Element", func(t *testing.T) { require.NoError(t, rdb.Do(ctx, "DEL", "cmsA").Err()) @@ -242,4 +249,45 @@ func testCMS(t *testing.T, configs util.KvrocksServerConfigs) { expectedTotal := int64(40) require.Equal(t, expectedTotal, infoMap["count"], "Total count mismatch after merge") }) + + t.Run("cms.merge - Merge with Uninitialized Destination CMS", func(t *testing.T) { + require.NoError(t, rdb.Do(ctx, "DEL", "cmsA").Err()) + require.NoError(t, rdb.Do(ctx, "DEL", "cmsB").Err()) + + // Initialize only the source CMS + res := rdb.Do(ctx, "cms.initbydim", "cmsB", "100", "10") + require.NoError(t, res.Err()) + require.Equal(t, "OK", res.Val()) + + // Attempt to merge cmsB into cmsA without initializing cmsA + res = rdb.Do(ctx, "cms.merge", "cmsA", "1", "cmsB", "WEIGHTS", "1") + require.Error(t, res.Err(), "Merging into an uninitialized destination CMS should return an error") + require.Contains(t, res.Err().Error(), "Destination CMS does not exist.", "Expected error message to contain 'Destination CMS does not exist.'") + }) + + t.Run("cms.merge - Merge with Uninitialized Source CMS", func(t *testing.T) { + require.NoError(t, rdb.Do(ctx, "DEL", "cmsA").Err()) + require.NoError(t, rdb.Do(ctx, "DEL", "cmsB").Err()) + + // Initialize only the destination CMS + res := rdb.Do(ctx, "cms.initbydim", "cmsA", "100", "10") + require.NoError(t, res.Err()) + require.Equal(t, "OK", res.Val()) + + // Attempt to merge a non-initialized cmsB into cmsA + res = rdb.Do(ctx, "cms.merge", "cmsA", "1", "cmsB", "WEIGHTS", "1") + require.Error(t, res.Err(), "Merging from an uninitialized source CMS should return an error") + require.Contains(t, res.Err().Error(), "Source CMS key not found.", "Expected error message to contain 'Source CMS key not found.'") + }) + + t.Run("cms.merge - Merge with Both Destination and Source CMS Uninitialized", func(t *testing.T) { + require.NoError(t, rdb.Do(ctx, "DEL", "cmsA").Err()) + require.NoError(t, rdb.Do(ctx, "DEL", "cmsB").Err()) + + // Attempt to merge two non-initialized CMSes + res := rdb.Do(ctx, "cms.merge", "cmsA", "1", "cmsB", "WEIGHTS", "1") + require.Error(t, res.Err(), "Merging with both destination and source CMS uninitialized should return an error") + errMsg := res.Err().Error() + require.Contains(t, errMsg, "Destination CMS does not exist.") + }) } From a0222d721839317fd60de5b9b8e4599279ca78ed Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Sun, 22 Sep 2024 22:59:40 -0400 Subject: [PATCH 13/25] lint fix --- src/types/redis_cms.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/types/redis_cms.cc b/src/types/redis_cms.cc index 2e2ce08051b..b3f2a4f84a8 100644 --- a/src/types/redis_cms.cc +++ b/src/types/redis_cms.cc @@ -57,7 +57,7 @@ rocksdb::Status CMS::IncrBy(engine::Context &ctx, const Slice &user_key, } for (auto &element : elements) { - cms.IncrBy(element.first.data(), element.second); + cms.IncrBy(element.first, element.second); metadata.counter += element.second; } @@ -256,7 +256,7 @@ rocksdb::Status CMS::Query(engine::Context &ctx, const Slice &user_key, const st CMSketch cms(metadata.width, metadata.depth, metadata.counter, metadata.array); for (auto &element : elements) { - counters.push_back(cms.Query(element.data())); + counters.push_back(cms.Query(element)); } return rocksdb::Status::OK(); }; From de0c2ad157c725b99498cd86fafd90de30832ce7 Mon Sep 17 00:00:00 2001 From: Jonathan Chen <86070045+jonathanc-n@users.noreply.github.com> Date: Sun, 29 Sep 2024 00:00:42 -0400 Subject: [PATCH 14/25] Update src/types/redis_cms.h Co-authored-by: mwish --- src/types/redis_cms.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types/redis_cms.h b/src/types/redis_cms.h index 2071256bf6b..d499ed709c8 100644 --- a/src/types/redis_cms.h +++ b/src/types/redis_cms.h @@ -36,7 +36,7 @@ class CMS : public Database { rocksdb::Status InitByDim(engine::Context &ctx, const Slice &user_key, uint32_t width, uint32_t depth); rocksdb::Status InitByProb(engine::Context &ctx, const Slice &user_key, double error, double delta); rocksdb::Status Query(engine::Context &ctx, const Slice &user_key, const std::vector &elements, - std::vector &counters); + std::vector *counters); rocksdb::Status MergeUserKeys(engine::Context &ctx, const Slice &user_key, const std::vector &src_keys, const std::vector &src_weights); From 69d6ea4297fcdea98dee0582da7a666e7cc08780 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Sun, 29 Sep 2024 00:50:00 -0400 Subject: [PATCH 15/25] Fixes --- src/commands/cmd_cms.cc | 19 +++++-------- src/types/cms.cc | 38 +++++++++++++------------- src/types/cms.h | 7 ++--- src/types/redis_cms.cc | 59 ++++++++++++++++------------------------- 4 files changed, 49 insertions(+), 74 deletions(-) diff --git a/src/commands/cmd_cms.cc b/src/commands/cmd_cms.cc index 1782d91be49..e10b7763ef8 100644 --- a/src/commands/cmd_cms.cc +++ b/src/commands/cmd_cms.cc @@ -75,7 +75,7 @@ class CommandCMSInfo final : public Commander { s = cms.Info(ctx, args_[1], &ret); if (s.IsNotFound()) { - return {Status::RedisExecErr, s.ToString()}; + return {Status::RedisExecErr, "Key not found"}; } if (!s.ok() && !s.IsNotFound()) { @@ -153,7 +153,7 @@ class CommandCMSMerge final : public Commander { public: Status Parse(const std::vector &args) override { CommandParser parser(args, 2); - destination_ = args[1]; // Change to std::string + destination_ = args[1]; StatusOr num_key_result = parser.TakeInt(); if (!num_key_result || *num_key_result <= 0) { @@ -179,7 +179,7 @@ class CommandCMSMerge final : public Commander { src_weights_.reserve(num_keys_); for (int i = 0; i < num_keys_; i++) { StatusOr weight_result = parser.TakeInt(); - if (!weight_result || *weight_result == 0) { // Adjust condition if needed + if (!weight_result || *weight_result == 0) { return {Status::RedisParseErr, "invalid weight value"}; } src_weights_.emplace_back(*weight_result); @@ -201,14 +201,7 @@ class CommandCMSMerge final : public Commander { redis::CMS cms(srv->storage, conn->GetNamespace()); engine::Context ctx(srv->storage); - // Convert std::string to Slice - std::vector src_keys_slices; - src_keys_slices.reserve(src_keys_.size()); - for (const auto &key : src_keys_) { - src_keys_slices.emplace_back(key); - } - - rocksdb::Status s = cms.MergeUserKeys(ctx, Slice(destination_), src_keys_slices, src_weights_); + rocksdb::Status s = cms.MergeUserKeys(ctx, destination_, src_keys_, src_weights_); if (!s.ok()) { return {Status::RedisExecErr, s.ToString()}; } @@ -218,9 +211,9 @@ class CommandCMSMerge final : public Commander { } private: - std::string destination_; // Changed to std::string + Slice destination_; int num_keys_; - std::vector src_keys_; // Changed to std::vector + std::vector src_keys_; std::vector src_weights_; }; diff --git a/src/types/cms.cc b/src/types/cms.cc index 4bbea0df4ed..8301b12554a 100644 --- a/src/types/cms.cc +++ b/src/types/cms.cc @@ -39,7 +39,7 @@ size_t CMSketch::IncrBy(std::string_view item, uint32_t value) { for (size_t i = 0; i < depth_; ++i) { uint64_t hash = XXH32(item.data(), static_cast(item.size()), i); - size_t loc = GetLocation(hash, i); + size_t loc = GetLocationForHash(hash, i); if (array_[loc] > UINT32_MAX - value) { array_[loc] = UINT32_MAX; } else { @@ -56,38 +56,36 @@ size_t CMSketch::Query(std::string_view item) const { for (size_t i = 0; i < depth_; ++i) { uint64_t hash = XXH32(item.data(), static_cast(item.size()), i); - min_count = std::min(min_count, static_cast(array_[GetLocation(hash, i)])); + min_count = std::min(min_count, static_cast(array_[GetLocationForHash(hash, i)])); } return min_count; } -int CMSketch::Merge(CMSketch* dest, size_t quantity, const std::vector& src, - const std::vector& weights) { - if (checkOverflow(dest, quantity, src, weights) != 0) { - return -1; +Status CMSketch::Merge(const CMSketch::MergeParams& params) { + // Perform overflow check + if (checkOverflow(params.dest, params.num_keys, params.cms_array, params.weights) != 0) { + return {Status::NotOK, "Overflow error."}; } - for (size_t i = 0; i < dest->GetDepth(); ++i) { - for (size_t j = 0; j < dest->GetWidth(); ++j) { + size_t dest_depth = params.dest->GetDepth(); + size_t dest_width = params.dest->GetWidth(); + + // Merge source CMSes into the destination CMS + for (size_t i = 0; i < dest_depth; ++i) { + for (size_t j = 0; j < dest_width; ++j) { int64_t item_count = 0; - // Sum the weighted counts from all source CMSes - for (size_t k = 0; k < quantity; ++k) { - item_count += static_cast(src[k]->array_[(i * dest->GetWidth()) + j]) * weights[k]; + for (size_t k = 0; k < params.num_keys; ++k) { + item_count += static_cast(params.cms_array[k]->array_[(i * dest_width) + j]) * params.weights[k]; } - // accumulates the weighted sum into the destination CMS's array - dest->GetArray()[(i * dest->GetWidth()) + j] += static_cast(item_count); + params.dest->GetArray()[(i * dest_width) + j] += static_cast(item_count); } } - for (size_t i = 0; i < quantity; ++i) { - dest->GetCounter() += src[i]->GetCounter() * weights[i]; + for (size_t i = 0; i < params.num_keys; ++i) { + params.dest->GetCounter() += params.cms_array[i]->GetCounter() * params.weights[i]; } - return 0; -} - -int CMSMergeParams(const CMSketch::MergeParams& params) { - return CMSketch::Merge(params.dest, params.num_keys, params.cms_array, params.weights); + return Status::OK(); } int CMSketch::checkOverflow(CMSketch* dest, size_t quantity, const std::vector& src, diff --git a/src/types/cms.h b/src/types/cms.h index f1bdb70ee9f..95ee0e65a39 100644 --- a/src/types/cms.h +++ b/src/types/cms.h @@ -50,9 +50,6 @@ class CMSketch { size_t Query(std::string_view item) const; - static int Merge(CMSketch* dest, size_t quantity, const std::vector& src, - const std::vector& weights); - struct MergeParams { CMSketch* dest; size_t num_keys; @@ -60,9 +57,9 @@ class CMSketch { std::vector weights; }; - int CMSMergeParams(const MergeParams& params); + static Status Merge(const MergeParams& params); - size_t GetLocation(uint64_t hash, size_t i) const { return (hash % width_) + (i * width_); } + size_t GetLocationForHash(uint64_t hash, size_t i) const { return (hash % width_) + (i * width_); } uint64_t& GetCounter() { return counter_; } std::vector& GetArray() { return array_; } diff --git a/src/types/redis_cms.cc b/src/types/redis_cms.cc index b3f2a4f84a8..976a7b91a0c 100644 --- a/src/types/redis_cms.cc +++ b/src/types/redis_cms.cc @@ -42,7 +42,7 @@ rocksdb::Status CMS::IncrBy(engine::Context &ctx, const Slice &user_key, if (s.IsNotFound()) { return rocksdb::Status::NotFound(); } - if (!s.ok() && !s.IsNotFound()) { + if (!s.ok()) { return s; } @@ -52,11 +52,10 @@ rocksdb::Status CMS::IncrBy(engine::Context &ctx, const Slice &user_key, CMSketch cms(metadata.width, metadata.depth, metadata.counter, metadata.array); - if (elements.empty()) { - return rocksdb::Status::OK(); - } - for (auto &element : elements) { + if (element.second > 0 && metadata.counter > std::numeric_limits::max() - element.second) { + return rocksdb::Status::InvalidArgument("Overflow error: IncrBy would result in counter overflow"); + } cms.IncrBy(element.first, element.second); metadata.counter += element.second; } @@ -77,7 +76,7 @@ rocksdb::Status CMS::Info(engine::Context &ctx, const Slice &user_key, CMSketch: CountMinSketchMetadata metadata{}; rocksdb::Status s = GetMetadata(ctx, ns_key, &metadata); - if (!s.ok() || s.IsNotFound()) { + if (!s.ok()) { return rocksdb::Status::NotFound(); } @@ -96,6 +95,10 @@ rocksdb::Status CMS::InitByDim(engine::Context &ctx, const Slice &user_key, uint rocksdb::Status s = GetMetadata(ctx, ns_key, &metadata); + if (s.ok()) { + return rocksdb::Status::InvalidArgument("Key already exists."); + } + if (!s.IsNotFound()) { return s; } @@ -123,22 +126,11 @@ rocksdb::Status CMS::InitByProb(engine::Context &ctx, const Slice &user_key, dou if (delta <= 0 || delta >= 1) { return rocksdb::Status::InvalidArgument("Delta must be between 0 and 1 (exclusive)."); } - - std::string ns_key = AppendNamespacePrefix(user_key); - - LockGuard guard(storage_->GetLockManager(), ns_key); - CountMinSketchMetadata metadata{}; - - rocksdb::Status s = GetMetadata(ctx, ns_key, &metadata); - if (!s.IsNotFound()) { + CMSketch::CMSketchDimensions dim = CMSketch::CMSDimFromProb(error, delta); + auto s = InitByDim(ctx, user_key, dim.width, dim.depth); + if (!s.ok()) { return s; } - auto batch = storage_->GetWriteBatchBase(); - WriteBatchLogData log_data(kRedisCountMinSketch); - batch->PutLogData(log_data.Encode()); - - CMSketch cms{0, 0, 0, {}}; - CMSketch::CMSketchDimensions dim = cms.CMSDimFromProb(error, delta); size_t memory_used = dim.width * dim.depth * sizeof(uint32_t); const size_t max_memory = 50 * 1024 * 1024; @@ -149,17 +141,7 @@ rocksdb::Status CMS::InitByProb(engine::Context &ctx, const Slice &user_key, dou if (memory_used > max_memory) { return rocksdb::Status::InvalidArgument("Memory usage exceeds 50MB."); } - - metadata.width = dim.width; - metadata.depth = dim.depth; - metadata.counter = 0; - metadata.array = std::vector(dim.width * dim.depth, 0); - - std::string bytes; - metadata.Encode(&bytes); - batch->Put(metadata_cf_handle_, ns_key, bytes); - - return storage_->Write(ctx, storage_->DefaultWriteOptions(), batch->GetWriteBatch()); + return rocksdb::Status::OK(); }; rocksdb::Status CMS::MergeUserKeys(engine::Context &ctx, const Slice &user_key, const std::vector &src_keys, @@ -190,12 +172,11 @@ rocksdb::Status CMS::MergeUserKeys(engine::Context &ctx, const Slice &user_key, src_cms_objects.reserve(num_sources); std::vector src_cms_pointers; src_cms_pointers.reserve(num_sources); - std::vector weights_long; + std::vector weights_long; weights_long.reserve(num_sources); for (size_t i = 0; i < num_sources; ++i) { std::string src_ns_key = AppendNamespacePrefix(src_keys[i]); - LOG(INFO) << "Dest Key: " << dest_ns_key << " | Source Key: " << src_ns_key; LockGuard guard(storage_->GetLockManager(), src_ns_key); CountMinSketchMetadata src_metadata{}; @@ -218,8 +199,14 @@ rocksdb::Status CMS::MergeUserKeys(engine::Context &ctx, const Slice &user_key, weights_long.push_back(static_cast(src_weights[i])); } - int merge_result = CMSketch::Merge(&dest_cms, num_sources, src_cms_pointers, weights_long); - if (merge_result != 0) { + CMSketch::MergeParams merge_params; + merge_params.dest = &dest_cms; + merge_params.num_keys = num_sources; + merge_params.cms_array = src_cms_pointers; + merge_params.weights = weights_long; + + auto merge_result = CMSketch::Merge(merge_params); + if (!merge_result.IsOK()) { return rocksdb::Status::InvalidArgument("Merge operation failed due to overflow or invalid dimensions."); } @@ -240,6 +227,7 @@ rocksdb::Status CMS::MergeUserKeys(engine::Context &ctx, const Slice &user_key, rocksdb::Status CMS::Query(engine::Context &ctx, const Slice &user_key, const std::vector &elements, std::vector &counters) { std::string ns_key = AppendNamespacePrefix(user_key); + counters.assign(elements.size(), 0); LockGuard guard(storage_->GetLockManager(), ns_key); CountMinSketchMetadata metadata{}; @@ -247,7 +235,6 @@ rocksdb::Status CMS::Query(engine::Context &ctx, const Slice &user_key, const st rocksdb::Status s = GetMetadata(ctx, ns_key, &metadata); if (s.IsNotFound()) { - counters.assign(elements.size(), 0); return rocksdb::Status::NotFound(); } else if (!s.ok()) { return s; From 0cdfb0751b6405c89ed313f2dc6b64e1381583b9 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Sun, 29 Sep 2024 01:11:17 -0400 Subject: [PATCH 16/25] Small Changes --- src/commands/cmd_cms.cc | 1 + src/types/cms.cc | 2 +- src/types/redis_cms.cc | 2 +- src/types/redis_cms.h | 2 +- 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/commands/cmd_cms.cc b/src/commands/cmd_cms.cc index e10b7763ef8..433e6bac3a2 100644 --- a/src/commands/cmd_cms.cc +++ b/src/commands/cmd_cms.cc @@ -201,6 +201,7 @@ class CommandCMSMerge final : public Commander { redis::CMS cms(srv->storage, conn->GetNamespace()); engine::Context ctx(srv->storage); + rocksdb::Status s = cms.MergeUserKeys(ctx, destination_, src_keys_, src_weights_); if (!s.ok()) { return {Status::RedisExecErr, s.ToString()}; diff --git a/src/types/cms.cc b/src/types/cms.cc index 8301b12554a..a8d621088d5 100644 --- a/src/types/cms.cc +++ b/src/types/cms.cc @@ -20,7 +20,7 @@ #include "cms.h" -#include +#include "xxhash.h" #include #include diff --git a/src/types/redis_cms.cc b/src/types/redis_cms.cc index 976a7b91a0c..acfc5c765c6 100644 --- a/src/types/redis_cms.cc +++ b/src/types/redis_cms.cc @@ -77,7 +77,7 @@ rocksdb::Status CMS::Info(engine::Context &ctx, const Slice &user_key, CMSketch: rocksdb::Status s = GetMetadata(ctx, ns_key, &metadata); if (!s.ok()) { - return rocksdb::Status::NotFound(); + return s; } ret->width = metadata.width; diff --git a/src/types/redis_cms.h b/src/types/redis_cms.h index d499ed709c8..2071256bf6b 100644 --- a/src/types/redis_cms.h +++ b/src/types/redis_cms.h @@ -36,7 +36,7 @@ class CMS : public Database { rocksdb::Status InitByDim(engine::Context &ctx, const Slice &user_key, uint32_t width, uint32_t depth); rocksdb::Status InitByProb(engine::Context &ctx, const Slice &user_key, double error, double delta); rocksdb::Status Query(engine::Context &ctx, const Slice &user_key, const std::vector &elements, - std::vector *counters); + std::vector &counters); rocksdb::Status MergeUserKeys(engine::Context &ctx, const Slice &user_key, const std::vector &src_keys, const std::vector &src_weights); From 02347df2b07d1072680380814d3b54ca45c3f33f Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Sun, 29 Sep 2024 01:37:58 -0400 Subject: [PATCH 17/25] logic fix --- src/commands/cmd_cms.cc | 3 +-- src/types/redis_cms.cc | 7 ++++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/commands/cmd_cms.cc b/src/commands/cmd_cms.cc index 433e6bac3a2..8e3426a14e2 100644 --- a/src/commands/cmd_cms.cc +++ b/src/commands/cmd_cms.cc @@ -201,7 +201,6 @@ class CommandCMSMerge final : public Commander { redis::CMS cms(srv->storage, conn->GetNamespace()); engine::Context ctx(srv->storage); - rocksdb::Status s = cms.MergeUserKeys(ctx, destination_, src_keys_, src_weights_); if (!s.ok()) { return {Status::RedisExecErr, s.ToString()}; @@ -242,7 +241,7 @@ class CommandCMSQuery final : public Commander { std::vector output_values; output_values.reserve(counters.size()); for (const auto &counter : counters) { - output_values.push_back(std::to_string(counter)); + output_values.emplace_back(std::to_string(counter)); } *output = redis::ArrayOfBulkStrings(output_values); diff --git a/src/types/redis_cms.cc b/src/types/redis_cms.cc index acfc5c765c6..00c4d929c96 100644 --- a/src/types/redis_cms.cc +++ b/src/types/redis_cms.cc @@ -227,7 +227,7 @@ rocksdb::Status CMS::MergeUserKeys(engine::Context &ctx, const Slice &user_key, rocksdb::Status CMS::Query(engine::Context &ctx, const Slice &user_key, const std::vector &elements, std::vector &counters) { std::string ns_key = AppendNamespacePrefix(user_key); - counters.assign(elements.size(), 0); + counters.resize(elements.size(), 0); LockGuard guard(storage_->GetLockManager(), ns_key); CountMinSketchMetadata metadata{}; @@ -242,9 +242,10 @@ rocksdb::Status CMS::Query(engine::Context &ctx, const Slice &user_key, const st CMSketch cms(metadata.width, metadata.depth, metadata.counter, metadata.array); - for (auto &element : elements) { - counters.push_back(cms.Query(element)); + for (size_t i = 0; i < elements.size(); ++i) { + counters[i] = cms.Query(elements[i]); } + return rocksdb::Status::OK(); }; From 606f9b5568fd0df5976e9e81d17578f51d06f371 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Sun, 29 Sep 2024 01:50:21 -0400 Subject: [PATCH 18/25] lint --- src/types/cms.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/types/cms.cc b/src/types/cms.cc index a8d621088d5..5f1a9c6924b 100644 --- a/src/types/cms.cc +++ b/src/types/cms.cc @@ -20,13 +20,13 @@ #include "cms.h" -#include "xxhash.h" - #include #include #include #include +#include "xxhash.h" + CMSketch::CMSketchDimensions CMSketch::CMSDimFromProb(double error, double delta) { CMSketchDimensions dims; dims.width = std::ceil(2 / error); From 0445b3dd5bcb558849890618c50391d5963fcc4a Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Sun, 29 Sep 2024 12:35:16 -0400 Subject: [PATCH 19/25] Quick Fixes --- src/types/cms.cc | 21 +++++++++++---------- src/types/cms.h | 9 +++++---- src/types/redis_cms.cc | 18 +++++++++++------- 3 files changed, 27 insertions(+), 21 deletions(-) diff --git a/src/types/cms.cc b/src/types/cms.cc index 5f1a9c6924b..c9eecc5f7e3 100644 --- a/src/types/cms.cc +++ b/src/types/cms.cc @@ -61,34 +61,35 @@ size_t CMSketch::Query(std::string_view item) const { return min_count; } -Status CMSketch::Merge(const CMSketch::MergeParams& params) { +Status CMSketch::Merge(CMSketch* dest, size_t num_keys, std::vector cms_array, + std::vector weights) { // Perform overflow check - if (checkOverflow(params.dest, params.num_keys, params.cms_array, params.weights) != 0) { + if (CMSketch::CheckOverflow(dest, num_keys, cms_array, weights) != 0) { return {Status::NotOK, "Overflow error."}; } - size_t dest_depth = params.dest->GetDepth(); - size_t dest_width = params.dest->GetWidth(); + size_t dest_depth = dest->GetDepth(); + size_t dest_width = dest->GetWidth(); // Merge source CMSes into the destination CMS for (size_t i = 0; i < dest_depth; ++i) { for (size_t j = 0; j < dest_width; ++j) { int64_t item_count = 0; - for (size_t k = 0; k < params.num_keys; ++k) { - item_count += static_cast(params.cms_array[k]->array_[(i * dest_width) + j]) * params.weights[k]; + for (size_t k = 0; k < num_keys; ++k) { + item_count += static_cast(cms_array[k]->array_[(i * dest_width) + j]) * weights[k]; } - params.dest->GetArray()[(i * dest_width) + j] += static_cast(item_count); + dest->GetArray()[(i * dest_width) + j] += static_cast(item_count); } } - for (size_t i = 0; i < params.num_keys; ++i) { - params.dest->GetCounter() += params.cms_array[i]->GetCounter() * params.weights[i]; + for (size_t i = 0; i < num_keys; ++i) { + dest->GetCounter() += cms_array[i]->GetCounter() * weights[i]; } return Status::OK(); } -int CMSketch::checkOverflow(CMSketch* dest, size_t quantity, const std::vector& src, +int CMSketch::CheckOverflow(CMSketch* dest, size_t quantity, const std::vector& src, const std::vector& weights) { int64_t item_count = 0; int64_t cms_count = 0; diff --git a/src/types/cms.h b/src/types/cms.h index 95ee0e65a39..b750bb02380 100644 --- a/src/types/cms.h +++ b/src/types/cms.h @@ -57,7 +57,8 @@ class CMSketch { std::vector weights; }; - static Status Merge(const MergeParams& params); + static Status Merge(CMSketch* dest, size_t num_keys, std::vector cms_array, + std::vector weights); size_t GetLocationForHash(uint64_t hash, size_t i) const { return (hash % width_) + (i * width_); } @@ -70,12 +71,12 @@ class CMSketch { uint32_t GetWidth() const { return width_; } uint32_t GetDepth() const { return depth_; } + static int CheckOverflow(CMSketch* dest, size_t quantity, const std::vector& src, + const std::vector& weights); + private: uint32_t width_; uint32_t depth_; uint64_t counter_; std::vector array_; - - static int checkOverflow(CMSketch* dest, size_t quantity, const std::vector& src, - const std::vector& weights); }; \ No newline at end of file diff --git a/src/types/redis_cms.cc b/src/types/redis_cms.cc index 00c4d929c96..9f8f7ea28fd 100644 --- a/src/types/redis_cms.cc +++ b/src/types/redis_cms.cc @@ -90,6 +90,16 @@ rocksdb::Status CMS::Info(engine::Context &ctx, const Slice &user_key, CMSketch: rocksdb::Status CMS::InitByDim(engine::Context &ctx, const Slice &user_key, uint32_t width, uint32_t depth) { std::string ns_key = AppendNamespacePrefix(user_key); + size_t memory_used = width * depth * sizeof(uint32_t); + const size_t max_memory = 50 * 1024 * 1024; + + if (memory_used == 0) { + return rocksdb::Status::InvalidArgument("Memory usage must be greater than 0."); + } + if (memory_used > max_memory) { + return rocksdb::Status::InvalidArgument("Memory usage exceeds 50MB."); + } + LockGuard guard(storage_->GetLockManager(), ns_key); CountMinSketchMetadata metadata{}; @@ -199,13 +209,7 @@ rocksdb::Status CMS::MergeUserKeys(engine::Context &ctx, const Slice &user_key, weights_long.push_back(static_cast(src_weights[i])); } - CMSketch::MergeParams merge_params; - merge_params.dest = &dest_cms; - merge_params.num_keys = num_sources; - merge_params.cms_array = src_cms_pointers; - merge_params.weights = weights_long; - - auto merge_result = CMSketch::Merge(merge_params); + auto merge_result = CMSketch::Merge(&dest_cms, num_sources, src_cms_pointers, weights_long); if (!merge_result.IsOK()) { return rocksdb::Status::InvalidArgument("Merge operation failed due to overflow or invalid dimensions."); } From 48e9bc25f1c64f4ceb6f886976cecb657932d352 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Sun, 29 Sep 2024 18:50:06 -0400 Subject: [PATCH 20/25] lint fix --- src/commands/command_parser.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/commands/command_parser.h b/src/commands/command_parser.h index 673a733e362..a4e06e1c563 100644 --- a/src/commands/command_parser.h +++ b/src/commands/command_parser.h @@ -33,7 +33,7 @@ template struct MoveIterator : Iter { - explicit MoveIterator(Iter iter) : Iter(iter) {}; + explicit MoveIterator(Iter iter) : Iter(iter){}; typename Iter::value_type&& operator*() const { return std::move(this->Iter::operator*()); } }; From f57714693271dd0cce2541b7861453a051730077 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Mon, 30 Sep 2024 00:33:20 -0400 Subject: [PATCH 21/25] typing fixes --- src/types/cms.cc | 4 ++-- src/types/cms.h | 6 +++--- src/types/redis_cms.cc | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/types/cms.cc b/src/types/cms.cc index c9eecc5f7e3..05d118a149b 100644 --- a/src/types/cms.cc +++ b/src/types/cms.cc @@ -62,7 +62,7 @@ size_t CMSketch::Query(std::string_view item) const { } Status CMSketch::Merge(CMSketch* dest, size_t num_keys, std::vector cms_array, - std::vector weights) { + std::vector weights) { // Perform overflow check if (CMSketch::CheckOverflow(dest, num_keys, cms_array, weights) != 0) { return {Status::NotOK, "Overflow error."}; @@ -90,7 +90,7 @@ Status CMSketch::Merge(CMSketch* dest, size_t num_keys, std::vector& src, - const std::vector& weights) { + const std::vector& weights) { int64_t item_count = 0; int64_t cms_count = 0; size_t width = dest->GetWidth(); diff --git a/src/types/cms.h b/src/types/cms.h index b750bb02380..3d49532dea8 100644 --- a/src/types/cms.h +++ b/src/types/cms.h @@ -54,11 +54,11 @@ class CMSketch { CMSketch* dest; size_t num_keys; std::vector cms_array; - std::vector weights; + std::vector weights; }; static Status Merge(CMSketch* dest, size_t num_keys, std::vector cms_array, - std::vector weights); + std::vector weights); size_t GetLocationForHash(uint64_t hash, size_t i) const { return (hash % width_) + (i * width_); } @@ -72,7 +72,7 @@ class CMSketch { uint32_t GetDepth() const { return depth_; } static int CheckOverflow(CMSketch* dest, size_t quantity, const std::vector& src, - const std::vector& weights); + const std::vector& weights); private: uint32_t width_; diff --git a/src/types/redis_cms.cc b/src/types/redis_cms.cc index 9f8f7ea28fd..c8be0e71ef4 100644 --- a/src/types/redis_cms.cc +++ b/src/types/redis_cms.cc @@ -182,7 +182,7 @@ rocksdb::Status CMS::MergeUserKeys(engine::Context &ctx, const Slice &user_key, src_cms_objects.reserve(num_sources); std::vector src_cms_pointers; src_cms_pointers.reserve(num_sources); - std::vector weights_long; + std::vector weights_long; weights_long.reserve(num_sources); for (size_t i = 0; i < num_sources; ++i) { @@ -206,7 +206,7 @@ rocksdb::Status CMS::MergeUserKeys(engine::Context &ctx, const Slice &user_key, src_cms_objects.emplace_back(std::move(src_cms)); src_cms_pointers.push_back(&src_cms_objects.back()); - weights_long.push_back(static_cast(src_weights[i])); + weights_long.push_back(static_cast(src_weights[i])); } auto merge_result = CMSketch::Merge(&dest_cms, num_sources, src_cms_pointers, weights_long); From ab277e93ac3ab3f753cdaf7d97bd77b3f3e3df14 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Mon, 30 Sep 2024 08:49:18 -0400 Subject: [PATCH 22/25] one mb limit update per key --- src/types/cms.h | 7 ------- src/types/redis_cms.cc | 4 ++-- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/src/types/cms.h b/src/types/cms.h index 3d49532dea8..09a49668ea6 100644 --- a/src/types/cms.h +++ b/src/types/cms.h @@ -50,13 +50,6 @@ class CMSketch { size_t Query(std::string_view item) const; - struct MergeParams { - CMSketch* dest; - size_t num_keys; - std::vector cms_array; - std::vector weights; - }; - static Status Merge(CMSketch* dest, size_t num_keys, std::vector cms_array, std::vector weights); diff --git a/src/types/redis_cms.cc b/src/types/redis_cms.cc index c8be0e71ef4..5fd4b601fbe 100644 --- a/src/types/redis_cms.cc +++ b/src/types/redis_cms.cc @@ -91,7 +91,7 @@ rocksdb::Status CMS::InitByDim(engine::Context &ctx, const Slice &user_key, uint std::string ns_key = AppendNamespacePrefix(user_key); size_t memory_used = width * depth * sizeof(uint32_t); - const size_t max_memory = 50 * 1024 * 1024; + const size_t max_memory = 1 * 1024 * 1024; if (memory_used == 0) { return rocksdb::Status::InvalidArgument("Memory usage must be greater than 0."); @@ -143,7 +143,7 @@ rocksdb::Status CMS::InitByProb(engine::Context &ctx, const Slice &user_key, dou } size_t memory_used = dim.width * dim.depth * sizeof(uint32_t); - const size_t max_memory = 50 * 1024 * 1024; + const size_t max_memory = 1 * 1024 * 1024; if (memory_used == 0) { return rocksdb::Status::InvalidArgument("Memory usage must be greater than 0."); From f2f1288e8111667f5cf84f81a5c770df498ce4c1 Mon Sep 17 00:00:00 2001 From: Jonathan Chen <86070045+jonathanc-n@users.noreply.github.com> Date: Mon, 30 Sep 2024 10:18:49 -0400 Subject: [PATCH 23/25] Update src/types/redis_cms.cc Co-authored-by: Twice --- src/types/redis_cms.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types/redis_cms.cc b/src/types/redis_cms.cc index 5fd4b601fbe..31d45515546 100644 --- a/src/types/redis_cms.cc +++ b/src/types/redis_cms.cc @@ -97,7 +97,7 @@ rocksdb::Status CMS::InitByDim(engine::Context &ctx, const Slice &user_key, uint return rocksdb::Status::InvalidArgument("Memory usage must be greater than 0."); } if (memory_used > max_memory) { - return rocksdb::Status::InvalidArgument("Memory usage exceeds 50MB."); + return rocksdb::Status::InvalidArgument("Memory usage exceeds 1MB."); } LockGuard guard(storage_->GetLockManager(), ns_key); From ca60e13019d4beead3e45cb57fa8eb0ad3835984 Mon Sep 17 00:00:00 2001 From: Jonathan Chen <86070045+jonathanc-n@users.noreply.github.com> Date: Mon, 30 Sep 2024 10:19:06 -0400 Subject: [PATCH 24/25] Update src/types/redis_cms.cc Co-authored-by: Twice --- src/types/redis_cms.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/types/redis_cms.cc b/src/types/redis_cms.cc index 31d45515546..f3fa77b68f8 100644 --- a/src/types/redis_cms.cc +++ b/src/types/redis_cms.cc @@ -149,7 +149,7 @@ rocksdb::Status CMS::InitByProb(engine::Context &ctx, const Slice &user_key, dou return rocksdb::Status::InvalidArgument("Memory usage must be greater than 0."); } if (memory_used > max_memory) { - return rocksdb::Status::InvalidArgument("Memory usage exceeds 50MB."); + return rocksdb::Status::InvalidArgument("Memory usage exceeds 1MB."); } return rocksdb::Status::OK(); }; From df7a7b22d973d26c6a04f7f11d6d675de713bd69 Mon Sep 17 00:00:00 2001 From: mwish Date: Mon, 30 Sep 2024 23:53:06 +0800 Subject: [PATCH 25/25] [WIP] Some codereview check 1. Add just IncrBy syntax (wip) 2. Extract a hash function rather than explicit xxh 3. Fix a bug in merge --- src/commands/cmd_cms.cc | 22 ++++++++----- src/types/cms.cc | 25 +++++++++------ src/types/cms.h | 11 +++++-- src/types/redis_cms.cc | 68 ++++++++++++++++++----------------------- src/types/redis_cms.h | 15 +++++++-- 5 files changed, 81 insertions(+), 60 deletions(-) diff --git a/src/commands/cmd_cms.cc b/src/commands/cmd_cms.cc index 8e3426a14e2..79e0bce4202 100644 --- a/src/commands/cmd_cms.cc +++ b/src/commands/cmd_cms.cc @@ -30,6 +30,9 @@ namespace redis { /// CMS.INCRBY key item increment [item increment ...] +/// +/// The `key` should be an existing Count-Min Sketch key, +/// otherwise, the command will return an error. class CommandCMSIncrBy final : public Commander { public: Status Execute(Server *srv, Connection *conn, std::string *output) override { @@ -39,18 +42,20 @@ class CommandCMSIncrBy final : public Commander { redis::CMS cms(srv->storage, conn->GetNamespace()); engine::Context ctx(srv->storage); rocksdb::Status s; - std::unordered_map elements; + // pairs + std::vector elements; + elements.reserve((args_.size() - 2) / 2); for (size_t i = 2; i < args_.size(); i += 2) { - std::string key = args_[i]; - auto parse_result = ParseInt(args_[i + 1]); + std::string_view key = args_[i]; + auto parse_result = ParseInt(args_[i + 1]); if (!parse_result) { return {Status::RedisParseErr, errValueNotInteger}; } - uint64_t value = *parse_result; - elements[key] = value; + int64_t value = *parse_result; + elements.emplace_back(redis::CMS::IncrByPair{key, value}); } - - s = cms.IncrBy(ctx, args_[1], elements); + std::vector counters; + s = cms.IncrBy(ctx, args_[1], elements, &counters); if (s.IsNotFound()) { return {Status::RedisExecErr, "Key not found"}; } @@ -58,6 +63,7 @@ class CommandCMSIncrBy final : public Commander { return {Status::RedisExecErr, s.ToString()}; } + // TODO(mwish): adjust the output *output = redis::SimpleString("OK"); return Status::OK(); } @@ -78,7 +84,7 @@ class CommandCMSInfo final : public Commander { return {Status::RedisExecErr, "Key not found"}; } - if (!s.ok() && !s.IsNotFound()) { + if (!s.ok()) { return {Status::RedisExecErr, s.ToString()}; } diff --git a/src/types/cms.cc b/src/types/cms.cc index 05d118a149b..ba56b17bf4c 100644 --- a/src/types/cms.cc +++ b/src/types/cms.cc @@ -27,6 +27,10 @@ #include "xxhash.h" +uint64_t CMSketch::CountMinSketchHash(std::string_view item, uint64_t seed) { + return XXH64(item.data(), item.size(), seed); +} + CMSketch::CMSketchDimensions CMSketch::CMSDimFromProb(double error, double delta) { CMSketchDimensions dims; dims.width = std::ceil(2 / error); @@ -34,29 +38,30 @@ CMSketch::CMSketchDimensions CMSketch::CMSDimFromProb(double error, double delta return dims; } -size_t CMSketch::IncrBy(std::string_view item, uint32_t value) { - size_t min_count = std::numeric_limits::max(); +uint32_t CMSketch::IncrBy(std::string_view item, uint32_t value) { + uint32_t min_count = std::numeric_limits::max(); for (size_t i = 0; i < depth_; ++i) { - uint64_t hash = XXH32(item.data(), static_cast(item.size()), i); + uint64_t hash = CountMinSketchHash(item, /*seed=*/i); size_t loc = GetLocationForHash(hash, i); - if (array_[loc] > UINT32_MAX - value) { - array_[loc] = UINT32_MAX; + // Do overflow check + if (array_[loc] > std::numeric_limits::max() - value) { + array_[loc] = std::numeric_limits::max(); } else { array_[loc] += value; } - min_count = std::min(min_count, static_cast(array_[loc])); + min_count = std::min(min_count, array_[loc]); } counter_ += value; return min_count; } -size_t CMSketch::Query(std::string_view item) const { - size_t min_count = std::numeric_limits::max(); +uint32_t CMSketch::Query(std::string_view item) const { + uint32_t min_count = std::numeric_limits::max(); for (size_t i = 0; i < depth_; ++i) { - uint64_t hash = XXH32(item.data(), static_cast(item.size()), i); - min_count = std::min(min_count, static_cast(array_[GetLocationForHash(hash, i)])); + uint64_t hash = CountMinSketchHash(item, /*seed=*/i); + min_count = std::min(min_count, array_[GetLocationForHash(hash, i)]); } return min_count; } diff --git a/src/types/cms.h b/src/types/cms.h index 09a49668ea6..4e3cd467af4 100644 --- a/src/types/cms.h +++ b/src/types/cms.h @@ -46,9 +46,13 @@ class CMSketch { static CMSketchDimensions CMSDimFromProb(double error, double delta); - size_t IncrBy(std::string_view item, uint32_t value); + /// Increment the counter of the given item by the specified increment. + /// + /// \param item The item to increment. Returns UINT32_MAX if the + /// counter overflows. + uint32_t IncrBy(std::string_view item, uint32_t value); - size_t Query(std::string_view item) const; + uint32_t Query(std::string_view item) const; static Status Merge(CMSketch* dest, size_t num_keys, std::vector cms_array, std::vector weights); @@ -67,6 +71,9 @@ class CMSketch { static int CheckOverflow(CMSketch* dest, size_t quantity, const std::vector& src, const std::vector& weights); + private: + static uint64_t CountMinSketchHash(std::string_view item, uint64_t seed); + private: uint32_t width_; uint32_t depth_; diff --git a/src/types/redis_cms.cc b/src/types/redis_cms.cc index f3fa77b68f8..16897f3838b 100644 --- a/src/types/redis_cms.cc +++ b/src/types/redis_cms.cc @@ -31,8 +31,8 @@ rocksdb::Status CMS::GetMetadata(engine::Context &ctx, const Slice &ns_key, Coun return Database::GetMetadata(ctx, {kRedisCountMinSketch}, ns_key, metadata); } -rocksdb::Status CMS::IncrBy(engine::Context &ctx, const Slice &user_key, - const std::unordered_map &elements) { +rocksdb::Status CMS::IncrBy(engine::Context &ctx, const Slice &user_key, const std::vector &elements, + std::vector *counters) { std::string ns_key = AppendNamespacePrefix(user_key); LockGuard guard(storage_->GetLockManager(), ns_key); @@ -52,12 +52,12 @@ rocksdb::Status CMS::IncrBy(engine::Context &ctx, const Slice &user_key, CMSketch cms(metadata.width, metadata.depth, metadata.counter, metadata.array); - for (auto &element : elements) { - if (element.second > 0 && metadata.counter > std::numeric_limits::max() - element.second) { + for (const auto &element : elements) { + if (element.value > 0 && metadata.counter > std::numeric_limits::max() - element.value) { return rocksdb::Status::InvalidArgument("Overflow error: IncrBy would result in counter overflow"); } - cms.IncrBy(element.first, element.second); - metadata.counter += element.second; + uint32_t local_counter = cms.IncrBy(element.key, element.value); + metadata.counter += element.value; } metadata.array = std::move(cms.GetArray()); @@ -83,20 +83,20 @@ rocksdb::Status CMS::Info(engine::Context &ctx, const Slice &user_key, CMSketch: ret->width = metadata.width; ret->depth = metadata.depth; ret->count = metadata.counter; - return rocksdb::Status::OK(); -}; +} rocksdb::Status CMS::InitByDim(engine::Context &ctx, const Slice &user_key, uint32_t width, uint32_t depth) { std::string ns_key = AppendNamespacePrefix(user_key); size_t memory_used = width * depth * sizeof(uint32_t); - const size_t max_memory = 1 * 1024 * 1024; + // We firstly limit the memory usage to 1MB. + constexpr size_t kMaxMemory = 1 * 1024 * 1024; if (memory_used == 0) { return rocksdb::Status::InvalidArgument("Memory usage must be greater than 0."); } - if (memory_used > max_memory) { + if (memory_used > kMaxMemory) { return rocksdb::Status::InvalidArgument("Memory usage exceeds 1MB."); } @@ -137,21 +137,7 @@ rocksdb::Status CMS::InitByProb(engine::Context &ctx, const Slice &user_key, dou return rocksdb::Status::InvalidArgument("Delta must be between 0 and 1 (exclusive)."); } CMSketch::CMSketchDimensions dim = CMSketch::CMSDimFromProb(error, delta); - auto s = InitByDim(ctx, user_key, dim.width, dim.depth); - if (!s.ok()) { - return s; - } - - size_t memory_used = dim.width * dim.depth * sizeof(uint32_t); - const size_t max_memory = 1 * 1024 * 1024; - - if (memory_used == 0) { - return rocksdb::Status::InvalidArgument("Memory usage must be greater than 0."); - } - if (memory_used > max_memory) { - return rocksdb::Status::InvalidArgument("Memory usage exceeds 1MB."); - } - return rocksdb::Status::OK(); + return InitByDim(ctx, user_key, dim.width, dim.depth); }; rocksdb::Status CMS::MergeUserKeys(engine::Context &ctx, const Slice &user_key, const std::vector &src_keys, @@ -165,14 +151,18 @@ rocksdb::Status CMS::MergeUserKeys(engine::Context &ctx, const Slice &user_key, } std::string dest_ns_key = AppendNamespacePrefix(user_key); - LockGuard guard(storage_->GetLockManager(), dest_ns_key); + std::vector ns_keys{dest_ns_key}; + for (const auto &src_key : src_keys) { + ns_keys.emplace_back(AppendNamespacePrefix(src_key)); + } + MultiLockGuard guard(storage_->GetLockManager(), ns_keys); CountMinSketchMetadata dest_metadata{}; rocksdb::Status dest_status = GetMetadata(ctx, dest_ns_key, &dest_metadata); + if (dest_status.IsNotFound()) { + return rocksdb::Status::InvalidArgument("Destination CMS does not exist."); + } if (!dest_status.ok()) { - if (dest_status.IsNotFound()) { - return rocksdb::Status::InvalidArgument("Destination CMS does not exist."); - } return dest_status; } @@ -180,18 +170,15 @@ rocksdb::Status CMS::MergeUserKeys(engine::Context &ctx, const Slice &user_key, std::vector src_cms_objects; src_cms_objects.reserve(num_sources); - std::vector src_cms_pointers; - src_cms_pointers.reserve(num_sources); std::vector weights_long; weights_long.reserve(num_sources); for (size_t i = 0; i < num_sources; ++i) { - std::string src_ns_key = AppendNamespacePrefix(src_keys[i]); - LockGuard guard(storage_->GetLockManager(), src_ns_key); - + const auto &src_ns_key = ns_keys[i + 1]; CountMinSketchMetadata src_metadata{}; rocksdb::Status src_status = GetMetadata(ctx, src_ns_key, &src_metadata); if (!src_status.ok()) { + // TODO(mwish): check the not found syntax here. if (src_status.IsNotFound()) { return rocksdb::Status::InvalidArgument("Source CMS key not found."); } @@ -204,11 +191,15 @@ rocksdb::Status CMS::MergeUserKeys(engine::Context &ctx, const Slice &user_key, CMSketch src_cms(src_metadata.width, src_metadata.depth, src_metadata.counter, src_metadata.array); src_cms_objects.emplace_back(std::move(src_cms)); - src_cms_pointers.push_back(&src_cms_objects.back()); weights_long.push_back(static_cast(src_weights[i])); } - + // Initialize the destination CMS with the source CMSes after initializations + // since vector might resize and reallocate memory. + std::vector src_cms_pointers(num_sources); + for (size_t i = 0; i < num_sources; ++i) { + src_cms_pointers[i] = &src_cms_objects[i]; + } auto merge_result = CMSketch::Merge(&dest_cms, num_sources, src_cms_pointers, weights_long); if (!merge_result.IsOK()) { return rocksdb::Status::InvalidArgument("Merge operation failed due to overflow or invalid dimensions."); @@ -240,7 +231,8 @@ rocksdb::Status CMS::Query(engine::Context &ctx, const Slice &user_key, const st if (s.IsNotFound()) { return rocksdb::Status::NotFound(); - } else if (!s.ok()) { + } + if (!s.ok()) { return s; } @@ -251,6 +243,6 @@ rocksdb::Status CMS::Query(engine::Context &ctx, const Slice &user_key, const st } return rocksdb::Status::OK(); -}; +} } // namespace redis diff --git a/src/types/redis_cms.h b/src/types/redis_cms.h index 2071256bf6b..f8fff5e21a0 100644 --- a/src/types/redis_cms.h +++ b/src/types/redis_cms.h @@ -20,6 +20,8 @@ #pragma once +#include + #include "cms.h" #include "storage/redis_db.h" #include "storage/redis_metadata.h" @@ -30,8 +32,17 @@ class CMS : public Database { public: explicit CMS(engine::Storage *storage, const std::string &ns) : Database(storage, ns) {} - rocksdb::Status IncrBy(engine::Context &ctx, const Slice &user_key, - const std::unordered_map &elements); + struct IncrByPair { + std::string_view key; + int64_t value; + }; + + /// Increment the counter of the given item(s) by the specified increment(s). + /// + /// \param[out] counters The counter values after the increment, if the value is UINT32_MAX, + /// it means the item does overflow. + rocksdb::Status IncrBy(engine::Context &ctx, const Slice &user_key, const std::vector &elements, + std::vector *counters); rocksdb::Status Info(engine::Context &ctx, const Slice &user_key, CMSketch::CMSInfo *ret); rocksdb::Status InitByDim(engine::Context &ctx, const Slice &user_key, uint32_t width, uint32_t depth); rocksdb::Status InitByProb(engine::Context &ctx, const Slice &user_key, double error, double delta);