Skip to content

Commit d6b90bf

Browse files
Add TypeInfo type UntypedMapBase for generic operation support.
The data is not currently used. This change is what we need to hook it up from Rust. In a future change we will start using the data to simplify the interface. No expected performance change. The data inserted via static typing is constant evaluated, and fits in the existing padding (in 64-bit builds). Also, fix Rust bindings to take float/double into account now that the enum lists them. PiperOrigin-RevId: 698780360
1 parent 61662a5 commit d6b90bf

File tree

6 files changed

+238
-29
lines changed

6 files changed

+238
-29
lines changed

rust/cpp.rs

+33-15
Original file line numberDiff line numberDiff line change
@@ -790,23 +790,28 @@ impl UntypedMapIterator {
790790
#[doc(hidden)]
791791
#[repr(u8)]
792792
#[derive(Debug, PartialEq)]
793+
// Copy of UntypedMapBase::TypeKind
793794
pub enum MapValueTag {
794795
Bool,
795796
U32,
796797
U64,
798+
F32,
799+
F64,
797800
String,
798801
Message,
802+
Unknown,
799803
}
800-
801-
// For the purposes of FFI, we treat all numeric types of a given size the same
802-
// way. For example, u32, i32, and f32 values are all represented as a u32.
803-
// Likewise, u64, i64, and f64 values are all stored in a u64.
804+
// For the purposes of FFI, we treat all integral types of a given size the same
805+
// way. For example, u32 and i32 values are all represented as a u32.
806+
// Likewise, u64 and i64 values are all stored in a u64.
804807
#[doc(hidden)]
805808
#[repr(C)]
806809
pub union MapValueUnion {
807810
pub b: bool,
808811
pub u: u32,
809812
pub uu: u64,
813+
pub f: f32,
814+
pub ff: f64,
810815
// Generally speaking, if s is set then it should not be None. However, we
811816
// do set it to None in the special case where the MapValue is just a
812817
// "prototype" (see below). In that scenario, we just want to indicate the
@@ -838,6 +843,14 @@ impl MapValue {
838843
MapValue { tag: MapValueTag::U64, val: MapValueUnion { uu } }
839844
}
840845

846+
pub fn make_f32(f: f32) -> Self {
847+
MapValue { tag: MapValueTag::F32, val: MapValueUnion { f } }
848+
}
849+
850+
fn make_f64(ff: f64) -> Self {
851+
MapValue { tag: MapValueTag::F64, val: MapValueUnion { ff } }
852+
}
853+
841854
fn make_string(s: CppStdString) -> Self {
842855
MapValue { tag: MapValueTag::String, val: MapValueUnion { s: Some(s) } }
843856
}
@@ -921,27 +934,27 @@ impl CppMapTypeConversions for i64 {
921934

922935
impl CppMapTypeConversions for f32 {
923936
fn get_prototype() -> MapValue {
924-
MapValue::make_u32(0)
937+
MapValue::make_f32(0f32)
925938
}
926939
fn to_map_value(self) -> MapValue {
927-
MapValue::make_u32(self.to_bits())
940+
MapValue::make_f32(self)
928941
}
929942
unsafe fn from_map_value<'a>(value: MapValue) -> View<'a, Self> {
930-
debug_assert_eq!(value.tag, MapValueTag::U32);
931-
unsafe { Self::from_bits(value.val.u) }
943+
debug_assert_eq!(value.tag, MapValueTag::F32);
944+
unsafe { value.val.f }
932945
}
933946
}
934947

935948
impl CppMapTypeConversions for f64 {
936949
fn get_prototype() -> MapValue {
937-
MapValue::make_u64(0)
950+
MapValue::make_f64(0.0)
938951
}
939952
fn to_map_value(self) -> MapValue {
940-
MapValue::make_u64(self.to_bits())
953+
MapValue::make_f64(self)
941954
}
942955
unsafe fn from_map_value<'a>(value: MapValue) -> View<'a, Self> {
943-
debug_assert_eq!(value.tag, MapValueTag::U64);
944-
unsafe { Self::from_bits(value.val.uu) }
956+
debug_assert_eq!(value.tag, MapValueTag::F64);
957+
unsafe { value.val.ff }
945958
}
946959
}
947960

@@ -1099,11 +1112,16 @@ generate_map_key_impl!(
10991112

11001113
impl<Key, Value> ProxiedInMapValue<Key> for Value
11011114
where
1102-
Key: Proxied + MapKey,
1115+
Key: Proxied + MapKey + CppMapTypeConversions,
11031116
Value: Proxied + CppMapTypeConversions,
11041117
{
11051118
fn map_new(_private: Private) -> Map<Key, Self> {
1106-
unsafe { Map::from_inner(Private, InnerMap::new(proto2_rust_map_new())) }
1119+
unsafe {
1120+
Map::from_inner(
1121+
Private,
1122+
InnerMap::new(proto2_rust_map_new(Key::get_prototype(), Value::get_prototype())),
1123+
)
1124+
}
11071125
}
11081126

11091127
unsafe fn map_free(_private: Private, map: &mut Map<Key, Self>) {
@@ -1241,7 +1259,7 @@ impl_map_primitives!(
12411259
extern "C" {
12421260
fn proto2_rust_thunk_UntypedMapIterator_increment(iter: &mut UntypedMapIterator);
12431261

1244-
pub fn proto2_rust_map_new() -> RawMap;
1262+
pub fn proto2_rust_map_new(key_prototype: MapValue, value_prototype: MapValue) -> RawMap;
12451263
pub fn proto2_rust_map_size(m: RawMap) -> usize;
12461264
pub fn proto2_rust_map_iter(m: RawMap) -> UntypedMapIterator;
12471265
}

rust/cpp_kernel/map.cc

+34-10
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,17 @@ namespace protobuf {
1717
namespace rust {
1818
namespace {
1919

20-
// LINT.IfChange(map_ffi)
21-
enum class MapValueTag : uint8_t {
22-
kBool,
23-
kU32,
24-
kU64,
25-
kString,
26-
kMessage,
27-
};
20+
using MapValueTag = internal::UntypedMapBase::TypeKind;
2821

22+
// LINT.IfChange(map_ffi)
2923
struct MapValue {
3024
MapValueTag tag;
3125
union {
3226
bool b;
3327
uint32_t u32;
3428
uint64_t u64;
29+
float f32;
30+
double f64;
3531
std::string* s;
3632
google::protobuf::MessageLite* message;
3733
};
@@ -66,6 +62,14 @@ void GetSizeAndAlignment(MapValue value, uint16_t* size, uint8_t* alignment) {
6662
*size = sizeof(uint64_t);
6763
*alignment = alignof(uint64_t);
6864
break;
65+
case MapValueTag::kFloat:
66+
*size = sizeof(float);
67+
*alignment = alignof(float);
68+
break;
69+
case MapValueTag::kDouble:
70+
*size = sizeof(double);
71+
*alignment = alignof(double);
72+
break;
6973
case MapValueTag::kString:
7074
*size = sizeof(std::string);
7175
*alignment = alignof(std::string);
@@ -149,6 +153,12 @@ bool Insert(internal::UntypedMapBase* m, Key key, MapValue value) {
149153
case MapValueTag::kU64:
150154
*static_cast<uint64_t*>(value_ptr) = value.u64;
151155
break;
156+
case MapValueTag::kFloat:
157+
*static_cast<float*>(value_ptr) = value.f32;
158+
break;
159+
case MapValueTag::kDouble:
160+
*static_cast<double*>(value_ptr) = value.f64;
161+
break;
152162
case MapValueTag::kString:
153163
new (value_ptr) std::string(std::move(*value.s));
154164
delete value.s;
@@ -196,6 +206,12 @@ void PopulateMapValue(MapValueTag tag, void* data, MapValue& output) {
196206
case MapValueTag::kU64:
197207
output.u64 = *static_cast<const uint64_t*>(data);
198208
break;
209+
case MapValueTag::kFloat:
210+
output.f32 = *static_cast<const float*>(data);
211+
break;
212+
case MapValueTag::kDouble:
213+
output.f64 = *static_cast<const double*>(data);
214+
break;
199215
case MapValueTag::kString:
200216
output.s = static_cast<std::string*>(data);
201217
break;
@@ -294,8 +310,16 @@ void proto2_rust_thunk_UntypedMapIterator_increment(
294310
iter->PlusPlus();
295311
}
296312

297-
google::protobuf::internal::UntypedMapBase* proto2_rust_map_new() {
298-
return new google::protobuf::internal::UntypedMapBase(/* arena = */ nullptr);
313+
google::protobuf::internal::UntypedMapBase* proto2_rust_map_new(
314+
google::protobuf::rust::MapValue key_prototype,
315+
google::protobuf::rust::MapValue value_prototype) {
316+
return new google::protobuf::internal::UntypedMapBase(
317+
/* arena = */ nullptr,
318+
google::protobuf::internal::UntypedMapBase::GetTypeInfoDynamic(
319+
key_prototype.tag, value_prototype.tag,
320+
value_prototype.tag == google::protobuf::rust::MapValueTag::kMessage
321+
? value_prototype.message
322+
: nullptr));
299323
}
300324

301325
size_t proto2_rust_map_size(google::protobuf::internal::UntypedMapBase* m) {

src/google/protobuf/map.cc

+61
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,67 @@ size_t UntypedMapBase::SpaceUsedInTable(size_t sizeof_node) const {
105105
return size;
106106
}
107107

108+
static size_t AlignTo(size_t v, size_t alignment, size_t& max_align) {
109+
max_align = std::max<size_t>(max_align, alignment);
110+
return (v + alignment - 1) / alignment * alignment;
111+
}
112+
113+
struct Offsets {
114+
size_t start;
115+
size_t end;
116+
};
117+
118+
template <typename T>
119+
static Offsets AlignAndAddSize(size_t v, size_t& max_align) {
120+
v = AlignTo(v, alignof(T), max_align);
121+
return {v, v + sizeof(T)};
122+
}
123+
124+
static Offsets AlignAndAddSizeDynamic(
125+
size_t v, UntypedMapBase::TypeKind kind,
126+
const MessageLite* value_prototype_if_message, size_t& max_align) {
127+
switch (kind) {
128+
case UntypedMapBase::TypeKind::kBool:
129+
return AlignAndAddSize<bool>(v, max_align);
130+
case UntypedMapBase::TypeKind::kU32:
131+
return AlignAndAddSize<int32_t>(v, max_align);
132+
case UntypedMapBase::TypeKind::kU64:
133+
return AlignAndAddSize<int64_t>(v, max_align);
134+
case UntypedMapBase::TypeKind::kFloat:
135+
return AlignAndAddSize<float>(v, max_align);
136+
case UntypedMapBase::TypeKind::kDouble:
137+
return AlignAndAddSize<double>(v, max_align);
138+
case UntypedMapBase::TypeKind::kString:
139+
return AlignAndAddSize<std::string>(v, max_align);
140+
case UntypedMapBase::TypeKind::kMessage: {
141+
auto* class_data = GetClassData(*value_prototype_if_message);
142+
v = AlignTo(v, class_data->alignment(), max_align);
143+
return {v, v + class_data->allocation_size()};
144+
}
145+
default:
146+
Unreachable();
147+
}
148+
}
149+
150+
template <typename T, typename U>
151+
T Narrow(U value) {
152+
ABSL_CHECK_EQ(value, static_cast<T>(value));
153+
return static_cast<T>(value);
154+
}
155+
156+
UntypedMapBase::TypeInfo UntypedMapBase::GetTypeInfoDynamic(
157+
TypeKind key_type, TypeKind value_type,
158+
const MessageLite* value_prototype_if_message) {
159+
size_t max_align = alignof(NodeBase);
160+
const auto key_offsets =
161+
AlignAndAddSizeDynamic(sizeof(NodeBase), key_type, nullptr, max_align);
162+
const auto value_offsets = AlignAndAddSizeDynamic(
163+
key_offsets.end, value_type, value_prototype_if_message, max_align);
164+
return TypeInfo{
165+
Narrow<uint16_t>(AlignTo(value_offsets.end, max_align, max_align)),
166+
Narrow<uint8_t>(value_offsets.start), key_type, value_type};
167+
}
168+
108169
} // namespace internal
109170
} // namespace protobuf
110171
} // namespace google

src/google/protobuf/map.h

+69-4
Original file line numberDiff line numberDiff line change
@@ -393,10 +393,62 @@ class PROTOBUF_EXPORT UntypedMapBase {
393393
public:
394394
using size_type = size_t;
395395

396-
explicit constexpr UntypedMapBase(Arena* arena)
396+
// Possible types that a key/value can take.
397+
// LINT.IfChange(map_ffi)
398+
enum class TypeKind : uint8_t {
399+
kBool, // bool
400+
kU32, // int32_t, uint32_t, enums
401+
kU64, // int64_t, uint64_t
402+
kFloat, // float
403+
kDouble, // double
404+
kString, // std::string
405+
kMessage, // Derived from MessageLite
406+
kUnknown, // For DynamicMapField for now
407+
};
408+
// LINT.ThenChange(//depot/google3/third_party/protobuf/rust/cpp.rs:map_ffi)
409+
410+
template <typename T>
411+
static constexpr TypeKind StaticTypeKind() {
412+
if constexpr (std::is_same_v<T, bool>) {
413+
return TypeKind::kBool;
414+
} else if constexpr (std::is_same_v<T, int32_t> ||
415+
std::is_same_v<T, uint32_t>) {
416+
return TypeKind::kU32;
417+
} else if constexpr (std::is_same_v<T, int64_t> ||
418+
std::is_same_v<T, uint64_t>) {
419+
return TypeKind::kU64;
420+
} else if constexpr (std::is_same_v<T, float>) {
421+
return TypeKind::kFloat;
422+
} else if constexpr (std::is_same_v<T, double>) {
423+
return TypeKind::kDouble;
424+
} else if constexpr (std::is_same_v<T, std::string>) {
425+
return TypeKind::kString;
426+
} else if constexpr (std::is_base_of_v<MessageLite, T>) {
427+
return TypeKind::kMessage;
428+
} else {
429+
return TypeKind::kUnknown;
430+
}
431+
}
432+
433+
struct TypeInfo {
434+
// Equivalent to `sizeof(Node)` in the derived type.
435+
uint16_t node_size;
436+
// Equivalent to `offsetof(Node, kv.second)` in the derived type.
437+
uint8_t value_offset;
438+
TypeKind key_type : 4;
439+
TypeKind value_type : 4;
440+
};
441+
static_assert(sizeof(TypeInfo) == 4);
442+
443+
static TypeInfo GetTypeInfoDynamic(
444+
TypeKind key_type, TypeKind value_type,
445+
const MessageLite* value_prototype_if_message);
446+
447+
explicit constexpr UntypedMapBase(Arena* arena, TypeInfo type_info)
397448
: num_elements_(0),
398449
num_buckets_(internal::kGlobalEmptyTableSize),
399450
index_of_first_non_null_(internal::kGlobalEmptyTableSize),
451+
type_info_(type_info),
400452
table_(const_cast<NodeBase**>(internal::kGlobalEmptyTable)),
401453
alloc_(arena) {}
402454

@@ -414,6 +466,7 @@ class PROTOBUF_EXPORT UntypedMapBase {
414466
std::swap(num_elements_, other->num_elements_);
415467
std::swap(num_buckets_, other->num_buckets_);
416468
std::swap(index_of_first_non_null_, other->index_of_first_non_null_);
469+
std::swap(type_info_, other->type_info_);
417470
std::swap(table_, other->table_);
418471
std::swap(alloc_, other->alloc_);
419472
}
@@ -557,6 +610,7 @@ class PROTOBUF_EXPORT UntypedMapBase {
557610
map_index_t num_elements_;
558611
map_index_t num_buckets_;
559612
map_index_t index_of_first_non_null_;
613+
TypeInfo type_info_;
560614
NodeBase** table_; // an array with num_buckets_ entries
561615
Allocator alloc_;
562616
};
@@ -952,12 +1006,14 @@ class Map : private internal::KeyMapBase<internal::KeyForBase<Key>> {
9521006
using size_type = size_t;
9531007
using hasher = absl::Hash<typename TS::ViewType>;
9541008

955-
constexpr Map() : Base(nullptr) { StaticValidityCheck(); }
1009+
constexpr Map() : Base(nullptr, GetTypeInfo()) { StaticValidityCheck(); }
9561010
Map(const Map& other) : Map(nullptr, other) {}
9571011

9581012
// Internal Arena constructors: do not use!
9591013
// TODO: remove non internal ctors
960-
explicit Map(Arena* arena) : Base(arena) { StaticValidityCheck(); }
1014+
explicit Map(Arena* arena) : Base(arena, GetTypeInfo()) {
1015+
StaticValidityCheck();
1016+
}
9611017
Map(internal::InternalVisibility, Arena* arena) : Map(arena) {}
9621018
Map(internal::InternalVisibility, Arena* arena, const Map& other)
9631019
: Map(arena, other) {}
@@ -997,7 +1053,7 @@ class Map : private internal::KeyMapBase<internal::KeyForBase<Key>> {
9971053
}
9981054

9991055
private:
1000-
Map(Arena* arena, const Map& other) : Base(arena) {
1056+
Map(Arena* arena, const Map& other) : Map(arena) {
10011057
StaticValidityCheck();
10021058
insert(other.begin(), other.end());
10031059
}
@@ -1386,6 +1442,15 @@ class Map : private internal::KeyMapBase<internal::KeyForBase<Key>> {
13861442
value_type kv;
13871443
};
13881444

1445+
static constexpr auto GetTypeInfo() {
1446+
return internal::UntypedMapBase::TypeInfo{
1447+
sizeof(Node),
1448+
PROTOBUF_FIELD_OFFSET(Node, kv.second),
1449+
internal::UntypedMapBase::StaticTypeKind<Key>(),
1450+
internal::UntypedMapBase::StaticTypeKind<T>(),
1451+
};
1452+
}
1453+
13891454
void DestroyNode(Node* node) {
13901455
if (this->alloc_.arena() == nullptr) {
13911456
node->kv.first.~key_type();

0 commit comments

Comments
 (0)