Skip to content

Commit 78694e4

Browse files
authored
Handle new linear modules in DeepSpeed v0.16.5 (#3622)
* Handle fused_LinearLayer in DeepSpeed v0.16.5 * Handle GateUpPack_LinearLayer in DeepSpeed v0.16.5
1 parent 6e17fbb commit 78694e4

File tree

3 files changed

+26
-14
lines changed

3 files changed

+26
-14
lines changed

intel_extension_for_pytorch/nn/utils/_weight_prepack.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -101,18 +101,22 @@ def may_import_deepspeed_modules():
101101
try:
102102
# import deepspeed in a global space will raise circular import error
103103
# intel-extension-for-deepspeed imports both IPEX and deepspeed
104-
from deepspeed.module_inject.layers import LinearAllreduce, LinearLayer
105-
106-
ds_layers = [LinearAllreduce, LinearLayer]
107-
108-
# TODO: remove this logic once deepspeed LmHeadLinearAllreduce change has been upstream-ed.
109-
try:
110-
from deepspeed.module_inject.layers import LmHeadLinearAllreduce
104+
from deepspeed.module_inject.layers import (
105+
LinearAllreduce,
106+
LinearLayer,
107+
LmHeadLinearAllreduce,
108+
fused_LinearLayer,
109+
GateUpPack_LinearLayer,
110+
)
111111

112-
ds_layers.append(LmHeadLinearAllreduce)
113-
return ds_layers
114-
except ImportError:
115-
return ds_layers
112+
ds_layers = [
113+
LinearAllreduce,
114+
LinearLayer,
115+
LmHeadLinearAllreduce,
116+
fused_LinearLayer,
117+
GateUpPack_LinearLayer,
118+
]
119+
return ds_layers
116120
except ImportError:
117121
return None
118122

intel_extension_for_pytorch/utils/weight_only_quantization.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,11 +292,19 @@ def _convert_woq_with_low_precision_checkpoint(
292292

293293
deepspeed_modules = may_import_deepspeed_modules()
294294
if deepspeed_modules is not None:
295-
LinearAllreduce, LinearLayer, LmHeadLinearAllreduce = deepspeed_modules[:]
295+
(
296+
LinearAllreduce,
297+
LinearLayer,
298+
LmHeadLinearAllreduce,
299+
fused_LinearLayer,
300+
GateUpPack_LinearLayer,
301+
) = deepspeed_modules
296302
q_op_map.update(
297303
{
298304
LinearAllreduce: IpexWoqLinearAllreduce,
299305
LinearLayer: WeightOnlyQuantizedLinear,
306+
fused_LinearLayer: WeightOnlyQuantizedLinear,
307+
GateUpPack_LinearLayer: WeightOnlyQuantizedLinear,
300308
}
301309
)
302310

tests/cpu/test_deepspeed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def _get_ds_model(self, m_linear):
197197
def test_ipex_optimize(self):
198198
deepspeed_modules = may_import_deepspeed_modules()
199199
if deepspeed_modules is not None:
200-
LinearAllreduce, LinearLayer, LmHeadLinearAllreduce = deepspeed_modules
200+
LinearAllreduce, LinearLayer, LmHeadLinearAllreduce = deepspeed_modules[:3]
201201

202202
x = torch.randn(2, 3, 64)
203203
m_linear = DeepSpeedTestM(MyLmHeadModel).eval()
@@ -241,7 +241,7 @@ def _test_quantization(
241241
):
242242
deepspeed_modules = may_import_deepspeed_modules()
243243
if deepspeed_modules is not None:
244-
LinearAllreduce, LinearLayer, LmHeadLinearAllreduce = deepspeed_modules
244+
LinearAllreduce, LinearLayer, LmHeadLinearAllreduce = deepspeed_modules[:3]
245245

246246
x = torch.randn(2, 3, 64)
247247
m_linear = DeepSpeedTestM(MyLmHeadModel).eval()

0 commit comments

Comments
 (0)