diff --git a/torchrec/distributed/benchmark/benchmark_inference.py b/torchrec/distributed/benchmark/benchmark_inference.py index 4eb7878e4..ecfcf1c73 100644 --- a/torchrec/distributed/benchmark/benchmark_inference.py +++ b/torchrec/distributed/benchmark/benchmark_inference.py @@ -129,6 +129,43 @@ def benchmark_qebc(args: argparse.Namespace, output_dir: str) -> List[BenchmarkR ) +def benchmark_qebc_unsharded( + args: argparse.Namespace, output_dir: str +) -> List[BenchmarkResult]: + tables = get_tables(TABLE_SIZES) + sharder = TestQuantEBCSharder( + sharding_type="", + kernel_type=EmbeddingComputeKernel.QUANT.value, + shardable_params=[table.name for table in tables], + ) + + module = QuantEmbeddingBagCollection( + # pyre-ignore [6] + tables=tables, + is_weighted=False, + device=torch.device("cpu"), + quant_state_dict_split_scale_bias=True, + ) + + args_kwargs = { + argname: getattr(args, argname) + for argname in dir(args) + # Don't include output_dir since output_dir was modified + if not argname.startswith("_") and argname not in IGNORE_ARGNAME + } + + return benchmark_module( + module=module, + sharder=sharder, + sharding_types=[], + compile_modes=BENCH_COMPILE_MODES, + tables=tables, + output_dir=output_dir, + benchmark_unsharded=True, # benchmark unsharded module + **args_kwargs, + ) + + def main() -> None: args: argparse.Namespace = init_argparse_and_args() @@ -143,14 +180,26 @@ def main() -> None: benchmark_results_per_module = [] write_report_funcs_per_module = [] - for module_name in ["QuantEmbeddingBagCollection", "QuantEmbeddingCollection"]: + module_names = [ + "QuantEmbeddingBagCollection", + "QuantEmbeddingCollection", + ] + + # Only do unsharded QEBC benchmark when using CPU device + if args.device_type == "cpu": + module_names.append("unshardedQuantEmbeddingBagCollection") + + for module_name in module_names: output_dir = args.output_dir + f"/run_{datetime_sfx}" if module_name == "QuantEmbeddingBagCollection": output_dir += "_qebc" benchmark_func = benchmark_qebc - else: + elif module_name == "QuantEmbeddingCollection": output_dir += "_qec" benchmark_func = benchmark_qec + else: + output_dir += "_uqebc" + benchmark_func = benchmark_qebc_unsharded if not os.path.exists(output_dir): # Place all outputs under the datetime folder diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index b4f0fc656..ebdbdc680 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -430,6 +430,7 @@ def transform_module( world_size: int, batch_size: int, ctx: ContextManager, + benchmark_unsharded_module: bool = False, ) -> torch.nn.Module: def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module: eager_module(inputs[0]) @@ -441,52 +442,61 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module: set_propogate_device(True) - topology: Topology = Topology(world_size=world_size, compute_device=device.type) - planner = EmbeddingShardingPlanner( - topology=topology, - batch_size=batch_size, - enumerator=EmbeddingEnumerator( + sharded_module = None + + if not benchmark_unsharded_module: + topology: Topology = Topology(world_size=world_size, compute_device=device.type) + planner = EmbeddingShardingPlanner( topology=topology, batch_size=batch_size, - estimator=[ - EmbeddingPerfEstimator(topology=topology), - EmbeddingStorageEstimator(topology=topology), - ], - ), - ) - - # Don't want to modify the module outright - # Since module is on cpu, won't cause cuda oom. - copied_module = copy.deepcopy(module) - # pyre-ignore [6] - plan = planner.plan(copied_module, [sharder]) - - if isinstance(ctx, MultiProcessContext): - sharded_module = DistributedModelParallel( - copied_module, - # pyre-ignore[6] - env=ShardingEnv.from_process_group(ctx.pg), - plan=plan, - # pyre-ignore[6] - sharders=[sharder], - device=ctx.device, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology), + EmbeddingStorageEstimator(topology=topology), + ], + ), ) - else: - env = ShardingEnv.from_local(world_size=topology.world_size, rank=0) - sharded_module = _shard_modules( - module=copied_module, - # pyre-ignore [6] - sharders=[sharder], - device=device, - plan=plan, - env=env, - ) + # Don't want to modify the module outright + # Since module is on cpu, won't cause cuda oom. + copied_module = copy.deepcopy(module) + # pyre-ignore [6] + plan = planner.plan(copied_module, [sharder]) + + if isinstance(ctx, MultiProcessContext): + sharded_module = DistributedModelParallel( + copied_module, + # pyre-ignore[6] + env=ShardingEnv.from_process_group(ctx.pg), + plan=plan, + # pyre-ignore[6] + sharders=[sharder], + device=ctx.device, + ) + else: + env = ShardingEnv.from_local(world_size=topology.world_size, rank=0) + + sharded_module = _shard_modules( + module=copied_module, + # pyre-ignore [6] + sharders=[sharder], + device=device, + plan=plan, + env=env, + ) if compile_mode == CompileMode.FX_SCRIPT: - return fx_script_module(sharded_module) + return fx_script_module( + # pyre-ignore [6] + sharded_module + if not benchmark_unsharded_module + else module + ) else: - return sharded_module + # pyre-ignore [7] + return sharded_module if not benchmark_unsharded_module else module def benchmark( @@ -504,6 +514,7 @@ def benchmark( rank: int, enable_logging: bool = True, device_type: str = "cuda", + benchmark_unsharded_module: bool = False, ) -> BenchmarkResult: max_mem_allocated: List[int] = [] if enable_logging: @@ -667,6 +678,7 @@ def init_module_and_run_benchmark( rank: int = -1, queue: Optional[mp.Queue] = None, pooling_configs: Optional[List[int]] = None, + benchmark_unsharded_module: bool = False, ) -> BenchmarkResult: """ There are a couple of caveats here as to why the module has to be initialized @@ -724,9 +736,13 @@ def init_module_and_run_benchmark( batch_size=batch_size, # pyre-ignore[6] ctx=ctx, + benchmark_unsharded_module=benchmark_unsharded_module, ) - name = benchmark_type_name(compile_mode, sharding_type) + if benchmark_unsharded_module: + name = "unsharded" + compile_mode.name + else: + name = benchmark_type_name(compile_mode, sharding_type) res = benchmark( name, @@ -741,6 +757,7 @@ def init_module_and_run_benchmark( benchmark_func_kwargs=benchmark_func_kwargs, rank=rank, device_type=device.type, + benchmark_unsharded_module=benchmark_unsharded_module, ) if queue is not None: @@ -825,6 +842,7 @@ def benchmark_module( world_size: int = 2, num_benchmarks: int = 5, output_dir: str = "", + benchmark_unsharded: bool = False, func_to_benchmark: Callable[..., None] = default_func_to_benchmark, benchmark_func_kwargs: Optional[Dict[str, Any]] = None, pooling_configs: Optional[List[int]] = None, @@ -896,13 +914,17 @@ def benchmark_module( ] prof_inputs = [rank_inputs[-prof_iters:] for rank_inputs in inputs] - for sharding_type in sharding_types: + for sharding_type in sharding_types if not benchmark_unsharded else ["Unsharded"]: for compile_mode in compile_modes: - # Test sharders should have a singular sharding_type - # pyre-ignore [16] - sharder._sharding_type = sharding_type.value + if not benchmark_unsharded: + # Test sharders should have a singular sharding_type + # pyre-ignore [16] + sharder._sharding_type = sharding_type.value + # pyre-ignore [6] + benchmark_type = benchmark_type_name(compile_mode, sharding_type) + else: + benchmark_type = "unsharded" + compile_mode.name - benchmark_type = benchmark_type_name(compile_mode, sharding_type) logging.info( f"\n\n###### Running Benchmark Type: {benchmark_type} ######\n" ) @@ -933,6 +955,7 @@ def benchmark_module( module=wrapped_module, sharder=sharder, device=torch.device(device_type), + # pyre-ignore sharding_type=sharding_type, compile_mode=compile_mode, world_size=world_size, @@ -946,6 +969,7 @@ def benchmark_module( func_to_benchmark=func_to_benchmark, benchmark_func_kwargs=benchmark_func_kwargs, pooling_configs=pooling_configs, + benchmark_unsharded_module=benchmark_unsharded, ) gc.collect()