Description
Describe the bug
I was trying to run Deepspeed-Inference on Llama-4-Scout-Instruct for text generation purpose. The process failed when it started to load model.
I am running on a single node with 8 GPUs with 80GB of GPU memory on each: 8*80GB total. I used float16.
Here is the error message:
AutoTP: [(<class 'transformers.models.llama4.modeling_llama4.Llama4TextDecoderLayer'>, ['shared_expert.down_proj', 'self_attn.o_proj'])]
Loading 0 checkpoint shards: 0it [00:00, ?it/s][rank0]: Traceback (most recent call last):
[rank0]: ds_engine = deepspeed.init_inference(model, config=ds_inference_config)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/site-packages/deepspeed/__init__.py", line 364, in init_inference
[rank0]: engine = InferenceEngine(model, config=ds_inference_config)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/usr/local/lib/python3.12/site-packages/deepspeed/inference/engine.py", line 164, in __init__
[rank0]: self._apply_injection_policy(config, client_module)
[rank0]: File "/usr/local/lib/python3.12/site-packages/deepspeed/inference/engine.py", line 388, in _apply_injection_policy
[rank0]: replace_transformer_layer(client_module, self.module, checkpoint, config, self.config)
[rank0]: File "/usr/local/lib/python3.12/site-packages/deepspeed/module_inject/replace_module.py", line 397, in replace_transformer_layer
[rank0]: if 'Yuan' in str(replaced_module):
[rank0]: ^^^^^^^^^^^^^^^
[rank0]: UnboundLocalError: cannot access local variable 'replaced_module' where it is not associated with a value
To Reproduce
Steps to reproduce the behavior:
- Simple inference script to reproduce
Here is a code snippet to reproduce the error:
import glob
import os
import deepspeed
import torch
from transformers import Llama4TextConfig, AutoModelForCausalLM
model_path = "./model_to_evaluate"
kwargs = {"torch_dtype": torch.float16, "attn_implementation": "sdpa"}
print(f"kwargs: {kwargs}")
model_config = Llama4TextConfig.from_pretrained(model_path, **kwargs)
# load model using meta device
with deepspeed.OnDevice(dtype=kwargs["torch_dtype"], device="meta", enabled=True):
model = AutoModelForCausalLM.from_config(model_config, **kwargs)
print(f"model device: {next(model.parameters()).device}")
# set up deepspeed inference config
ds_inference_config = {
"dtype": kwargs["torch_dtype"],
# meta device is not compatible with kernel injection
"replace_with_kernel_inject": False,
# tp equals to the global number of gpus
"tensor_parallel": {
"tp_size": int(os.getenv("WORLD_SIZE", "1"))
},
# specify where the model files are
"checkpoint": {
"checkpoints": glob.glob(os.path.join(model_path, "**", "*" + ".safetensors"), recursive=False),
"type": "DS_MODEL",
"version": 1.0
}
}
print(f"deepspeed inference config: {ds_inference_config}")
ds_engine = deepspeed.init_inference(model, config=ds_inference_config)
model = ds_engine.module
model.eval()
- What packages are required and their versions
transformers==4.51.3
accelerate==1.6.0
deepspeed==0.16.7
flash-attn==2.7.3 (this is optional since I ran with sdpa)
torch==2.6.0+cu126
-
How to run the script
Step 1: Download meta-llama/Llama-4-Scout-17B-16E-Instruct to local directory "model_to_evaluate".
Step 2: Put the above code to file "llama4_dsi.py".
Step 3: Run this: accelerate launch llama4_dsi.py -
...
Expected behavior
Expecting the model to successfully load and be distributed across the 8 GPUs.
ds_report output
Please run ds_report
to give us details about your setup.
Screenshots
If applicable, add screenshots to help explain your problem.
System info (please complete the following information):
- OS: [e.g. Ubuntu 18.04]
- GPU count and types [e.g. two machines with x8 A100s each]
- (if applicable) what DeepSpeed-MII version are you using
- (if applicable) Hugging Face Transformers/Accelerate/etc. versions
- Python version
- Any other relevant info about your setup
Docker context
Are you using a specific docker image that you can share?
Additional context
Add any other context about the problem here.