Skip to content

Commit 7ae8f1a

Browse files
lizhouyufacebook-github-bot
authored andcommitted
OSS internal MPZCH Module in TorchRec (#3017)
Summary: Pull Request resolved: #3017 ### Major changes - Copy the following files from `fb` to corresponding location in the `torchrec` repository - `fb/distributed/hash_mc_embedding.py → torchrec/distributed/hash_mc_embedding.py` - `fb/modules/hash_mc_evictions.py → torchrec/modules/hash_mc_evictions.py` - `fb/modules/hash_mc_metrics.py → torchrec/modules/hash_mc_metrics.py` - `fb/modules/hash_mc_modules.py → torchrec/modules/hash_mc_modules.py` - Create a `test_hash_zch_mc.py` file in `torchrec/distributed/tests` folder following the `test_quant_mc_embedding.py` in `torchrec/fb/distributed/tests`. - trimmed quantization and inference codes, and only kept the training part. - rewire the related packages from `torchrec.fb` to `torchrec` - Update `BUCK` files in related folders - Update the affected repos to use `torchrec` modules instead of the modules in `torchrec.fb` - Update `/modules/hash_mc_metrics.py` - Replace the tensorboard module with a local file logger in `hash_mc_metrics.py` module Differential Revision: D75559591
1 parent 515b97b commit 7ae8f1a

File tree

6 files changed

+1363
-15
lines changed

6 files changed

+1363
-15
lines changed

.github/workflows/unittest_ci_cpu.yml

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,21 @@ on:
2020
jobs:
2121
build_test:
2222
strategy:
23-
fail-fast: false
24-
matrix:
25-
include:
26-
- os: linux.2xlarge
27-
python-version: 3.9
28-
python-tag: "py39"
29-
- os: linux.2xlarge
30-
python-version: '3.10'
31-
python-tag: "py310"
32-
- os: linux.2xlarge
33-
python-version: '3.11'
34-
python-tag: "py311"
35-
- os: linux.2xlarge
36-
python-version: '3.12'
37-
python-tag: "py312"
23+
fail-fast: false
24+
matrix:
25+
include:
26+
- os: linux.2xlarge
27+
python-version: '3.9'
28+
python-tag: "py39"
29+
- os: linux.2xlarge
30+
python-version: '3.10'
31+
python-tag: "py310"
32+
- os: linux.2xlarge
33+
python-version: '3.11'
34+
python-tag: "py311"
35+
- os: linux.2xlarge
36+
python-version: '3.12'
37+
python-tag: "py312"
3838
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
3939
permissions:
4040
id-token: write
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#!/usr/bin/env python3
2+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
3+
4+
# pyre-strict
5+
6+
import logging as logger
7+
from collections import defaultdict
8+
from typing import Dict, List
9+
10+
import torch
11+
from torchrec.distributed.quant_state import WeightSpec
12+
from torchrec.distributed.types import ShardingType
13+
from torchrec.modules.hash_mc_modules import HashZchManagedCollisionModule
14+
15+
16+
def sharded_zchs_buffers_spec(
17+
sharded_model: torch.nn.Module,
18+
) -> Dict[str, WeightSpec]:
19+
# OUTPUT:
20+
# Example:
21+
# "main_module.module.ec_in_task_arch_hash._decoupled_embedding_collection._mcec_lookup.0.0._mcc_remapper.zchs.viewer_rid_duplicate._hash_zch_identities", [0, 0], [500, 1])
22+
# "main_module.module.ec_in_task_arch_hash._decoupled_embedding_collection._mcec_lookup.0.1._mcc_remapper.zchs.viewer_rid_duplicate._hash_zch_identities", [500, 0], [1000, 1])
23+
24+
# 'main_module.module.ec_in_task_arch_hash._decoupled_embedding_collection._mcec_lookup.0.0._mcc_remapper.zchs.viewer_rid_duplicate._hash_zch_identities': WeightSpec(fqn='main_module.module.ec_in_task_arch_hash._ d_embedding_collection._managed_collision_collection.viewer_rid_duplicate._hash_zch_identities'
25+
def _get_table_names(
26+
sharded_module: torch.nn.Module,
27+
) -> List[str]:
28+
table_names: List[str] = []
29+
for _, module in sharded_module.named_modules():
30+
type_name: str = type(module).__name__
31+
if "ShardedMCCRemapper" in type_name:
32+
for table_name in module._tables:
33+
if table_name not in table_names:
34+
table_names.append(table_name)
35+
return table_names
36+
37+
def _get_unsharded_fqn_identities(
38+
sharded_module: torch.nn.Module,
39+
fqn: str,
40+
table_name: str,
41+
) -> str:
42+
for module_fqn, module in sharded_module.named_modules():
43+
type_name: str = type(module).__name__
44+
if "ManagedCollisionCollection" in type_name:
45+
if table_name in module._table_to_features:
46+
return f"{fqn}.{module_fqn}._managed_collision_modules.{table_name}.{HashZchManagedCollisionModule.IDENTITY_BUFFER}"
47+
logger.info(f"did not find table {table_name} in module {fqn}")
48+
return ""
49+
50+
ret: Dict[str, WeightSpec] = defaultdict()
51+
for module_fqn, module in sharded_model.named_modules():
52+
type_name: str = type(module).__name__
53+
if "ShardedQuantManagedCollisionEmbeddingCollection" in type_name:
54+
sharding_type = ShardingType.ROW_WISE.value
55+
table_name_to_unsharded_fqn_identities: Dict[str, str] = {}
56+
for subfqn, submodule in module.named_modules():
57+
type_name: str = type(submodule).__name__
58+
if "ShardedMCCRemapper" in type_name:
59+
for table_name in submodule.zchs.keys():
60+
# identities tensor has only one column
61+
shard_offsets: List[int] = [
62+
submodule._shard_metadata[table_name][0],
63+
0,
64+
]
65+
shard_sizes: List[int] = [
66+
submodule._shard_metadata[table_name][1],
67+
1,
68+
]
69+
if table_name not in table_name_to_unsharded_fqn_identities:
70+
table_name_to_unsharded_fqn_identities[table_name] = (
71+
_get_unsharded_fqn_identities(
72+
module, module_fqn, table_name
73+
)
74+
)
75+
unsharded_fqn_identities: str = (
76+
table_name_to_unsharded_fqn_identities[table_name]
77+
)
78+
# subfqn contains the index of sharding, so no need to add it specifically here
79+
sharded_fqn_identities: str = (
80+
f"{module_fqn}.{subfqn}.zchs.{table_name}.{HashZchManagedCollisionModule.IDENTITY_BUFFER}"
81+
)
82+
ret[sharded_fqn_identities] = WeightSpec(
83+
fqn=unsharded_fqn_identities,
84+
shard_offsets=shard_offsets,
85+
shard_sizes=shard_sizes,
86+
sharding_type=sharding_type,
87+
)
88+
return ret
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
#!/usr/bin/env python3
2+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
3+
4+
#!/usr/bin/env python3
5+
6+
# pyre-strict
7+
8+
import copy
9+
import multiprocessing
10+
import unittest
11+
from typing import Any, Dict, List, Optional, Tuple, Union
12+
13+
import torch
14+
from pyre_extensions import none_throws
15+
from torch import nn
16+
from torchrec import (
17+
EmbeddingCollection,
18+
EmbeddingConfig,
19+
JaggedTensor,
20+
KeyedJaggedTensor,
21+
KeyedTensor,
22+
)
23+
from torchrec.distributed import ModuleSharder, ShardingEnv
24+
from torchrec.distributed.mc_modules import ManagedCollisionCollectionSharder
25+
26+
from torchrec.distributed.shard import _shard_modules
27+
from torchrec.distributed.sharding_plan import (
28+
construct_module_sharding_plan,
29+
EmbeddingCollectionSharder,
30+
ManagedCollisionEmbeddingCollectionSharder,
31+
row_wise,
32+
)
33+
from torchrec.distributed.test_utils.multi_process import (
34+
MultiProcessContext,
35+
MultiProcessTestBase,
36+
)
37+
from torchrec.distributed.types import ShardingPlan
38+
from torchrec.modules.hash_mc_evictions import (
39+
HashZchEvictionConfig,
40+
HashZchEvictionPolicyName,
41+
)
42+
from torchrec.modules.hash_mc_modules import HashZchManagedCollisionModule
43+
from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection
44+
from torchrec.modules.mc_modules import ManagedCollisionCollection
45+
46+
BASE_LEAF_MODULES = [
47+
"IntNBitTableBatchedEmbeddingBagsCodegen",
48+
"HashZchManagedCollisionModule",
49+
]
50+
51+
52+
class SparseArch(nn.Module):
53+
def __init__(
54+
self,
55+
tables: List[EmbeddingConfig],
56+
device: torch.device,
57+
buckets: int,
58+
return_remapped: bool = False,
59+
input_hash_size: int = 4000,
60+
is_inference: bool = False,
61+
) -> None:
62+
super().__init__()
63+
self._return_remapped = return_remapped
64+
65+
mc_modules = {}
66+
mc_modules["table_0"] = HashZchManagedCollisionModule(
67+
is_inference=is_inference,
68+
zch_size=(tables[0].num_embeddings),
69+
input_hash_size=input_hash_size,
70+
device=device,
71+
total_num_buckets=buckets,
72+
eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION,
73+
eviction_config=HashZchEvictionConfig(
74+
features=["feature_0"],
75+
single_ttl=1,
76+
),
77+
)
78+
79+
mc_modules["table_1"] = HashZchManagedCollisionModule(
80+
is_inference=is_inference,
81+
zch_size=(tables[1].num_embeddings),
82+
device=device,
83+
input_hash_size=input_hash_size,
84+
total_num_buckets=buckets,
85+
eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION,
86+
eviction_config=HashZchEvictionConfig(
87+
features=["feature_1"],
88+
single_ttl=1,
89+
),
90+
)
91+
92+
self._mc_ec: ManagedCollisionEmbeddingCollection = (
93+
ManagedCollisionEmbeddingCollection(
94+
EmbeddingCollection(
95+
tables=tables,
96+
device=device,
97+
),
98+
ManagedCollisionCollection(
99+
managed_collision_modules=mc_modules,
100+
embedding_configs=tables,
101+
),
102+
return_remapped_features=self._return_remapped,
103+
)
104+
)
105+
106+
def forward(
107+
self, kjt: KeyedJaggedTensor
108+
) -> Tuple[
109+
Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor]
110+
]:
111+
return self._mc_ec(kjt)
112+
113+
114+
class TestHashZchMcEmbedding(MultiProcessTestBase):
115+
116+
@unittest.skipIf(torch.cuda.device_count() <= 1, "Not enough GPUs, skipping")
117+
def test_hash_zch_mc_ec(self) -> None:
118+
119+
WORLD_SIZE = 2
120+
121+
embedding_config = [
122+
EmbeddingConfig(
123+
name="table_0",
124+
feature_names=["feature_0"],
125+
embedding_dim=8,
126+
num_embeddings=16,
127+
),
128+
EmbeddingConfig(
129+
name="table_1",
130+
feature_names=["feature_1"],
131+
embedding_dim=8,
132+
num_embeddings=32,
133+
),
134+
]
135+
136+
train_input_per_rank = [
137+
KeyedJaggedTensor.from_lengths_sync(
138+
keys=["feature_0", "feature_1"],
139+
values=torch.LongTensor(
140+
list(range(1000, 1025)),
141+
),
142+
lengths=torch.LongTensor([1] * 8 + [2] * 8),
143+
weights=None,
144+
),
145+
KeyedJaggedTensor.from_lengths_sync(
146+
keys=["feature_0", "feature_1"],
147+
values=torch.LongTensor(
148+
list(range(25000, 25025)),
149+
),
150+
lengths=torch.LongTensor([1] * 8 + [2] * 8),
151+
weights=None,
152+
),
153+
]
154+
train_state_dict = multiprocessing.Manager().dict()
155+
156+
# Train Model with ZCH on GPU
157+
self._run_multi_process_test(
158+
callable=_train_model,
159+
world_size=WORLD_SIZE,
160+
tables=embedding_config,
161+
num_buckets=2,
162+
kjt_input_per_rank=train_input_per_rank,
163+
sharder=ManagedCollisionEmbeddingCollectionSharder(
164+
EmbeddingCollectionSharder(),
165+
ManagedCollisionCollectionSharder(),
166+
),
167+
return_dict=train_state_dict,
168+
backend="nccl",
169+
)
170+
171+
172+
def _train_model(
173+
tables: List[EmbeddingConfig],
174+
num_buckets: int,
175+
rank: int,
176+
world_size: int,
177+
kjt_input_per_rank: List[KeyedJaggedTensor],
178+
sharder: ModuleSharder[nn.Module],
179+
backend: str,
180+
return_dict: Dict[str, Any],
181+
local_size: Optional[int] = None,
182+
) -> None:
183+
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
184+
kjt_input = kjt_input_per_rank[rank].to(ctx.device)
185+
186+
train_model = SparseArch(
187+
tables=tables,
188+
device=torch.device("cuda"),
189+
input_hash_size=0,
190+
return_remapped=True,
191+
buckets=num_buckets,
192+
)
193+
train_sharding_plan = construct_module_sharding_plan(
194+
train_model._mc_ec,
195+
per_param_sharding={"table_0": row_wise(), "table_1": row_wise()},
196+
local_size=local_size,
197+
world_size=world_size,
198+
device_type="cuda",
199+
sharder=sharder,
200+
)
201+
print(f"train_sharding_plan: {train_sharding_plan}")
202+
sharded_train_model = _shard_modules(
203+
module=copy.deepcopy(train_model),
204+
plan=ShardingPlan({"_mc_ec": train_sharding_plan}),
205+
env=ShardingEnv.from_process_group(none_throws(ctx.pg)),
206+
sharders=[sharder],
207+
device=ctx.device,
208+
)
209+
# train
210+
sharded_train_model(kjt_input.to(ctx.device))
211+
212+
for (
213+
key,
214+
value,
215+
) in (
216+
# pyre-ignore
217+
sharded_train_model._mc_ec._managed_collision_collection._managed_collision_modules.state_dict().items()
218+
):
219+
return_dict[f"mc_{key}_{rank}"] = value.cpu()
220+
for (
221+
key,
222+
value,
223+
# pyre-ignore
224+
) in sharded_train_model._mc_ec._embedding_collection.state_dict().items():
225+
tensors = []
226+
for i in range(len(value.local_shards())):
227+
tensors.append(value.local_shards()[i].tensor.cpu())
228+
return_dict[f"ec_{key}_{rank}"] = torch.cat(tensors, dim=0)

0 commit comments

Comments
 (0)