diff --git a/conformance/failure_list_python.txt b/conformance/failure_list_python.txt index e788be17d5b6..e7b56042a784 100644 --- a/conformance/failure_list_python.txt +++ b/conformance/failure_list_python.txt @@ -1,4 +1 @@ Required.*.JsonInput.Int32FieldQuotedExponentialValue.* # Failed to parse input or produce output. -Required.*.ProtobufInput.MismatchedNestedGroupTags # Should have failed to parse, but didn't. -Required.Editions_Proto3.ProtobufInput.MismatchedGroupTags # Should have failed to parse, but didn't. -Required.Proto3.ProtobufInput.MismatchedGroupTags # Should have failed to parse, but didn't. diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py index dcde1d9420c9..e26770e1d5fc 100755 --- a/python/google/protobuf/internal/decoder.py +++ b/python/google/protobuf/internal/decoder.py @@ -944,14 +944,16 @@ def _DecodeUnknownFieldSet(buffer, pos, end_pos=None): field_number, wire_type = wire_format.UnpackTag(tag) if wire_type == wire_format.WIRETYPE_END_GROUP: break - (data, pos) = _DecodeUnknownField(buffer, pos, wire_type) + (data, pos) = _DecodeUnknownField( + buffer, pos, end_pos, field_number, wire_type + ) # pylint: disable=protected-access unknown_field_set._add(field_number, wire_type, data) return (unknown_field_set, pos) -def _DecodeUnknownField(buffer, pos, wire_type): +def _DecodeUnknownField(buffer, pos, end_pos, field_number, wire_type): """Decode a unknown field. Returns the UnknownField and new position.""" if wire_type == wire_format.WIRETYPE_VARINT: @@ -965,7 +967,13 @@ def _DecodeUnknownField(buffer, pos, wire_type): data = buffer[pos:pos+size].tobytes() pos += size elif wire_type == wire_format.WIRETYPE_START_GROUP: - (data, pos) = _DecodeUnknownFieldSet(buffer, pos) + end_tag_bytes = encoder.TagBytes( + field_number, wire_format.WIRETYPE_END_GROUP + ) + data, pos = _DecodeUnknownFieldSet(buffer, pos, end_pos) + # Check end tag. + if buffer[pos - len(end_tag_bytes) : pos] != end_tag_bytes: + raise _DecodeError('Missing group end tag.') elif wire_type == wire_format.WIRETYPE_END_GROUP: return (0, -1) else: diff --git a/python/google/protobuf/internal/decoder_test.py b/python/google/protobuf/internal/decoder_test.py index f801b6e76fd8..8737e117d9e2 100644 --- a/python/google/protobuf/internal/decoder_test.py +++ b/python/google/protobuf/internal/decoder_test.py @@ -11,8 +11,10 @@ import io import unittest +from google.protobuf import message from google.protobuf.internal import decoder from google.protobuf.internal import testing_refleaks +from google.protobuf.internal import wire_format _INPUT_BYTES = b'\x84r\x12' @@ -33,7 +35,7 @@ def test_decode_varint_bytes(self): def test_decode_varint_bytes_empty(self): with self.assertRaises(IndexError) as context: - (size, pos) = decoder._DecodeVarint(b'', 0) + decoder._DecodeVarint(b'', 0) self.assertIn('index out of range', str(context.exception)) def test_decode_varint_bytesio(self): @@ -50,7 +52,57 @@ def test_decode_varint_bytesio(self): def test_decode_varint_bytesio_empty(self): input_io = io.BytesIO(b'') size = decoder._DecodeVarint(input_io) - self.assertEqual(size, None) + self.assertIsNone(size) + + def test_decode_unknown_group_field(self): + data = memoryview(b'\013\020\003\014\040\005') + parsed, pos = decoder._DecodeUnknownField( + data, 1, len(data), 1, wire_format.WIRETYPE_START_GROUP + ) + + self.assertEqual(pos, 4) + self.assertEqual(len(parsed), 1) + self.assertEqual(parsed[0].field_number, 2) + self.assertEqual(parsed[0].data, 3) + + def test_decode_unknown_group_field_nested(self): + data = memoryview(b'\013\023\013\030\004\014\024\014\050\006') + parsed, pos = decoder._DecodeUnknownField( + data, 1, len(data), 1, wire_format.WIRETYPE_START_GROUP + ) + + self.assertEqual(pos, 8) + self.assertEqual(len(parsed), 1) + self.assertEqual(parsed[0].field_number, 2) + self.assertEqual(len(parsed[0].data), 1) + self.assertEqual(parsed[0].data[0].field_number, 1) + self.assertEqual(len(parsed[0].data[0].data), 1) + self.assertEqual(parsed[0].data[0].data[0].field_number, 3) + self.assertEqual(parsed[0].data[0].data[0].data, 4) + + def test_decode_unknown_mismatched_end_group(self): + self.assertRaisesRegex( + message.DecodeError, + 'Missing group end tag.*', + decoder._DecodeUnknownField, + memoryview(b'\013\024'), + 1, + 2, + 1, + wire_format.WIRETYPE_START_GROUP, + ) + + def test_decode_unknown_mismatched_end_group_nested(self): + self.assertRaisesRegex( + message.DecodeError, + 'Missing group end tag.*', + decoder._DecodeUnknownField, + memoryview(b'\013\023\034\024\014'), + 1, + 5, + 1, + wire_format.WIRETYPE_START_GROUP, + ) if __name__ == '__main__': diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index f1541f14bdb9..7212639db684 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -69,6 +69,8 @@ def testParseErrors(self, message_module): msg = message_module.TestAllTypes() self.assertRaises(TypeError, msg.FromString, 0) self.assertRaises(Exception, msg.FromString, '0') + + # Unexpected end group tag. end_tag = encoder.TagBytes(1, 4) with self.assertRaises(message.DecodeError) as context: msg.FromString(end_tag) @@ -76,6 +78,15 @@ def testParseErrors(self, message_module): # upb raises a less specific exception. self.assertRegex(str(context.exception), 'Unexpected end-group tag.*') + # Unmatched start group tag. + start_tag = encoder.TagBytes(2, 3) + with self.assertRaises(message.DecodeError): + msg.FromString(start_tag) + + # Mismatched end group tag. + with self.assertRaises(message.DecodeError): + msg.FromString(start_tag + end_tag) + # Field number 0 is illegal. self.assertRaises(message.DecodeError, msg.FromString, b'\3\4') @@ -1265,16 +1276,16 @@ def testReturningType(self, message_module): self.assertEqual(True, m.repeated_bool[0]) def testDir(self, message_module): - m = message_module.TestAllTypes() - attributes = dir(m) - self.assertGreaterEqual(len(attributes), 55) - self.assertIn('DESCRIPTOR', attributes) - - class_attributes = dir(type(m)) - attribute_set = set(attributes) - for attr in class_attributes: - if attr != 'Extensions': - self.assertIn(attr, attribute_set) + m = message_module.TestAllTypes() + attributes = dir(m) + self.assertGreaterEqual(len(attributes), 55) + self.assertIn('DESCRIPTOR', attributes) + + class_attributes = dir(type(m)) + attribute_set = set(attributes) + for attr in class_attributes: + if attr != 'Extensions': + self.assertIn(attr, attribute_set) def testAllAttributeFromDirAccessible(self, message_module): m = message_module.TestAllTypes() diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index 1f9016880944..22d02c4d6e31 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -1234,7 +1234,8 @@ def InternalParse(self, buffer, pos, end): # TODO: remove old_pos. old_pos = new_pos (data, new_pos) = decoder._DecodeUnknownField( - buffer, new_pos, wire_type) # pylint: disable=protected-access + buffer, new_pos, end, field_number, wire_type + ) # pylint: disable=protected-access if new_pos == -1: return pos # TODO: remove _unknown_fields. diff --git a/python/google/protobuf/unknown_fields.py b/python/google/protobuf/unknown_fields.py index 9b1e54932430..1a4537e3f051 100644 --- a/python/google/protobuf/unknown_fields.py +++ b/python/google/protobuf/unknown_fields.py @@ -78,7 +78,8 @@ def InternalAdd(field_number, wire_type, data): if field_number == 0: raise RuntimeError('Field number 0 is illegal.') (data, _) = decoder._DecodeUnknownField( - memoryview(buffer), 0, wire_type) + memoryview(buffer), 0, len(buffer), field_number, wire_type + ) InternalAdd(field_number, wire_type, data) def __getitem__(self, index):