Skip to content

Commit 05ba1a8

Browse files
protobuf-github-botshaod2
authored andcommitted
Add recursion depth limits to pure python
PiperOrigin-RevId: 758382549
1 parent 1ef3f01 commit 05ba1a8

File tree

4 files changed

+105
-5
lines changed

4 files changed

+105
-5
lines changed

python/google/protobuf/internal/decoder.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,13 @@ def DecodeRepeatedField(
668668
if value is None:
669669
value = field_dict.setdefault(key, new_default(message))
670670
# Read sub-message.
671+
current_depth += 1
672+
if current_depth > _recursion_limit:
673+
raise _DecodeError(
674+
'Error parsing message: too many levels of nesting.'
675+
)
671676
pos = value.add()._InternalParse(buffer, pos, end, current_depth)
677+
current_depth -= 1
672678
# Read end tag.
673679
new_pos = pos+end_tag_len
674680
if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
@@ -687,7 +693,11 @@ def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
687693
if value is None:
688694
value = field_dict.setdefault(key, new_default(message))
689695
# Read sub-message.
696+
current_depth += 1
697+
if current_depth > _recursion_limit:
698+
raise _DecodeError('Error parsing message: too many levels of nesting.')
690699
pos = value._InternalParse(buffer, pos, end, current_depth)
700+
current_depth -= 1
691701
# Read end tag.
692702
new_pos = pos+end_tag_len
693703
if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
@@ -720,13 +730,19 @@ def DecodeRepeatedField(
720730
if new_pos > end:
721731
raise _DecodeError('Truncated message.')
722732
# Read sub-message.
733+
current_depth += 1
734+
if current_depth > _recursion_limit:
735+
raise _DecodeError(
736+
'Error parsing message: too many levels of nesting.'
737+
)
723738
if (
724739
value.add()._InternalParse(buffer, pos, new_pos, current_depth)
725740
!= new_pos
726741
):
727742
# The only reason _InternalParse would return early is if it
728743
# encountered an end-group tag.
729744
raise _DecodeError('Unexpected end-group tag.')
745+
current_depth -= 1
730746
# Predict that the next tag is another copy of the same repeated field.
731747
pos = new_pos + tag_len
732748
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
@@ -746,10 +762,14 @@ def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
746762
if new_pos > end:
747763
raise _DecodeError('Truncated message.')
748764
# Read sub-message.
765+
current_depth += 1
766+
if current_depth > _recursion_limit:
767+
raise _DecodeError('Error parsing message: too many levels of nesting.')
749768
if value._InternalParse(buffer, pos, new_pos, current_depth) != new_pos:
750769
# The only reason _InternalParse would return early is if it encountered
751770
# an end-group tag.
752771
raise _DecodeError('Unexpected end-group tag.')
772+
current_depth -= 1
753773
return new_pos
754774

755775
return DecodeField
@@ -984,6 +1004,15 @@ def _SkipGroup(buffer, pos, end):
9841004
pos = new_pos
9851005

9861006

1007+
DEFAULT_RECURSION_LIMIT = 100
1008+
_recursion_limit = DEFAULT_RECURSION_LIMIT
1009+
1010+
1011+
def SetRecursionLimit(new_limit):
1012+
global _recursion_limit
1013+
_recursion_limit = new_limit
1014+
1015+
9871016
def _DecodeUnknownFieldSet(buffer, pos, end_pos=None, current_depth=0):
9881017
"""Decode UnknownFieldSet. Returns the UnknownFieldSet and new position."""
9891018

@@ -1017,7 +1046,11 @@ def _DecodeUnknownField(
10171046
data = buffer[pos:pos+size].tobytes()
10181047
pos += size
10191048
elif wire_type == wire_format.WIRETYPE_START_GROUP:
1020-
(data, pos) = _DecodeUnknownFieldSet(buffer, pos, None, current_depth)
1049+
current_depth += 1
1050+
if current_depth >= _recursion_limit:
1051+
raise _DecodeError('Error parsing message: too many levels of nesting.')
1052+
data, pos = _DecodeUnknownFieldSet(buffer, pos, None, current_depth)
1053+
current_depth -= 1
10211054
elif wire_type == wire_format.WIRETYPE_END_GROUP:
10221055
return (0, -1)
10231056
else:

python/google/protobuf/internal/decoder_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
import io
1212
import unittest
1313

14+
from google.protobuf import message
1415
from google.protobuf.internal import decoder
1516
from google.protobuf.internal import testing_refleaks
17+
from google.protobuf.internal import wire_format
1618

1719

1820
_INPUT_BYTES = b'\x84r\x12'
@@ -52,6 +54,18 @@ def test_decode_varint_bytesio_empty(self):
5254
size = decoder._DecodeVarint(input_io)
5355
self.assertEqual(size, None)
5456

57+
def test_decode_unknown_group_field_too_many_levels(self):
58+
data = memoryview(b'\023' * 5_000_000)
59+
self.assertRaisesRegex(
60+
message.DecodeError,
61+
'Error parsing message',
62+
decoder._DecodeUnknownField,
63+
data,
64+
1,
65+
wire_format.WIRETYPE_START_GROUP,
66+
1
67+
)
68+
5569

5670
if __name__ == '__main__':
5771
unittest.main()

python/google/protobuf/internal/message_test.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from google.protobuf.internal import more_extensions_pb2
3737
from google.protobuf.internal import more_messages_pb2
3838
from google.protobuf.internal import packed_field_test_pb2
39+
from google.protobuf.internal import self_recursive_pb2
3940
from google.protobuf.internal import test_proto3_optional_pb2
4041
from google.protobuf.internal import test_util
4142
from google.protobuf.internal import testing_refleaks
@@ -1339,6 +1340,52 @@ def testIn(self, message_module):
13391340
self.assertNotIn('oneof_string', m)
13401341

13411342

1343+
@testing_refleaks.TestCase
1344+
class TestRecursiveGroup(unittest.TestCase):
1345+
1346+
def _MakeRecursiveGroupMessage(self, n):
1347+
msg = self_recursive_pb2.SelfRecursive()
1348+
sub = msg
1349+
for _ in range(n):
1350+
sub = sub.sub_group
1351+
sub.i = 1
1352+
return msg.SerializeToString()
1353+
1354+
def testRecursiveGroups(self):
1355+
recurse_msg = self_recursive_pb2.SelfRecursive()
1356+
data = self._MakeRecursiveGroupMessage(100)
1357+
recurse_msg.ParseFromString(data)
1358+
self.assertTrue(recurse_msg.HasField('sub_group'))
1359+
1360+
def testRecursiveGroupsException(self):
1361+
if api_implementation.Type() != 'python':
1362+
api_implementation._c_module.SetAllowOversizeProtos(False)
1363+
recurse_msg = self_recursive_pb2.SelfRecursive()
1364+
data = self._MakeRecursiveGroupMessage(300)
1365+
with self.assertRaises(message.DecodeError) as context:
1366+
recurse_msg.ParseFromString(data)
1367+
self.assertIn('Error parsing message', str(context.exception))
1368+
if api_implementation.Type() == 'python':
1369+
self.assertIn('too many levels of nesting', str(context.exception))
1370+
1371+
def testRecursiveGroupsUnknownFields(self):
1372+
if api_implementation.Type() != 'python':
1373+
api_implementation._c_module.SetAllowOversizeProtos(False)
1374+
test_msg = unittest_pb2.TestAllTypes()
1375+
data = self._MakeRecursiveGroupMessage(300) # unknown to test_msg
1376+
with self.assertRaises(message.DecodeError) as context:
1377+
test_msg.ParseFromString(data)
1378+
self.assertIn(
1379+
'Error parsing message',
1380+
str(context.exception),
1381+
)
1382+
if api_implementation.Type() == 'python':
1383+
self.assertIn('too many levels of nesting', str(context.exception))
1384+
decoder.SetRecursionLimit(310)
1385+
test_msg.ParseFromString(data)
1386+
decoder.SetRecursionLimit(decoder.DEFAULT_RECURSION_LIMIT)
1387+
1388+
13421389
# Class to test proto2-only features (required, extensions, etc.)
13431390
@testing_refleaks.TestCase
13441391
class Proto2Test(unittest.TestCase):
@@ -2728,8 +2775,6 @@ def testUnpackedFields(self):
27282775
self.assertEqual(golden_data, message.SerializeToString())
27292776

27302777

2731-
@unittest.skipIf(api_implementation.Type() == 'python',
2732-
'explicit tests of the C++ implementation')
27332778
@testing_refleaks.TestCase
27342779
class OversizeProtosTest(unittest.TestCase):
27352780

@@ -2746,16 +2791,23 @@ def testSucceedOkSizedProto(self):
27462791
msg.ParseFromString(self.GenerateNestedProto(100))
27472792

27482793
def testAssertOversizeProto(self):
2749-
api_implementation._c_module.SetAllowOversizeProtos(False)
2794+
if api_implementation.Type() != 'python':
2795+
api_implementation._c_module.SetAllowOversizeProtos(False)
27502796
msg = unittest_pb2.TestRecursiveMessage()
27512797
with self.assertRaises(message.DecodeError) as context:
27522798
msg.ParseFromString(self.GenerateNestedProto(101))
27532799
self.assertIn('Error parsing message', str(context.exception))
27542800

27552801
def testSucceedOversizeProto(self):
2756-
api_implementation._c_module.SetAllowOversizeProtos(True)
2802+
2803+
if api_implementation.Type() == 'python':
2804+
decoder.SetRecursionLimit(310)
2805+
else:
2806+
api_implementation._c_module.SetAllowOversizeProtos(True)
2807+
27572808
msg = unittest_pb2.TestRecursiveMessage()
27582809
msg.ParseFromString(self.GenerateNestedProto(101))
2810+
decoder.SetRecursionLimit(decoder.DEFAULT_RECURSION_LIMIT)
27592811

27602812

27612813
if __name__ == '__main__':

python/google/protobuf/internal/self_recursive.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ package google.protobuf.python.internal;
1212
message SelfRecursive {
1313
SelfRecursive sub = 1;
1414
int32 i = 2;
15+
SelfRecursive sub_group = 3 [features.message_encoding = DELIMITED];
1516
}
1617

1718
message IndirectRecursive {

0 commit comments

Comments
 (0)