Skip to content

Commit a3921fb

Browse files
protobuf-github-botshaod2
authored andcommitted
Add recursion depth limits to pure python
PiperOrigin-RevId: 758382549
1 parent ac94456 commit a3921fb

File tree

4 files changed

+101
-4
lines changed

4 files changed

+101
-4
lines changed

python/google/protobuf/internal/decoder.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,13 @@ def DecodeRepeatedField(
703703
if value is None:
704704
value = field_dict.setdefault(key, new_default(message))
705705
# Read sub-message.
706+
current_depth += 1
707+
if current_depth > _recursion_limit:
708+
raise _DecodeError(
709+
'Error parsing message: too many levels of nesting.'
710+
)
706711
pos = value.add()._InternalParse(buffer, pos, end, current_depth)
712+
current_depth -= 1
707713
# Read end tag.
708714
new_pos = pos+end_tag_len
709715
if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
@@ -722,7 +728,11 @@ def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
722728
if value is None:
723729
value = field_dict.setdefault(key, new_default(message))
724730
# Read sub-message.
731+
current_depth += 1
732+
if current_depth > _recursion_limit:
733+
raise _DecodeError('Error parsing message: too many levels of nesting.')
725734
pos = value._InternalParse(buffer, pos, end, current_depth)
735+
current_depth -= 1
726736
# Read end tag.
727737
new_pos = pos+end_tag_len
728738
if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
@@ -755,13 +765,19 @@ def DecodeRepeatedField(
755765
if new_pos > end:
756766
raise _DecodeError('Truncated message.')
757767
# Read sub-message.
768+
current_depth += 1
769+
if current_depth > _recursion_limit:
770+
raise _DecodeError(
771+
'Error parsing message: too many levels of nesting.'
772+
)
758773
if (
759774
value.add()._InternalParse(buffer, pos, new_pos, current_depth)
760775
!= new_pos
761776
):
762777
# The only reason _InternalParse would return early is if it
763778
# encountered an end-group tag.
764779
raise _DecodeError('Unexpected end-group tag.')
780+
current_depth -= 1
765781
# Predict that the next tag is another copy of the same repeated field.
766782
pos = new_pos + tag_len
767783
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
@@ -781,10 +797,14 @@ def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
781797
if new_pos > end:
782798
raise _DecodeError('Truncated message.')
783799
# Read sub-message.
800+
current_depth += 1
801+
if current_depth > _recursion_limit:
802+
raise _DecodeError('Error parsing message: too many levels of nesting.')
784803
if value._InternalParse(buffer, pos, new_pos, current_depth) != new_pos:
785804
# The only reason _InternalParse would return early is if it encountered
786805
# an end-group tag.
787806
raise _DecodeError('Unexpected end-group tag.')
807+
current_depth -= 1
788808
return new_pos
789809

790810
return DecodeField
@@ -980,6 +1000,13 @@ def _DecodeFixed32(buffer, pos):
9801000

9811001
new_pos = pos + 4
9821002
return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos)
1003+
DEFAULT_RECURSION_LIMIT = 100
1004+
_recursion_limit = DEFAULT_RECURSION_LIMIT
1005+
1006+
1007+
def SetRecursionLimit(new_limit):
1008+
global _recursion_limit
1009+
_recursion_limit = new_limit
9831010

9841011

9851012
def _DecodeUnknownFieldSet(buffer, pos, end_pos=None, current_depth=0):
@@ -1020,7 +1047,11 @@ def _DecodeUnknownField(
10201047
end_tag_bytes = encoder.TagBytes(
10211048
field_number, wire_format.WIRETYPE_END_GROUP
10221049
)
1050+
current_depth += 1
1051+
if current_depth >= _recursion_limit:
1052+
raise _DecodeError('Error parsing message: too many levels of nesting.')
10231053
data, pos = _DecodeUnknownFieldSet(buffer, pos, end_pos, current_depth)
1054+
current_depth -= 1
10241055
# Check end tag.
10251056
if buffer[pos - len(end_tag_bytes) : pos] != end_tag_bytes:
10261057
raise _DecodeError('Missing group end tag.')

python/google/protobuf/internal/decoder_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,19 @@ def test_decode_unknown_group_field_nested(self):
8484
self.assertEqual(parsed[0].data[0].data[0].field_number, 3)
8585
self.assertEqual(parsed[0].data[0].data[0].data, 4)
8686

87+
def test_decode_unknown_group_field_too_many_levels(self):
88+
data = memoryview(b'\023' * 5_000_000)
89+
self.assertRaisesRegex(
90+
message.DecodeError,
91+
'Error parsing message',
92+
decoder._DecodeUnknownField,
93+
data,
94+
1,
95+
len(data),
96+
1,
97+
wire_format.WIRETYPE_START_GROUP,
98+
)
99+
87100
def test_decode_unknown_mismatched_end_group(self):
88101
self.assertRaisesRegex(
89102
message.DecodeError,

python/google/protobuf/internal/message_test.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from google.protobuf.internal import more_extensions_pb2
3737
from google.protobuf.internal import more_messages_pb2
3838
from google.protobuf.internal import packed_field_test_pb2
39+
from google.protobuf.internal import self_recursive_pb2
3940
from google.protobuf.internal import test_proto3_optional_pb2
4041
from google.protobuf.internal import test_util
4142
from google.protobuf.internal import testing_refleaks
@@ -1431,6 +1432,52 @@ def testMessageClassName(self, message_module):
14311432
)
14321433

14331434

1435+
@testing_refleaks.TestCase
1436+
class TestRecursiveGroup(unittest.TestCase):
1437+
1438+
def _MakeRecursiveGroupMessage(self, n):
1439+
msg = self_recursive_pb2.SelfRecursive()
1440+
sub = msg
1441+
for _ in range(n):
1442+
sub = sub.sub_group
1443+
sub.i = 1
1444+
return msg.SerializeToString()
1445+
1446+
def testRecursiveGroups(self):
1447+
recurse_msg = self_recursive_pb2.SelfRecursive()
1448+
data = self._MakeRecursiveGroupMessage(100)
1449+
recurse_msg.ParseFromString(data)
1450+
self.assertTrue(recurse_msg.HasField('sub_group'))
1451+
1452+
def testRecursiveGroupsException(self):
1453+
if api_implementation.Type() != 'python':
1454+
api_implementation._c_module.SetAllowOversizeProtos(False)
1455+
recurse_msg = self_recursive_pb2.SelfRecursive()
1456+
data = self._MakeRecursiveGroupMessage(300)
1457+
with self.assertRaises(message.DecodeError) as context:
1458+
recurse_msg.ParseFromString(data)
1459+
self.assertIn('Error parsing message', str(context.exception))
1460+
if api_implementation.Type() == 'python':
1461+
self.assertIn('too many levels of nesting', str(context.exception))
1462+
1463+
def testRecursiveGroupsUnknownFields(self):
1464+
if api_implementation.Type() != 'python':
1465+
api_implementation._c_module.SetAllowOversizeProtos(False)
1466+
test_msg = unittest_pb2.TestAllTypes()
1467+
data = self._MakeRecursiveGroupMessage(300) # unknown to test_msg
1468+
with self.assertRaises(message.DecodeError) as context:
1469+
test_msg.ParseFromString(data)
1470+
self.assertIn(
1471+
'Error parsing message',
1472+
str(context.exception),
1473+
)
1474+
if api_implementation.Type() == 'python':
1475+
self.assertIn('too many levels of nesting', str(context.exception))
1476+
decoder.SetRecursionLimit(310)
1477+
test_msg.ParseFromString(data)
1478+
decoder.SetRecursionLimit(decoder.DEFAULT_RECURSION_LIMIT)
1479+
1480+
14341481
# Class to test proto2-only features (required, extensions, etc.)
14351482
@testing_refleaks.TestCase
14361483
class Proto2Test(unittest.TestCase):
@@ -2859,8 +2906,6 @@ def testUnpackedFields(self):
28592906
self.assertEqual(golden_data, message.SerializeToString())
28602907

28612908

2862-
@unittest.skipIf(api_implementation.Type() == 'python',
2863-
'explicit tests of the C++ implementation')
28642909
@testing_refleaks.TestCase
28652910
class OversizeProtosTest(unittest.TestCase):
28662911

@@ -2877,16 +2922,23 @@ def testSucceedOkSizedProto(self):
28772922
msg.ParseFromString(self.GenerateNestedProto(100))
28782923

28792924
def testAssertOversizeProto(self):
2880-
api_implementation._c_module.SetAllowOversizeProtos(False)
2925+
if api_implementation.Type() != 'python':
2926+
api_implementation._c_module.SetAllowOversizeProtos(False)
28812927
msg = unittest_pb2.TestRecursiveMessage()
28822928
with self.assertRaises(message.DecodeError) as context:
28832929
msg.ParseFromString(self.GenerateNestedProto(101))
28842930
self.assertIn('Error parsing message', str(context.exception))
28852931

28862932
def testSucceedOversizeProto(self):
2887-
api_implementation._c_module.SetAllowOversizeProtos(True)
2933+
2934+
if api_implementation.Type() == 'python':
2935+
decoder.SetRecursionLimit(310)
2936+
else:
2937+
api_implementation._c_module.SetAllowOversizeProtos(True)
2938+
28882939
msg = unittest_pb2.TestRecursiveMessage()
28892940
msg.ParseFromString(self.GenerateNestedProto(101))
2941+
decoder.SetRecursionLimit(decoder.DEFAULT_RECURSION_LIMIT)
28902942

28912943

28922944
if __name__ == '__main__':

python/google/protobuf/internal/self_recursive.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ package google.protobuf.python.internal;
1212
message SelfRecursive {
1313
SelfRecursive sub = 1;
1414
int32 i = 2;
15+
SelfRecursive sub_group = 3 [features.message_encoding = DELIMITED];
1516
}
1617

1718
message IndirectRecursive {

0 commit comments

Comments
 (0)