Skip to content

Commit

Permalink
fix: provide resp3 option to CapturingReplyBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
BorysTheDev committed Jan 5, 2025
1 parent 6e9409c commit e7eeffd
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 37 deletions.
7 changes: 3 additions & 4 deletions src/facade/reply_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ void MCReplyBuilder::SendRaw(std::string_view str) {

void RedisReplyBuilderBase::SendNull() {
ReplyScope scope(this);
resp3_ ? WritePieces(kNullStringR3) : WritePieces(kNullStringR2);
IsResp3() ? WritePieces(kNullStringR3) : WritePieces(kNullStringR2);
}

void RedisReplyBuilderBase::SendSimpleString(std::string_view str) {
Expand Down Expand Up @@ -323,7 +323,7 @@ void RedisReplyBuilderBase::SendDouble(double val) {
static_assert(ABSL_ARRAYSIZE(buf) < kMaxInlineSize, "Write temporary string from buf inline");
string_view val_str = FormatDouble(val, buf, ABSL_ARRAYSIZE(buf));

if (!resp3_)
if (!IsResp3())
return SendBulkString(val_str);

ReplyScope scope(this);
Expand Down Expand Up @@ -422,7 +422,7 @@ void RedisReplyBuilder::SendScoredArray(ScoredArray arr, bool with_scores) {

void RedisReplyBuilder::SendLabeledScoredArray(std::string_view arr_label, ScoredArray arr) {
ReplyScope scope(this);

StartArray(2);

SendBulkString(arr_label);
Expand All @@ -432,7 +432,6 @@ void RedisReplyBuilder::SendLabeledScoredArray(std::string_view arr_label, Score
SendBulkString(str);
SendDouble(score);
}

}

void RedisReplyBuilder::SendStored() {
Expand Down
14 changes: 10 additions & 4 deletions src/facade/reply_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ enum class ReplyMode {
FULL // All replies are recorded
};

enum class RespVersion { kResp2, kResp3 };

// Base class for all reply builders. Offer a simple high level interface for controlling output
// modes and sending basic response types.
class SinkReplyBuilder {
Expand Down Expand Up @@ -258,15 +260,19 @@ class RedisReplyBuilderBase : public SinkReplyBuilder {
static std::string SerializeCommand(std::string_view command);

bool IsResp3() const {
return resp3_;
return resp_ == RespVersion::kResp3;
}

void SetRespVersion(RespVersion resp_version) {
resp_ = resp_version;
}

void SetResp3(bool resp3) {
resp3_ = resp3;
RespVersion GetRespVersion() {
return resp_;
}

private:
bool resp3_ = false;
RespVersion resp_ = RespVersion::kResp2;
};

// Non essential redis reply builder functions implemented on top of the base resp protocol
Expand Down
36 changes: 18 additions & 18 deletions src/facade/reply_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -699,14 +699,14 @@ TEST_F(RedisReplyBuilderTest, BatchMode) {
}

TEST_F(RedisReplyBuilderTest, Resp3Double) {
builder_->SetResp3(true);
builder_->SetRespVersion(RespVersion::kResp3);
builder_->SendDouble(5.5);
ASSERT_TRUE(NoErrors());
ASSERT_EQ(str(), ",5.5\r\n");
}

TEST_F(RedisReplyBuilderTest, Resp3NullString) {
builder_->SetResp3(true);
builder_->SetRespVersion(RespVersion::kResp3);
builder_->SendNull();
ASSERT_TRUE(NoErrors());
ASSERT_EQ(TakePayload(), "_\r\n");
Expand All @@ -715,13 +715,13 @@ TEST_F(RedisReplyBuilderTest, Resp3NullString) {
TEST_F(RedisReplyBuilderTest, SendStringArrayAsMap) {
const std::vector<std::string> map_array{"k1", "v1", "k2", "v2"};

builder_->SetResp3(false);
builder_->SetRespVersion(RespVersion::kResp2);
builder_->SendBulkStrArr(map_array, builder_->MAP);
ASSERT_TRUE(NoErrors());
ASSERT_EQ(TakePayload(), "*4\r\n$2\r\nk1\r\n$2\r\nv1\r\n$2\r\nk2\r\n$2\r\nv2\r\n")
<< "SendStringArrayAsMap Resp2 Failed.";

builder_->SetResp3(true);
builder_->SetRespVersion(RespVersion::kResp3);
builder_->SendBulkStrArr(map_array, builder_->MAP);
ASSERT_TRUE(NoErrors());
ASSERT_EQ(TakePayload(), "%2\r\n$2\r\nk1\r\n$2\r\nv1\r\n$2\r\nk2\r\n$2\r\nv2\r\n")
Expand All @@ -731,13 +731,13 @@ TEST_F(RedisReplyBuilderTest, SendStringArrayAsMap) {
TEST_F(RedisReplyBuilderTest, SendStringArrayAsSet) {
const std::vector<std::string> set_array{"e1", "e2", "e3"};

builder_->SetResp3(false);
builder_->SetRespVersion(RespVersion::kResp2);
builder_->SendBulkStrArr(set_array, builder_->SET);
ASSERT_TRUE(NoErrors());
ASSERT_EQ(TakePayload(), "*3\r\n$2\r\ne1\r\n$2\r\ne2\r\n$2\r\ne3\r\n")
<< "SendStringArrayAsSet Resp2 Failed.";

builder_->SetResp3(true);
builder_->SetRespVersion(RespVersion::kResp3);
builder_->SendBulkStrArr(set_array, builder_->SET);
ASSERT_TRUE(NoErrors());
ASSERT_EQ(TakePayload(), "~3\r\n$2\r\ne1\r\n$2\r\ne2\r\n$2\r\ne3\r\n")
Expand All @@ -748,26 +748,26 @@ TEST_F(RedisReplyBuilderTest, SendScoredArray) {
const std::vector<std::pair<std::string, double>> scored_array{
{"e1", 1.1}, {"e2", 2.2}, {"e3", 3.3}};

builder_->SetResp3(false);
builder_->SetRespVersion(RespVersion::kResp2);
builder_->SendScoredArray(scored_array, false);
ASSERT_TRUE(NoErrors());
ASSERT_EQ(TakePayload(), "*3\r\n$2\r\ne1\r\n$2\r\ne2\r\n$2\r\ne3\r\n")
<< "Resp2 WITHOUT scores failed.";

builder_->SetResp3(true);
builder_->SetRespVersion(RespVersion::kResp3);
builder_->SendScoredArray(scored_array, false);
ASSERT_TRUE(NoErrors());
ASSERT_EQ(TakePayload(), "*3\r\n$2\r\ne1\r\n$2\r\ne2\r\n$2\r\ne3\r\n")
<< "Resp3 WITHOUT scores failed.";

builder_->SetResp3(false);
builder_->SetRespVersion(RespVersion::kResp2);
builder_->SendScoredArray(scored_array, true);
ASSERT_TRUE(NoErrors());
ASSERT_EQ(TakePayload(),
"*6\r\n$2\r\ne1\r\n$3\r\n1.1\r\n$2\r\ne2\r\n$3\r\n2.2\r\n$2\r\ne3\r\n$3\r\n3.3\r\n")
<< "Resp3 WITHSCORES failed.";

builder_->SetResp3(true);
builder_->SetRespVersion(RespVersion::kResp3);
builder_->SendScoredArray(scored_array, true);
ASSERT_TRUE(NoErrors());
ASSERT_EQ(TakePayload(),
Expand All @@ -779,15 +779,15 @@ TEST_F(RedisReplyBuilderTest, SendLabeledScoredArray) {
const std::vector<std::pair<std::string, double>> scored_array{
{"e1", 1.1}, {"e2", 2.2}, {"e3", 3.3}};

builder_->SetResp3(false);
builder_->SetRespVersion(RespVersion::kResp2);
builder_->SendLabeledScoredArray("foobar", scored_array);
ASSERT_TRUE(NoErrors());
ASSERT_EQ(TakePayload(),
"*2\r\n$6\r\nfoobar\r\n*3\r\n*2\r\n$2\r\ne1\r\n$3\r\n1.1\r\n*2\r\n$2\r\ne2\r\n$3\r\n2."
"2\r\n*2\r\n$2\r\ne3\r\n$3\r\n3.3\r\n")
<< "Resp3 failed.\n";

builder_->SetResp3(true);
builder_->SetRespVersion(RespVersion::kResp3);
builder_->SendLabeledScoredArray("foobar", scored_array);
ASSERT_TRUE(NoErrors());
ASSERT_EQ(TakePayload(),
Expand Down Expand Up @@ -850,8 +850,8 @@ TEST_F(RedisReplyBuilderTest, BasicCapture) {
big_arr_cb,
};

crb.SetResp3(true);
builder_->SetResp3(true);
crb.SetRespVersion(RespVersion::kResp3);
builder_->SetRespVersion(RespVersion::kResp3);

// Run generator functions on both a regular redis builder
// and the capturing builder with its capture applied.
Expand All @@ -864,7 +864,7 @@ TEST_F(RedisReplyBuilderTest, BasicCapture) {
EXPECT_EQ(expected, actual);
}

builder_->SetResp3(false);
builder_->SetRespVersion(RespVersion::kResp2);
}

TEST_F(RedisReplyBuilderTest, FormatDouble) {
Expand All @@ -889,17 +889,17 @@ TEST_F(RedisReplyBuilderTest, VerbatimString) {
// test resp3
std::string str = "A simple string!";

builder_->SetResp3(true);
builder_->SetRespVersion(RespVersion::kResp3);
builder_->SendVerbatimString(str, RedisReplyBuilder::VerbatimFormat::TXT);
ASSERT_TRUE(NoErrors());
ASSERT_EQ(TakePayload(), "=20\r\ntxt:A simple string!\r\n") << "Resp3 VerbatimString TXT failed.";

builder_->SetResp3(true);
builder_->SetRespVersion(RespVersion::kResp3);
builder_->SendVerbatimString(str, RedisReplyBuilder::VerbatimFormat::MARKDOWN);
ASSERT_TRUE(NoErrors());
ASSERT_EQ(TakePayload(), "=20\r\nmkd:A simple string!\r\n") << "Resp3 VerbatimString TXT failed.";

builder_->SetResp3(false);
builder_->SetRespVersion(RespVersion::kResp2);
builder_->SendVerbatimString(str);
ASSERT_TRUE(NoErrors());
ASSERT_EQ(TakePayload(), "$16\r\nA simple string!\r\n") << "Resp3 VerbatimString TXT failed.";
Expand Down
3 changes: 2 additions & 1 deletion src/facade/reply_capture.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ class CapturingReplyBuilder : public RedisReplyBuilder {
struct SimpleString : public std::string {}; // SendSimpleString
struct BulkString : public std::string {}; // SendBulkString

CapturingReplyBuilder(ReplyMode mode = ReplyMode::FULL)
CapturingReplyBuilder(ReplyMode mode = ReplyMode::FULL, RespVersion resp_v = RespVersion::kResp2)
: RedisReplyBuilder{nullptr}, reply_mode_{mode}, stack_{}, current_{} {
SetRespVersion(resp_v);
}

using Payload = std::variant<std::monostate, Null, Error, long, double, SimpleString, BulkString,
Expand Down
17 changes: 10 additions & 7 deletions src/server/multi_command_squasher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,13 @@ bool MultiCommandSquasher::ExecuteStandalone(facade::RedisReplyBuilder* rb, Stor
return true;
}

OpStatus MultiCommandSquasher::SquashedHopCb(Transaction* parent_tx, EngineShard* es) {
OpStatus MultiCommandSquasher::SquashedHopCb(Transaction* parent_tx, EngineShard* es,
RespVersion resp_v) {
auto& sinfo = sharded_[es->shard_id()];
DCHECK(!sinfo.cmds.empty());

auto* local_tx = sinfo.local_tx.get();
facade::CapturingReplyBuilder crb;
facade::CapturingReplyBuilder crb(ReplyMode::FULL, resp_v);
ConnectionContext local_cntx{cntx_, local_tx};
if (cntx_->conn()) {
local_cntx.skip_acl_validation = cntx_->conn()->IsPrivileged();
Expand Down Expand Up @@ -244,14 +245,15 @@ bool MultiCommandSquasher::ExecuteSquashed(facade::RedisReplyBuilder* rb) {
cntx_->cid = base_cid_;
auto cb = [this](ShardId sid) { return !sharded_[sid].cmds.empty(); };
tx->PrepareSquashedMultiHop(base_cid_, cb);
tx->ScheduleSingleHop([this](auto* tx, auto* es) { return SquashedHopCb(tx, es); });
tx->ScheduleSingleHop(
[this, rb](auto* tx, auto* es) { return SquashedHopCb(tx, es, rb->GetRespVersion()); });
} else {
#if 1
fb2::BlockingCounter bc(num_shards);
DVLOG(1) << "Squashing " << num_shards << " " << tx->DebugId();

auto cb = [this, tx, bc]() mutable {
this->SquashedHopCb(tx, EngineShard::tlocal());
auto cb = [this, tx, bc, rb]() mutable {
this->SquashedHopCb(tx, EngineShard::tlocal(), rb->GetRespVersion());
bc->Dec();
};

Expand All @@ -261,8 +263,9 @@ bool MultiCommandSquasher::ExecuteSquashed(facade::RedisReplyBuilder* rb) {
}
bc->Wait();
#else
shard_set->RunBlockingInParallel([this, tx](auto* es) { SquashedHopCb(tx, es); },
[this](auto sid) { return !sharded_[sid].cmds.empty(); });
shard_set->RunBlockingInParallel(
[this, tx, rb](auto* es) { SquashedHopCb(tx, es, rb->GetRespVersion()); },
[this](auto sid) { return !sharded_[sid].cmds.empty(); });
#endif
}

Expand Down
3 changes: 2 additions & 1 deletion src/server/multi_command_squasher.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ class MultiCommandSquasher {
bool ExecuteStandalone(facade::RedisReplyBuilder* rb, StoredCmd* cmd);

// Callback that runs on shards during squashed hop.
facade::OpStatus SquashedHopCb(Transaction* parent_tx, EngineShard* es);
facade::OpStatus SquashedHopCb(Transaction* parent_tx, EngineShard* es,
facade::RespVersion resp_v);

// Execute all currently squashed commands. Return false if aborting on error.
bool ExecuteSquashed(facade::RedisReplyBuilder* rb);
Expand Down
4 changes: 2 additions & 2 deletions src/server/server_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2696,10 +2696,10 @@ void ServerFamily::Hello(CmdArgList args, const CommandContext& cmd_cntx) {
int proto_version = 2;
if (is_resp3) {
proto_version = 3;
rb->SetResp3(true);
rb->SetRespVersion(RespVersion::kResp3);
} else {
// Issuing hello 2 again is valid and should switch back to RESP2
rb->SetResp3(false);
rb->SetRespVersion(RespVersion::kResp2);
}

SinkReplyBuilder::ReplyAggregator agg(rb);
Expand Down

0 comments on commit e7eeffd

Please sign in to comment.