Skip to content

Commit 3514b5c

Browse files
authored
Add 70B models support for fine tuning job submission (#28)
* Revert "Remove 70B models (not yet supported) (#26)" This reverts commit 6b7ff20. * Add a check for n_checkpoints to ensure that we limit the number of checkpoints to 1 for 70B models * Change default behavior for batch size for 40B models. And only allow fixed batch size. * Format changes only * Address PR review * Remove duplicate code. * Black fixes
1 parent 2af9f1e commit 3514b5c

File tree

3 files changed

+35
-7
lines changed

3 files changed

+35
-7
lines changed

src/together/commands/finetune.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def _add_create(parser: argparse._SubParsersAction[argparse.ArgumentParser]) ->
7777
"--batch-size",
7878
"-b",
7979
metavar="BATCH_SIZE",
80-
default=32,
80+
default=None,
8181
help="The batch size to use for training. Default=32",
8282
type=int,
8383
)
@@ -281,6 +281,16 @@ def _add_checkpoints(
281281
def _run_create(args: argparse.Namespace) -> None:
282282
finetune = Finetune()
283283

284+
# Set default batch size based on model
285+
if args.batch_size is None:
286+
if args.model in [
287+
"togethercomputer/llama-2-70b",
288+
"togethercomputer/llama-2-70b-chat",
289+
]:
290+
args.batch_size = 144
291+
else:
292+
args.batch_size = 32
293+
284294
response = finetune.create(
285295
training_file=args.training_file, # training file_id
286296
# validation_file=args.validation_file, # validation file_id

src/together/config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
"togethercomputer/RedPajama-INCITE-Instruct-3B-v1",
1515
"togethercomputer/Pythia-Chat-Base-7B",
1616
"togethercomputer/Llama-2-7B-32K-Instruct",
17-
# "togethercomputer/llama-2-70b",
18-
# "togethercomputer/llama-2-70b-chat",
17+
"togethercomputer/llama-2-70b",
18+
"togethercomputer/llama-2-70b-chat",
1919
]
2020

2121
# List of models we support and their particular behavior, ie special tokens,
@@ -73,8 +73,8 @@
7373
"togethercomputer/falcon-7b": {},
7474
"togethercomputer/llama-2-13b-chat": {"bos_token": "<s>", "eos_token": "</s>"},
7575
"togethercomputer/llama-2-13b": {"bos_token": "<s>", "eos_token": "</s>"},
76-
# "togethercomputer/llama-2-70b-chat": {"bos_token": "<s>", "eos_token": "</s>"},
77-
# "togethercomputer/llama-2-70b": {"bos_token": "<s>", "eos_token": "</s>"},
76+
"togethercomputer/llama-2-70b-chat": {"bos_token": "<s>", "eos_token": "</s>"},
77+
"togethercomputer/llama-2-70b": {"bos_token": "<s>", "eos_token": "</s>"},
7878
"togethercomputer/llama-2-7b-chat": {"bos_token": "<s>", "eos_token": "</s>"},
7979
"togethercomputer/llama-2-7b": {"bos_token": "<s>", "eos_token": "</s>"},
8080
"togethercomputer/mpt-30b-chat": {},

src/together/finetune.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ def model_param_count(name: str) -> int:
3434
"togethercomputer/CodeLlama-13b": 13016028160,
3535
"togethercomputer/CodeLlama-13b-Python": 13016028160,
3636
"togethercomputer/CodeLlama-13b-Instruct": 13016028160,
37-
# "togethercomputer/llama-2-70b": 68976648192,
38-
# "togethercomputer/llama-2-70b-chat": 68976648192,
37+
"togethercomputer/llama-2-70b": 68976648192,
38+
"togethercomputer/llama-2-70b-chat": 68976648192,
3939
}
4040
try:
4141
return pcount[name]
@@ -89,6 +89,24 @@ def create(
8989
f"The number of checkpoints must be < the number of epochs, setting to {n_checkpoints}"
9090
)
9191

92+
if (
93+
model
94+
in ["togethercomputer/llama-2-70b", "togethercomputer/llama-2-70b-chat"]
95+
and batch_size != 144
96+
):
97+
raise ValueError(
98+
f"Batch size must be 144 for {model} model. Please set batch size to 144"
99+
)
100+
101+
# TODO: REMOVE THIS CHECK WHEN WE HAVE CHECKPOINTING WORKING FOR 70B models
102+
if n_checkpoints > 1 and model in [
103+
"togethercomputer/llama-2-70b",
104+
"togethercomputer/llama-2-70b-chat",
105+
]:
106+
raise ValueError(
107+
"Saving checkpoints during training currently not supported for {model}. Please set the number of checkpoints to 1"
108+
)
109+
92110
parameter_payload = {
93111
"training_file": training_file,
94112
# "validation_file": validation_file,

0 commit comments

Comments
 (0)