@@ -430,6 +430,7 @@ def transform_module(
430
430
world_size : int ,
431
431
batch_size : int ,
432
432
ctx : ContextManager ,
433
+ benchmark_unsharded_module : bool = False ,
433
434
) -> torch .nn .Module :
434
435
def fx_script_module (eager_module : torch .nn .Module ) -> torch .nn .Module :
435
436
eager_module (inputs [0 ])
@@ -441,52 +442,61 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module:
441
442
442
443
set_propogate_device (True )
443
444
444
- topology : Topology = Topology ( world_size = world_size , compute_device = device . type )
445
- planner = EmbeddingShardingPlanner (
446
- topology = topology ,
447
- batch_size = batch_size ,
448
- enumerator = EmbeddingEnumerator (
445
+ sharded_module = None
446
+
447
+ if not benchmark_unsharded_module :
448
+ topology : Topology = Topology ( world_size = world_size , compute_device = device . type )
449
+ planner = EmbeddingShardingPlanner (
449
450
topology = topology ,
450
451
batch_size = batch_size ,
451
- estimator = [
452
- EmbeddingPerfEstimator (topology = topology ),
453
- EmbeddingStorageEstimator (topology = topology ),
454
- ],
455
- ),
456
- )
457
-
458
- # Don't want to modify the module outright
459
- # Since module is on cpu, won't cause cuda oom.
460
- copied_module = copy .deepcopy (module )
461
- # pyre-ignore [6]
462
- plan = planner .plan (copied_module , [sharder ])
463
-
464
- if isinstance (ctx , MultiProcessContext ):
465
- sharded_module = DistributedModelParallel (
466
- copied_module ,
467
- # pyre-ignore[6]
468
- env = ShardingEnv .from_process_group (ctx .pg ),
469
- plan = plan ,
470
- # pyre-ignore[6]
471
- sharders = [sharder ],
472
- device = ctx .device ,
452
+ enumerator = EmbeddingEnumerator (
453
+ topology = topology ,
454
+ batch_size = batch_size ,
455
+ estimator = [
456
+ EmbeddingPerfEstimator (topology = topology ),
457
+ EmbeddingStorageEstimator (topology = topology ),
458
+ ],
459
+ ),
473
460
)
474
- else :
475
- env = ShardingEnv .from_local (world_size = topology .world_size , rank = 0 )
476
461
477
- sharded_module = _shard_modules (
478
- module = copied_module ,
479
- # pyre-ignore [6]
480
- sharders = [sharder ],
481
- device = device ,
482
- plan = plan ,
483
- env = env ,
484
- )
462
+ # Don't want to modify the module outright
463
+ # Since module is on cpu, won't cause cuda oom.
464
+ copied_module = copy .deepcopy (module )
465
+ # pyre-ignore [6]
466
+ plan = planner .plan (copied_module , [sharder ])
467
+
468
+ if isinstance (ctx , MultiProcessContext ):
469
+ sharded_module = DistributedModelParallel (
470
+ copied_module ,
471
+ # pyre-ignore[6]
472
+ env = ShardingEnv .from_process_group (ctx .pg ),
473
+ plan = plan ,
474
+ # pyre-ignore[6]
475
+ sharders = [sharder ],
476
+ device = ctx .device ,
477
+ )
478
+ else :
479
+ env = ShardingEnv .from_local (world_size = topology .world_size , rank = 0 )
480
+
481
+ sharded_module = _shard_modules (
482
+ module = copied_module ,
483
+ # pyre-ignore [6]
484
+ sharders = [sharder ],
485
+ device = device ,
486
+ plan = plan ,
487
+ env = env ,
488
+ )
485
489
486
490
if compile_mode == CompileMode .FX_SCRIPT :
487
- return fx_script_module (sharded_module )
491
+ return fx_script_module (
492
+ # pyre-ignore [6]
493
+ sharded_module
494
+ if not benchmark_unsharded_module
495
+ else module
496
+ )
488
497
else :
489
- return sharded_module
498
+ # pyre-ignore [7]
499
+ return sharded_module if not benchmark_unsharded_module else module
490
500
491
501
492
502
def benchmark (
@@ -504,6 +514,7 @@ def benchmark(
504
514
rank : int ,
505
515
enable_logging : bool = True ,
506
516
device_type : str = "cuda" ,
517
+ benchmark_unsharded_module : bool = False ,
507
518
) -> BenchmarkResult :
508
519
max_mem_allocated : List [int ] = []
509
520
if enable_logging :
@@ -667,6 +678,7 @@ def init_module_and_run_benchmark(
667
678
rank : int = - 1 ,
668
679
queue : Optional [mp .Queue ] = None ,
669
680
pooling_configs : Optional [List [int ]] = None ,
681
+ benchmark_unsharded_module : bool = False ,
670
682
) -> BenchmarkResult :
671
683
"""
672
684
There are a couple of caveats here as to why the module has to be initialized
@@ -724,9 +736,13 @@ def init_module_and_run_benchmark(
724
736
batch_size = batch_size ,
725
737
# pyre-ignore[6]
726
738
ctx = ctx ,
739
+ benchmark_unsharded_module = benchmark_unsharded_module ,
727
740
)
728
741
729
- name = benchmark_type_name (compile_mode , sharding_type )
742
+ if benchmark_unsharded_module :
743
+ name = "unsharded" + compile_mode .name
744
+ else :
745
+ name = benchmark_type_name (compile_mode , sharding_type )
730
746
731
747
res = benchmark (
732
748
name ,
@@ -741,6 +757,7 @@ def init_module_and_run_benchmark(
741
757
benchmark_func_kwargs = benchmark_func_kwargs ,
742
758
rank = rank ,
743
759
device_type = device .type ,
760
+ benchmark_unsharded_module = benchmark_unsharded_module ,
744
761
)
745
762
746
763
if queue is not None :
@@ -825,6 +842,7 @@ def benchmark_module(
825
842
world_size : int = 2 ,
826
843
num_benchmarks : int = 5 ,
827
844
output_dir : str = "" ,
845
+ benchmark_unsharded : bool = False ,
828
846
func_to_benchmark : Callable [..., None ] = default_func_to_benchmark ,
829
847
benchmark_func_kwargs : Optional [Dict [str , Any ]] = None ,
830
848
pooling_configs : Optional [List [int ]] = None ,
@@ -896,13 +914,17 @@ def benchmark_module(
896
914
]
897
915
prof_inputs = [rank_inputs [- prof_iters :] for rank_inputs in inputs ]
898
916
899
- for sharding_type in sharding_types :
917
+ for sharding_type in sharding_types if not benchmark_unsharded else [ "Unsharded" ] :
900
918
for compile_mode in compile_modes :
901
- # Test sharders should have a singular sharding_type
902
- # pyre-ignore [16]
903
- sharder ._sharding_type = sharding_type .value
919
+ if not benchmark_unsharded :
920
+ # Test sharders should have a singular sharding_type
921
+ # pyre-ignore [16]
922
+ sharder ._sharding_type = sharding_type .value
923
+ # pyre-ignore [6]
924
+ benchmark_type = benchmark_type_name (compile_mode , sharding_type )
925
+ else :
926
+ benchmark_type = "unsharded" + compile_mode .name
904
927
905
- benchmark_type = benchmark_type_name (compile_mode , sharding_type )
906
928
logging .info (
907
929
f"\n \n ###### Running Benchmark Type: { benchmark_type } ######\n "
908
930
)
@@ -933,6 +955,7 @@ def benchmark_module(
933
955
module = wrapped_module ,
934
956
sharder = sharder ,
935
957
device = torch .device (device_type ),
958
+ # pyre-ignore
936
959
sharding_type = sharding_type ,
937
960
compile_mode = compile_mode ,
938
961
world_size = world_size ,
@@ -946,6 +969,7 @@ def benchmark_module(
946
969
func_to_benchmark = func_to_benchmark ,
947
970
benchmark_func_kwargs = benchmark_func_kwargs ,
948
971
pooling_configs = pooling_configs ,
972
+ benchmark_unsharded_module = benchmark_unsharded ,
949
973
)
950
974
951
975
gc .collect ()
0 commit comments