diff --git a/.github/workflows/kvrocks.yaml b/.github/workflows/kvrocks.yaml index 25e47cf38bb..84f1a1ab6da 100644 --- a/.github/workflows/kvrocks.yaml +++ b/.github/workflows/kvrocks.yaml @@ -127,6 +127,10 @@ jobs: - name: Ubuntu GCC os: ubuntu-20.04 compiler: gcc + - name: SonarCloud with Coverage + os: ubuntu-22.04 + compiler: gcc + sonarcloud: -DCMAKE_CXX_FLAGS=--coverage - name: Ubuntu Clang os: ubuntu-20.04 compiler: clang @@ -192,6 +196,8 @@ jobs: with_speedb: -DENABLE_SPEEDB=ON runs-on: ${{ matrix.os }} + env: + SONARCLOUD_OUTPUT_DIR: sonarcloud-data steps: - name: Setup macOS if: ${{ startsWith(matrix.os, 'macos') }} @@ -222,6 +228,8 @@ jobs: pushd redis-6.2.7 && BUILD_TLS=yes make -j$NPROC redis-cli && mv src/redis-cli $HOME/local/bin/ && popd - uses: actions/checkout@v3 + with: + fetch-depth: 0 - uses: actions/setup-python@v4 with: python-version: 3.x @@ -229,10 +237,25 @@ jobs: with: go-version-file: 'tests/gocase/go.mod' + - name: Install gcovr 5.0 + run: pip install gcovr==5.0 # 5.1 is not supported + if: ${{ matrix.sonarcloud }} + + - name: Install sonar-scanner and build-wrapper + uses: SonarSource/sonarcloud-github-c-cpp@v2 + if: ${{ matrix.sonarcloud }} + - name: Build Kvrocks + if: ${{ !matrix.sonarcloud }} run: | - ./x.py build -j$NPROC --unittest --compiler ${{ matrix.compiler }} ${{ matrix.without_jemalloc }} ${{ matrix.without_luajit }} \ - ${{ matrix.with_ninja }} ${{ matrix.with_sanitizer }} ${{ matrix.with_openssl }} ${{ matrix.new_encoding }} ${{ matrix.with_speedb }} ${{ env.CMAKE_EXTRA_DEFS }} + ./x.py build -j$NPROC --unittest --compiler ${{ matrix.compiler }} ${{ matrix.without_jemalloc }} \ + ${{ matrix.without_luajit }} ${{ matrix.with_ninja }} ${{ matrix.with_sanitizer }} ${{ matrix.with_openssl }} \ + ${{ matrix.new_encoding }} ${{ matrix.with_speedb }} ${{ env.CMAKE_EXTRA_DEFS }} + + - name: Build Kvrocks (SonarCloud) + if: ${{ matrix.sonarcloud }} + run: | + build-wrapper-linux-x86-64 --out-dir ${{ env.SONARCLOUD_OUTPUT_DIR }} ./x.py build -j$NPROC --unittest --compiler ${{ matrix.compiler }} ${{ matrix.sonarcloud }} - name: Setup Coredump if: ${{ startsWith(matrix.os, 'ubuntu') }} @@ -292,6 +315,25 @@ jobs: path: | ./build/kvrocks ./coredumps/* + + - name: Collect coverage into one XML report + if: ${{ matrix.sonarcloud }} + run: | + gcovr --sonarqube > ${{ env.SONARCLOUD_OUTPUT_DIR }}/coverage.xml + + - name: Add event information + if: ${{ matrix.sonarcloud }} + env: + GITHUB_EVENT_JSON: ${{ toJson(github.event) }} + run: | + echo "$GITHUB_EVENT_JSON" | tee ${{ env.SONARCLOUD_OUTPUT_DIR }}/github-event.json + + - name: Upload SonarCloud data + if: ${{ matrix.sonarcloud }} + uses: actions/upload-artifact@v3 + with: + name: sonarcloud-data + path: ${{ env.SONARCLOUD_OUTPUT_DIR }} check-docker: name: Check Docker image diff --git a/.github/workflows/sonar.yaml b/.github/workflows/sonar.yaml new file mode 100644 index 00000000000..36a2d2a0f39 --- /dev/null +++ b/.github/workflows/sonar.yaml @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: SonarCloud Analysis + +on: + workflow_run: + workflows: [CI] + types: [completed] + +jobs: + sonarcloud: + name: Upload to SonarCloud + runs-on: ubuntu-22.04 + if: github.event.workflow_run.conclusion == 'success' && github.repository_owner == 'apache' + steps: + - uses: actions/checkout@v3 + with: + repository: ${{ github.event.workflow_run.head_repository.full_name }} + ref: ${{ github.event.workflow_run.head_sha }} + fetch-depth: 0 + - name: Install sonar-scanner and build-wrapper + uses: SonarSource/sonarcloud-github-c-cpp@v2 + - name: 'Download code coverage' + uses: actions/github-script@v6 + with: + script: | + let allArtifacts = await github.rest.actions.listWorkflowRunArtifacts({ + owner: context.repo.owner, + repo: context.repo.repo, + run_id: context.payload.workflow_run.id, + }); + let matchArtifact = allArtifacts.data.artifacts.filter((artifact) => { + return artifact.name == "sonarcloud-data" + })[0]; + let download = await github.rest.actions.downloadArtifact({ + owner: context.repo.owner, + repo: context.repo.repo, + artifact_id: matchArtifact.id, + archive_format: 'zip', + }); + let fs = require('fs'); + fs.writeFileSync(`${process.env.GITHUB_WORKSPACE}/sonarcloud-data.zip`, Buffer.from(download.data)); + - name: 'Unzip code coverage' + run: | + unzip sonarcloud-data.zip -d sonarcloud-data + ls -a sonarcloud-data + + - uses: actions/setup-python@v4 + with: + python-version: 3.x + - name: Configure Kvrocks + run: | + ./x.py build -j$(nproc) --compiler gcc --skip-build + + - name: Run sonar-scanner + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + SONAR_TOKEN: ${{ secrets.SONARCLOUD_TOKEN }} + run: | + PR_NUMBER=$(jq -r '.number | select (.!=null)' sonarcloud-data/github-event.json) + echo "The PR number is $PR_NUMBER" + + sonar-scanner \ + --define sonar.cfamily.build-wrapper-output="sonarcloud-data" \ + --define sonar.coverageReportPaths=sonarcloud-data/coverage.xml \ + --define sonar.projectKey=apache_kvrocks \ + --define sonar.organization=apache \ + --define sonar.scm.revision=${{ github.event.workflow_run.head_sha }} \ + --define sonar.pullrequest.key=$PR_NUMBER diff --git a/README.md b/README.md index 7bc80f2adca..bd0a6798996 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,12 @@ sudo bash cmake.sh --skip-license --prefix=/usr # enable gcc and make in devtoolset-11 source /opt/rh/devtoolset-11/enable +# openSUSE / SUSE Linux Enterprise +sudo zypper install -y gcc11 gcc11-c++ make wget git autoconf automake python3 curl cmake + +# Arch Linux +sudo pacman -Sy --noconfirm autoconf automake python3 git wget which cmake make gcc + # macOS brew install git cmake autoconf automake libtool openssl # please link openssl by force if it still cannot be found after installing diff --git a/kvrocks.conf b/kvrocks.conf index 3da4da6a055..17fe485426e 100644 --- a/kvrocks.conf +++ b/kvrocks.conf @@ -572,6 +572,17 @@ migrate-sequence-gap 10000 # Default: 4096MB rocksdb.block_cache_size 4096 +# Specify the type of cache used in the block cache. +# Accept value: "lru", "hcc" +# "lru" stands for the cache with the LRU(Least Recently Used) replacement policy. +# +# "hcc" stands for the Hyper Clock Cache, a lock-free cache alternative +# that offers much improved CPU efficiency vs. LRU cache under high parallel +# load or high contention. +# +# default lru +rocksdb.block_cache_type lru + # A global cache for table-level rows in RocksDB. If almost always point # lookups, enlarging row cache may improve read performance. Otherwise, # if we enlarge this value, we can lessen metadata/subkey block cache size. @@ -838,6 +849,7 @@ rocksdb.write_options.sync no # If yes, writes will not first go to the write ahead log, # and the write may get lost after a crash. +# You must keep wal enabled if you use replication. # # Default: no rocksdb.write_options.disable_wal no diff --git a/src/cluster/slot_migrate.cc b/src/cluster/slot_migrate.cc index 424d0b4ca5a..be54cc2e9e2 100644 --- a/src/cluster/slot_migrate.cc +++ b/src/cluster/slot_migrate.cc @@ -151,11 +151,10 @@ Status SlotMigrator::CreateMigrationThread() { void SlotMigrator::loop() { while (true) { - std::unique_lock ul(job_mutex_); - while (!isTerminated() && !migration_job_) { - job_cv_.wait(ul); + { + std::unique_lock ul(job_mutex_); + job_cv_.wait(ul, [&] { return isTerminated() || migration_job_; }); } - ul.unlock(); if (isTerminated()) { clean(); diff --git a/src/commands/cmd_function.cc b/src/commands/cmd_function.cc index 2123cc72ec7..2d7ce193e49 100644 --- a/src/commands/cmd_function.cc +++ b/src/commands/cmd_function.cc @@ -53,18 +53,18 @@ struct CommandFunction : Commander { with_code = true; } - return lua::FunctionList(srv, libname, with_code, output); + return lua::FunctionList(srv, conn, libname, with_code, output); } else if (parser.EatEqICase("listfunc")) { std::string funcname; if (parser.EatEqICase("funcname")) { funcname = GET_OR_RET(parser.TakeStr()); } - return lua::FunctionListFunc(srv, funcname, output); + return lua::FunctionListFunc(srv, conn, funcname, output); } else if (parser.EatEqICase("listlib")) { auto libname = GET_OR_RET(parser.TakeStr().Prefixed("expect a library name")); - return lua::FunctionListLib(srv, libname, output); + return lua::FunctionListLib(srv, conn, libname, output); } else if (parser.EatEqICase("delete")) { auto libname = GET_OR_RET(parser.TakeStr()); if (!lua::FunctionIsLibExist(conn, libname)) { diff --git a/src/commands/cmd_geo.cc b/src/commands/cmd_geo.cc index 0f4d98ebf25..3ed2237d8a8 100644 --- a/src/commands/cmd_geo.cc +++ b/src/commands/cmd_geo.cc @@ -150,7 +150,7 @@ class CommandGeoDist : public CommandGeoBase { if (s.IsNotFound()) { *output = conn->NilString(); } else { - *output = redis::BulkString(util::Float2String(GetDistanceByUnit(distance))); + *output = conn->Double(GetDistanceByUnit(distance)); } return Status::OK(); } @@ -215,8 +215,7 @@ class CommandGeoPos : public Commander { if (iter == geo_points.end()) { list.emplace_back(conn->NilString()); } else { - list.emplace_back(conn->MultiBulkString( - {util::Float2String(iter->second.longitude), util::Float2String(iter->second.latitude)})); + list.emplace_back(redis::Array({conn->Double(iter->second.longitude), conn->Double(iter->second.latitude)})); } } *output = redis::Array(list); @@ -331,14 +330,13 @@ class CommandGeoRadius : public CommandGeoBase { std::vector one; one.emplace_back(redis::BulkString(geo_point.member)); if (with_dist_) { - one.emplace_back(redis::BulkString(util::Float2String(GetDistanceByUnit(geo_point.dist)))); + one.emplace_back(conn->Double(GetDistanceByUnit(geo_point.dist))); } if (with_hash_) { - one.emplace_back(redis::BulkString(util::Float2String(geo_point.score))); + one.emplace_back(conn->Double(geo_point.score)); } if (with_coord_) { - one.emplace_back( - conn->MultiBulkString({util::Float2String(geo_point.longitude), util::Float2String(geo_point.latitude)})); + one.emplace_back(redis::Array({conn->Double(geo_point.longitude), conn->Double(geo_point.latitude)})); } list.emplace_back(redis::Array(one)); } @@ -346,6 +344,26 @@ class CommandGeoRadius : public CommandGeoBase { return redis::Array(list); } + static std::vector Range(const std::vector &args) { + int store_key = 0; + + // Check for the presence of the stored key in the command args. + for (size_t i = 6; i < args.size(); i++) { + // For the case when a user specifies both "store" and "storedist" options, + // the second key will override the first key. The behavior is kept the same + // as in ParseRadiusExtraOption method. + if ((util::ToLower(args[i]) == "store" || util::ToLower(args[i]) == "storedist") && i + 1 < args.size()) { + store_key = (int)i + 1; + i++; + } + } + + if (store_key > 0) { + return {{1, 1, 1}, {store_key, store_key, 1}}; + } + return {{1, 1, 1}}; + } + protected: double radius_ = 0; bool with_coord_ = false; @@ -509,14 +527,13 @@ class CommandGeoSearch : public CommandGeoBase { std::vector one; one.emplace_back(redis::BulkString(geo_point.member)); if (with_dist_) { - one.emplace_back(redis::BulkString(util::Float2String(GetDistanceByUnit(geo_point.dist)))); + one.emplace_back(conn->Double(GetDistanceByUnit(geo_point.dist))); } if (with_hash_) { - one.emplace_back(redis::BulkString(util::Float2String(geo_point.score))); + one.emplace_back(conn->Double(geo_point.score)); } if (with_coord_) { - one.emplace_back( - conn->MultiBulkString({util::Float2String(geo_point.longitude), util::Float2String(geo_point.latitude)})); + one.emplace_back(redis::Array({conn->Double(geo_point.longitude), conn->Double(geo_point.latitude)})); } output.emplace_back(redis::Array(one)); } @@ -607,6 +624,8 @@ class CommandGeoSearchStore : public CommandGeoSearch { return Status::OK(); } + static std::vector Range(const std::vector &args) { return {{1, 1, 1}, {2, 2, 1}}; } + private: bool store_distance_ = false; std::string store_key_; @@ -649,6 +668,26 @@ class CommandGeoRadiusByMember : public CommandGeoRadius { return Status::OK(); } + + static std::vector Range(const std::vector &args) { + int store_key = 0; + + // Check for the presence of the stored key in the command args. + for (size_t i = 5; i < args.size(); i++) { + // For the case when a user specifies both "store" and "storedist" options, + // the second key will override the first key. The behavior is kept the same + // as in ParseRadiusExtraOption method. + if ((util::ToLower(args[i]) == "store" || util::ToLower(args[i]) == "storedist") && i + 1 < args.size()) { + store_key = (int)i + 1; + i++; + } + } + + if (store_key > 0) { + return {{1, 1, 1}, {store_key, store_key, 1}}; + } + return {{1, 1, 1}}; + } }; class CommandGeoRadiusReadonly : public CommandGeoRadius { @@ -665,11 +704,12 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr("geoadd", -5, "write", 1, 1, MakeCmdAttr("geodist", -4, "read-only", 1, 1, 1), MakeCmdAttr("geohash", -3, "read-only", 1, 1, 1), MakeCmdAttr("geopos", -3, "read-only", 1, 1, 1), - MakeCmdAttr("georadius", -6, "write", 1, 1, 1), - MakeCmdAttr("georadiusbymember", -5, "write", 1, 1, 1), + MakeCmdAttr("georadius", -6, "write", CommandGeoRadius::Range), + MakeCmdAttr("georadiusbymember", -5, "write", + CommandGeoRadiusByMember::Range), MakeCmdAttr("georadius_ro", -6, "read-only", 1, 1, 1), MakeCmdAttr("georadiusbymember_ro", -5, "read-only", 1, 1, 1), MakeCmdAttr("geosearch", -7, "read-only", 1, 1, 1), - MakeCmdAttr("geosearchstore", -8, "write", 1, 1, 1)) + MakeCmdAttr("geosearchstore", -8, "write", CommandGeoSearchStore::Range)) } // namespace redis diff --git a/src/commands/cmd_hash.cc b/src/commands/cmd_hash.cc index 8b9526f97fc..cc4c475ebad 100644 --- a/src/commands/cmd_hash.cc +++ b/src/commands/cmd_hash.cc @@ -284,7 +284,7 @@ class CommandHVals : public Commander { for (const auto &p : field_values) { values.emplace_back(p.value); } - *output = conn->MultiBulkString(values, false); + *output = ArrayOfBulkStrings(values); return Status::OK(); } @@ -306,7 +306,7 @@ class CommandHGetAll : public Commander { kv_pairs.emplace_back(p.field); kv_pairs.emplace_back(p.value); } - *output = conn->MultiBulkString(kv_pairs, false); + *output = conn->MapOfBulkStrings(kv_pairs); return Status::OK(); } @@ -350,7 +350,7 @@ class CommandHRangeByLex : public Commander { kv_pairs.emplace_back(p.field); kv_pairs.emplace_back(p.value); } - *output = conn->MultiBulkString(kv_pairs, false); + *output = ArrayOfBulkStrings(kv_pairs); return Status::OK(); } @@ -372,7 +372,14 @@ class CommandHScan : public CommandSubkeyScanBase { return {Status::RedisExecErr, s.ToString()}; } - *output = GenerateOutput(srv, conn, fields, values, CursorType::kTypeHash); + auto cursor = GetNextCursor(srv, fields, CursorType::kTypeHash); + std::vector entries; + entries.reserve(2 * fields.size()); + for (size_t i = 0; i < fields.size(); i++) { + entries.emplace_back(redis::BulkString(fields[i])); + entries.emplace_back(redis::BulkString(values[i])); + } + *output = redis::Array({redis::BulkString(cursor), redis::Array(entries)}); return Status::OK(); } }; @@ -417,7 +424,7 @@ class CommandHRandField : public Commander { if (no_parameters_) *output = s.IsNotFound() ? conn->NilString() : redis::BulkString(result_entries[0]); else - *output = conn->MultiBulkString(result_entries, false); + *output = ArrayOfBulkStrings(result_entries); return Status::OK(); } diff --git a/src/commands/cmd_json.cc b/src/commands/cmd_json.cc index 5377ed31654..8cd49c51e3f 100644 --- a/src/commands/cmd_json.cc +++ b/src/commands/cmd_json.cc @@ -226,7 +226,7 @@ class CommandJsonObjkeys : public Commander { *output = redis::MultiLen(results.size()); for (const auto &item : results) { if (item.has_value()) { - *output += conn->MultiBulkString(item.value(), false); + *output += ArrayOfBulkStrings(item.value()); } else { *output += conn->NilString(); } diff --git a/src/commands/cmd_key.cc b/src/commands/cmd_key.cc index 35bfeefd491..f94f87fecf5 100644 --- a/src/commands/cmd_key.cc +++ b/src/commands/cmd_key.cc @@ -307,6 +307,36 @@ class CommandDel : public Commander { } }; +class CommandRename : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::Database redis(srv->storage, conn->GetNamespace()); + bool ret = true; + + auto s = redis.Rename(args_[1], args_[2], false, &ret); + if (!s.ok()) return {Status::RedisExecErr, s.ToString()}; + + *output = redis::SimpleString("OK"); + return Status::OK(); + } +}; + +class CommandRenameNX : public Commander { + public: + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::Database redis(srv->storage, conn->GetNamespace()); + bool ret = true; + auto s = redis.Rename(args_[1], args_[2], true, &ret); + if (!s.ok()) return {Status::RedisExecErr, s.ToString()}; + if (ret) { + *output = redis::Integer(1); + } else { + *output = redis::Integer(0); + } + return Status::OK(); + } +}; + REDIS_REGISTER_COMMANDS(MakeCmdAttr("ttl", 2, "read-only", 1, 1, 1), MakeCmdAttr("pttl", 2, "read-only", 1, 1, 1), MakeCmdAttr("type", 2, "read-only", 1, 1, 1), @@ -321,6 +351,8 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr("ttl", 2, "read-only", 1, 1, 1), MakeCmdAttr("expiretime", 2, "read-only", 1, 1, 1), MakeCmdAttr("pexpiretime", 2, "read-only", 1, 1, 1), MakeCmdAttr("del", -2, "write", 1, -1, 1), - MakeCmdAttr("unlink", -2, "write", 1, -1, 1), ) + MakeCmdAttr("unlink", -2, "write", 1, -1, 1), + MakeCmdAttr("rename", 3, "write", 1, 2, 1), + MakeCmdAttr("renamenx", 3, "write", 1, 2, 1), ) } // namespace redis diff --git a/src/commands/cmd_list.cc b/src/commands/cmd_list.cc index 5470d444f4f..726d2a70889 100644 --- a/src/commands/cmd_list.cc +++ b/src/commands/cmd_list.cc @@ -536,7 +536,7 @@ class CommandLRange : public Commander { return {Status::RedisExecErr, s.ToString()}; } - *output = conn->MultiBulkString(elems, false); + *output = ArrayOfBulkStrings(elems); return Status::OK(); } @@ -839,7 +839,7 @@ class CommandLPos : public Commander { for (const auto &index : indexes) { values.emplace_back(std::to_string(index)); } - *output = conn->MultiBulkString(values, false); + *output = ArrayOfBulkStrings(values); } return Status::OK(); } diff --git a/src/commands/cmd_server.cc b/src/commands/cmd_server.cc index 71781cd5182..10cbc8cdc4e 100644 --- a/src/commands/cmd_server.cc +++ b/src/commands/cmd_server.cc @@ -105,7 +105,7 @@ class CommandNamespace : public Commander { } namespaces.emplace_back(kDefaultNamespace); namespaces.emplace_back(config->requirepass); - *output = conn->MultiBulkString(namespaces, false); + *output = ArrayOfBulkStrings(namespaces); } else { auto token = srv->GetNamespace()->Get(args_[2]); if (token.Is()) { @@ -252,7 +252,7 @@ class CommandConfig : public Commander { } else if (args_.size() == 3 && sub_command == "get") { std::vector values; config->Get(args_[2], &values); - *output = conn->MultiBulkString(values); + *output = conn->MapOfBulkStrings(values); } else if (args_.size() == 4 && sub_command == "set") { Status s = config->Set(srv, args_[2], args_[3]); if (!s.IsOK()) { @@ -607,11 +607,26 @@ class CommandDebug : public Commander { *output = redis::BulkString("Hello World"); } else if (protocol_type_ == "integer") { *output = redis::Integer(12345); + } else if (protocol_type_ == "double") { + *output = conn->Double(3.141); } else if (protocol_type_ == "array") { *output = redis::MultiLen(3); for (int i = 0; i < 3; i++) { *output += redis::Integer(i); } + } else if (protocol_type_ == "set") { + *output = conn->HeaderOfSet(3); + for (int i = 0; i < 3; i++) { + *output += redis::Integer(i); + } + } else if (protocol_type_ == "map") { + *output = conn->HeaderOfMap(3); + for (int i = 0; i < 3; i++) { + *output += redis::Integer(i); + *output += conn->Bool(i == 1); + } + } else if (protocol_type_ == "bignum") { + *output = conn->BigNumber("1234567999999999999999999999999999999"); } else if (protocol_type_ == "true") { *output = conn->Bool(true); } else if (protocol_type_ == "false") { @@ -619,8 +634,9 @@ class CommandDebug : public Commander { } else if (protocol_type_ == "null") { *output = conn->NilString(); } else { - *output = - redis::Error("Wrong protocol type name. Please use one of the following: string|int|array|true|false|null"); + *output = redis::Error( + "Wrong protocol type name. Please use one of the following: " + "string|integer|double|array|set|bignum|true|false|null"); } } else { return {Status::RedisInvalidCmd, "Unknown subcommand, should be DEBUG or PROTOCOL"}; @@ -778,7 +794,10 @@ class CommandHello final : public Commander { } else { output_list.push_back(redis::BulkString("standalone")); } - *output = redis::Array(output_list); + *output = conn->HeaderOfMap(output_list.size() / 2); + for (const auto &item : output_list) { + *output += item; + } return Status::OK(); } }; @@ -819,7 +838,7 @@ class CommandScan : public CommandScanBase { list.emplace_back(redis::BulkString("0")); } - list.emplace_back(conn->MultiBulkString(keys, false)); + list.emplace_back(ArrayOfBulkStrings(keys)); return redis::Array(list); } diff --git a/src/commands/cmd_set.cc b/src/commands/cmd_set.cc index 0b5a8bda5a4..ced252234b2 100644 --- a/src/commands/cmd_set.cc +++ b/src/commands/cmd_set.cc @@ -93,7 +93,7 @@ class CommandSMembers : public Commander { return {Status::RedisExecErr, s.ToString()}; } - *output = conn->MultiBulkString(members, false); + *output = conn->SetOfBulkStrings(members); return Status::OK(); } }; @@ -171,7 +171,7 @@ class CommandSPop : public Commander { } if (with_count_) { - *output = conn->MultiBulkString(members, false); + *output = conn->SetOfBulkStrings(members); } else { if (members.size() > 0) { *output = redis::BulkString(members.front()); @@ -211,7 +211,7 @@ class CommandSRandMember : public Commander { if (!s.ok()) { return {Status::RedisExecErr, s.ToString()}; } - *output = conn->MultiBulkString(members, false); + *output = conn->SetOfBulkStrings(members); return Status::OK(); } @@ -249,7 +249,7 @@ class CommandSDiff : public Commander { return {Status::RedisExecErr, s.ToString()}; } - *output = conn->MultiBulkString(members, false); + *output = conn->SetOfBulkStrings(members); return Status::OK(); } }; @@ -269,7 +269,7 @@ class CommandSUnion : public Commander { return {Status::RedisExecErr, s.ToString()}; } - *output = conn->MultiBulkString(members, false); + *output = conn->SetOfBulkStrings(members); return Status::OK(); } }; @@ -289,7 +289,7 @@ class CommandSInter : public Commander { return {Status::RedisExecErr, s.ToString()}; } - *output = conn->MultiBulkString(members, false); + *output = conn->SetOfBulkStrings(members); return Status::OK(); } }; diff --git a/src/commands/cmd_stream.cc b/src/commands/cmd_stream.cc index f82497fce4d..7ba408859c1 100644 --- a/src/commands/cmd_stream.cc +++ b/src/commands/cmd_stream.cc @@ -445,9 +445,9 @@ class CommandXInfo : public Commander { } if (!full_) { - output->append(redis::MultiLen(14)); + output->append(conn->HeaderOfMap(7)); } else { - output->append(redis::MultiLen(12)); + output->append(conn->HeaderOfMap(6)); } output->append(redis::BulkString("length")); output->append(redis::Integer(info.size)); @@ -503,7 +503,7 @@ class CommandXInfo : public Commander { output->append(redis::MultiLen(result_vector.size())); for (auto const &it : result_vector) { - output->append(redis::MultiLen(12)); + output->append(conn->HeaderOfMap(6)); output->append(redis::BulkString("name")); output->append(redis::BulkString(it.first)); output->append(redis::BulkString("consumers")); @@ -545,7 +545,7 @@ class CommandXInfo : public Commander { output->append(redis::MultiLen(result_vector.size())); auto now = util::GetTimeStampMS(); for (auto const &it : result_vector) { - output->append(redis::MultiLen(8)); + output->append(conn->HeaderOfMap(4)); output->append(redis::BulkString("name")); output->append(redis::BulkString(it.first)); output->append(redis::BulkString("pending")); diff --git a/src/commands/cmd_zset.cc b/src/commands/cmd_zset.cc index acddad82ddf..d78fe54eed8 100644 --- a/src/commands/cmd_zset.cc +++ b/src/commands/cmd_zset.cc @@ -87,7 +87,7 @@ class CommandZAdd : public Commander { return Status::OK(); } - *output = redis::BulkString(util::Float2String(new_score)); + *output = conn->Double(new_score); } else { *output = redis::Integer(ret); } @@ -192,7 +192,7 @@ class CommandZIncrBy : public Commander { return {Status::RedisExecErr, s.ToString()}; } - *output = redis::BulkString(util::Float2String(score)); + *output = conn->Double(score); return Status::OK(); } @@ -258,7 +258,7 @@ class CommandZPop : public Commander { output->append(redis::MultiLen(member_scores.size() * 2)); for (const auto &ms : member_scores) { output->append(redis::BulkString(ms.member)); - output->append(redis::BulkString(util::Float2String(ms.score))); + output->append(conn->Double(ms.score)); } return Status::OK(); @@ -329,7 +329,7 @@ class CommandBZPop : public BlockingCommander { } if (!member_scores.empty()) { - SendMembersWithScores(member_scores, user_key); + SendMembersWithScores(conn, member_scores, user_key); return Status::OK(); } @@ -350,13 +350,14 @@ class CommandBZPop : public BlockingCommander { } } - void SendMembersWithScores(const std::vector &member_scores, const std::string &user_key) { + void SendMembersWithScores(const Connection *conn, const std::vector &member_scores, + const std::string &user_key) { std::string output; output.append(redis::MultiLen(member_scores.size() * 2 + 1)); output.append(redis::BulkString(user_key)); for (const auto &ms : member_scores) { output.append(redis::BulkString(ms.member)); - output.append(redis::BulkString(util::Float2String(ms.score))); + output.append(conn->Double(ms.score)); } conn_->Reply(output); } @@ -374,7 +375,7 @@ class CommandBZPop : public BlockingCommander { bool empty = member_scores.empty(); if (!empty) { - SendMembersWithScores(member_scores, user_key); + SendMembersWithScores(conn_, member_scores, user_key); } return !empty; @@ -405,7 +406,7 @@ static void SendMembersWithScoresForZMpop(Connection *conn, const std::string &u output.append(redis::MultiLen(member_scores.size() * 2)); for (const auto &ms : member_scores) { output.append(redis::BulkString(ms.member)); - output.append(redis::BulkString(util::Float2String(ms.score))); + output.append(conn->Double(ms.score)); } conn->Reply(output); } @@ -561,7 +562,7 @@ class CommandBZMPop : public BlockingCommander { static CommandKeyRange Range(const std::vector &args) { int num_key = *ParseInt(args[2], 10); - return {3, 1 + num_key, 1}; + return {3, 2 + num_key, 1}; } private: @@ -817,7 +818,7 @@ class CommandZRangeGeneric : public Commander { output->append(redis::MultiLen(member_scores.size() * (with_scores_ ? 2 : 1))); for (const auto &ms : member_scores) { output->append(redis::BulkString(ms.member)); - if (with_scores_) output->append(redis::BulkString(util::Float2String(ms.score))); + if (with_scores_) output->append(conn->Double(ms.score)); } return Status::OK(); } @@ -904,7 +905,7 @@ class CommandZRank : public Commander { if (with_score_) { output->append(redis::MultiLen(2)); output->append(redis::Integer(rank)); - output->append(redis::BulkString(util::Float2String(score))); + output->append(conn->Double(score)); } else { *output = redis::Integer(rank); } @@ -1047,7 +1048,7 @@ class CommandZScore : public Commander { if (s.IsNotFound()) { *output = conn->NilString(); } else { - *output = redis::BulkString(util::Float2String(score)); + *output = conn->Double(score); } return Status::OK(); } @@ -1074,9 +1075,9 @@ class CommandZMScore : public Commander { for (const auto &member : members) { auto iter = mscores.find(member.ToString()); if (iter == mscores.end()) { - values.emplace_back(""); + values.emplace_back(conn->NilString()); } else { - values.emplace_back(util::Float2String(iter->second)); + values.emplace_back(conn->Double(iter->second)); } } } @@ -1142,7 +1143,7 @@ class CommandZUnion : public Commander { output->append(redis::MultiLen(member_scores.size() * (with_scores_ ? 2 : 1))); for (const auto &ms : member_scores) { output->append(redis::BulkString(ms.member)); - if (with_scores_) output->append(redis::BulkString(util::Float2String(ms.score))); + if (with_scores_) output->append(conn->Double(ms.score)); } return Status::OK(); } @@ -1222,9 +1223,9 @@ class CommandZUnionStore : public Commander { return Status::OK(); } - static CommandKeyRange Range(const std::vector &args) { - int num_key = *ParseInt(args[1], 10); - return {3, 2 + num_key, 1}; + static std::vector Range(const std::vector &args) { + int num_key = *ParseInt(args[2], 10); + return {{1, 1, 1}, {3, 2 + num_key, 1}}; } protected: @@ -1249,9 +1250,9 @@ class CommandZInterStore : public CommandZUnionStore { return Status::OK(); } - static CommandKeyRange Range(const std::vector &args) { - int num_key = *ParseInt(args[1], 10); - return {3, 2 + num_key, 1}; + static std::vector Range(const std::vector &args) { + int num_key = *ParseInt(args[2], 10); + return {{1, 1, 1}, {3, 2 + num_key, 1}}; } }; @@ -1276,7 +1277,7 @@ class CommandZInter : public CommandZUnion { output->append(redis::MultiLen(member_scores.size() * (with_scores_ ? 2 : 1))); for (const auto &member_score : member_scores) { output->append(redis::BulkString(member_score.member)); - if (with_scores_) output->append(redis::BulkString(util::Float2String(member_score.score))); + if (with_scores_) output->append(conn->Double(member_score.score)); } return Status::OK(); } @@ -1350,12 +1351,14 @@ class CommandZScan : public CommandSubkeyScanBase { return {Status::RedisExecErr, s.ToString()}; } - std::vector score_strings; - score_strings.reserve(scores.size()); - for (const auto &score : scores) { - score_strings.emplace_back(util::Float2String(score)); + auto cursor = GetNextCursor(srv, members, CursorType::kTypeZSet); + std::vector entries; + entries.reserve(2 * members.size()); + for (size_t i = 0; i < members.size(); i++) { + entries.emplace_back(redis::BulkString(members[i])); + entries.emplace_back(conn->Double(scores[i])); } - *output = GenerateOutput(srv, conn, members, score_strings, CursorType::kTypeZSet); + *output = redis::Array({redis::BulkString(cursor), redis::Array(entries)}); return Status::OK(); } }; @@ -1402,14 +1405,14 @@ class CommandZRandMember : public Commander { result_entries.reserve(member_scores.size()); for (const auto &[member, score] : member_scores) { - result_entries.emplace_back(member); - if (with_scores_) result_entries.emplace_back(util::Float2String(score)); + result_entries.emplace_back(BulkString(member)); + if (with_scores_) result_entries.emplace_back(conn->Double(score)); } if (no_parameters_) *output = s.IsNotFound() ? conn->NilString() : redis::BulkString(result_entries[0]); else - *output = conn->MultiBulkString(result_entries, false); + *output = Array(result_entries); return Status::OK(); } @@ -1419,6 +1422,100 @@ class CommandZRandMember : public Commander { bool no_parameters_ = true; }; +class CommandZDiff : public Commander { + public: + Status Parse(const std::vector &args) override { + auto parse_result = ParseInt(args[1], 10); + if (!parse_result) return {Status::RedisParseErr, errValueNotInteger}; + + numkeys_ = *parse_result; + if (numkeys_ > args.size() - 2) return {Status::RedisParseErr, errInvalidSyntax}; + + size_t j = 0; + keys_.reserve(numkeys_); + while (j < numkeys_) { + keys_.emplace_back(args[j + 2]); + j++; + } + + if (auto i = 2 + numkeys_; i < args.size()) { + if (util::ToLower(args[i]) == "withscores") { + with_scores_ = true; + } + } + + return Commander::Parse(args); + } + + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::ZSet zset_db(srv->storage, conn->GetNamespace()); + + std::vector members_with_scores; + auto s = zset_db.Diff(keys_, &members_with_scores); + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + + output->append(redis::MultiLen(members_with_scores.size() * (with_scores_ ? 2 : 1))); + for (const auto &ms : members_with_scores) { + output->append(redis::BulkString(ms.member)); + if (with_scores_) output->append(conn->Double(ms.score)); + } + + return Status::OK(); + } + + static CommandKeyRange Range(const std::vector &args) { + int num_key = *ParseInt(args[1], 10); + return {2, 1 + num_key, 1}; + } + + protected: + size_t numkeys_ = 0; + std::vector keys_; + bool with_scores_ = false; +}; + +class CommandZDiffStore : public Commander { + public: + Status Parse(const std::vector &args) override { + auto parse_result = ParseInt(args[2], 10); + if (!parse_result) return {Status::RedisParseErr, errValueNotInteger}; + + numkeys_ = *parse_result; + if (numkeys_ > args.size() - 3) return {Status::RedisParseErr, errInvalidSyntax}; + + size_t j = 0; + while (j < numkeys_) { + keys_.emplace_back(args[j + 3]); + j++; + } + + return Commander::Parse(args); + } + + Status Execute(Server *srv, Connection *conn, std::string *output) override { + redis::ZSet zset_db(srv->storage, conn->GetNamespace()); + + uint64_t stored_count = 0; + auto s = zset_db.DiffStore(args_[1], keys_, &stored_count); + if (!s.ok()) { + return {Status::RedisExecErr, s.ToString()}; + } + *output = redis::Integer(stored_count); + return Status::OK(); + } + + static std::vector Range(const std::vector &args) { + int num_key = *ParseInt(args[2], 10); + return {{1, 1, 1}, {3, 2 + num_key, 1}}; + } + + protected: + size_t numkeys_ = 0; + std::vector keys_; +}; + REDIS_REGISTER_COMMANDS(MakeCmdAttr("zadd", -4, "write", 1, 1, 1), MakeCmdAttr("zcard", 2, "read-only", 1, 1, 1), MakeCmdAttr("zcount", 4, "read-only", 1, 1, 1), @@ -1451,6 +1548,8 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr("zadd", -4, "write", 1, 1, 1), MakeCmdAttr("zscan", -3, "read-only", 1, 1, 1), MakeCmdAttr("zunionstore", -4, "write", CommandZUnionStore::Range), MakeCmdAttr("zunion", -3, "read-only", CommandZUnion::Range), - MakeCmdAttr("zrandmember", -2, "read-only", 1, 1, 1)) + MakeCmdAttr("zrandmember", -2, "read-only", 1, 1, 1), + MakeCmdAttr("zdiff", -3, "read-only", CommandZDiff::Range), + MakeCmdAttr("zdiffstore", -3, "read-only", CommandZDiffStore::Range), ) } // namespace redis diff --git a/src/commands/command_parser.h b/src/commands/command_parser.h index 3aa31ae0931..c13d682bdba 100644 --- a/src/commands/command_parser.h +++ b/src/commands/command_parser.h @@ -24,10 +24,12 @@ #include #include #include +#include #include "parse_util.h" #include "status.h" #include "string_util.h" +#include "type_util.h" template struct MoveIterator : Iter { @@ -46,17 +48,29 @@ struct CommandParser { CommandParser(Iter begin, Iter end) : begin_(std::move(begin)), end_(std::move(end)) {} - template - explicit CommandParser(const Container& con, size_t skip_num = 0) : CommandParser(std::begin(con), std::end(con)) { + template && + !std::is_same_v, CommandParser>, + int> = 0> + explicit CommandParser(Container&& con, size_t skip_num = 0) : CommandParser(std::begin(con), std::end(con)) { std::advance(begin_, skip_num); } - template + template && + !std::is_same_v, CommandParser>, + int> = 0> explicit CommandParser(Container&& con, size_t skip_num = 0) : CommandParser(MoveIterator(std::begin(con)), MoveIterator(std::end(con))) { std::advance(begin_, skip_num); } + CommandParser(const CommandParser&) = default; + CommandParser(CommandParser&&) noexcept = default; + + CommandParser& operator=(const CommandParser&) = default; + CommandParser& operator=(CommandParser&&) noexcept = default; + + ~CommandParser() = default; + decltype(auto) RawPeek() const { return *begin_; } decltype(auto) operator[](size_t index) const { diff --git a/src/commands/scan_base.h b/src/commands/scan_base.h index bab86e218a7..2e11c989bd4 100644 --- a/src/commands/scan_base.h +++ b/src/commands/scan_base.h @@ -74,7 +74,7 @@ class CommandScanBase : public Commander { list.emplace_back(redis::BulkString("0")); } - list.emplace_back(conn->MultiBulkString(keys, false)); + list.emplace_back(ArrayOfBulkStrings(keys)); return redis::Array(list); } @@ -112,25 +112,11 @@ class CommandSubkeyScanBase : public CommandScanBase { return Commander::Parse(args); } - std::string GenerateOutput(Server *srv, const Connection *conn, const std::vector &fields, - const std::vector &values, CursorType cursor_type) { - std::vector list; - auto items_count = fields.size(); - if (items_count == static_cast(limit_)) { - auto end_cursor = srv->GenerateCursorFromKeyName(fields.back(), cursor_type); - list.emplace_back(redis::BulkString(end_cursor)); - } else { - list.emplace_back(redis::BulkString("0")); + std::string GetNextCursor(Server *srv, std::vector &fields, CursorType cursor_type) const { + if (fields.size() == static_cast(limit_)) { + return srv->GenerateCursorFromKeyName(fields.back(), cursor_type); } - std::vector fvs; - if (items_count > 0) { - for (size_t i = 0; i < items_count; i++) { - fvs.emplace_back(fields[i]); - fvs.emplace_back(values[i]); - } - } - list.emplace_back(conn->MultiBulkString(fvs, false)); - return redis::Array(list); + return "0"; } protected: diff --git a/src/common/bitfield_util.h b/src/common/bitfield_util.h index dbb44b1db76..63d27980d13 100644 --- a/src/common/bitfield_util.h +++ b/src/common/bitfield_util.h @@ -205,7 +205,7 @@ class ArrayBitfieldBitmap { void Reset() { memset(buf_, 0, sizeof(buf_)); } - Status Set(uint32_t byte_offset, uint32_t bytes, const uint8_t *src) { + [[nodiscard]] Status Set(uint32_t byte_offset, uint32_t bytes, const uint8_t *src) { Status bound_status(checkLegalBound(byte_offset, bytes)); if (!bound_status) { return bound_status; @@ -215,7 +215,7 @@ class ArrayBitfieldBitmap { return Status::OK(); } - Status Get(uint32_t byte_offset, uint32_t bytes, uint8_t *dst) const { + [[nodiscard]] Status Get(uint32_t byte_offset, uint32_t bytes, uint8_t *dst) const { Status bound_status(checkLegalBound(byte_offset, bytes)); if (!bound_status) { return bound_status; @@ -226,7 +226,7 @@ class ArrayBitfieldBitmap { } StatusOr GetUnsignedBitfield(uint64_t bit_offset, uint64_t bits) const { - Status bits_status(BitfieldEncoding::CheckSupportedBitLengths(BitfieldEncoding::Type::kUnsigned, bits)); + Status bits_status = BitfieldEncoding::CheckSupportedBitLengths(BitfieldEncoding::Type::kUnsigned, bits); if (!bits_status) { return bits_status; } @@ -234,7 +234,7 @@ class ArrayBitfieldBitmap { } StatusOr GetSignedBitfield(uint64_t bit_offset, uint64_t bits) const { - Status bits_status(BitfieldEncoding::CheckSupportedBitLengths(BitfieldEncoding::Type::kSigned, bits)); + Status bits_status = BitfieldEncoding::CheckSupportedBitLengths(BitfieldEncoding::Type::kSigned, bits); if (!bits_status) { return bits_status; } @@ -257,10 +257,10 @@ class ArrayBitfieldBitmap { return value; } - Status SetBitfield(uint32_t bit_offset, uint32_t bits, uint64_t value) { + [[nodiscard]] Status SetBitfield(uint32_t bit_offset, uint32_t bits, uint64_t value) { uint32_t first_byte = bit_offset / 8; uint32_t last_byte = (bit_offset + bits - 1) / 8 + 1; - Status bound_status(checkLegalBound(first_byte, last_byte - first_byte)); + Status bound_status = checkLegalBound(first_byte, last_byte - first_byte); if (!bound_status) { return bound_status; } diff --git a/src/config/config.cc b/src/config/config.cc index 5ea8c6173dc..7a944f6c4ad 100644 --- a/src/config/config.cc +++ b/src/config/config.cc @@ -74,6 +74,15 @@ const std::vector> compression_types{[] { return res; }()}; +const std::vector> cache_types{[] { + std::vector> res; + res.reserve(engine::CacheOptions.size()); + for (const auto &e : engine::CacheOptions) { + res.push_back({e.name, e.type}); + } + return res; +}()}; + std::string TrimRocksDbPrefix(std::string s) { if (strncasecmp(s.data(), "rocksdb.", 8) != 0) return s; return s.substr(8, s.size() - 8); @@ -191,6 +200,8 @@ Config::Config() { {"rocksdb.stats_dump_period_sec", false, new IntField(&rocks_db.stats_dump_period_sec, 0, 0, INT_MAX)}, {"rocksdb.cache_index_and_filter_blocks", true, new YesNoField(&rocks_db.cache_index_and_filter_blocks, true)}, {"rocksdb.block_cache_size", true, new IntField(&rocks_db.block_cache_size, 0, 0, INT_MAX)}, + {"rocksdb.block_cache_type", true, + new EnumField(&rocks_db.block_cache_type, cache_types, BlockCacheType::kCacheTypeLRU)}, {"rocksdb.subkey_block_cache_size", true, new IntField(&rocks_db.subkey_block_cache_size, 2048, 0, INT_MAX)}, {"rocksdb.metadata_block_cache_size", true, new IntField(&rocks_db.metadata_block_cache_size, 2048, 0, INT_MAX)}, {"rocksdb.share_metadata_and_subkey_block_cache", true, diff --git a/src/config/config.h b/src/config/config.h index 46e260bc541..e69ad5e580f 100644 --- a/src/config/config.h +++ b/src/config/config.h @@ -54,6 +54,8 @@ constexpr const uint32_t kDefaultPort = 6666; constexpr const char *kDefaultNamespace = "__namespace"; +enum class BlockCacheType { kCacheTypeLRU = 0, kCacheTypeHCC }; + struct CompactionCheckerRange { public: int start; @@ -169,6 +171,7 @@ struct Config { int block_size; bool cache_index_and_filter_blocks; int block_cache_size; + BlockCacheType block_cache_type; int metadata_block_cache_size; int subkey_block_cache_size; bool share_metadata_and_subkey_block_cache; diff --git a/src/server/redis_connection.cc b/src/server/redis_connection.cc index 87370ffa921..99d579fe136 100644 --- a/src/server/redis_connection.cc +++ b/src/server/redis_connection.cc @@ -137,11 +137,10 @@ std::string Connection::Bool(bool b) const { return Integer(b ? 1 : 0); } -std::string Connection::MultiBulkString(const std::vector &values, - bool output_nil_for_empty_string) const { +std::string Connection::MultiBulkString(const std::vector &values) const { std::string result = "*" + std::to_string(values.size()) + CRLF; for (const auto &value : values) { - if (value.empty() && output_nil_for_empty_string) { + if (value.empty()) { result += NilString(); } else { result += BulkString(value); @@ -163,6 +162,26 @@ std::string Connection::MultiBulkString(const std::vector &values, return result; } +std::string Connection::SetOfBulkStrings(const std::vector &elems) const { + std::string result; + result += HeaderOfSet(elems.size()); + for (const auto &elem : elems) { + result += BulkString(elem); + } + return result; +} + +std::string Connection::MapOfBulkStrings(const std::vector &elems) const { + CHECK(elems.size() % 2 == 0); + + std::string result; + result += HeaderOfMap(elems.size() / 2); + for (const auto &elem : elems) { + result += BulkString(elem); + } + return result; +} + void Connection::SendFile(int fd) { // NOTE: we don't need to close the fd, the libevent will do that auto output = bufferevent_get_output(bev_); diff --git a/src/server/redis_connection.h b/src/server/redis_connection.h index f35a3889c68..e51f66ab72e 100644 --- a/src/server/redis_connection.h +++ b/src/server/redis_connection.h @@ -65,11 +65,27 @@ class Connection : public EvbufCallbackBase { RESP GetProtocolVersion() const { return protocol_version_; } void SetProtocolVersion(RESP version) { protocol_version_ = version; } std::string Bool(bool b) const; + std::string BigNumber(const std::string &n) const { + return protocol_version_ == RESP::v3 ? "(" + n + CRLF : BulkString(n); + } + std::string Double(double d) const { + return protocol_version_ == RESP::v3 ? "," + util::Float2String(d) + CRLF : BulkString(util::Float2String(d)); + } std::string NilString() const { return redis::NilString(protocol_version_); } std::string NilArray() const { return protocol_version_ == RESP::v3 ? "_" CRLF : "*-1" CRLF; } - std::string MultiBulkString(const std::vector &values, bool output_nil_for_empty_string = true) const; + std::string MultiBulkString(const std::vector &values) const; std::string MultiBulkString(const std::vector &values, const std::vector &statuses) const; + template , int> = 0> + std::string HeaderOfSet(T len) const { + return protocol_version_ == RESP::v3 ? "~" + std::to_string(len) + CRLF : MultiLen(len); + } + std::string SetOfBulkStrings(const std::vector &elems) const; + template , int> = 0> + std::string HeaderOfMap(T len) const { + return protocol_version_ == RESP::v3 ? "%" + std::to_string(len) + CRLF : MultiLen(len * 2); + } + std::string MapOfBulkStrings(const std::vector &elems) const; using UnsubscribeCallback = std::function; void SubscribeChannel(const std::string &channel); diff --git a/src/storage/iterator.cc b/src/storage/iterator.cc index 58e283b165c..6514207b35d 100644 --- a/src/storage/iterator.cc +++ b/src/storage/iterator.cc @@ -25,7 +25,7 @@ #include "db_util.h" namespace engine { -DBIterator::DBIterator(Storage* storage, rocksdb::ReadOptions read_options, int slot) +DBIterator::DBIterator(Storage *storage, rocksdb::ReadOptions read_options, int slot) : storage_(storage), read_options_(std::move(read_options)), slot_(slot) { metadata_cf_handle_ = storage_->GetCFHandle(kMetadataColumnFamilyName); metadata_iter_ = util::UniqueIterator(storage_->NewIterator(read_options_, metadata_cf_handle_)); @@ -80,7 +80,7 @@ void DBIterator::Reset() { if (metadata_iter_) metadata_iter_.reset(); } -void DBIterator::Seek(const std::string& target) { +void DBIterator::Seek(const std::string &target) { if (!metadata_iter_) return; // Iterate with the slot id but storage didn't enable slot id encoding @@ -112,7 +112,7 @@ std::unique_ptr DBIterator::GetSubKeyIterator() const { return std::make_unique(storage_, read_options_, type, std::move(prefix)); } -SubKeyIterator::SubKeyIterator(Storage* storage, rocksdb::ReadOptions read_options, RedisType type, std::string prefix) +SubKeyIterator::SubKeyIterator(Storage *storage, rocksdb::ReadOptions read_options, RedisType type, std::string prefix) : storage_(storage), read_options_(std::move(read_options)), type_(type), prefix_(std::move(prefix)) { if (type_ == kRedisStream) { cf_handle_ = storage_->GetCFHandle(kStreamColumnFamilyName); @@ -145,6 +145,8 @@ Slice SubKeyIterator::UserKey() const { return internal_key.GetSubKey(); } +rocksdb::ColumnFamilyHandle *SubKeyIterator::ColumnFamilyHandle() const { return Valid() ? this->cf_handle_ : nullptr; } + Slice SubKeyIterator::Value() const { return Valid() ? iter_->value() : Slice(); } void SubKeyIterator::Seek() { @@ -162,4 +164,125 @@ void SubKeyIterator::Reset() { if (iter_) iter_.reset(); } +rocksdb::Status WALBatchExtractor::PutCF(uint32_t column_family_id, const Slice &key, const Slice &value) { + if (slot_ != -1 && slot_ != ExtractSlotId(key)) { + return rocksdb::Status::OK(); + } + items_.emplace_back(WALItem::Type::kTypePut, column_family_id, key.ToString(), value.ToString()); + return rocksdb::Status::OK(); +} + +rocksdb::Status WALBatchExtractor::DeleteCF(uint32_t column_family_id, const rocksdb::Slice &key) { + if (slot_ != -1 && slot_ != ExtractSlotId(key)) { + return rocksdb::Status::OK(); + } + items_.emplace_back(WALItem::Type::kTypeDelete, column_family_id, key.ToString(), std::string{}); + return rocksdb::Status::OK(); +} + +rocksdb::Status WALBatchExtractor::DeleteRangeCF(uint32_t column_family_id, const rocksdb::Slice &begin_key, + const rocksdb::Slice &end_key) { + items_.emplace_back(WALItem::Type::kTypeDeleteRange, column_family_id, begin_key.ToString(), end_key.ToString()); + return rocksdb::Status::OK(); +} + +void WALBatchExtractor::LogData(const rocksdb::Slice &blob) { + items_.emplace_back(WALItem::Type::kTypeLogData, 0, blob.ToString(), std::string{}); +}; + +void WALBatchExtractor::Clear() { items_.clear(); } + +WALBatchExtractor::Iter WALBatchExtractor::GetIter() { return Iter(&items_); } + +bool WALBatchExtractor::Iter::Valid() { return items_ && cur_ < items_->size(); } + +void WALBatchExtractor::Iter::Next() { cur_++; } + +WALItem WALBatchExtractor::Iter::Value() { + if (!Valid()) { + return {}; + } + return (*items_)[cur_]; +} + +void WALIterator::Reset() { + if (iter_) { + iter_.reset(); + } + if (batch_iter_) { + batch_iter_.reset(); + } + extractor_.Clear(); + next_batch_seq_ = 0; +} + +bool WALIterator::Valid() const { return (batch_iter_ && batch_iter_->Valid()) || (iter_ && iter_->Valid()); } + +void WALIterator::nextBatch() { + if (!iter_ || !iter_->Valid()) { + Reset(); + return; + } + + auto batch = iter_->GetBatch(); + if (batch.sequence != next_batch_seq_ || !batch.writeBatchPtr) { + Reset(); + return; + } + + extractor_.Clear(); + + auto s = batch.writeBatchPtr->Iterate(&extractor_); + if (!s.ok()) { + Reset(); + return; + } + + next_batch_seq_ += batch.writeBatchPtr->Count(); + batch_iter_ = std::make_unique(extractor_.GetIter()); +} + +void WALIterator::Seek(rocksdb::SequenceNumber seq) { + if (slot_ != -1 && !storage_->IsSlotIdEncoded()) { + Reset(); + return; + } + + auto s = storage_->GetWALIter(seq, &iter_); + if (!s.IsOK()) { + Reset(); + return; + } + + next_batch_seq_ = seq; + + nextBatch(); +} + +WALItem WALIterator::Item() { + if (batch_iter_ && batch_iter_->Valid()) { + return batch_iter_->Value(); + } + return {}; +} + +rocksdb::SequenceNumber WALIterator::NextSequenceNumber() const { return next_batch_seq_; } + +void WALIterator::Next() { + if (!Valid()) { + Reset(); + return; + } + + if (batch_iter_ && batch_iter_->Valid()) { + batch_iter_->Next(); + if (batch_iter_->Valid()) { + return; + } + } + + iter_->Next(); + nextBatch(); +} + } // namespace engine diff --git a/src/storage/iterator.h b/src/storage/iterator.h index 40b93bc3799..2f123630c3e 100644 --- a/src/storage/iterator.h +++ b/src/storage/iterator.h @@ -37,6 +37,7 @@ class SubKeyIterator { Slice Key() const; // return the user key without prefix Slice UserKey() const; + rocksdb::ColumnFamilyHandle *ColumnFamilyHandle() const; Slice Value() const; void Reset(); @@ -79,4 +80,86 @@ class DBIterator { std::unique_ptr subkey_iter_; }; +struct WALItem { + enum class Type : uint8_t { + kTypeInvalid = 0, + kTypeLogData = 1, + kTypePut = 2, + kTypeDelete = 3, + kTypeDeleteRange = 4, + }; + + WALItem() = default; + WALItem(WALItem::Type t, uint32_t cf_id, std::string k, std::string v) + : type(t), column_family_id(cf_id), key(std::move(k)), value(std::move(v)) {} + + WALItem::Type type = WALItem::Type::kTypeInvalid; + uint32_t column_family_id = 0; + std::string key; + std::string value; +}; + +class WALBatchExtractor : public rocksdb::WriteBatch::Handler { + public: + // If set slot, storage must enable slot id encoding + explicit WALBatchExtractor(int slot = -1) : slot_(slot) {} + + rocksdb::Status PutCF(uint32_t column_family_id, const Slice &key, const Slice &value) override; + + rocksdb::Status DeleteCF(uint32_t column_family_id, const rocksdb::Slice &key) override; + + rocksdb::Status DeleteRangeCF(uint32_t column_family_id, const rocksdb::Slice &begin_key, + const rocksdb::Slice &end_key) override; + + void LogData(const rocksdb::Slice &blob) override; + + void Clear(); + + class Iter { + friend class WALBatchExtractor; + + public: + bool Valid(); + void Next(); + WALItem Value(); + + private: + explicit Iter(std::vector *items) : items_(items), cur_(0) {} + std::vector *items_; + size_t cur_; + }; + + WALBatchExtractor::Iter GetIter(); + + private: + std::vector items_; + int slot_; +}; + +class WALIterator { + public: + explicit WALIterator(engine::Storage *storage, int slot = -1) + : storage_(storage), slot_(slot), extractor_(slot), next_batch_seq_(0){}; + ~WALIterator() = default; + + bool Valid() const; + void Seek(rocksdb::SequenceNumber seq); + void Next(); + WALItem Item(); + + rocksdb::SequenceNumber NextSequenceNumber() const; + void Reset(); + + private: + void nextBatch(); + + engine::Storage *storage_; + int slot_; + + std::unique_ptr iter_; + WALBatchExtractor extractor_; + std::unique_ptr batch_iter_; + rocksdb::SequenceNumber next_batch_seq_; +}; + } // namespace engine diff --git a/src/storage/redis_db.cc b/src/storage/redis_db.cc index bdbd9789490..63e8ef571cc 100644 --- a/src/storage/redis_db.cc +++ b/src/storage/redis_db.cc @@ -29,8 +29,11 @@ #include "db_util.h" #include "parse_util.h" #include "rocksdb/iterator.h" +#include "rocksdb/status.h" #include "server/server.h" +#include "storage/iterator.h" #include "storage/redis_metadata.h" +#include "storage/storage.h" #include "time_util.h" namespace redis { @@ -55,7 +58,8 @@ rocksdb::Status Database::ParseMetadata(RedisTypes types, Slice *bytes, Metadata }); auto s = metadata->Decode(bytes); - if (!s.ok()) return s; + // delay InvalidArgument error check after type match check + if (!s.ok() && !s.IsInvalidArgument()) return s; if (metadata->Expired()) { // error discarded here since it already failed @@ -69,6 +73,8 @@ rocksdb::Status Database::ParseMetadata(RedisTypes types, Slice *bytes, Metadata auto _ [[maybe_unused]] = metadata->Decode(old_metadata); return rocksdb::Status::InvalidArgument(kErrMsgWrongType); } + if (s.IsInvalidArgument()) return s; + if (metadata->size == 0 && !metadata->IsEmptyableType()) { // error discarded here since it already failed auto _ [[maybe_unused]] = metadata->Decode(old_metadata); @@ -182,7 +188,9 @@ rocksdb::Status Database::MDel(const std::vector &keys, uint64_t *deleted if (statuses[i].IsNotFound()) continue; Metadata metadata(kRedisNone, false); - auto s = metadata.Decode(pin_values[i]); + // Explicit construct a rocksdb::Slice to avoid the implicit conversion from + // PinnableSlice to Slice. + auto s = metadata.Decode(rocksdb::Slice(pin_values[i].data(), pin_values[i].size())); if (!s.ok()) continue; if (metadata.Expired()) continue; @@ -690,4 +698,67 @@ Status WriteBatchLogData::Decode(const rocksdb::Slice &blob) { return Status::OK(); } + +rocksdb::Status Database::Rename(const std::string &key, const std::string &new_key, bool nx, bool *ret) { + *ret = true; + std::string ns_key = AppendNamespacePrefix(key); + std::string new_ns_key = AppendNamespacePrefix(new_key); + + std::vector lock_keys = {ns_key, new_ns_key}; + MultiLockGuard guard(storage_->GetLockManager(), lock_keys); + + RedisType type = kRedisNone; + auto s = Type(key, &type); + if (!s.ok()) return s; + if (type == kRedisNone) return rocksdb::Status::InvalidArgument("ERR no such key"); + + if (nx) { + int exist = 0; + if (s = Exists({new_key}, &exist), !s.ok()) return s; + if (exist > 0) { + *ret = false; + return rocksdb::Status::OK(); + } + } + + if (key == new_key) return rocksdb::Status::OK(); + + auto batch = storage_->GetWriteBatchBase(); + WriteBatchLogData log_data(type); + batch->PutLogData(log_data.Encode()); + + engine::DBIterator iter(storage_, rocksdb::ReadOptions()); + iter.Seek(ns_key); + + // copy metadata + batch->Delete(metadata_cf_handle_, ns_key); + batch->Put(metadata_cf_handle_, new_ns_key, iter.Value()); + + auto subkey_iter = iter.GetSubKeyIterator(); + + if (subkey_iter != nullptr) { + auto zset_score_cf = type == kRedisZSet ? storage_->GetCFHandle(engine::kZSetScoreColumnFamilyName) : nullptr; + + for (subkey_iter->Seek(); subkey_iter->Valid(); subkey_iter->Next()) { + InternalKey from_ikey(subkey_iter->Key(), storage_->IsSlotIdEncoded()); + std::string to_ikey = + InternalKey(new_ns_key, from_ikey.GetSubKey(), from_ikey.GetVersion(), storage_->IsSlotIdEncoded()).Encode(); + // copy sub key + batch->Put(subkey_iter->ColumnFamilyHandle(), to_ikey, subkey_iter->Value()); + + // The ZSET type stores an extra score and member field inside `zset_score` column family + // while compared to other composed data structures. The purpose is to allow to seek by score. + if (type == kRedisZSet) { + std::string score_bytes = subkey_iter->Value().ToString(); + score_bytes.append(from_ikey.GetSubKey().ToString()); + // copy score key + std::string score_key = + InternalKey(new_ns_key, score_bytes, from_ikey.GetVersion(), storage_->IsSlotIdEncoded()).Encode(); + batch->Put(zset_score_cf, score_key, Slice()); + } + } + } + + return storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); +} } // namespace redis diff --git a/src/storage/redis_db.h b/src/storage/redis_db.h index 658044414a1..0627b651702 100644 --- a/src/storage/redis_db.h +++ b/src/storage/redis_db.h @@ -61,6 +61,7 @@ class Database { rocksdb::ColumnFamilyHandle *cf_handle = nullptr); [[nodiscard]] rocksdb::Status ClearKeysOfSlot(const rocksdb::Slice &ns, int slot); [[nodiscard]] rocksdb::Status KeyExist(const std::string &key); + [[nodiscard]] rocksdb::Status Rename(const std::string &key, const std::string &new_key, bool nx, bool *ret); protected: engine::Storage *storage_; diff --git a/src/storage/redis_metadata.cc b/src/storage/redis_metadata.cc index 23afc765457..1bca93d77dc 100644 --- a/src/storage/redis_metadata.cc +++ b/src/storage/redis_metadata.cc @@ -101,6 +101,17 @@ bool InternalKey::operator==(const InternalKey &that) const { return version_ == that.version_; } +// Must slot encoded +uint16_t ExtractSlotId(Slice ns_key) { + uint8_t namespace_size = 0; + GetFixed8(&ns_key, &namespace_size); + ns_key.remove_prefix(namespace_size); + + uint16_t slot_id = HASH_SLOTS_SIZE; + GetFixed16(&ns_key, &slot_id); + return slot_id; +} + template std::tuple ExtractNamespaceKey(Slice ns_key, bool slot_id_encoded) { uint8_t namespace_size = 0; @@ -335,13 +346,13 @@ rocksdb::Status ListMetadata::Decode(Slice *input) { if (auto s = Metadata::Decode(input); !s.ok()) { return s; } - if (Type() == kRedisList) { - if (input->size() < 8 + 8) { - return rocksdb::Status::InvalidArgument(kErrMetadataTooShort); - } - GetFixed64(input, &head); - GetFixed64(input, &tail); + + if (input->size() < 8 + 8) { + return rocksdb::Status::InvalidArgument(kErrMetadataTooShort); } + GetFixed64(input, &head); + GetFixed64(input, &tail); + return rocksdb::Status::OK(); } diff --git a/src/storage/redis_metadata.h b/src/storage/redis_metadata.h index 8fe52f3910e..ce1026443af 100644 --- a/src/storage/redis_metadata.h +++ b/src/storage/redis_metadata.h @@ -105,6 +105,7 @@ struct KeyNumStats { uint64_t avg_ttl = 0; }; +[[nodiscard]] uint16_t ExtractSlotId(Slice ns_key); template [[nodiscard]] std::tuple ExtractNamespaceKey(Slice ns_key, bool slot_id_encoded); [[nodiscard]] std::string ComposeNamespaceKey(const Slice &ns, const Slice &key, bool slot_id_encoded); diff --git a/src/storage/scripting.cc b/src/storage/scripting.cc index a6e73fa040d..8105d321464 100644 --- a/src/storage/scripting.cc +++ b/src/storage/scripting.cc @@ -425,7 +425,8 @@ Status FunctionCall(redis::Connection *conn, const std::string &name, const std: } // list all library names and their code (enabled via `with_code`) -Status FunctionList(Server *srv, const std::string &libname, bool with_code, std::string *output) { +Status FunctionList(Server *srv, const redis::Connection *conn, const std::string &libname, bool with_code, + std::string *output) { std::string start_key = engine::kLuaLibCodePrefix + libname; std::string end_key = start_key; end_key.back()++; @@ -445,12 +446,13 @@ Status FunctionList(Server *srv, const std::string &libname, bool with_code, std result.emplace_back(lib.ToString(), iter->value().ToString()); } - output->append(redis::MultiLen(result.size() * (with_code ? 4 : 2))); + output->append(redis::MultiLen(result.size())); for (const auto &[lib, code] : result) { - output->append(redis::SimpleString("library_name")); - output->append(redis::SimpleString(lib)); + output->append(conn->HeaderOfMap(with_code ? 2 : 1)); + output->append(redis::BulkString("library_name")); + output->append(redis::BulkString(lib)); if (with_code) { - output->append(redis::SimpleString("library_code")); + output->append(redis::BulkString("library_code")); output->append(redis::BulkString(code)); } } @@ -460,7 +462,7 @@ Status FunctionList(Server *srv, const std::string &libname, bool with_code, std // extension to Redis Function // list all function names and their corresponding library names -Status FunctionListFunc(Server *srv, const std::string &funcname, std::string *output) { +Status FunctionListFunc(Server *srv, const redis::Connection *conn, const std::string &funcname, std::string *output) { std::string start_key = engine::kLuaFuncLibPrefix + funcname; std::string end_key = start_key; end_key.back()++; @@ -480,12 +482,13 @@ Status FunctionListFunc(Server *srv, const std::string &funcname, std::string *o result.emplace_back(func.ToString(), iter->value().ToString()); } - output->append(redis::MultiLen(result.size() * 4)); + output->append(redis::MultiLen(result.size())); for (const auto &[func, lib] : result) { - output->append(redis::SimpleString("function_name")); - output->append(redis::SimpleString(func)); - output->append(redis::SimpleString("from_library")); - output->append(redis::SimpleString(lib)); + output->append(conn->HeaderOfMap(2)); + output->append(redis::BulkString("function_name")); + output->append(redis::BulkString(func)); + output->append(redis::BulkString("from_library")); + output->append(redis::BulkString(lib)); } return Status::OK(); @@ -495,7 +498,7 @@ Status FunctionListFunc(Server *srv, const std::string &funcname, std::string *o // list detailed informantion of a specific library // NOTE: it is required to load the library to lua runtime before listing (calling this function) // i.e. it will output nothing if the library is only in storage but not loaded -Status FunctionListLib(Server *srv, const std::string &libname, std::string *output) { +Status FunctionListLib(Server *srv, const redis::Connection *conn, const std::string &libname, std::string *output) { auto lua = srv->Lua(); lua_getglobal(lua, REDIS_FUNCTION_LIBRARIES); @@ -511,11 +514,11 @@ Status FunctionListLib(Server *srv, const std::string &libname, std::string *out return {Status::NotOK, "The library is not found or not loaded from storage"}; } - output->append(redis::MultiLen(6)); - output->append(redis::SimpleString("library_name")); - output->append(redis::SimpleString(libname)); - output->append(redis::SimpleString("engine")); - output->append(redis::SimpleString("lua")); + output->append(conn->HeaderOfMap(3)); + output->append(redis::BulkString("library_name")); + output->append(redis::BulkString(libname)); + output->append(redis::BulkString("engine")); + output->append(redis::BulkString("lua")); auto count = lua_objlen(lua, -1); output->append(redis::SimpleString("functions")); @@ -524,7 +527,7 @@ Status FunctionListLib(Server *srv, const std::string &libname, std::string *out for (size_t i = 1; i <= count; ++i) { lua_rawgeti(lua, -1, static_cast(i)); auto func = lua_tostring(lua, -1); - output->append(redis::SimpleString(func)); + output->append(redis::BulkString(func)); lua_pop(lua, 1); } @@ -1077,6 +1080,7 @@ std::string ReplyToRedisReply(redis::Connection *conn, lua_State *lua) { std::string output; const char *obj_s = nullptr; size_t obj_len = 0; + int j = 0, mbulklen = 0; int t = lua_type(lua, -1); switch (t) { @@ -1110,6 +1114,7 @@ std::string ReplyToRedisReply(redis::Connection *conn, lua_State *lua) { return output; } lua_pop(lua, 1); /* Discard field name pushed before. */ + /* Handle status reply. */ lua_pushstring(lua, "ok"); lua_gettable(lua, -2); @@ -1119,23 +1124,35 @@ std::string ReplyToRedisReply(redis::Connection *conn, lua_State *lua) { output = redis::BulkString(std::string(obj_s, obj_len)); lua_pop(lua, 1); return output; - } else { - int j = 1, mbulklen = 0; - lua_pop(lua, 1); /* Discard the 'ok' field value we popped */ - while (true) { - lua_pushnumber(lua, j++); - lua_gettable(lua, -2); - t = lua_type(lua, -1); - if (t == LUA_TNIL) { - lua_pop(lua, 1); - break; - } - mbulklen++; - output += ReplyToRedisReply(conn, lua); + } + lua_pop(lua, 1); /* Discard the 'ok' field value we pushed */ + + /* Handle big number reply. */ + lua_pushstring(lua, "big_number"); + lua_gettable(lua, -2); + t = lua_type(lua, -1); + if (t == LUA_TSTRING) { + obj_s = lua_tolstring(lua, -1, &obj_len); + output = conn->BigNumber(std::string(obj_s, obj_len)); + lua_pop(lua, 1); + return output; + } + lua_pop(lua, 1); /* Discard the 'big_number' field value we pushed */ + + j = 1, mbulklen = 0; + while (true) { + lua_pushnumber(lua, j++); + lua_gettable(lua, -2); + t = lua_type(lua, -1); + if (t == LUA_TNIL) { lua_pop(lua, 1); + break; } - output = redis::MultiLen(mbulklen) + output; + mbulklen++; + output += ReplyToRedisReply(conn, lua); + lua_pop(lua, 1); } + output = redis::MultiLen(mbulklen) + output; break; default: output = conn->NilString(); diff --git a/src/storage/scripting.h b/src/storage/scripting.h index 0d9ce46c316..a2c90b90ae0 100644 --- a/src/storage/scripting.h +++ b/src/storage/scripting.h @@ -66,9 +66,10 @@ Status FunctionLoad(redis::Connection *conn, const std::string &script, bool nee std::string *lib_name, bool read_only = false); Status FunctionCall(redis::Connection *conn, const std::string &name, const std::vector &keys, const std::vector &argv, std::string *output, bool read_only = false); -Status FunctionList(Server *srv, const std::string &libname, bool with_code, std::string *output); -Status FunctionListFunc(Server *srv, const std::string &funcname, std::string *output); -Status FunctionListLib(Server *srv, const std::string &libname, std::string *output); +Status FunctionList(Server *srv, const redis::Connection *conn, const std::string &libname, bool with_code, + std::string *output); +Status FunctionListFunc(Server *srv, const redis::Connection *conn, const std::string &funcname, std::string *output); +Status FunctionListLib(Server *srv, const redis::Connection *conn, const std::string &libname, std::string *output); Status FunctionDelete(Server *srv, const std::string &name); bool FunctionIsLibExist(redis::Connection *conn, const std::string &libname, bool need_check_storage = true, bool read_only = false); diff --git a/src/storage/storage.cc b/src/storage/storage.cc index a600813fae8..75f6fb6de86 100644 --- a/src/storage/storage.cc +++ b/src/storage/storage.cc @@ -42,6 +42,7 @@ #include "event_util.h" #include "redis_db.h" #include "redis_metadata.h" +#include "rocksdb/cache.h" #include "rocksdb_crc32c.h" #include "server/server.h" #include "table_properties_collector.h" @@ -52,6 +53,23 @@ namespace engine { constexpr const char *kReplicationIdKey = "replication_id_"; +// used in creating rocksdb::LRUCache, set `num_shard_bits` to -1 means let rocksdb choose a good default shard count +// based on the capacity and the implementation. +constexpr int kRocksdbLRUAutoAdjustShardBits = -1; + +// used as the default argument for `strict_capacity_limit` in creating rocksdb::Cache. +constexpr bool kRocksdbCacheStrictCapacityLimit = false; + +// used as the default argument for `high_pri_pool_ratio` in creating block cache. +constexpr double kRocksdbLRUBlockCacheHighPriPoolRatio = 0.75; + +// used as the default argument for `high_pri_pool_ratio` in creating row cache. +constexpr double kRocksdbLRURowCacheHighPriPoolRatio = 0.5; + +// used in creating rocksdb::HyperClockCache, set`estimated_entry_charge` to 0 means let rocksdb dynamically and +// automacally adjust the table size for the cache. +constexpr size_t kRockdbHCCAutoAdjustCharge = 0; + const int64_t kIORateLimitMaxMb = 1024000; using rocksdb::Slice; @@ -152,9 +170,12 @@ rocksdb::Options Storage::InitRocksDBOptions() { options.compression_per_level[i] = config_->rocks_db.compression; } } + if (config_->rocks_db.row_cache_size) { - options.row_cache = rocksdb::NewLRUCache(config_->rocks_db.row_cache_size * MiB); + options.row_cache = rocksdb::NewLRUCache(config_->rocks_db.row_cache_size * MiB, kRocksdbLRUAutoAdjustShardBits, + kRocksdbCacheStrictCapacityLimit, kRocksdbLRURowCacheHighPriPoolRatio); } + options.enable_pipelined_write = config_->rocks_db.enable_pipelined_write; options.target_file_size_base = config_->rocks_db.target_file_size_base * MiB; options.max_manifest_file_size = 64 * MiB; @@ -256,7 +277,15 @@ Status Storage::Open(DBOpenMode mode) { } } - std::shared_ptr shared_block_cache = rocksdb::NewLRUCache(block_cache_size, -1, false, 0.75); + std::shared_ptr shared_block_cache; + + if (config_->rocks_db.block_cache_type == BlockCacheType::kCacheTypeLRU) { + shared_block_cache = rocksdb::NewLRUCache(block_cache_size, kRocksdbLRUAutoAdjustShardBits, + kRocksdbCacheStrictCapacityLimit, kRocksdbLRUBlockCacheHighPriPoolRatio); + } else { + rocksdb::HyperClockCacheOptions hcc_cache_options(block_cache_size, kRockdbHCCAutoAdjustCharge); + shared_block_cache = hcc_cache_options.MakeSharedCache(); + } rocksdb::BlockBasedTableOptions metadata_table_opts = InitTableOptions(); metadata_table_opts.block_cache = shared_block_cache; diff --git a/src/storage/storage.h b/src/storage/storage.h index f07b1b4821e..dfa60659494 100644 --- a/src/storage/storage.h +++ b/src/storage/storage.h @@ -96,6 +96,17 @@ inline const std::vector CompressionOptions = { {rocksdb::kZSTD, "zstd", "kZSTD"}, }; +struct CacheOption { + BlockCacheType type; + const std::string name; + const std::string val; +}; + +inline const std::vector CacheOptions = { + {BlockCacheType::kCacheTypeLRU, "lru", "kCacheTypeLRU"}, + {BlockCacheType::kCacheTypeHCC, "hcc", "kCacheTypeHCC"}, +}; + enum class StatType { CompactionCount, FlushCount, diff --git a/src/types/redis_bitmap_string.cc b/src/types/redis_bitmap_string.cc index d9d77114985..9b17963010a 100644 --- a/src/types/redis_bitmap_string.cc +++ b/src/types/redis_bitmap_string.cc @@ -222,7 +222,10 @@ rocksdb::Status BitmapString::Bitfield(const Slice &ns_key, std::string *raw_val ArrayBitfieldBitmap bitfield(first_byte); auto str = reinterpret_cast(string_value.data() + first_byte); - auto s = bitfield.Set(first_byte, last_byte - first_byte, str); + auto s = bitfield.Set(/*byte_offset=*/first_byte, /*bytes=*/last_byte - first_byte, /*src=*/str); + if (!s.IsOK()) { + return rocksdb::Status::IOError(s.Msg()); + } uint64_t unsigned_old_value = 0; if (op.encoding.IsSigned()) { @@ -232,13 +235,19 @@ rocksdb::Status BitmapString::Bitfield(const Slice &ns_key, std::string *raw_val } uint64_t unsigned_new_value = 0; - auto &ret = rets->emplace_back(); - if (BitfieldOp(op, unsigned_old_value, &unsigned_new_value).GetValue()) { + std::optional &ret = rets->emplace_back(); + StatusOr bitfield_op = BitfieldOp(op, unsigned_old_value, &unsigned_new_value); + if (!bitfield_op.IsOK()) { + return rocksdb::Status::InvalidArgument(bitfield_op.Msg()); + } + if (bitfield_op.GetValue()) { if (op.type != BitfieldOperation::Type::kGet) { // never failed. s = bitfield.SetBitfield(op.offset, op.encoding.Bits(), unsigned_new_value); + CHECK(s.IsOK()); auto dst = reinterpret_cast(string_value.data()) + first_byte; s = bitfield.Get(first_byte, last_byte - first_byte, dst); + CHECK(s.IsOK()); } if (op.type == BitfieldOperation::Type::kSet) { @@ -276,6 +285,9 @@ rocksdb::Status BitmapString::BitfieldReadOnly(const Slice &ns_key, const std::s ArrayBitfieldBitmap bitfield(first_byte); auto s = bitfield.Set(first_byte, last_byte - first_byte, reinterpret_cast(string_value.data() + first_byte)); + if (!s.IsOK()) { + return rocksdb::Status::IOError(s.Msg()); + } if (op.encoding.IsSigned()) { int64_t value = bitfield.GetSignedBitfield(op.offset, op.encoding.Bits()).GetValue(); diff --git a/src/types/redis_hash.cc b/src/types/redis_hash.cc index 0ee9c16e3d2..d7a1ad8b7bc 100644 --- a/src/types/redis_hash.cc +++ b/src/types/redis_hash.cc @@ -405,8 +405,7 @@ rocksdb::Status Hash::RandField(const Slice &user_key, int64_t command_count, st field_values->reserve(std::min(size, count)); if (!unique || count == 1) { // Case 1: Negative count, randomly select elements or without parameter - std::random_device rd; - std::mt19937 gen(rd()); + std::mt19937 gen(std::random_device{}()); std::uniform_int_distribution dis(0, size - 1); for (uint64_t i = 0; i < count; i++) { uint64_t index = dis(gen); @@ -421,8 +420,8 @@ rocksdb::Status Hash::RandField(const Slice &user_key, int64_t command_count, st // Case 3: Requested count is less than the number of elements inside the hash std::vector indices(size); std::iota(indices.begin(), indices.end(), 0); - std::shuffle(indices.begin(), indices.end(), - std::random_device{}); // use Fisher-Yates shuffle algorithm to randomize the order + std::mt19937 gen(std::random_device{}()); + std::shuffle(indices.begin(), indices.end(), gen); // use Fisher-Yates shuffle algorithm to randomize the order for (uint64_t i = 0; i < count; i++) { uint64_t index = indices[i]; append_field_with_index(index); diff --git a/src/types/redis_set.cc b/src/types/redis_set.cc index 32c06ef9132..98677e27560 100644 --- a/src/types/redis_set.cc +++ b/src/types/redis_set.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include "db_util.h" @@ -222,15 +223,32 @@ rocksdb::Status Set::Take(const Slice &user_key, std::vector *membe rocksdb::Slice upper_bound(next_version_prefix); read_options.iterate_upper_bound = &upper_bound; + std::vector iter_keys; + iter_keys.reserve(count); + std::random_device rd; + std::mt19937 gen(rd()); auto iter = util::UniqueIterator(storage_, read_options); for (iter->Seek(prefix); iter->Valid() && iter->key().starts_with(prefix); iter->Next()) { - InternalKey ikey(iter->key(), storage_->IsSlotIdEncoded()); + ++n; + if (n <= count) { + iter_keys.push_back(iter->key().ToString()); + } else { // n > count + std::uniform_int_distribution<> distrib(0, n - 1); + int random = distrib(gen); // [0,n-1] + if (random < count) { + iter_keys[random] = iter->key().ToString(); + } + } + } + for (Slice key : iter_keys) { + InternalKey ikey(key, storage_->IsSlotIdEncoded()); members->emplace_back(ikey.GetSubKey().ToString()); - if (pop) batch->Delete(iter->key()); - if (++n >= count) break; + if (pop) { + batch->Delete(key); + } } - if (pop && n > 0) { - metadata.size -= n; + if (pop && !iter_keys.empty()) { + metadata.size -= iter_keys.size(); std::string bytes; metadata.Encode(&bytes); batch->Put(metadata_cf_handle_, ns_key, bytes); diff --git a/src/types/redis_zset.cc b/src/types/redis_zset.cc index 57c154acc8d..328182b5518 100644 --- a/src/types/redis_zset.cc +++ b/src/types/redis_zset.cc @@ -179,7 +179,7 @@ rocksdb::Status ZSet::Pop(const Slice &user_key, int count, bool min, MemberScor auto iter = util::UniqueIterator(storage_, read_options, score_cf_handle_); iter->Seek(start_key); - // see comment in rangebyscore() + // see comment in RangeByScore() if (!min && (!iter->Valid() || !iter->key().starts_with(prefix_key))) { iter->SeekForPrev(start_key); } @@ -249,7 +249,7 @@ rocksdb::Status ZSet::RangeByRank(const Slice &user_key, const RangeRankSpec &sp auto batch = storage_->GetWriteBatchBase(); auto iter = util::UniqueIterator(storage_, read_options, score_cf_handle_); iter->Seek(start_key); - // see comment in rangebyscore() + // see comment in RangeByScore() if (spec.reversed && (!iter->Valid() || !iter->key().starts_with(prefix_key))) { iter->SeekForPrev(start_key); } @@ -583,7 +583,7 @@ rocksdb::Status ZSet::Rank(const Slice &user_key, const Slice &member, bool reve auto iter = util::UniqueIterator(storage_, read_options, score_cf_handle_); iter->Seek(start_key); - // see comment in rangebyscore() + // see comment in RangeByScore() if (reversed && (!iter->Valid() || !iter->key().starts_with(prefix_key))) { iter->SeekForPrev(start_key); } @@ -716,7 +716,7 @@ rocksdb::Status ZSet::InterCard(const std::vector &user_keys, uint6 MemberScores mscores; auto s = RangeByScore(user_key, spec, &mscores, nullptr); if (!s.ok() || mscores.empty()) return s; - mscores_list.emplace_back(mscores); + mscores_list.emplace_back(std::move(mscores)); } std::sort(mscores_list.begin(), mscores_list.end(), [](const MemberScores &v1, const MemberScores &v2) { return v1.size() < v2.size(); }); @@ -897,13 +897,14 @@ rocksdb::Status ZSet::RandMember(const Slice &user_key, int64_t command_count, std::string ns_key = AppendNamespacePrefix(user_key); ZSetMetadata metadata(false); rocksdb::Status s = GetMetadata(ns_key, &metadata); - if (!s.ok() || metadata.size == 0) return s; + if (!s.ok()) return s.IsNotFound() ? rocksdb::Status::OK() : s; + if (metadata.size == 0) return rocksdb::Status::OK(); std::vector samples; s = GetAllMemberScores(user_key, &samples); if (!s.ok() || samples.empty()) return s; - auto size = static_cast(samples.size()); + uint64_t size = samples.size(); member_scores->reserve(std::min(size, count)); if (!unique || count == 1) { @@ -915,19 +916,58 @@ rocksdb::Status ZSet::RandMember(const Slice &user_key, int64_t command_count, } } else if (size <= count) { for (auto &sample : samples) { - member_scores->push_back(sample); + member_scores->push_back(std::move(sample)); } } else { // first shuffle the samples - std::shuffle(samples.begin(), samples.end(), std::random_device{}); - + std::mt19937 gen(std::random_device{}()); + std::shuffle(samples.begin(), samples.end(), gen); // then pick the first `count` ones. for (uint64_t i = 0; i < count; i++) { - member_scores->emplace_back(samples[i]); + member_scores->emplace_back(std::move(samples[i])); } } return rocksdb::Status::OK(); } +rocksdb::Status ZSet::Diff(const std::vector &keys, MemberScores *members) { + members->clear(); + MemberScores source_member_scores; + RangeScoreSpec spec; + uint64_t first_element_size = 0; + auto s = RangeByScore(keys[0], spec, &source_member_scores, &first_element_size); + if (!s.ok()) return s; + + if (first_element_size == 0) { + return rocksdb::Status::OK(); + } + + std::set exclude_members; + MemberScores target_member_scores; + for (size_t i = 1; i < keys.size(); i++) { + uint64_t size = 0; + s = RangeByScore(keys[i], spec, &target_member_scores, &size); + if (!s.ok()) return s; + for (auto &member_score : target_member_scores) { + exclude_members.emplace(std::move(member_score.member)); + } + target_member_scores.clear(); + } + for (const auto &member_score : source_member_scores) { + if (exclude_members.find(member_score.member) == exclude_members.end()) { + members->push_back(member_score); + } + } + return rocksdb::Status::OK(); +} + +rocksdb::Status ZSet::DiffStore(const Slice &dst, const std::vector &keys, uint64_t *stored_count) { + MemberScores mscores; + auto s = Diff(keys, &mscores); + if (!s.ok()) return s; + *stored_count = mscores.size(); + return Overwrite(dst, mscores); +} + } // namespace redis diff --git a/src/types/redis_zset.h b/src/types/redis_zset.h index 397ca10b126..d806d57e3cf 100644 --- a/src/types/redis_zset.h +++ b/src/types/redis_zset.h @@ -116,6 +116,8 @@ class ZSet : public SubKeyScanner { AggregateMethod aggregate_method, uint64_t *saved_cnt); rocksdb::Status Union(const std::vector &keys_weights, AggregateMethod aggregate_method, std::vector *members); + rocksdb::Status Diff(const std::vector &keys, MemberScores *members); + rocksdb::Status DiffStore(const Slice &dst, const std::vector &keys, uint64_t *stored_count); rocksdb::Status MGet(const Slice &user_key, const std::vector &members, std::map *scores); rocksdb::Status GetMetadata(const Slice &ns_key, ZSetMetadata *metadata); diff --git a/tests/cppunit/bitfield_util.cc b/tests/cppunit/bitfield_util.cc index 71976df4627..c6b9d36d1eb 100644 --- a/tests/cppunit/bitfield_util.cc +++ b/tests/cppunit/bitfield_util.cc @@ -35,7 +35,7 @@ TEST(BitfieldUtil, Get) { ArrayBitfieldBitmap bitfield(0); auto s = bitfield.Set(0, big_endian_bitmap.size(), big_endian_bitmap.data()); - for (int bits = 16; bits < 64; bits *= 2) { + for (uint8_t bits = 16; bits < 64; bits *= 2) { for (uint64_t offset = 0; bits + offset <= big_endian_bitmap.size() * 8; offset += bits) { uint64_t value = bitfield.GetUnsignedBitfield(offset, bits).GetValue(); if (IsBigEndian()) { diff --git a/tests/cppunit/iterator_test.cc b/tests/cppunit/iterator_test.cc index 4bbd24089ea..6a2437a126b 100644 --- a/tests/cppunit/iterator_test.cc +++ b/tests/cppunit/iterator_test.cc @@ -19,6 +19,7 @@ */ #include +#include #include #include #include @@ -32,10 +33,10 @@ #include "test_base.h" #include "types/redis_string.h" -class IteratorTest : public TestBase { +class DBIteratorTest : public TestBase { protected: - explicit IteratorTest() = default; - ~IteratorTest() override = default; + explicit DBIteratorTest() = default; + ~DBIteratorTest() override = default; void SetUp() override { { // string @@ -112,7 +113,7 @@ class IteratorTest : public TestBase { } }; -TEST_F(IteratorTest, AllKeys) { +TEST_F(DBIteratorTest, AllKeys) { engine::DBIterator iter(storage_, rocksdb::ReadOptions()); std::vector live_keys = {"a", "b", "d", "hash-1", "set-1", "zset-1", "list-1", "stream-1", "bitmap-1", "json-1", "json-2", "json-3", "sortedint-1"}; @@ -126,7 +127,7 @@ TEST_F(IteratorTest, AllKeys) { ASSERT_TRUE(live_keys.empty()); } -TEST_F(IteratorTest, BasicString) { +TEST_F(DBIteratorTest, BasicString) { engine::DBIterator iter(storage_, rocksdb::ReadOptions()); std::vector expected_keys = {"a", "b", "d"}; @@ -148,7 +149,7 @@ TEST_F(IteratorTest, BasicString) { ASSERT_TRUE(expected_keys.empty()); } -TEST_F(IteratorTest, BasicHash) { +TEST_F(DBIteratorTest, BasicHash) { engine::DBIterator iter(storage_, rocksdb::ReadOptions()); auto prefix = ComposeNamespaceKey("test_ns1", "", storage_->IsSlotIdEncoded()); for (iter.Seek(prefix); iter.Valid() && iter.Key().starts_with(prefix); iter.Next()) { @@ -171,7 +172,7 @@ TEST_F(IteratorTest, BasicHash) { } } -TEST_F(IteratorTest, BasicSet) { +TEST_F(DBIteratorTest, BasicSet) { engine::DBIterator iter(storage_, rocksdb::ReadOptions()); auto prefix = ComposeNamespaceKey("test_ns2", "", storage_->IsSlotIdEncoded()); for (iter.Seek(prefix); iter.Valid() && iter.Key().starts_with(prefix); iter.Next()) { @@ -194,7 +195,7 @@ TEST_F(IteratorTest, BasicSet) { } } -TEST_F(IteratorTest, BasicZSet) { +TEST_F(DBIteratorTest, BasicZSet) { engine::DBIterator iter(storage_, rocksdb::ReadOptions()); auto prefix = ComposeNamespaceKey("test_ns3", "", storage_->IsSlotIdEncoded()); for (iter.Seek(prefix); iter.Valid() && iter.Key().starts_with(prefix); iter.Next()) { @@ -217,7 +218,7 @@ TEST_F(IteratorTest, BasicZSet) { } } -TEST_F(IteratorTest, BasicList) { +TEST_F(DBIteratorTest, BasicList) { engine::DBIterator iter(storage_, rocksdb::ReadOptions()); auto prefix = ComposeNamespaceKey("test_ns4", "", storage_->IsSlotIdEncoded()); for (iter.Seek(prefix); iter.Valid() && iter.Key().starts_with(prefix); iter.Next()) { @@ -240,7 +241,7 @@ TEST_F(IteratorTest, BasicList) { } } -TEST_F(IteratorTest, BasicStream) { +TEST_F(DBIteratorTest, BasicStream) { engine::DBIterator iter(storage_, rocksdb::ReadOptions()); auto prefix = ComposeNamespaceKey("test_ns5", "", storage_->IsSlotIdEncoded()); for (iter.Seek(prefix); iter.Valid() && iter.Key().starts_with(prefix); iter.Next()) { @@ -266,7 +267,7 @@ TEST_F(IteratorTest, BasicStream) { } } -TEST_F(IteratorTest, BasicBitmap) { +TEST_F(DBIteratorTest, BasicBitmap) { engine::DBIterator iter(storage_, rocksdb::ReadOptions()); auto prefix = ComposeNamespaceKey("test_ns6", "", storage_->IsSlotIdEncoded()); for (iter.Seek(prefix); iter.Valid() && iter.Key().starts_with(prefix); iter.Next()) { @@ -288,7 +289,7 @@ TEST_F(IteratorTest, BasicBitmap) { } } -TEST_F(IteratorTest, BasicJSON) { +TEST_F(DBIteratorTest, BasicJSON) { engine::DBIterator iter(storage_, rocksdb::ReadOptions()); std::vector expected_keys = {"json-1", "json-2", "json-3"}; @@ -310,7 +311,7 @@ TEST_F(IteratorTest, BasicJSON) { ASSERT_TRUE(expected_keys.empty()); } -TEST_F(IteratorTest, BasicSortedInt) { +TEST_F(DBIteratorTest, BasicSortedInt) { engine::DBIterator iter(storage_, rocksdb::ReadOptions()); auto prefix = ComposeNamespaceKey("test_ns8", "", storage_->IsSlotIdEncoded()); @@ -343,6 +344,7 @@ class SlotIteratorTest : public TestBase { TEST_F(SlotIteratorTest, LiveKeys) { redis::String string(storage_, kDefaultNamespace); + auto start_seq = storage_->GetDB()->GetLatestSequenceNumber(); std::vector keys = {"{x}a", "{x}b", "{y}c", "{y}d", "{x}e"}; for (const auto &key : keys) { string.Set(key, "1"); @@ -363,4 +365,495 @@ TEST_F(SlotIteratorTest, LiveKeys) { count++; } ASSERT_EQ(count, same_slot_keys.size()); + + engine::WALIterator wal_iter(storage_, slot_id); + count = 0; + for (wal_iter.Seek(start_seq + 1); wal_iter.Valid(); wal_iter.Next()) { + auto item = wal_iter.Item(); + if (item.type == engine::WALItem::Type::kTypePut) { + auto [_, user_key] = ExtractNamespaceKey(item.key, storage_->IsSlotIdEncoded()); + ASSERT_EQ(slot_id, GetSlotIdFromKey(user_key.ToString())) << user_key.ToString(); + count++; + } + } + ASSERT_EQ(count, same_slot_keys.size()); +} + +class WALIteratorTest : public TestBase { + protected: + explicit WALIteratorTest() = default; + ~WALIteratorTest() override = default; + void SetUp() override {} +}; + +TEST_F(WALIteratorTest, BasicString) { + auto start_seq = storage_->GetDB()->GetLatestSequenceNumber(); + redis::String string(storage_, "test_ns0"); + string.Set("a", "1"); + string.MSet({{"b", "2"}, {"c", "3"}}); + ASSERT_TRUE(string.Del("b").ok()); + + std::vector put_keys, delete_keys; + auto expected_put_keys = {"a", "b", "c"}; + auto expected_delete_keys = {"b"}; + + engine::WALIterator iter(storage_); + + for (iter.Seek(start_seq + 1); iter.Valid(); iter.Next()) { + auto item = iter.Item(); + switch (item.type) { + case engine::WALItem::Type::kTypePut: { + auto [ns, key] = ExtractNamespaceKey(item.key, storage_->IsSlotIdEncoded()); + ASSERT_EQ(ns.ToString(), "test_ns0"); + put_keys.emplace_back(key.ToString()); + break; + } + case engine::WALItem::Type::kTypeLogData: { + redis::WriteBatchLogData log_data; + ASSERT_TRUE(log_data.Decode(item.key).IsOK()); + ASSERT_EQ(log_data.GetRedisType(), kRedisString); + break; + } + case engine::WALItem::Type::kTypeDelete: { + auto [ns, key] = ExtractNamespaceKey(item.key, storage_->IsSlotIdEncoded()); + ASSERT_EQ(ns.ToString(), "test_ns0"); + delete_keys.emplace_back(key.ToString()); + break; + } + default: + FAIL() << "Unexpected wal item type" << uint8_t(item.type); + } + } + ASSERT_EQ(expected_put_keys.size(), put_keys.size()); + ASSERT_EQ(expected_delete_keys.size(), delete_keys.size()); + ASSERT_TRUE(std::equal(expected_put_keys.begin(), expected_put_keys.end(), put_keys.begin())); + ASSERT_TRUE(std::equal(expected_delete_keys.begin(), expected_delete_keys.end(), delete_keys.begin())); +} + +TEST_F(WALIteratorTest, BasicHash) { + auto start_seq = storage_->GetDB()->GetLatestSequenceNumber(); + redis::Hash hash(storage_, "test_ns1"); + uint64_t ret = 0; + hash.MSet("hash-1", {{"f0", "v0"}, {"f1", "v1"}, {"f2", "v2"}, {"f3", "v3"}}, false, &ret); + uint64_t deleted_cnt = 0; + hash.Delete("hash-1", {"f0"}, &deleted_cnt); + + // Delete will put meta key again + auto expected_put_keys = {"hash-1", "hash-1"}; + // Sub key will be putted in reverse order + auto expected_put_fields = {"f3", "f2", "f1", "f0"}; + auto expected_delete_fields = {"f0"}; + std::vector put_keys, put_fields, delete_fields; + + engine::WALIterator iter(storage_); + + for (iter.Seek(start_seq + 1); iter.Valid(); iter.Next()) { + auto item = iter.Item(); + switch (item.type) { + case engine::WALItem::Type::kTypePut: { + if (item.column_family_id == kColumnFamilyIDDefault) { + InternalKey internal_key(item.key, storage_->IsSlotIdEncoded()); + put_fields.emplace_back(internal_key.GetSubKey().ToString()); + } else if (item.column_family_id == kColumnFamilyIDMetadata) { + auto [ns, key] = ExtractNamespaceKey(item.key, storage_->IsSlotIdEncoded()); + ASSERT_EQ(ns.ToString(), "test_ns1"); + put_keys.emplace_back(key.ToString()); + } + break; + } + case engine::WALItem::Type::kTypeLogData: { + redis::WriteBatchLogData log_data; + ASSERT_TRUE(log_data.Decode(item.key).IsOK()); + ASSERT_EQ(log_data.GetRedisType(), kRedisHash); + break; + } + case engine::WALItem::Type::kTypeDelete: { + InternalKey internal_key(item.key, storage_->IsSlotIdEncoded()); + delete_fields.emplace_back(internal_key.GetSubKey().ToString()); + break; + } + default: + FAIL() << "Unexpected wal item type" << uint8_t(item.type); + } + } + ASSERT_EQ(expected_put_keys.size(), put_keys.size()); + ASSERT_EQ(expected_put_fields.size(), put_fields.size()); + ASSERT_EQ(expected_delete_fields.size(), delete_fields.size()); + ASSERT_TRUE(std::equal(expected_put_keys.begin(), expected_put_keys.end(), put_keys.begin())); + ASSERT_TRUE(std::equal(expected_put_fields.begin(), expected_put_fields.end(), put_fields.begin())); + ASSERT_TRUE(std::equal(expected_delete_fields.begin(), expected_delete_fields.end(), delete_fields.begin())); +} + +TEST_F(WALIteratorTest, BasicSet) { + auto start_seq = storage_->GetDB()->GetLatestSequenceNumber(); + + uint64_t ret = 0; + redis::Set set(storage_, "test_ns2"); + set.Add("set-1", {"e0", "e1", "e2"}, &ret); + uint64_t removed_cnt = 0; + set.Remove("set-1", {"e0", "e1"}, &removed_cnt); + + auto expected_put_keys = {"set-1", "set-1"}; + auto expected_put_members = {"e0", "e1", "e2"}; + auto expected_delete_members = {"e0", "e1"}; + std::vector put_keys, put_members, delete_members; + + engine::WALIterator iter(storage_); + + for (iter.Seek(start_seq + 1); iter.Valid(); iter.Next()) { + auto item = iter.Item(); + switch (item.type) { + case engine::WALItem::Type::kTypePut: { + if (item.column_family_id == kColumnFamilyIDDefault) { + InternalKey internal_key(item.key, storage_->IsSlotIdEncoded()); + put_members.emplace_back(internal_key.GetSubKey().ToString()); + } else if (item.column_family_id == kColumnFamilyIDMetadata) { + auto [ns, key] = ExtractNamespaceKey(item.key, storage_->IsSlotIdEncoded()); + ASSERT_EQ(ns.ToString(), "test_ns2"); + put_keys.emplace_back(key.ToString()); + } + break; + } + case engine::WALItem::Type::kTypeLogData: { + redis::WriteBatchLogData log_data; + ASSERT_TRUE(log_data.Decode(item.key).IsOK()); + ASSERT_EQ(log_data.GetRedisType(), kRedisSet); + break; + } + case engine::WALItem::Type::kTypeDelete: { + InternalKey internal_key(item.key, storage_->IsSlotIdEncoded()); + delete_members.emplace_back(internal_key.GetSubKey().ToString()); + break; + } + default: + FAIL() << "Unexpected wal item type" << uint8_t(item.type); + } + } + + ASSERT_EQ(expected_put_keys.size(), put_keys.size()); + ASSERT_EQ(expected_put_members.size(), put_members.size()); + ASSERT_EQ(expected_delete_members.size(), delete_members.size()); + ASSERT_TRUE(std::equal(expected_put_keys.begin(), expected_put_keys.end(), put_keys.begin())); + ASSERT_TRUE(std::equal(expected_put_members.begin(), expected_put_members.end(), put_members.begin())); + ASSERT_TRUE(std::equal(expected_delete_members.begin(), expected_delete_members.end(), delete_members.begin())); +} + +TEST_F(WALIteratorTest, BasicZSet) { + auto start_seq = storage_->GetDB()->GetLatestSequenceNumber(); + uint64_t ret = 0; + redis::ZSet zset(storage_, "test_ns3"); + auto mscores = std::vector{{"z0", 0}, {"z1", 1}, {"z2", 2}}; + zset.Add("zset-1", ZAddFlags(), &mscores, &ret); + uint64_t removed_cnt = 0; + zset.Remove("zset-1", {"z0"}, &removed_cnt); + + auto expected_put_keys = {"zset-1", "zset-1"}; + auto expected_put_members = {"z2", "z1", "z0"}; + // member and score + int expected_delete_count = 2, delete_count = 0; + std::vector put_keys, put_members; + + engine::WALIterator iter(storage_); + + for (iter.Seek(start_seq + 1); iter.Valid(); iter.Next()) { + auto item = iter.Item(); + switch (item.type) { + case engine::WALItem::Type::kTypePut: { + if (item.column_family_id == kColumnFamilyIDDefault) { + InternalKey internal_key(item.key, storage_->IsSlotIdEncoded()); + put_members.emplace_back(internal_key.GetSubKey().ToString()); + } else if (item.column_family_id == kColumnFamilyIDMetadata) { + auto [ns, key] = ExtractNamespaceKey(item.key, storage_->IsSlotIdEncoded()); + ASSERT_EQ(ns.ToString(), "test_ns3"); + put_keys.emplace_back(key.ToString()); + } + break; + } + case engine::WALItem::Type::kTypeLogData: { + redis::WriteBatchLogData log_data; + ASSERT_TRUE(log_data.Decode(item.key).IsOK()); + ASSERT_EQ(log_data.GetRedisType(), kRedisZSet); + break; + } + case engine::WALItem::Type::kTypeDelete: { + delete_count++; + break; + } + default: + FAIL() << "Unexpected wal item type" << uint8_t(item.type); + } + } + + ASSERT_EQ(expected_put_keys.size(), put_keys.size()); + ASSERT_EQ(expected_put_members.size(), put_members.size()); + ASSERT_EQ(expected_delete_count, delete_count); + ASSERT_TRUE(std::equal(expected_put_keys.begin(), expected_put_keys.end(), put_keys.begin())); + ASSERT_TRUE(std::equal(expected_put_members.begin(), expected_put_members.end(), put_members.begin())); +} + +TEST_F(WALIteratorTest, BasicList) { + auto start_seq = storage_->GetDB()->GetLatestSequenceNumber(); + uint64_t ret = 0; + redis::List list(storage_, "test_ns4"); + list.Push("list-1", {"l0", "l1", "l2", "l3", "l4"}, false, &ret); + ASSERT_TRUE(list.Trim("list-1", 2, 4).ok()); + + auto expected_put_keys = {"list-1", "list-1"}; + auto expected_put_values = {"l0", "l1", "l2", "l3", "l4"}; + auto expected_delete_count = 2, delete_count = 0; + std::vector put_keys, put_values; + + engine::WALIterator iter(storage_); + + for (iter.Seek(start_seq + 1); iter.Valid(); iter.Next()) { + auto item = iter.Item(); + switch (item.type) { + case engine::WALItem::Type::kTypePut: { + if (item.column_family_id == kColumnFamilyIDDefault) { + put_values.emplace_back(item.value); + } else if (item.column_family_id == kColumnFamilyIDMetadata) { + auto [ns, key] = ExtractNamespaceKey(item.key, storage_->IsSlotIdEncoded()); + ASSERT_EQ(ns.ToString(), "test_ns4"); + put_keys.emplace_back(key.ToString()); + } + break; + } + case engine::WALItem::Type::kTypeLogData: { + redis::WriteBatchLogData log_data; + ASSERT_TRUE(log_data.Decode(item.key).IsOK()); + ASSERT_EQ(log_data.GetRedisType(), kRedisList); + break; + } + case engine::WALItem::Type::kTypeDelete: { + delete_count++; + break; + } + default: + FAIL() << "Unexpected wal item type" << uint8_t(item.type); + } + } + + ASSERT_EQ(expected_put_keys.size(), put_keys.size()); + ASSERT_EQ(expected_put_values.size(), put_values.size()); + ASSERT_EQ(expected_delete_count, delete_count); + ASSERT_TRUE(std::equal(expected_put_keys.begin(), expected_put_keys.end(), put_keys.begin())); + ASSERT_TRUE(std::equal(expected_put_values.begin(), expected_put_values.end(), put_values.begin())); +} + +TEST_F(WALIteratorTest, BasicStream) { + auto start_seq = storage_->GetDB()->GetLatestSequenceNumber(); + redis::Stream stream(storage_, "test_ns5"); + redis::StreamEntryID ret; + redis::StreamAddOptions options; + options.next_id_strategy = std::make_unique(); + stream.Add("stream-1", options, {"x0"}, &ret); + stream.Add("stream-1", options, {"x1"}, &ret); + stream.Add("stream-1", options, {"x2"}, &ret); + uint64_t deleted = 0; + ASSERT_TRUE(stream.DeleteEntries("stream-1", {ret}, &deleted).ok()); + + auto expected_put_keys = {"stream-1", "stream-1", "stream-1", "stream-1"}; + auto expected_put_values = {"x0", "x1", "x2"}; + int delete_count = 0; + std::vector put_keys, put_values; + + engine::WALIterator iter(storage_); + + for (iter.Seek(start_seq + 1); iter.Valid(); iter.Next()) { + auto item = iter.Item(); + switch (item.type) { + case engine::WALItem::Type::kTypePut: { + if (item.column_family_id == kColumnFamilyIDStream) { + std::vector elems; + auto s = redis::DecodeRawStreamEntryValue(item.value, &elems); + ASSERT_TRUE(s.IsOK() && !elems.empty()); + put_values.emplace_back(elems[0]); + } else if (item.column_family_id == kColumnFamilyIDMetadata) { + auto [ns, key] = ExtractNamespaceKey(item.key, storage_->IsSlotIdEncoded()); + ASSERT_EQ(ns.ToString(), "test_ns5"); + put_keys.emplace_back(key.ToString()); + } + break; + } + case engine::WALItem::Type::kTypeLogData: { + redis::WriteBatchLogData log_data; + ASSERT_TRUE(log_data.Decode(item.key).IsOK()); + ASSERT_EQ(log_data.GetRedisType(), kRedisStream); + break; + } + case engine::WALItem::Type::kTypeDelete: { + delete_count++; + break; + } + default: + FAIL() << "Unexpected wal item type" << uint8_t(item.type); + } + } + + ASSERT_EQ(expected_put_keys.size(), put_keys.size()); + ASSERT_EQ(expected_put_values.size(), put_values.size()); + ASSERT_EQ(deleted, delete_count); + ASSERT_TRUE(std::equal(expected_put_keys.begin(), expected_put_keys.end(), put_keys.begin())); + ASSERT_TRUE(std::equal(expected_put_values.begin(), expected_put_values.end(), put_values.begin())); +} + +TEST_F(WALIteratorTest, BasicBitmap) { + auto start_seq = storage_->GetDB()->GetLatestSequenceNumber(); + + redis::Bitmap bitmap(storage_, "test_ns6"); + bool ret = false; + bitmap.SetBit("bitmap-1", 0, true, &ret); + bitmap.SetBit("bitmap-1", 8 * 1024, true, &ret); + bitmap.SetBit("bitmap-1", 2 * 8 * 1024, true, &ret); + + auto expected_put_values = {"\x1", "\x1", "\x1"}; + std::vector put_values; + + engine::WALIterator iter(storage_); + + for (iter.Seek(start_seq + 1); iter.Valid(); iter.Next()) { + auto item = iter.Item(); + switch (item.type) { + case engine::WALItem::Type::kTypePut: { + if (item.column_family_id == kColumnFamilyIDDefault) { + put_values.emplace_back(item.value); + } + break; + } + case engine::WALItem::Type::kTypeLogData: { + redis::WriteBatchLogData log_data; + ASSERT_TRUE(log_data.Decode(item.key).IsOK()); + ASSERT_EQ(log_data.GetRedisType(), kRedisBitmap); + break; + } + default: + FAIL() << "Unexpected wal item type" << uint8_t(item.type); + } + } + ASSERT_EQ(expected_put_values.size(), put_values.size()); + ASSERT_TRUE(std::equal(expected_put_values.begin(), expected_put_values.end(), put_values.begin())); +} + +TEST_F(WALIteratorTest, BasicJSON) { + auto start_seq = storage_->GetDB()->GetLatestSequenceNumber(); + redis::Json json(storage_, "test_ns7"); + json.Set("json-1", "$", "{\"a\": 1, \"b\": 2}"); + json.Set("json-2", "$", "{\"a\": 1, \"b\": 2}"); + json.Set("json-3", "$", "{\"a\": 1, \"b\": 2}"); + + size_t result = 0; + ASSERT_TRUE(json.Del("json-3", "$", &result).ok()); + + auto expected_put_keys = {"json-1", "json-2", "json-3"}; + auto expected_delete_keys = {"json-3"}; + std::vector put_keys, delete_keys; + + engine::WALIterator iter(storage_); + + for (iter.Seek(start_seq + 1); iter.Valid(); iter.Next()) { + auto item = iter.Item(); + switch (item.type) { + case engine::WALItem::Type::kTypePut: { + ASSERT_EQ(item.column_family_id, kColumnFamilyIDMetadata); + auto [ns, key] = ExtractNamespaceKey(item.key, storage_->IsSlotIdEncoded()); + ASSERT_EQ(ns.ToString(), "test_ns7"); + put_keys.emplace_back(key.ToString()); + break; + } + case engine::WALItem::Type::kTypeLogData: { + redis::WriteBatchLogData log_data; + ASSERT_TRUE(log_data.Decode(item.key).IsOK()); + ASSERT_EQ(log_data.GetRedisType(), kRedisJson); + break; + } + case engine::WALItem::Type::kTypeDelete: { + ASSERT_EQ(item.column_family_id, kColumnFamilyIDMetadata); + auto [ns, key] = ExtractNamespaceKey(item.key, storage_->IsSlotIdEncoded()); + ASSERT_EQ(ns.ToString(), "test_ns7"); + delete_keys.emplace_back(key.ToString()); + break; + } + default: + FAIL() << "Unexpected wal item type" << uint8_t(item.type); + } + } + + ASSERT_EQ(expected_put_keys.size(), put_keys.size()); + ASSERT_EQ(expected_delete_keys.size(), delete_keys.size()); + ASSERT_TRUE(std::equal(expected_put_keys.begin(), expected_put_keys.end(), put_keys.begin())); + ASSERT_TRUE(std::equal(expected_delete_keys.begin(), expected_delete_keys.end(), delete_keys.begin())); +} + +TEST_F(WALIteratorTest, BasicSortedInt) { + auto start_seq = storage_->GetDB()->GetLatestSequenceNumber(); + redis::Sortedint sortedint(storage_, "test_ns8"); + uint64_t ret = 0; + sortedint.Add("sortedint-1", {1, 2, 3}, &ret); + uint64_t removed_cnt = 0; + sortedint.Remove("sortedint-1", {2}, &removed_cnt); + + std::vector expected_values = {1, 2, 3}, put_values; + std::vector expected_delete_values = {2}, delete_values; + + engine::WALIterator iter(storage_); + + for (iter.Seek(start_seq + 1); iter.Valid(); iter.Next()) { + auto item = iter.Item(); + switch (item.type) { + case engine::WALItem::Type::kTypePut: { + if (item.column_family_id == kColumnFamilyIDDefault) { + const InternalKey internal_key(item.key, storage_->IsSlotIdEncoded()); + auto value = DecodeFixed64(internal_key.GetSubKey().data()); + put_values.emplace_back(value); + } + break; + } + case engine::WALItem::Type::kTypeLogData: { + redis::WriteBatchLogData log_data; + ASSERT_TRUE(log_data.Decode(item.key).IsOK()); + ASSERT_EQ(log_data.GetRedisType(), kRedisSortedint); + break; + } + case engine::WALItem::Type::kTypeDelete: { + const InternalKey internal_key(item.key, storage_->IsSlotIdEncoded()); + auto value = DecodeFixed64(internal_key.GetSubKey().data()); + delete_values.emplace_back(value); + break; + } + default: + FAIL() << "Unexpected wal item type" << uint8_t(item.type); + } + } + ASSERT_EQ(expected_values.size(), put_values.size()); + ASSERT_EQ(expected_delete_values.size(), delete_values.size()); + ASSERT_TRUE(std::equal(expected_values.begin(), expected_values.end(), put_values.begin())); + ASSERT_TRUE(std::equal(expected_delete_values.begin(), expected_delete_values.end(), delete_values.begin())); +} + +TEST_F(WALIteratorTest, NextSequence) { + std::vector expected_next_sequences; + std::set next_sequences_set; + + auto start_seq = storage_->GetDB()->GetLatestSequenceNumber(); + uint64_t ret = 0; + redis::List list(storage_, "test_ns2"); + list.Push("list-1", {"l0", "l1", "l2", "l3", "l4"}, false, &ret); + expected_next_sequences.emplace_back(storage_->GetDB()->GetLatestSequenceNumber() + 1); + list.Push("list-2", {"l0", "l1", "l2"}, false, &ret); + expected_next_sequences.emplace_back(storage_->GetDB()->GetLatestSequenceNumber() + 1); + ASSERT_TRUE(list.Trim("list-1", 2, 4).ok()); + expected_next_sequences.emplace_back(storage_->GetDB()->GetLatestSequenceNumber() + 1); + + engine::WALIterator iter(storage_); + + ASSERT_EQ(iter.NextSequenceNumber(), 0); + + for (iter.Seek(start_seq + 1); iter.Valid(); iter.Next()) { + next_sequences_set.emplace(iter.NextSequenceNumber()); + } + + std::vector next_sequences(next_sequences_set.begin(), next_sequences_set.end()); + + ASSERT_EQ(expected_next_sequences.size(), next_sequences.size()); + ASSERT_TRUE(std::equal(expected_next_sequences.begin(), expected_next_sequences.end(), next_sequences.begin())); } diff --git a/tests/cppunit/types/zset_test.cc b/tests/cppunit/types/zset_test.cc index 34c71d78c2f..da2ce71469c 100644 --- a/tests/cppunit/types/zset_test.cc +++ b/tests/cppunit/types/zset_test.cc @@ -535,3 +535,81 @@ TEST_F(RedisZSetTest, RandMember) { auto s = zset_->Del(key_); EXPECT_TRUE(s.ok()); } + +TEST_F(RedisZSetTest, Diff) { + uint64_t ret = 0; + + std::string k1 = "key1"; + std::vector k1_mscores = {{"a", -100.1}, {"b", -100.1}, {"c", 0}, {"d", 1.234}}; + + std::string k2 = "key2"; + std::vector k2_mscores = {{"c", -150.1}}; + + std::string k3 = "key3"; + std::vector k3_mscores = {{"a", -1000.1}, {"c", -100.1}, {"e", 8000.9}}; + + auto s = zset_->Add(k1, ZAddFlags::Default(), &k1_mscores, &ret); + EXPECT_EQ(ret, 4); + zset_->Add(k2, ZAddFlags::Default(), &k2_mscores, &ret); + EXPECT_EQ(ret, 1); + zset_->Add(k3, ZAddFlags::Default(), &k3_mscores, &ret); + EXPECT_EQ(ret, 3); + + std::vector mscores; + zset_->Diff({k1, k2, k3}, &mscores); + + EXPECT_EQ(2, mscores.size()); + std::vector expected_mscores = {{"b", -100.1}, {"d", 1.234}}; + int index = 0; + for (const auto &mscore : expected_mscores) { + EXPECT_EQ(mscore.member, mscores[index].member); + EXPECT_EQ(mscore.score, mscores[index].score); + index++; + } + + s = zset_->Del(k1); + EXPECT_TRUE(s.ok()); + s = zset_->Del(k2); + EXPECT_TRUE(s.ok()); + s = zset_->Del(k3); + EXPECT_TRUE(s.ok()); +} + +TEST_F(RedisZSetTest, DiffStore) { + uint64_t ret = 0; + + std::string k1 = "key1"; + std::vector k1_mscores = {{"a", -100.1}, {"b", -100.1}, {"c", 0}, {"d", 1.234}}; + + std::string k2 = "key2"; + std::vector k2_mscores = {{"c", -150.1}}; + + auto s = zset_->Add(k1, ZAddFlags::Default(), &k1_mscores, &ret); + EXPECT_EQ(ret, 4); + zset_->Add(k2, ZAddFlags::Default(), &k2_mscores, &ret); + EXPECT_EQ(ret, 1); + + uint64_t stored_count = 0; + zset_->DiffStore("zsetdiff", {k1, k2}, &stored_count); + EXPECT_EQ(stored_count, 3); + + RangeScoreSpec spec; + std::vector mscores; + zset_->RangeByScore("zsetdiff", spec, &mscores, nullptr); + EXPECT_EQ(mscores.size(), 3); + + std::vector expected_mscores = {{"a", -100.1}, {"b", -100.1}, {"d", 1.234}}; + int index = 0; + for (const auto &mscore : expected_mscores) { + EXPECT_EQ(mscore.member, mscores[index].member); + EXPECT_EQ(mscore.score, mscores[index].score); + index++; + } + + s = zset_->Del(k1); + EXPECT_TRUE(s.ok()); + s = zset_->Del(k2); + EXPECT_TRUE(s.ok()); + s = zset_->Del("zsetdiff"); + EXPECT_TRUE(s.ok()); +} diff --git a/tests/gocase/unit/command/command_test.go b/tests/gocase/unit/command/command_test.go index 58b7589a30f..51be1347fd7 100644 --- a/tests/gocase/unit/command/command_test.go +++ b/tests/gocase/unit/command/command_test.go @@ -65,4 +65,190 @@ func TestCommand(t *testing.T) { require.Len(t, vs, 1) require.Equal(t, "test", vs[0]) }) + + t.Run("COMMAND GETKEYS SINTERCARD", func(t *testing.T) { + r := rdb.Do(ctx, "COMMAND", "GETKEYS", "SINTERCARD", "2", "key1", "key2") + vs, err := r.Slice() + require.NoError(t, err) + require.Len(t, vs, 2) + require.Equal(t, "key1", vs[0]) + require.Equal(t, "key2", vs[1]) + }) + + t.Run("COMMAND GETKEYS ZINTER", func(t *testing.T) { + r := rdb.Do(ctx, "COMMAND", "GETKEYS", "ZINTER", "2", "key1", "key2") + vs, err := r.Slice() + require.NoError(t, err) + require.Len(t, vs, 2) + require.Equal(t, "key1", vs[0]) + require.Equal(t, "key2", vs[1]) + }) + + t.Run("COMMAND GETKEYS ZINTERSTORE", func(t *testing.T) { + r := rdb.Do(ctx, "COMMAND", "GETKEYS", "ZINTERSTORE", "dst", "2", "src1", "src2") + vs, err := r.Slice() + require.NoError(t, err) + require.Len(t, vs, 3) + require.Equal(t, "dst", vs[0]) + require.Equal(t, "src1", vs[1]) + require.Equal(t, "src2", vs[2]) + }) + + t.Run("COMMAND GETKEYS ZINTERCARD", func(t *testing.T) { + r := rdb.Do(ctx, "COMMAND", "GETKEYS", "ZINTERCARD", "2", "key1", "key2") + vs, err := r.Slice() + require.NoError(t, err) + require.Len(t, vs, 2) + require.Equal(t, "key1", vs[0]) + require.Equal(t, "key2", vs[1]) + }) + + t.Run("COMMAND GETKEYS ZUNION", func(t *testing.T) { + r := rdb.Do(ctx, "COMMAND", "GETKEYS", "ZUNION", "2", "key1", "key2") + vs, err := r.Slice() + require.NoError(t, err) + require.Len(t, vs, 2) + require.Equal(t, "key1", vs[0]) + require.Equal(t, "key2", vs[1]) + }) + + t.Run("COMMAND GETKEYS ZUNIONSTORE", func(t *testing.T) { + r := rdb.Do(ctx, "COMMAND", "GETKEYS", "ZUNIONSTORE", "dst", "2", "src1", "src2") + vs, err := r.Slice() + require.NoError(t, err) + require.Len(t, vs, 3) + require.Equal(t, "dst", vs[0]) + require.Equal(t, "src1", vs[1]) + require.Equal(t, "src2", vs[2]) + }) + + t.Run("COMMAND GETKEYS ZDIFF", func(t *testing.T) { + r := rdb.Do(ctx, "COMMAND", "GETKEYS", "ZDIFF", "2", "key1", "key2") + vs, err := r.Slice() + require.NoError(t, err) + require.Len(t, vs, 2) + require.Equal(t, "key1", vs[0]) + require.Equal(t, "key2", vs[1]) + }) + + t.Run("COMMAND GETKEYS ZDIFFSTORE", func(t *testing.T) { + r := rdb.Do(ctx, "COMMAND", "GETKEYS", "ZDIFFSTORE", "dst", "2", "src1", "src2") + vs, err := r.Slice() + require.NoError(t, err) + require.Len(t, vs, 3) + require.Equal(t, "dst", vs[0]) + require.Equal(t, "src1", vs[1]) + require.Equal(t, "src2", vs[2]) + }) + + t.Run("COMMAND GETKEYS ZMPOP", func(t *testing.T) { + r := rdb.Do(ctx, "COMMAND", "GETKEYS", "ZMPOP", "2", "key1", "key2") + vs, err := r.Slice() + require.NoError(t, err) + require.Len(t, vs, 2) + require.Equal(t, "key1", vs[0]) + require.Equal(t, "key2", vs[1]) + }) + + t.Run("COMMAND GETKEYS BZMPOP", func(t *testing.T) { + r := rdb.Do(ctx, "COMMAND", "GETKEYS", "BZMPOP", "0", "2", "key1", "key2") + vs, err := r.Slice() + require.NoError(t, err) + require.Len(t, vs, 2) + require.Equal(t, "key1", vs[0]) + require.Equal(t, "key2", vs[1]) + }) + + t.Run("COMMAND GETKEYS LMPOP", func(t *testing.T) { + r := rdb.Do(ctx, "COMMAND", "GETKEYS", "LMPOP", "2", "key1", "key2") + vs, err := r.Slice() + require.NoError(t, err) + require.Len(t, vs, 2) + require.Equal(t, "key1", vs[0]) + require.Equal(t, "key2", vs[1]) + }) + + t.Run("COMMAND GETKEYS BLMPOP", func(t *testing.T) { + r := rdb.Do(ctx, "COMMAND", "GETKEYS", "BLMPOP", "0", "2", "key1", "key2") + vs, err := r.Slice() + require.NoError(t, err) + require.Len(t, vs, 2) + require.Equal(t, "key1", vs[0]) + require.Equal(t, "key2", vs[1]) + }) + + t.Run("COMMAND GETKEYS GEORADIUS", func(t *testing.T) { + // non-store + r := rdb.Do(ctx, "COMMAND", "GETKEYS", "GEORADIUS", "src", "1", "1", "1", "km") + vs, err := r.Slice() + require.NoError(t, err) + require.Len(t, vs, 1) + require.Equal(t, "src", vs[0]) + + // store + r = rdb.Do(ctx, "COMMAND", "GETKEYS", "GEORADIUS", "src", "1", "1", "1", "km", "store", "dst") + vs, err = r.Slice() + require.NoError(t, err) + require.Len(t, vs, 2) + require.Equal(t, "src", vs[0]) + require.Equal(t, "dst", vs[1]) + + // storedist + r = rdb.Do(ctx, "COMMAND", "GETKEYS", "GEORADIUS", "src", "1", "1", "1", "km", "storedist", "dst") + vs, err = r.Slice() + require.NoError(t, err) + require.Len(t, vs, 2) + require.Equal(t, "src", vs[0]) + require.Equal(t, "dst", vs[1]) + + // store + storedist + r = rdb.Do(ctx, "COMMAND", "GETKEYS", "GEORADIUS", "src", "1", "1", "1", "km", "store", "dst1", "storedist", "dst2") + vs, err = r.Slice() + require.NoError(t, err) + require.Len(t, vs, 2) + require.Equal(t, "src", vs[0]) + require.Equal(t, "dst2", vs[1]) + }) + + t.Run("COMMAND GETKEYS GEORADIUSBYMEMBER", func(t *testing.T) { + // non-store + r := rdb.Do(ctx, "COMMAND", "GETKEYS", "GEORADIUSBYMEMBER", "src", "member", "radius", "m") + vs, err := r.Slice() + require.NoError(t, err) + require.Len(t, vs, 1) + require.Equal(t, "src", vs[0]) + + // store + r = rdb.Do(ctx, "COMMAND", "GETKEYS", "GEORADIUSBYMEMBER", "src", "member", "radius", "m", "store", "dst") + vs, err = r.Slice() + require.NoError(t, err) + require.Len(t, vs, 2) + require.Equal(t, "src", vs[0]) + require.Equal(t, "dst", vs[1]) + + // storedist + r = rdb.Do(ctx, "COMMAND", "GETKEYS", "GEORADIUSBYMEMBER", "src", "member", "radius", "m", "storedist", "dst") + vs, err = r.Slice() + require.NoError(t, err) + require.Len(t, vs, 2) + require.Equal(t, "src", vs[0]) + require.Equal(t, "dst", vs[1]) + + // store + storedist + r = rdb.Do(ctx, "COMMAND", "GETKEYS", "GEORADIUSBYMEMBER", "src", "member", "radius", "m", "store", "dst1", "storedist", "dst2") + vs, err = r.Slice() + require.NoError(t, err) + require.Len(t, vs, 2) + require.Equal(t, "src", vs[0]) + require.Equal(t, "dst2", vs[1]) + }) + + t.Run("COMMAND GETKEYS GEOSEARCHSTORE", func(t *testing.T) { + r := rdb.Do(ctx, "COMMAND", "GETKEYS", "GEOSEARCHSTORE", "dst", "src", "frommember", "member", "byradius", "10", "m") + vs, err := r.Slice() + require.NoError(t, err) + require.Len(t, vs, 2) + require.Equal(t, "dst", vs[0]) + require.Equal(t, "src", vs[1]) + }) } diff --git a/tests/gocase/unit/config/config_test.go b/tests/gocase/unit/config/config_test.go index dd880367847..d2643c2cc71 100644 --- a/tests/gocase/unit/config/config_test.go +++ b/tests/gocase/unit/config/config_test.go @@ -133,6 +133,19 @@ func TestConfigSetCompression(t *testing.T) { require.ErrorContains(t, rdb.ConfigSet(ctx, configKey, "unsupported").Err(), "invalid enum option") } +func TestConfigGetRESP3(t *testing.T) { + srv := util.StartServer(t, map[string]string{ + "resp3-enabled": "yes", + }) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + val := rdb.ConfigGet(ctx, "resp3-enabled").Val() + require.EqualValues(t, "yes", val["resp3-enabled"]) +} + func TestStartWithoutConfigurationFile(t *testing.T) { srv := util.StartServerWithCLIOptions(t, false, map[string]string{}, []string{}) defer srv.Close() diff --git a/tests/gocase/unit/debug/debug_test.go b/tests/gocase/unit/debug/debug_test.go index 6b65ad8c6b3..416e4a7d555 100644 --- a/tests/gocase/unit/debug/debug_test.go +++ b/tests/gocase/unit/debug/debug_test.go @@ -21,6 +21,7 @@ package debug import ( "context" + "math/big" "testing" "github.com/redis/go-redis/v9" @@ -43,7 +44,11 @@ func TestDebugProtocolV2(t *testing.T) { types := map[string]interface{}{ "string": "Hello World", "integer": int64(12345), + "double": "3.141", "array": []interface{}{int64(0), int64(1), int64(2)}, + "set": []interface{}{int64(0), int64(1), int64(2)}, + "map": []interface{}{int64(0), int64(0), int64(1), int64(1), int64(2), int64(0)}, + "bignum": "1234567999999999999999999999999999999", "true": int64(1), "false": int64(0), } @@ -81,10 +86,15 @@ func TestDebugProtocolV3(t *testing.T) { defer func() { require.NoError(t, rdb.Close()) }() t.Run("debug protocol type", func(t *testing.T) { + bignum, _ := big.NewInt(0).SetString("1234567999999999999999999999999999999", 10) types := map[string]interface{}{ "string": "Hello World", "integer": int64(12345), + "double": 3.141, "array": []interface{}{int64(0), int64(1), int64(2)}, + "set": []interface{}{int64(0), int64(1), int64(2)}, + "map": map[interface{}]interface{}{int64(0): false, int64(1): true, int64(2): false}, + "bignum": bignum, "true": true, "false": false, } diff --git a/tests/gocase/unit/geo/geo_test.go b/tests/gocase/unit/geo/geo_test.go index 6db8c222a41..ebbbcd323d9 100644 --- a/tests/gocase/unit/geo/geo_test.go +++ b/tests/gocase/unit/geo/geo_test.go @@ -86,8 +86,18 @@ func compareLists(list1, list2 []string) []string { return result } -func TestGeo(t *testing.T) { - srv := util.StartServer(t, map[string]string{}) +func TestGeoWithRESP2(t *testing.T) { + testGeo(t, "no") +} + +func TestGeoWithRESP3(t *testing.T) { + testGeo(t, "yes") +} + +var testGeo = func(t *testing.T, enabledRESP3 string) { + srv := util.StartServer(t, map[string]string{ + "resp3-enabled": enabledRESP3, + }) defer srv.Close() ctx := context.Background() rdb := srv.NewClient() diff --git a/tests/gocase/unit/hello/hello_test.go b/tests/gocase/unit/hello/hello_test.go index d965b29c28e..36296b05f09 100644 --- a/tests/gocase/unit/hello/hello_test.go +++ b/tests/gocase/unit/hello/hello_test.go @@ -86,15 +86,16 @@ func TestEnableRESP3(t *testing.T) { rdb := srv.NewClient() defer func() { require.NoError(t, rdb.Close()) }() - r := rdb.Do(ctx, "HELLO", "2") - rList := r.Val().([]interface{}) + r, err := rdb.Do(ctx, "HELLO", "2").Result() + require.NoError(t, err) + rList := r.([]interface{}) require.EqualValues(t, rList[2], "proto") require.EqualValues(t, rList[3], 2) - r = rdb.Do(ctx, "HELLO", "3") - rList = r.Val().([]interface{}) - require.EqualValues(t, rList[2], "proto") - require.EqualValues(t, rList[3], 3) + r, err = rdb.Do(ctx, "HELLO", "3").Result() + require.NoError(t, err) + rMap := r.(map[interface{}]interface{}) + require.EqualValues(t, rMap["proto"], 3) } func TestHelloWithAuth(t *testing.T) { diff --git a/tests/gocase/unit/protocol/protocol_test.go b/tests/gocase/unit/protocol/protocol_test.go index 7896cf00548..61db7cf1bf7 100644 --- a/tests/gocase/unit/protocol/protocol_test.go +++ b/tests/gocase/unit/protocol/protocol_test.go @@ -153,7 +153,11 @@ func TestProtocolRESP2(t *testing.T) { types := map[string][]string{ "string": {"$11", "Hello World"}, "integer": {":12345"}, + "double": {"$5", "3.141"}, "array": {"*3", ":0", ":1", ":2"}, + "set": {"*3", ":0", ":1", ":2"}, + "map": {"*6", ":0", ":0", ":1", ":1", ":2", ":0"}, + "bignum": {"$37", "1234567999999999999999999999999999999"}, "true": {":1"}, "false": {":0"}, "null": {"$-1"}, @@ -197,7 +201,7 @@ func TestProtocolRESP3(t *testing.T) { t.Run("debug protocol string", func(t *testing.T) { require.NoError(t, c.WriteArgs("HELLO", "3")) - values := []string{"*6", "$6", "server", "$5", "redis", "$5", "proto", ":3", "$4", "mode", "$10", "standalone"} + values := []string{"%3", "$6", "server", "$5", "redis", "$5", "proto", ":3", "$4", "mode", "$10", "standalone"} for _, line := range values { c.MustRead(t, line) } @@ -205,7 +209,11 @@ func TestProtocolRESP3(t *testing.T) { types := map[string][]string{ "string": {"$11", "Hello World"}, "integer": {":12345"}, + "double": {",3.141"}, "array": {"*3", ":0", ":1", ":2"}, + "set": {"~3", ":0", ":1", ":2"}, + "map": {"%3", ":0", "#f", ":1", "#t", ":2", "#f"}, + "bignum": {"(1234567999999999999999999999999999999"}, "true": {"#t"}, "false": {"#f"}, "null": {"_"}, diff --git a/tests/gocase/unit/rename/rename_test.go b/tests/gocase/unit/rename/rename_test.go new file mode 100644 index 00000000000..7bbd4a5284d --- /dev/null +++ b/tests/gocase/unit/rename/rename_test.go @@ -0,0 +1,1044 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package rename + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/apache/kvrocks/tests/gocase/util" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +func TestRename_String(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + t.Run("Rename string", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Set(ctx, "a", "hello", 0).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a1").Val()) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Set(ctx, "a", "hello", 0).Err()) + require.NoError(t, rdb.Set(ctx, "a1", "world", 0).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a1").Val()) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value with TTL + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Set(ctx, "a", "hello", 10*time.Second).Err()) + require.NoError(t, rdb.Set(ctx, "a1", "world", 1000*time.Second).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a1").Val()) + util.BetweenValues(t, rdb.TTL(ctx, "a1").Val(), time.Second, 10*time.Second) + + // to-key has value that not same type + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Set(ctx, "a", "hello", 0).Err()) + require.NoError(t, rdb.LPush(ctx, "a1", "world").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a1").Val()) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.Set(ctx, "a", "hello", 0).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a").Err()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a").Val()) + + // rename*3 + require.NoError(t, rdb.Del(ctx, "a", "a1", "a2", "a3").Err()) + require.NoError(t, rdb.Set(ctx, "a", "hello", 0).Err()) + require.NoError(t, rdb.Set(ctx, "a1", "world1", 0).Err()) + require.NoError(t, rdb.Set(ctx, "a2", "world2", 0).Err()) + require.NoError(t, rdb.Set(ctx, "a3", "world3", 0).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Rename(ctx, "a1", "a2").Err()) + require.NoError(t, rdb.Rename(ctx, "a2", "a3").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, "", rdb.Get(ctx, "a1").Val()) + require.EqualValues(t, "", rdb.Get(ctx, "a2").Val()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a3").Val()) + }) + + t.Run("RenameNX string", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Set(ctx, "a", "hello", 0).Err()) + require.NoError(t, rdb.Set(ctx, "a1", "world", 0).Err()) + require.EqualValues(t, false, rdb.RenameNX(ctx, "a", "a1").Val()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, "world", rdb.Get(ctx, "a1").Val()) + + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Set(ctx, "a", "hello", 0).Err()) + require.EqualValues(t, true, rdb.RenameNX(ctx, "a", "a1").Val()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a1").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Set(ctx, "a", "hello", 0).Err()) + require.EqualValues(t, false, rdb.RenameNX(ctx, "a", "a").Val()) + require.EqualValues(t, "hello", rdb.Get(ctx, "a").Val()) + }) + +} + +func TestRename_JSON(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + setCmd := "JSON.SET" + getCmd := "JSON.GET" + jsonA := `{"x":1,"y":2}` + jsonB := `{"x":1}` + + t.Run("Rename json", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a", "$", jsonA).Err()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a").Val()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, nil, rdb.Do(ctx, getCmd, "a").Val()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a1").Val()) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a", "$", jsonA).Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a1", "$", jsonB).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, nil, rdb.Do(ctx, getCmd, "a").Val()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a1").Val()) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value with TTL + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a", "$", jsonA).Err()) + require.NoError(t, rdb.Expire(ctx, "a", 10*time.Second).Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a1", "$", jsonA).Err()) + require.NoError(t, rdb.Expire(ctx, "a1", 1000*time.Second).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, nil, rdb.Do(ctx, getCmd, "a").Val()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a1").Val()) + util.BetweenValues(t, rdb.TTL(ctx, "a1").Val(), time.Second, 10*time.Second) + + // to-key has value that not same type + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a", "$", jsonA).Err()) + require.NoError(t, rdb.LPush(ctx, "a1", "world").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, nil, rdb.Do(ctx, getCmd, "a").Val()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a1").Val()) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a", "$", jsonA).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a").Err()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a").Val()) + + // rename*3 + require.NoError(t, rdb.Del(ctx, "a", "a1", "a2", "a3").Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a", "$", jsonA).Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a1", "$", jsonB).Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a2", "$", jsonB).Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a3", "$", jsonB).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Rename(ctx, "a1", "a2").Err()) + require.NoError(t, rdb.Rename(ctx, "a2", "a3").Err()) + require.EqualValues(t, nil, rdb.Do(ctx, getCmd, "a").Val()) + require.EqualValues(t, nil, rdb.Do(ctx, getCmd, "a1").Val()) + require.EqualValues(t, nil, rdb.Do(ctx, getCmd, "a2").Val()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a3").Val()) + }) + + t.Run("RenameNX json", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a", "$", jsonA).Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a1", "$", jsonB).Err()) + require.EqualValues(t, false, rdb.RenameNX(ctx, "a", "a1").Val()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a").Val()) + require.EqualValues(t, jsonB, rdb.Do(ctx, getCmd, "a1").Val()) + + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a", "$", jsonA).Err()) + require.EqualValues(t, true, rdb.RenameNX(ctx, "a", "a1").Val()) + require.EqualValues(t, nil, rdb.Do(ctx, getCmd, "a").Val()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a1").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, setCmd, "a", "$", jsonA).Err()) + require.EqualValues(t, false, rdb.RenameNX(ctx, "a", "a").Val()) + require.EqualValues(t, jsonA, rdb.Do(ctx, getCmd, "a").Val()) + }) + +} + +func TestRename_List(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + EqualListValues := func(t *testing.T, key string, value []string) { + require.EqualValues(t, len(value), rdb.LLen(ctx, key).Val()) + for i := 0; i < len(value); i++ { + require.EqualValues(t, value[i], rdb.LIndex(ctx, key, int64(i)).Val()) + } + } + + t.Run("Rename string", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.LPush(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, 3, rdb.LLen(ctx, "a1").Val()) + EqualListValues(t, "a1", []string{"3", "2", "1"}) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.LPush(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.LPush(ctx, "a1", "a").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + EqualListValues(t, "a1", []string{"3", "2", "1"}) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value with TTL + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.LPush(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.Expire(ctx, "a", 10*time.Second).Err()) + require.NoError(t, rdb.LPush(ctx, "a1", "a").Err()) + require.NoError(t, rdb.Expire(ctx, "a1", 1000*time.Second).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + EqualListValues(t, "a1", []string{"3", "2", "1"}) + util.BetweenValues(t, rdb.TTL(ctx, "a1").Val(), time.Second, 10*time.Second) + + // to-key has value that not same type + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.LPush(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.Set(ctx, "a1", "world", 0).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + EqualListValues(t, "a1", []string{"3", "2", "1"}) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.LPush(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a").Err()) + EqualListValues(t, "a1", []string{"3", "2", "1"}) + + // rename*3 + require.NoError(t, rdb.Del(ctx, "a", "a1", "a2", "a3").Err()) + require.NoError(t, rdb.LPush(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.LPush(ctx, "a1", "2").Err()) + require.NoError(t, rdb.LPush(ctx, "a2", "3").Err()) + require.NoError(t, rdb.LPush(ctx, "a3", "1").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Rename(ctx, "a1", "a2").Err()) + require.NoError(t, rdb.Rename(ctx, "a2", "a3").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, "", rdb.Get(ctx, "a1").Val()) + require.EqualValues(t, "", rdb.Get(ctx, "a2").Val()) + EqualListValues(t, "a3", []string{"3", "2", "1"}) + }) + + t.Run("RenameNX string", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.LPush(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.LPush(ctx, "a1", "3").Err()) + require.EqualValues(t, false, rdb.RenameNX(ctx, "a", "a1").Val()) + EqualListValues(t, "a", []string{"3", "2", "1"}) + EqualListValues(t, "a1", []string{"3"}) + + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.LPush(ctx, "a", "1", "2", "3").Err()) + require.EqualValues(t, true, rdb.RenameNX(ctx, "a", "a1").Val()) + EqualListValues(t, "a1", []string{"3", "2", "1"}) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.LPush(ctx, "a", "1", "2", "3").Err()) + require.EqualValues(t, false, rdb.RenameNX(ctx, "a", "a").Val()) + EqualListValues(t, "a", []string{"3", "2", "1"}) + }) + +} + +func TestRename_hash(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + EqualListValues := func(t *testing.T, key string, value map[string]string) { + require.EqualValues(t, len(value), rdb.HLen(ctx, key).Val()) + for subKey := range value { + require.EqualValues(t, value[subKey], rdb.HGet(ctx, key, subKey).Val()) + } + } + + t.Run("Rename hash", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.HSet(ctx, "a", "a", "1", "b", "2", "c", "3").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + EqualListValues(t, "a1", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.HSet(ctx, "a", "a", "1", "b", "2", "c", "3").Err()) + require.NoError(t, rdb.HSet(ctx, "a1", "a", "1").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + EqualListValues(t, "a1", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value with TTL + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.HSet(ctx, "a", "a", "1", "b", "2", "c", "3").Err()) + require.NoError(t, rdb.Expire(ctx, "a", 10*time.Second).Err()) + require.NoError(t, rdb.HSet(ctx, "a1", "a", "1").Err()) + require.NoError(t, rdb.Expire(ctx, "a1", 1000*time.Second).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + EqualListValues(t, "a1", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + util.BetweenValues(t, rdb.TTL(ctx, "a1").Val(), time.Second, 10*time.Second) + + // to-key has value that not same type + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.HSet(ctx, "a", "a", "1", "b", "2", "c", "3").Err()) + require.NoError(t, rdb.LPush(ctx, "a1", "a", "1").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + EqualListValues(t, "a1", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.HSet(ctx, "a", "a", "1", "b", "2", "c", "3").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a").Err()) + EqualListValues(t, "a1", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + + // rename*3 + require.NoError(t, rdb.Del(ctx, "a", "a1", "a2", "a3").Err()) + require.NoError(t, rdb.HSet(ctx, "a", "a", "1", "b", "2", "c", "3").Err()) + require.NoError(t, rdb.HSet(ctx, "a1", "a", "1").Err()) + require.NoError(t, rdb.HSet(ctx, "a2", "a", "1").Err()) + require.NoError(t, rdb.HSet(ctx, "a3", "a", "1").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Rename(ctx, "a1", "a2").Err()) + require.NoError(t, rdb.Rename(ctx, "a2", "a3").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, "", rdb.Get(ctx, "a1").Val()) + require.EqualValues(t, "", rdb.Get(ctx, "a2").Val()) + EqualListValues(t, "a3", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + }) + + t.Run("RenameNX hash", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.HSet(ctx, "a", "a", "1", "b", "2", "c", "3").Err()) + require.NoError(t, rdb.HSet(ctx, "a1", "a", "1").Err()) + require.EqualValues(t, false, rdb.RenameNX(ctx, "a", "a1").Val()) + EqualListValues(t, "a", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + EqualListValues(t, "a1", map[string]string{ + "a": "1", + }) + + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.HSet(ctx, "a", "a", "1", "b", "2", "c", "3").Err()) + require.EqualValues(t, true, rdb.RenameNX(ctx, "a", "a1").Val()) + EqualListValues(t, "a1", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.HSet(ctx, "a", "a", "1", "b", "2", "c", "3").Err()) + require.EqualValues(t, false, rdb.RenameNX(ctx, "a", "a").Val()) + EqualListValues(t, "a", map[string]string{ + "a": "1", + "b": "2", + "c": "3", + }) + }) + +} + +func TestRename_set(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + EqualSetValues := func(t *testing.T, key string, value []string) { + require.EqualValues(t, len(value), rdb.SCard(ctx, key).Val()) + for index := range value { + require.EqualValues(t, true, rdb.SIsMember(ctx, key, value[index]).Val()) + } + } + + t.Run("Rename set", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.SAdd(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + EqualSetValues(t, "a1", []string{"1", "2", "3"}) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.SAdd(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.SAdd(ctx, "a1", "a", "1").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + EqualSetValues(t, "a1", []string{"1", "2", "3"}) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value with TTL + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.SAdd(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.Expire(ctx, "a", 10*time.Second).Err()) + require.NoError(t, rdb.SAdd(ctx, "a1", "1").Err()) + require.NoError(t, rdb.Expire(ctx, "a1", 1000*time.Second).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + EqualSetValues(t, "a1", []string{"1", "2", "3"}) + util.BetweenValues(t, rdb.TTL(ctx, "a1").Val(), time.Second, 10*time.Second) + + // to-key has value that not same type + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.SAdd(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.LPush(ctx, "a1", "1").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + EqualSetValues(t, "a1", []string{"1", "2", "3"}) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.SAdd(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a").Err()) + EqualSetValues(t, "a1", []string{"1", "2", "3"}) + + // rename*3 + require.NoError(t, rdb.Del(ctx, "a", "a1", "a2", "a3").Err()) + require.NoError(t, rdb.SAdd(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.SAdd(ctx, "a1", "1").Err()) + require.NoError(t, rdb.SAdd(ctx, "a2", "a2", "1").Err()) + require.NoError(t, rdb.SAdd(ctx, "a3", "a3", "1").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Rename(ctx, "a1", "a2").Err()) + require.NoError(t, rdb.Rename(ctx, "a2", "a3").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, "", rdb.Get(ctx, "a1").Val()) + require.EqualValues(t, "", rdb.Get(ctx, "a2").Val()) + EqualSetValues(t, "a3", []string{"1", "2", "3"}) + }) + + t.Run("RenameNX set", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.SAdd(ctx, "a", "1", "2", "3").Err()) + require.NoError(t, rdb.SAdd(ctx, "a1", "1").Err()) + require.EqualValues(t, false, rdb.RenameNX(ctx, "a", "a1").Val()) + EqualSetValues(t, "a", []string{"1", "2", "3"}) + EqualSetValues(t, "a1", []string{"1"}) + + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.SAdd(ctx, "a", "1", "2", "3").Err()) + require.EqualValues(t, true, rdb.RenameNX(ctx, "a", "a1").Val()) + EqualSetValues(t, "a1", []string{"1", "2", "3"}) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.SAdd(ctx, "a", "1", "2", "3").Err()) + require.EqualValues(t, false, rdb.RenameNX(ctx, "a", "a").Val()) + EqualSetValues(t, "a", []string{"1", "2", "3"}) + + }) + +} + +func TestRename_zset(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + EqualZSetValues := func(t *testing.T, key string, value map[string]int) { + require.EqualValues(t, len(value), rdb.ZCard(ctx, key).Val()) + for subKey := range value { + score := value[subKey] + require.EqualValues(t, []string{subKey}, rdb.ZRangeByScore(ctx, key, + &redis.ZRangeBy{Max: strconv.Itoa(score), Min: strconv.Itoa(score)}).Val()) + require.EqualValues(t, float64(score), rdb.ZScore(ctx, key, subKey).Val()) + } + } + + zMember := []redis.Z{{Member: "a", Score: 1}, {Member: "b", Score: 2}, {Member: "c", Score: 3}} + zMember2 := []redis.Z{{Member: "a", Score: 2}} + + t.Run("Rename zset", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.ZAdd(ctx, "a", zMember...).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + EqualZSetValues(t, "a1", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.ZAdd(ctx, "a", zMember...).Err()) + require.NoError(t, rdb.ZAdd(ctx, "a1", zMember2...).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + EqualZSetValues(t, "a1", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value with TTL + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.ZAdd(ctx, "a", zMember...).Err()) + require.NoError(t, rdb.Expire(ctx, "a", 10*time.Second).Err()) + require.NoError(t, rdb.ZAdd(ctx, "a1", zMember2...).Err()) + require.NoError(t, rdb.Expire(ctx, "a1", 1000*time.Second).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + EqualZSetValues(t, "a1", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + util.BetweenValues(t, rdb.TTL(ctx, "a1").Val(), time.Second, 10*time.Second) + + // to-key has value that not same type + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.ZAdd(ctx, "a", zMember...).Err()) + require.NoError(t, rdb.LPush(ctx, "a1", 1, 2, 3).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + EqualZSetValues(t, "a1", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.ZAdd(ctx, "a", zMember...).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a").Err()) + EqualZSetValues(t, "a1", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + // rename*3 + require.NoError(t, rdb.Del(ctx, "a", "a1", "a2", "a3").Err()) + require.NoError(t, rdb.ZAdd(ctx, "a", zMember...).Err()) + require.NoError(t, rdb.ZAdd(ctx, "a1", zMember2...).Err()) + require.NoError(t, rdb.ZAdd(ctx, "a2", zMember2...).Err()) + require.NoError(t, rdb.ZAdd(ctx, "a3", zMember2...).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Rename(ctx, "a1", "a2").Err()) + require.NoError(t, rdb.Rename(ctx, "a2", "a3").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, "", rdb.Get(ctx, "a1").Val()) + require.EqualValues(t, "", rdb.Get(ctx, "a2").Val()) + EqualZSetValues(t, "a3", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + + }) + + t.Run("RenameNX zset", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.ZAdd(ctx, "a", zMember...).Err()) + require.NoError(t, rdb.ZAdd(ctx, "a1", zMember2...).Err()) + require.EqualValues(t, false, rdb.RenameNX(ctx, "a", "a1").Val()) + EqualZSetValues(t, "a", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + EqualZSetValues(t, "a1", map[string]int{ + "a": 2, + }) + + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.ZAdd(ctx, "a", zMember...).Err()) + require.EqualValues(t, true, rdb.RenameNX(ctx, "a", "a1").Val()) + EqualZSetValues(t, "a1", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.ZAdd(ctx, "a", zMember...).Err()) + require.EqualValues(t, false, rdb.RenameNX(ctx, "a", "a").Val()) + EqualZSetValues(t, "a", map[string]int{ + "a": 1, + "b": 2, + "c": 3, + }) + + }) + +} + +func TestRename_Bitmap(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + EqualBitSetValues := func(t *testing.T, key string, value []int64) { + for i := 0; i < len(value); i++ { + require.EqualValues(t, int64(value[i]), rdb.Do(ctx, "BITPOS", key, 1, value[i]/8).Val()) + } + } + + SetBits := func(t *testing.T, key string, value []int64) { + for i := 0; i < len(value); i++ { + require.NoError(t, rdb.Do(ctx, "SETBIT", key, value[i], 1).Err()) + } + } + bitSetA := []int64{16, 1024 * 8 * 2, 1024 * 8 * 12} + bitSetB := []int64{1} + + t.Run("Rename Bitmap", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + SetBits(t, "a", bitSetA) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + EqualBitSetValues(t, "a1", bitSetA) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // newkey has value with TTL + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + SetBits(t, "a", bitSetA) + SetBits(t, "a1", bitSetB) + require.NoError(t, rdb.Expire(ctx, "a", 10*time.Second).Err()) + require.NoError(t, rdb.Expire(ctx, "a1", 1000*time.Second).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + EqualBitSetValues(t, "a1", bitSetA) + util.BetweenValues(t, rdb.TTL(ctx, "a1").Val(), time.Second, 10*time.Second) + + // newkey has value that not same type + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + SetBits(t, "a", bitSetA) + require.NoError(t, rdb.LPush(ctx, "a1", "a").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + EqualBitSetValues(t, "a1", bitSetA) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + SetBits(t, "a", bitSetA) + require.NoError(t, rdb.Rename(ctx, "a", "a").Err()) + EqualBitSetValues(t, "a", bitSetA) + + // rename*3 + require.NoError(t, rdb.Del(ctx, "a", "a1", "a2", "a3").Err()) + SetBits(t, "a", bitSetA) + SetBits(t, "a1", bitSetB) + SetBits(t, "a2", bitSetB) + SetBits(t, "a3", bitSetB) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Rename(ctx, "a1", "a2").Err()) + require.NoError(t, rdb.Rename(ctx, "a2", "a3").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, "", rdb.Get(ctx, "a1").Val()) + require.EqualValues(t, "", rdb.Get(ctx, "a2").Val()) + EqualBitSetValues(t, "a3", bitSetA) + }) + + t.Run("RenameNX Bitmap", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + SetBits(t, "a", bitSetA) + SetBits(t, "a1", bitSetB) + require.EqualValues(t, false, rdb.RenameNX(ctx, "a", "a1").Val()) + EqualBitSetValues(t, "a", bitSetA) + EqualBitSetValues(t, "a1", bitSetB) + + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + SetBits(t, "a", bitSetA) + require.EqualValues(t, true, rdb.RenameNX(ctx, "a", "a1").Val()) + EqualBitSetValues(t, "a1", bitSetA) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + SetBits(t, "a", bitSetA) + require.EqualValues(t, false, rdb.RenameNX(ctx, "a", "a").Val()) + EqualBitSetValues(t, "a", bitSetA) + + }) + +} + +func TestRename_SInt(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + EqualSIntValues := func(t *testing.T, key string, value []int) { + require.EqualValues(t, len(value), rdb.Do(ctx, "SICARD", key).Val()) + for i := 0; i < len(value); i++ { + require.EqualValues(t, []interface{}{int64(1)}, rdb.Do(ctx, "SIEXISTS", key, value[i]).Val()) + } + } + + t.Run("Rename SInt", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a", 3, 4, 5, 123, 245).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + EqualSIntValues(t, "a1", []int{3, 4, 5, 123, 245}) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value with TTL + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a", 3, 4, 5, 123, 245).Err()) + require.NoError(t, rdb.Expire(ctx, "a", 10*time.Second).Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a1", 99).Err()) + require.NoError(t, rdb.Expire(ctx, "a1", 1000*time.Second).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + EqualSIntValues(t, "a1", []int{3, 4, 5, 123, 245}) + util.BetweenValues(t, rdb.TTL(ctx, "a1").Val(), time.Second, 10*time.Second) + + // to-key has value that not same type + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a", 3, 4, 5, 123, 245).Err()) + require.NoError(t, rdb.LPush(ctx, "a1", "a").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + EqualSIntValues(t, "a1", []int{3, 4, 5, 123, 245}) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a", 3, 4, 5, 123, 245).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a").Err()) + EqualSIntValues(t, "a1", []int{3, 4, 5, 123, 245}) + + // rename*3 + require.NoError(t, rdb.Del(ctx, "a", "a1", "a2", "a3").Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a", 3, 4, 5, 123, 245).Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a1", 85).Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a2", 77, 0).Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a3", 111, 222, 333).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Rename(ctx, "a1", "a2").Err()) + require.NoError(t, rdb.Rename(ctx, "a2", "a3").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, "", rdb.Get(ctx, "a1").Val()) + require.EqualValues(t, "", rdb.Get(ctx, "a2").Val()) + EqualSIntValues(t, "a3", []int{3, 4, 5, 123, 245}) + }) + + t.Run("RenameNX SInt", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a", 3, 4, 5, 123, 245).Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a1", 99).Err()) + require.EqualValues(t, false, rdb.RenameNX(ctx, "a", "a1").Val()) + EqualSIntValues(t, "a", []int{3, 4, 5, 123, 245}) + EqualSIntValues(t, "a1", []int{99}) + + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a", 3, 4, 5, 123, 245).Err()) + require.EqualValues(t, true, rdb.RenameNX(ctx, "a", "a1").Val()) + EqualSIntValues(t, "a1", []int{3, 4, 5, 123, 245}) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, "SIADD", "a", 3, 4, 5, 123, 245).Err()) + require.EqualValues(t, false, rdb.RenameNX(ctx, "a", "a").Val()) + EqualSIntValues(t, "a", []int{3, 4, 5, 123, 245}) + + }) + +} + +func TestRename_Bloom(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + bfAdd := "BF.ADD" + bfExists := "BF.EXISTS" + + t.Run("Rename Bloom", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a", "hello").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a1", "hello").Val()) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value with TTL + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a", "hello").Err()) + require.NoError(t, rdb.Expire(ctx, "a", 10*time.Second).Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a1", "world").Err()) + require.NoError(t, rdb.Expire(ctx, "a1", 1000*time.Second).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a1", "hello").Val()) + require.EqualValues(t, 0, rdb.Do(ctx, bfExists, "a1", "world").Val()) + util.BetweenValues(t, rdb.TTL(ctx, "a1").Val(), time.Second, 10*time.Second) + + // to-key has value that not same type + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a", "hello").Err()) + require.NoError(t, rdb.LPush(ctx, "a1", "a").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a1", "hello").Val()) + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a", "hello").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a").Err()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a", "hello").Val()) + + // rename*3 + require.NoError(t, rdb.Del(ctx, "a", "a1", "a2", "a3").Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a", "hello").Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a1", "world1").Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a2", "world2").Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a3", "world3").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Rename(ctx, "a1", "a2").Err()) + require.NoError(t, rdb.Rename(ctx, "a2", "a3").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, "", rdb.Get(ctx, "a1").Val()) + require.EqualValues(t, "", rdb.Get(ctx, "a2").Val()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a3", "hello").Val()) + }) + + t.Run("RenameNX Bloom", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a", "hello").Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a1", "world").Err()) + require.EqualValues(t, false, rdb.RenameNX(ctx, "a", "a1").Val()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a", "hello").Val()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a1", "world").Val()) + + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a", "hello").Err()) + require.EqualValues(t, true, rdb.RenameNX(ctx, "a", "a1").Val()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a1", "hello").Val()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, bfAdd, "a", "hello").Err()) + require.EqualValues(t, false, rdb.RenameNX(ctx, "a", "a").Val()) + require.EqualValues(t, 1, rdb.Do(ctx, bfExists, "a", "hello").Val()) + }) +} + +func TestRename_Stream(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + XADD := "XADD" + XREAD := "XREAD" + t.Run("Rename Stream", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a", "*", "a", "hello").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a1", "0").String(), "hello") + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // to-key has value with TTL + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a", "*", "a", "hello").Err()) + require.NoError(t, rdb.Expire(ctx, "a", 10*time.Second).Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a1", "*", "a", "world").Err()) + require.NoError(t, rdb.Expire(ctx, "a1", 1000*time.Second).Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a1", "0").String(), "hello") + util.BetweenValues(t, rdb.TTL(ctx, "a1").Val(), time.Second, 10*time.Second) + + // to-key has value that not same type + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a", "*", "a", "hello").Err()) + require.NoError(t, rdb.LPush(ctx, "a1", "a").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a1", "0").String(), "hello") + require.EqualValues(t, -1, rdb.TTL(ctx, "a1").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a").Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a", "*", "a", "hello").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a").Err()) + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a", "0").String(), "hello") + + // rename*3 + require.NoError(t, rdb.Del(ctx, "a", "a1", "a2", "a3").Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a", "*", "a", "hello").Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a1", "*", "a", "world1").Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a2", "*", "a", "world2").Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a3", "*", "a", "world3").Err()) + require.NoError(t, rdb.Rename(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Rename(ctx, "a1", "a2").Err()) + require.NoError(t, rdb.Rename(ctx, "a2", "a3").Err()) + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + require.EqualValues(t, "", rdb.Get(ctx, "a1").Val()) + require.EqualValues(t, "", rdb.Get(ctx, "a2").Val()) + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a3", "0").String(), "hello") + }) + + t.Run("RenameNX Stream", func(t *testing.T) { + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a", "*", "a", "hello").Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a1", "*", "a", "world").Err()) + require.EqualValues(t, false, rdb.RenameNX(ctx, "a", "a1").Val()) + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a", "0").String(), "hello") + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a1", "0").String(), "world") + + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a", "*", "a", "hello").Err()) + require.EqualValues(t, true, rdb.RenameNX(ctx, "a", "a1").Val()) + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a1", "0").String(), "hello") + require.EqualValues(t, "", rdb.Get(ctx, "a").Val()) + + // key == newkey + require.NoError(t, rdb.Del(ctx, "a", "a1").Err()) + require.NoError(t, rdb.Do(ctx, XADD, "a", "*", "a", "hello").Err()) + require.EqualValues(t, false, rdb.RenameNX(ctx, "a", "a").Val()) + require.Contains(t, rdb.Do(ctx, XREAD, "STREAMS", "a", "0").String(), "hello") + }) +} + +func TestRename_Error(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + t.Run("Rename from empty key", func(t *testing.T) { + require.Error(t, rdb.Rename(ctx, ".empty", "a").Err()) + require.Error(t, rdb.RenameNX(ctx, ".empty", "a").Err()) + }) + +} diff --git a/tests/gocase/unit/scripting/function_test.go b/tests/gocase/unit/scripting/function_test.go index 3e262db218a..d0c12ec157a 100644 --- a/tests/gocase/unit/scripting/function_test.go +++ b/tests/gocase/unit/scripting/function_test.go @@ -25,6 +25,8 @@ import ( "strings" "testing" + "github.com/redis/go-redis/v9" + "github.com/apache/kvrocks/tests/gocase/util" "github.com/stretchr/testify/require" ) @@ -38,8 +40,74 @@ var luaMylib2 string //go:embed mylib3.lua var luaMylib3 string -func TestFunction(t *testing.T) { - srv := util.StartServer(t, map[string]string{}) +type ListFuncResult struct { + Name string + Library string +} + +func decodeListFuncResult(t *testing.T, v interface{}) ListFuncResult { + switch res := v.(type) { + case []interface{}: + require.EqualValues(t, 4, len(res)) + require.EqualValues(t, "function_name", res[0]) + require.EqualValues(t, "from_library", res[2]) + return ListFuncResult{ + Name: res[1].(string), + Library: res[3].(string), + } + case map[interface{}]interface{}: + require.EqualValues(t, 2, len(res)) + return ListFuncResult{ + Name: res["function_name"].(string), + Library: res["from_library"].(string), + } + } + require.Fail(t, "unexpected type") + return ListFuncResult{} +} + +type ListLibResult struct { + Name string + Engine string + Functions []interface{} +} + +func decodeListLibResult(t *testing.T, v interface{}) ListLibResult { + switch res := v.(type) { + case []interface{}: + require.EqualValues(t, 6, len(res)) + require.EqualValues(t, "library_name", res[0]) + require.EqualValues(t, "engine", res[2]) + require.EqualValues(t, "functions", res[4]) + return ListLibResult{ + Name: res[1].(string), + Engine: res[3].(string), + Functions: res[5].([]interface{}), + } + case map[interface{}]interface{}: + require.EqualValues(t, 3, len(res)) + return ListLibResult{ + Name: res["library_name"].(string), + Engine: res["engine"].(string), + Functions: res["functions"].([]interface{}), + } + } + require.Fail(t, "unexpected type") + return ListLibResult{} +} + +func TestFunctionsWithRESP3(t *testing.T) { + testFunctions(t, "yes") +} + +func TestFunctionsWithoutRESP2(t *testing.T) { + testFunctions(t, "no") +} + +var testFunctions = func(t *testing.T, enabledRESP3 string) { + srv := util.StartServer(t, map[string]string{ + "resp3-enabled": enabledRESP3, + }) defer srv.Close() ctx := context.Background() @@ -65,17 +133,22 @@ func TestFunction(t *testing.T) { }) t.Run("FUNCTION LIST and FUNCTION LISTFUNC mylib1", func(t *testing.T) { - list := rdb.Do(ctx, "FUNCTION", "LIST", "WITHCODE").Val().([]interface{}) - require.Equal(t, list[1].(string), "mylib1") - require.Equal(t, list[3].(string), luaMylib1) - require.Equal(t, len(list), 4) - - list = rdb.Do(ctx, "FUNCTION", "LISTFUNC").Val().([]interface{}) - require.Equal(t, list[1].(string), "add") - require.Equal(t, list[3].(string), "mylib1") - require.Equal(t, list[5].(string), "inc") - require.Equal(t, list[7].(string), "mylib1") - require.Equal(t, len(list), 8) + libraries, err := rdb.FunctionList(ctx, redis.FunctionListQuery{ + WithCode: true, + }).Result() + require.NoError(t, err) + require.EqualValues(t, 1, len(libraries)) + require.Equal(t, "mylib1", libraries[0].Name) + require.Equal(t, luaMylib1, libraries[0].Code) + + list := rdb.Do(ctx, "FUNCTION", "LISTFUNC").Val().([]interface{}) + require.EqualValues(t, 2, len(list)) + f1 := decodeListFuncResult(t, list[0]) + require.Equal(t, "add", f1.Name) + require.Equal(t, "mylib1", f1.Library) + f2 := decodeListFuncResult(t, list[1]) + require.Equal(t, "inc", f2.Name) + require.Equal(t, "mylib1", f2.Library) }) t.Run("FUNCTION LOAD and FCALL mylib2", func(t *testing.T) { @@ -87,23 +160,25 @@ func TestFunction(t *testing.T) { }) t.Run("FUNCTION LIST and FUNCTION LISTFUNC mylib2", func(t *testing.T) { - list := rdb.Do(ctx, "FUNCTION", "LIST", "WITHCODE").Val().([]interface{}) - require.Equal(t, list[1].(string), "mylib1") - require.Equal(t, list[3].(string), luaMylib1) - require.Equal(t, list[5].(string), "mylib2") - require.Equal(t, list[7].(string), luaMylib2) - require.Equal(t, len(list), 8) - - list = rdb.Do(ctx, "FUNCTION", "LISTFUNC").Val().([]interface{}) - require.Equal(t, list[1].(string), "add") - require.Equal(t, list[3].(string), "mylib1") - require.Equal(t, list[5].(string), "hello") - require.Equal(t, list[7].(string), "mylib2") - require.Equal(t, list[9].(string), "inc") - require.Equal(t, list[11].(string), "mylib1") - require.Equal(t, list[13].(string), "reverse") - require.Equal(t, list[15].(string), "mylib2") - require.Equal(t, len(list), 16) + libraries, err := rdb.FunctionList(ctx, redis.FunctionListQuery{ + WithCode: true, + }).Result() + require.NoError(t, err) + require.EqualValues(t, 2, len(libraries)) + + list := rdb.Do(ctx, "FUNCTION", "LISTFUNC").Val().([]interface{}) + expected := []ListFuncResult{ + {Name: "add", Library: "mylib1"}, + {Name: "hello", Library: "mylib2"}, + {Name: "inc", Library: "mylib1"}, + {Name: "reverse", Library: "mylib2"}, + } + require.EqualValues(t, len(expected), len(list)) + for i, f := range expected { + actual := decodeListFuncResult(t, list[i]) + require.Equal(t, f.Name, actual.Name) + require.Equal(t, f.Library, actual.Library) + } }) t.Run("FUNCTION DELETE", func(t *testing.T) { @@ -113,17 +188,24 @@ func TestFunction(t *testing.T) { util.ErrorRegexp(t, rdb.Do(ctx, "FCALL", "reverse", 0, "x").Err(), ".*No such function name.*") require.Equal(t, rdb.Do(ctx, "FCALL", "inc", 0, 3).Val(), int64(4)) - list := rdb.Do(ctx, "FUNCTION", "LIST", "WITHCODE").Val().([]interface{}) - require.Equal(t, list[1].(string), "mylib1") - require.Equal(t, list[3].(string), luaMylib1) - require.Equal(t, len(list), 4) - - list = rdb.Do(ctx, "FUNCTION", "LISTFUNC").Val().([]interface{}) - require.Equal(t, list[1].(string), "add") - require.Equal(t, list[3].(string), "mylib1") - require.Equal(t, list[5].(string), "inc") - require.Equal(t, list[7].(string), "mylib1") - require.Equal(t, len(list), 8) + libraries, err := rdb.FunctionList(ctx, redis.FunctionListQuery{ + WithCode: true, + }).Result() + require.NoError(t, err) + require.EqualValues(t, 1, len(libraries)) + require.Equal(t, "mylib1", libraries[0].Name) + + list := rdb.Do(ctx, "FUNCTION", "LISTFUNC").Val().([]interface{}) + expected := []ListFuncResult{ + {Name: "add", Library: "mylib1"}, + {Name: "inc", Library: "mylib1"}, + } + require.EqualValues(t, len(expected), len(list)) + for i, f := range expected { + actual := decodeListFuncResult(t, list[i]) + require.Equal(t, f.Name, actual.Name) + require.Equal(t, f.Library, actual.Library) + } }) t.Run("FUNCTION LOAD REPLACE", func(t *testing.T) { @@ -135,17 +217,24 @@ func TestFunction(t *testing.T) { require.Equal(t, rdb.Do(ctx, "FCALL", "reverse", 0, "xyz").Val(), "zyx") util.ErrorRegexp(t, rdb.Do(ctx, "FCALL", "inc", 0, 1).Err(), ".*No such function name.*") - list := rdb.Do(ctx, "FUNCTION", "LIST", "WITHCODE").Val().([]interface{}) - require.Equal(t, list[1].(string), "mylib1") - require.Equal(t, list[3].(string), code) - require.Equal(t, len(list), 4) - - list = rdb.Do(ctx, "FUNCTION", "LISTFUNC").Val().([]interface{}) - require.Equal(t, list[1].(string), "hello") - require.Equal(t, list[3].(string), "mylib1") - require.Equal(t, list[5].(string), "reverse") - require.Equal(t, list[7].(string), "mylib1") - require.Equal(t, len(list), 8) + libraries, err := rdb.FunctionList(ctx, redis.FunctionListQuery{ + WithCode: true, + }).Result() + require.NoError(t, err) + require.EqualValues(t, 1, len(libraries)) + require.Equal(t, "mylib1", libraries[0].Name) + + list := rdb.Do(ctx, "FUNCTION", "LISTFUNC").Val().([]interface{}) + expected := []ListFuncResult{ + {Name: "hello", Library: "mylib1"}, + {Name: "reverse", Library: "mylib1"}, + } + require.EqualValues(t, len(expected), len(list)) + for i, f := range expected { + actual := decodeListFuncResult(t, list[i]) + require.Equal(t, f.Name, actual.Name) + require.Equal(t, f.Library, actual.Library) + } }) t.Run("FCALL_RO", func(t *testing.T) { @@ -167,19 +256,24 @@ func TestFunction(t *testing.T) { require.Equal(t, rdb.Do(ctx, "FCALL", "myget", 1, "x").Val(), "2") require.Equal(t, rdb.Do(ctx, "FCALL", "hello", 0, "xxx").Val(), "Hello, xxx!") - list := rdb.Do(ctx, "FUNCTION", "LIST").Val().([]interface{}) - require.Equal(t, list[1].(string), "mylib1") - require.Equal(t, list[3].(string), "mylib3") - require.Equal(t, len(list), 4) + libraries, err := rdb.FunctionList(ctx, redis.FunctionListQuery{ + WithCode: true, + }).Result() + require.NoError(t, err) + require.EqualValues(t, 2, len(libraries)) + require.Equal(t, libraries[0].Name, "mylib1") + require.Equal(t, libraries[1].Name, "mylib3") }) t.Run("FUNCTION LISTLIB", func(t *testing.T) { - list := rdb.Do(ctx, "FUNCTION", "LISTLIB", "mylib1").Val().([]interface{}) - require.Equal(t, list[1].(string), "mylib1") - require.Equal(t, list[5].([]interface{}), []interface{}{"hello", "reverse"}) - - list = rdb.Do(ctx, "FUNCTION", "LISTLIB", "mylib3").Val().([]interface{}) - require.Equal(t, list[1].(string), "mylib3") - require.Equal(t, list[5].([]interface{}), []interface{}{"myget", "myset"}) + r := rdb.Do(ctx, "FUNCTION", "LISTLIB", "mylib1").Val() + require.EqualValues(t, ListLibResult{ + Name: "mylib1", Engine: "lua", Functions: []interface{}{"hello", "reverse"}, + }, decodeListLibResult(t, r)) + + r = rdb.Do(ctx, "FUNCTION", "LISTLIB", "mylib3").Val() + require.EqualValues(t, ListLibResult{ + Name: "mylib3", Engine: "lua", Functions: []interface{}{"myget", "myset"}, + }, decodeListLibResult(t, r)) }) } diff --git a/tests/gocase/unit/type/hash/hash_test.go b/tests/gocase/unit/type/hash/hash_test.go index bf93d268860..38b0a576089 100644 --- a/tests/gocase/unit/type/hash/hash_test.go +++ b/tests/gocase/unit/type/hash/hash_test.go @@ -50,8 +50,18 @@ func getVals(hash map[string]string) []string { return r } -func TestHash(t *testing.T) { - srv := util.StartServer(t, map[string]string{}) +func TestHashWithRESP2(t *testing.T) { + testHash(t, "no") +} + +func TestHashWithRESP3(t *testing.T) { + testHash(t, "yes") +} + +var testHash = func(t *testing.T, enabledRESP3 string) { + srv := util.StartServer(t, map[string]string{ + "resp3-enabled": enabledRESP3, + }) defer srv.Close() ctx := context.Background() rdb := srv.NewClient() @@ -359,29 +369,15 @@ func TestHash(t *testing.T) { }) t.Run("HGETALL - small hash}", func(t *testing.T) { - res := rdb.Do(ctx, "hgetall", "smallhash").Val().([]interface{}) - mid := make(map[string]string) - for i := 0; i < len(res); i += 2 { - if res[i+1] == nil { - mid[res[i].(string)] = "" - } else { - mid[res[i].(string)] = res[i+1].(string) - } - } - require.Equal(t, smallhash, mid) + gotHash, err := rdb.HGetAll(ctx, "smallhash").Result() + require.NoError(t, err) + require.Equal(t, smallhash, gotHash) }) t.Run("HGETALL - big hash}", func(t *testing.T) { - res := rdb.Do(ctx, "hgetall", "bighash").Val().([]interface{}) - mid := make(map[string]string) - for i := 0; i < len(res); i += 2 { - if res[i+1] == nil { - mid[res[i].(string)] = "" - } else { - mid[res[i].(string)] = res[i+1].(string) - } - } - require.Equal(t, bighash, mid) + gotHash, err := rdb.HGetAll(ctx, "bighash").Result() + require.NoError(t, err) + require.Equal(t, bighash, gotHash) }) t.Run("HGETALL - field with empty string as a value", func(t *testing.T) { @@ -835,6 +831,30 @@ func TestHash(t *testing.T) { } } +func TestHGetAllWithRESP3(t *testing.T) { + srv := util.StartServer(t, map[string]string{ + "resp3-enabled": "yes", + }) + defer srv.Close() + + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + ctx := context.Background() + + testKey := "test-hash-1" + require.NoError(t, rdb.Del(ctx, testKey).Err()) + require.NoError(t, rdb.HSet(ctx, testKey, "key1", "value1", "key2", "value2", "key3", "value3").Err()) + result, err := rdb.HGetAll(ctx, testKey).Result() + require.NoError(t, err) + require.Len(t, result, 3) + require.EqualValues(t, map[string]string{ + "key1": "value1", + "key2": "value2", + "key3": "value3", + }, result) +} + func TestHashWithAsyncIOEnabled(t *testing.T) { srv := util.StartServer(t, map[string]string{ "rocksdb.read_options.async_io": "yes", diff --git a/tests/gocase/unit/type/set/set_test.go b/tests/gocase/unit/type/set/set_test.go index 8a58c7b93eb..0d596620936 100644 --- a/tests/gocase/unit/type/set/set_test.go +++ b/tests/gocase/unit/type/set/set_test.go @@ -57,7 +57,17 @@ func GetArrayUnion(arrays ...[]string) []string { } func TestSet(t *testing.T) { - srv := util.StartServer(t, map[string]string{}) + setTests(t, "no") +} + +func TestSetWithRESP3(t *testing.T) { + setTests(t, "yes") +} + +var setTests = func(t *testing.T, enabledRESP3 string) { + srv := util.StartServer(t, map[string]string{ + "resp3-enabled": enabledRESP3, + }) defer srv.Close() ctx := context.Background() rdb := srv.NewClient() diff --git a/tests/gocase/unit/type/stream/stream_test.go b/tests/gocase/unit/type/stream/stream_test.go index d3a1f8d273a..7dee10b6b3e 100644 --- a/tests/gocase/unit/type/stream/stream_test.go +++ b/tests/gocase/unit/type/stream/stream_test.go @@ -34,8 +34,18 @@ import ( "github.com/stretchr/testify/require" ) -func TestStream(t *testing.T) { - srv := util.StartServer(t, map[string]string{}) +func TestStreamWithRESP2(t *testing.T) { + streamTests(t, "no") +} + +func TestStreamWithRESP3(t *testing.T) { + streamTests(t, "yes") +} + +var streamTests = func(t *testing.T, enabledRESP3 string) { + srv := util.StartServer(t, map[string]string{ + "resp3-enabled": enabledRESP3, + }) defer srv.Close() ctx := context.Background() rdb := srv.NewClient() diff --git a/tests/gocase/unit/type/types_test.go b/tests/gocase/unit/type/types_test.go new file mode 100644 index 00000000000..3a33c1d56c2 --- /dev/null +++ b/tests/gocase/unit/type/types_test.go @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package types + +import ( + "context" + "testing" + + "github.com/apache/kvrocks/tests/gocase/util" + "github.com/stretchr/testify/require" +) + +func TestTypesError(t *testing.T) { + srv := util.StartServer(t, map[string]string{}) + defer srv.Close() + ctx := context.Background() + rdb := srv.NewClient() + defer func() { require.NoError(t, rdb.Close()) }() + + t.Run("Operate with wrong type", func(t *testing.T) { + message := "ERR Invalid argument: WRONGTYPE Operation against a key holding the wrong kind of value" + require.NoError(t, rdb.Set(ctx, "a", "hello", 0).Err()) + require.EqualError(t, rdb.Do(ctx, "XADD", "a", "*", "a", "test").Err(), message) + require.EqualError(t, rdb.Do(ctx, "LPUSH", "a", 1).Err(), message) + require.EqualError(t, rdb.Do(ctx, "HSET", "a", "1", "2").Err(), message) + require.EqualError(t, rdb.Do(ctx, "SADD", "a", "1", "2").Err(), message) + require.EqualError(t, rdb.Do(ctx, "ZADD", "a", "1", "2").Err(), message) + require.EqualError(t, rdb.Do(ctx, "JSON.SET", "a", "$", "{}").Err(), message) + require.EqualError(t, rdb.Do(ctx, "BF.ADD", "a", "test").Err(), message) + require.EqualError(t, rdb.Do(ctx, "SADD", "a", 100).Err(), message) + + require.NoError(t, rdb.LPush(ctx, "a1", "hello", 0).Err()) + require.EqualError(t, rdb.Do(ctx, "SETBIT", "a1", 1, 1).Err(), message) + require.EqualError(t, rdb.Do(ctx, "GET", "a1").Err(), message) + + }) +} diff --git a/tests/gocase/unit/type/zset/zset_test.go b/tests/gocase/unit/type/zset/zset_test.go index 86adceda403..5f1cf80fab3 100644 --- a/tests/gocase/unit/type/zset/zset_test.go +++ b/tests/gocase/unit/type/zset/zset_test.go @@ -67,7 +67,8 @@ func createDefaultLexZset(rdb *redis.Client, ctx context.Context) { {0, "omega"}}) } -func basicTests(t *testing.T, rdb *redis.Client, ctx context.Context, encoding string, srv *util.KvrocksServer) { +func basicTests(t *testing.T, rdb *redis.Client, ctx context.Context, enabledRESP3, encoding string, srv *util.KvrocksServer) { + isRESP3 := enabledRESP3 == "yes" t.Run(fmt.Sprintf("Check encoding - %s", encoding), func(t *testing.T) { rdb.Del(ctx, "ztmp") rdb.ZAdd(ctx, "ztmp", redis.Z{Score: 10, Member: "x"}) @@ -103,9 +104,15 @@ func basicTests(t *testing.T, rdb *redis.Client, ctx context.Context, encoding s t.Run(fmt.Sprintf("ZSET ZADD IncrMixedOtherOptions - %s", encoding), func(t *testing.T) { rdb.Del(ctx, "ztmp") - require.Equal(t, "1.5", rdb.Do(ctx, "zadd", "ztmp", "nx", "nx", "nx", "nx", "incr", "1.5", "abc").Val()) - require.Equal(t, redis.Nil, rdb.Do(ctx, "zadd", "ztmp", "nx", "nx", "nx", "nx", "incr", "1.5", "abc").Err()) - require.Equal(t, "3", rdb.Do(ctx, "zadd", "ztmp", "xx", "xx", "xx", "xx", "incr", "1.5", "abc").Val()) + if isRESP3 { + require.Equal(t, 1.5, rdb.Do(ctx, "zadd", "ztmp", "nx", "nx", "nx", "nx", "incr", "1.5", "abc").Val()) + require.Equal(t, redis.Nil, rdb.Do(ctx, "zadd", "ztmp", "nx", "nx", "nx", "nx", "incr", "1.5", "abc").Err()) + require.EqualValues(t, 3, rdb.Do(ctx, "zadd", "ztmp", "xx", "xx", "xx", "xx", "incr", "1.5", "abc").Val()) + } else { + require.Equal(t, "1.5", rdb.Do(ctx, "zadd", "ztmp", "nx", "nx", "nx", "nx", "incr", "1.5", "abc").Val()) + require.Equal(t, redis.Nil, rdb.Do(ctx, "zadd", "ztmp", "nx", "nx", "nx", "nx", "incr", "1.5", "abc").Err()) + require.Equal(t, "3", rdb.Do(ctx, "zadd", "ztmp", "xx", "xx", "xx", "xx", "incr", "1.5", "abc").Val()) + } rdb.Del(ctx, "ztmp") require.Equal(t, 1.5, rdb.ZAddArgsIncr(ctx, "ztmp", redis.ZAddArgs{NX: true, Members: []redis.Z{{Member: "abc", Score: 1.5}}}).Val()) @@ -684,14 +691,14 @@ func basicTests(t *testing.T, rdb *redis.Client, ctx context.Context, encoding s require.Equal(t, int64(0), rdb.ZRevRank(ctx, "zranktmp", "z").Val()) require.Equal(t, redis.Nil, rdb.ZRevRank(ctx, "zranktmp", "foo").Err()) - require.Equal(t, []interface{}{int64(0), "10"}, rdb.Do(ctx, "zrank", "zranktmp", "x", "withscore").Val()) - require.Equal(t, []interface{}{int64(1), "20"}, rdb.Do(ctx, "zrank", "zranktmp", "y", "withscore").Val()) - require.Equal(t, []interface{}{int64(2), "30"}, rdb.Do(ctx, "zrank", "zranktmp", "z", "withscore").Val()) - require.Equal(t, redis.Nil, rdb.Do(ctx, "zrank", "zranktmp", "foo", "withscore").Err()) - require.Equal(t, []interface{}{int64(2), "10"}, rdb.Do(ctx, "zrevrank", "zranktmp", "x", "withscore").Val()) - require.Equal(t, []interface{}{int64(1), "20"}, rdb.Do(ctx, "zrevrank", "zranktmp", "y", "withscore").Val()) - require.Equal(t, []interface{}{int64(0), "30"}, rdb.Do(ctx, "zrevrank", "zranktmp", "z", "withscore").Val()) - require.Equal(t, redis.Nil, rdb.Do(ctx, "zrevrank", "zranktmp", "foo", "withscore").Err()) + require.Equal(t, redis.RankScore{Rank: 0, Score: 10}, rdb.ZRankWithScore(ctx, "zranktmp", "x").Val()) + require.Equal(t, redis.RankScore{Rank: 1, Score: 20}, rdb.ZRankWithScore(ctx, "zranktmp", "y").Val()) + require.Equal(t, redis.RankScore{Rank: 2, Score: 30}, rdb.ZRankWithScore(ctx, "zranktmp", "z").Val()) + require.Equal(t, redis.Nil, rdb.ZRankWithScore(ctx, "zranktmp", "foo").Err()) + require.Equal(t, redis.RankScore{Rank: 2, Score: 10}, rdb.ZRevRankWithScore(ctx, "zranktmp", "x").Val()) + require.Equal(t, redis.RankScore{Rank: 1, Score: 20}, rdb.ZRevRankWithScore(ctx, "zranktmp", "y").Val()) + require.Equal(t, redis.RankScore{Rank: 0, Score: 30}, rdb.ZRevRankWithScore(ctx, "zranktmp", "z").Val()) + require.Equal(t, redis.Nil, rdb.ZRevRankWithScore(ctx, "zranktmp", "foo").Err()) }) t.Run(fmt.Sprintf("ZRANK/ZREVRANK - after deletion -%s", encoding), func(t *testing.T) { @@ -704,12 +711,12 @@ func basicTests(t *testing.T, rdb *redis.Client, ctx context.Context, encoding s require.Equal(t, int64(0), rdb.ZRevRank(ctx, "zranktmp", "z").Val()) require.Equal(t, redis.Nil, rdb.ZRevRank(ctx, "zranktmp", "foo").Err()) - require.Equal(t, []interface{}{int64(0), "10"}, rdb.Do(ctx, "zrank", "zranktmp", "x", "withscore").Val()) - require.Equal(t, []interface{}{int64(1), "30"}, rdb.Do(ctx, "zrank", "zranktmp", "z", "withscore").Val()) - require.Equal(t, redis.Nil, rdb.Do(ctx, "zrank", "zranktmp", "foo", "withscore").Err()) - require.Equal(t, []interface{}{int64(1), "10"}, rdb.Do(ctx, "zrevrank", "zranktmp", "x", "withscore").Val()) - require.Equal(t, []interface{}{int64(0), "30"}, rdb.Do(ctx, "zrevrank", "zranktmp", "z", "withscore").Val()) - require.Equal(t, redis.Nil, rdb.Do(ctx, "zrevrank", "zranktmp", "foo", "withscore").Err()) + require.Equal(t, redis.RankScore{Rank: 0, Score: 10}, rdb.ZRankWithScore(ctx, "zranktmp", "x").Val()) + require.Equal(t, redis.RankScore{Rank: 1, Score: 30}, rdb.ZRankWithScore(ctx, "zranktmp", "z").Val()) + require.Equal(t, redis.Nil, rdb.ZRankWithScore(ctx, "zranktmp", "foo").Err()) + require.Equal(t, redis.RankScore{Rank: 1, Score: 10}, rdb.ZRevRankWithScore(ctx, "zranktmp", "x").Val()) + require.Equal(t, redis.RankScore{Rank: 0, Score: 30}, rdb.ZRevRankWithScore(ctx, "zranktmp", "z").Val()) + require.Equal(t, redis.Nil, rdb.ZRevRankWithScore(ctx, "zranktmp", "foo").Err()) }) t.Run(fmt.Sprintf("ZINCRBY - can create a new sorted set - %s", encoding), func(t *testing.T) { @@ -1463,6 +1470,167 @@ func basicTests(t *testing.T, rdb *redis.Client, ctx context.Context, encoding s ).Err(), ".*weight.*not.*double.*") }) } + + t.Run(fmt.Sprintf("ZDIFF with two sets - %s", encoding), func(t *testing.T) { + createZset(rdb, ctx, "zseta", []redis.Z{ + {Score: 1, Member: "a"}, + {Score: 2, Member: "b"}, + {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, + {Score: 4, Member: "e"}, + }) + createZset(rdb, ctx, "zsetb", []redis.Z{ + {Score: 1, Member: "b"}, + {Score: 2, Member: "c"}, + {Score: 4, Member: "f"}, + }) + cmd := rdb.ZDiff(ctx, "zseta", "zsetb") + require.NoError(t, cmd.Err()) + sort.Strings(cmd.Val()) + require.EqualValues(t, []string{"a", "d", "e"}, cmd.Val()) + }) + + t.Run(fmt.Sprintf("ZDIFF with three sets - %s", encoding), func(t *testing.T) { + createZset(rdb, ctx, "zseta", []redis.Z{ + {Score: 1, Member: "a"}, + {Score: 2, Member: "b"}, + {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, + {Score: 4, Member: "e"}, + }) + createZset(rdb, ctx, "zsetb", []redis.Z{ + {Score: 1, Member: "b"}, + {Score: 2, Member: "c"}, + {Score: 4, Member: "f"}, + }) + createZset(rdb, ctx, "zsetc", []redis.Z{ + {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, + {Score: 5, Member: "e"}, + }) + cmd := rdb.ZDiff(ctx, "zseta", "zsetb", "zsetc") + require.NoError(t, cmd.Err()) + sort.Strings(cmd.Val()) + require.EqualValues(t, []string{"a"}, cmd.Val()) + }) + + t.Run(fmt.Sprintf("ZDIFF with three sets with scores - %s", encoding), func(t *testing.T) { + createZset(rdb, ctx, "zseta", []redis.Z{ + {Score: 1, Member: "a"}, + {Score: 2, Member: "b"}, + {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, + {Score: 4, Member: "e"}, + }) + createZset(rdb, ctx, "zsetb", []redis.Z{ + {Score: 1, Member: "b"}, + {Score: 2, Member: "c"}, + {Score: 4, Member: "f"}, + }) + createZset(rdb, ctx, "zsetc", []redis.Z{ + {Score: 4, Member: "c"}, + {Score: 5, Member: "e"}, + }) + cmd := rdb.ZDiffWithScores(ctx, "zseta", "zsetb", "zsetc") + require.NoError(t, cmd.Err()) + require.EqualValues(t, []redis.Z([]redis.Z{{Score: 1, Member: "a"}, {Score: 3, Member: "d"}}), cmd.Val()) + }) + + t.Run(fmt.Sprintf("ZDIFF with empty sets - %s", encoding), func(t *testing.T) { + createZset(rdb, ctx, "zseta", []redis.Z{}) + createZset(rdb, ctx, "zsetb", []redis.Z{}) + cmd := rdb.ZDiff(ctx, "zseta", "zsetb") + require.NoError(t, cmd.Err()) + require.EqualValues(t, []string([]string{}), cmd.Val()) + }) + + t.Run(fmt.Sprintf("ZDIFF with non existing sets - %s", encoding), func(t *testing.T) { + rdb.Del(ctx, "zseta") + rdb.Del(ctx, "zsetb") + cmd := rdb.ZDiff(ctx, "zseta", "zsetb") + require.NoError(t, cmd.Err()) + require.EqualValues(t, []string([]string{}), cmd.Val()) + }) + + t.Run(fmt.Sprintf("ZDIFF with missing set with scores - %s", encoding), func(t *testing.T) { + createZset(rdb, ctx, "zseta", []redis.Z{ + {Score: 1, Member: "a"}, + {Score: 2, Member: "b"}, + {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, + }) + createZset(rdb, ctx, "zsetb", []redis.Z{ + {Score: 1, Member: "b"}, + {Score: 2, Member: "c"}, + {Score: 4, Member: "f"}, + }) + rdb.Del(ctx, "zsetc") + cmd := rdb.ZDiffWithScores(ctx, "zseta", "zsetb", "zsetc") + require.NoError(t, cmd.Err()) + require.EqualValues(t, []redis.Z([]redis.Z{{Score: 1, Member: "a"}, {Score: 3, Member: "d"}}), cmd.Val()) + }) + + t.Run(fmt.Sprintf("ZDIFF with empty sets with scores - %s", encoding), func(t *testing.T) { + createZset(rdb, ctx, "zseta", []redis.Z{}) + createZset(rdb, ctx, "zsetb", []redis.Z{}) + cmd := rdb.ZDiffWithScores(ctx, "zseta", "zsetb") + require.NoError(t, cmd.Err()) + require.EqualValues(t, []redis.Z([]redis.Z{}), cmd.Val()) + }) + + t.Run("ZDIFFSTORE with three sets - ", func(t *testing.T) { + createZset(rdb, ctx, "zseta", []redis.Z{ + {Score: 1, Member: "a"}, + {Score: 2, Member: "b"}, + {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, + {Score: 4, Member: "e"}, + }) + createZset(rdb, ctx, "zsetb", []redis.Z{ + {Score: 1, Member: "b"}, + {Score: 2, Member: "c"}, + {Score: 4, Member: "f"}, + }) + createZset(rdb, ctx, "zsetc", []redis.Z{ + {Score: 4, Member: "c"}, + {Score: 5, Member: "e"}, + }) + cmd := rdb.ZDiffStore(ctx, "setres", "zseta", "zsetb", "zsetc") + require.NoError(t, cmd.Err()) + require.EqualValues(t, int64(2), cmd.Val()) + require.Equal(t, []redis.Z([]redis.Z{{Score: 1, Member: "a"}, {Score: 3, Member: "d"}}), rdb.ZRangeWithScores(ctx, "setres", 0, -1).Val()) + }) + + t.Run("ZDIFFSTORE with missing sets - ", func(t *testing.T) { + createZset(rdb, ctx, "zseta", []redis.Z{ + {Score: 1, Member: "a"}, + {Score: 2, Member: "b"}, + {Score: 3, Member: "c"}, + {Score: 3, Member: "d"}, + {Score: 4, Member: "e"}, + }) + createZset(rdb, ctx, "zsetb", []redis.Z{ + {Score: 1, Member: "b"}, + {Score: 2, Member: "c"}, + {Score: 4, Member: "f"}, + {Score: 4, Member: "e"}, + }) + rdb.Del(ctx, "zsetc") + cmd := rdb.ZDiffStore(ctx, "setres", "zseta", "zsetb", "zsetc") + require.NoError(t, cmd.Err()) + require.EqualValues(t, int64(2), cmd.Val()) + require.Equal(t, []redis.Z([]redis.Z{{Score: 1, Member: "a"}, {Score: 3, Member: "d"}}), rdb.ZRangeWithScores(ctx, "setres", 0, -1).Val()) + }) + + t.Run("ZDIFFSTORE with missing sets - ", func(t *testing.T) { + rdb.Del(ctx, "zseta") + rdb.Del(ctx, "zsetb") + rdb.Del(ctx, "zsetc") + cmd := rdb.ZDiffStore(ctx, "setres", "zseta", "zsetb", "zsetc") + require.NoError(t, cmd.Err()) + require.EqualValues(t, int64(0), cmd.Val()) + require.Equal(t, []redis.Z([]redis.Z{}), rdb.ZRangeWithScores(ctx, "setres", 0, -1).Val()) + }) } func stressTests(t *testing.T, rdb *redis.Client, ctx context.Context, encoding string) { @@ -1732,14 +1900,24 @@ func stressTests(t *testing.T, rdb *redis.Client, ctx context.Context, encoding }) } -func TestZset(t *testing.T) { - srv := util.StartServer(t, map[string]string{}) +func TestZSetWithRESP2(t *testing.T) { + testZSet(t, "no") +} + +func TestZSetWithRESP3(t *testing.T) { + testZSet(t, "yes") +} + +var testZSet = func(t *testing.T, enabledRESP3 string) { + srv := util.StartServer(t, map[string]string{ + "resp3-enabled": enabledRESP3, + }) defer srv.Close() ctx := context.Background() rdb := srv.NewClient() defer func() { require.NoError(t, rdb.Close()) }() - basicTests(t, rdb, ctx, "skiplist", srv) + basicTests(t, rdb, ctx, enabledRESP3, "skiplist", srv) t.Run("ZUNIONSTORE regression, should not create NaN in scores", func(t *testing.T) { rdb.ZAdd(ctx, "z", redis.Z{Score: math.Inf(-1), Member: "neginf"})