Skip to content

Vision Fine Tuning Gemma 3 takes Impossiblily High VRam (OOM Error 8xH200) #3481

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
5 tasks done
amanmehra89 opened this issue May 22, 2025 · 0 comments
Open
5 tasks done
Labels
⚡accelerate Related to accelerate 🐛 bug Something isn't working ⚡ PEFT Related to PEFT

Comments

@amanmehra89
Copy link

amanmehra89 commented May 22, 2025

Reproduction

I'm trying to fientune Gemma 3 on a vision dataset, based on https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py

How ever i get out of memory issue even if i try with 2-3 training samples..
Below is the Code i used

# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Train Gemma-3 on the HuggingFaceH4/llava-instruct-mix-vsft dataset (single-image).

accelerate launch \
    --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
    examples/scripts/sft_vlm_gemma3.py \
    --dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
    --model_name_or_path google/gemma-3-4b-it \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --output_dir gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft \
    --bf16 \
    --torch_dtype bfloat16 \
    --use_peft \
    --lora_target_modules all-linear \
    --attn_implementation eager

Train Gemma-3 on the FanqingM/MMIU-Benchmark dataset (multi-image).

accelerate launch \
    --config_file /home/ubuntu/gemma-finetune/examples/accelerate_configs/deepspeed_zero3.yaml \
    /home/ubuntu/gemma-finetune/scripts/sft_vlm_gemma3.py \
    --dataset_name FanqingM/MMIU-Benchmark \
    --dataset_train_split test \
    --model_name_or_path google/gemma-3-4b-it \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --output_dir gemma-3-4b-it-trl-sft-MMIU-Benchmark \
    --bf16 \
    --torch_dtype bfloat16 \
    --use_peft \
    --lora_target_modules all-linear
    --attn_implementation eager


    
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
accelerate launch --num_processes 4 --config_file /home/ubuntu/gemma-finetune/examples/accelerate_configs/deepspeed_zero3.yaml /home/ubuntu/gemma-finetune/scripts/sft_vlm_gemma3.py --dataset_name FanqingM/MMIU-Benchmark --dataset_train_split test --model_name_or_path google/gemma-3-4b-it --per_device_train_batch_size 1 --gradient_accumulation_steps 4 --bf16 --torch_dtype bfloat16 --use_peft --lora_target_modules q_proj,v_proj,k_proj,o_proj --attn_implementation eager

    accelerate launch   --num_processes 4   --config_file /home/ubuntu/gemma-finetune/examples/accelerate_configs/deepspeed_zero3.yaml   /home/ubuntu/gemma-finetune/scripts/sft_vlm_gemma3.py   --dataset_name FanqingM/MMIU-Benchmark   --dataset_train_split test   --model_name_or_path google/gemma-3-4b-it   --per_device_train_batch_size 1   --gradient_accumulation_steps 1   --output_dir gemma-3-4b-it-trl-sft-MMIU-Benchmark   --bf16   --torch_dtype bfloat16   --use_peft   --lora_target_modules all-linear   --attn_implementation eager
