@@ -172,7 +172,10 @@ def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default,
172
172
clear_if_default = False ):
173
173
if is_packed :
174
174
local_DecodeVarint = _DecodeVarint
175
- def DecodePackedField (buffer , pos , end , message , field_dict ):
175
+ def DecodePackedField (
176
+ buffer , pos , end , message , field_dict , current_depth = 0
177
+ ):
178
+ del current_depth # unused
176
179
value = field_dict .get (key )
177
180
if value is None :
178
181
value = field_dict .setdefault (key , new_default (message ))
@@ -191,7 +194,10 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
191
194
elif is_repeated :
192
195
tag_bytes = encoder .TagBytes (field_number , wire_type )
193
196
tag_len = len (tag_bytes )
194
- def DecodeRepeatedField (buffer , pos , end , message , field_dict ):
197
+ def DecodeRepeatedField (
198
+ buffer , pos , end , message , field_dict , current_depth = 0
199
+ ):
200
+ del current_depth # unused
195
201
value = field_dict .get (key )
196
202
if value is None :
197
203
value = field_dict .setdefault (key , new_default (message ))
@@ -208,7 +214,8 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
208
214
return new_pos
209
215
return DecodeRepeatedField
210
216
else :
211
- def DecodeField (buffer , pos , end , message , field_dict ):
217
+ def DecodeField (buffer , pos , end , message , field_dict , current_depth = 0 ):
218
+ del current_depth # unused
212
219
(new_value , pos ) = decode_value (buffer , pos )
213
220
if pos > end :
214
221
raise _DecodeError ('Truncated message.' )
@@ -352,7 +359,9 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default,
352
359
enum_type = key .enum_type
353
360
if is_packed :
354
361
local_DecodeVarint = _DecodeVarint
355
- def DecodePackedField (buffer , pos , end , message , field_dict ):
362
+ def DecodePackedField (
363
+ buffer , pos , end , message , field_dict , current_depth = 0
364
+ ):
356
365
"""Decode serialized packed enum to its value and a new position.
357
366
358
367
Args:
@@ -365,6 +374,7 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
365
374
Returns:
366
375
int, new position in serialized data.
367
376
"""
377
+ del current_depth # unused
368
378
value = field_dict .get (key )
369
379
if value is None :
370
380
value = field_dict .setdefault (key , new_default (message ))
@@ -405,7 +415,9 @@ def DecodePackedField(buffer, pos, end, message, field_dict):
405
415
elif is_repeated :
406
416
tag_bytes = encoder .TagBytes (field_number , wire_format .WIRETYPE_VARINT )
407
417
tag_len = len (tag_bytes )
408
- def DecodeRepeatedField (buffer , pos , end , message , field_dict ):
418
+ def DecodeRepeatedField (
419
+ buffer , pos , end , message , field_dict , current_depth = 0
420
+ ):
409
421
"""Decode serialized repeated enum to its value and a new position.
410
422
411
423
Args:
@@ -418,6 +430,7 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
418
430
Returns:
419
431
int, new position in serialized data.
420
432
"""
433
+ del current_depth # unused
421
434
value = field_dict .get (key )
422
435
if value is None :
423
436
value = field_dict .setdefault (key , new_default (message ))
@@ -446,7 +459,7 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
446
459
return new_pos
447
460
return DecodeRepeatedField
448
461
else :
449
- def DecodeField (buffer , pos , end , message , field_dict ):
462
+ def DecodeField (buffer , pos , end , message , field_dict , current_depth = 0 ):
450
463
"""Decode serialized repeated enum to its value and a new position.
451
464
452
465
Args:
@@ -459,6 +472,7 @@ def DecodeField(buffer, pos, end, message, field_dict):
459
472
Returns:
460
473
int, new position in serialized data.
461
474
"""
475
+ del current_depth # unused
462
476
value_start_pos = pos
463
477
(enum_value , pos ) = _DecodeSignedVarint32 (buffer , pos )
464
478
if pos > end :
@@ -540,7 +554,10 @@ def _ConvertToUnicode(memview):
540
554
tag_bytes = encoder .TagBytes (field_number ,
541
555
wire_format .WIRETYPE_LENGTH_DELIMITED )
542
556
tag_len = len (tag_bytes )
543
- def DecodeRepeatedField (buffer , pos , end , message , field_dict ):
557
+ def DecodeRepeatedField (
558
+ buffer , pos , end , message , field_dict , current_depth = 0
559
+ ):
560
+ del current_depth # unused
544
561
value = field_dict .get (key )
545
562
if value is None :
546
563
value = field_dict .setdefault (key , new_default (message ))
@@ -557,7 +574,8 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
557
574
return new_pos
558
575
return DecodeRepeatedField
559
576
else :
560
- def DecodeField (buffer , pos , end , message , field_dict ):
577
+ def DecodeField (buffer , pos , end , message , field_dict , current_depth = 0 ):
578
+ del current_depth # unused
561
579
(size , pos ) = local_DecodeVarint (buffer , pos )
562
580
new_pos = pos + size
563
581
if new_pos > end :
@@ -581,7 +599,10 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default,
581
599
tag_bytes = encoder .TagBytes (field_number ,
582
600
wire_format .WIRETYPE_LENGTH_DELIMITED )
583
601
tag_len = len (tag_bytes )
584
- def DecodeRepeatedField (buffer , pos , end , message , field_dict ):
602
+ def DecodeRepeatedField (
603
+ buffer , pos , end , message , field_dict , current_depth = 0
604
+ ):
605
+ del current_depth # unused
585
606
value = field_dict .get (key )
586
607
if value is None :
587
608
value = field_dict .setdefault (key , new_default (message ))
@@ -598,7 +619,8 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
598
619
return new_pos
599
620
return DecodeRepeatedField
600
621
else :
601
- def DecodeField (buffer , pos , end , message , field_dict ):
622
+ def DecodeField (buffer , pos , end , message , field_dict , current_depth = 0 ):
623
+ del current_depth # unused
602
624
(size , pos ) = local_DecodeVarint (buffer , pos )
603
625
new_pos = pos + size
604
626
if new_pos > end :
@@ -623,7 +645,9 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
623
645
tag_bytes = encoder .TagBytes (field_number ,
624
646
wire_format .WIRETYPE_START_GROUP )
625
647
tag_len = len (tag_bytes )
626
- def DecodeRepeatedField (buffer , pos , end , message , field_dict ):
648
+ def DecodeRepeatedField (
649
+ buffer , pos , end , message , field_dict , current_depth = 0
650
+ ):
627
651
value = field_dict .get (key )
628
652
if value is None :
629
653
value = field_dict .setdefault (key , new_default (message ))
@@ -632,7 +656,13 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
632
656
if value is None :
633
657
value = field_dict .setdefault (key , new_default (message ))
634
658
# Read sub-message.
635
- pos = value .add ()._InternalParse (buffer , pos , end )
659
+ current_depth += 1
660
+ if current_depth > _recursion_limit :
661
+ raise _DecodeError (
662
+ 'Error parsing message: too many levels of nesting.'
663
+ )
664
+ pos = value .add ()._InternalParse (buffer , pos , end , current_depth )
665
+ current_depth -= 1
636
666
# Read end tag.
637
667
new_pos = pos + end_tag_len
638
668
if buffer [pos :new_pos ] != end_tag_bytes or new_pos > end :
@@ -644,12 +674,16 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
644
674
return new_pos
645
675
return DecodeRepeatedField
646
676
else :
647
- def DecodeField (buffer , pos , end , message , field_dict ):
677
+ def DecodeField (buffer , pos , end , message , field_dict , current_depth = 0 ):
648
678
value = field_dict .get (key )
649
679
if value is None :
650
680
value = field_dict .setdefault (key , new_default (message ))
651
681
# Read sub-message.
652
- pos = value ._InternalParse (buffer , pos , end )
682
+ current_depth += 1
683
+ if current_depth > _recursion_limit :
684
+ raise _DecodeError ('Error parsing message: too many levels of nesting.' )
685
+ pos = value ._InternalParse (buffer , pos , end , current_depth )
686
+ current_depth -= 1
653
687
# Read end tag.
654
688
new_pos = pos + end_tag_len
655
689
if buffer [pos :new_pos ] != end_tag_bytes or new_pos > end :
@@ -668,7 +702,9 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
668
702
tag_bytes = encoder .TagBytes (field_number ,
669
703
wire_format .WIRETYPE_LENGTH_DELIMITED )
670
704
tag_len = len (tag_bytes )
671
- def DecodeRepeatedField (buffer , pos , end , message , field_dict ):
705
+ def DecodeRepeatedField (
706
+ buffer , pos , end , message , field_dict , current_depth = 0
707
+ ):
672
708
value = field_dict .get (key )
673
709
if value is None :
674
710
value = field_dict .setdefault (key , new_default (message ))
@@ -679,18 +715,27 @@ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
679
715
if new_pos > end :
680
716
raise _DecodeError ('Truncated message.' )
681
717
# Read sub-message.
682
- if value .add ()._InternalParse (buffer , pos , new_pos ) != new_pos :
718
+ current_depth += 1
719
+ if current_depth > _recursion_limit :
720
+ raise _DecodeError (
721
+ 'Error parsing message: too many levels of nesting.'
722
+ )
723
+ if (
724
+ value .add ()._InternalParse (buffer , pos , new_pos , current_depth )
725
+ != new_pos
726
+ ):
683
727
# The only reason _InternalParse would return early is if it
684
728
# encountered an end-group tag.
685
729
raise _DecodeError ('Unexpected end-group tag.' )
686
730
# Predict that the next tag is another copy of the same repeated field.
731
+ current_depth -= 1
687
732
pos = new_pos + tag_len
688
733
if buffer [new_pos :pos ] != tag_bytes or new_pos == end :
689
734
# Prediction failed. Return.
690
735
return new_pos
691
736
return DecodeRepeatedField
692
737
else :
693
- def DecodeField (buffer , pos , end , message , field_dict ):
738
+ def DecodeField (buffer , pos , end , message , field_dict , current_depth = 0 ):
694
739
value = field_dict .get (key )
695
740
if value is None :
696
741
value = field_dict .setdefault (key , new_default (message ))
@@ -699,11 +744,14 @@ def DecodeField(buffer, pos, end, message, field_dict):
699
744
new_pos = pos + size
700
745
if new_pos > end :
701
746
raise _DecodeError ('Truncated message.' )
702
- # Read sub-message.
703
- if value ._InternalParse (buffer , pos , new_pos ) != new_pos :
747
+ current_depth += 1
748
+ if current_depth > _recursion_limit :
749
+ raise _DecodeError ('Error parsing message: too many levels of nesting.' )
750
+ if value ._InternalParse (buffer , pos , new_pos , current_depth ) != new_pos :
704
751
# The only reason _InternalParse would return early is if it encountered
705
752
# an end-group tag.
706
753
raise _DecodeError ('Unexpected end-group tag.' )
754
+ current_depth -= 1
707
755
return new_pos
708
756
return DecodeField
709
757
@@ -859,7 +907,8 @@ def MapDecoder(field_descriptor, new_default, is_message_map):
859
907
# Can't read _concrete_class yet; might not be initialized.
860
908
message_type = field_descriptor .message_type
861
909
862
- def DecodeMap (buffer , pos , end , message , field_dict ):
910
+ def DecodeMap (buffer , pos , end , message , field_dict , current_depth = 0 ):
911
+ del current_depth # unused
863
912
submsg = message_type ._concrete_class ()
864
913
value = field_dict .get (key )
865
914
if value is None :
@@ -941,8 +990,16 @@ def _SkipGroup(buffer, pos, end):
941
990
return pos
942
991
pos = new_pos
943
992
993
+ DEFAULT_RECURSION_LIMIT = 100
994
+ _recursion_limit = DEFAULT_RECURSION_LIMIT
995
+
996
+
997
+ def SetRecursionLimit (new_limit ):
998
+ global _recursion_limit
999
+ _recursion_limit = new_limit
1000
+
944
1001
945
- def _DecodeUnknownFieldSet (buffer , pos , end_pos = None ):
1002
+ def _DecodeUnknownFieldSet (buffer , pos , end_pos = None , current_depth = 0 ):
946
1003
"""Decode UnknownFieldSet. Returns the UnknownFieldSet and new position."""
947
1004
948
1005
unknown_field_set = containers .UnknownFieldSet ()
@@ -952,14 +1009,14 @@ def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
952
1009
field_number , wire_type = wire_format .UnpackTag (tag )
953
1010
if wire_type == wire_format .WIRETYPE_END_GROUP :
954
1011
break
955
- (data , pos ) = _DecodeUnknownField (buffer , pos , wire_type )
1012
+ (data , pos ) = _DecodeUnknownField (buffer , pos , wire_type , current_depth )
956
1013
# pylint: disable=protected-access
957
1014
unknown_field_set ._add (field_number , wire_type , data )
958
1015
959
1016
return (unknown_field_set , pos )
960
1017
961
1018
962
- def _DecodeUnknownField (buffer , pos , wire_type ):
1019
+ def _DecodeUnknownField (buffer , pos , wire_type , current_depth = 0 ):
963
1020
"""Decode a unknown field. Returns the UnknownField and new position."""
964
1021
965
1022
if wire_type == wire_format .WIRETYPE_VARINT :
@@ -973,7 +1030,12 @@ def _DecodeUnknownField(buffer, pos, wire_type):
973
1030
data = buffer [pos :pos + size ].tobytes ()
974
1031
pos += size
975
1032
elif wire_type == wire_format .WIRETYPE_START_GROUP :
976
- (data , pos ) = _DecodeUnknownFieldSet (buffer , pos )
1033
+ print ("MMP " + str (current_depth ))
1034
+ current_depth += 1
1035
+ if current_depth >= _recursion_limit :
1036
+ raise _DecodeError ('Error parsing message: too many levels of nesting.' )
1037
+ (data , pos ) = _DecodeUnknownFieldSet (buffer , pos , None , current_depth )
1038
+ current_depth -= 1
977
1039
elif wire_type == wire_format .WIRETYPE_END_GROUP :
978
1040
return (0 , - 1 )
979
1041
else :
0 commit comments