Skip to content

Commit

Permalink
Make Pure Python reject unmatched end-group tag in unknown fields
Browse files Browse the repository at this point in the history
This brings it into conformance with our spec and other languages.

PiperOrigin-RevId: 704518974
  • Loading branch information
mkruskal-google authored and copybara-github committed Dec 10, 2024
1 parent 75581bf commit f69ea1c
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 20 deletions.
3 changes: 0 additions & 3 deletions conformance/failure_list_python.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1 @@
Required.*.JsonInput.Int32FieldQuotedExponentialValue.* # Failed to parse input or produce output.
Required.*.ProtobufInput.MismatchedNestedGroupTags # Should have failed to parse, but didn't.
Required.Editions_Proto3.ProtobufInput.MismatchedGroupTags # Should have failed to parse, but didn't.
Required.Proto3.ProtobufInput.MismatchedGroupTags # Should have failed to parse, but didn't.
14 changes: 11 additions & 3 deletions python/google/protobuf/internal/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,14 +944,16 @@ def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
field_number, wire_type = wire_format.UnpackTag(tag)
if wire_type == wire_format.WIRETYPE_END_GROUP:
break
(data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
(data, pos) = _DecodeUnknownField(
buffer, pos, end_pos, field_number, wire_type
)
# pylint: disable=protected-access
unknown_field_set._add(field_number, wire_type, data)

return (unknown_field_set, pos)


def _DecodeUnknownField(buffer, pos, wire_type):
def _DecodeUnknownField(buffer, pos, end_pos, field_number, wire_type):
"""Decode a unknown field. Returns the UnknownField and new position."""

if wire_type == wire_format.WIRETYPE_VARINT:
Expand All @@ -965,7 +967,13 @@ def _DecodeUnknownField(buffer, pos, wire_type):
data = buffer[pos:pos+size].tobytes()
pos += size
elif wire_type == wire_format.WIRETYPE_START_GROUP:
(data, pos) = _DecodeUnknownFieldSet(buffer, pos)
end_tag_bytes = encoder.TagBytes(
field_number, wire_format.WIRETYPE_END_GROUP
)
data, pos = _DecodeUnknownFieldSet(buffer, pos, end_pos)
# Check end tag.
if buffer[pos - len(end_tag_bytes) : pos] != end_tag_bytes:
raise _DecodeError('Missing group end tag.')
elif wire_type == wire_format.WIRETYPE_END_GROUP:
return (0, -1)
else:
Expand Down
56 changes: 54 additions & 2 deletions python/google/protobuf/internal/decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
import io
import unittest

from google.protobuf import message
from google.protobuf.internal import decoder
from google.protobuf.internal import testing_refleaks
from google.protobuf.internal import wire_format


_INPUT_BYTES = b'\x84r\x12'
Expand All @@ -33,7 +35,7 @@ def test_decode_varint_bytes(self):

def test_decode_varint_bytes_empty(self):
with self.assertRaises(IndexError) as context:
(size, pos) = decoder._DecodeVarint(b'', 0)
decoder._DecodeVarint(b'', 0)
self.assertIn('index out of range', str(context.exception))

def test_decode_varint_bytesio(self):
Expand All @@ -50,7 +52,57 @@ def test_decode_varint_bytesio(self):
def test_decode_varint_bytesio_empty(self):
input_io = io.BytesIO(b'')
size = decoder._DecodeVarint(input_io)
self.assertEqual(size, None)
self.assertIsNone(size)

def test_decode_unknown_group_field(self):
data = memoryview(b'\013\020\003\014\040\005')
parsed, pos = decoder._DecodeUnknownField(
data, 1, len(data), 1, wire_format.WIRETYPE_START_GROUP
)

self.assertEqual(pos, 4)
self.assertEqual(len(parsed), 1)
self.assertEqual(parsed[0].field_number, 2)
self.assertEqual(parsed[0].data, 3)

def test_decode_unknown_group_field_nested(self):
data = memoryview(b'\013\023\013\030\004\014\024\014\050\006')
parsed, pos = decoder._DecodeUnknownField(
data, 1, len(data), 1, wire_format.WIRETYPE_START_GROUP
)

self.assertEqual(pos, 8)
self.assertEqual(len(parsed), 1)
self.assertEqual(parsed[0].field_number, 2)
self.assertEqual(len(parsed[0].data), 1)
self.assertEqual(parsed[0].data[0].field_number, 1)
self.assertEqual(len(parsed[0].data[0].data), 1)
self.assertEqual(parsed[0].data[0].data[0].field_number, 3)
self.assertEqual(parsed[0].data[0].data[0].data, 4)

def test_decode_unknown_mismatched_end_group(self):
self.assertRaisesRegex(
message.DecodeError,
'Missing group end tag.*',
decoder._DecodeUnknownField,
memoryview(b'\013\024'),
1,
2,
1,
wire_format.WIRETYPE_START_GROUP,
)

def test_decode_unknown_mismatched_end_group_nested(self):
self.assertRaisesRegex(
message.DecodeError,
'Missing group end tag.*',
decoder._DecodeUnknownField,
memoryview(b'\013\023\034\024\014'),
1,
5,
1,
wire_format.WIRETYPE_START_GROUP,
)


if __name__ == '__main__':
Expand Down
31 changes: 21 additions & 10 deletions python/google/protobuf/internal/message_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,24 @@ def testParseErrors(self, message_module):
msg = message_module.TestAllTypes()
self.assertRaises(TypeError, msg.FromString, 0)
self.assertRaises(Exception, msg.FromString, '0')

# Unexpected end group tag.
end_tag = encoder.TagBytes(1, 4)
with self.assertRaises(message.DecodeError) as context:
msg.FromString(end_tag)
if api_implementation.Type() != 'upb':
# upb raises a less specific exception.
self.assertRegex(str(context.exception), 'Unexpected end-group tag.*')

# Unmatched start group tag.
start_tag = encoder.TagBytes(2, 3)
with self.assertRaises(message.DecodeError):
msg.FromString(start_tag)

# Mismatched end group tag.
with self.assertRaises(message.DecodeError):
msg.FromString(start_tag + end_tag)

# Field number 0 is illegal.
self.assertRaises(message.DecodeError, msg.FromString, b'\3\4')

Expand Down Expand Up @@ -1265,16 +1276,16 @@ def testReturningType(self, message_module):
self.assertEqual(True, m.repeated_bool[0])

def testDir(self, message_module):
m = message_module.TestAllTypes()
attributes = dir(m)
self.assertGreaterEqual(len(attributes), 55)
self.assertIn('DESCRIPTOR', attributes)

class_attributes = dir(type(m))
attribute_set = set(attributes)
for attr in class_attributes:
if attr != 'Extensions':
self.assertIn(attr, attribute_set)
m = message_module.TestAllTypes()
attributes = dir(m)
self.assertGreaterEqual(len(attributes), 55)
self.assertIn('DESCRIPTOR', attributes)

class_attributes = dir(type(m))
attribute_set = set(attributes)
for attr in class_attributes:
if attr != 'Extensions':
self.assertIn(attr, attribute_set)

def testAllAttributeFromDirAccessible(self, message_module):
m = message_module.TestAllTypes()
Expand Down
3 changes: 2 additions & 1 deletion python/google/protobuf/internal/python_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,7 +1234,8 @@ def InternalParse(self, buffer, pos, end):
# TODO: remove old_pos.
old_pos = new_pos
(data, new_pos) = decoder._DecodeUnknownField(
buffer, new_pos, wire_type) # pylint: disable=protected-access
buffer, new_pos, end, field_number, wire_type
) # pylint: disable=protected-access
if new_pos == -1:
return pos
# TODO: remove _unknown_fields.
Expand Down
3 changes: 2 additions & 1 deletion python/google/protobuf/unknown_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def InternalAdd(field_number, wire_type, data):
if field_number == 0:
raise RuntimeError('Field number 0 is illegal.')
(data, _) = decoder._DecodeUnknownField(
memoryview(buffer), 0, wire_type)
memoryview(buffer), 0, len(buffer), field_number, wire_type
)
InternalAdd(field_number, wire_type, data)

def __getitem__(self, index):
Expand Down

0 comments on commit f69ea1c

Please sign in to comment.