Skip to content

Commit

Permalink
Use generic DeleteNode to reduce code size of erase in Map and to…
Browse files Browse the repository at this point in the history
… simplify the parsing logic in `MpMap`.

PiperOrigin-RevId: 704832360
  • Loading branch information
protobuf-github-bot authored and copybara-github committed Dec 10, 2024
1 parent 4b397e5 commit 828716e
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 158 deletions.
40 changes: 11 additions & 29 deletions rust/cpp_kernel/map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <utility>

#include "absl/functional/overload.h"
#include "absl/log/absl_log.h"
#include "absl/strings/string_view.h"
#include "google/protobuf/message.h"
#include "google/protobuf/message_lite.h"
Expand Down Expand Up @@ -48,6 +49,13 @@ template <typename Key>
using KeyMap = internal::KeyMapBase<
internal::KeyForBase<typename FromViewType<Key>::type>>;

template <typename T>
T AsViewType(T t) {
return t;
}

absl::string_view AsViewType(PtrAndLen key) { return key.AsStringView(); }

void InitializeMessageValue(void* raw_ptr, MessageLite* msg) {
MessageLite* new_msg = internal::RustMapHelper::PlacementNew(msg, raw_ptr);
auto* full_msg = DynamicCastMessage<Message>(new_msg);
Expand Down Expand Up @@ -86,28 +94,8 @@ bool Insert(internal::UntypedMapBase* m, Key key, MapValue value) {
},
});

node = internal::RustMapHelper::InsertOrReplaceNode(
return internal::RustMapHelper::InsertOrReplaceNode(
static_cast<KeyMap<Key>*>(m), node);
if (node == nullptr) {
return true;
}
internal::RustMapHelper::DeleteNode(m, node);
return false;
}

template <typename Map, typename Key,
typename = typename std::enable_if<
!std::is_same<Key, google::protobuf::rust::PtrAndLen>::value>::type>
internal::RustMapHelper::NodeAndBucket FindHelper(Map* m, Key key) {
return internal::RustMapHelper::FindHelper(
m, static_cast<internal::KeyForBase<Key>>(key));
}

template <typename Map>
internal::RustMapHelper::NodeAndBucket FindHelper(Map* m,
google::protobuf::rust::PtrAndLen key) {
return internal::RustMapHelper::FindHelper(
m, absl::string_view(key.ptr, key.len));
}

