Skip to content

Commit 5151fd5

Browse files
authored
New options for preference tuning: rpo alpha, logprobs normalization, reference-free, simpo gamma (#327)
* Add dpo improvements arguments * Version bump (tmp, dev) * Implicit setting of `reference_free` in case if simpo_gamma is set * Fix unbound variable * Fix * Force normalization for simpo * Version bump * Formatting * Version fix * Remove reference-free from dpo * Review fixes * Formatting * Fixes
1 parent ecd68a4 commit 5151fd5

File tree

4 files changed

+89
-2
lines changed

4 files changed

+89
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
1212

1313
[tool.poetry]
1414
name = "together"
15-
version = "1.5.13"
15+
version = "1.5.14"
1616
authors = ["Together AI <[email protected]>"]
1717
description = "Python client for Together's Cloud Platform!"
1818
readme = "README.md"

src/together/cli/api/finetune.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,30 @@ def fine_tuning(ctx: click.Context) -> None:
142142
default=0.1,
143143
help="Beta parameter for DPO training (only used when '--training-method' is 'dpo')",
144144
)
145+
@click.option(
146+
"--dpo-normalize-logratios-by-length",
147+
type=bool,
148+
default=False,
149+
help=(
150+
"Whether to normalize logratios by sample length "
151+
"(only used when '--training-method' is 'dpo')"
152+
),
153+
)
154+
@click.option(
155+
"--rpo-alpha",
156+
type=float,
157+
default=0.0,
158+
help=(
159+
"RPO alpha parameter of DPO training to include NLL in the loss "
160+
"(only used when '--training-method' is 'dpo')"
161+
),
162+
)
163+
@click.option(
164+
"--simpo-gamma",
165+
type=float,
166+
default=0.1,
167+
help="SimPO gamma parameter (only used when '--training-method' is 'dpo')",
168+
)
145169
@click.option(
146170
"--suffix",
147171
"-s",
@@ -206,6 +230,9 @@ def create(
206230
train_on_inputs: bool | Literal["auto"],
207231
training_method: str,
208232
dpo_beta: float,
233+
dpo_normalize_logratios_by_length: bool,
234+
rpo_alpha: float,
235+
simpo_gamma: float,
209236
from_checkpoint: str,
210237
) -> None:
211238
"""Start fine-tuning"""
@@ -239,6 +266,9 @@ def create(
239266
train_on_inputs=train_on_inputs,
240267
training_method=training_method,
241268
dpo_beta=dpo_beta,
269+
dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length,
270+
rpo_alpha=rpo_alpha,
271+
simpo_gamma=simpo_gamma,
242272
from_checkpoint=from_checkpoint,
243273
)
244274

src/together/resources/finetune.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ def create_finetune_request(
7272
train_on_inputs: bool | Literal["auto"] | None = None,
7373
training_method: str = "sft",
7474
dpo_beta: float | None = None,
75+
dpo_normalize_logratios_by_length: bool = False,
76+
rpo_alpha: float | None = None,
77+
simpo_gamma: float | None = None,
7578
from_checkpoint: str | None = None,
7679
) -> FinetuneRequest:
7780
if model is not None and from_checkpoint is not None:
@@ -182,6 +185,21 @@ def create_finetune_request(
182185

183186
if dpo_beta is not None and training_method != "dpo":
184187
raise ValueError("dpo_beta is only supported for DPO training")
188+
if dpo_normalize_logratios_by_length and training_method != "dpo":
189+
raise ValueError(
190+
"dpo_normalize_logratios_by_length=True is only supported for DPO training"
191+
)
192+
if rpo_alpha is not None:
193+
if training_method != "dpo":
194+
raise ValueError("rpo_alpha is only supported for DPO training")
195+
if not rpo_alpha >= 0.0:
196+
raise ValueError(f"rpo_alpha should be non-negative (got {rpo_alpha})")
197+
198+
if simpo_gamma is not None:
199+
if training_method != "dpo":
200+
raise ValueError("simpo_gamma is only supported for DPO training")
201+
if not simpo_gamma >= 0.0:
202+
raise ValueError(f"simpo_gamma should be non-negative (got {simpo_gamma})")
185203

186204
lr_scheduler: FinetuneLRScheduler
187205
if lr_scheduler_type == "cosine":
@@ -204,7 +222,24 @@ def create_finetune_request(
204222
if training_method == "sft":
205223
training_method_cls = TrainingMethodSFT(train_on_inputs=train_on_inputs)
206224
elif training_method == "dpo":
207-
training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta)
225+
if simpo_gamma is not None and simpo_gamma > 0:
226+
dpo_reference_free = True
227+
dpo_normalize_logratios_by_length = True
228+
rprint(
229+
f"Parameter simpo_gamma was set to {simpo_gamma}. "
230+
"SimPO training detected. Reference logits will not be used "
231+
"and length normalization of log-probabilities will be enabled."
232+
)
233+
else:
234+
dpo_reference_free = False
235+
236+
training_method_cls = TrainingMethodDPO(
237+
dpo_beta=dpo_beta,
238+
dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length,
239+
dpo_reference_free=dpo_reference_free,
240+
rpo_alpha=rpo_alpha,
241+
simpo_gamma=simpo_gamma,
242+
)
208243

209244
finetune_request = FinetuneRequest(
210245
model=model,
@@ -302,6 +337,9 @@ def create(
302337
train_on_inputs: bool | Literal["auto"] | None = None,
303338
training_method: str = "sft",
304339
dpo_beta: float | None = None,
340+
dpo_normalize_logratios_by_length: bool = False,
341+
rpo_alpha: float | None = None,
342+
simpo_gamma: float | None = None,
305343
from_checkpoint: str | None = None,
306344
) -> FinetuneResponse:
307345
"""
@@ -353,6 +391,9 @@ def create(
353391
training_method (str, optional): Training method. Defaults to "sft".
354392
Supported methods: "sft", "dpo".
355393
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
394+
dpo_normalize_logratios_by_length (bool): Whether or not normalize logratios by sample length. Defaults to False,
395+
rpo_alpha (float, optional): RPO alpha parameter of DPO training to include NLL in the loss. Defaults to None.
396+
simpo_gamma: (float, optional): SimPO gamma parameter. Defaults to None.
356397
from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
357398
The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
358399
The step value is optional, without it the final checkpoint will be used.
@@ -405,6 +446,9 @@ def create(
405446
train_on_inputs=train_on_inputs,
406447
training_method=training_method,
407448
dpo_beta=dpo_beta,
449+
dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length,
450+
rpo_alpha=rpo_alpha,
451+
simpo_gamma=simpo_gamma,
408452
from_checkpoint=from_checkpoint,
409453
)
410454

@@ -714,6 +758,9 @@ async def create(
714758
train_on_inputs: bool | Literal["auto"] | None = None,
715759
training_method: str = "sft",
716760
dpo_beta: float | None = None,
761+
dpo_normalize_logratios_by_length: bool = False,
762+
rpo_alpha: float | None = None,
763+
simpo_gamma: float | None = None,
717764
from_checkpoint: str | None = None,
718765
) -> FinetuneResponse:
719766
"""
@@ -765,6 +812,9 @@ async def create(
765812
training_method (str, optional): Training method. Defaults to "sft".
766813
Supported methods: "sft", "dpo".
767814
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
815+
dpo_normalize_logratios_by_length (bool): Whether or not normalize logratios by sample length. Defaults to False,
816+
rpo_alpha (float, optional): RPO alpha parameter of DPO training to include NLL in the loss. Defaults to None.
817+
simpo_gamma: (float, optional): SimPO gamma parameter. Defaults to None.
768818
from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
769819
The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
770820
The step value is optional, without it the final checkpoint will be used.
@@ -817,6 +867,9 @@ async def create(
817867
train_on_inputs=train_on_inputs,
818868
training_method=training_method,
819869
dpo_beta=dpo_beta,
870+
dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length,
871+
rpo_alpha=rpo_alpha,
872+
simpo_gamma=simpo_gamma,
820873
from_checkpoint=from_checkpoint,
821874
)
822875

src/together/types/finetune.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ class TrainingMethodDPO(TrainingMethod):
159159

160160
method: Literal["dpo"] = "dpo"
161161
dpo_beta: float | None = None
162+
dpo_normalize_logratios_by_length: bool = False
163+
dpo_reference_free: bool = False
164+
rpo_alpha: float | None = None
165+
simpo_gamma: float | None = None
162166

163167

164168
class FinetuneRequest(BaseModel):

0 commit comments

Comments
 (0)