We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
I'm trying to fientune Gemma 3 on a vision dataset, based on https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py
How ever i get out of memory issue even if i try with 2-3 training samples.. Below is the Code i used
# Copyright 2020-2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Train Gemma-3 on the HuggingFaceH4/llava-instruct-mix-vsft dataset (single-image). accelerate launch \ --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ examples/scripts/sft_vlm_gemma3.py \ --dataset_name HuggingFaceH4/llava-instruct-mix-vsft \ --model_name_or_path google/gemma-3-4b-it \ --per_device_train_batch_size 1 \ --gradient_accumulation_steps 1 \ --output_dir gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft \ --bf16 \ --torch_dtype bfloat16 \ --use_peft \ --lora_target_modules all-linear \ --attn_implementation eager Train Gemma-3 on the FanqingM/MMIU-Benchmark dataset (multi-image). accelerate launch \ --config_file /home/ubuntu/gemma-finetune/examples/accelerate_configs/deepspeed_zero3.yaml \ /home/ubuntu/gemma-finetune/scripts/sft_vlm_gemma3.py \ --dataset_name FanqingM/MMIU-Benchmark \ --dataset_train_split test \ --model_name_or_path google/gemma-3-4b-it \ --per_device_train_batch_size 1 \ --gradient_accumulation_steps 1 \ --output_dir gemma-3-4b-it-trl-sft-MMIU-Benchmark \ --bf16 \ --torch_dtype bfloat16 \ --use_peft \ --lora_target_modules all-linear --attn_implementation eager export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True accelerate launch --num_processes 4 --config_file /home/ubuntu/gemma-finetune/examples/accelerate_configs/deepspeed_zero3.yaml /home/ubuntu/gemma-finetune/scripts/sft_vlm_gemma3.py --dataset_name FanqingM/MMIU-Benchmark --dataset_train_split test --model_name_or_path google/gemma-3-4b-it --per_device_train_batch_size 1 --gradient_accumulation_steps 4 --bf16 --torch_dtype bfloat16 --use_peft --lora_target_modules q_proj,v_proj,k_proj,o_proj --attn_implementation eager accelerate launch --num_processes 4 --config_file /home/ubuntu/gemma-finetune/examples/accelerate_configs/deepspeed_zero3.yaml /home/ubuntu/gemma-finetune/scripts/sft_vlm_gemma3.py --dataset_name FanqingM/MMIU-Benchmark --dataset_train_split test --model_name_or_path google/gemma-3-4b-it --per_device_train_batch_size 1 --gradient_accumulation_steps 1 --output_dir gemma-3-4b-it-trl-sft-MMIU-Benchmark --bf16 --torch_dtype bfloat16 --use_peft --lora_target_modules all-linear --attn_implementation eager """ import io import os import zipfile import torch from datasets import DatasetDict, load_dataset from huggingface_hub import hf_hub_download, list_repo_files from PIL import Image from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig from peft import LoraConfig, PeftConfig import json from datasets import Dataset, DatasetDict from accelerate import Accelerator accelerator = Accelerator() print(f"Running on rank: {accelerator.local_process_index}, world size: {accelerator.num_processes}") rank = accelerator.local_process_index from trl import ( ModelConfig, ScriptArguments, SFTConfig, SFTTrainer, TrlParser, get_kbit_device_map, get_peft_config, get_quantization_config, ) # For multi-image example def process_vision_info(messages: list[dict]) -> list[Image.Image]: image_inputs = [] for msg in messages: content = msg.get("content", []) if not isinstance(content, list): content = [content] for element in content: if isinstance(element, dict) and ("image" in element or element.get("type") == "image"): if "image" in element: image = element["image"] else: image = element if image is not None: image = Image.open(io.BytesIO(image["bytes"])) image_inputs.append(image.convert("RGB")) return image_inputs def format_data(samples: dict[str, any]) -> dict[str, list]: print(f"Rank {rank}: Processing samples: {samples['question']}") formatted_samples = {"messages": []} for cont in range(len(samples["question"])): images = [] max_images = 10 for img_path in samples["input_image_path"][cont][:max_images]: try: with open(img_path, "rb") as f: img_bytes = f.read() image = Image.open(io.BytesIO(img_bytes)).convert("RGB") images.append({"type": "image", "image": image}) except Exception as e: print(f"Error processing image {img_path}: {e}") continue formatted_samples["messages"].append( [ {"role": "system", "content": [{"type": "text", "text": samples["context"][cont]}]}, {"role": "user", "content": images + [{"type": "text", "text": samples["question"][cont]}]}, {"role": "assistant", "content": [{"type": "text", "text": samples["output"][cont]}]}, ] ) print(formatted_samples) return formatted_samples # For multi-image example def prepare_dataset(dataset: DatasetDict, dataset_name: str, dataset_train_split: str) -> DatasetDict: all_files = list_repo_files(dataset_name, repo_type="dataset") zip_files = [f for f in all_files if f.endswith(".zip")] for zip_filename in zip_files: zip_path = hf_hub_download(repo_id=dataset_name, filename=zip_filename, repo_type="dataset") extract_folder = zip_filename.replace(".zip", "") if os.path.exists(extract_folder): continue else: os.makedirs(extract_folder, exist_ok=True) with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(extract_folder) dataset['test'] = dataset['test'].select(range(1)) print(dataset) dataset = dataset.map(format_data, batched=True, batch_size=1) # dataset.save_to_disk("cached_dataset") return dataset def prepare_custom_dataset(path): with open(path, "r") as f: data = json.load(f) # If you want to create a train/test split manually, # for example split 80/20: train_size = 0 # int(0.8 * len(data)) train_data = data[:train_size] test_data = data[train_size:] # Create Hugging Face Dataset objects train_dataset = Dataset.from_list(train_data) test_dataset = Dataset.from_list(test_data) # Wrap into DatasetDict dataset = DatasetDict({ #"train": train_dataset, "test": test_dataset, }) print(dataset) dataset = dataset.map(format_data, batched=True, batch_size=1) return dataset def main(): parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) script_args, training_args, model_args = parser.parse_args_and_config() training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) training_args.remove_unused_columns = False training_args.dataset_kwargs = {"skip_prepare_dataset": True} # START: Added checkpoint saving configuration training_args.save_strategy = "epoch" training_args.save_total_limit = 3 # END: Added checkpoint saving configuration ################ # Model, Tokenizer & Processor print("Loading model, tokenizer and processor...") ################ torch_dtype = ( model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) quantization_config = get_quantization_config(model_args) #print(quantization_config) bnb_config = BitsAndBytesConfig( load_in_4bit=True, # or False for 8-bit bnb_4bit_quant_type="nf4", # or "fp4" bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype="bfloat16", # or torch.float16 ) model_kwargs = dict( revision=model_args.model_revision, attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, #device_map='auto', quantization_config=bnb_config, ) processor = AutoProcessor.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True ) processor.tokenizer.padding_side = "right" model = AutoModelForImageTextToText.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs ) print("Model, tokenizer and processor loaded successfully!") def collate_fn(examples): texts = [ processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False).strip() for example in examples ] if "images" in examples[0]: # single-image images = [[img.convert("RGB") for img in example["images"]] for example in examples] else: # multi-image images = [process_vision_info(example["messages"]) for example in examples] # Tokenize the texts and process the images batch = processor( text=texts, images=images, return_tensors="pt", padding=True ) # Encode texts and images into tensors # The labels are the input_ids, and we mask the padding tokens in the loss computation labels = batch["input_ids"].clone() # Clone input IDs for labels # Mask image tokens image_token_id = [ processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.special_tokens_map["boi_token"]) ] # Mask tokens for not being used in the loss computation labels[labels == processor.tokenizer.pad_token_id] = -100 labels[labels == image_token_id] = -100 labels[labels == 262144] = -100 batch["labels"] = labels return batch # Return the prepared batch ################ # Dataset ################ print("Loading dataset...") ## DEFAULT DATASET dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) if script_args.dataset_name == "FanqingM/MMIU-Benchmark": dataset = prepare_dataset(dataset, script_args.dataset_name, script_args.dataset_train_split) ## BARC SAMPLE DATASET # dataset = prepare_custom_dataset('/home/ubuntu/barc_client_data/Training/gemma_sample_dataset_colors/Chunk_colors/gemma_dataset.json') def check_dataloader(trainer): dataloader = trainer.get_train_dataloader() sampler = dataloader.sampler print(f"Rank {accelerator.local_process_index}: Using sampler {type(sampler).__name__}") if isinstance(sampler, DistributedSampler): indices = list(sampler) print(f"Rank {accelerator.local_process_index}: Dataset size: {len(sampler.dataset)}, Samples per replica: {sampler.num_samples}, Replicas: {sampler.num_replicas}") print(f"Rank {accelerator.local_process_index}: Assigned sample indices: {indices}") # Log sample details for this rank for idx in indices: sample = sampler.dataset[idx] sample_id = sample.get('id', idx) question = sample['messages'][1]['content'][-1]['text'] if 'messages' in sample else 'Unknown' print(f"Rank {accelerator.local_process_index}: Sample index {idx}, ID: {sample_id}, Question: {question}") ################ # Training ################ print("Starting training...") model_args.lora_r = 2 # Reduce rank from 8 to 4 model_args.lora_target_modules = ["q_proj", "v_proj", "k_proj", "o_proj"] peft_config = LoraConfig( r=model_args.lora_r, target_modules=model_args.lora_target_modules, lora_alpha=8, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM" ) trainer = SFTTrainer( model=model, args=training_args, data_collator=collate_fn, train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, processing_class=processor.tokenizer, peft_config=peft_config, eval_accumulation_steps=1 ) #check_dataloader(trainer) print("Training started...") trainer.train() # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: trainer.push_to_hub(dataset_name=script_args.dataset_name) if trainer.accelerator.is_main_process: processor.push_to_hub(training_args.hub_model_id) if __name__ == "__main__": main()
Is this Internal SFTTrainer or accelerate issue ?
[2025-05-22 14:44:09,766] [INFO] [real_accelerator.py:239:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Copy-paste the following information when reporting an issue:
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Uh oh!
There was an error while loading. Please reload this page.
Reproduction
I'm trying to fientune Gemma 3 on a vision dataset, based on https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py
How ever i get out of memory issue even if i try with 2-3 training samples..
Below is the Code i used
Is this Internal SFTTrainer or accelerate issue ?
System Info
[2025-05-22 14:44:09,766] [INFO] [real_accelerator.py:239:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Copy-paste the following information when reporting an issue:
Checklist
The text was updated successfully, but these errors were encountered: