@@ -72,6 +72,9 @@ def create_finetune_request(
72
72
train_on_inputs : bool | Literal ["auto" ] | None = None ,
73
73
training_method : str = "sft" ,
74
74
dpo_beta : float | None = None ,
75
+ dpo_normalize_logratios_by_length : bool = False ,
76
+ rpo_alpha : float | None = None ,
77
+ simpo_gamma : float | None = None ,
75
78
from_checkpoint : str | None = None ,
76
79
) -> FinetuneRequest :
77
80
if model is not None and from_checkpoint is not None :
@@ -182,6 +185,21 @@ def create_finetune_request(
182
185
183
186
if dpo_beta is not None and training_method != "dpo" :
184
187
raise ValueError ("dpo_beta is only supported for DPO training" )
188
+ if dpo_normalize_logratios_by_length and training_method != "dpo" :
189
+ raise ValueError (
190
+ "dpo_normalize_logratios_by_length=True is only supported for DPO training"
191
+ )
192
+ if rpo_alpha is not None :
193
+ if training_method != "dpo" :
194
+ raise ValueError ("rpo_alpha is only supported for DPO training" )
195
+ if not rpo_alpha >= 0.0 :
196
+ raise ValueError (f"rpo_alpha should be non-negative (got { rpo_alpha } )" )
197
+
198
+ if simpo_gamma is not None :
199
+ if training_method != "dpo" :
200
+ raise ValueError ("simpo_gamma is only supported for DPO training" )
201
+ if not simpo_gamma >= 0.0 :
202
+ raise ValueError (f"simpo_gamma should be non-negative (got { simpo_gamma } )" )
185
203
186
204
lr_scheduler : FinetuneLRScheduler
187
205
if lr_scheduler_type == "cosine" :
@@ -204,7 +222,24 @@ def create_finetune_request(
204
222
if training_method == "sft" :
205
223
training_method_cls = TrainingMethodSFT (train_on_inputs = train_on_inputs )
206
224
elif training_method == "dpo" :
207
- training_method_cls = TrainingMethodDPO (dpo_beta = dpo_beta )
225
+ if simpo_gamma is not None and simpo_gamma > 0 :
226
+ dpo_reference_free = True
227
+ dpo_normalize_logratios_by_length = True
228
+ rprint (
229
+ f"Parameter simpo_gamma was set to { simpo_gamma } . "
230
+ "SimPO training detected. Reference logits will not be used "
231
+ "and length normalization of log-probabilities will be enabled."
232
+ )
233
+ else :
234
+ dpo_reference_free = False
235
+
236
+ training_method_cls = TrainingMethodDPO (
237
+ dpo_beta = dpo_beta ,
238
+ dpo_normalize_logratios_by_length = dpo_normalize_logratios_by_length ,
239
+ dpo_reference_free = dpo_reference_free ,
240
+ rpo_alpha = rpo_alpha ,
241
+ simpo_gamma = simpo_gamma ,
242
+ )
208
243
209
244
finetune_request = FinetuneRequest (
210
245
model = model ,
@@ -302,6 +337,9 @@ def create(
302
337
train_on_inputs : bool | Literal ["auto" ] | None = None ,
303
338
training_method : str = "sft" ,
304
339
dpo_beta : float | None = None ,
340
+ dpo_normalize_logratios_by_length : bool = False ,
341
+ rpo_alpha : float | None = None ,
342
+ simpo_gamma : float | None = None ,
305
343
from_checkpoint : str | None = None ,
306
344
) -> FinetuneResponse :
307
345
"""
@@ -353,6 +391,9 @@ def create(
353
391
training_method (str, optional): Training method. Defaults to "sft".
354
392
Supported methods: "sft", "dpo".
355
393
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
394
+ dpo_normalize_logratios_by_length (bool): Whether or not normalize logratios by sample length. Defaults to False,
395
+ rpo_alpha (float, optional): RPO alpha parameter of DPO training to include NLL in the loss. Defaults to None.
396
+ simpo_gamma: (float, optional): SimPO gamma parameter. Defaults to None.
356
397
from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
357
398
The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
358
399
The step value is optional, without it the final checkpoint will be used.
@@ -405,6 +446,9 @@ def create(
405
446
train_on_inputs = train_on_inputs ,
406
447
training_method = training_method ,
407
448
dpo_beta = dpo_beta ,
449
+ dpo_normalize_logratios_by_length = dpo_normalize_logratios_by_length ,
450
+ rpo_alpha = rpo_alpha ,
451
+ simpo_gamma = simpo_gamma ,
408
452
from_checkpoint = from_checkpoint ,
409
453
)
410
454
@@ -714,6 +758,9 @@ async def create(
714
758
train_on_inputs : bool | Literal ["auto" ] | None = None ,
715
759
training_method : str = "sft" ,
716
760
dpo_beta : float | None = None ,
761
+ dpo_normalize_logratios_by_length : bool = False ,
762
+ rpo_alpha : float | None = None ,
763
+ simpo_gamma : float | None = None ,
717
764
from_checkpoint : str | None = None ,
718
765
) -> FinetuneResponse :
719
766
"""
@@ -765,6 +812,9 @@ async def create(
765
812
training_method (str, optional): Training method. Defaults to "sft".
766
813
Supported methods: "sft", "dpo".
767
814
dpo_beta (float, optional): DPO beta parameter. Defaults to None.
815
+ dpo_normalize_logratios_by_length (bool): Whether or not normalize logratios by sample length. Defaults to False,
816
+ rpo_alpha (float, optional): RPO alpha parameter of DPO training to include NLL in the loss. Defaults to None.
817
+ simpo_gamma: (float, optional): SimPO gamma parameter. Defaults to None.
768
818
from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
769
819
The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
770
820
The step value is optional, without it the final checkpoint will be used.
@@ -817,6 +867,9 @@ async def create(
817
867
train_on_inputs = train_on_inputs ,
818
868
training_method = training_method ,
819
869
dpo_beta = dpo_beta ,
870
+ dpo_normalize_logratios_by_length = dpo_normalize_logratios_by_length ,
871
+ rpo_alpha = rpo_alpha ,
872
+ simpo_gamma = simpo_gamma ,
820
873
from_checkpoint = from_checkpoint ,
821
874
)
822
875
0 commit comments