Skip to content

Commit f7b8ef9

Browse files
PaulZhang12facebook-github-bot
authored andcommitted
Enable TW/CW pruning with no index remapping needed, only number of rows post pruning (#2343)
Summary: Pull Request resolved: #2343 Enable TW pruning for TorchRec inference modules. We switch from pruned_indices_remapping to num_rows_post_pruning given the fact that the indices remapping isn't calculated until after physical transformations. Logical transformations (optimizing the model itself) should not depend on the reset of physical transformations, and the actual index remapping will be set later on in model loading Reviewed By: dstaay-fb Differential Revision: D61879687
1 parent c2d6457 commit f7b8ef9

File tree

13 files changed

+326
-276
lines changed

13 files changed

+326
-276
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def create_sharding_infos_by_sharding(
287287
embedding_names=embedding_names,
288288
weight_init_max=config.weight_init_max,
289289
weight_init_min=config.weight_init_min,
290-
pruning_indices_remapping=config.pruning_indices_remapping,
290+
num_embeddings_post_pruning=config.num_embeddings_post_pruning,
291291
),
292292
param_sharding=parameter_sharding,
293293
param=param,
@@ -402,7 +402,7 @@ def create_sharding_infos_by_sharding_device_group(
402402
embedding_names=embedding_names,
403403
weight_init_max=config.weight_init_max,
404404
weight_init_min=config.weight_init_min,
405-
pruning_indices_remapping=config.pruning_indices_remapping,
405+
num_embeddings_post_pruning=config.num_embeddings_post_pruning,
406406
),
407407
param_sharding=parameter_sharding,
408408
param=param,

torchrec/distributed/quant_embedding_kernel.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,6 @@ def __init__(
214214
fused_params
215215
)
216216

217-
index_remapping = [
218-
table.pruning_indices_remapping for table in config.embedding_tables
219-
]
220-
if all(v is None for v in index_remapping):
221-
index_remapping = None
222217
self._runtime_device: torch.device = _get_runtime_device(device, config)
223218
# 16 for CUDA, 1 for others like CPU and MTIA.
224219
self._tbe_row_alignment: int = 16 if self._runtime_device.type == "cuda" else 1
@@ -244,8 +239,6 @@ def __init__(
244239
)
245240
],
246241
device=device,
247-
# pyre-ignore
248-
index_remapping=index_remapping,
249242
pooling_mode=self._pooling,
250243
feature_table_map=self._feature_table_map,
251244
row_alignment=self._tbe_row_alignment,

torchrec/distributed/quant_state.py

Lines changed: 45 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import copy
1111
from dataclasses import dataclass
12+
from functools import partial
1213
from typing import Any, Dict, List, Mapping, Optional, Tuple, TypeVar, Union
1314

1415
import torch
@@ -46,6 +47,47 @@ def _append_table_shard(
4647
d[table_name].append(shard)
4748

4849

50+
def post_state_dict_hook(
51+
# Union["ShardedQuantEmbeddingBagCollection", "ShardedQuantEmbeddingCollection"]
52+
# pyre-ignore [24]
53+
module: ShardedEmbeddingModule,
54+
destination: Dict[str, torch.Tensor],
55+
prefix: str,
56+
_local_metadata: Dict[str, Any],
57+
tables_weights_prefix: str, # "embedding_bags" or "embeddings"
58+
) -> None:
59+
for (
60+
table_name,
61+
sharded_t,
62+
) in module._table_name_to_sharded_tensor.items():
63+
destination[f"{prefix}{tables_weights_prefix}.{table_name}.weight"] = sharded_t
64+
65+
for sfx, dict_sharded_t, dict_t_list in [
66+
(
67+
"weight_qscale",
68+
module._table_name_to_sharded_tensor_qscale,
69+
module._table_name_to_tensors_list_qscale,
70+
),
71+
(
72+
"weight_qbias",
73+
module._table_name_to_sharded_tensor_qbias,
74+
module._table_name_to_tensors_list_qbias,
75+
),
76+
]:
77+
for (
78+
table_name,
79+
sharded_t,
80+
) in dict_sharded_t.items():
81+
destination[f"{prefix}{tables_weights_prefix}.{table_name}.{sfx}"] = (
82+
sharded_t
83+
)
84+
for (
85+
table_name,
86+
t_list,
87+
) in dict_t_list.items():
88+
destination[f"{prefix}{tables_weights_prefix}.{table_name}.{sfx}"] = t_list
89+
90+
4991
class ShardedQuantEmbeddingModuleState(
5092
ShardedEmbeddingModule[CompIn, DistOut, Out, ShrdCtx]
5193
):
@@ -82,17 +124,6 @@ def _initialize_torch_state( # noqa: C901
82124
] = {}
83125
self._table_name_to_tensors_list_qbias: Dict[str, List[torch.Tensor]] = {}
84126

85-
# pruning_index_remappings
86-
self._table_name_to_local_shards_pruning_index_remappings: Dict[
87-
str, List[Shard]
88-
] = {}
89-
self._table_name_to_sharded_tensor_pruning_index_remappings: Dict[
90-
str, Union[torch.Tensor, ShardedTensorBase]
91-
] = {}
92-
self._table_name_to_tensors_list_pruning_index_remappings: Dict[
93-
str, List[torch.Tensor]
94-
] = {}
95-
96127
for tbe, config in tbes.items():
97128
for (tbe_split_w, tbe_split_qscale, tbe_split_qbias), table in zip(
98129
tbe.split_embedding_weights_with_scale_bias(split_scale_bias_mode=2),
@@ -184,43 +215,6 @@ def _initialize_torch_state( # noqa: C901
184215
Shard(tensor=tbe_split_qparam, metadata=qmetadata),
185216
)
186217
# end of weight_qscale & weight_qbias section
187-
if table.pruning_indices_remapping is not None:
188-
for (
189-
qparam,
190-
table_name_to_local_shards,
191-
_,
192-
) in [
193-
(
194-
table.pruning_indices_remapping,
195-
self._table_name_to_local_shards_pruning_index_remappings,
196-
self._table_name_to_tensors_list_pruning_index_remappings,
197-
)
198-
]:
199-
parameter_sharding: ParameterSharding = (
200-
table_name_to_parameter_sharding[table.name]
201-
)
202-
sharding_type: str = parameter_sharding.sharding_type
203-
204-
assert sharding_type in [
205-
ShardingType.TABLE_WISE.value,
206-
ShardingType.COLUMN_WISE.value,
207-
]
208-
209-
qmetadata = ShardMetadata(
210-
shard_offsets=[0],
211-
shard_sizes=[
212-
qparam.shape[0],
213-
],
214-
placement=table.local_metadata.placement,
215-
)
216-
# TODO(ivankobzarev): "meta" sharding support: cleanup when copy to "meta" moves all tensors to "meta"
217-
if qmetadata.placement.device != qparam.device:
218-
qmetadata.placement = _remote_device(qparam.device)
219-
_append_table_shard(
220-
table_name_to_local_shards,
221-
table.name,
222-
Shard(tensor=qparam, metadata=qmetadata),
223-
)
224218

225219
for table_name_to_local_shards, table_name_to_sharded_tensor in [
226220
(self._table_name_to_local_shards, self._table_name_to_sharded_tensor),
@@ -263,65 +257,9 @@ def _initialize_torch_state( # noqa: C901
263257
)
264258
)
265259

266-
for table_name_to_local_shards, table_name_to_sharded_tensor in [
267-
(
268-
self._table_name_to_local_shards_pruning_index_remappings,
269-
self._table_name_to_sharded_tensor_pruning_index_remappings,
270-
),
271-
]:
272-
for table_name, local_shards in table_name_to_local_shards.items():
273-
# Single Tensor per table (TW sharding)
274-
table_name_to_sharded_tensor[table_name] = local_shards[0].tensor
275-
continue
276-
277-
def post_state_dict_hook(
278-
# Union["ShardedQuantEmbeddingBagCollection", "ShardedQuantEmbeddingCollection"]
279-
module: ShardedQuantEmbeddingModuleState[CompIn, DistOut, Out, ShrdCtx],
280-
destination: Dict[str, torch.Tensor],
281-
prefix: str,
282-
_local_metadata: Dict[str, Any],
283-
) -> None:
284-
for (
285-
table_name,
286-
sharded_t,
287-
) in module._table_name_to_sharded_tensor.items():
288-
destination[f"{prefix}{tables_weights_prefix}.{table_name}.weight"] = (
289-
sharded_t
290-
)
291-
292-
for sfx, dict_sharded_t, dict_t_list in [
293-
(
294-
"weight_qscale",
295-
module._table_name_to_sharded_tensor_qscale,
296-
module._table_name_to_tensors_list_qscale,
297-
),
298-
(
299-
"weight_qbias",
300-
module._table_name_to_sharded_tensor_qbias,
301-
module._table_name_to_tensors_list_qbias,
302-
),
303-
(
304-
"index_remappings_array",
305-
module._table_name_to_sharded_tensor_pruning_index_remappings,
306-
module._table_name_to_tensors_list_pruning_index_remappings,
307-
),
308-
]:
309-
for (
310-
table_name,
311-
sharded_t,
312-
) in dict_sharded_t.items():
313-
destination[
314-
f"{prefix}{tables_weights_prefix}.{table_name}.{sfx}"
315-
] = sharded_t
316-
for (
317-
table_name,
318-
t_list,
319-
) in dict_t_list.items():
320-
destination[
321-
f"{prefix}{tables_weights_prefix}.{table_name}.{sfx}"
322-
] = t_list
323-
324-
self._register_state_dict_hook(post_state_dict_hook)
260+
self._register_state_dict_hook(
261+
partial(post_state_dict_hook, tables_weights_prefix=tables_weights_prefix)
262+
)
325263

326264
def _load_from_state_dict(
327265
# Union["ShardedQuantEmbeddingBagCollection", "ShardedQuantEmbeddingCollection"]

torchrec/distributed/sharding/cw_sharding.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
ShardingEnv,
4949
ShardMetadata,
5050
)
51+
from torchrec.distributed.utils import none_throws
5152
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
5253
from torchrec.streamable import Multistreamable
5354

@@ -157,7 +158,12 @@ def _shard(
157158
shards_metadata=shards,
158159
size=torch.Size(
159160
[
160-
info.embedding_config.num_embeddings,
161+
(
162+
info.embedding_config.num_embeddings_post_pruning
163+
if info.embedding_config.num_embeddings_post_pruning
164+
is not None
165+
else info.embedding_config.num_embeddings
166+
),
161167
info.embedding_config.embedding_dim,
162168
]
163169
),
@@ -169,7 +175,12 @@ def _shard(
169175
mesh=self._env.device_mesh,
170176
placements=(Shard(1),),
171177
size=(
172-
info.embedding_config.num_embeddings,
178+
(
179+
info.embedding_config.num_embeddings_post_pruning
180+
if info.embedding_config.num_embeddings_post_pruning
181+
is not None
182+
else info.embedding_config.num_embeddings
183+
),
173184
info.embedding_config.embedding_dim,
174185
),
175186
stride=info.param.stride(),
@@ -190,7 +201,14 @@ def _shard(
190201
pooling=info.embedding_config.pooling,
191202
is_weighted=info.embedding_config.is_weighted,
192203
has_feature_processor=info.embedding_config.has_feature_processor,
193-
local_rows=info.embedding_config.num_embeddings,
204+
local_rows=(
205+
none_throws(
206+
info.embedding_config.num_embeddings_post_pruning
207+
)
208+
if info.embedding_config.num_embeddings_post_pruning
209+
is not None
210+
else info.embedding_config.num_embeddings
211+
),
194212
local_cols=shards[i].shard_sizes[1],
195213
compute_kernel=EmbeddingComputeKernel(
196214
info.param_sharding.compute_kernel
@@ -201,7 +219,7 @@ def _shard(
201219
fused_params=info.fused_params,
202220
weight_init_max=info.embedding_config.weight_init_max,
203221
weight_init_min=info.embedding_config.weight_init_min,
204-
pruning_indices_remapping=info.embedding_config.pruning_indices_remapping,
222+
num_embeddings_post_pruning=info.embedding_config.num_embeddings_post_pruning,
205223
)
206224
)
207225

torchrec/distributed/sharding/rw_sharding.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,12 @@ def _shard(
158158
shards_metadata=shards,
159159
size=torch.Size(
160160
[
161-
info.embedding_config.num_embeddings,
161+
(
162+
info.embedding_config.num_embeddings_post_pruning
163+
if info.embedding_config.num_embeddings_post_pruning
164+
is not None
165+
else info.embedding_config.num_embeddings
166+
),
162167
info.embedding_config.embedding_dim,
163168
]
164169
),
@@ -170,7 +175,12 @@ def _shard(
170175
mesh=self._env.device_mesh,
171176
placements=(Shard(0),),
172177
size=(
173-
info.embedding_config.num_embeddings,
178+
(
179+
info.embedding_config.num_embeddings_post_pruning
180+
if info.embedding_config.num_embeddings_post_pruning
181+
is not None
182+
else info.embedding_config.num_embeddings
183+
),
174184
info.embedding_config.embedding_dim,
175185
),
176186
stride=info.param.stride(),
@@ -201,6 +211,7 @@ def _shard(
201211
weight_init_max=info.embedding_config.weight_init_max,
202212
weight_init_min=info.embedding_config.weight_init_min,
203213
fused_params=info.fused_params,
214+
num_embeddings_post_pruning=info.embedding_config.num_embeddings_post_pruning,
204215
)
205216
)
206217
return tables_per_rank

torchrec/distributed/sharding/tw_sharding.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
ShardingEnv,
5050
ShardMetadata,
5151
)
52+
from torchrec.distributed.utils import none_throws
5253
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
5354
from torchrec.streamable import Multistreamable
5455

@@ -103,11 +104,17 @@ def _shard(
103104
# pyre-fixme [16]
104105
shards = info.param_sharding.sharding_spec.shards
105106
# construct the global sharded_tensor_metadata
107+
106108
global_metadata = ShardedTensorMetadata(
107109
shards_metadata=shards,
108110
size=torch.Size(
109111
[
110-
info.embedding_config.num_embeddings,
112+
(
113+
info.embedding_config.num_embeddings_post_pruning
114+
if info.embedding_config.num_embeddings_post_pruning
115+
is not None
116+
else info.embedding_config.num_embeddings
117+
),
111118
info.embedding_config.embedding_dim,
112119
]
113120
),
@@ -139,7 +146,11 @@ def _shard(
139146
pooling=info.embedding_config.pooling,
140147
is_weighted=info.embedding_config.is_weighted,
141148
has_feature_processor=info.embedding_config.has_feature_processor,
142-
local_rows=info.embedding_config.num_embeddings,
149+
local_rows=(
150+
none_throws(info.embedding_config.num_embeddings_post_pruning)
151+
if info.embedding_config.num_embeddings_post_pruning is not None
152+
else info.embedding_config.num_embeddings
153+
),
143154
local_cols=info.embedding_config.embedding_dim,
144155
compute_kernel=EmbeddingComputeKernel(
145156
info.param_sharding.compute_kernel
@@ -150,7 +161,7 @@ def _shard(
150161
weight_init_max=info.embedding_config.weight_init_max,
151162
weight_init_min=info.embedding_config.weight_init_min,
152163
fused_params=info.fused_params,
153-
pruning_indices_remapping=info.embedding_config.pruning_indices_remapping,
164+
num_embeddings_post_pruning=info.embedding_config.num_embeddings_post_pruning,
154165
)
155166
)
156167
return tables_per_rank

0 commit comments

Comments
 (0)