From 1b0b1c9ba005e85292cf597de6bcb64c69eeb23a Mon Sep 17 00:00:00 2001 From: Mike Kruskal Date: Tue, 10 Dec 2024 17:13:31 -0800 Subject: [PATCH] Remove decoder.SkipField. This is only used in some narrow edge cases, and is a less-safe version of our unknown field decoders. Switching to those reduces some duplication and improves error handling. PiperOrigin-RevId: 704899254 --- python/google/protobuf/internal/decoder.py | 133 ++++-------------- .../google/protobuf/internal/decoder_test.py | 24 ++++ .../protobuf/internal/python_message.py | 17 +-- python/google/protobuf/unknown_fields.py | 4 +- 4 files changed, 54 insertions(+), 124 deletions(-) diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py index e26770e1d5fc..f9e45c53f11b 100755 --- a/python/google/protobuf/internal/decoder.py +++ b/python/google/protobuf/internal/decoder.py @@ -168,6 +168,19 @@ def ReadTag(buffer, pos): return tag_bytes, pos +def DecodeTag(tag_bytes): + """Decode a tag from the bytes. + + Args: + tag_bytes: the bytes of the tag + + Returns: + Tuple[int, int] of the tag field number and wire type. + """ + (tag, _) = _DecodeVarint(tag_bytes, 0) + return wire_format.UnpackTag(tag) + + # -------------------------------------------------------------------- @@ -730,7 +743,6 @@ def MessageSetItemDecoder(descriptor): local_ReadTag = ReadTag local_DecodeVarint = _DecodeVarint - local_SkipField = SkipField def DecodeItem(buffer, pos, end, message, field_dict): """Decode serialized message set to its value and new position. @@ -762,9 +774,10 @@ def DecodeItem(buffer, pos, end, message, field_dict): elif tag_bytes == item_end_tag_bytes: break else: - pos = SkipField(buffer, pos, end, tag_bytes) + field_number, wire_type = DecodeTag(tag_bytes) + _, pos = _DecodeUnknownField(buffer, pos, end, field_number, wire_type) if pos == -1: - raise _DecodeError('Missing group end tag.') + raise _DecodeError('Unexpected end-group tag.') if pos > end: raise _DecodeError('Truncated message.') @@ -822,9 +835,10 @@ def DecodeUnknownItem(buffer): elif tag_bytes == item_end_tag_bytes: break else: - pos = SkipField(buffer, pos, end, tag_bytes) + field_number, wire_type = DecodeTag(tag_bytes) + _, pos = _DecodeUnknownField(buffer, pos, end, field_number, wire_type) if pos == -1: - raise _DecodeError('Missing group end tag.') + raise _DecodeError('Unexpected end-group tag.') if pos > end: raise _DecodeError('Truncated message.') @@ -882,30 +896,6 @@ def DecodeMap(buffer, pos, end, message, field_dict): return DecodeMap -# -------------------------------------------------------------------- -# Optimization is not as heavy here because calls to SkipField() are rare, -# except for handling end-group tags. - -def _SkipVarint(buffer, pos, end): - """Skip a varint value. Returns the new position.""" - # Previously ord(buffer[pos]) raised IndexError when pos is out of range. - # With this code, ord(b'') raises TypeError. Both are handled in - # python_message.py to generate a 'Truncated message' error. - while ord(buffer[pos:pos+1].tobytes()) & 0x80: - pos += 1 - pos += 1 - if pos > end: - raise _DecodeError('Truncated message.') - return pos - -def _SkipFixed64(buffer, pos, end): - """Skip a fixed64 value. Returns the new position.""" - - pos += 8 - if pos > end: - raise _DecodeError('Truncated message.') - return pos - def _DecodeFixed64(buffer, pos): """Decode a fixed64.""" @@ -913,25 +903,11 @@ def _DecodeFixed64(buffer, pos): return (struct.unpack(' end: - raise _DecodeError('Truncated message.') - return pos - - -def _SkipGroup(buffer, pos, end): - """Skip sub-group. Returns the new position.""" +def _DecodeFixed32(buffer, pos): + """Decode a fixed32.""" - while 1: - (tag_bytes, pos) = ReadTag(buffer, pos) - new_pos = SkipField(buffer, pos, end, tag_bytes) - if new_pos == -1: - return pos - pos = new_pos + new_pos = pos + 4 + return (struct.unpack(' end: + if pos > end_pos: raise _DecodeError('Truncated message.') - return pos - - -def _DecodeFixed32(buffer, pos): - """Decode a fixed32.""" - - new_pos = pos + 4 - return (struct.unpack('