Skip to content

Fy/sllm checkpoint #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
vllm loader new API
  • Loading branch information
Leyang Xue authored and future-xy committed Sep 10, 2024
commit 2b9fe5b24b8f4fbd399bc4e4559ecddd21bd2e11
166 changes: 114 additions & 52 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,90 +632,152 @@ def load_model(self, *, model_config: ModelConfig,
cache_config: CacheConfig) -> nn.Module:
print("Loading model")
from serverless_llm_store.client import SllmStoreClient
from serverless_llm_store import load_into_cpu_non_blocking, load_into_gpu_non_blocking, wait_dict_loaded
from serverless_llm_store._C import (
get_cuda_memory_handles,
get_device_uuid_map,
get_device_ptrs_from_mem_handles,
allocate_cuda_memory,
)

from vllm.distributed import get_tensor_model_parallel_rank
from accelerate import dispatch_model, init_empty_weights
import uuid

assert os.path.isdir(model_config.model)

client = SllmStoreClient("localhost:8073")
# client = SllmStoreClient("localhost:8073")
rank = get_tensor_model_parallel_rank()

local_model_path = model_config.model
local_model_path = os.path.join(local_model_path, f"rank_{rank}")
model_name = "/".join(local_model_path.split("/")[-2:])

ret = client.load_into_cpu(model_name)
if not ret or ret == False:
raise ValueError(f"Failed to load model {model_name} into CPU")
# model name is everything after models
model_name = local_model_path.split("models/")[1]
storage_path = local_model_path.split("models/")[0]
if storage_path.endswith("/"):
storage_path = os.path.join(storage_path, "models")
else:
storage_path = storage_path + "models"
device_map = {"": rank}

load_into_cpu_non_blocking(model_name, device_map, storage_path)
replica_uuid, sllm_state_dict, device_map = load_into_gpu_non_blocking(model_name, device_map, storage_path)

tensor_index_path = os.path.join(local_model_path, "tensor_index.json")
with open(tensor_index_path, "r") as f:
tensor_index = json.load(f)
# tensor_index_path = os.path.join(local_model_path, "tensor_index.json")
# with open(tensor_index_path, "r") as f:
# tensor_index = json.load(f)

device_uuid_map = get_device_uuid_map()
device_uuid = device_uuid_map[rank]
replica_uuid = str(uuid.uuid4())
# device_uuid_map = get_device_uuid_map()
# device_uuid = device_uuid_map[rank]
# replica_uuid = str(uuid.uuid4())

# sllm_state_dict = load_dict(os.path.join(local_model_path, f"rank_{rank}", device_config.device))

with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config,
cache_config)
lora_config, vision_language_config,
cache_config)
model = model.eval()

# set all parameters to meta device
state_dict = self._filter_subtensors(model.state_dict())
key_list = list(state_dict.keys())

memory_ptrs = {rank: []}
tensor_copy_chunks = {rank: []}
for key, param in model.named_parameters(recurse=True):
if key in key_list:
param.data = torch.empty(1, device=torch.device("cuda"))

# idx = 0
# for name, param in model.named_parameters(recurse=True):
# if not name in state_dict:
# continue
# data_ptr = param.data_ptr()
# memory_ptrs[rank].append(data_ptr)
# model = dispatch_model(model, {"": torch.device("meta")})
torch.cuda.empty_cache()

wait_dict_loaded(model_name, replica_uuid)

for key, param in model.named_parameters(recurse=True):
if key in key_list:
tensor = sllm_state_dict[key]
# param_data = param.data
# param_shape = param.shape
# print(f"{param_shape=}, {param.device=}")
# for dim, size in enumerate(tensor.shape):
# if size < param_shape[dim]:
# param_data = param_data.narrow(dim, 0, size)
# if tensor.shape != param_shape:
# logger.warning(
# "loading tensor of shape %s into "
# "parameter '%s' of shape %s", tensor.shape, key, param_shape)
# param_data.copy_(tensor)
param.data = tensor
state_dict.pop(key)
if state_dict:
raise ValueError(
f"Missing keys {tuple(state_dict)} in loaded state!")



# tensor_meta_index = {}
# tensor_data_index = {}
# for name, (offset, size, shape, stride, dtype) in tensor_index.items():
# tensor_meta_index[name] = (shape, stride, dtype)
# tensor_data_index[name] = (offset, size)

# total_memory_size = 0
# tensor_offsets = []
# tensor_chunks = []
# for name in state_dict.keys():
# cpu_offsets, memory_size = tensor_index[name]
# tensor_chunks.append((cpu_offsets, size, total_memory_size, 0))
# tensor_offsets.append(total_memory_size)
# total_memory_size += memory_size

# cuda_memory_ptrs = allocate_cuda_memory(total_memory_size)
# cuda_memory_handles = get_cuda_memory_handles(cuda_memory_ptrs)

# memory_ptrs = {rank: []}
# tensor_copy_chunks = {rank: []}

# # idx = 0
# # for name, param in model.named_parameters(recurse=True):
# # if not name in state_dict:
# # continue
# # data_ptr = param.data.untyped_storage().data_ptr()
# # memory_ptrs[rank].append(data_ptr)

# # offset, size, _, _, _ = tensor_index[name]
# # tensor_copy_chunks[rank].append((offset, size, 0, idx))
# # idx += 1
# # print(f"Loading tensor {name} with offset {offset} and size {size}, device {param.device}, {hex(data_ptr)}, {idx}")

# for idx, (name, param) in enumerate(state_dict.items()):
# # data_ptr = param.untyped_storage().data_ptr()
# data_ptr = param.view(-1)[-1].data_ptr()
# memory_ptrs[rank].append(data_ptr)
# offset, size, _, _, _ = tensor_index[name]
# # every tensor has its own base address, so GPU offset is always 0
# tensor_copy_chunks[rank].append((offset, size, 0, idx))
# idx += 1
# print(f"Loading tensor {name} with offset {offset} and size {size}, device {param.device}, {hex(data_ptr)}, {idx}")

for idx, (name, param) in enumerate(state_dict.items()):
data_ptr = param.untyped_storage().data_ptr()
memory_ptrs[rank].append(data_ptr)
offset, size, _, _, _ = tensor_index[name]
# every tensor has its own base address, so GPU offset is always 0
tensor_copy_chunks[rank].append((offset, size, 0, idx))
print(f"Loading tensor {name} with offset {offset} and size {size}, device {param.device}, {hex(data_ptr)}, {idx}")

cuda_memory_handles = get_cuda_memory_handles(memory_ptrs)
# device_ptrs = get_device_ptrs_from_mem_handles(cuda_memory_handles)
# cuda_memory_handles = get_cuda_memory_handles(memory_ptrs)

# for k, ptr in enumerate(device_ptrs[rank]):
# assert hex(ptr) == hex(memory_ptrs[rank][k]), f"Memory ptrs do not match: {hex(ptr)} != {hex(memory_ptrs[rank][k])}"
# cuda_memory_handles = {
# rank: [
# get_cuda_memory_handles({rank: ptr})[rank]
# for ptr in memory_ptrs[rank]
# ]
# }
# # for k, ptr in enumerate(device_ptrs[rank]):
# # assert hex(ptr) == hex(memory_ptrs[rank][k]), f"Memory ptrs do not match: {hex(ptr)} != {hex(memory_ptrs[rank][k])}"
# # cuda_memory_handles = {
# # rank: [
# # get_cuda_memory_handles({rank: ptr})[rank]
# # for ptr in memory_ptrs[rank]
# # ]
# # }

ret = client.load_into_gpu(
model_name,
replica_uuid,
{device_uuid: tensor_copy_chunks[rank]},
{device_uuid: cuda_memory_handles[rank]}
)
if not ret or ret == False:
raise ValueError(f"Failed to load model {model_name} into GPU")
client.confirm_model_loaded(model_name, replica_uuid)
return model.eval()
# ret = client.load_into_gpu(
# model_name,
# replica_uuid,
# {device_uuid: tensor_copy_chunks[rank]},
# {device_uuid: cuda_memory_handles[rank]}
# )
# if not ret or ret == False:
# raise ValueError(f"Failed to load model {model_name} into GPU")
# client.confirm_model_loaded(model_name, replica_uuid)
return model

@staticmethod
def save_model(
Expand All @@ -732,7 +794,7 @@ def save_model(

# move all tensors to CPU
for key, tensor in state_dict.items():
state_dict[key] = tensor.cpu()
state_dict[key] = tensor.cpu().contiguous()

save_dict(state_dict, os.path.join(path, f"rank_{rank}"))

Expand Down