Skip to content

Commit 8ef8769

Browse files
committed
tests
1 parent 2cbd4c0 commit 8ef8769

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

tests/unit/test_finetune_resources.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,3 +281,32 @@ def test_bad_training_method():
281281
training_file=_TRAINING_FILE,
282282
training_method="NON_SFT",
283283
)
284+
285+
286+
@pytest.mark.parametrize("train_on_inputs", [True, False, "auto", None])
287+
def test_train_on_inputs_for_sft(train_on_inputs):
288+
request = create_finetune_request(
289+
model_limits=_MODEL_LIMITS,
290+
model=_MODEL_NAME,
291+
training_file=_TRAINING_FILE,
292+
training_method="sft",
293+
train_on_inputs=train_on_inputs,
294+
)
295+
assert request.training_method.method == "sft"
296+
if isinstance(train_on_inputs, bool):
297+
assert request.training_method.train_on_inputs is train_on_inputs
298+
else:
299+
assert request.training_method.train_on_inputs == "auto"
300+
301+
302+
def test_train_on_inputs_not_supported_for_dpo():
303+
with pytest.raises(
304+
ValueError, match="train_on_inputs is only supported for SFT training"
305+
):
306+
_ = create_finetune_request(
307+
model_limits=_MODEL_LIMITS,
308+
model=_MODEL_NAME,
309+
training_file=_TRAINING_FILE,
310+
training_method="dpo",
311+
train_on_inputs=True,
312+
)

0 commit comments

Comments
 (0)