Skip to content

Commit c09481b

Browse files
authored
Add a logic to support max_batch_size_dpo. (#305)
* Add a logic for max_batch_size_dpo, update version * Fix tests * Use default valuat to support old API
1 parent f647f3e commit c09481b

File tree

5 files changed

+73
-11
lines changed

5 files changed

+73
-11
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
1212

1313
[tool.poetry]
1414
name = "together"
15-
version = "1.5.6"
15+
version = "1.5.7"
1616
authors = ["Together AI <[email protected]>"]
1717
description = "Python client for Together's Cloud Platform!"
1818
readme = "README.md"

src/together/cli/api/finetune.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,13 @@ def create(
258258
raise click.BadParameter(
259259
f"LoRA fine-tuning is not supported for the model `{model}`"
260260
)
261-
261+
if training_method == "dpo":
262+
default_batch_size = model_limits.lora_training.max_batch_size_dpo
263+
else:
264+
default_batch_size = model_limits.lora_training.max_batch_size
262265
default_values = {
263266
"lora_r": model_limits.lora_training.max_rank,
264-
"batch_size": model_limits.lora_training.max_batch_size,
267+
"batch_size": default_batch_size,
265268
"learning_rate": 1e-3,
266269
}
267270

@@ -288,7 +291,12 @@ def create(
288291

289292
batch_size_source = ctx.get_parameter_source("batch_size") # type: ignore[attr-defined]
290293
if batch_size_source == ParameterSource.DEFAULT:
291-
training_args["batch_size"] = model_limits.full_training.max_batch_size
294+
if training_method == "dpo":
295+
training_args["batch_size"] = (
296+
model_limits.full_training.max_batch_size_dpo
297+
)
298+
else:
299+
training_args["batch_size"] = model_limits.full_training.max_batch_size
292300

293301
if n_evals <= 0 and validation_file:
294302
log_warn(

src/together/resources/finetune.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def create_finetune_request(
102102

103103
training_type: TrainingType = FullTrainingType()
104104
max_batch_size: int = 0
105+
max_batch_size_dpo: int = 0
105106
min_batch_size: int = 0
106107
if lora:
107108
if model_limits.lora_training is None:
@@ -119,7 +120,7 @@ def create_finetune_request(
119120

120121
max_batch_size = model_limits.lora_training.max_batch_size
121122
min_batch_size = model_limits.lora_training.min_batch_size
122-
123+
max_batch_size_dpo = model_limits.lora_training.max_batch_size_dpo
123124
else:
124125
if model_limits.full_training is None:
125126
raise ValueError(
@@ -128,13 +129,24 @@ def create_finetune_request(
128129

129130
max_batch_size = model_limits.full_training.max_batch_size
130131
min_batch_size = model_limits.full_training.min_batch_size
132+
max_batch_size_dpo = model_limits.full_training.max_batch_size_dpo
131133

132-
batch_size = batch_size if batch_size != "max" else max_batch_size
134+
if batch_size == "max":
135+
if training_method == "dpo":
136+
batch_size = max_batch_size_dpo
137+
else:
138+
batch_size = max_batch_size
133139

134-
if batch_size > max_batch_size:
135-
raise ValueError(
136-
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size}."
137-
)
140+
if training_method == "sft":
141+
if batch_size > max_batch_size:
142+
raise ValueError(
143+
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size}."
144+
)
145+
elif training_method == "dpo":
146+
if batch_size > max_batch_size_dpo:
147+
raise ValueError(
148+
f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size_dpo}."
149+
)
138150

139151
if batch_size < min_batch_size:
140152
raise ValueError(

src/together/types/finetune.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from enum import Enum
4-
from typing import List, Literal
4+
from typing import List, Literal, Any
55

66
from pydantic import StrictBool, Field, field_validator
77

@@ -329,8 +329,16 @@ class FinetuneDownloadResult(BaseModel):
329329

330330
class FinetuneFullTrainingLimits(BaseModel):
331331
max_batch_size: int
332+
max_batch_size_dpo: int = -1
332333
min_batch_size: int
333334

335+
def __init__(self, **data: Any) -> None:
336+
super().__init__(**data)
337+
if self.max_batch_size_dpo == -1:
338+
half_max = self.max_batch_size // 2
339+
rounded_half_max = (half_max // 8) * 8
340+
self.max_batch_size_dpo = max(self.min_batch_size, rounded_half_max)
341+
334342

335343
class FinetuneLoraTrainingLimits(FinetuneFullTrainingLimits):
336344
max_rank: int

tests/unit/test_finetune_resources.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
min_learning_rate=1e-6,
1919
full_training=FinetuneFullTrainingLimits(
2020
max_batch_size=96,
21+
max_batch_size_dpo=48,
2122
min_batch_size=8,
2223
),
2324
lora_training=FinetuneLoraTrainingLimits(
2425
max_batch_size=128,
26+
max_batch_size_dpo=64,
2527
min_batch_size=8,
2628
max_rank=64,
2729
target_modules=["q", "k", "v", "o", "mlp"],
@@ -83,6 +85,36 @@ def test_lora_request():
8385
assert request.batch_size == _MODEL_LIMITS.lora_training.max_batch_size
8486

8587

88+
def test_dpo_request_lora():
89+
request = create_finetune_request(
90+
model_limits=_MODEL_LIMITS,
91+
model=_MODEL_NAME,
92+
training_file=_TRAINING_FILE,
93+
training_method="dpo",
94+
lora=True,
95+
)
96+
97+
assert request.training_type.type == "Lora"
98+
assert request.training_type.lora_r == _MODEL_LIMITS.lora_training.max_rank
99+
assert request.training_type.lora_alpha == _MODEL_LIMITS.lora_training.max_rank * 2
100+
assert request.training_type.lora_dropout == 0.0
101+
assert request.training_type.lora_trainable_modules == "all-linear"
102+
assert request.batch_size == _MODEL_LIMITS.lora_training.max_batch_size_dpo
103+
104+
105+
def test_dpo_request():
106+
request = create_finetune_request(
107+
model_limits=_MODEL_LIMITS,
108+
model=_MODEL_NAME,
109+
training_file=_TRAINING_FILE,
110+
training_method="dpo",
111+
lora=False,
112+
)
113+
114+
assert request.training_type.type == "Full"
115+
assert request.batch_size == _MODEL_LIMITS.full_training.max_batch_size_dpo
116+
117+
86118
def test_from_checkpoint_request():
87119
request = create_finetune_request(
88120
model_limits=_MODEL_LIMITS,
@@ -160,6 +192,7 @@ def test_non_lora_model():
160192
min_learning_rate=1e-6,
161193
full_training=FinetuneFullTrainingLimits(
162194
max_batch_size=96,
195+
max_batch_size_dpo=48,
163196
min_batch_size=8,
164197
),
165198
lora_training=None,
@@ -181,6 +214,7 @@ def test_non_full_model():
181214
min_learning_rate=1e-6,
182215
lora_training=FinetuneLoraTrainingLimits(
183216
max_batch_size=96,
217+
max_batch_size_dpo=48,
184218
min_batch_size=8,
185219
max_rank=64,
186220
target_modules=["q", "k", "v", "o", "mlp"],

0 commit comments

Comments
 (0)