Skip to content

Commit e35019a

Browse files
committed
use copies for trimming
1 parent 8b1b60e commit e35019a

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

guardrails/utils/reask_utils.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from copy import deepcopy
12
from typing import Any, Dict, List, Optional, Tuple, Union
23

34
import pydantic
@@ -44,41 +45,42 @@ def gather_reasks(
4445
reasks = []
4546

4647
def _gather_reasks_in_dict(
47-
output: Dict, path: Optional[List[Union[str, int]]] = None
48+
original: Dict, output: Dict, path: Optional[List[Union[str, int]]] = None
4849
) -> None:
4950
if path is None:
5051
path = []
51-
for field, value in output.items():
52+
for field, value in original.items():
5253
if isinstance(value, FieldReAsk):
5354
value.path = path + [field]
5455
reasks.append(value)
5556
del output[field]
5657

5758
if isinstance(value, dict):
58-
_gather_reasks_in_dict(value, path + [field])
59+
_gather_reasks_in_dict(value, output[field], path + [field])
5960

6061
if isinstance(value, list):
61-
_gather_reasks_in_list(value, path + [field])
62+
_gather_reasks_in_list(value, output[field], path + [field])
6263
return
6364

6465
def _gather_reasks_in_list(
65-
output: List, path: Optional[List[Union[str, int]]] = None
66+
original: List, output: List, path: Optional[List[Union[str, int]]] = None
6667
) -> None:
6768
if path is None:
6869
path = []
69-
for idx, item in enumerate(output):
70+
for idx, item in enumerate(original):
7071
if isinstance(item, FieldReAsk):
7172
item.path = path + [idx]
7273
reasks.append(item)
7374
del output[idx]
7475
elif isinstance(item, dict):
75-
_gather_reasks_in_dict(item, path + [idx])
76+
_gather_reasks_in_dict(item, output[idx], path + [idx])
7677
elif isinstance(item, list):
77-
_gather_reasks_in_list(item, path + [idx])
78+
_gather_reasks_in_list(item, output[idx], path + [idx])
7879
return
7980

80-
_gather_reasks_in_dict(validated_output)
81-
return reasks, validated_output
81+
output_copy = deepcopy(validated_output)
82+
_gather_reasks_in_dict(validated_output, output_copy)
83+
return reasks, output_copy
8284

8385

8486
def get_pruned_tree(

0 commit comments

Comments
 (0)