diff --git a/rust/cpp_kernel/map.cc b/rust/cpp_kernel/map.cc index 4e4f78ca042c..66265e3140af 100644 --- a/rust/cpp_kernel/map.cc +++ b/rust/cpp_kernel/map.cc @@ -7,6 +7,7 @@ #include #include "absl/functional/overload.h" +#include "absl/log/absl_log.h" #include "absl/strings/string_view.h" #include "google/protobuf/message.h" #include "google/protobuf/message_lite.h" @@ -48,6 +49,13 @@ template using KeyMap = internal::KeyMapBase< internal::KeyForBase::type>>; +template +T AsViewType(T t) { + return t; +} + +absl::string_view AsViewType(PtrAndLen key) { return key.AsStringView(); } + void InitializeMessageValue(void* raw_ptr, MessageLite* msg) { MessageLite* new_msg = internal::RustMapHelper::PlacementNew(msg, raw_ptr); auto* full_msg = DynamicCastMessage(new_msg); @@ -86,28 +94,8 @@ bool Insert(internal::UntypedMapBase* m, Key key, MapValue value) { }, }); - node = internal::RustMapHelper::InsertOrReplaceNode( + return internal::RustMapHelper::InsertOrReplaceNode( static_cast*>(m), node); - if (node == nullptr) { - return true; - } - internal::RustMapHelper::DeleteNode(m, node); - return false; -} - -template ::value>::type> -internal::RustMapHelper::NodeAndBucket FindHelper(Map* m, Key key) { - return internal::RustMapHelper::FindHelper( - m, static_cast>(key)); -} - -template -internal::RustMapHelper::NodeAndBucket FindHelper(Map* m, - google::protobuf::rust::PtrAndLen key) { - return internal::RustMapHelper::FindHelper( - m, absl::string_view(key.ptr, key.len)); } void PopulateMapValue(const internal::UntypedMapBase& map, @@ -147,7 +135,7 @@ void PopulateMapValue(const internal::UntypedMapBase& map, template bool Get(internal::UntypedMapBase* m, Key key, MapValue* value) { auto* map_base = static_cast*>(m); - internal::RustMapHelper::NodeAndBucket result = FindHelper(map_base, key); + auto result = internal::RustMapHelper::FindHelper(map_base, AsViewType(key)); if (result.node == nullptr) { return false; } @@ -158,13 +146,7 @@ bool Get(internal::UntypedMapBase* m, Key key, MapValue* value) { template bool Remove(internal::UntypedMapBase* m, Key key) { auto* map_base = static_cast*>(m); - internal::RustMapHelper::NodeAndBucket result = FindHelper(map_base, key); - if (result.node == nullptr) { - return false; - } - internal::RustMapHelper::EraseNoDestroy(map_base, result.bucket, result.node); - internal::RustMapHelper::DeleteNode(m, result.node); - return true; + return internal::RustMapHelper::EraseImpl(map_base, AsViewType(key)); } template diff --git a/src/google/protobuf/dynamic_message.cc b/src/google/protobuf/dynamic_message.cc index d1d8ec6db7a6..ba7cfaad2392 100644 --- a/src/google/protobuf/dynamic_message.cc +++ b/src/google/protobuf/dynamic_message.cc @@ -269,7 +269,7 @@ bool DynamicMapField::DeleteMapValueImpl(MapFieldBase& base, if (self.arena() == nullptr) { it->second.DeleteData(); } - self.map_.erase(it); + self.map_.EraseDynamic(it); return true; } diff --git a/src/google/protobuf/generated_message_tctable_impl.h b/src/google/protobuf/generated_message_tctable_impl.h index 45ec731938d4..721a60cf1b71 100644 --- a/src/google/protobuf/generated_message_tctable_impl.h +++ b/src/google/protobuf/generated_message_tctable_impl.h @@ -1007,8 +1007,8 @@ class PROTOBUF_EXPORT TcParser final { static void WriteMapEntryAsUnknown(MessageLite* msg, const TcParseTableBase* table, - uint32_t tag, NodeBase* node, - MapAuxInfo map_info); + UntypedMapBase& map, uint32_t tag, + NodeBase* node, MapAuxInfo map_info); static const char* ParseOneMapEntry(NodeBase* node, const char* ptr, ParseContext* ctx, diff --git a/src/google/protobuf/generated_message_tctable_lite.cc b/src/google/protobuf/generated_message_tctable_lite.cc index 207c4b8f34e8..5e7d40f6d196 100644 --- a/src/google/protobuf/generated_message_tctable_lite.cc +++ b/src/google/protobuf/generated_message_tctable_lite.cc @@ -2600,63 +2600,49 @@ const char* TcParser::MpRepeatedMessageOrGroup(PROTOBUF_TC_PARAM_DECL) { PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS); } -static void SerializeMapKey(const NodeBase* node, MapTypeCard type_card, +static void SerializeMapKey(UntypedMapBase& map, NodeBase* node, + MapTypeCard type_card, io::CodedOutputStream& coded_output) { switch (type_card.wiretype()) { case WireFormatLite::WIRETYPE_VARINT: - switch (type_card.cpp_type()) { - case MapTypeCard::kBool: - WireFormatLite::WriteBool( - 1, static_cast*>(node)->key(), &coded_output); - break; - case MapTypeCard::k32: - if (type_card.is_zigzag()) { - WireFormatLite::WriteSInt32( - 1, static_cast*>(node)->key(), - &coded_output); - } else if (type_card.is_signed()) { - WireFormatLite::WriteInt32( - 1, static_cast*>(node)->key(), - &coded_output); - } else { - WireFormatLite::WriteUInt32( - 1, static_cast*>(node)->key(), - &coded_output); - } - break; - case MapTypeCard::k64: - if (type_card.is_zigzag()) { - WireFormatLite::WriteSInt64( - 1, static_cast*>(node)->key(), - &coded_output); - } else if (type_card.is_signed()) { - WireFormatLite::WriteInt64( - 1, static_cast*>(node)->key(), - &coded_output); - } else { - WireFormatLite::WriteUInt64( - 1, static_cast*>(node)->key(), - &coded_output); - } - break; - default: - Unreachable(); - } + map.VisitKey(node, // + absl::Overload{ + [&](const bool* v) { + WireFormatLite::WriteBool(1, *v, &coded_output); + }, + [&](const uint32_t* v) { + if (type_card.is_zigzag()) { + WireFormatLite::WriteSInt32(1, *v, &coded_output); + } else if (type_card.is_signed()) { + WireFormatLite::WriteInt32(1, *v, &coded_output); + } else { + WireFormatLite::WriteUInt32(1, *v, &coded_output); + } + }, + [&](const uint64_t* v) { + if (type_card.is_zigzag()) { + WireFormatLite::WriteSInt64(1, *v, &coded_output); + } else if (type_card.is_signed()) { + WireFormatLite::WriteInt64(1, *v, &coded_output); + } else { + WireFormatLite::WriteUInt64(1, *v, &coded_output); + } + }, + [](const void*) { Unreachable(); }, + }); break; case WireFormatLite::WIRETYPE_FIXED32: - WireFormatLite::WriteFixed32( - 1, static_cast*>(node)->key(), &coded_output); + WireFormatLite::WriteFixed32(1, *map.GetKey(node), + &coded_output); break; case WireFormatLite::WIRETYPE_FIXED64: - WireFormatLite::WriteFixed64( - 1, static_cast*>(node)->key(), &coded_output); + WireFormatLite::WriteFixed64(1, *map.GetKey(node), + &coded_output); break; case WireFormatLite::WIRETYPE_LENGTH_DELIMITED: // We should never have a message here. They can only be values maps. - ABSL_DCHECK_EQ(+type_card.cpp_type(), +MapTypeCard::kString); - WireFormatLite::WriteString( - 1, static_cast*>(node)->key(), - &coded_output); + WireFormatLite::WriteString(1, *map.GetKey(node), + &coded_output); break; default: Unreachable(); @@ -2665,21 +2651,22 @@ static void SerializeMapKey(const NodeBase* node, MapTypeCard type_card, void TcParser::WriteMapEntryAsUnknown(MessageLite* msg, const TcParseTableBase* table, - uint32_t tag, NodeBase* node, - MapAuxInfo map_info) { + UntypedMapBase& map, uint32_t tag, + NodeBase* node, MapAuxInfo map_info) { std::string serialized; { io::StringOutputStream string_output(&serialized); io::CodedOutputStream coded_output(&string_output); - SerializeMapKey(node, map_info.key_type_card, coded_output); + SerializeMapKey(map, node, map_info.key_type_card, coded_output); // The mapped_type is always an enum here. ABSL_DCHECK(map_info.value_is_validated_enum); - WireFormatLite::WriteInt32(2, - *reinterpret_cast( - node->GetVoidValue(map_info.node_size_info)), - &coded_output); + WireFormatLite::WriteInt32(2, *map.GetValue(node), &coded_output); } GetUnknownFieldOps(table).write_length_delimited(msg, tag >> 3, serialized); + + if (map.arena() == nullptr) { + map.DeleteNode(node); + } } template @@ -2865,51 +2852,41 @@ PROTOBUF_NOINLINE const char* TcParser::MpMap(PROTOBUF_TC_PARAM_DECL) { return ParseOneMapEntry(node, ptr, ctx, aux, table, entry, map.arena()); }); - if (ABSL_PREDICT_TRUE(ptr != nullptr)) { - if (ABSL_PREDICT_FALSE(map_info.value_is_validated_enum && - !internal::ValidateEnumInlined( - *static_cast(node->GetVoidValue( - map_info.node_size_info)), - aux[1].enum_data))) { - WriteMapEntryAsUnknown(msg, table, saved_tag, node, map_info); - } else { - // Done parsing the node, try to insert it. - // If it overwrites something we get old node back to destroy it. - switch (map_info.key_type_card.cpp_type()) { - case MapTypeCard::kBool: - node = static_cast&>(map).InsertOrReplaceNode( - static_cast::KeyNode*>(node)); - break; - case MapTypeCard::k32: - node = static_cast&>(map).InsertOrReplaceNode( - static_cast::KeyNode*>(node)); - break; - case MapTypeCard::k64: - node = static_cast&>(map).InsertOrReplaceNode( - static_cast::KeyNode*>(node)); - break; - case MapTypeCard::kString: - node = - static_cast&>(map).InsertOrReplaceNode( - static_cast::KeyNode*>(node)); - break; - default: - Unreachable(); - } - } - } - - // Destroy the node if we have it. - // It could be because we failed to parse, or because insertion returned - // an overwritten node. - if (ABSL_PREDICT_FALSE(node != nullptr && map.arena() == nullptr)) { - map.DeleteNode(node); - } - if (ABSL_PREDICT_FALSE(ptr == nullptr)) { + // Parsing failed. Delete the node that we didn't insert. + if (map.arena() == nullptr) map.DeleteNode(node); PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS); } + if (ABSL_PREDICT_FALSE( + map_info.value_is_validated_enum && + !internal::ValidateEnumInlined(*map.GetValue(node), + aux[1].enum_data))) { + WriteMapEntryAsUnknown(msg, table, map, saved_tag, node, map_info); + } else { + // Done parsing the node, insert it. + switch (map_info.key_type_card.cpp_type()) { + case MapTypeCard::kBool: + static_cast&>(map).InsertOrReplaceNode( + static_cast::KeyNode*>(node)); + break; + case MapTypeCard::k32: + static_cast&>(map).InsertOrReplaceNode( + static_cast::KeyNode*>(node)); + break; + case MapTypeCard::k64: + static_cast&>(map).InsertOrReplaceNode( + static_cast::KeyNode*>(node)); + break; + case MapTypeCard::kString: + static_cast&>(map).InsertOrReplaceNode( + static_cast::KeyNode*>(node)); + break; + default: + Unreachable(); + } + } + if (ABSL_PREDICT_FALSE(!ctx->DataAvailable(ptr))) { PROTOBUF_MUSTTAIL return ToParseLoop(PROTOBUF_TC_PARAM_NO_DATA_PASS); } diff --git a/src/google/protobuf/map.h b/src/google/protobuf/map.h index c0916c9a607c..5f16a4c76f98 100644 --- a/src/google/protobuf/map.h +++ b/src/google/protobuf/map.h @@ -680,7 +680,8 @@ class KeyMapBase : public UntypedMapBase { friend class RustMapHelper; friend class v2::TableDriven; - PROTOBUF_NOINLINE void erase_no_destroy(map_index_t b, KeyNode* node) { + PROTOBUF_NOINLINE size_type EraseImpl(map_index_t b, KeyNode* node, + bool do_destroy) { // Force bucket_index to be in range. b &= (num_buckets_ - 1); @@ -708,6 +709,20 @@ class KeyMapBase : public UntypedMapBase { ++index_of_first_non_null_; } } + + if (arena() == nullptr && do_destroy) { + DeleteNode(node); + } + + // To allow for the other overload of EraseImpl to do a tail call. + return 1; + } + + PROTOBUF_NOINLINE size_type EraseImpl(typename TS::ViewType k) { + if (auto result = FindHelper(k); result.node != nullptr) { + return EraseImpl(result.bucket, static_cast(result.node), true); + } + return 0; } NodeAndBucket FindHelper(typename TS::ViewType k) const { @@ -721,22 +736,20 @@ class KeyMapBase : public UntypedMapBase { } // Insert the given node. - // If the key is a duplicate, it inserts the new node and returns the old one. - // Gives ownership to the caller. - // If the key is unique, it returns `nullptr`. - KeyNode* InsertOrReplaceNode(KeyNode* node) { - KeyNode* to_erase = nullptr; + // If the key is a duplicate, it inserts the new node and deletes the old one. + bool InsertOrReplaceNode(KeyNode* node) { + bool is_new = true; auto p = this->FindHelper(node->key()); map_index_t b = p.bucket; - if (p.node != nullptr) { - erase_no_destroy(p.bucket, static_cast(p.node)); - to_erase = static_cast(p.node); + if (ABSL_PREDICT_FALSE(p.node != nullptr)) { + EraseImpl(p.bucket, static_cast(p.node), true); + is_new = false; } else if (ResizeIfLoadIsOutOfRange(num_elements_ + 1)) { b = BucketNumber(node->key()); // bucket_number } InsertUnique(b, node); ++num_elements_; - return to_erase; + return is_new; } // Insert the given Node in bucket b. If that would make bucket b too big, @@ -876,13 +889,13 @@ class RustMapHelper { } template - static typename Map::KeyNode* InsertOrReplaceNode(Map* m, NodeBase* node) { + static bool InsertOrReplaceNode(Map* m, NodeBase* node) { return m->InsertOrReplaceNode(static_cast(node)); } - template - static void EraseNoDestroy(Map* m, map_index_t bucket, NodeBase* node) { - m->erase_no_destroy(bucket, static_cast(node)); + template + static bool EraseImpl(Map* m, const Key& key) { + return m->EraseImpl(key); } static google::protobuf::MessageLite* PlacementNew(const MessageLite* prototype, @@ -1296,21 +1309,13 @@ class Map : private internal::KeyMapBase> { // Erase and clear template size_type erase(const key_arg& key) { - iterator it = find(key); - if (it == end()) { - return 0; - } else { - erase(it); - return 1; - } + return this->EraseImpl(TS::ToView(key)); } iterator erase(iterator pos) ABSL_ATTRIBUTE_LIFETIME_BOUND { auto next = std::next(pos); ABSL_DCHECK_EQ(pos.m_, static_cast(this)); - auto* node = static_cast(pos.node_); - this->erase_no_destroy(pos.bucket_index_, node); - DeleteNode(node); + this->EraseImpl(pos.bucket_index_, static_cast(pos.node_), true); return next; } @@ -1429,6 +1434,15 @@ class Map : private internal::KeyMapBase> { true); } + // For DynamicMapField, which needs a special destructor. + void EraseDynamic(iterator it) { + this->EraseImpl(it.bucket_index_, + static_cast(it.node_), false); + if (this->arena() == nullptr) { + delete static_cast(it.node_); + } + } + using Base::arena; friend class Arena; @@ -1438,6 +1452,7 @@ class Map : private internal::KeyMapBase> { using DestructorSkippable_ = void; template friend class internal::MapFieldLite; + friend class internal::DynamicMapField; friend class internal::TcParser; friend struct internal::MapTestPeer; friend struct internal::MapBenchmarkPeer; diff --git a/src/google/protobuf/map_test.inc b/src/google/protobuf/map_test.inc index 3105cdfbd34e..cbc21d17b285 100644 --- a/src/google/protobuf/map_test.inc +++ b/src/google/protobuf/map_test.inc @@ -183,13 +183,7 @@ struct MapTestPeer { using Node = typename T::Node; auto* node = static_cast(map.AllocNode(sizeof(Node))); ::new (static_cast(&node->kv)) typename T::value_type{key, value}; - node = static_cast(GetKeyMapBase(map).InsertOrReplaceNode(node)); - if (node) { - node->~Node(); - GetKeyMapBase(map).DeallocNode(node, sizeof(Node)); - return false; - } - return true; + return GetKeyMapBase(map).InsertOrReplaceNode(node); } template