Skip to content

Commit 6a0c2ba

Browse files
sarckkfacebook-github-bot
authored andcommitted
Add custom model fwd in train pipelines (#2324)
Summary: Pull Request resolved: #2324 Add missing pipelline_preproc and custom_moel_fwd args. Reviewed By: chrisxcai Differential Revision: D61564467
1 parent f7e444d commit 6a0c2ba

File tree

2 files changed

+64
-10
lines changed

2 files changed

+64
-10
lines changed

torchrec/distributed/train_pipeline/tests/test_train_pipelines.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,54 @@ def test_multi_dataloader_pipelining(self) -> None:
827827
)
828828
)
829829

830+
# pyre-ignore
831+
@unittest.skipIf(
832+
not torch.cuda.is_available(),
833+
"Not enough GPUs, this test requires at least one GPU",
834+
)
835+
def test_custom_fwd(
836+
self,
837+
) -> None:
838+
data = self._generate_data(
839+
num_batches=4,
840+
batch_size=32,
841+
)
842+
dataloader = iter(data)
843+
844+
fused_params_pipelined = {}
845+
sharding_type = ShardingType.ROW_WISE.value
846+
kernel_type = EmbeddingComputeKernel.FUSED.value
847+
sharded_model_pipelined: torch.nn.Module
848+
849+
model = self._setup_model()
850+
851+
(
852+
sharded_model_pipelined,
853+
optim_pipelined,
854+
) = self._generate_sharded_model_and_optimizer(
855+
model, sharding_type, kernel_type, fused_params_pipelined
856+
)
857+
858+
def custom_model_fwd(
859+
input: Optional[ModelInput],
860+
) -> Tuple[torch.Tensor, torch.Tensor]:
861+
loss, pred = sharded_model_pipelined(input)
862+
batch_size = pred.size(0)
863+
return loss, pred.expand(batch_size * 2, -1)
864+
865+
pipeline = TrainPipelineSparseDist(
866+
model=sharded_model_pipelined,
867+
optimizer=optim_pipelined,
868+
device=self.device,
869+
execute_all_batches=True,
870+
custom_model_fwd=custom_model_fwd,
871+
)
872+
873+
for _ in data:
874+
# Forward + backward w/ pipelining
875+
pred_pipeline = pipeline.progress(dataloader)
876+
self.assertEqual(pred_pipeline.size(0), 64)
877+
830878

831879
class TrainPipelinePreprocTest(TrainPipelineSparseDistTestBase):
832880
def setUp(self) -> None:

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def __init__(
312312
context_type: Type[TrainPipelineContext] = TrainPipelineContext,
313313
pipeline_preproc: bool = False,
314314
custom_model_fwd: Optional[
315-
Callable[[In], Tuple[torch.Tensor, List[torch.Tensor]]]
315+
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
316316
] = None,
317317
) -> None:
318318
self._model = model
@@ -366,6 +366,10 @@ def __init__(
366366
self._dataloader_exhausted: bool = False
367367
self._context_type: Type[TrainPipelineContext] = context_type
368368

369+
self._model_fwd: Callable[[Optional[In]], Tuple[torch.Tensor, Out]] = (
370+
custom_model_fwd if custom_model_fwd else model
371+
)
372+
369373
# DEPRECATED FIELDS
370374
self._batch_i: Optional[In] = None
371375
self._batch_ip1: Optional[In] = None
@@ -483,9 +487,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
483487

484488
# forward
485489
with record_function("## forward ##"):
486-
losses, output = cast(
487-
Tuple[torch.Tensor, Out], self._model(self.batches[0])
488-
)
490+
losses, output = self._model_fwd(self.batches[0])
489491

490492
if len(self.batches) >= 2:
491493
self.wait_sparse_data_dist(self.contexts[1])
@@ -718,7 +720,7 @@ def __init__(
718720
stash_gradients: bool = False,
719721
pipeline_preproc: bool = False,
720722
custom_model_fwd: Optional[
721-
Callable[[In], Tuple[torch.Tensor, List[torch.Tensor]]]
723+
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
722724
] = None,
723725
) -> None:
724726
super().__init__(
@@ -729,6 +731,7 @@ def __init__(
729731
apply_jit=apply_jit,
730732
context_type=EmbeddingTrainPipelineContext,
731733
pipeline_preproc=pipeline_preproc,
734+
custom_model_fwd=custom_model_fwd,
732735
)
733736
self._start_batch = start_batch
734737
self._stash_gradients = stash_gradients
@@ -752,9 +755,6 @@ def __init__(
752755
self._embedding_odd_streams: List[Optional[torch.Stream]] = []
753756
self._embedding_even_streams: List[Optional[torch.Stream]] = []
754757
self._gradients: Dict[str, torch.Tensor] = {}
755-
self._model_fwd: Union[
756-
torch.nn.Module, Callable[[In], Tuple[torch.Tensor, List[torch.Tensor]]]
757-
] = (custom_model_fwd if custom_model_fwd is not None else model)
758758

759759
def _grad_swap(self) -> None:
760760
for name, param in self._model.named_parameters():
@@ -893,7 +893,7 @@ def _mlp_forward(
893893
_wait_for_events(
894894
batch, context, torch.get_device_module(self._device).current_stream()
895895
)
896-
return cast(Tuple[torch.Tensor, Out], self._model_fwd(batch))
896+
return self._model_fwd(batch)
897897

898898
def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None:
899899
default_stream = torch.get_device_module(self._device).current_stream()
@@ -1020,6 +1020,10 @@ def __init__(
10201020
device: torch.device,
10211021
execute_all_batches: bool = True,
10221022
apply_jit: bool = False,
1023+
pipeline_preproc: bool = False,
1024+
custom_model_fwd: Optional[
1025+
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
1026+
] = None,
10231027
) -> None:
10241028
super().__init__(
10251029
model=model,
@@ -1028,6 +1032,8 @@ def __init__(
10281032
execute_all_batches=execute_all_batches,
10291033
apply_jit=apply_jit,
10301034
context_type=PrefetchTrainPipelineContext,
1035+
pipeline_preproc=pipeline_preproc,
1036+
custom_model_fwd=custom_model_fwd,
10311037
)
10321038
self._context = PrefetchTrainPipelineContext(version=0)
10331039
self._prefetch_stream: Optional[torch.Stream] = (
@@ -1084,7 +1090,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
10841090
self._wait_sparse_data_dist()
10851091
# forward
10861092
with record_function("## forward ##"):
1087-
losses, output = cast(Tuple[torch.Tensor, Out], self._model(self._batch_i))
1093+
losses, output = self._model_fwd(self._batch_i)
10881094

10891095
self._prefetch(self._batch_ip1)
10901096

0 commit comments

Comments
 (0)