@@ -554,15 +554,8 @@ def __init__(
554
554
self ._stash_gradients = stash_gradients
555
555
556
556
# use two data streams to support two concurrent batches
557
- self ._embedding_odd_stream : Optional [torch .cuda .streams .Stream ] = (
558
- (torch .cuda .Stream (priority = 0 )) if device .type == "cuda" else None
559
- )
560
- self ._embedding_even_stream : Optional [torch .cuda .streams .Stream ] = (
561
- (torch .cuda .Stream (priority = 0 )) if device .type == "cuda" else None
562
- )
563
- self ._overarch_stream : Optional [torch .cuda .streams .Stream ] = (
564
- (torch .cuda .Stream (priority = - 1 )) if device .type == "cuda" else None
565
- )
557
+ self ._embedding_odd_streams : List [Optional [torch .cuda .streams .Stream ]] = []
558
+ self ._embedding_even_streams : List [Optional [torch .cuda .streams .Stream ]] = []
566
559
self ._gradients : Dict [str , torch .Tensor ] = {}
567
560
568
561
def _grad_swap (self ) -> None :
@@ -572,6 +565,25 @@ def _grad_swap(self) -> None:
572
565
self ._gradients [name ] = param .grad .clone ()
573
566
param .grad = grad
574
567
568
+ def _init_embedding_streams (self ) -> None :
569
+
570
+ for _ in self ._pipelined_modules :
571
+ self ._embedding_odd_streams .append (
572
+ torch .cuda .Stream (priority = 0 ) if self ._device .type == "cuda" else None
573
+ )
574
+ self ._embedding_even_streams .append (
575
+ torch .cuda .Stream (priority = 0 ) if self ._device .type == "cuda" else None
576
+ )
577
+
578
+ def _validate_optimizer (self ) -> None :
579
+ for pipelined_module in self ._pipelined_modules :
580
+ pipelined_params = set (pipelined_module .parameters ())
581
+ for group in self ._optimizer .param_groups :
582
+ if not set (group ["params" ]).isdisjoint (pipelined_params ):
583
+ logger .warning (
584
+ f"SemiSync pipelined { type (pipelined_module )} and MLP optimizer share parameters"
585
+ )
586
+
575
587
def fill_pipeline (self , dataloader_iter : Iterator [In ]) -> None :
576
588
# pipeline is already filled
577
589
if len (self .batches ) >= 3 :
@@ -591,7 +603,9 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
591
603
# pyre-ignore [6]
592
604
EmbeddingPipelinedForward ,
593
605
)
606
+ self ._init_embedding_streams ()
594
607
self .wait_sparse_data_dist (self .contexts [0 ])
608
+ self ._validate_optimizer ()
595
609
# pyre-ignore [6]
596
610
self .start_embedding_lookup (self .batches [0 ], self .contexts [0 ])
597
611
@@ -645,25 +659,25 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
645
659
self .wait_sparse_data_dist (self .contexts [2 ])
646
660
647
661
if self ._model .training :
648
- # backward would put an implicit sync point in stream called from, ideally
649
- # this would different from optimizer so it could start earilier, but currently not safe to do so.
650
- with torch .cuda .stream (self ._overarch_stream ):
651
- with record_function (f"## backward { self .contexts [0 ].index } ##" ):
652
- torch .sum (losses , dim = 0 ).backward ()
653
-
654
- with record_function (
655
- f"## optimizer { cast (int , self .contexts [0 ].index ) - 1 } ##"
656
- ):
657
- if self .is_semi_sync () and self ._stash_gradients :
658
- self ._grad_swap ()
659
- self ._mlp_optimizer_step ()
662
+ with record_function (f"## backward { self .contexts [0 ].index } ##" ):
663
+ torch .sum (losses , dim = 0 ).backward ()
664
+ # pyre-ignore [6]
665
+ self .embedding_backward (self .contexts [0 ])
660
666
661
- with record_function (
662
- f"## zero_grad { cast (int , self .contexts [0 ].index ) - 1 } ##"
663
- ):
664
- self ._optimizer .zero_grad ()
667
+ with record_function (
668
+ f"## optimizer { cast (int , self .contexts [0 ].index ) - 1 } ##"
669
+ ):
670
+ if self .is_semi_sync () and self ._stash_gradients :
671
+ self ._grad_swap ()
672
+ self ._mlp_optimizer_step ()
673
+
674
+ with record_function (
675
+ f"## zero_grad { cast (int , self .contexts [0 ].index ) - 1 } ##"
676
+ ):
677
+ self ._optimizer .zero_grad ()
665
678
666
679
if len (self .batches ) >= 2 and not self .is_semi_sync ():
680
+ torch .cuda .synchronize () # needed to avoid race condition
667
681
# pyre-ignore [6]
668
682
self .start_embedding_lookup (self .batches [1 ], self .contexts [1 ])
669
683
@@ -674,10 +688,26 @@ def _mlp_forward(
674
688
self , batch : In , context : TrainPipelineContext
675
689
) -> Tuple [torch .Tensor , Out ]:
676
690
with record_function (f"## forward { context .index } ##" ):
677
- with torch .cuda .stream (self ._overarch_stream ):
678
- _wait_for_event (batch , context .event )
679
- context .event = None
680
- return cast (Tuple [torch .Tensor , Out ], self ._model (batch ))
691
+ _wait_for_event (batch , context .event )
692
+ context .event = None
693
+ return cast (Tuple [torch .Tensor , Out ], self ._model (batch ))
694
+
695
+ def embedding_backward (self , context : EmbeddingTrainPipelineContext ) -> None :
696
+ default_stream = torch .cuda .current_stream ()
697
+ streams = (
698
+ self ._embedding_even_streams
699
+ if cast (int , context .index ) % 2 == 0
700
+ else self ._embedding_odd_streams
701
+ )
702
+ for stream , emb_tensor , grad_tensor in zip (
703
+ streams ,
704
+ context .embedding_tensors ,
705
+ context .detached_embedding_tensors ,
706
+ ):
707
+ with torch .cuda .stream (stream ):
708
+ # pyre-ignore
709
+ stream .wait_stream (default_stream )
710
+ torch .autograd .backward (emb_tensor , grad_tensor .grad )
681
711
682
712
def copy_batch_to_gpu (
683
713
self ,
@@ -722,15 +752,19 @@ def start_embedding_lookup(
722
752
if batch is None :
723
753
return
724
754
with record_function (f"## start_embedding_lookup { context .index } ##" ):
725
- with torch .cuda .stream (
726
- self ._embedding_even_stream
727
- if cast (int , context .index ) % 2 == 0
728
- else self ._embedding_odd_stream
729
- ):
730
- _wait_for_event (batch , context .event )
731
- _start_embedding_lookup (self ._pipelined_modules , batch , context )
732
- context .event = torch .cuda .Event ()
733
- context .event .record ()
755
+ _wait_for_event (batch , context .event )
756
+ context .event = []
757
+ for i , module in enumerate (self ._pipelined_modules ):
758
+ with torch .cuda .stream (
759
+ self ._embedding_even_streams [i ]
760
+ if cast (int , context .index ) % 2 == 0
761
+ else self ._embedding_odd_streams [i ]
762
+ ):
763
+ _start_embedding_lookup ([module ], batch , context )
764
+ event = torch .cuda .Event ()
765
+ event .record ()
766
+ # pyre-ignore [16]
767
+ context .event .append (event )
734
768
735
769
736
770
class PrefetchTrainPipelineSparseDist (TrainPipelineSparseDist [In , Out ]):
0 commit comments