Skip to content

Commit a528b9c

Browse files
kashiflewtun
andauthored
[NashMD] fix the edge case where the model is a peft model (#3473)
Co-authored-by: lewtun <[email protected]>
1 parent e0dd525 commit a528b9c

File tree

4 files changed

+126
-23
lines changed

4 files changed

+126
-23
lines changed

tests/test_nash_md_trainer.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,38 @@ def test_training_with_peft_model_and_peft_config(self):
160160
# Check if training loss is available
161161
self.assertIn("train_loss", trainer.state.log_history[-1])
162162

163+
@require_peft
164+
def test_training_pre_pefted_model_implicit_ref_with_reward_model(self):
165+
lora_config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM")
166+
# self.model from setUp is a base AutoModelForCausalLM
167+
peft_model_instance = get_peft_model(self.model, lora_config)
168+
169+
with tempfile.TemporaryDirectory() as tmp_dir:
170+
training_args = NashMDConfig(
171+
output_dir=tmp_dir,
172+
per_device_train_batch_size=1, # Keep small for quick test
173+
max_steps=2, # Few steps
174+
learning_rate=5.0e-7,
175+
eval_strategy="no",
176+
report_to="none",
177+
remove_unused_columns=False, # Important for the dummy dataset
178+
)
179+
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")["train"]
180+
181+
trainer = NashMDTrainer(
182+
model=peft_model_instance, # Pass the already PEFT model
183+
ref_model=None, # Implicit reference from peft_model_instance's base
184+
reward_model=self.reward_model, # To trigger GeometricMixtureWrapper path
185+
args=training_args,
186+
processing_class=self.tokenizer,
187+
train_dataset=dummy_dataset,
188+
# peft_config is not passed, as model is already PEFT
189+
)
190+
191+
trainer.train()
192+
193+
self.assertIn("train_loss", trainer.state.log_history[-1])
194+
163195
@parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)])
164196
@require_llm_blender
165197
def test_nash_md_trainer_judge_training(self, config_name):

tests/test_xpo_trainer.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,36 @@ def test_training_with_peft_model_and_peft_config(self):
160160
# Check if training loss is available
161161
self.assertIn("train_loss", trainer.state.log_history[-1])
162162

163+
@require_peft
164+
def test_training_pre_pefted_model_implicit_ref(self):
165+
lora_config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM")
166+
peft_model_instance = get_peft_model(self.model, lora_config)
167+
168+
with tempfile.TemporaryDirectory() as tmp_dir:
169+
training_args = XPOConfig(
170+
output_dir=tmp_dir,
171+
per_device_train_batch_size=1,
172+
max_steps=2,
173+
learning_rate=5.0e-7,
174+
eval_strategy="no",
175+
report_to="none",
176+
remove_unused_columns=False,
177+
)
178+
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")["train"]
179+
180+
trainer = XPOTrainer(
181+
model=peft_model_instance,
182+
ref_model=None,
183+
reward_model=self.reward_model, # Using reward_model to ensure _generate_completions is used as expected
184+
args=training_args,
185+
processing_class=self.tokenizer,
186+
train_dataset=dummy_dataset,
187+
)
188+
189+
trainer.train()
190+
191+
self.assertIn("train_loss", trainer.state.log_history[-1])
192+
163193
@require_llm_blender
164194
@parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)])
165195
def test_xpo_trainer_judge_training(self, config_name):

trl/trainer/nash_md_trainer.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
)
3333
from transformers.trainer_utils import EvalPrediction
3434
from transformers.training_args import OptimizerNames
35-
from transformers.utils import is_apex_available
35+
from transformers.utils import is_apex_available, is_peft_available
3636

3737
from ..data_utils import is_conversational, maybe_apply_chat_template
3838
from ..models.modeling_base import GeometricMixtureWrapper
@@ -59,6 +59,10 @@
5959
import wandb
6060

6161

