Open
Description
System Info
- GPU: B200
- TensorRT-LLM: v0.20.0
- Driver Version: 570.158.01
- CUDA Version: 12.8
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Run server:
trtllm-serve nvidia/DeepSeek-R1-0528-FP4 --max_batch_size 128 --max_num_tokens 163840 --max_seq_len 163840 --kv_cache_free_gpu_memory_fraction 0.6 --port 8000 --trust_remote_code true --backend pytorch --tp_size 8 --pp_size 1 --ep_size 8 --extra_llm_api_options extra-llm-api-config.yaml
Config file extra-llm-api-config.yaml
:
pytorch_backend_config:
use_cuda_graph: true
cuda_graph_batch_sizes: [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,100,128,256]
cuda_graph_padding_enabled: true
kv_cache_dtype: fp8
enable_attention_dp: false
enable_chunked_prefill: true
speculative_config:
decoding_type: MTP
num_nextn_predict_layers: 3
use_relaxed_acceptance_for_thinking: true
relaxed_topk: 10
relaxed_delta: 0.6
Then send requests at 120 RPM.
Expected behavior
No errors.
actual behavior
Got error:
Traceback (most recent call last):
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/py_executor.py", line 1615, in _forward_step
outputs = forward(scheduled_requests, self.resource_manager,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvtx/nvtx.py", line 122, in inner
result = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/py_executor.py", line 1615, in _forward_step
outputs = forward(scheduled_requests, self.resource_manager,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvtx/nvtx.py", line 122, in inner
result = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/py_executor.py", line 1610, in forward
return self.model_engine.forward(scheduled_requests,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/py_executor.py", line 1615, in _forward_step
outputs = forward(scheduled_requests, self.resource_manager,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvtx/nvtx.py", line 122, in inner
result = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/py_executor.py", line 1610, in forward
return self.model_engine.forward(scheduled_requests,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/py_executor.py", line 1615, in _forward_step
outputs = forward(scheduled_requests, self.resource_manager,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvtx/nvtx.py", line 122, in inner
result = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/py_executor.py", line 1610, in forward
return self.model_engine.forward(scheduled_requests,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/utils.py", line 66, in wrapper
return func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 1913, in forward
inputs, gather_ids = self._prepare_inputs(scheduled_requests,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvtx/nvtx.py", line 122, in inner
result = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 1866, in _prepare_inputs
return self._prepare_tp_inputs(scheduled_requests, kv_cache_manager,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 1352, in _prepare_tp_inputs
spec_metadata.prepare()
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/speculative/mtp.py", line 213, in prepare
mtp_slot_ids = torch.tensor(mtp_slot_ids,
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/py_executor.py", line 1610, in forward
return self.model_engine.forward(scheduled_requests,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/utils.py", line 66, in wrapper
return func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 1913, in forward
inputs, gather_ids = self._prepare_inputs(scheduled_requests,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/nvtx/nvtx.py", line 122, in inner
result = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 1866, in _prepare_inputs
return self._prepare_tp_inputs(scheduled_requests, kv_cache_manager,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/pyexecutor/model_engine.py", line 1352, in _prepare_tp_inputs
spec_metadata.prepare()
File "/usr/local/lib/python3.12/dist-packages/tensorrt_llm/_torch/speculative/mtp.py", line 213, in prepare
mtp_slot_ids = torch.tensor(mtp_slot_ids,
^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: 'NoneType' object cannot be interpreted as an integer
additional notes
I added a log in tensorrt_llm/_torch/speculative/mtp.py
:
if self.mtp_hidden_states_manager is not None: # MTP vanilla or use relaxed acceptance
mtp_slot_ids = []
for rid in self.request_ids:
slot_id = self.mtp_hidden_states_manager.slot_manager.get_slot(
rid)
+ print(f'{slot_id=}, {rid=}')
mtp_slot_ids.append(slot_id)
And I got:
slot_id=99, rid=486
slot_id=101, rid=487
slot_id=20, rid=488
slot_id=102, rid=489
slot_id=None, rid=18446744073709551598
slot_id=None, rid=18446744073709551599
slot_id=None, rid=18446744073709551600
slot_id=None, rid=18446744073709551601
It seems there are some invalid uint64 request ids (like 0xffffffffffffffee) in MTPSpecMetadata.request_ids
.