Skip to content

Commit 4981be5

Browse files
PaulZhang12facebook-github-bot
authored andcommitted
VBE training benchmarks (Manual) (#1855)
Summary: Pull Request resolved: #1855 Set TorchRec's distributed training benchmarks to include VBE. Reviewed By: joshuadeng Differential Revision: D55882022 fbshipit-source-id: 59640c74157c4c734f23286bad394f3c0c3d3145
1 parent 8a6547d commit 4981be5

File tree

2 files changed

+55
-10
lines changed

2 files changed

+55
-10
lines changed

torchrec/distributed/benchmark/benchmark_train.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def benchmark_ebc(
7171
args: argparse.Namespace,
7272
output_dir: str,
7373
pooling_configs: Optional[List[int]] = None,
74+
variable_batch_embeddings: bool = False,
7475
) -> List[BenchmarkResult]:
7576
table_configs = get_tables(tables, data_type=DataType.FP32)
7677
sharder = TestEBCSharder(
@@ -104,6 +105,8 @@ def benchmark_ebc(
104105
if pooling_configs:
105106
args_kwargs["pooling_configs"] = pooling_configs
106107

108+
args_kwargs["variable_batch_embeddings"] = variable_batch_embeddings
109+
107110
return benchmark_module(
108111
module=module,
109112
sharder=sharder,
@@ -153,6 +156,7 @@ def main() -> None:
153156
mb = int(float(num * dim) / 1024 / 1024) * 4
154157
tables_info += f"\nTABLE[{i}][{num:9}, {dim:4}] {mb:6}Mb"
155158

159+
### Benchmark no VBE
156160
report: str = (
157161
f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n"
158162
)
@@ -176,6 +180,27 @@ def main() -> None:
176180
)
177181
)
178182

183+
### Benchmark with VBE
184+
report: str = (
185+
f"REPORT BENCHMARK (VBE) {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n"
186+
)
187+
report += f"Module: {module_name} (VBE)\n"
188+
report += tables_info
189+
report += "\n"
190+
report_file = f"{output_dir}/run_vbe.report"
191+
192+
benchmark_results_per_module.append(
193+
benchmark_func(shrunk_table_sizes, args, output_dir, pooling_configs, True)
194+
)
195+
write_report_funcs_per_module.append(
196+
partial(
197+
write_report,
198+
report_file=report_file,
199+
report_str=report,
200+
num_requests=num_requests,
201+
)
202+
)
203+
179204
for i, write_report_func in enumerate(write_report_funcs_per_module):
180205
write_report_func(benchmark_results_per_module[i])
181206

torchrec/distributed/benchmark/benchmark_utils.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -269,19 +269,32 @@ def get_inputs(
269269
num_inputs: int,
270270
train: bool,
271271
pooling_configs: Optional[List[int]] = None,
272+
variable_batch_embeddings: bool = False,
272273
) -> List[List[KeyedJaggedTensor]]:
273274
inputs_batch: List[List[KeyedJaggedTensor]] = []
274275

276+
if variable_batch_embeddings and not train:
277+
raise RuntimeError("Variable batch size is only supported in training mode")
278+
275279
for _ in range(num_inputs):
276-
_, model_input_by_rank = ModelInput.generate(
277-
batch_size=batch_size,
278-
world_size=world_size,
279-
num_float_features=0,
280-
tables=tables,
281-
weighted_tables=[],
282-
long_indices=False,
283-
tables_pooling=pooling_configs,
284-
)
280+
if variable_batch_embeddings:
281+
_, model_input_by_rank = ModelInput.generate_variable_batch_input(
282+
average_batch_size=batch_size,
283+
world_size=world_size,
284+
num_float_features=0,
285+
# pyre-ignore
286+
tables=tables,
287+
)
288+
else:
289+
_, model_input_by_rank = ModelInput.generate(
290+
batch_size=batch_size,
291+
world_size=world_size,
292+
num_float_features=0,
293+
tables=tables,
294+
weighted_tables=[],
295+
long_indices=False,
296+
tables_pooling=pooling_configs,
297+
)
285298

286299
if train:
287300
sparse_features_by_rank = [
@@ -770,6 +783,7 @@ def benchmark_module(
770783
func_to_benchmark: Callable[..., None] = default_func_to_benchmark,
771784
benchmark_func_kwargs: Optional[Dict[str, Any]] = None,
772785
pooling_configs: Optional[List[int]] = None,
786+
variable_batch_embeddings: bool = False,
773787
) -> List[BenchmarkResult]:
774788
"""
775789
Args:
@@ -820,7 +834,13 @@ def benchmark_module(
820834

821835
num_inputs_to_gen: int = warmup_iters + bench_iters + prof_iters
822836
inputs = get_inputs(
823-
tables, batch_size, world_size, num_inputs_to_gen, train, pooling_configs
837+
tables,
838+
batch_size,
839+
world_size,
840+
num_inputs_to_gen,
841+
train,
842+
pooling_configs,
843+
variable_batch_embeddings,
824844
)
825845

826846
warmup_inputs = [rank_inputs[:warmup_iters] for rank_inputs in inputs]

0 commit comments

Comments
 (0)