Skip to content

Commit 6fba1d0

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
modify the placeholder attr parser to handle dict/list types (#2181)
Summary: Pull Request resolved: #2181 # context * to handle complex ph_key for the placeholder like the following: ``` (Pdb) arg.op 'placeholder' (Pdb) arg.ph_key 'event_id_list_features_seqs[marketplace]' ``` * original workaround is to modify the `arg_info` in the `_start_data_dist` ``` (Pdb) forward.args [ArgInfo(input_attrs=['event_id_list_features_seqs[user_conv_ads_event]'], is_getitems=[False], name=None)] (Pdb) attr 'event_id_list_features_seqs[user_conv_ads_event]' ``` * according to the ph_key generation, it could be something like `A[key][idx]`. Differential Revision: D59074268
1 parent b8a1c40 commit 6fba1d0

File tree

1 file changed

+11
-4
lines changed
  • torchrec/distributed/train_pipeline

1 file changed

+11
-4
lines changed

torchrec/distributed/train_pipeline/utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -678,9 +678,16 @@ def _get_node_args_helper(
678678
if child_node.op == "placeholder":
679679
if hasattr(child_node, "ph_key"):
680680
# pyre-ignore[16]
681-
arg_info.input_attrs.insert(0, child_node.ph_key)
682-
arg_info.is_getitems.insert(0, False)
683-
arg_info.preproc_modules.insert(0, None)
681+
ph_key: str = child_node.ph_key
682+
# example: ph_key = 'event_id_list_features_seqs[marketplace]'
683+
ph_keys = ph_key.split("[")
684+
for key in ph_keys:
685+
if "]" in key:
686+
arg_info.input_attrs.append(key[:-1])
687+
arg_info.is_getitems.append(True)
688+
else:
689+
arg_info.input_attrs.append(key)
690+
arg_info.is_getitems.append(False)
684691
else:
685692
# no-op
686693
arg_info.input_attrs.insert(0, "")
@@ -1038,7 +1045,7 @@ def _rewrite_model( # noqa C901
10381045
)
10391046

10401047
if num_found == total_num_args:
1041-
logger.info(f"Module '{node.target}'' will be pipelined")
1048+
logger.info(f"Module '{node.target}' will be pipelined")
10421049
child = sharded_modules[node.target]
10431050
original_forwards.append(child.forward)
10441051
child.forward = pipelined_forward(

0 commit comments

Comments
 (0)