Skip to content

Commit

Permalink
Add TypeInfo type UntypedMapBase for generic operation support.
Browse files Browse the repository at this point in the history
The data is not currently used. This change is what we need to hook it up from Rust. In a future change we will start using the data to simplify the interface.

No expected performance change. The data inserted via static typing is constant evaluated, and fits in the existing padding (in 64-bit builds).

Also, fix Rust bindings to take float/double into account now that the enum lists them.

PiperOrigin-RevId: 698780360
  • Loading branch information
protobuf-github-bot authored and copybara-github committed Nov 21, 2024
1 parent 61662a5 commit d6b90bf
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 29 deletions.
48 changes: 33 additions & 15 deletions rust/cpp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -790,23 +790,28 @@ impl UntypedMapIterator {
#[doc(hidden)]
#[repr(u8)]
#[derive(Debug, PartialEq)]
// Copy of UntypedMapBase::TypeKind
pub enum MapValueTag {
Bool,
U32,
U64,
F32,
F64,
String,
Message,
Unknown,
}

// For the purposes of FFI, we treat all numeric types of a given size the same
// way. For example, u32, i32, and f32 values are all represented as a u32.
// Likewise, u64, i64, and f64 values are all stored in a u64.
// For the purposes of FFI, we treat all integral types of a given size the same
// way. For example, u32 and i32 values are all represented as a u32.
// Likewise, u64 and i64 values are all stored in a u64.
#[doc(hidden)]
#[repr(C)]
pub union MapValueUnion {
pub b: bool,
pub u: u32,
pub uu: u64,
pub f: f32,
pub ff: f64,
// Generally speaking, if s is set then it should not be None. However, we
// do set it to None in the special case where the MapValue is just a
// "prototype" (see below). In that scenario, we just want to indicate the
Expand Down Expand Up @@ -838,6 +843,14 @@ impl MapValue {
MapValue { tag: MapValueTag::U64, val: MapValueUnion { uu } }
}

pub fn make_f32(f: f32) -> Self {
MapValue { tag: MapValueTag::F32, val: MapValueUnion { f } }
}

fn make_f64(ff: f64) -> Self {
MapValue { tag: MapValueTag::F64, val: MapValueUnion { ff } }
}

fn make_string(s: CppStdString) -> Self {
MapValue { tag: MapValueTag::String, val: MapValueUnion { s: Some(s) } }
}
Expand Down Expand Up @@ -921,27 +934,27 @@ impl CppMapTypeConversions for i64 {

impl CppMapTypeConversions for f32 {
fn get_prototype() -> MapValue {
MapValue::make_u32(0)
MapValue::make_f32(0f32)
}
fn to_map_value(self) -> MapValue {
MapValue::make_u32(self.to_bits())
MapValue::make_f32(self)
}
unsafe fn from_map_value<'a>(value: MapValue) -> View<'a, Self> {
debug_assert_eq!(value.tag, MapValueTag::U32);
unsafe { Self::from_bits(value.val.u) }
debug_assert_eq!(value.tag, MapValueTag::F32);
unsafe { value.val.f }
}
}

impl CppMapTypeConversions for f64 {
fn get_prototype() -> MapValue {
MapValue::make_u64(0)
MapValue::make_f64(0.0)
}
fn to_map_value(self) -> MapValue {
MapValue::make_u64(self.to_bits())
MapValue::make_f64(self)
}
unsafe fn from_map_value<'a>(value: MapValue) -> View<'a, Self> {
debug_assert_eq!(value.tag, MapValueTag::U64);
unsafe { Self::from_bits(value.val.uu) }
debug_assert_eq!(value.tag, MapValueTag::F64);
unsafe { value.val.ff }
}
}

Expand Down Expand Up @@ -1099,11 +1112,16 @@ generate_map_key_impl!(

impl<Key, Value> ProxiedInMapValue<Key> for Value
where
Key: Proxied + MapKey,
Key: Proxied + MapKey + CppMapTypeConversions,
Value: Proxied + CppMapTypeConversions,
{
fn map_new(_private: Private) -> Map<Key, Self> {
unsafe { Map::from_inner(Private, InnerMap::new(proto2_rust_map_new())) }
unsafe {
Map::from_inner(
Private,
InnerMap::new(proto2_rust_map_new(Key::get_prototype(), Value::get_prototype())),
)
}
}

unsafe fn map_free(_private: Private, map: &mut Map<Key, Self>) {
Expand Down Expand Up @@ -1241,7 +1259,7 @@ impl_map_primitives!(
extern "C" {
fn proto2_rust_thunk_UntypedMapIterator_increment(iter: &mut UntypedMapIterator);

pub fn proto2_rust_map_new() -> RawMap;
pub fn proto2_rust_map_new(key_prototype: MapValue, value_prototype: MapValue) -> RawMap;
pub fn proto2_rust_map_size(m: RawMap) -> usize;
pub fn proto2_rust_map_iter(m: RawMap) -> UntypedMapIterator;
}
Expand Down
44 changes: 34 additions & 10 deletions rust/cpp_kernel/map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,17 @@ namespace protobuf {
namespace rust {
namespace {

// LINT.IfChange(map_ffi)
enum class MapValueTag : uint8_t {
kBool,
kU32,
kU64,
kString,
kMessage,
};
using MapValueTag = internal::UntypedMapBase::TypeKind;

// LINT.IfChange(map_ffi)
struct MapValue {
MapValueTag tag;
union {
bool b;
uint32_t u32;
uint64_t u64;
float f32;
double f64;
std::string* s;
google::protobuf::MessageLite* message;
};
Expand Down Expand Up @@ -66,6 +62,14 @@ void GetSizeAndAlignment(MapValue value, uint16_t* size, uint8_t* alignment) {
*size = sizeof(uint64_t);
*alignment = alignof(uint64_t);
break;
case MapValueTag::kFloat:
*size = sizeof(float);
*alignment = alignof(float);
break;
case MapValueTag::kDouble:
*size = sizeof(double);
*alignment = alignof(double);
break;
case MapValueTag::kString:
*size = sizeof(std::string);
*alignment = alignof(std::string);
Expand Down Expand Up @@ -149,6 +153,12 @@ bool Insert(internal::UntypedMapBase* m, Key key, MapValue value) {
case MapValueTag::kU64:
*static_cast<uint64_t*>(value_ptr) = value.u64;
break;
case MapValueTag::kFloat:
*static_cast<float*>(value_ptr) = value.f32;
break;
case MapValueTag::kDouble:
*static_cast<double*>(value_ptr) = value.f64;
break;
case MapValueTag::kString:
new (value_ptr) std::string(std::move(*value.s));
delete value.s;
Expand Down Expand Up @@ -196,6 +206,12 @@ void PopulateMapValue(MapValueTag tag, void* data, MapValue& output) {
case MapValueTag::kU64:
output.u64 = *static_cast<const uint64_t*>(data);
break;
case MapValueTag::kFloat:
output.f32 = *static_cast<const float*>(data);
break;
case MapValueTag::kDouble:
output.f64 = *static_cast<const double*>(data);
break;
case MapValueTag::kString:
output.s = static_cast<std::string*>(data);
break;
Expand Down Expand Up @@ -294,8 +310,16 @@ void proto2_rust_thunk_UntypedMapIterator_increment(
iter->PlusPlus();
}

google::protobuf::internal::UntypedMapBase* proto2_rust_map_new() {
return new google::protobuf::internal::UntypedMapBase(/* arena = */ nullptr);
google::protobuf::internal::UntypedMapBase* proto2_rust_map_new(
google::protobuf::rust::MapValue key_prototype,
google::protobuf::rust::MapValue value_prototype) {
return new google::protobuf::internal::UntypedMapBase(
/* arena = */ nullptr,
google::protobuf::internal::UntypedMapBase::GetTypeInfoDynamic(
key_prototype.tag, value_prototype.tag,
value_prototype.tag == google::protobuf::rust::MapValueTag::kMessage
? value_prototype.message
: nullptr));
}

size_t proto2_rust_map_size(google::protobuf::internal::UntypedMapBase* m) {
Expand Down
61 changes: 61 additions & 0 deletions src/google/protobuf/map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,67 @@ size_t UntypedMapBase::SpaceUsedInTable(size_t sizeof_node) const {
return size;
}

static size_t AlignTo(size_t v, size_t alignment, size_t& max_align) {
max_align = std::max<size_t>(max_align, alignment);
return (v + alignment - 1) / alignment * alignment;
}

struct Offsets {
size_t start;
size_t end;
};

template <typename T>
static Offsets AlignAndAddSize(size_t v, size_t& max_align) {
v = AlignTo(v, alignof(T), max_align);
return {v, v + sizeof(T)};
}

static Offsets AlignAndAddSizeDynamic(
size_t v, UntypedMapBase::TypeKind kind,
const MessageLite* value_prototype_if_message, size_t& max_align) {
switch (kind) {
case UntypedMapBase::TypeKind::kBool:
return AlignAndAddSize<bool>(v, max_align);
case UntypedMapBase::TypeKind::kU32:
return AlignAndAddSize<int32_t>(v, max_align);
case UntypedMapBase::TypeKind::kU64:
return AlignAndAddSize<int64_t>(v, max_align);
case UntypedMapBase::TypeKind::kFloat:
return AlignAndAddSize<float>(v, max_align);
case UntypedMapBase::TypeKind::kDouble:
return AlignAndAddSize<double>(v, max_align);
case UntypedMapBase::TypeKind::kString:
return AlignAndAddSize<std::string>(v, max_align);
case UntypedMapBase::TypeKind::kMessage: {
auto* class_data = GetClassData(*value_prototype_if_message);
v = AlignTo(v, class_data->alignment(), max_align);
return {v, v + class_data->allocation_size()};
}
default:
Unreachable();
}
}

template <typename T, typename U>
T Narrow(U value) {
ABSL_CHECK_EQ(value, static_cast<T>(value));
return static_cast<T>(value);
}

UntypedMapBase::TypeInfo UntypedMapBase::GetTypeInfoDynamic(
TypeKind key_type, TypeKind value_type,
const MessageLite* value_prototype_if_message) {
size_t max_align = alignof(NodeBase);
const auto key_offsets =
AlignAndAddSizeDynamic(sizeof(NodeBase), key_type, nullptr, max_align);
const auto value_offsets = AlignAndAddSizeDynamic(
key_offsets.end, value_type, value_prototype_if_message, max_align);
return TypeInfo{
Narrow<uint16_t>(AlignTo(value_offsets.end, max_align, max_align)),
Narrow<uint8_t>(value_offsets.start), key_type, value_type};
}

} // namespace internal
} // namespace protobuf
} // namespace google
Expand Down
73 changes: 69 additions & 4 deletions src/google/protobuf/map.h
Original file line number Diff line number Diff line change
Expand Up @@ -393,10 +393,62 @@ class PROTOBUF_EXPORT UntypedMapBase {
public:
using size_type = size_t;

explicit constexpr UntypedMapBase(Arena* arena)
// Possible types that a key/value can take.
// LINT.IfChange(map_ffi)
enum class TypeKind : uint8_t {
kBool, // bool
kU32, // int32_t, uint32_t, enums
kU64, // int64_t, uint64_t
kFloat, // float
kDouble, // double
kString, // std::string
kMessage, // Derived from MessageLite
kUnknown, // For DynamicMapField for now
};
// LINT.ThenChange(//depot/google3/third_party/protobuf/rust/cpp.rs:map_ffi)

template <typename T>
static constexpr TypeKind StaticTypeKind() {
if constexpr (std::is_same_v<T, bool>) {
return TypeKind::kBool;
} else if constexpr (std::is_same_v<T, int32_t> ||
std::is_same_v<T, uint32_t>) {
return TypeKind::kU32;
} else if constexpr (std::is_same_v<T, int64_t> ||
std::is_same_v<T, uint64_t>) {
return TypeKind::kU64;
} else if constexpr (std::is_same_v<T, float>) {
return TypeKind::kFloat;
} else if constexpr (std::is_same_v<T, double>) {
return TypeKind::kDouble;
} else if constexpr (std::is_same_v<T, std::string>) {
return TypeKind::kString;
} else if constexpr (std::is_base_of_v<MessageLite, T>) {
return TypeKind::kMessage;
} else {
return TypeKind::kUnknown;
}
}

struct TypeInfo {
// Equivalent to `sizeof(Node)` in the derived type.
uint16_t node_size;
// Equivalent to `offsetof(Node, kv.second)` in the derived type.
uint8_t value_offset;
TypeKind key_type : 4;
TypeKind value_type : 4;
};
static_assert(sizeof(TypeInfo) == 4);

static TypeInfo GetTypeInfoDynamic(
TypeKind key_type, TypeKind value_type,
const MessageLite* value_prototype_if_message);

explicit constexpr UntypedMapBase(Arena* arena, TypeInfo type_info)
: num_elements_(0),
num_buckets_(internal::kGlobalEmptyTableSize),
index_of_first_non_null_(internal::kGlobalEmptyTableSize),
type_info_(type_info),
table_(const_cast<NodeBase**>(internal::kGlobalEmptyTable)),
alloc_(arena) {}

Expand All @@ -414,6 +466,7 @@ class PROTOBUF_EXPORT UntypedMapBase {
std::swap(num_elements_, other->num_elements_);
std::swap(num_buckets_, other->num_buckets_);
std::swap(index_of_first_non_null_, other->index_of_first_non_null_);
std::swap(type_info_, other->type_info_);
std::swap(table_, other->table_);
std::swap(alloc_, other->alloc_);
}
Expand Down Expand Up @@ -557,6 +610,7 @@ class PROTOBUF_EXPORT UntypedMapBase {
map_index_t num_elements_;
map_index_t num_buckets_;
map_index_t index_of_first_non_null_;
TypeInfo type_info_;
NodeBase** table_; // an array with num_buckets_ entries
Allocator alloc_;
};
Expand Down Expand Up @@ -952,12 +1006,14 @@ class Map : private internal::KeyMapBase<internal::KeyForBase<Key>> {
using size_type = size_t;
using hasher = absl::Hash<typename TS::ViewType>;

constexpr Map() : Base(nullptr) { StaticValidityCheck(); }
constexpr Map() : Base(nullptr, GetTypeInfo()) { StaticValidityCheck(); }
Map(const Map& other) : Map(nullptr, other) {}

// Internal Arena constructors: do not use!
// TODO: remove non internal ctors
explicit Map(Arena* arena) : Base(arena) { StaticValidityCheck(); }
explicit Map(Arena* arena) : Base(arena, GetTypeInfo()) {
StaticValidityCheck();
}
Map(internal::InternalVisibility, Arena* arena) : Map(arena) {}
Map(internal::InternalVisibility, Arena* arena, const Map& other)
: Map(arena, other) {}
Expand Down Expand Up @@ -997,7 +1053,7 @@ class Map : private internal::KeyMapBase<internal::KeyForBase<Key>> {
}

private:
Map(Arena* arena, const Map& other) : Base(arena) {
Map(Arena* arena, const Map& other) : Map(arena) {
StaticValidityCheck();
insert(other.begin(), other.end());
}
Expand Down Expand Up @@ -1386,6 +1442,15 @@ class Map : private internal::KeyMapBase<internal::KeyForBase<Key>> {
value_type kv;
};

static constexpr auto GetTypeInfo() {
return internal::UntypedMapBase::TypeInfo{
sizeof(Node),
PROTOBUF_FIELD_OFFSET(Node, kv.second),
internal::UntypedMapBase::StaticTypeKind<Key>(),
internal::UntypedMapBase::StaticTypeKind<T>(),
};
}

void DestroyNode(Node* node) {
if (this->alloc_.arena() == nullptr) {
node->kv.first.~key_type();
Expand Down
Loading

0 comments on commit d6b90bf

Please sign in to comment.