"""

import io
import os
import zipfile

import torch
from datasets import DatasetDict, load_dataset
from huggingface_hub import hf_hub_download, list_repo_files
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig
from peft import LoraConfig, PeftConfig
import json
from datasets import Dataset, DatasetDict
from accelerate import Accelerator
accelerator = Accelerator()
print(f"Running on rank: {accelerator.local_process_index}, world size: {accelerator.num_processes}")
rank = accelerator.local_process_index

from trl import (
    ModelConfig,
    ScriptArguments,
    SFTConfig,
    SFTTrainer,
    TrlParser,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
)


# For multi-image example
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
    image_inputs = []
    for msg in messages:
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]

        for element in content:
            if isinstance(element, dict) and ("image" in element or element.get("type") == "image"):
                if "image" in element:
                    image = element["image"]
                else:
                    image = element
                if image is not None:
                    image = Image.open(io.BytesIO(image["bytes"]))
                    image_inputs.append(image.convert("RGB"))
    return image_inputs


def format_data(samples: dict[str, any]) -> dict[str, list]:
    print(f"Rank {rank}: Processing samples: {samples['question']}")
    formatted_samples = {"messages": []}
    for cont in range(len(samples["question"])):
        images = []
        max_images = 10
        for img_path in samples["input_image_path"][cont][:max_images]:
            try:
                with open(img_path, "rb") as f:
                    img_bytes = f.read()
                image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
                images.append({"type": "image", "image": image})
            except Exception as e:
                print(f"Error processing image {img_path}: {e}")
                continue

        formatted_samples["messages"].append(
            [
                {"role": "system", "content": [{"type": "text", "text": samples["context"][cont]}]},
                {"role": "user", "content": images + [{"type": "text", "text": samples["question"][cont]}]},
                {"role": "assistant", "content": [{"type": "text", "text": samples["output"][cont]}]},
            ]
        )
        print(formatted_samples)
    return formatted_samples


# For multi-image example
def prepare_dataset(dataset: DatasetDict, dataset_name: str, dataset_train_split: str) -> DatasetDict:
    all_files = list_repo_files(dataset_name, repo_type="dataset")
    zip_files = [f for f in all_files if f.endswith(".zip")]

    for zip_filename in zip_files:
        zip_path = hf_hub_download(repo_id=dataset_name, filename=zip_filename, repo_type="dataset")
        extract_folder = zip_filename.replace(".zip", "")
        if os.path.exists(extract_folder):
            continue
        else:
            os.makedirs(extract_folder, exist_ok=True)

            with zipfile.ZipFile(zip_path, "r") as zip_ref:
                zip_ref.extractall(extract_folder)
    dataset['test'] = dataset['test'].select(range(1))
    print(dataset)
    dataset = dataset.map(format_data, batched=True, batch_size=1)
    # dataset.save_to_disk("cached_dataset")
    return dataset

def prepare_custom_dataset(path):
    with open(path, "r") as f:
        data = json.load(f)

    # If you want to create a train/test split manually,
    # for example split 80/20:
    train_size = 0 # int(0.8 * len(data))
    train_data = data[:train_size]
    test_data = data[train_size:]

    # Create Hugging Face Dataset objects
    train_dataset = Dataset.from_list(train_data)
    test_dataset = Dataset.from_list(test_data)

    # Wrap into DatasetDict
    dataset = DatasetDict({
        #"train": train_dataset,
        "test": test_dataset,
    })
    print(dataset)
    dataset = dataset.map(format_data, batched=True, batch_size=1)
    return dataset


def main():
 
    parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
    script_args, training_args, model_args = parser.parse_args_and_config()
    training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
    training_args.remove_unused_columns = False
    training_args.dataset_kwargs = {"skip_prepare_dataset": True}


    # START: Added checkpoint saving configuration
    training_args.save_strategy = "epoch"    
    training_args.save_total_limit = 3   
    # END: Added checkpoint saving configuration
  
     
 
    ################
    # Model, Tokenizer & Processor
    print("Loading model, tokenizer and processor...")
    ################
    torch_dtype = (
        model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
    )
    quantization_config = get_quantization_config(model_args)
    #print(quantization_config)
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,  # or False for 8-bit
        bnb_4bit_quant_type="nf4",  # or "fp4"
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype="bfloat16",  # or torch.float16
    )
    model_kwargs = dict(
        revision=model_args.model_revision,
        attn_implementation=model_args.attn_implementation,
        torch_dtype=torch_dtype,
        #device_map='auto',
        quantization_config=bnb_config,
    )
    processor = AutoProcessor.from_pretrained(
    model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
    )
    processor.tokenizer.padding_side = "right"
    model = AutoModelForImageTextToText.from_pretrained(
        model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
    )
    print("Model, tokenizer and processor loaded successfully!")
    def collate_fn(examples):
        texts = [
            processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False).strip()
            for example in examples
        ]
        if "images" in examples[0]:  # single-image
            images = [[img.convert("RGB") for img in example["images"]] for example in examples]
        else:  # multi-image
            images = [process_vision_info(example["messages"]) for example in examples]

        # Tokenize the texts and process the images
        batch = processor(
            text=texts, images=images, return_tensors="pt", padding=True
        )  # Encode texts and images into tensors

        # The labels are the input_ids, and we mask the padding tokens in the loss computation
        labels = batch["input_ids"].clone()  # Clone input IDs for labels
        # Mask image tokens
        image_token_id = [
            processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.special_tokens_map["boi_token"])
        ]
        # Mask tokens for not being used in the loss computation
        labels[labels == processor.tokenizer.pad_token_id] = -100
        labels[labels == image_token_id] = -100
        labels[labels == 262144] = -100

        batch["labels"] = labels
        return batch  # Return the prepared batch

    ################
    # Dataset
    ################

    print("Loading dataset...")
    ## DEFAULT DATASET
    dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
    if script_args.dataset_name == "FanqingM/MMIU-Benchmark":
       dataset = prepare_dataset(dataset, script_args.dataset_name, script_args.dataset_train_split)
    
    ## BARC SAMPLE DATASET
    # dataset = prepare_custom_dataset('/home/ubuntu/barc_client_data/Training/gemma_sample_dataset_colors/Chunk_colors/gemma_dataset.json')

    def check_dataloader(trainer):
        dataloader = trainer.get_train_dataloader()
        sampler = dataloader.sampler
        print(f"Rank {accelerator.local_process_index}: Using sampler {type(sampler).__name__}")
        if isinstance(sampler, DistributedSampler):
            indices = list(sampler)
            print(f"Rank {accelerator.local_process_index}: Dataset size: {len(sampler.dataset)}, Samples per replica: {sampler.num_samples}, Replicas: {sampler.num_replicas}")
            print(f"Rank {accelerator.local_process_index}: Assigned sample indices: {indices}")
            # Log sample details for this rank
            for idx in indices:
                sample = sampler.dataset[idx]
                sample_id = sample.get('id', idx)
                question = sample['messages'][1]['content'][-1]['text'] if 'messages' in sample else 'Unknown'
                print(f"Rank {accelerator.local_process_index}: Sample index {idx}, ID: {sample_id}, Question: {question}")
    ################
    # Training
    ################
    
    print("Starting training...")
    model_args.lora_r = 2 # Reduce rank from 8 to 4
    model_args.lora_target_modules = ["q_proj", "v_proj", "k_proj", "o_proj"]
    peft_config = LoraConfig(
        r=model_args.lora_r,
        target_modules=model_args.lora_target_modules,
        lora_alpha=8,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM"
    )


    trainer = SFTTrainer(
        model=model,
        args=training_args,
        data_collator=collate_fn,
        train_dataset=dataset[script_args.dataset_train_split],
        eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
        processing_class=processor.tokenizer,
        peft_config=peft_config,
        eval_accumulation_steps=1
    )
    #check_dataloader(trainer)
    print("Training started...")
    trainer.train()

    # Save and push to hub
    trainer.save_model(training_args.output_dir)
    if training_args.push_to_hub:
        trainer.push_to_hub(dataset_name=script_args.dataset_name)
        if trainer.accelerator.is_main_process:
            processor.push_to_hub(training_args.hub_model_id)


if __name__ == "__main__":
    main()

Is this Internal SFTTrainer or accelerate issue ?

System Info

[2025-05-22 14:44:09,766] [INFO] [real_accelerator.py:239:get_accelerator] Setting ds_accelerator to cuda (auto detect)

Copy-paste the following information when reporting an issue:

  • Platform: Linux-6.8.0-1021-aws-x86_64-with-glibc2.35
  • Python version: 3.12.8
  • TRL version: 0.17.0
  • PyTorch version: 2.7.0
  • CUDA device(s): NVIDIA H200, NVIDIA H200, NVIDIA H200, NVIDIA H200, NVIDIA H200, NVIDIA H200, NVIDIA H200, NVIDIA H200
  • Transformers version: 4.52.2
  • Accelerate version: 1.7.0
  • Accelerate config:
    • compute_environment: LOCAL_MACHINE
    • distributed_type: DEEPSPEED
    • mixed_precision: bf16
    • use_cpu: False
    • debug: True
    • num_processes: 8
    • machine_rank: 0
    • num_machines: 1
    • rdzv_backend: static
    • same_network: True
    • main_training_function: main
    • enable_cpu_affinity: False
    • deepspeed_config: {'deepspeed_moe_layer_cls_names': '', 'gradient_accumulation_steps': 1, 'gradient_clipping': 8.0, 'offload_optimizer_device': 'cpu', 'offload_param_device': 'none', 'zero3_init_flag': True, 'zero3_save_16bit_model': True, 'zero_stage': 3}
    • downcast_bf16: no
    • tpu_use_cluster: False
    • tpu_use_sudo: False
    • tpu_env: []
    • dynamo_config: {'dynamo_backend': 'INDUCTOR'}
  • Datasets version: 3.6.0
  • HF Hub version: 0.31.4
  • bitsandbytes version: 0.45.5
  • DeepSpeed version: 0.16.8
  • Diffusers version: not installed
  • Liger-Kernel version: not installed
  • LLM-Blender version: not installed
  • OpenAI version: not installed
  • PEFT version: 0.15.2
  • vLLM version: not installed

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete
@github-actions github-actions bot added ⚡ PEFT Related to PEFT ⚡accelerate Related to accelerate 🐛 bug Something isn't working labels May 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
⚡accelerate Related to accelerate 🐛 bug Something isn't working ⚡ PEFT Related to PEFT
Projects
None yet
Development

No branches or pull requests

1 participant