Skip to content

FSDP2 Unwrap model, still has dtensors as weights #3487

Closed
@kiansierra

Description

@kiansierra

System Info

- `Accelerate` version: 1.6.0
- Platform: Linux-6.11.0-21-generic-x86_64-with-glibc2.39
- `accelerate` bash location: /home/kian/coding/kaggle/kaggle-drawing-with-llms/code-drawing-llms-kaggle/.venv/bin/accelerate
- Python version: 3.12.3
- Numpy version: 1.26.4
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- PyTorch SDAA available: False
- PyTorch MUSA available: False
- System RAM: 125.71 GB
- GPU type: NVIDIA GeForce RTX 3090
- `Accelerate` config passed:
        - compute_environment: LOCAL_MACHINE
        - distributed_type: FSDP
        - mixed_precision: bf16
        - use_cpu: False
        - debug: False
        - num_processes: 2
        - machine_rank: 0
        - num_machines: 1
        - rdzv_backend: static
        - same_network: True
        - main_training_function: main
        - enable_cpu_affinity: False
        - fsdp_config: {'fsdp_activation_checkpointing': True, 'fsdp_auto_wrap_policy': 'TRANSFORMER_BASED_WRAP', 'fsdp_cpu_ram_efficient_loading': True, 'fsdp_offload_params': False, 'fsdp_reshard_after_forward': True, 'fsdp_state_dict_type': 'SHARDED_STATE_DICT', 'fsdp_version': 2}
        - downcast_bf16: no
        - tpu_use_cluster: False
        - tpu_use_sudo: False
        - tpu_env: []

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

By running this minimal example, I run in to an issue when saving.

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTConfig, SFTTrainer
from accelerate import Accelerator
def main():
    dataset = load_dataset("winglian/alpaca-cleaned-chat", split="train")

    model_name = "HuggingFaceTB/SmolLM2-135M"
    processor_name = "HuggingFaceTB/SmolLM2-135M-Instruct"
    accelerator = Accelerator()

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map={'':accelerator.device},
        attn_implementation="flash_attention_2",
    )
    processing_class = AutoTokenizer.from_pretrained(processor_name)

    training_args = SFTConfig(
        max_length=2048,
        output_dir="/tmp",
        report_to=None,
        bf16=True,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=1,
        num_train_epochs=1,
        logging_steps=1,
        max_steps=5,
    )
    trainer = SFTTrainer(
        model,
        processing_class=processing_class,
        train_dataset=dataset,
        args=training_args,
    )
    trainer.train()
    model = accelerator.unwrap_model(model)
    model.save_pretrained(training_args.output_dir)

if __name__ == "__main__":
    main()

Part of the issue is that after the unwrap, the model is still an FSDPModule and the tensors are still distributes tensors.
Running the following removes the dtensors and now the weights go back to being regular tensors (for some of them, not all), but the model is still wrapped in the FSDPModule

handle = model.unshard(async_op = True)
handle.wait()
accelerator.wait_for_everyone()

This somewhat seems to be linked to safetensors since the below code does generate a pytorch_model.bin

model = accelerator.unwrap_model(model)
model.save_pretrained(training_args.output_dir, safe_serialization=False)

My debugger sets the safetensor error in
https://github.com/huggingface/safetensors/blob/main/bindings/python/py_src/safetensors/torch.py#L17
before going in to torch distributed

Exception has occurred: RuntimeError
Attempted to access the data pointer on an invalid python storage.
  File "/home/kian/coding/kaggle/kaggle-drawing-with-llms/code-drawing-llms-kaggle/.venv/lib/python3.12/site-packages/safetensors/torch.py", line 13, in storage_ptr
    return tensor.untyped_storage().data_ptr()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Attempted to access the data pointer on an invalid python storage.

Expected behavior

To generate a model.safetensors when safe_serialization=True

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions