Skip to content

Commit cd44c6e

Browse files
check device before sync (#796)
* check device before sync * cleanup
1 parent 26961ce commit cd44c6e

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

gptqmodel/models/loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from transformers.modeling_utils import no_init_weights
1313
from transformers.utils.generic import ContextManagers
1414

15-
from ._const import DEVICE, SUPPORTED_MODELS, get_best_device, is_torch_support_xpu
1615
from ..nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear
1716
from ..nn_modules.qlinear.ipex import IPEXQuantLinear, ipex_dtype
1817
from ..quantization import QuantizeConfig
@@ -25,6 +24,7 @@
2524
from ..utils.model import (auto_dtype_from_config, convert_gptq_v1_to_v2_format, find_layers,
2625
get_checkpoints, get_moe_layer_modules, gptqmodel_post_init, make_quant,
2726
simple_dispatch_model, verify_model_hash, verify_sharded_model_hashes)
27+
from ._const import DEVICE, SUPPORTED_MODELS, get_best_device, is_torch_support_xpu
2828

2929
logger = setup_logger()
3030

gptqmodel/quantization/gptq.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,11 @@ def fasterquant(
183183
logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
184184
logger.debug(torch.sum(Losses))
185185

186-
if torch.cuda.is_available():
186+
if self.dev.type == "cuda":
187187
torch.cuda.synchronize()
188-
if hasattr(torch, "xpu") and torch.xpu.is_available():
188+
elif self.dev.type == "xpu":
189189
torch.xpu.synchronize()
190+
190191
duration = time.time() - tick
191192
avg_loss = torch.sum(Losses).item() / self.nsamples
192193

gptqmodel/utils/model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,17 @@
2020
from transformers import AutoConfig, PretrainedConfig
2121
from transformers.utils.hub import cached_file
2222

23-
from .backend import BACKEND
24-
from .importer import select_quant_linear
25-
from .logger import setup_logger
26-
from .progress import ProgressBar
2723
from ..models._const import CPU, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH, EXPERT_INDEX_PLACEHOLDER, SUPPORTED_MODELS
2824
from ..nn_modules.qlinear import BaseQuantLinear
2925
from ..nn_modules.qlinear.exllama import ExllamaQuantLinear
3026
from ..nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear
3127
from ..nn_modules.qlinear.ipex import IPEXQuantLinear
3228
from ..nn_modules.qlinear.torch import TorchQuantLinear
3329
from ..quantization import FORMAT, QuantizeConfig
30+
from .backend import BACKEND
31+
from .importer import select_quant_linear
32+
from .logger import setup_logger
33+
from .progress import ProgressBar
3434

3535
logger = setup_logger()
3636

0 commit comments

Comments
 (0)