void PopulateMapValue(const internal::UntypedMapBase& map,
Expand Down Expand Up @@ -147,7 +135,7 @@ void PopulateMapValue(const internal::UntypedMapBase& map,
template <typename Key>
bool Get(internal::UntypedMapBase* m, Key key, MapValue* value) {
auto* map_base = static_cast<KeyMap<Key>*>(m);
internal::RustMapHelper::NodeAndBucket result = FindHelper(map_base, key);
auto result = internal::RustMapHelper::FindHelper(map_base, AsViewType(key));
if (result.node == nullptr) {
return false;
}
Expand All @@ -158,13 +146,7 @@ bool Get(internal::UntypedMapBase* m, Key key, MapValue* value) {
template <typename Key>
bool Remove(internal::UntypedMapBase* m, Key key) {
auto* map_base = static_cast<KeyMap<Key>*>(m);
internal::RustMapHelper::NodeAndBucket result = FindHelper(map_base, key);
if (result.node == nullptr) {
return false;
}
internal::RustMapHelper::EraseNoDestroy(map_base, result.bucket, result.node);
internal::RustMapHelper::DeleteNode(m, result.node);
return true;
return internal::RustMapHelper::EraseImpl(map_base, AsViewType(key));
}

template <typename Key>
Expand Down
2 changes: 1 addition & 1 deletion src/google/protobuf/dynamic_message.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ bool DynamicMapField::DeleteMapValueImpl(MapFieldBase& base,
if (self.arena() == nullptr) {
it->second.DeleteData();
}
self.map_.erase(it);
self.map_.EraseDynamic(it);
return true;
}

Expand Down
4 changes: 2 additions & 2 deletions src/google/protobuf/generated_message_tctable_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1007,8 +1007,8 @@ class PROTOBUF_EXPORT TcParser final {

static void WriteMapEntryAsUnknown(MessageLite* msg,
const TcParseTableBase* table,
uint32_t tag, NodeBase* node,
MapAuxInfo map_info);
UntypedMapBase& map, uint32_t tag,
NodeBase* node, MapAuxInfo map_info);

static const char* ParseOneMapEntry(NodeBase* node, const char* ptr,
ParseContext* ctx,
Expand Down
167 changes: 72 additions & 95 deletions src/google/protobuf/generated_message_tctable_lite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2600,63 +2600,49 @@ const char* TcParser::MpRepeatedMessageOrGroup(PROTOBUF_TC_PARAM_DECL) {
PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS);
}

static void SerializeMapKey(const NodeBase* node, MapTypeCard type_card,
static void SerializeMapKey(UntypedMapBase& map, NodeBase* node,
MapTypeCard type_card,
io::CodedOutputStream& coded_output) {
switch (type_card.wiretype()) {
case WireFormatLite::WIRETYPE_VARINT:
switch (type_card.cpp_type()) {
case MapTypeCard::kBool:
WireFormatLite::WriteBool(
1, static_cast<const KeyNode<bool>*>(node)->key(), &coded_output);
break;
case MapTypeCard::k32:
if (type_card.is_zigzag()) {
WireFormatLite::WriteSInt32(
1, static_cast<const KeyNode<uint32_t>*>(node)->key(),
&coded_output);
} else if (type_card.is_signed()) {
WireFormatLite::WriteInt32(
1, static_cast<const KeyNode<uint32_t>*>(node)->key(),
&coded_output);
} else {
WireFormatLite::WriteUInt32(
1, static_cast<const KeyNode<uint32_t>*>(node)->key(),
&coded_output);
}
break;
case MapTypeCard::k64:
if (type_card.is_zigzag()) {
WireFormatLite::WriteSInt64(
1, static_cast<const KeyNode<uint64_t>*>(node)->key(),
&coded_output);
} else if (type_card.is_signed()) {
WireFormatLite::WriteInt64(
1, static_cast<const KeyNode<uint64_t>*>(node)->key(),
&coded_output);
} else {
WireFormatLite::WriteUInt64(
1, static_cast<const KeyNode<uint64_t>*>(node)->key(),
&coded_output);
}
break;
default:
Unreachable();
}
map.VisitKey(node, //
absl::Overload{
[&](const bool* v) {
WireFormatLite::WriteBool(1, *v, &coded_output);
},
[&](const uint32_t* v) {
if (type_card.is_zigzag()) {
WireFormatLite::WriteSInt32(1, *v, &coded_output);
} else if (type_card.is_signed()) {
WireFormatLite::WriteInt32(1, *v, &coded_output);
} else {
WireFormatLite::WriteUInt32(1, *v, &coded_output);
}
},
[&](const uint64_t* v) {
if (type_card.is_zigzag()) {
WireFormatLite::WriteSInt64(1, *v, &coded_output);
} else if (type_card.is_signed()) {
WireFormatLite::WriteInt64(1, *v, &coded_output);
} else {
WireFormatLite::WriteUInt64(1, *v, &coded_output);
}
},
[](const void*) { Unreachable(); },
});
break;
case WireFormatLite::WIRETYPE_FIXED32:
WireFormatLite::WriteFixed32(
1, static_cast<const KeyNode<uint32_t>*>(node)->key(), &coded_output);
WireFormatLite::WriteFixed32(1, *map.GetKey<uint32_t>(node),
&coded_output);
break;
case WireFormatLite::WIRETYPE_FIXED64:
WireFormatLite::WriteFixed64(
1, static_cast<const KeyNode<uint64_t>*>(node)->key(), &coded_output);
WireFormatLite::WriteFixed64(1, *map.GetKey<uint64_t>(node),
&coded_output);
break;
case WireFormatLite::WIRETYPE_LENGTH_DELIMITED:
// We should never have a message here. They can only be values maps.
ABSL_DCHECK_EQ(+type_card.cpp_type(), +MapTypeCard::kString);
WireFormatLite::WriteString(
1, static_cast<const KeyNode<std::string>*>(node)->key(),
&coded_output);
WireFormatLite::WriteString(1, *map.GetKey<std::string>(node),
&coded_output);
break;
default:
Unreachable();
Expand All @@ -2665,21 +2651,22 @@ static void SerializeMapKey(const NodeBase* node, MapTypeCard type_card,

void TcParser::WriteMapEntryAsUnknown(MessageLite* msg,
const TcParseTableBase* table,
uint32_t tag, NodeBase* node,
MapAuxInfo map_info) {
UntypedMapBase& map, uint32_t tag,
NodeBase* node, MapAuxInfo map_info) {
std::string serialized;
{
io::StringOutputStream string_output(&serialized);
io::CodedOutputStream coded_output(&string_output);
SerializeMapKey(node, map_info.key_type_card, coded_output);
SerializeMapKey(map, node, map_info.key_type_card, coded_output);
// The mapped_type is always an enum here.
ABSL_DCHECK(map_info.value_is_validated_enum);
WireFormatLite::WriteInt32(2,
*reinterpret_cast<int32_t*>(
node->GetVoidValue(map_info.node_size_info)),
&coded_output);
WireFormatLite::WriteInt32(2, *map.GetValue<int32_t>(node), &coded_output);
}
GetUnknownFieldOps(table).write_length_delimited(msg, tag >> 3, serialized);

if (map.arena() == nullptr) {
map.DeleteNode(node);
}
}

template <typename T>
Expand Down Expand Up @@ -2865,51 +2852,41 @@ PROTOBUF_NOINLINE const char* TcParser::MpMap(PROTOBUF_TC_PARAM_DECL) {
return ParseOneMapEntry(node, ptr, ctx, aux, table, entry, map.arena());
});

if (ABSL_PREDICT_TRUE(ptr != nullptr)) {
if (ABSL_PREDICT_FALSE(map_info.value_is_validated_enum &&
!internal::ValidateEnumInlined(
*static_cast<int32_t*>(node->GetVoidValue(
map_info.node_size_info)),
aux[1].enum_data))) {
WriteMapEntryAsUnknown(msg, table, saved_tag, node, map_info);
} else {
// Done parsing the node, try to insert it.
// If it overwrites something we get old node back to destroy it.
switch (map_info.key_type_card.cpp_type()) {
case MapTypeCard::kBool:
node = static_cast<KeyMapBase<bool>&>(map).InsertOrReplaceNode(
static_cast<KeyMapBase<bool>::KeyNode*>(node));
break;
case MapTypeCard::k32:
node = static_cast<KeyMapBase<uint32_t>&>(map).InsertOrReplaceNode(
static_cast<KeyMapBase<uint32_t>::KeyNode*>(node));
break;
case MapTypeCard::k64:
node = static_cast<KeyMapBase<uint64_t>&>(map).InsertOrReplaceNode(
static_cast<KeyMapBase<uint64_t>::KeyNode*>(node));
break;
case MapTypeCard::kString:
node =
static_cast<KeyMapBase<std::string>&>(map).InsertOrReplaceNode(
static_cast<KeyMapBase<std::string>::KeyNode*>(node));
break;
default:
Unreachable();
}
}
}

// Destroy the node if we have it.
// It could be because we failed to parse, or because insertion returned
// an overwritten node.
if (ABSL_PREDICT_FALSE(node != nullptr && map.arena() == nullptr)) {
map.DeleteNode(node);
}

if (ABSL_PREDICT_FALSE(ptr == nullptr)) {
// Parsing failed. Delete the node that we didn't insert.
if (map.arena() == nullptr) map.DeleteNode(node);
PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS);
}

if (ABSL_PREDICT_FALSE(
map_info.value_is_validated_enum &&
!internal::ValidateEnumInlined(*map.GetValue<int32_t>(node),
aux[1].enum_data))) {
WriteMapEntryAsUnknown(msg, table, map, saved_tag, node, map_info);
} else {
// Done parsing the node, insert it.
switch (map_info.key_type_card.cpp_type()) {
case MapTypeCard::kBool:
static_cast<KeyMapBase<bool>&>(map).InsertOrReplaceNode(
static_cast<KeyMapBase<bool>::KeyNode*>(node));
break;
case MapTypeCard::k32:
static_cast<KeyMapBase<uint32_t>&>(map).InsertOrReplaceNode(
static_cast<KeyMapBase<uint32_t>::KeyNode*>(node));
break;
case MapTypeCard::k64:
static_cast<KeyMapBase<uint64_t>&>(map).InsertOrReplaceNode(
static_cast<KeyMapBase<uint64_t>::KeyNode*>(node));
break;
case MapTypeCard::kString:
static_cast<KeyMapBase<std::string>&>(map).InsertOrReplaceNode(
static_cast<KeyMapBase<std::string>::KeyNode*>(node));
break;
default:
Unreachable();
}
}

if (ABSL_PREDICT_FALSE(!ctx->DataAvailable(ptr))) {
PROTOBUF_MUSTTAIL return ToParseLoop(PROTOBUF_TC_PARAM_NO_DATA_PASS);
}
Expand Down
Loading

0 comments on commit 828716e

Please sign in to comment.