Skip to content

ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32 #61

Open
@daje0601

Description

@daje0601

I've been testing https://github.com/philschmid/deep-learning-pytorch-huggingface/blob/main/training/scripts/run_fsdp_qlora.py

The language models I used are llama3-70B and llama-3.1-70B.
With llama-3-70B, fsdp Qlora works very well.
However, it does not work well with llama-3.1-70B.

I think it's a library dependency issue, and I'm asking because I've been testing dependencies for the last 3 days, but I can't find one that makes it work.

experimented withThe torch version

  • torch2.2.0
  • torch2.2.2
  • torch2.3.0

Experimented with library versions

"transformers==4.40.0" 부터 최신버전까지
"datasets==2.18.0" 부터 최신버전까지
"accelerate==0.29.3" 부터 최신버전까지
"evaluate==0.4.1" 부터 최신버전까지
"bitsandbytes==0.43.1" 부터 최신버전까지
"huggingface_hub==0.22.2" 부터 최신버전까지
"trl==0.8.6" 부터 최신버전까지
"peft==0.10.0" 부터 최신버전까지

error 코드

trainable params: 671,088,640 || all params: 8,701,349,888 || trainable%: 7.7125
[rank0]: Traceback (most recent call last):
[rank0]:   File "/data/user/iitp-Data/./script/fsdp_qlora.py", line 179, in <module>
[rank0]:     training_function(script_args, training_args)
[rank0]:   File "/data/user/iitp-Data/./script/fsdp_qlora.py", line 158, in training_function
[rank0]:     trainer.train(resume_from_checkpoint=checkpoint)
[rank0]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/trl/trainer/sft_trainer.py", line 450, in train
[rank0]:     output = super().train(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/transformers/trainer.py", line 1938, in train
[rank0]:     return inner_training_loop(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/transformers/trainer.py", line 2085, in _inner_training_loop
[rank0]:     self.model = self.accelerator.prepare(self.model)
[rank0]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/accelerate/accelerator.py", line 1326, in prepare
[rank0]:     result = tuple(
[rank0]:              ^^^^^^
[rank0]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/accelerate/accelerator.py", line 1327, in <genexpr>
[rank0]:     self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
[rank0]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/accelerate/accelerator.py", line 1200, in _prepare_one
[rank0]:     return self.prepare_model(obj, device_placement=device_placement)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/accelerate/accelerator.py", line 1484, in prepare_model
[rank0]:     model = FSDP(model, **kwargs)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 511, in __init__
[rank0]:     _init_param_handle_from_module(
[rank0]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/torch/distributed/fsdp/_init_utils.py", line 598, in _init_param_handle_from_module
[rank0]:     _init_param_handle_from_params(state, managed_params, fully_sharded_module)
[rank0]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/torch/distributed/fsdp/_init_utils.py", line 610, in _init_param_handle_from_params
[rank0]:     handle = FlatParamHandle(
[rank0]:              ^^^^^^^^^^^^^^^^
[rank0]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/torch/distributed/fsdp/_flat_param.py", line 582, in __init__
[rank0]:     self._init_flat_param_and_metadata(
[rank0]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/torch/distributed/fsdp/_flat_param.py", line 632, in _init_flat_param_and_metadata
[rank0]:     ) = self._validate_tensors_to_flatten(params)
[rank0]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/torch/distributed/fsdp/_flat_param.py", line 770, in _validate_tensors_to_flatten
[rank0]:     raise ValueError(
[rank0]: ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32
[rank1]: Traceback (most recent call last):
[rank1]:   File "/data/user/iitp-Data/./script/fsdp_qlora.py", line 179, in <module>
[rank1]:     training_function(script_args, training_args)
[rank1]:   File "/data/user/iitp-Data/./script/fsdp_qlora.py", line 158, in training_function
[rank1]:     trainer.train(resume_from_checkpoint=checkpoint)
[rank1]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/trl/trainer/sft_trainer.py", line 450, in train
[rank1]:     output = super().train(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/transformers/trainer.py", line 1938, in train
[rank1]:     return inner_training_loop(
[rank1]:            ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/transformers/trainer.py", line 2085, in _inner_training_loop
[rank1]:     self.model = self.accelerator.prepare(self.model)
[rank1]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/accelerate/accelerator.py", line 1326, in prepare
[rank1]:     result = tuple(
[rank1]:              ^^^^^^
[rank1]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/accelerate/accelerator.py", line 1327, in <genexpr>
[rank1]:     self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
[rank1]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/accelerate/accelerator.py", line 1200, in _prepare_one
[rank1]:     return self.prepare_model(obj, device_placement=device_placement)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/accelerate/accelerator.py", line 1484, in prepare_model
[rank1]:     model = FSDP(model, **kwargs)
[rank1]:             ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 511, in __init__
[rank1]:     _init_param_handle_from_module(
[rank1]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/torch/distributed/fsdp/_init_utils.py", line 598, in _init_param_handle_from_module
[rank1]:     _init_param_handle_from_params(state, managed_params, fully_sharded_module)
[rank1]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/torch/distributed/fsdp/_init_utils.py", line 610, in _init_param_handle_from_params
[rank1]:     handle = FlatParamHandle(
[rank1]:              ^^^^^^^^^^^^^^^^
[rank1]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/torch/distributed/fsdp/_flat_param.py", line 582, in __init__
[rank1]:     self._init_flat_param_and_metadata(
[rank1]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/torch/distributed/fsdp/_flat_param.py", line 632, in _init_flat_param_and_metadata
[rank1]:     ) = self._validate_tensors_to_flatten(params)
[rank1]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/commpany/anaconda3/envs/fsdp/lib/python3.11/site-packages/torch/distributed/fsdp/_flat_param.py", line 770, in _validate_tensors_to_flatten
[rank1]:     raise ValueError(
[rank1]: ValueError: Must flatten tensors with uniform dtype but got torch.bfloat16 and torch.float32

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions