@@ -643,6 +643,127 @@ def trace_handler(prof) -> None:
643
643
)
644
644
645
645
646
+ def benchmark_func (
647
+ name : str ,
648
+ bench_inputs : List [Dict [str , Any ]],
649
+ prof_inputs : List [Dict [str , Any ]],
650
+ world_size : int ,
651
+ profile_dir : str ,
652
+ num_benchmarks : int ,
653
+ num_profiles : int ,
654
+ # pyre-ignore[2]
655
+ func_to_benchmark : Any ,
656
+ benchmark_func_kwargs : Optional [Dict [str , Any ]],
657
+ rank : int ,
658
+ device_type : str = "cuda" ,
659
+ ) -> BenchmarkResult :
660
+ max_mem_allocated : List [int ] = []
661
+ if device_type == "cuda" :
662
+ if rank == - 1 :
663
+ # Reset memory for measurement, no process per rank so do all
664
+ for di in range (world_size ):
665
+ torch .cuda .reset_peak_memory_stats (di )
666
+ else :
667
+ torch .cuda .reset_peak_memory_stats (rank )
668
+
669
+ start = []
670
+ end = []
671
+ if device_type == "cuda" :
672
+ # Measure time taken for batches in bench_inputs
673
+ start = [torch .cuda .Event (enable_timing = True ) for _ in range (num_benchmarks )]
674
+ end = [torch .cuda .Event (enable_timing = True ) for _ in range (num_benchmarks )]
675
+
676
+ if benchmark_func_kwargs is None :
677
+ # Need this to unwrap
678
+ benchmark_func_kwargs = {}
679
+
680
+ times = []
681
+ if device_type == "cuda" :
682
+ for i in range (num_benchmarks ):
683
+ start [i ].record ()
684
+ func_to_benchmark (bench_inputs , ** benchmark_func_kwargs )
685
+ end [i ].record ()
686
+ elif device_type == "cpu" :
687
+ times = timeit .repeat (
688
+ lambda : func_to_benchmark (bench_inputs , ** benchmark_func_kwargs ),
689
+ number = 1 ,
690
+ repeat = num_benchmarks ,
691
+ )
692
+
693
+ if device_type == "cuda" :
694
+ if rank == - 1 :
695
+ for di in range (world_size ):
696
+ torch .cuda .synchronize (di )
697
+ else :
698
+ torch .cuda .synchronize (rank )
699
+
700
+ # TODO: First Benchmark Run for Eager Mode produces outlier
701
+ # Start counting after first as workaround for standard deviation
702
+ if device_type == "cuda" :
703
+ elapsed_time = torch .tensor (
704
+ [si .elapsed_time (ei ) for si , ei in zip (start [1 :], end [1 :])]
705
+ )
706
+ else :
707
+ elapsed_time = torch .tensor (times ) * 1e3
708
+
709
+ if device_type == "cuda" :
710
+ if rank == - 1 :
711
+ # Add up all memory allocated in inference mode
712
+ for di in range (world_size ):
713
+ b = torch .cuda .max_memory_allocated (di )
714
+ max_mem_allocated .append (b // 1024 // 1024 )
715
+ else :
716
+ # Only add up memory allocated for current rank in training mode
717
+ b = torch .cuda .max_memory_allocated (rank )
718
+ max_mem_allocated .append (b // 1024 // 1024 )
719
+
720
+ if profile_dir != "" :
721
+ # Only do profiling if output_dir is set
722
+
723
+ # pyre-ignore[2]
724
+ def trace_handler (prof ) -> None :
725
+ total_average = prof .profiler .total_average ()
726
+ logger .info (f" TOTAL_AVERAGE:\n { name } \n { total_average } " )
727
+ dir_path : str = profile_dir
728
+ if rank == 0 :
729
+ trace_file : str = f"{ dir_path } /trace-{ name } .json"
730
+ else :
731
+ trace_file : str = f"{ dir_path } /trace-{ name } -{ rank } .json"
732
+ return # only 1 rank should output in pg case, rank = 0
733
+ logger .info (f" PROFILE[{ name } ].chrome_trace:{ trace_file } " )
734
+ prof .export_chrome_trace (trace_file )
735
+
736
+ if device_type == "cuda" :
737
+ with torch .profiler .profile (
738
+ activities = [
739
+ torch .profiler .ProfilerActivity .CPU ,
740
+ torch .profiler .ProfilerActivity .CUDA ,
741
+ ],
742
+ record_shapes = True ,
743
+ profile_memory = True ,
744
+ with_flops = True ,
745
+ with_modules = True ,
746
+ on_trace_ready = trace_handler ,
747
+ ) as p :
748
+ for i in range (num_profiles ):
749
+ with record_function (f"## profile { i } ##" ):
750
+ func_to_benchmark (prof_inputs , ** benchmark_func_kwargs )
751
+ p .step ()
752
+
753
+ if rank == - 1 :
754
+ for di in range (torch .cuda .device_count ()):
755
+ torch .cuda .synchronize (torch .device (f"cuda:{ di } " ))
756
+ else :
757
+ torch .cuda .synchronize ()
758
+
759
+ return BenchmarkResult (
760
+ short_name = name ,
761
+ elapsed_time = elapsed_time ,
762
+ max_mem_allocated = max_mem_allocated ,
763
+ rank = rank ,
764
+ )
765
+
766
+
646
767
def benchmark_type_name (compile_mode : CompileMode , sharding_type : ShardingType ) -> str :
647
768
if sharding_type == ShardingType .TABLE_WISE :
648
769
name = "tw-sharded"
0 commit comments