Skip to content

Enable unsharded QEBC benchmark #2139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions torchrec/distributed/benchmark/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,43 @@ def benchmark_qebc(args: argparse.Namespace, output_dir: str) -> List[BenchmarkR
)


def benchmark_qebc_unsharded(
args: argparse.Namespace, output_dir: str
) -> List[BenchmarkResult]:
tables = get_tables(TABLE_SIZES)
sharder = TestQuantEBCSharder(
sharding_type="",
kernel_type=EmbeddingComputeKernel.QUANT.value,
shardable_params=[table.name for table in tables],
)

module = QuantEmbeddingBagCollection(
# pyre-ignore [6]
tables=tables,
is_weighted=False,
device=torch.device("cpu"),
quant_state_dict_split_scale_bias=True,
)

args_kwargs = {
argname: getattr(args, argname)
for argname in dir(args)
# Don't include output_dir since output_dir was modified
if not argname.startswith("_") and argname not in IGNORE_ARGNAME
}

return benchmark_module(
module=module,
sharder=sharder,
sharding_types=[],
compile_modes=BENCH_COMPILE_MODES,
tables=tables,
output_dir=output_dir,
benchmark_unsharded=True, # benchmark unsharded module
**args_kwargs,
)


def main() -> None:
args: argparse.Namespace = init_argparse_and_args()

Expand All @@ -143,14 +180,26 @@ def main() -> None:
benchmark_results_per_module = []
write_report_funcs_per_module = []

for module_name in ["QuantEmbeddingBagCollection", "QuantEmbeddingCollection"]:
module_names = [
"QuantEmbeddingBagCollection",
"QuantEmbeddingCollection",
]

# Only do unsharded QEBC benchmark when using CPU device
if args.device_type == "cpu":
module_names.append("unshardedQuantEmbeddingBagCollection")

for module_name in module_names:
output_dir = args.output_dir + f"/run_{datetime_sfx}"
if module_name == "QuantEmbeddingBagCollection":
output_dir += "_qebc"
benchmark_func = benchmark_qebc
else:
elif module_name == "QuantEmbeddingCollection":
output_dir += "_qec"
benchmark_func = benchmark_qec
else:
output_dir += "_uqebc"
benchmark_func = benchmark_qebc_unsharded

if not os.path.exists(output_dir):
# Place all outputs under the datetime folder
Expand Down
114 changes: 69 additions & 45 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ def transform_module(
world_size: int,
batch_size: int,
ctx: ContextManager,
benchmark_unsharded_module: bool = False,
) -> torch.nn.Module:
def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module:
eager_module(inputs[0])
Expand All @@ -441,52 +442,61 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module:

set_propogate_device(True)

