Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove decoder.SkipField. #19599

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 25 additions & 108 deletions python/google/protobuf/internal/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,19 @@ def ReadTag(buffer, pos):
return tag_bytes, pos


def DecodeTag(tag_bytes):
"""Decode a tag from the bytes.

Args:
tag_bytes: the bytes of the tag

Returns:
Tuple[int, int] of the tag field number and wire type.
"""
(tag, _) = _DecodeVarint(tag_bytes, 0)
return wire_format.UnpackTag(tag)


# --------------------------------------------------------------------


Expand Down Expand Up @@ -730,7 +743,6 @@ def MessageSetItemDecoder(descriptor):

local_ReadTag = ReadTag
local_DecodeVarint = _DecodeVarint
local_SkipField = SkipField

def DecodeItem(buffer, pos, end, message, field_dict):
"""Decode serialized message set to its value and new position.
Expand Down Expand Up @@ -762,9 +774,10 @@ def DecodeItem(buffer, pos, end, message, field_dict):
elif tag_bytes == item_end_tag_bytes:
break
else:
pos = SkipField(buffer, pos, end, tag_bytes)
field_number, wire_type = DecodeTag(tag_bytes)
_, pos = _DecodeUnknownField(buffer, pos, end, field_number, wire_type)
if pos == -1:
raise _DecodeError('Missing group end tag.')
raise _DecodeError('Unexpected end-group tag.')

if pos > end:
raise _DecodeError('Truncated message.')
Expand Down Expand Up @@ -822,9 +835,10 @@ def DecodeUnknownItem(buffer):
elif tag_bytes == item_end_tag_bytes:
break
else:
pos = SkipField(buffer, pos, end, tag_bytes)
field_number, wire_type = DecodeTag(tag_bytes)
_, pos = _DecodeUnknownField(buffer, pos, end, field_number, wire_type)
if pos == -1:
raise _DecodeError('Missing group end tag.')
raise _DecodeError('Unexpected end-group tag.')

if pos > end:
raise _DecodeError('Truncated message.')
Expand Down Expand Up @@ -882,56 +896,18 @@ def DecodeMap(buffer, pos, end, message, field_dict):

return DecodeMap

# --------------------------------------------------------------------
# Optimization is not as heavy here because calls to SkipField() are rare,
# except for handling end-group tags.

def _SkipVarint(buffer, pos, end):
"""Skip a varint value. Returns the new position."""
# Previously ord(buffer[pos]) raised IndexError when pos is out of range.
# With this code, ord(b'') raises TypeError. Both are handled in
# python_message.py to generate a 'Truncated message' error.
while ord(buffer[pos:pos+1].tobytes()) & 0x80:
pos += 1
pos += 1
if pos > end:
raise _DecodeError('Truncated message.')
return pos

def _SkipFixed64(buffer, pos, end):
"""Skip a fixed64 value. Returns the new position."""

pos += 8
if pos > end:
raise _DecodeError('Truncated message.')
return pos


def _DecodeFixed64(buffer, pos):
"""Decode a fixed64."""
new_pos = pos + 8
return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos)


def _SkipLengthDelimited(buffer, pos, end):
"""Skip a length-delimited value. Returns the new position."""

(size, pos) = _DecodeVarint(buffer, pos)
pos += size
if pos > end:
raise _DecodeError('Truncated message.')
return pos


def _SkipGroup(buffer, pos, end):
"""Skip sub-group. Returns the new position."""
def _DecodeFixed32(buffer, pos):
"""Decode a fixed32."""

while 1:
(tag_bytes, pos) = ReadTag(buffer, pos)
new_pos = SkipField(buffer, pos, end, tag_bytes)
if new_pos == -1:
return pos
pos = new_pos
new_pos = pos + 4
return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos)


def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
Expand Down Expand Up @@ -979,66 +955,7 @@ def _DecodeUnknownField(buffer, pos, end_pos, field_number, wire_type):
else:
raise _DecodeError('Wrong wire type in tag.')

return (data, pos)


def _EndGroup(buffer, pos, end):
"""Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""

return -1


def _SkipFixed32(buffer, pos, end):
"""Skip a fixed32 value. Returns the new position."""

pos += 4
if pos > end:
if pos > end_pos:
raise _DecodeError('Truncated message.')
return pos


def _DecodeFixed32(buffer, pos):
"""Decode a fixed32."""

new_pos = pos + 4
return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos)


def _RaiseInvalidWireType(buffer, pos, end):
"""Skip function for unknown wire types. Raises an exception."""

raise _DecodeError('Tag had invalid wire type.')

def _FieldSkipper():
"""Constructs the SkipField function."""

