Skip to content

Commit 1fe5c05

Browse files
committed
add intel xpu support for TGI
Signed-off-by: xiaolil1 <[email protected]> Signed-off-by: ganyi <[email protected]> Signed-off-by: Wang, Yi A <[email protected]>
1 parent da27fbd commit 1fe5c05

File tree

13 files changed

+259
-74
lines changed

13 files changed

+259
-74
lines changed

Dockerfile_intel

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Rust builder
2+
FROM lukemathwalker/cargo-chef:latest-rust-1.71 AS chef
3+
WORKDIR /usr/src
4+
5+
ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
6+
7+
FROM chef as planner
8+
COPY Cargo.toml Cargo.toml
9+
COPY rust-toolchain.toml rust-toolchain.toml
10+
COPY proto proto
11+
COPY benchmark benchmark
12+
COPY router router
13+
COPY launcher launcher
14+
RUN cargo chef prepare --recipe-path recipe.json
15+
16+
FROM chef AS builder
17+
18+
ARG GIT_SHA
19+
ARG DOCKER_LABEL
20+
21+
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
22+
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
23+
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
24+
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
25+
rm -f $PROTOC_ZIP
26+
27+
COPY --from=planner /usr/src/recipe.json recipe.json
28+
RUN cargo chef cook --release --recipe-path recipe.json
29+
30+
COPY Cargo.toml Cargo.toml
31+
COPY rust-toolchain.toml rust-toolchain.toml
32+
COPY proto proto
33+
COPY benchmark benchmark
34+
COPY router router
35+
COPY launcher launcher
36+
RUN cargo build --release
37+
38+
# Text Generation Inference base image for Intel
39+
FROM intel/intel-extension-for-pytorch:2.1.10-xpu as base
40+
41+
USER root
42+
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
43+
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
44+
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
45+
46+
# Text Generation Inference base env
47+
ENV HUGGINGFACE_HUB_CACHE=/data \
48+
HF_HUB_ENABLE_HF_TRANSFER=1 \
49+
PORT=80
50+
51+
52+
# Install server
53+
COPY proto proto
54+
COPY server server
55+
COPY server/Makefile server/Makefile
56+
RUN cd server && \
57+
make gen-server && \
58+
pip install -r requirements_common.txt && \
59+
pip install ".[accelerate, peft]" --no-cache-dir
60+
61+
# Install benchmarker
62+
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
63+
# Install router
64+
COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
65+
# Install launcher
66+
COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
67+
68+
# Final image
69+
FROM base
70+
71+
ENTRYPOINT ["text-generation-launcher"]
72+
CMD ["--json-output"]

server/text_generation_server/models/cache_manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33

44
from typing import Optional, List, Tuple
5+
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
56

67
BLOCK_SIZE: int = 16
78
# Will be set in warmup
@@ -24,7 +25,10 @@ def __init__(
2425
self.repeat_slots = repeat_slots
2526

2627
element_size = torch.tensor([], dtype=dtype).element_size()
27-
x = self.block_size // element_size
28+
if IS_XPU_SYSTEM:
29+
x = 1
30+
else:
31+
x = self.block_size // element_size
2832

2933
self.kv_cache = [
3034
(

server/text_generation_server/models/flash_causal_lm.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from text_generation_server.utils.dist import MEMORY_FRACTION
3131

3232
tracer = trace.get_tracer(__name__)
33-
33+
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM
3434

3535
@dataclass
3636
class FlashCausalLMBatch(Batch):
@@ -679,7 +679,10 @@ def batch_type(self) -> Type[FlashCausalLMBatch]:
679679
return FlashCausalLMBatch
680680

681681
def warmup(self, batch: FlashCausalLMBatch):
682-
torch.cuda.empty_cache()
682+
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
683+
torch.cuda.empty_cache()
684+
elif IS_XPU_SYSTEM:
685+
torch.xpu.empty_cache()
683686
try:
684687
cache_manager = set_cache_manager(
685688
batch.blocks,
@@ -697,20 +700,29 @@ def warmup(self, batch: FlashCausalLMBatch):
697700
f"You need to decrease `--max-batch-prefill-tokens`"
698701
) from e
699702

700-
torch.cuda.synchronize(self.device)
703+
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
704+
torch.cuda.synchronize(self.device)
705+
elif IS_XPU_SYSTEM:
706+
torch.xpu.synchronize(self.device)
701707

702708
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
703709
# Calculate the number of blocks that can be allocated with the free memory
704710
dtype_size = torch.tensor([], dtype=self.dtype).element_size()
705711
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
706712
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
707713

708-
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
709-
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
714+
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
715+
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
716+
total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory
710717

711-
free_memory = max(
712-
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
713-
)
718+
free_memory = max(
719+
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
720+
)
721+
elif IS_XPU_SYSTEM:
722+
total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory
723+
free_memory = int(total_gpu_memory *0.5)
724+
else:
725+
raise NotImplementedError("FlashModel is only available on GPU")
714726

715727
num_blocks = (
716728
int(free_memory // total_cache_size)

server/text_generation_server/models/flash_llama.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
tracer = trace.get_tracer(__name__)
2121

22+
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
2223

2324
class FlashLlama(FlashCausalLM):
2425
def __init__(
@@ -34,6 +35,9 @@ def __init__(
3435
if torch.cuda.is_available():
3536
device = torch.device(f"cuda:{rank}")
3637
dtype = torch.float16 if dtype is None else dtype
38+
elif IS_XPU_SYSTEM:
39+
device = torch.device(f"xpu:{rank}")
40+
dtype = torch.float16 if dtype is None else dtype
3741
else:
3842
raise NotImplementedError("FlashLlama is only available on GPU")
3943

server/text_generation_server/models/flash_mistral.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
# Will be set in init
3535
SLIDING_WINDOW: Optional[int] = None
3636
SLIDING_WINDOW_BLOCKS: Optional[int] = None
37+
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
3738

3839

3940
# Adds windowing logic to FlashCausalLMBatch
@@ -302,8 +303,11 @@ def __init__(
302303
if torch.cuda.is_available():
303304
device = torch.device(f"cuda:{rank}")
304305
dtype = torch.float16 if dtype is None else dtype
306+
elif IS_XPU_SYSTEM:
307+
device = torch.device(f"xpu:{rank}")
308+
dtype = torch.float16 if dtype is None else dtype
305309
else:
306-
raise NotImplementedError("FlashLlama is only available on GPU")
310+
raise NotImplementedError("FlashMistral is only available on GPU")
307311

308312
tokenizer = LlamaTokenizerFast.from_pretrained(
309313
model_id,

server/text_generation_server/models/flash_neox.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
weight_files,
1515
Weights,
1616
)
17-
17+
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
1818
tracer = trace.get_tracer(__name__)
1919

2020

@@ -31,6 +31,9 @@ def __init__(
3131
if torch.cuda.is_available():
3232
device = torch.device(f"cuda:{rank}")
3333
dtype = torch.float16 if dtype is None else dtype
34+
elif IS_XPU_SYSTEM:
35+
device = torch.device(f"xpu:{rank}")
36+
dtype = torch.float16 if dtype is None else dtype
3437
else:
3538
raise NotImplementedError("FlashNeoX is only available on GPU")
3639

server/text_generation_server/models/flash_rw.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
weight_files,
1616
Weights,
1717
)
18-
18+
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
1919
tracer = trace.get_tracer(__name__)
2020

2121

@@ -32,6 +32,9 @@ def __init__(
3232
if torch.cuda.is_available():
3333
device = torch.device(f"cuda:{rank}")
3434
dtype = torch.float16 if dtype is None else dtype
35+
elif IS_XPU_SYSTEM:
36+
device = torch.device(f"xpu:{rank}")
37+
dtype = torch.float16 if dtype is None else dtype
3538
else:
3639
raise NotImplementedError("FlashRW is only available on GPU")
3740

server/text_generation_server/models/flash_santacoder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Weights,
1919
)
2020

21+
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
2122
tracer = trace.get_tracer(__name__)
2223

2324

@@ -34,6 +35,9 @@ def __init__(
3435
if torch.cuda.is_available():
3536
device = torch.device(f"cuda:{rank}")
3637
dtype = torch.float16 if dtype is None else dtype
38+
elif IS_XPU_SYSTEM:
39+
device = torch.device(f"xpu:{rank}")
40+
dtype = torch.float16 if dtype is None else dtype
3741
else:
3842
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
3943

server/text_generation_server/utils/dist.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,14 @@ def initialize_torch_distributed():
5757
options.is_high_priority_stream = True
5858
options._timeout = timedelta(seconds=60)
5959
else:
60-
backend = "gloo"
60+
try:
61+
import oneccl_bindings_for_pytorch
62+
63+
backend = "ccl"
64+
if os.getenv("CCL_WORKER_COUNT", None) is None:
65+
os.environ["CCL_WORKER_COUNT"] = str(1)
66+
except ImportError:
67+
backend = "gloo"
6168
options = None
6269

6370
if WORLD_SIZE == 1:

0 commit comments

Comments
 (0)