Skip to content

🤧 LD-DPO support #3458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 27, 2025
Merged
4 changes: 4 additions & 0 deletions docs/source/dpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterativ

The [WPO](https://huggingface.co/papers/2406.11827) paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the `use_weighting` flag to `True` in the [`DPOConfig`].

### LD-DPO loss

The [LD-DPO](https://huggingface.co/papers/2409.06411) paper decomposes the portion of the response that exceeds the desired length into two components — human-like preferences and verbosity preference — based on a mixing coefficient \\( \alpha \\). To use this method, set the `ld_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this value between `0.0` and `1.0`.

### For Mixture of Experts Models: Enabling the auxiliary loss

MOEs are the most efficient if the load is about equally distributed between experts.
Expand Down
31 changes: 31 additions & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1258,6 +1258,37 @@ def dummy_compute_metrics(*args, **kwargs):

self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0)

def test_train_with_length_desensitization(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
tokenizer = AutoTokenizer.from_pretrained(model_id)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
learning_rate=9e-1,
ld_alpha=0.5,
report_to="none",
)
trainer = DPOTrainer(
model=model_id,
args=training_args,
processing_class=tokenizer,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check that the parameters have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))


@require_vision
class DPOVisionTrainerTester(unittest.TestCase):
Expand Down
19 changes: 16 additions & 3 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,19 @@ class DPOConfig(TrainingArguments):
Whether to ignore the provided reference model and implicitly use a reference model that assigns equal
probability to all responses.
label_smoothing (`float`, *optional*, defaults to `0.0`):
Robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report and
Robust DPO label smoothing parameter from the [cDPO report](https://ericmitchell.ai/cdpo.pdf) and
[Robust DPO](https://huggingface.co/papers/2403.00409) paper that should be between `0.0` and `0.5`.
use_weighting (`bool`, *optional*, defaults to `False`):
Whether to weight the loss as done in the [WPO](https://huggingface.co/papers/2406.11827) paper.
Whether to weight the loss as done in the [WPO paper](https://huggingface.co/papers/2406.11827).
rpo_alpha (`float`, *optional*, defaults to `None`):
α parameter from the [RPO](https://huggingface.co/papers/2404.19733) paper (v3), which controls the
α parameter from the [RPO paper](https://huggingface.co/papers/2404.19733) (v3), which controls the
weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the
DPO loss. The paper recommends `rpo_alpha=1.0`.
ld_alpha (`float` or `None`, *optional*, defaults to `None`):
α parameter from the [LD-DPO paper](https://huggingface.co/papers/2409.06411), which controls the weighting
of the verbose token log-probabilities in responses. If `None`, no weighting is applied to the verbose
part, and the loss is equivalent to the standard DPO loss. The paper recommends setting `ld_alpha` between
`0.0` and `1.0`.
discopop_tau (`float`, *optional*, defaults to `0.05`):
τ/temperature parameter from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper, which controls
the shape of log ratio modulated loss. The paper recommends the default value `discopop_tau=0.05`.
Expand Down Expand Up @@ -346,6 +351,14 @@ class DPOConfig(TrainingArguments):
"`rpo_alpha=1.0`."
},
)
ld_alpha: Optional[float] = field(
default=None,
metadata={
"help": "α parameter from the LD-DPO paper, which controls the weighting of the verbose token "
"log-probabilities in responses. If `None`, no weighting is applied to the verbose part, and the loss is "
"equivalent to the standard DPO loss. The paper recommends setting `ld_alpha` between `0.0` and `1.0`.",
},
)
discopop_tau: float = field(
default=0.05,
metadata={
Expand Down
42 changes: 38 additions & 4 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,9 +804,9 @@ def compute_ref_log_probs(self, batch: dict[str, torch.LongTensor]) -> dict:
with torch.no_grad(), compte_ref_context_manager:
if self.ref_model is None:
with self.null_ref_context():
ref_model_output = self.concatenated_forward(self.model, batch)
ref_model_output = self.concatenated_forward(self.model, batch, is_ref_model=True)
else:
ref_model_output = self.concatenated_forward(self.ref_model, batch)
ref_model_output = self.concatenated_forward(self.ref_model, batch, is_ref_model=True)
return ref_model_output["chosen_logps"], ref_model_output["rejected_logps"]

@staticmethod
Expand Down Expand Up @@ -1066,10 +1066,22 @@ def dpo_loss(

return losses, chosen_rewards, rejected_rewards

def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]):
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
def concatenated_forward(
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]], is_ref_model: bool = False
):
"""
Runs the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.

We do this to avoid doing two forward passes, because it's faster for FSDP.

Args:
model:
Model to run the forward pass on.
batch:
Batch of input data.
is_ref_model:
Whether this method is being called for the reference model. If `True`, length desensitization is not
applied.
"""
num_examples = batch["prompt_input_ids"].shape[0]

Expand Down Expand Up @@ -1218,6 +1230,28 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to
if self.loss_type == "ipo":
all_logps = all_logps / loss_mask.sum(-1)

if self.args.ld_alpha is not None and not is_ref_model:
# Compute response lengths based on loss_mask
completion_lengths = loss_mask.sum(dim=1)

chosen_lengths = completion_lengths[:num_examples]
rejected_lengths = completion_lengths[num_examples:]
public_lengths = torch.min(chosen_lengths, rejected_lengths) # l_p in the paper
public_lengths = torch.cat([public_lengths, public_lengths], dim=0)

seq_len = per_token_logps.size(1)
position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps)

ld_mask = position_ids < public_lengths.unsqueeze(1)
mask = position_ids < completion_lengths.unsqueeze(1)

front_mask = (ld_mask & mask).float()
rear_mask = (~ld_mask & mask).float()
front_logps = (per_token_logps * front_mask).sum(dim=1)
rear_logps = (per_token_logps * rear_mask).sum(dim=1)

all_logps = front_logps + self.args.ld_alpha * rear_logps

output["chosen_logps"] = all_logps[:num_examples]
output["rejected_logps"] = all_logps[num_examples:]

Expand Down
Loading