|
4 | 4 | import torch
|
5 | 5 |
|
6 | 6 | from ..._ops import register_kernel
|
7 |
| -from ..utils import ipex_xpu |
| 7 | +from ..utils import ipex_xpu, triton_available |
8 | 8 |
|
9 | 9 | # With default torch, error:
|
10 | 10 | # NotImplementedError: The operator 'aten::_int_mm' for XPU
|
@@ -52,23 +52,16 @@ def _(
|
52 | 52 | raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")
|
53 | 53 |
|
54 | 54 | return out.reshape(shape)
|
55 |
| -else: |
| 55 | +elif triton_available: |
56 | 56 | # IPEX should be faster for xpu, so at first checking if it is available.
|
57 |
| - try: |
58 |
| - from ..triton import ops as triton_ops |
59 |
| - |
60 |
| - triton_available = True |
61 |
| - except ImportError as e: |
62 |
| - print("Import error:", e) |
63 |
| - triton_available = False |
| 57 | + from ..triton import ops as triton_ops |
64 | 58 |
|
65 |
| - if triton_available: |
66 |
| - register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise) |
67 |
| - register_kernel("bitsandbytes::dequantize_blockwise.out", "xpu")(triton_ops.dequantize_blockwise_inplace) |
68 |
| - register_kernel("bitsandbytes::dequantize_blockwise", "xpu")(triton_ops.dequantize_blockwise) |
69 |
| - register_kernel("bitsandbytes::quantize_4bit", "xpu")(triton_ops.quantize_4bit) |
70 |
| - register_kernel("bitsandbytes::dequantize_4bit.out", "xpu")(triton_ops.dequantize_4bit_inplace) |
71 |
| - register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit) |
72 |
| - register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit) |
73 |
| - else: |
74 |
| - warnings.warn("XPU available, but trtion package is missing.") |
| 59 | + register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise) |
| 60 | + register_kernel("bitsandbytes::dequantize_blockwise.out", "xpu")(triton_ops.dequantize_blockwise_inplace) |
| 61 | + register_kernel("bitsandbytes::dequantize_blockwise", "xpu")(triton_ops.dequantize_blockwise) |
| 62 | + register_kernel("bitsandbytes::quantize_4bit", "xpu")(triton_ops.quantize_4bit) |
| 63 | + register_kernel("bitsandbytes::dequantize_4bit.out", "xpu")(triton_ops.dequantize_4bit_inplace) |
| 64 | + register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit) |
| 65 | + register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit) |
| 66 | +else: |
| 67 | + warnings.warn("XPU available, but nor ipex or trtion package is found.") |
0 commit comments