Skip to content

Commit

Permalink
Remove decoder.SkipField.
Browse files Browse the repository at this point in the history
This is only used in some narrow edge cases, and is a less-safe version of our unknown field decoders.  Switching to those reduces some duplication and improves error handling.

PiperOrigin-RevId: 704899254
  • Loading branch information
mkruskal-google authored and copybara-github committed Dec 11, 2024
1 parent 7fc92a8 commit 1b0b1c9
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 124 deletions.
133 changes: 25 additions & 108 deletions python/google/protobuf/internal/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,19 @@ def ReadTag(buffer, pos):
return tag_bytes, pos


def DecodeTag(tag_bytes):
"""Decode a tag from the bytes.
Args:
tag_bytes: the bytes of the tag
Returns:
Tuple[int, int] of the tag field number and wire type.
"""
(tag, _) = _DecodeVarint(tag_bytes, 0)
return wire_format.UnpackTag(tag)


# --------------------------------------------------------------------


Expand Down Expand Up @@ -730,7 +743,6 @@ def MessageSetItemDecoder(descriptor):

local_ReadTag = ReadTag
local_DecodeVarint = _DecodeVarint
local_SkipField = SkipField

def DecodeItem(buffer, pos, end, message, field_dict):
"""Decode serialized message set to its value and new position.
Expand Down Expand Up @@ -762,9 +774,10 @@ def DecodeItem(buffer, pos, end, message, field_dict):
elif tag_bytes == item_end_tag_bytes:
break
else:
pos = SkipField(buffer, pos, end, tag_bytes)
field_number, wire_type = DecodeTag(tag_bytes)
_, pos = _DecodeUnknownField(buffer, pos, end, field_number, wire_type)
if pos == -1:
raise _DecodeError('Missing group end tag.')
raise _DecodeError('Unexpected end-group tag.')

if pos > end:
raise _DecodeError('Truncated message.')
Expand Down Expand Up @@ -822,9 +835,10 @@ def DecodeUnknownItem(buffer):
elif tag_bytes == item_end_tag_bytes:
break
else:
pos = SkipField(buffer, pos, end, tag_bytes)
field_number, wire_type = DecodeTag(tag_bytes)
_, pos = _DecodeUnknownField(buffer, pos, end, field_number, wire_type)
if pos == -1:
raise _DecodeError('Missing group end tag.')
raise _DecodeError('Unexpected end-group tag.')

if pos > end:
raise _DecodeError('Truncated message.')
Expand Down Expand Up @@ -882,56 +896,18 @@ def DecodeMap(buffer, pos, end, message, field_dict):

return DecodeMap

# --------------------------------------------------------------------
# Optimization is not as heavy here because calls to SkipField() are rare,
# except for handling end-group tags.

def _SkipVarint(buffer, pos, end):
"""Skip a varint value. Returns the new position."""
# Previously ord(buffer[pos]) raised IndexError when pos is out of range.
# With this code, ord(b'') raises TypeError. Both are handled in
# python_message.py to generate a 'Truncated message' error.
while ord(buffer[pos:pos+1].tobytes()) & 0x80:
pos += 1
pos += 1
if pos > end:
raise _DecodeError('Truncated message.')
return pos

def _SkipFixed64(buffer, pos, end):
"""Skip a fixed64 value. Returns the new position."""

pos += 8
if pos > end:
raise _DecodeError('Truncated message.')
return pos


def _DecodeFixed64(buffer, pos):
"""Decode a fixed64."""
new_pos = pos + 8
return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos)


def _SkipLengthDelimited(buffer, pos, end):
"""Skip a length-delimited value. Returns the new position."""

(size, pos) = _DecodeVarint(buffer, pos)
pos += size
if pos > end:
raise _DecodeError('Truncated message.')
return pos


def _SkipGroup(buffer, pos, end):
"""Skip sub-group. Returns the new position."""
def _DecodeFixed32(buffer, pos):
"""Decode a fixed32."""

while 1:
(tag_bytes, pos) = ReadTag(buffer, pos)
new_pos = SkipField(buffer, pos, end, tag_bytes)
if new_pos == -1:
return pos
pos = new_pos
new_pos = pos + 4
return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos)


def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
Expand Down Expand Up @@ -979,66 +955,7 @@ def _DecodeUnknownField(buffer, pos, end_pos, field_number, wire_type):
else:
raise _DecodeError('Wrong wire type in tag.')

return (data, pos)


def _EndGroup(buffer, pos, end):
"""Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""

return -1


def _SkipFixed32(buffer, pos, end):
"""Skip a fixed32 value. Returns the new position."""

pos += 4
if pos > end:
if pos > end_pos:
raise _DecodeError('Truncated message.')
return pos


def _DecodeFixed32(buffer, pos):
"""Decode a fixed32."""

new_pos = pos + 4
return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos)


def _RaiseInvalidWireType(buffer, pos, end):
"""Skip function for unknown wire types. Raises an exception."""