62+
if is_peft_available():
63+
from peft import PeftModel
64+
65+
6266
class NashMDTrainer(OnlineDPOTrainer):
6367
r"""
6468
Initialize NashMDTrainer as a subclass of [`OnlineDPOConfig`].
@@ -170,28 +174,50 @@ def mixture_coef(self):
170174
return self._mixture_coef
171175

172176
def _generate_completions(self, model, prompts):
173-
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
174-
model_output = unwrapped_model.generate(
177+
# Generate completions from the policy model.
178+
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_for_gen_ctx:
179+
model_output = unwrapped_policy_for_gen_ctx.generate(
175180
input_ids=prompts["input_ids"],
176181
attention_mask=prompts["attention_mask"],
177182
generation_config=self.generation_config,
178183
)
179184

180-
ref_model = model if self.ref_model is None else self.ref_model
181-
with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
182-
mixture_model = GeometricMixtureWrapper(
183-
model=unwrapped_model,
184-
ref_model=unwrapped_ref_model,
185-
generation_config=self.generation_config,
186-
mixture_coef=self.mixture_coef,
187-
device=self.accelerator.device,
188-
)
189-
190-
mixture_output = mixture_model.generate(
191-
input_ids=prompts["input_ids"],
192-
attention_mask=prompts["attention_mask"],
193-
generation_config=self.generation_config,
194-
)
185+
# Get the DDP/FSDP unwrapped version of the main model.
186+
# This will be the policy model for GeometricMixtureWrapper (PEFT adapters active if PEFT is used).
187+
policy_model_for_gmw = self.accelerator.unwrap_model(model)
188+
189+
# Determine the correct reference model for GeometricMixtureWrapper.
190+
# This also needs to be DDP/FSDP unwrapped.
191+
ref_model_for_gmw: torch.nn.Module
192+
if self.ref_model is None:
193+
# No explicit ref_model is provided.
194+
# Use the base of the main `model` if it's a PEFT model.
195+
# policy_model_for_gmw is already DDP-unwrapped.
196+
if is_peft_available() and isinstance(policy_model_for_gmw, PeftModel):
197+
ref_model_for_gmw = policy_model_for_gmw.get_base_model()
198+
else:
199+
# Not a PEFT model (or PEFT not available), or already a base model.
200+
# Use the DDP-unwrapped policy model itself as the reference.
201+
ref_model_for_gmw = policy_model_for_gmw
202+
else:
203+
# An explicit ref_model is provided. Unwrap it for DDP/FSDP.
204+
ref_model_for_gmw = self.accelerator.unwrap_model(self.ref_model)
205+
206+
# Both models given to GeometricMixtureWrapper (policy_model_for_gmw and ref_model_for_gmw) are DDP-unwrapped.
207+
with torch.no_grad(): # Ensure no_grad context for mixture model generation
208+
mixture_model = GeometricMixtureWrapper(
209+
model=policy_model_for_gmw,
210+
ref_model=ref_model_for_gmw,
211+
generation_config=self.generation_config,
212+
mixture_coef=self.mixture_coef,
213+
device=self.accelerator.device,
214+
)
215+
216+
mixture_output = mixture_model.generate(
217+
input_ids=prompts["input_ids"],
218+
attention_mask=prompts["attention_mask"],
219+
generation_config=self.generation_config,
220+
)
195221

196222
return model_output, mixture_output
197223

trl/trainer/xpo_trainer.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
)
3434
from transformers.trainer_utils import EvalPrediction
3535
from transformers.training_args import OptimizerNames
36+
from transformers.utils import is_peft_available
3637

3738
from ..data_utils import is_conversational, maybe_apply_chat_template
3839
from ..models.utils import unwrap_model_for_generation
@@ -58,6 +59,10 @@
5859
import wandb
5960

6061

62+
if is_peft_available():
63+
from peft import PeftModel
64+
65+
6166
class XPOTrainer(OnlineDPOTrainer):
6267
r"""
6368
Initialize XPOTrainer as a subclass of [`OnlineDPOConfig`].
@@ -174,16 +179,26 @@ def alpha(self):
174179
return self._alpha
175180

176181
def _generate_completions(self, prompts, model):
177-
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
178-
model_output = unwrapped_model.generate(
182+
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_model_for_gen:
183+
model_output = unwrapped_policy_model_for_gen.generate(
179184
input_ids=prompts["input_ids"],
180185
attention_mask=prompts["attention_mask"],
181186
generation_config=self.generation_config,
182187
)
183188

184-
ref_model = model if self.ref_model is None else self.ref_model
185-
with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
186-
ref_output = unwrapped_ref_model.generate(
189+
actual_model_for_ref_generation: torch.nn.Module
190+
if self.ref_model is None:
191+
unwrapped_main_model_for_ref_logic = self.accelerator.unwrap_model(model)
192+
193+
if is_peft_available() and isinstance(unwrapped_main_model_for_ref_logic, PeftModel):
194+
actual_model_for_ref_generation = unwrapped_main_model_for_ref_logic.get_base_model()
195+
else:
196+
actual_model_for_ref_generation = unwrapped_main_model_for_ref_logic
197+
else:
198+
actual_model_for_ref_generation = self.accelerator.unwrap_model(self.ref_model)
199+
200+
with unwrap_model_for_generation(actual_model_for_ref_generation, self.accelerator) as final_ref_model_for_gen:
201+
ref_output = final_ref_model_for_gen.generate(
187202
input_ids=prompts["input_ids"],
188203
attention_mask=prompts["attention_mask"],
189204
generation_config=self.generation_config,

0 commit comments

Comments
 (0)