@@ -312,7 +312,7 @@ def __init__(
312
312
context_type : Type [TrainPipelineContext ] = TrainPipelineContext ,
313
313
pipeline_preproc : bool = False ,
314
314
custom_model_fwd : Optional [
315
- Callable [[In ], Tuple [torch .Tensor , List [ torch . Tensor ] ]]
315
+ Callable [[Optional [ In ]] , Tuple [torch .Tensor , Out ]]
316
316
] = None ,
317
317
) -> None :
318
318
self ._model = model
@@ -366,6 +366,10 @@ def __init__(
366
366
self ._dataloader_exhausted : bool = False
367
367
self ._context_type : Type [TrainPipelineContext ] = context_type
368
368
369
+ self ._model_fwd : Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]] = (
370
+ custom_model_fwd if custom_model_fwd else model
371
+ )
372
+
369
373
# DEPRECATED FIELDS
370
374
self ._batch_i : Optional [In ] = None
371
375
self ._batch_ip1 : Optional [In ] = None
@@ -483,9 +487,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
483
487
484
488
# forward
485
489
with record_function ("## forward ##" ):
486
- losses , output = cast (
487
- Tuple [torch .Tensor , Out ], self ._model (self .batches [0 ])
488
- )
490
+ losses , output = self ._model_fwd (self .batches [0 ])
489
491
490
492
if len (self .batches ) >= 2 :
491
493
self .wait_sparse_data_dist (self .contexts [1 ])
@@ -718,7 +720,7 @@ def __init__(
718
720
stash_gradients : bool = False ,
719
721
pipeline_preproc : bool = False ,
720
722
custom_model_fwd : Optional [
721
- Callable [[In ], Tuple [torch .Tensor , List [ torch . Tensor ] ]]
723
+ Callable [[Optional [ In ]] , Tuple [torch .Tensor , Out ]]
722
724
] = None ,
723
725
) -> None :
724
726
super ().__init__ (
@@ -729,6 +731,7 @@ def __init__(
729
731
apply_jit = apply_jit ,
730
732
context_type = EmbeddingTrainPipelineContext ,
731
733
pipeline_preproc = pipeline_preproc ,
734
+ custom_model_fwd = custom_model_fwd ,
732
735
)
733
736
self ._start_batch = start_batch
734
737
self ._stash_gradients = stash_gradients
@@ -752,9 +755,6 @@ def __init__(
752
755
self ._embedding_odd_streams : List [Optional [torch .Stream ]] = []
753
756
self ._embedding_even_streams : List [Optional [torch .Stream ]] = []
754
757
self ._gradients : Dict [str , torch .Tensor ] = {}
755
- self ._model_fwd : Union [
756
- torch .nn .Module , Callable [[In ], Tuple [torch .Tensor , List [torch .Tensor ]]]
757
- ] = (custom_model_fwd if custom_model_fwd is not None else model )
758
758
759
759
def _grad_swap (self ) -> None :
760
760
for name , param in self ._model .named_parameters ():
@@ -893,7 +893,7 @@ def _mlp_forward(
893
893
_wait_for_events (
894
894
batch , context , torch .get_device_module (self ._device ).current_stream ()
895
895
)
896
- return cast ( Tuple [ torch . Tensor , Out ], self ._model_fwd (batch ) )
896
+ return self ._model_fwd (batch )
897
897
898
898
def embedding_backward (self , context : EmbeddingTrainPipelineContext ) -> None :
899
899
default_stream = torch .get_device_module (self ._device ).current_stream ()
@@ -1020,6 +1020,10 @@ def __init__(
1020
1020
device : torch .device ,
1021
1021
execute_all_batches : bool = True ,
1022
1022
apply_jit : bool = False ,
1023
+ pipeline_preproc : bool = False ,
1024
+ custom_model_fwd : Optional [
1025
+ Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]]
1026
+ ] = None ,
1023
1027
) -> None :
1024
1028
super ().__init__ (
1025
1029
model = model ,
@@ -1028,6 +1032,8 @@ def __init__(
1028
1032
execute_all_batches = execute_all_batches ,
1029
1033
apply_jit = apply_jit ,
1030
1034
context_type = PrefetchTrainPipelineContext ,
1035
+ pipeline_preproc = pipeline_preproc ,
1036
+ custom_model_fwd = custom_model_fwd ,
1031
1037
)
1032
1038
self ._context = PrefetchTrainPipelineContext (version = 0 )
1033
1039
self ._prefetch_stream : Optional [torch .Stream ] = (
@@ -1084,7 +1090,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
1084
1090
self ._wait_sparse_data_dist ()
1085
1091
# forward
1086
1092
with record_function ("## forward ##" ):
1087
- losses , output = cast ( Tuple [ torch . Tensor , Out ], self ._model (self ._batch_i ) )
1093
+ losses , output = self ._model_fwd (self ._batch_i )
1088
1094
1089
1095
self ._prefetch (self ._batch_ip1 )
1090
1096
0 commit comments