Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In this CL, we update generated_message_tctable_lite to support both length-prefixed and delimited when it comes to parsing submessages. #16247

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 16 additions & 29 deletions src/google/protobuf/generated_message_tctable_lite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2297,44 +2297,35 @@ PROTOBUF_NOINLINE const char* TcParser::MpMessage(PROTOBUF_TC_PARAM_DECL) {
const uint16_t type_card = entry.type_card;
const uint16_t card = type_card & field_layout::kFcMask;

const uint32_t decoded_tag = data.tag();
const uint32_t decoded_wiretype = decoded_tag & 7;

// Check for repeated parsing:
if (card == field_layout::kFcRepeated) {
const uint16_t rep = type_card & field_layout::kRepMask;
switch (rep) {
case field_layout::kRepMessage:
switch (decoded_wiretype) {
case WireFormatLite::WIRETYPE_LENGTH_DELIMITED:
PROTOBUF_MUSTTAIL return MpRepeatedMessageOrGroup<is_split, false>(
PROTOBUF_TC_PARAM_PASS);
case field_layout::kRepGroup:
case WireFormatLite::WIRETYPE_START_GROUP:
PROTOBUF_MUSTTAIL return MpRepeatedMessageOrGroup<is_split, true>(
PROTOBUF_TC_PARAM_PASS);
default:
PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS);
}
}

const uint32_t decoded_tag = data.tag();
const uint32_t decoded_wiretype = decoded_tag & 7;
const uint16_t rep = type_card & field_layout::kRepMask;
const bool is_group = rep == field_layout::kRepGroup;

// Validate wiretype:
switch (rep) {
case field_layout::kRepMessage:
if (decoded_wiretype != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) {
goto fallback;
}
break;
case field_layout::kRepGroup:
if (decoded_wiretype != WireFormatLite::WIRETYPE_START_GROUP) {
goto fallback;
}
break;
default: {
fallback:
PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS);
}
// note that we solely rely on wiretype for parsing messages (schema ignored)
const bool is_group =
decoded_wiretype == WireFormatLite::WIRETYPE_START_GROUP;

// If we don't see a wiretype of START_GROUP or DELIM even though we're in the
// entry point for MpMessage, something is wrong. Bail out!
if (decoded_wiretype != WireFormatLite::WIRETYPE_START_GROUP &&
decoded_wiretype != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) {
PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS);
}


