Skip to content

Commit 3c49cbe

Browse files
lizhouyufacebook-github-bot
authored andcommitted
OSS internal MPZCH Module in TorchRec (#3017)
Summary: Pull Request resolved: #3017 ### Major changes - Copy the following files from `fb` to corresponding location in the `torchrec` repository - `fb/distributed/hash_mc_embedding.py → torchrec/distributed/hash_mc_embedding.py` - `fb/modules/hash_mc_evictions.py → torchrec/modules/hash_mc_evictions.py` - `fb/modules/hash_mc_metrics.py → torchrec/modules/hash_mc_metrics.py` - `fb/modules/hash_mc_modules.py → torchrec/modules/hash_mc_modules.py` - Create a `test_hash_zch_mc.py` file in `torchrec/distributed/tests` folder following the `test_quant_mc_embedding.py` in `torchrec/fb/distributed/tests`. - trimmed quantization and inference codes, and only kept the training part. - rewire the related packages from `torchrec.fb` to `torchrec` - Update `BUCK` files in related folders - Update the affected repos to use `torchrec` modules instead of the modules in `torchrec.fb` Differential Revision: D75559591
1 parent 515b97b commit 3c49cbe

File tree

5 files changed

+1385
-0
lines changed

5 files changed

+1385
-0
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#!/usr/bin/env python3
2+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
3+
4+
# pyre-strict
5+
6+
import logging as logger
7+
from collections import defaultdict
8+
from typing import Dict, List
9+
10+
import torch
11+
from torchrec.distributed.quant_state import WeightSpec
12+
from torchrec.distributed.types import ShardingType
13+
from torchrec.modules.hash_mc_modules import HashZchManagedCollisionModule
14+
15+
16+
def sharded_zchs_buffers_spec(
17+
sharded_model: torch.nn.Module,
18+
) -> Dict[str, WeightSpec]:
19+
# OUTPUT:
20+
# Example:
21+
# "main_module.module.ec_in_task_arch_hash._decoupled_embedding_collection._mcec_lookup.0.0._mcc_remapper.zchs.viewer_rid_duplicate._hash_zch_identities", [0, 0], [500, 1])
22+
# "main_module.module.ec_in_task_arch_hash._decoupled_embedding_collection._mcec_lookup.0.1._mcc_remapper.zchs.viewer_rid_duplicate._hash_zch_identities", [500, 0], [1000, 1])
23+
24+
# 'main_module.module.ec_in_task_arch_hash._decoupled_embedding_collection._mcec_lookup.0.0._mcc_remapper.zchs.viewer_rid_duplicate._hash_zch_identities': WeightSpec(fqn='main_module.module.ec_in_task_arch_hash._ d_embedding_collection._managed_collision_collection.viewer_rid_duplicate._hash_zch_identities'
25+
def _get_table_names(
26+
sharded_module: torch.nn.Module,
27+
) -> List[str]:
28+
table_names: List[str] = []
29+
for _, module in sharded_module.named_modules():
30+
type_name: str = type(module).__name__
31+
if "ShardedMCCRemapper" in type_name:
32+
for table_name in module._tables:
33+
if table_name not in table_names:
34+
table_names.append(table_name)
35+
return table_names
36+
37+
def _get_unsharded_fqn_identities(
38+
sharded_module: torch.nn.Module,
39+
fqn: str,
40+
table_name: str,
41+
) -> str:
42+
for module_fqn, module in sharded_module.named_modules():
43+
type_name: str = type(module).__name__
44+
if "ManagedCollisionCollection" in type_name:
45+
if table_name in module._table_to_features:
46+
return f"{fqn}.{module_fqn}._managed_collision_modules.{table_name}.{HashZchManagedCollisionModule.IDENTITY_BUFFER}"
47+
logger.info(f"did not find table {table_name} in module {fqn}")
48+
return ""
49+
50+
ret: Dict[str, WeightSpec] = defaultdict()
51+
for module_fqn, module in sharded_model.named_modules():
52+
type_name: str = type(module).__name__
53+
if "ShardedQuantManagedCollisionEmbeddingCollection" in type_name:
54+
sharding_type = ShardingType.ROW_WISE.value
55+
table_name_to_unsharded_fqn_identities: Dict[str, str] = {}
56+
for subfqn, submodule in module.named_modules():
57+
type_name: str = type(submodule).__name__
58+
if "ShardedMCCRemapper" in type_name:
59+
for table_name in submodule.zchs.keys():
60+
# identities tensor has only one column
61+
shard_offsets: List[int] = [
62+
submodule._shard_metadata[table_name][0],
63+
0,
64+
]
65+
shard_sizes: List[int] = [
66+
submodule._shard_metadata[table_name][1],
67+
1,
68+
]
69+
if table_name not in table_name_to_unsharded_fqn_identities:
70+
table_name_to_unsharded_fqn_identities[table_name] = (
71+
_get_unsharded_fqn_identities(
72+
module, module_fqn, table_name
73+
)
74+
)
75+
unsharded_fqn_identities: str = (
76+
table_name_to_unsharded_fqn_identities[table_name]
77+
)
78+
# subfqn contains the index of sharding, so no need to add it specifically here
79+
sharded_fqn_identities: str = (
80+
f"{module_fqn}.{subfqn}.zchs.{table_name}.{HashZchManagedCollisionModule.IDENTITY_BUFFER}"
81+
)
82+
ret[sharded_fqn_identities] = WeightSpec(
83+
fqn=unsharded_fqn_identities,
84+
shard_offsets=shard_offsets,
85+
shard_sizes=shard_sizes,
86+
sharding_type=sharding_type,
87+
)
88+
return ret
Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
#!/usr/bin/env python3
2+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
3+
4+
#!/usr/bin/env python3
5+
6+
# pyre-strict
7+
8+
import copy
9+
import multiprocessing
10+
import unittest
11+
from typing import Any, Dict, List, Optional, Tuple, Union
12+
13+
import torch
14+
from hypothesis import settings
15+
from libfb.py.pyre import none_throws
16+
from torch import nn
17+
from torchrec import (
18+
EmbeddingCollection,
19+
EmbeddingConfig,
20+
JaggedTensor,
21+
KeyedJaggedTensor,
22+
KeyedTensor,
23+
)
24+
from torchrec.distributed import ModuleSharder, ShardingEnv
25+
from torchrec.distributed.mc_modules import ManagedCollisionCollectionSharder
26+
27+
from torchrec.distributed.shard import _shard_modules
28+
from torchrec.distributed.sharding_plan import (
29+
construct_module_sharding_plan,
30+
EmbeddingCollectionSharder,
31+
ManagedCollisionEmbeddingCollectionSharder,
32+
row_wise,
33+
)
34+
from torchrec.distributed.test_utils.multi_process import (
35+
MultiProcessContext,
36+
MultiProcessTestBase,
37+
)
38+
from torchrec.distributed.types import ShardingPlan
39+
from torchrec.modules.hash_mc_evictions import (
40+
HashZchEvictionConfig,
41+
HashZchEvictionPolicyName,
42+
)
43+
from torchrec.modules.hash_mc_modules import HashZchManagedCollisionModule
44+
from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection
45+
from torchrec.modules.mc_modules import ManagedCollisionCollection
46+
47+
BASE_LEAF_MODULES = [
48+
"IntNBitTableBatchedEmbeddingBagsCodegen",
49+
"HashZchManagedCollisionModule",
50+
]
51+
52+
53+
class SparseArch(nn.Module):
54+
def __init__(
55+
self,
56+
tables: List[EmbeddingConfig],
57+
device: torch.device,
58+
buckets: int,
59+
return_remapped: bool = False,
60+
input_hash_size: int = 4000,
61+
is_inference: bool = False,
62+
) -> None:
63+
super().__init__()
64+
self._return_remapped = return_remapped
65+
66+
mc_modules = {}
67+
mc_modules["table_0"] = HashZchManagedCollisionModule(
68+
is_inference=is_inference,
69+
zch_size=(tables[0].num_embeddings),
70+
input_hash_size=input_hash_size,
71+
device=device,
72+
total_num_buckets=buckets,
73+
eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION,
74+
eviction_config=HashZchEvictionConfig(
75+
features=["feature_0"],
76+
single_ttl=1,
77+
),
78+
)
79+
80+
mc_modules["table_1"] = HashZchManagedCollisionModule(
81+
is_inference=is_inference,
82+
zch_size=(tables[1].num_embeddings),
83+
device=device,
84+
input_hash_size=input_hash_size,
85+
total_num_buckets=buckets,
86+
eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION,
87+
eviction_config=HashZchEvictionConfig(
88+
features=["feature_1"],
89+
single_ttl=1,
90+
),
91+
)
92+
93+
self._mc_ec: ManagedCollisionEmbeddingCollection = (
94+
ManagedCollisionEmbeddingCollection(
95+
EmbeddingCollection(
96+
tables=tables,
97+
device=device,
98+
),
99+
ManagedCollisionCollection(
100+
managed_collision_modules=mc_modules,
101+
embedding_configs=tables,
102+
),
103+
return_remapped_features=self._return_remapped,
104+
)
105+
)
106+
107+
def forward(
108+
self, kjt: KeyedJaggedTensor
109+
) -> Tuple[
110+
Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor]
111+
]:
112+
return self._mc_ec(kjt)
113+
114+
115+
class TestHashZchMcEmbedding(MultiProcessTestBase):
116+
117+
@settings(deadline=None)
118+
@unittest.skipIf(torch.cuda.device_count() <= 1, "Not enough GPUs, skipping")
119+
def test_hash_zch_mc_ec(self) -> None:
120+
121+
WORLD_SIZE = 2
122+
123+
embedding_config = [
124+
EmbeddingConfig(
125+
name="table_0",
126+
feature_names=["feature_0"],
127+
embedding_dim=8,
128+
num_embeddings=16,
129+
),
130+
EmbeddingConfig(
131+
name="table_1",
132+
feature_names=["feature_1"],
133+
embedding_dim=8,
134+
num_embeddings=32,
135+
),
136+
]
137+
138+
train_input_per_rank = [
139+
KeyedJaggedTensor.from_lengths_sync(
140+
keys=["feature_0", "feature_1"],
141+
values=torch.LongTensor(
142+
list(range(1000, 1025)),
143+
),
144+
lengths=torch.LongTensor([1] * 8 + [2] * 8),
145+
weights=None,
146+
),
147+
KeyedJaggedTensor.from_lengths_sync(
148+
keys=["feature_0", "feature_1"],
149+
values=torch.LongTensor(
150+
list(range(25000, 25025)),
151+
),
152+
lengths=torch.LongTensor([1] * 8 + [2] * 8),
153+
weights=None,
154+
),
155+
]
156+
train_state_dict = multiprocessing.Manager().dict()
157+
158+
# Train Model with ZCH on GPU
159+
self._run_multi_process_test(
160+
callable=_train_model,
161+
world_size=WORLD_SIZE,
162+
tables=embedding_config,
163+
num_buckets=2,
164+
kjt_input_per_rank=train_input_per_rank,
165+
sharder=ManagedCollisionEmbeddingCollectionSharder(
166+
EmbeddingCollectionSharder(),
167+
ManagedCollisionCollectionSharder(),
168+
),
169+
return_dict=train_state_dict,
170+
backend="nccl",
171+
)
172+
173+
174+
def _train_model(
175+
tables: List[EmbeddingConfig],
176+
num_buckets: int,
177+
rank: int,
178+
world_size: int,
179+
kjt_input_per_rank: List[KeyedJaggedTensor],
180+
sharder: ModuleSharder[nn.Module],
181+
backend: str,
182+
return_dict: Dict[str, Any],
183+
local_size: Optional[int] = None,
184+
) -> None:
185+
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
186+
kjt_input = kjt_input_per_rank[rank].to(ctx.device)
187+
188+
train_model = SparseArch(
189+
tables=tables,
190+
device=torch.device("cuda"),
191+
input_hash_size=0,
192+
return_remapped=True,
193+
buckets=num_buckets,
194+
)
195+
train_sharding_plan = construct_module_sharding_plan(
196+
train_model._mc_ec,
197+
per_param_sharding={"table_0": row_wise(), "table_1": row_wise()},
198+
local_size=local_size,
199+
world_size=world_size,
200+
device_type="cuda",
201+
sharder=sharder,
202+
)
203+
print(f"train_sharding_plan: {train_sharding_plan}")
204+
sharded_train_model = _shard_modules(
205+
module=copy.deepcopy(train_model),
206+
plan=ShardingPlan({"_mc_ec": train_sharding_plan}),
207+
env=ShardingEnv.from_process_group(none_throws(ctx.pg)),
208+
sharders=[sharder],
209+
device=ctx.device,
210+
)
211+
# train
212+
sharded_train_model(kjt_input.to(ctx.device))
213+
214+
for (
215+
key,
216+
value,
217+
) in (
218+
# pyre-ignore
219+
sharded_train_model._mc_ec._managed_collision_collection._managed_collision_modules.state_dict().items()
220+
):
221+
return_dict[f"mc_{key}_{rank}"] = value.cpu()
222+
for (
223+
key,
224+
value,
225+
# pyre-ignore
226+
) in sharded_train_model._mc_ec._embedding_collection.state_dict().items():
227+
tensors = []
228+
for i in range(len(value.local_shards())):
229+
tensors.append(value.local_shards()[i].tensor.cpu())
230+
return_dict[f"ec_{key}_{rank}"] = torch.cat(tensors, dim=0)

0 commit comments

Comments
 (0)