Skip to content

Commit 298c49e

Browse files
gnahzgfacebook-github-bot
authored andcommitted
Enable unsharded QEBC benchmark (#2139)
Summary: Pull Request resolved: #2139 Unsharded QEBC is used in remote requst only split (running on CPU host) in several MRS models. Improve its performance is of critical importance. This diff add the benchmark for this unsharded module Reviewed By: IvanKobzarev Differential Revision: D58628879
1 parent 71ca217 commit 298c49e

File tree

2 files changed

+120
-47
lines changed

2 files changed

+120
-47
lines changed

torchrec/distributed/benchmark/benchmark_inference.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,43 @@ def benchmark_qebc(args: argparse.Namespace, output_dir: str) -> List[BenchmarkR
129129
)
130130

131131

132+
def benchmark_qebc_unsharded(
133+
args: argparse.Namespace, output_dir: str
134+
) -> List[BenchmarkResult]:
135+
tables = get_tables(TABLE_SIZES)
136+
sharder = TestQuantEBCSharder(
137+
sharding_type="",
138+
kernel_type=EmbeddingComputeKernel.QUANT.value,
139+
shardable_params=[table.name for table in tables],
140+
)
141+
142+
module = QuantEmbeddingBagCollection(
143+
# pyre-ignore [6]
144+
tables=tables,
145+
is_weighted=False,
146+
device=torch.device("cpu"),
147+
quant_state_dict_split_scale_bias=True,
148+
)
149+
150+
args_kwargs = {
151+
argname: getattr(args, argname)
152+
for argname in dir(args)
153+
# Don't include output_dir since output_dir was modified
154+
if not argname.startswith("_") and argname not in IGNORE_ARGNAME
155+
}
156+
157+
return benchmark_module(
158+
module=module,
159+
sharder=sharder,
160+
sharding_types=[],
161+
compile_modes=BENCH_COMPILE_MODES,
162+
tables=tables,
163+
output_dir=output_dir,
164+
benchmark_unsharded=True, # benchmark unsharded module
165+
**args_kwargs,
166+
)
167+
168+
132169
def main() -> None:
133170
args: argparse.Namespace = init_argparse_and_args()
134171

@@ -143,14 +180,26 @@ def main() -> None:
143180
benchmark_results_per_module = []
144181
write_report_funcs_per_module = []
145182

146-
for module_name in ["QuantEmbeddingBagCollection", "QuantEmbeddingCollection"]:
183+
module_names = [
184+
"QuantEmbeddingBagCollection",
185+
"QuantEmbeddingCollection",
186+
]
187+
188+
# Only do unsharded QEBC benchmark when using CPU device
189+
if args.device_type == "cpu":
190+
module_names.append("unshardedQuantEmbeddingBagCollection")
191+
192+
for module_name in module_names:
147193
output_dir = args.output_dir + f"/run_{datetime_sfx}"
148194
if module_name == "QuantEmbeddingBagCollection":
149195
output_dir += "_qebc"
150196
benchmark_func = benchmark_qebc
151-
else:
197+
elif module_name == "QuantEmbeddingCollection":
152198
output_dir += "_qec"
153199
benchmark_func = benchmark_qec
200+
else:
201+
output_dir += "_uqebc"
202+
benchmark_func = benchmark_qebc_unsharded
154203

155204
if not os.path.exists(output_dir):
156205
# Place all outputs under the datetime folder

torchrec/distributed/benchmark/benchmark_utils.py

Lines changed: 69 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,7 @@ def transform_module(
430430
world_size: int,
431431
batch_size: int,
432432
ctx: ContextManager,
433+
benchmark_unsharded_module: bool = False,
433434
) -> torch.nn.Module:
434435
def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module:
435436
eager_module(inputs[0])
@@ -441,52 +442,61 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module:
441442

442443
set_propogate_device(True)
443444

444-
topology: Topology = Topology(world_size=world_size, compute_device=device.type)
445-
planner = EmbeddingShardingPlanner(
446-
topology=topology,
447-
batch_size=batch_size,
448-
enumerator=EmbeddingEnumerator(
445+
sharded_module = None
446+
447+
if not benchmark_unsharded_module:
448+
topology: Topology = Topology(world_size=world_size, compute_device=device.type)
449+
planner = EmbeddingShardingPlanner(
449450
topology=topology,
450451
batch_size=batch_size,
451-
estimator=[
452-
EmbeddingPerfEstimator(topology=topology),
453-
EmbeddingStorageEstimator(topology=topology),
454-
],
455-
),
456-
)
457-
458-
# Don't want to modify the module outright
459-
# Since module is on cpu, won't cause cuda oom.
460-
copied_module = copy.deepcopy(module)
461-
# pyre-ignore [6]
462-
plan = planner.plan(copied_module, [sharder])
463-
464-
if isinstance(ctx, MultiProcessContext):
465-
sharded_module = DistributedModelParallel(
466-
copied_module,
467-
# pyre-ignore[6]
468-
env=ShardingEnv.from_process_group(ctx.pg),
469-
plan=plan,
470-
# pyre-ignore[6]
471-
sharders=[sharder],
472-
device=ctx.device,
452+
enumerator=EmbeddingEnumerator(
453+
topology=topology,
454+
batch_size=batch_size,
455+
estimator=[
456+
EmbeddingPerfEstimator(topology=topology),
457+
EmbeddingStorageEstimator(topology=topology),
458+
],
459+
),
473460
)
474-
else:
475-
env = ShardingEnv.from_local(world_size=topology.world_size, rank=0)
476461

477-
sharded_module = _shard_modules(
478-
module=copied_module,
479-
# pyre-ignore [6]
480-
sharders=[sharder],
481-
device=device,
482-
plan=plan,
483-
env=env,
484-
)
462+
# Don't want to modify the module outright
463+
# Since module is on cpu, won't cause cuda oom.
464+
copied_module = copy.deepcopy(module)
465+
# pyre-ignore [6]
466+
plan = planner.plan(copied_module, [sharder])
467+
468+
if isinstance(ctx, MultiProcessContext):
469+
sharded_module = DistributedModelParallel(
470+
copied_module,
471+
# pyre-ignore[6]
472+
env=ShardingEnv.from_process_group(ctx.pg),
473+
plan=plan,
474+
# pyre-ignore[6]
475+
sharders=[sharder],
476+
device=ctx.device,
477+
)
478+
else:
479+
env = ShardingEnv.from_local(world_size=topology.world_size, rank=0)
480+
481+
sharded_module = _shard_modules(
482+
module=copied_module,
483+
# pyre-ignore [6]
484+
sharders=[sharder],
485+
device=device,
486+
plan=plan,
487+
env=env,
488+
)
485489

486490
if compile_mode == CompileMode.FX_SCRIPT:
487-
return fx_script_module(sharded_module)
491+
return fx_script_module(
492+
# pyre-ignore [6]
493+
sharded_module
494+
if not benchmark_unsharded_module
495+
else module
496+
)
488497
else:
489-
return sharded_module
498+
# pyre-ignore [7]
499+
return sharded_module if not benchmark_unsharded_module else module
490500

491501

492502
def benchmark(
@@ -504,6 +514,7 @@ def benchmark(
504514
rank: int,
505515
enable_logging: bool = True,
506516
device_type: str = "cuda",
517+
benchmark_unsharded_module: bool = False,
507518
) -> BenchmarkResult:
508519
max_mem_allocated: List[int] = []
509520
if enable_logging:
@@ -667,6 +678,7 @@ def init_module_and_run_benchmark(
667678
rank: int = -1,
668679
queue: Optional[mp.Queue] = None,
669680
pooling_configs: Optional[List[int]] = None,
681+
benchmark_unsharded_module: bool = False,
670682
) -> BenchmarkResult:
671683
"""
672684
There are a couple of caveats here as to why the module has to be initialized
@@ -724,9 +736,13 @@ def init_module_and_run_benchmark(
724736
batch_size=batch_size,
725737
# pyre-ignore[6]
726738
ctx=ctx,
739+
benchmark_unsharded_module=benchmark_unsharded_module,
727740
)
728741

729-
name = benchmark_type_name(compile_mode, sharding_type)
742+
if benchmark_unsharded_module:
743+
name = "unsharded" + compile_mode.name
744+
else:
745+
name = benchmark_type_name(compile_mode, sharding_type)
730746

731747
res = benchmark(
732748
name,
@@ -741,6 +757,7 @@ def init_module_and_run_benchmark(
741757
benchmark_func_kwargs=benchmark_func_kwargs,
742758
rank=rank,
743759
device_type=device.type,
760+
benchmark_unsharded_module=benchmark_unsharded_module,
744761
)
745762

746763
if queue is not None:
@@ -825,6 +842,7 @@ def benchmark_module(
825842
world_size: int = 2,
826843
num_benchmarks: int = 5,
827844
output_dir: str = "",
845+
benchmark_unsharded: bool = False,
828846
func_to_benchmark: Callable[..., None] = default_func_to_benchmark,
829847
benchmark_func_kwargs: Optional[Dict[str, Any]] = None,
830848
pooling_configs: Optional[List[int]] = None,
@@ -896,13 +914,17 @@ def benchmark_module(
896914
]
897915
prof_inputs = [rank_inputs[-prof_iters:] for rank_inputs in inputs]
898916

899-
for sharding_type in sharding_types:
917+
for sharding_type in sharding_types if not benchmark_unsharded else ["Unsharded"]:
900918
for compile_mode in compile_modes:
901-
# Test sharders should have a singular sharding_type
902-
# pyre-ignore [16]
903-
sharder._sharding_type = sharding_type.value
919+
if not benchmark_unsharded:
920+
# Test sharders should have a singular sharding_type
921+
# pyre-ignore [16]
922+
sharder._sharding_type = sharding_type.value
923+
# pyre-ignore [6]
924+
benchmark_type = benchmark_type_name(compile_mode, sharding_type)
925+
else:
926+
benchmark_type = "unsharded" + compile_mode.name
904927

905-
benchmark_type = benchmark_type_name(compile_mode, sharding_type)
906928
logging.info(
907929
f"\n\n###### Running Benchmark Type: {benchmark_type} ######\n"
908930
)
@@ -933,6 +955,7 @@ def benchmark_module(
933955
module=wrapped_module,
934956
sharder=sharder,
935957
device=torch.device(device_type),
958+
# pyre-ignore
936959
sharding_type=sharding_type,
937960
compile_mode=compile_mode,
938961
world_size=world_size,
@@ -946,6 +969,7 @@ def benchmark_module(
946969
func_to_benchmark=func_to_benchmark,
947970
benchmark_func_kwargs=benchmark_func_kwargs,
948971
pooling_configs=pooling_configs,
972+
benchmark_unsharded_module=benchmark_unsharded,
949973
)
950974

951975
gc.collect()

0 commit comments

Comments
 (0)