raise _DecodeError('Tag had invalid wire type.')

def _FieldSkipper():
"""Constructs the SkipField function."""

WIRETYPE_TO_SKIPPER = [
_SkipVarint,
_SkipFixed64,
_SkipLengthDelimited,
_SkipGroup,
_EndGroup,
_SkipFixed32,
_RaiseInvalidWireType,
_RaiseInvalidWireType,
]

wiretype_mask = wire_format.TAG_TYPE_MASK

def SkipField(buffer, pos, end, tag_bytes):
"""Skips a field with the specified tag.
|pos| should point to the byte immediately after the tag.
Returns:
The new position (after the tag value), or -1 if the tag is an end-group
tag (in which case the calling loop should break).
"""

# The wire type is always in the first byte since varints are little-endian.
wire_type = ord(tag_bytes[0:1]) & wiretype_mask
return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)

return SkipField

SkipField = _FieldSkipper()
return (data, pos)
24 changes: 24 additions & 0 deletions python/google/protobuf/internal/decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import unittest

from google.protobuf import message
from google.protobuf.internal import api_implementation
from google.protobuf.internal import decoder
from google.protobuf.internal import message_set_extensions_pb2
from google.protobuf.internal import testing_refleaks
from google.protobuf.internal import wire_format

Expand Down Expand Up @@ -104,6 +106,28 @@ def test_decode_unknown_mismatched_end_group_nested(self):
wire_format.WIRETYPE_START_GROUP,
)

def test_decode_message_set_unknown_mismatched_end_group(self):
proto = message_set_extensions_pb2.TestMessageSet()
self.assertRaisesRegex(
message.DecodeError,
'Unexpected end-group tag.'
if api_implementation.Type() == 'python'
else '.*',
proto.ParseFromString,
b'\013\054\014',
)

def test_unknown_message_set_decoder_mismatched_end_group(self):
# This behavior isn't actually reachable in practice, but it's good to
# test anyway.
decode = decoder.UnknownMessageSetItemDecoder()
self.assertRaisesRegex(
message.DecodeError,
'Unexpected end-group tag.',
decode,
memoryview(b'\054\014'),
)


if __name__ == '__main__':
unittest.main()
17 changes: 4 additions & 13 deletions python/google/protobuf/internal/python_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,8 +1192,6 @@ def MergeFromString(self, serialized):
return length # Return this for legacy reasons.
cls.MergeFromString = MergeFromString

local_ReadTag = decoder.ReadTag
local_SkipField = decoder.SkipField
fields_by_tag = cls._fields_by_tag
message_set_decoders_by_tag = cls._message_set_decoders_by_tag

Expand All @@ -1215,7 +1213,7 @@ def InternalParse(self, buffer, pos, end):
self._Modified()
field_dict = self._fields
while pos != end:
(tag_bytes, new_pos) = local_ReadTag(buffer, pos)
(tag_bytes, new_pos) = decoder.ReadTag(buffer, pos)
field_decoder, field_des = message_set_decoders_by_tag.get(
tag_bytes, (None, None)
)
Expand All @@ -1226,24 +1224,17 @@ def InternalParse(self, buffer, pos, end):
if field_des is None:
if not self._unknown_fields: # pylint: disable=protected-access
self._unknown_fields = [] # pylint: disable=protected-access
# pylint: disable=protected-access
(tag, _) = decoder._DecodeVarint(tag_bytes, 0)
field_number, wire_type = wire_format.UnpackTag(tag)
field_number, wire_type = decoder.DecodeTag(tag_bytes)
if field_number == 0:
raise message_mod.DecodeError('Field number 0 is illegal.')
# TODO: remove old_pos.
old_pos = new_pos
(data, new_pos) = decoder._DecodeUnknownField(
buffer, new_pos, end, field_number, wire_type
) # pylint: disable=protected-access
if new_pos == -1:
return pos
# TODO: remove _unknown_fields.
new_pos = local_SkipField(buffer, old_pos, end, tag_bytes)
if new_pos == -1:
return pos
self._unknown_fields.append(
(tag_bytes, buffer[old_pos:new_pos].tobytes()))
(tag_bytes, buffer[pos + len(tag_bytes) : new_pos].tobytes())
)
pos = new_pos
else:
_MaybeAddDecoder(cls, field_des)
Expand Down
4 changes: 1 addition & 3 deletions python/google/protobuf/unknown_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ def InternalAdd(field_number, wire_type, data):
InternalAdd(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED, data)
else:
for tag_bytes, buffer in unknown_fields:
# pylint: disable=protected-access
(tag, _) = decoder._DecodeVarint(tag_bytes, 0)
field_number, wire_type = wire_format.UnpackTag(tag)
field_number, wire_type = decoder.DecodeTag(tag_bytes)
if field_number == 0:
raise RuntimeError('Field number 0 is illegal.')
(data, _) = decoder._DecodeUnknownField(
Expand Down

0 comments on commit 1b0b1c9

Please sign in to comment.