WIRETYPE_TO_SKIPPER = [
_SkipVarint,
_SkipFixed64,
_SkipLengthDelimited,
_SkipGroup,
_EndGroup,
_SkipFixed32,
_RaiseInvalidWireType,
_RaiseInvalidWireType,
]

wiretype_mask = wire_format.TAG_TYPE_MASK

def SkipField(buffer, pos, end, tag_bytes):
"""Skips a field with the specified tag.

|pos| should point to the byte immediately after the tag.

Returns:
The new position (after the tag value), or -1 if the tag is an end-group
tag (in which case the calling loop should break).
"""

# The wire type is always in the first byte since varints are little-endian.
wire_type = ord(tag_bytes[0:1]) & wiretype_mask
return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)

return SkipField

SkipField = _FieldSkipper()
return (data, pos)
24 changes: 24 additions & 0 deletions python/google/protobuf/internal/decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import unittest

from google.protobuf import message
from google.protobuf.internal import api_implementation
from google.protobuf.internal import decoder
from google.protobuf.internal import message_set_extensions_pb2
from google.protobuf.internal import testing_refleaks
from google.protobuf.internal import wire_format

Expand Down Expand Up @@ -104,6 +106,28 @@ def test_decode_unknown_mismatched_end_group_nested(self):
wire_format.WIRETYPE_START_GROUP,
)

def test_decode_message_set_unknown_mismatched_end_group(self):
proto = message_set_extensions_pb2.TestMessageSet()
self.assertRaisesRegex(
message.DecodeError,
'Unexpected end-group tag.'
if api_implementation.Type() == 'python'
else '.*',
proto.ParseFromString,
b'\013\054\014',
)

def test_unknown_message_set_decoder_mismatched_end_group(self):
# This behavior isn't actually reachable in practice, but it's good to
# test anyway.
decode = decoder.UnknownMessageSetItemDecoder()
self.assertRaisesRegex(
message.DecodeError,
'Unexpected end-group tag.',
decode,
memoryview(b'\054\014'),
)


if __name__ == '__main__':
unittest.main()
17 changes: 4 additions & 13 deletions python/google/protobuf/internal/python_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,8 +1192,6 @@ def MergeFromString(self, serialized):
return length # Return this for legacy reasons.
cls.MergeFromString = MergeFromString

local_ReadTag = decoder.ReadTag
local_SkipField = decoder.SkipField
fields_by_tag = cls._fields_by_tag
message_set_decoders_by_tag = cls._message_set_decoders_by_tag

Expand All @@ -1215,7 +1213,7 @@ def InternalParse(self, buffer, pos, end):
self._Modified()
field_dict = self._fields
while pos != end:
(tag_bytes, new_pos) = local_ReadTag(buffer, pos)
(tag_bytes, new_pos) = decoder.ReadTag(buffer, pos)
field_decoder, field_des = message_set_decoders_by_tag.get(
tag_bytes, (None, None)
)
Expand All @@ -1226,24 +1224,17 @@ def InternalParse(self, buffer, pos, end):
if field_des is None:
if not self._unknown_fields: # pylint: disable=protected-access
self._unknown_fields = [] # pylint: disable=protected-access
# pylint: disable=protected-access
(tag, _) = decoder._DecodeVarint(tag_bytes, 0)
field_number, wire_type = wire_format.UnpackTag(tag)
field_number, wire_type = decoder.DecodeTag(tag_bytes)
if field_number == 0:
raise message_mod.DecodeError('Field number 0 is illegal.')
# TODO: remove old_pos.
old_pos = new_pos
(data, new_pos) = decoder._DecodeUnknownField(
buffer, new_pos, end, field_number, wire_type
) # pylint: disable=protected-access
if new_pos == -1:
return pos
# TODO: remove _unknown_fields.
new_pos = local_SkipField(buffer, old_pos, end, tag_bytes)
if new_pos == -1:
return pos
self._unknown_fields.append(
(tag_bytes, buffer[old_pos:new_pos].tobytes()))
(tag_bytes, buffer[pos + len(tag_bytes) : new_pos].tobytes())
)
pos = new_pos
else:
_MaybeAddDecoder(cls, field_des)
Expand Down
4 changes: 1 addition & 3 deletions python/google/protobuf/unknown_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ def InternalAdd(field_number, wire_type, data):
InternalAdd(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED, data)
else:
for tag_bytes, buffer in unknown_fields:
# pylint: disable=protected-access
(tag, _) = decoder._DecodeVarint(tag_bytes, 0)
field_number, wire_type = wire_format.UnpackTag(tag)
field_number, wire_type = decoder.DecodeTag(tag_bytes)
if field_number == 0:
raise RuntimeError('Field number 0 is illegal.')
(data, _) = decoder._DecodeUnknownField(
Expand Down
Loading