38
38
_start_embedding_lookup ,
39
39
_to_device ,
40
40
_wait_for_batch ,
41
- _wait_for_event ,
41
+ _wait_for_events ,
42
42
DataLoadingThread ,
43
43
EmbeddingPipelinedForward ,
44
44
EmbeddingTrainPipelineContext ,
@@ -590,8 +590,6 @@ def start_sparse_data_dist(
590
590
return
591
591
with record_function (f"## start_sparse_data_dist { context .index } ##" ):
592
592
with self ._stream_context (self ._data_dist_stream ):
593
- if context .event is not None :
594
- context .event .wait ()
595
593
_wait_for_batch (batch , self ._memcpy_stream )
596
594
597
595
original_contexts = [p .get_context () for p in self ._pipelined_preprocs ]
@@ -737,11 +735,8 @@ def __init__(
737
735
if device .type in ["cuda" , "mtia" ]
738
736
else None
739
737
)
740
- self ._bwd_sync_stream : Optional [torch .Stream ] = (
741
- (torch .get_device_module (self ._device ).Stream (priority = 0 ))
742
- if device .type in ["cuda" , "mtia" ]
743
- else None
744
- )
738
+ self ._embedding_odd_streams : List [Optional [torch .Stream ]] = []
739
+ self ._embedding_even_streams : List [Optional [torch .Stream ]] = []
745
740
self ._gradients : Dict [str , torch .Tensor ] = {}
746
741
747
742
def _grad_swap (self ) -> None :
@@ -751,6 +746,29 @@ def _grad_swap(self) -> None:
751
746
self ._gradients [name ] = param .grad .clone ()
752
747
param .grad = grad
753
748
749
+ def _init_embedding_streams (self ) -> None :
750
+
751
+ for _ in self ._pipelined_modules :
752
+ self ._embedding_odd_streams .append (
753
+ (torch .get_device_module (self ._device ).Stream (priority = 0 ))
754
+ if self ._device .type in ["cuda" , "mtia" ]
755
+ else None
756
+ )
757
+ self ._embedding_even_streams .append (
758
+ (torch .get_device_module (self ._device ).Stream (priority = 0 ))
759
+ if self ._device .type in ["cuda" , "mtia" ]
760
+ else None
761
+ )
762
+
763
+ def _validate_optimizer (self ) -> None :
764
+ for pipelined_module in self ._pipelined_modules :
765
+ pipelined_params = set (pipelined_module .parameters ())
766
+ for group in self ._optimizer .param_groups :
767
+ if not set (group ["params" ]).isdisjoint (pipelined_params ):
768
+ logger .warning (
769
+ f"SemiSync pipelined { type (pipelined_module )} and optimizer share parameters"
770
+ )
771
+
754
772
def fill_pipeline (self , dataloader_iter : Iterator [In ]) -> None :
755
773
# pipeline is already filled
756
774
if len (self .batches ) >= 3 :
@@ -770,7 +788,9 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
770
788
# pyre-ignore [6]
771
789
EmbeddingPipelinedForward ,
772
790
)
791
+ self ._init_embedding_streams ()
773
792
self .wait_sparse_data_dist (self .contexts [0 ])
793
+ self ._validate_optimizer ()
774
794
# pyre-ignore [6]
775
795
self .start_embedding_lookup (self .batches [0 ], self .contexts [0 ])
776
796
@@ -824,26 +844,25 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
824
844
self .wait_sparse_data_dist (self .contexts [2 ])
825
845
826
846
if self ._model .training :
827
- # backward would put an implicit sync point in stream called from, ideally
828
- # this would different from optimizer so it could start earilier, but currently not safe to do so.
829
- with self ._stream_context (self ._overarch_stream ):
830
- with record_function (f"## backward { self .contexts [0 ].index } ##" ):
831
- torch .sum (losses , dim = 0 ).backward ()
832
-
833
- with self ._stream_context (self ._overarch_stream ):
834
- with record_function (
835
- f"## optimizer { cast (int , self .contexts [0 ].index ) - 1 } ##"
836
- ):
837
- if self .is_semi_sync () and self ._stash_gradients :
838
- self ._grad_swap ()
839
- self ._mlp_optimizer_step ()
840
-
841
- with record_function (
842
- f"## zero_grad { cast (int , self .contexts [0 ].index ) - 1 } ##"
843
- ):
844
- self ._optimizer .zero_grad ()
847
+ with record_function (f"## backward { self .contexts [0 ].index } ##" ):
848
+ torch .sum (losses , dim = 0 ).backward ()
849
+ # pyre-ignore [6]
850
+ self .embedding_backward (self .contexts [0 ])
851
+
852
+ with record_function (
853
+ f"## optimizer { cast (int , self .contexts [0 ].index ) - 1 } ##"
854
+ ):
855
+ if self .is_semi_sync () and self ._stash_gradients :
856
+ self ._grad_swap ()
857
+ self ._mlp_optimizer_step ()
858
+
859
+ with record_function (
860
+ f"## zero_grad { cast (int , self .contexts [0 ].index ) - 1 } ##"
861
+ ):
862
+ self ._optimizer .zero_grad ()
845
863
846
864
if len (self .batches ) >= 2 and not self .is_semi_sync ():
865
+ torch .cuda .synchronize () # needed to avoid race condition
847
866
# pyre-ignore [6]
848
867
self .start_embedding_lookup (self .batches [1 ], self .contexts [1 ])
849
868
@@ -854,10 +873,29 @@ def _mlp_forward(
854
873
self , batch : In , context : TrainPipelineContext
855
874
) -> Tuple [torch .Tensor , Out ]:
856
875
with record_function (f"## forward { context .index } ##" ):
857
- with self ._stream_context (self ._overarch_stream ):
858
- _wait_for_event (batch , self ._device , context .event )
859
- context .event = None
860
- return cast (Tuple [torch .Tensor , Out ], self ._model (batch ))
876
+ _wait_for_events (
877
+ batch , context , torch .get_device_module (self ._device ).current_stream ()
878
+ )
879
+ return cast (Tuple [torch .Tensor , Out ], self ._model (batch ))
880
+
881
+ def embedding_backward (self , context : EmbeddingTrainPipelineContext ) -> None :
882
+ default_stream = torch .get_device_module (self ._device ).current_stream ()
883
+ streams = (
884
+ self ._embedding_even_streams
885
+ if cast (int , context .index ) % 2 == 0
886
+ else self ._embedding_odd_streams
887
+ )
888
+ for stream , emb_tensors , detached_emb_tensors in zip (
889
+ streams ,
890
+ context .embedding_tensors ,
891
+ context .detached_embedding_tensors ,
892
+ ):
893
+ with self ._stream_context (stream ):
894
+ grads = [tensor .grad for tensor in detached_emb_tensors ]
895
+ if stream :
896
+ stream .wait_stream (default_stream )
897
+ # pyre-ignore
898
+ torch .autograd .backward (emb_tensors , grads )
861
899
862
900
def copy_batch_to_gpu (
863
901
self ,
@@ -870,8 +908,9 @@ def copy_batch_to_gpu(
870
908
if batch is not None :
871
909
batch = _to_device (batch , self ._device , non_blocking = True )
872
910
context = self ._create_context ()
873
- context .event = torch .get_device_module (self ._device ).Event ()
874
- context .event .record ()
911
+ event = torch .get_device_module (self ._device ).Event ()
912
+ event .record ()
913
+ context .events .append (event )
875
914
return batch , context
876
915
877
916
def start_sparse_data_dist (
@@ -882,9 +921,25 @@ def start_sparse_data_dist(
882
921
"""
883
922
Waits for batch to finish getting copied to GPU, then starts the input dist. This is Event based version.
884
923
"""
885
- super ().start_sparse_data_dist (batch , context )
886
- context .event = torch .get_device_module (self ._device ).Event ()
887
- context .event .record ()
924
+ if batch is None :
925
+ return
926
+
927
+ # Temporarily set context for next iter to populate cache
928
+ original_contexts = [p .get_context () for p in self ._pipelined_preprocs ]
929
+ for preproc_mod in self ._pipelined_preprocs :
930
+ preproc_mod .set_context (context )
931
+
932
+ with record_function (f"## start_sparse_data_dist { context .index } ##" ):
933
+ with self ._stream_context (self ._data_dist_stream ):
934
+ _wait_for_events (batch , context , self ._data_dist_stream )
935
+ _start_data_dist (self ._pipelined_modules , batch , context )
936
+ event = torch .get_device_module (self ._device ).Event ()
937
+ event .record ()
938
+ context .events .append (event )
939
+
940
+ # Restore context for model forward
941
+ for module , context in zip (self ._pipelined_preprocs , original_contexts ):
942
+ module .set_context (context )
888
943
889
944
def start_embedding_lookup (
890
945
self ,
@@ -897,17 +952,20 @@ def start_embedding_lookup(
897
952
if batch is None :
898
953
return
899
954
with record_function (f"## start_embedding_lookup { context .index } ##" ):
900
- with self . _stream_context (
901
- self ._embedding_even_stream
902
- if cast ( int , context . index ) % 2 == 0
903
- else self ._embedding_odd_stream
904
- ):
905
- _wait_for_event ( batch , self ._device , context . event )
906
- _start_embedding_lookup (
907
- self . _pipelined_modules , batch , context , self ._device
955
+ _wait_for_events (
956
+ batch , context , torch . get_device_module ( self ._device ). current_stream ()
957
+ )
958
+ for i , module in enumerate ( self ._pipelined_modules ):
959
+ stream = (
960
+ self ._embedding_even_streams [ i ]
961
+ if cast ( int , context . index ) % 2 == 0
962
+ else self ._embedding_odd_streams [ i ]
908
963
)
909
- context .event = torch .get_device_module (self ._device ).Event ()
910
- context .event .record ()
964
+ with self ._stream_context (stream ):
965
+ _start_embedding_lookup (module , batch , context , stream )
966
+ event = torch .get_device_module (self ._device ).Event ()
967
+ event .record ()
968
+ context .events .append (event )
911
969
912
970
913
971
class PrefetchTrainPipelineSparseDist (TrainPipelineSparseDist [In , Out ]):
0 commit comments