topology: Topology = Topology(world_size=world_size, compute_device=device.type)
planner = EmbeddingShardingPlanner(
topology=topology,
batch_size=batch_size,
enumerator=EmbeddingEnumerator(
sharded_module = None

if not benchmark_unsharded_module:
topology: Topology = Topology(world_size=world_size, compute_device=device.type)
planner = EmbeddingShardingPlanner(
topology=topology,
batch_size=batch_size,
estimator=[
EmbeddingPerfEstimator(topology=topology),
EmbeddingStorageEstimator(topology=topology),
],
),
)

# Don't want to modify the module outright
# Since module is on cpu, won't cause cuda oom.
copied_module = copy.deepcopy(module)
# pyre-ignore [6]
plan = planner.plan(copied_module, [sharder])

if isinstance(ctx, MultiProcessContext):
sharded_module = DistributedModelParallel(
copied_module,
# pyre-ignore[6]
env=ShardingEnv.from_process_group(ctx.pg),
plan=plan,
# pyre-ignore[6]
sharders=[sharder],
device=ctx.device,
enumerator=EmbeddingEnumerator(
topology=topology,
batch_size=batch_size,
estimator=[
EmbeddingPerfEstimator(topology=topology),
EmbeddingStorageEstimator(topology=topology),
],
),
)
else:
env = ShardingEnv.from_local(world_size=topology.world_size, rank=0)

sharded_module = _shard_modules(
module=copied_module,
# pyre-ignore [6]
sharders=[sharder],
device=device,
plan=plan,
env=env,
)
# Don't want to modify the module outright
# Since module is on cpu, won't cause cuda oom.
copied_module = copy.deepcopy(module)
# pyre-ignore [6]
plan = planner.plan(copied_module, [sharder])

if isinstance(ctx, MultiProcessContext):
sharded_module = DistributedModelParallel(
copied_module,
# pyre-ignore[6]
env=ShardingEnv.from_process_group(ctx.pg),
plan=plan,
# pyre-ignore[6]
sharders=[sharder],
device=ctx.device,
)
else:
env = ShardingEnv.from_local(world_size=topology.world_size, rank=0)

sharded_module = _shard_modules(
module=copied_module,
# pyre-ignore [6]
sharders=[sharder],
device=device,
plan=plan,
env=env,
)

if compile_mode == CompileMode.FX_SCRIPT:
return fx_script_module(sharded_module)
return fx_script_module(
# pyre-ignore [6]
sharded_module
if not benchmark_unsharded_module
else module
)
else:
return sharded_module
# pyre-ignore [7]
return sharded_module if not benchmark_unsharded_module else module


def benchmark(
Expand All @@ -504,6 +514,7 @@ def benchmark(
rank: int,
enable_logging: bool = True,
device_type: str = "cuda",
benchmark_unsharded_module: bool = False,
) -> BenchmarkResult:
max_mem_allocated: List[int] = []
if enable_logging:
Expand Down Expand Up @@ -667,6 +678,7 @@ def init_module_and_run_benchmark(
rank: int = -1,
queue: Optional[mp.Queue] = None,
pooling_configs: Optional[List[int]] = None,
benchmark_unsharded_module: bool = False,
) -> BenchmarkResult:
"""
There are a couple of caveats here as to why the module has to be initialized
Expand Down Expand Up @@ -724,9 +736,13 @@ def init_module_and_run_benchmark(
batch_size=batch_size,
# pyre-ignore[6]
ctx=ctx,
benchmark_unsharded_module=benchmark_unsharded_module,
)

name = benchmark_type_name(compile_mode, sharding_type)
if benchmark_unsharded_module:
name = "unsharded" + compile_mode.name
else:
name = benchmark_type_name(compile_mode, sharding_type)

res = benchmark(
name,
Expand All @@ -741,6 +757,7 @@ def init_module_and_run_benchmark(
benchmark_func_kwargs=benchmark_func_kwargs,
rank=rank,
device_type=device.type,
benchmark_unsharded_module=benchmark_unsharded_module,
)

if queue is not None:
Expand Down Expand Up @@ -825,6 +842,7 @@ def benchmark_module(
world_size: int = 2,
num_benchmarks: int = 5,
output_dir: str = "",
benchmark_unsharded: bool = False,
func_to_benchmark: Callable[..., None] = default_func_to_benchmark,
benchmark_func_kwargs: Optional[Dict[str, Any]] = None,
pooling_configs: Optional[List[int]] = None,
Expand Down Expand Up @@ -896,13 +914,17 @@ def benchmark_module(
]
prof_inputs = [rank_inputs[-prof_iters:] for rank_inputs in inputs]

for sharding_type in sharding_types:
for sharding_type in sharding_types if not benchmark_unsharded else ["Unsharded"]:
for compile_mode in compile_modes:
# Test sharders should have a singular sharding_type
# pyre-ignore [16]
sharder._sharding_type = sharding_type.value
if not benchmark_unsharded:
# Test sharders should have a singular sharding_type
# pyre-ignore [16]
sharder._sharding_type = sharding_type.value
# pyre-ignore [6]
benchmark_type = benchmark_type_name(compile_mode, sharding_type)
else:
benchmark_type = "unsharded" + compile_mode.name

benchmark_type = benchmark_type_name(compile_mode, sharding_type)
logging.info(
f"\n\n###### Running Benchmark Type: {benchmark_type} ######\n"
)
Expand Down Expand Up @@ -933,6 +955,7 @@ def benchmark_module(
module=wrapped_module,
sharder=sharder,
device=torch.device(device_type),
# pyre-ignore
sharding_type=sharding_type,
compile_mode=compile_mode,
world_size=world_size,
Expand All @@ -946,6 +969,7 @@ def benchmark_module(
func_to_benchmark=func_to_benchmark,
benchmark_func_kwargs=benchmark_func_kwargs,
pooling_configs=pooling_configs,
benchmark_unsharded_module=benchmark_unsharded,
)

gc.collect()
Expand Down
Loading