-
Notifications
You must be signed in to change notification settings - Fork 16
New options for preference tuning: rpo alpha, logprobs normalization, reference-free, simpo gamma #327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -204,7 +215,24 @@ def create_finetune_request( | |||
if training_method == "sft": | |||
training_method_cls = TrainingMethodSFT(train_on_inputs=train_on_inputs) | |||
elif training_method == "dpo": | |||
training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta) | |||
if simpo_gamma is not None and simpo_gamma > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the way, should we raise a ValueError if it's <=0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added + added for rpo_alpha (can't imagine an use case for negative values for these parameters)
if rpo_alpha is not None: | ||
if training_method != "dpo": | ||
raise ValueError("rpo_alpha is only supported for DPO training") | ||
if not rpo_alpha >= 0.0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it's wise to put an upper limit too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what can be a limit here, lets say 10? Wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure we should be enforcing any particular limit on this value, although it might be helpful. The problem is that this limit will apply only when users submit jobs via together-python
raise ValueError( | ||
"dpo_normalize_logratios_by_length=True is only supported for DPO training" | ||
) | ||
if rpo_alpha is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this could simply be if rpo_alpha
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not quite, PEP8 explicitly advises against it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit below I want to notify user that rpo_alpha==0.0 throws an error
src/together/resources/finetune.py
Outdated
@@ -765,6 +812,9 @@ async def create( | |||
training_method (str, optional): Training method. Defaults to "sft". | |||
Supported methods: "sft", "dpo". | |||
dpo_beta (float, optional): DPO beta parameter. Defaults to None. | |||
dpo_normalize_logratios_by_length (bool): Whether or not normalize logratios by sample lenght. Defaults to False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
length* (sorry for being nit-picky)
Have you read the Contributing Guidelines?
Issue #
Describe your changes
Clearly and concisely describe what's in this pull request. Include screenshots, if necessary.