From 298c49e9f858f8b03f2018b449336ca60ec77875 Mon Sep 17 00:00:00 2001 From: Qiang Zhang Date: Tue, 25 Jun 2024 14:46:37 -0700 Subject: [PATCH] Enable unsharded QEBC benchmark (#2139) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2139 Unsharded QEBC is used in remote requst only split (running on CPU host) in several MRS models. Improve its performance is of critical importance. This diff add the benchmark for this unsharded module Reviewed By: IvanKobzarev Differential Revision: D58628879 --- .../benchmark/benchmark_inference.py | 53 +++++++- .../distributed/benchmark/benchmark_utils.py | 114 +++++++++++------- 2 files changed, 120 insertions(+), 47 deletions(-) diff --git a/torchrec/distributed/benchmark/benchmark_inference.py b/torchrec/distributed/benchmark/benchmark_inference.py index 4eb7878e4..ecfcf1c73 100644 --- a/torchrec/distributed/benchmark/benchmark_inference.py +++ b/torchrec/distributed/benchmark/benchmark_inference.py @@ -129,6 +129,43 @@ def benchmark_qebc(args: argparse.Namespace, output_dir: str) -> List[BenchmarkR ) +def benchmark_qebc_unsharded( + args: argparse.Namespace, output_dir: str +) -> List[BenchmarkResult]: + tables = get_tables(TABLE_SIZES) + sharder = TestQuantEBCSharder( + sharding_type="", + kernel_type=EmbeddingComputeKernel.QUANT.value, + shardable_params=[table.name for table in tables], + ) + + module = QuantEmbeddingBagCollection( + # pyre-ignore [6] + tables=tables, + is_weighted=False, + device=torch.device("cpu"), + quant_state_dict_split_scale_bias=True, + ) + + args_kwargs = { + argname: getattr(args, argname) + for argname in dir(args) + # Don't include output_dir since output_dir was modified + if not argname.startswith("_") and argname not in IGNORE_ARGNAME + } + + return benchmark_module( + module=module, + sharder=sharder, + sharding_types=[], + compile_modes=BENCH_COMPILE_MODES, + tables=tables, + output_dir=output_dir, + benchmark_unsharded=True, # benchmark unsharded module + **args_kwargs, + ) + + def main() -> None: args: argparse.Namespace = init_argparse_and_args() @@ -143,14 +180,26 @@ def main() -> None: benchmark_results_per_module = [] write_report_funcs_per_module = [] - for module_name in ["QuantEmbeddingBagCollection", "QuantEmbeddingCollection"]: + module_names = [ + "QuantEmbeddingBagCollection", + "QuantEmbeddingCollection", + ] + + # Only do unsharded QEBC benchmark when using CPU device + if args.device_type == "cpu": + module_names.append("unshardedQuantEmbeddingBagCollection") + + for module_name in module_names: output_dir = args.output_dir + f"/run_{datetime_sfx}" if module_name == "QuantEmbeddingBagCollection": output_dir += "_qebc" benchmark_func = benchmark_qebc - else: + elif module_name == "QuantEmbeddingCollection": output_dir += "_qec" benchmark_func = benchmark_qec + else: + output_dir += "_uqebc" + benchmark_func = benchmark_qebc_unsharded if not os.path.exists(output_dir): # Place all outputs under the datetime folder diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index b4f0fc656..ebdbdc680 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -430,6 +430,7 @@ def transform_module( world_size: int, batch_size: int, ctx: ContextManager, + benchmark_unsharded_module: bool = False, ) -> torch.nn.Module: def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module: eager_module(inputs[0]) @@ -441,52 +442,61 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module: set_propogate_device(True) - topology: Topology = Topology(world_size=world_size, compute_device=device.type) - planner = EmbeddingShardingPlanner( - topology=topology, - batch_size=batch_size, - enumerator=EmbeddingEnumerator( + sharded_module = None + + if not benchmark_unsharded_module: + topology: Topology = Topology(world_size=world_size, compute_device=device.type) + planner = EmbeddingShardingPlanner( topology=topology, batch_size=batch_size, - estimator=[ - EmbeddingPerfEstimator(topology=topology), - EmbeddingStorageEstimator(topology=topology), - ], - ), - ) - - # Don't want to modify the module outright - # Since module is on cpu, won't cause cuda oom. - copied_module = copy.deepcopy(module) - # pyre-ignore [6] - plan = planner.plan(copied_module, [sharder]) - - if isinstance(ctx, MultiProcessContext): - sharded_module = DistributedModelParallel( - copied_module, - # pyre-ignore[6] - env=ShardingEnv.from_process_group(ctx.pg), - plan=plan, - # pyre-ignore[6] - sharders=[sharder], - device=ctx.device, + enumerator=EmbeddingEnumerator( + topology=topology, + batch_size=batch_size, + estimator=[ + EmbeddingPerfEstimator(topology=topology), + EmbeddingStorageEstimator(topology=topology), + ], + ), ) - else: - env = ShardingEnv.from_local(world_size=topology.world_size, rank=0) - sharded_module = _shard_modules( - module=copied_module, - # pyre-ignore [6] - sharders=[sharder], - device=device, - plan=plan, - env=env, - ) + # Don't want to modify the module outright + # Since module is on cpu, won't cause cuda oom. + copied_module = copy.deepcopy(module) + # pyre-ignore [6] + plan = planner.plan(copied_module, [sharder]) + + if isinstance(ctx, MultiProcessContext): + sharded_module = DistributedModelParallel( + copied_module, + # pyre-ignore[6] + env=ShardingEnv.from_process_group(ctx.pg), + plan=plan, + # pyre-ignore[6] + sharders=[sharder], + device=ctx.device, + ) + else: + env = ShardingEnv.from_local(world_size=topology.world_size, rank=0) + + sharded_module = _shard_modules( + module=copied_module, + # pyre-ignore [6] + sharders=[sharder], + device=device, + plan=plan, + env=env, + ) if compile_mode == CompileMode.FX_SCRIPT: - return fx_script_module(sharded_module) + return fx_script_module( + # pyre-ignore [6] + sharded_module + if not benchmark_unsharded_module + else module + ) else: - return sharded_module + # pyre-ignore [7] + return sharded_module if not benchmark_unsharded_module else module def benchmark( @@ -504,6 +514,7 @@ def benchmark( rank: int, enable_logging: bool = True, device_type: str = "cuda", + benchmark_unsharded_module: bool = False, ) -> BenchmarkResult: max_mem_allocated: List[int] = [] if enable_logging: @@ -667,6 +678,7 @@ def init_module_and_run_benchmark( rank: int = -1, queue: Optional[mp.Queue] = None, pooling_configs: Optional[List[int]] = None, + benchmark_unsharded_module: bool = False, ) -> BenchmarkResult: """ There are a couple of caveats here as to why the module has to be initialized @@ -724,9 +736,13 @@ def init_module_and_run_benchmark( batch_size=batch_size, # pyre-ignore[6] ctx=ctx, + benchmark_unsharded_module=benchmark_unsharded_module, ) - name = benchmark_type_name(compile_mode, sharding_type) + if benchmark_unsharded_module: + name = "unsharded" + compile_mode.name + else: + name = benchmark_type_name(compile_mode, sharding_type) res = benchmark( name, @@ -741,6 +757,7 @@ def init_module_and_run_benchmark( benchmark_func_kwargs=benchmark_func_kwargs, rank=rank, device_type=device.type, + benchmark_unsharded_module=benchmark_unsharded_module, ) if queue is not None: @@ -825,6 +842,7 @@ def benchmark_module( world_size: int = 2, num_benchmarks: int = 5, output_dir: str = "", + benchmark_unsharded: bool = False, func_to_benchmark: Callable[..., None] = default_func_to_benchmark, benchmark_func_kwargs: Optional[Dict[str, Any]] = None, pooling_configs: Optional[List[int]] = None, @@ -896,13 +914,17 @@ def benchmark_module( ] prof_inputs = [rank_inputs[-prof_iters:] for rank_inputs in inputs] - for sharding_type in sharding_types: + for sharding_type in sharding_types if not benchmark_unsharded else ["Unsharded"]: for compile_mode in compile_modes: - # Test sharders should have a singular sharding_type - # pyre-ignore [16] - sharder._sharding_type = sharding_type.value + if not benchmark_unsharded: + # Test sharders should have a singular sharding_type + # pyre-ignore [16] + sharder._sharding_type = sharding_type.value + # pyre-ignore [6] + benchmark_type = benchmark_type_name(compile_mode, sharding_type) + else: + benchmark_type = "unsharded" + compile_mode.name - benchmark_type = benchmark_type_name(compile_mode, sharding_type) logging.info( f"\n\n###### Running Benchmark Type: {benchmark_type} ######\n" ) @@ -933,6 +955,7 @@ def benchmark_module( module=wrapped_module, sharder=sharder, device=torch.device(device_type), + # pyre-ignore sharding_type=sharding_type, compile_mode=compile_mode, world_size=world_size, @@ -946,6 +969,7 @@ def benchmark_module( func_to_benchmark=func_to_benchmark, benchmark_func_kwargs=benchmark_func_kwargs, pooling_configs=pooling_configs, + benchmark_unsharded_module=benchmark_unsharded, ) gc.collect()