const bool is_oneof = card == field_layout::kFcOneof;
bool need_init = false;
if (card == field_layout::kFcOptional) {
Expand Down Expand Up @@ -2386,14 +2377,10 @@ const char* TcParser::MpRepeatedMessageOrGroup(PROTOBUF_TC_PARAM_DECL) {

// Validate wiretype:
if (!is_group) {
ABSL_DCHECK_EQ(type_card & field_layout::kRepMask,
static_cast<uint16_t>(field_layout::kRepMessage));
if (decoded_wiretype != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) {
PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS);
}
} else {
ABSL_DCHECK_EQ(type_card & field_layout::kRepMask,
static_cast<uint16_t>(field_layout::kRepGroup));
if (decoded_wiretype != WireFormatLite::WIRETYPE_START_GROUP) {
PROTOBUF_MUSTTAIL return table->fallback(PROTOBUF_TC_PARAM_PASS);
}
Expand Down
20 changes: 20 additions & 0 deletions src/google/protobuf/lite_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,26 @@ TYPED_TEST(LiteTest, CorrectEnding) {
}
}

TYPED_TEST(LiteTest, MessageEncoding) {
protobuf_unittest::TestAllTypesLite msg;
{
// Make sure that we support length-prefixed encoding for submsgs
static const char kWireFormat[] = "\n\002\010\003"; // 1: {1: 3}
io::CodedInputStream cis(reinterpret_cast<const uint8_t*>(kWireFormat), 4);
EXPECT_TRUE(msg.MergePartialFromCodedStream(&cis));
EXPECT_TRUE(cis.ConsumedEntireMessage());
EXPECT_TRUE(cis.LastTagWas(0));
}
{
// Make sure that we support delimited encoding for submsgs
static const char kWireFormat[] = "\013\010\003\014"; // 1: !{1: 3}
io::CodedInputStream cis(reinterpret_cast<const uint8_t*>(kWireFormat), 4);
EXPECT_TRUE(msg.MergePartialFromCodedStream(&cis));
EXPECT_TRUE(cis.ConsumedEntireMessage());
EXPECT_TRUE(cis.LastTagWas(0));
}
}

TYPED_TEST(LiteTest, DebugString) {
protobuf_unittest::TestAllTypesLite message1, message2;
EXPECT_TRUE(absl::StartsWith(message1.DebugString(), "MessageLite at 0x"));
Expand Down
213 changes: 190 additions & 23 deletions src/google/protobuf/wire_format.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "google/protobuf/message_lite.h"
#include "google/protobuf/parse_context.h"
#include "google/protobuf/unknown_field_set.h"
#include "google/protobuf/wire_format_lite.h"


// Must be included last.
Expand Down Expand Up @@ -578,31 +579,59 @@ bool WireFormat::ParseAndMergeField(

case FieldDescriptor::TYPE_GROUP: {
Message* sub_message;
if (field->is_repeated()) {
sub_message = message_reflection->AddMessage(
message, field, input->GetExtensionFactory());
if (WireFormatLite::GetTagWireType(tag) ==
WireFormatLite::WIRETYPE_START_GROUP) {
if (field->is_repeated()) {
sub_message = message_reflection->AddMessage(
message, field, input->GetExtensionFactory());
} else {
sub_message = message_reflection->MutableMessage(
message, field, input->GetExtensionFactory());
}

if (!WireFormatLite::ReadGroup(WireFormatLite::GetTagFieldNumber(tag),
input, sub_message))
return false;
} else {
sub_message = message_reflection->MutableMessage(
message, field, input->GetExtensionFactory());
}
if (field->is_repeated()) {
sub_message = message_reflection->AddMessage(
message, field, input->GetExtensionFactory());
} else {
sub_message = message_reflection->MutableMessage(
message, field, input->GetExtensionFactory());
}

if (!WireFormatLite::ReadGroup(WireFormatLite::GetTagFieldNumber(tag),
input, sub_message))
return false;
if (!WireFormatLite::ReadMessage(input, sub_message)) return false;
}
break;
}

case FieldDescriptor::TYPE_MESSAGE: {
Message* sub_message;
if (field->is_repeated()) {
sub_message = message_reflection->AddMessage(
message, field, input->GetExtensionFactory());
if (WireFormatLite::GetTagWireType(tag) ==
WireFormatLite::WIRETYPE_START_GROUP) {
if (field->is_repeated()) {
sub_message = message_reflection->AddMessage(
message, field, input->GetExtensionFactory());
} else {
sub_message = message_reflection->MutableMessage(
message, field, input->GetExtensionFactory());
}

if (!WireFormatLite::ReadGroup(WireFormatLite::GetTagFieldNumber(tag),
input, sub_message))
return false;
} else {
sub_message = message_reflection->MutableMessage(
message, field, input->GetExtensionFactory());
}
if (field->is_repeated()) {
sub_message = message_reflection->AddMessage(
message, field, input->GetExtensionFactory());
} else {
sub_message = message_reflection->MutableMessage(
message, field, input->GetExtensionFactory());
}

if (!WireFormatLite::ReadMessage(input, sub_message)) return false;
if (!WireFormatLite::ReadMessage(input, sub_message)) return false;
}
break;
}
}
Expand Down Expand Up @@ -871,6 +900,104 @@ const char* WireFormat::_InternalParseAndMergeField(
ABSL_LOG(FATAL) << "Can't reach";
return nullptr;
}
} else if (WireFormatLite::GetTagWireType(tag) !=
WireTypeForFieldType(field->type()) &&
(WireFormatLite::GetTagWireType(tag) ==
WireFormatLite::WIRETYPE_START_GROUP ||
WireFormatLite::GetTagWireType(tag) ==
WireFormatLite::WIRETYPE_LENGTH_DELIMITED)) {
switch (field->type()) {
case FieldDescriptor::TYPE_GROUP: {
Message* sub_message;

if (WireFormatLite::GetTagWireType(tag) ==
WireFormatLite::WIRETYPE_START_GROUP) {
if (field->is_repeated()) {
sub_message =
reflection->AddMessage(msg, field, ctx->data().factory);
} else {
sub_message =
reflection->MutableMessage(msg, field, ctx->data().factory);
}

return ctx->ParseGroup(sub_message, ptr, tag);
} else {
if (field->is_repeated()) {
sub_message =
reflection->AddMessage(msg, field, ctx->data().factory);
} else {
sub_message =
reflection->MutableMessage(msg, field, ctx->data().factory);
}
ptr = ctx->ParseMessage(sub_message, ptr);

// For map entries, if the value is an unknown enum we have to push
// it into the unknown field set and remove it from the list.
if (ptr != nullptr && field->is_map()) {
auto* value_field = field->message_type()->map_value();
auto* enum_type = value_field->enum_type();
if (enum_type != nullptr &&
!internal::cpp::HasPreservingUnknownEnumSemantics(
value_field) &&
enum_type->FindValueByNumber(
sub_message->GetReflection()->GetEnumValue(
*sub_message, value_field)) == nullptr) {
reflection->MutableUnknownFields(msg)->AddLengthDelimited(
field->number(), sub_message->SerializeAsString());
reflection->RemoveLast(msg, field);
}
}

return ptr;
}
}

case FieldDescriptor::TYPE_MESSAGE: {
Message* sub_message;
if (WireFormatLite::GetTagWireType(tag) ==
WireFormatLite::WIRETYPE_START_GROUP) {
if (field->is_repeated()) {
sub_message =
reflection->AddMessage(msg, field, ctx->data().factory);
} else {
sub_message =
reflection->MutableMessage(msg, field, ctx->data().factory);
}

return ctx->ParseGroup(sub_message, ptr, tag);
} else if (field->is_repeated()) {
sub_message =
reflection->AddMessage(msg, field, ctx->data().factory);
} else {
sub_message =
reflection->MutableMessage(msg, field, ctx->data().factory);
}
ptr = ctx->ParseMessage(sub_message, ptr);

// For map entries, if the value is an unknown enum we have to push it
// into the unknown field set and remove it from the list.
if (ptr != nullptr && field->is_map()) {
auto* value_field = field->message_type()->map_value();
auto* enum_type = value_field->enum_type();
if (enum_type != nullptr &&
!internal::cpp::HasPreservingUnknownEnumSemantics(
value_field) &&
enum_type->FindValueByNumber(
sub_message->GetReflection()->GetEnumValue(
*sub_message, value_field)) == nullptr) {
reflection->MutableUnknownFields(msg)->AddLengthDelimited(
field->number(), sub_message->SerializeAsString());
reflection->RemoveLast(msg, field);
}
}

return ptr;
}
default: {
return internal::UnknownFieldParse(
tag, reflection->MutableUnknownFields(msg), ptr, ctx);
}
}
} else {
// mismatched wiretype;
return internal::UnknownFieldParse(
Expand Down Expand Up @@ -997,19 +1124,59 @@ const char* WireFormat::_InternalParseAndMergeField(

case FieldDescriptor::TYPE_GROUP: {
Message* sub_message;
if (field->is_repeated()) {
sub_message = reflection->AddMessage(msg, field, ctx->data().factory);

if (WireFormatLite::GetTagWireType(tag) ==
WireFormatLite::WIRETYPE_START_GROUP) {
if (field->is_repeated()) {
sub_message = reflection->AddMessage(msg, field, ctx->data().factory);
} else {
sub_message =
reflection->MutableMessage(msg, field, ctx->data().factory);
}

return ctx->ParseGroup(sub_message, ptr, tag);
} else {
sub_message =
reflection->MutableMessage(msg, field, ctx->data().factory);
}
if (field->is_repeated()) {
sub_message = reflection->AddMessage(msg, field, ctx->data().factory);
} else {
sub_message =
reflection->MutableMessage(msg, field, ctx->data().factory);
}
ptr = ctx->ParseMessage(sub_message, ptr);

// For map entries, if the value is an unknown enum we have to push it
// into the unknown field set and remove it from the list.
if (ptr != nullptr && field->is_map()) {
auto* value_field = field->message_type()->map_value();
auto* enum_type = value_field->enum_type();
if (enum_type != nullptr &&
!internal::cpp::HasPreservingUnknownEnumSemantics(value_field) &&
enum_type->FindValueByNumber(
sub_message->GetReflection()->GetEnumValue(
*sub_message, value_field)) == nullptr) {
reflection->MutableUnknownFields(msg)->AddLengthDelimited(
field->number(), sub_message->SerializeAsString());
reflection->RemoveLast(msg, field);
}
}

return ctx->ParseGroup(sub_message, ptr, tag);
return ptr;
}
}

case FieldDescriptor::TYPE_MESSAGE: {
Message* sub_message;
if (field->is_repeated()) {
if (WireFormatLite::GetTagWireType(tag) ==
WireFormatLite::WIRETYPE_START_GROUP) {
if (field->is_repeated()) {
sub_message = reflection->AddMessage(msg, field, ctx->data().factory);
} else {
sub_message =
reflection->MutableMessage(msg, field, ctx->data().factory);
}

return ctx->ParseGroup(sub_message, ptr, tag);
} else if (field->is_repeated()) {
sub_message = reflection->AddMessage(msg, field, ctx->data().factory);
} else {
sub_message =
Expand Down
Loading