Skip to content

Commit d31100c

Browse files
committed
Manually backport recursion limit enforcement to 25.x
1 parent 88a3b90 commit d31100c

File tree

6 files changed

+176
-27
lines changed

6 files changed

+176
-27
lines changed

python/build_targets.bzl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,12 @@ def build_targets(name):
335335
data = ["//src/google/protobuf:testdata"],
336336
)
337337

338+
internal_py_test(
339+
name = "decoder_test",
340+
srcs = ["google/protobuf/internal/decoder_test.py"],
341+
data = ["//src/google/protobuf:testdata"],
342+
)
343+
338344
internal_py_test(
339345
name = "proto_builder_test",
340346
srcs = ["google/protobuf/internal/proto_builder_test.py"],

python/google/protobuf/internal/decoder.py

Lines changed: 86 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,10 @@ def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default,
172172
clear_if_default=False):
173173
if is_packed:
174174
local_DecodeVarint = _DecodeVarint
175-
def DecodePackedField(buffer, pos, end, message, field_dict):
175+
def DecodePackedField(
176+
buffer, pos, end, message, field_dict, current_depth=0
177+
):
178+
del current_depth # unused
176179
value = field_dict.get(key)
177180
if value is None:
178181
value = field_dict.setdefault(key, new_default(message))
@@ -191,7 +194,10 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
191194
elif is_repeated:
192195
tag_bytes = encoder.TagBytes(field_number, wire_type)
193196
tag_len = len(tag_bytes)
194-
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
197+
def DecodeRepeatedField(
198+
buffer, pos, end, message, field_dict, current_depth=0
199+
):
200+
del current_depth # unused
195201
value = field_dict.get(key)
196202
if value is None:
197203
value = field_dict.setdefault(key, new_default(message))
@@ -208,7 +214,8 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
208214
return new_pos
209215
return DecodeRepeatedField
210216
else:
211-
def DecodeField(buffer, pos, end, message, field_dict):
217+
def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
218+
del current_depth # unused
212219
(new_value, pos) = decode_value(buffer, pos)
213220
if pos > end:
214221
raise _DecodeError('Truncated message.')
@@ -352,7 +359,9 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
352359
enum_type = key.enum_type
353360
if is_packed:
354361
local_DecodeVarint = _DecodeVarint
355-
def DecodePackedField(buffer, pos, end, message, field_dict):
362+
def DecodePackedField(
363+
buffer, pos, end, message, field_dict, current_depth=0
364+
):
356365
"""Decode serialized packed enum to its value and a new position.
357366
358367
Args:
@@ -365,6 +374,7 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
365374
Returns:
366375
int, new position in serialized data.
367376
"""
377+
del current_depth # unused
368378
value = field_dict.get(key)
369379
if value is None:
370380
value = field_dict.setdefault(key, new_default(message))
@@ -405,7 +415,9 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
405415
elif is_repeated:
406416
tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
407417
tag_len = len(tag_bytes)
408-
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
418+
def DecodeRepeatedField(
419+
buffer, pos, end, message, field_dict, current_depth=0
420+
):
409421
"""Decode serialized repeated enum to its value and a new position.
410422
411423
Args:
@@ -418,6 +430,7 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
418430
Returns:
419431
int, new position in serialized data.
420432
"""
433+
del current_depth # unused
421434
value = field_dict.get(key)
422435
if value is None:
423436
value = field_dict.setdefault(key, new_default(message))
@@ -446,7 +459,7 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
446459
return new_pos
447460
return DecodeRepeatedField
448461
else:
449-
def DecodeField(buffer, pos, end, message, field_dict):
462+
def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
450463
"""Decode serialized repeated enum to its value and a new position.
451464
452465
Args:
@@ -459,6 +472,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
459472
Returns:
460473
int, new position in serialized data.
461474
"""
475+
del current_depth # unused
462476
value_start_pos = pos
463477
(enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
464478
if pos > end:
@@ -540,7 +554,10 @@ def _ConvertToUnicode(memview):
540554
tag_bytes = encoder.TagBytes(field_number,
541555
wire_format.WIRETYPE_LENGTH_DELIMITED)
542556
tag_len = len(tag_bytes)
543-
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
557+
def DecodeRepeatedField(
558+
buffer, pos, end, message, field_dict, current_depth=0
559+
):
560+
del current_depth # unused
544561
value = field_dict.get(key)
545562
if value is None:
546563
value = field_dict.setdefault(key, new_default(message))
@@ -557,7 +574,8 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
557574
return new_pos
558575
return DecodeRepeatedField
559576
else:
560-
def DecodeField(buffer, pos, end, message, field_dict):
577+
def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
578+
del current_depth # unused
561579
(size, pos) = local_DecodeVarint(buffer, pos)
562580
new_pos = pos + size
563581
if new_pos > end:
@@ -581,7 +599,10 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
581599
tag_bytes = encoder.TagBytes(field_number,
582600
wire_format.WIRETYPE_LENGTH_DELIMITED)
583601
tag_len = len(tag_bytes)
584-
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
602+
def DecodeRepeatedField(
603+
buffer, pos, end, message, field_dict, current_depth=0
604+
):
605+
del current_depth # unused
585606
value = field_dict.get(key)
586607
if value is None:
587608
value = field_dict.setdefault(key, new_default(message))
@@ -598,7 +619,8 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
598619
return new_pos
599620
return DecodeRepeatedField
600621
else:
601-
def DecodeField(buffer, pos, end, message, field_dict):
622+
def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
623+
del current_depth # unused
602624
(size, pos) = local_DecodeVarint(buffer, pos)
603625
new_pos = pos + size
604626
if new_pos > end:
@@ -623,7 +645,9 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
623645
tag_bytes = encoder.TagBytes(field_number,
624646
wire_format.WIRETYPE_START_GROUP)
625647
tag_len = len(tag_bytes)
626-
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
648+
def DecodeRepeatedField(
649+
buffer, pos, end, message, field_dict, current_depth=0
650+
):
627651
value = field_dict.get(key)
628652
if value is None:
629653
value = field_dict.setdefault(key, new_default(message))
@@ -632,7 +656,13 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
632656
if value is None:
633657
value = field_dict.setdefault(key, new_default(message))
634658
# Read sub-message.
635-
pos = value.add()._InternalParse(buffer, pos, end)
659+
current_depth += 1
660+
if current_depth > _recursion_limit:
661+
raise _DecodeError(
662+
'Error parsing message: too many levels of nesting.'
663+
)
664+
pos = value.add()._InternalParse(buffer, pos, end, current_depth)
665+
current_depth -= 1
636666
# Read end tag.
637667
new_pos = pos+end_tag_len
638668
if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
@@ -644,12 +674,16 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
644674
return new_pos
645675
return DecodeRepeatedField
646676
else:
647-
def DecodeField(buffer, pos, end, message, field_dict):
677+
def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
648678
value = field_dict.get(key)
649679
if value is None:
650680
value = field_dict.setdefault(key, new_default(message))
651681
# Read sub-message.
652-
pos = value._InternalParse(buffer, pos, end)
682+
current_depth += 1
683+
if current_depth > _recursion_limit:
684+
raise _DecodeError('Error parsing message: too many levels of nesting.')
685+
pos = value._InternalParse(buffer, pos, end, current_depth)
686+
current_depth -= 1
653687
# Read end tag.
654688
new_pos = pos+end_tag_len
655689
if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
@@ -668,7 +702,9 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
668702
tag_bytes = encoder.TagBytes(field_number,
669703
wire_format.WIRETYPE_LENGTH_DELIMITED)
670704
tag_len = len(tag_bytes)
671-
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
705+
def DecodeRepeatedField(
706+
buffer, pos, end, message, field_dict, current_depth=0
707+
):
672708
value = field_dict.get(key)
673709
if value is None:
674710
value = field_dict.setdefault(key, new_default(message))
@@ -679,18 +715,27 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
679715
if new_pos > end:
680716
raise _DecodeError('Truncated message.')
681717
# Read sub-message.
682-
if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
718+
current_depth += 1
719+
if current_depth > _recursion_limit:
720+
raise _DecodeError(
721+
'Error parsing message: too many levels of nesting.'
722+
)
723+
if (
724+
value.add()._InternalParse(buffer, pos, new_pos, current_depth)
725+
!= new_pos
726+
):
683727
# The only reason _InternalParse would return early is if it
684728
# encountered an end-group tag.
685729
raise _DecodeError('Unexpected end-group tag.')
686730
# Predict that the next tag is another copy of the same repeated field.
731+
current_depth -= 1
687732
pos = new_pos + tag_len
688733
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
689734
# Prediction failed. Return.
690735
return new_pos
691736
return DecodeRepeatedField
692737
else:
693-
def DecodeField(buffer, pos, end, message, field_dict):
738+
def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
694739
value = field_dict.get(key)
695740
if value is None:
696741
value = field_dict.setdefault(key, new_default(message))
@@ -699,11 +744,14 @@ def DecodeField(buffer, pos, end, message, field_dict):
699744
new_pos = pos + size
700745
if new_pos > end:
701746
raise _DecodeError('Truncated message.')
702-
# Read sub-message.
703-
if value._InternalParse(buffer, pos, new_pos) != new_pos:
747+
current_depth += 1
748+
if current_depth > _recursion_limit:
749+
raise _DecodeError('Error parsing message: too many levels of nesting.')
750+
if value._InternalParse(buffer, pos, new_pos, current_depth) != new_pos:
704751
# The only reason _InternalParse would return early is if it encountered
705752
# an end-group tag.
706753
raise _DecodeError('Unexpected end-group tag.')
754+
current_depth -= 1
707755
return new_pos
708756
return DecodeField
709757

@@ -859,7 +907,8 @@ def MapDecoder(field_descriptor, new_default, is_message_map):
859907
# Can't read _concrete_class yet; might not be initialized.
860908
message_type = field_descriptor.message_type
861909

862-
def DecodeMap(buffer, pos, end, message, field_dict):
910+
def DecodeMap(buffer, pos, end, message, field_dict, current_depth=0):
911+
del current_depth # unused
863912
submsg = message_type._concrete_class()
864913
value = field_dict.get(key)
865914
if value is None:
@@ -941,8 +990,16 @@ def _SkipGroup(buffer, pos, end):
941990
return pos
942991
pos = new_pos
943992

993+
DEFAULT_RECURSION_LIMIT = 100
994+
_recursion_limit = DEFAULT_RECURSION_LIMIT
995+
996+
997+
def SetRecursionLimit(new_limit):
998+
global _recursion_limit
999+
_recursion_limit = new_limit
1000+
9441001

945-
def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
1002+
def _DecodeUnknownFieldSet(buffer, pos, end_pos=None, current_depth=0):
9461003
"""Decode UnknownFieldSet. Returns the UnknownFieldSet and new position."""
9471004

9481005
unknown_field_set = containers.UnknownFieldSet()
@@ -952,14 +1009,14 @@ def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
9521009
field_number, wire_type = wire_format.UnpackTag(tag)
9531010
if wire_type == wire_format.WIRETYPE_END_GROUP:
9541011
break
955-
(data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
1012+
(data, pos) = _DecodeUnknownField(buffer, pos, wire_type, current_depth)
9561013
# pylint: disable=protected-access
9571014
unknown_field_set._add(field_number, wire_type, data)
9581015

9591016
return (unknown_field_set, pos)
9601017

9611018

962-
def _DecodeUnknownField(buffer, pos, wire_type):
1019+
def _DecodeUnknownField(buffer, pos, wire_type, current_depth=0):
9631020
"""Decode a unknown field. Returns the UnknownField and new position."""
9641021

9651022
if wire_type == wire_format.WIRETYPE_VARINT:
@@ -973,7 +1030,12 @@ def _DecodeUnknownField(buffer, pos, wire_type):
9731030
data = buffer[pos:pos+size].tobytes()
9741031
pos += size
9751032
elif wire_type == wire_format.WIRETYPE_START_GROUP:
976-
(data, pos) = _DecodeUnknownFieldSet(buffer, pos)
1033+
print("MMP " + str(current_depth))
1034+
current_depth += 1
1035+
if current_depth >= _recursion_limit:
1036+
raise _DecodeError('Error parsing message: too many levels of nesting.')
1037+
(data, pos) = _DecodeUnknownFieldSet(buffer, pos, None, current_depth)
1038+
current_depth -= 1
9771039
elif wire_type == wire_format.WIRETYPE_END_GROUP:
9781040
return (0, -1)
9791041
else:
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# -*- coding: utf-8 -*-
2+
# Protocol Buffers - Google's data interchange format
3+
# Copyright 2008 Google Inc. All rights reserved.
4+
#
5+
# Use of this source code is governed by a BSD-style
6+
# license that can be found in the LICENSE file or at
7+
# https://developers.google.com/open-source/licenses/bsd
8+
9+
"""Test decoder."""
10+
11+
import unittest
12+
13+
from google.protobuf import message
14+
from google.protobuf.internal import decoder
15+
from google.protobuf.internal import testing_refleaks
16+
from google.protobuf.internal import wire_format
17+
18+
@testing_refleaks.TestCase
19+
class DecoderTest(unittest.TestCase):
20+
def test_decode_unknown_group_field_too_many_levels(self):
21+
data = memoryview(b'\023' * 5_000_000)
22+
self.assertRaisesRegex(
23+
message.DecodeError,
24+
'Error parsing message',
25+
decoder._DecodeUnknownField,
26+
data,
27+
1,
28+
wire_format.WIRETYPE_START_GROUP,
29+
1
30+
)
31+
32+
if __name__ == '__main__':
33+
unittest.main()

python/google/protobuf/internal/message_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@
3030

3131
from google.protobuf.internal import api_implementation # pylint: disable=g-import-not-at-top
3232
from google.protobuf.internal import encoder
33+
from google.protobuf.internal import decoder
3334
from google.protobuf.internal import more_extensions_pb2
3435
from google.protobuf.internal import more_messages_pb2
3536
from google.protobuf.internal import packed_field_test_pb2
37+
from google.protobuf.internal import self_recursive_pb2
3638
from google.protobuf.internal import test_proto3_optional_pb2
3739
from google.protobuf.internal import test_util
3840
from google.protobuf.internal import testing_refleaks
@@ -1261,6 +1263,35 @@ def __eq__(self, other):
12611263
self.assertNotEqual(ComparesWithFoo(), m)
12621264

12631265

1266+
@testing_refleaks.TestCase
1267+
class TestRecursiveGroup(unittest.TestCase):
1268+
1269+
def _MakeRecursiveGroupMessage(self, n):
1270+
msg = self_recursive_pb2.SelfRecursive.RecursiveGroup()
1271+
sub = msg
1272+
for _ in range(n):
1273+
sub = sub.sub_group
1274+
sub.i = 1
1275+
return msg.SerializeToString()
1276+
1277+
def testRecursiveGroups(self):
1278+
recurse_msg = self_recursive_pb2.SelfRecursive.RecursiveGroup()
1279+
data = self._MakeRecursiveGroupMessage(100)
1280+
recurse_msg.ParseFromString(data)
1281+
self.assertTrue(recurse_msg.HasField('sub_group'))
1282+
1283+
def testRecursiveGroupsException(self):
1284+
if api_implementation.Type() != 'python':
1285+
api_implementation._c_module.SetAllowOversizeProtos(False)
1286+
recurse_msg = self_recursive_pb2.SelfRecursive.RecursiveGroup()
1287+
data = self._MakeRecursiveGroupMessage(300)
1288+
with self.assertRaises(message.DecodeError) as context:
1289+
recurse_msg.ParseFromString(data)
1290+
self.assertIn('Error parsing message', str(context.exception))
1291+
if api_implementation.Type() == 'python':
1292+
self.assertIn('too many levels of nesting', str(context.exception))
1293+
1294+
12641295
# Class to test proto2-only features (required, extensions, etc.)
12651296
@testing_refleaks.TestCase
12661297
class Proto2Test(unittest.TestCase):

0 commit comments

Comments
 (0)