Skip to content

Commit fb5b85b

Browse files
authored
Merge pull request #53 from togethercomputer/orangetin/handle-errors-fix
Handle errors gracefully
2 parents 474dcc5 + ea33e79 commit fb5b85b

File tree

14 files changed

+239
-165
lines changed

14 files changed

+239
-165
lines changed

src/together/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
default_image_model = "runwayml/stable-diffusion-v1-5"
2222
log_level = "WARNING"
2323

24+
MISSING_API_KEY_MESSAGE = """TOGETHER_API_KEY not found.
25+
Please set it as an environment variable or set it as together.api_key
26+
Find your TOGETHER_API_KEY at https://api.together.xyz/settings/api-keys"""
27+
2428
MAX_CONNECTION_RETRIES = 2
2529
BACKOFF_FACTOR = 0.2
2630

@@ -49,6 +53,7 @@
4953
"Finetune",
5054
"Image",
5155
"MAX_CONNECTION_RETRIES",
56+
"MISSING_API_KEY_MESSAGE",
5257
"BACKOFF_FACTOR",
5358
"min_samples",
5459
]

src/together/commands/chat.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
import together
77
import together.tools.conversation as convo
88
from together import Complete
9+
from together.utils import get_logger
10+
11+
12+
logger = get_logger(str(__name__))
913

1014

1115
def add_parser(subparsers: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
@@ -99,18 +103,22 @@ def precmd(self, line: str) -> str:
99103
def do_say(self, arg: str) -> None:
100104
self._convo.push_human_turn(arg)
101105
output = ""
102-
for token in self.infer.create_streaming(
103-
prompt=self._convo.get_raw_prompt(),
104-
model=self.args.model,
105-
max_tokens=self.args.max_tokens,
106-
stop=self.args.stop,
107-
temperature=self.args.temperature,
108-
top_p=self.args.top_p,
109-
top_k=self.args.top_k,
110-
repetition_penalty=self.args.repetition_penalty,
111-
):
112-
print(token, end="", flush=True)
113-
output += token
106+
try:
107+
for token in self.infer.create_streaming(
108+
prompt=self._convo.get_raw_prompt(),
109+
model=self.args.model,
110+
max_tokens=self.args.max_tokens,
111+
stop=self.args.stop,
112+
temperature=self.args.temperature,
113+
top_p=self.args.top_p,
114+
top_k=self.args.top_k,
115+
repetition_penalty=self.args.repetition_penalty,
116+
):
117+
print(token, end="", flush=True)
118+
output += token
119+
except together.AuthenticationError:
120+
logger.critical(together.MISSING_API_KEY_MESSAGE)
121+
exit(0)
114122
print("\n")
115123
self._convo.push_model_response(output)
116124

src/together/commands/complete.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -131,33 +131,41 @@ def _run_complete(args: argparse.Namespace) -> None:
131131
complete = Complete()
132132

133133
if args.no_stream:
134-
response = complete.create(
135-
prompt=args.prompt,
136-
model=args.model,
137-
max_tokens=args.max_tokens,
138-
stop=args.stop,
139-
temperature=args.temperature,
140-
top_p=args.top_p,
141-
top_k=args.top_k,
142-
repetition_penalty=args.repetition_penalty,
143-
logprobs=args.logprobs,
144-
)
134+
try:
135+
response = complete.create(
136+
prompt=args.prompt,
137+
model=args.model,
138+
max_tokens=args.max_tokens,
139+
stop=args.stop,
140+
temperature=args.temperature,
141+
top_p=args.top_p,
142+
top_k=args.top_k,
143+
repetition_penalty=args.repetition_penalty,
144+
logprobs=args.logprobs,
145+
)
146+
except together.AuthenticationError:
147+
logger.critical(together.MISSING_API_KEY_MESSAGE)
148+
exit(0)
145149
no_streamer(args, response)
146150
else:
147-
for text in complete.create_streaming(
148-
prompt=args.prompt,
149-
model=args.model,
150-
max_tokens=args.max_tokens,
151-
stop=args.stop,
152-
temperature=args.temperature,
153-
top_p=args.top_p,
154-
top_k=args.top_k,
155-
repetition_penalty=args.repetition_penalty,
156-
raw=args.raw,
157-
):
158-
if not args.raw:
159-
print(text, end="", flush=True)
160-
else:
161-
print(text)
151+
try:
152+
for text in complete.create_streaming(
153+
prompt=args.prompt,
154+
model=args.model,
155+
max_tokens=args.max_tokens,
156+
stop=args.stop,
157+
temperature=args.temperature,
158+
top_p=args.top_p,
159+
top_k=args.top_k,
160+
repetition_penalty=args.repetition_penalty,
161+
raw=args.raw,
162+
):
163+
if not args.raw:
164+
print(text, end="", flush=True)
165+
else:
166+
print(text)
167+
except together.AuthenticationError:
168+
logger.critical(together.MISSING_API_KEY_MESSAGE)
169+
exit(0)
162170
if not args.raw:
163171
print("\n")

src/together/commands/files.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,12 @@
55

66
from tabulate import tabulate
77

8+
import together
89
from together import Files
9-
from together.utils import bytes_to_human_readable, extract_time
10+
from together.utils import bytes_to_human_readable, extract_time, get_logger
11+
12+
13+
logger = get_logger(str(__name__))
1014

1115

1216
def add_parser(subparsers: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
@@ -123,7 +127,11 @@ def _add_retrieve_content(
123127

124128

125129
def _run_list(args: argparse.Namespace) -> None:
126-
response = Files.list()
130+
try:
131+
response = Files.list()
132+
except together.AuthenticationError:
133+
logger.critical(together.MISSING_API_KEY_MESSAGE)
134+
exit(0)
127135
response["data"].sort(key=extract_time)
128136
if args.raw:
129137
print(json.dumps(response, indent=4))
@@ -146,22 +154,40 @@ def _run_list(args: argparse.Namespace) -> None:
146154

147155

148156
def _run_check(args: argparse.Namespace) -> None:
149-
response = Files.check(args.file)
157+
try:
158+
response = Files.check(args.file)
159+
except together.AuthenticationError:
160+
logger.critical(together.MISSING_API_KEY_MESSAGE)
161+
exit(0)
150162
print(json.dumps(response, indent=4))
151163

152164

153165
def _run_upload(args: argparse.Namespace) -> None:
154-
response = Files.upload(file=args.file, check=not args.no_check, model=args.model)
166+
try:
167+
response = Files.upload(
168+
file=args.file, check=not args.no_check, model=args.model
169+
)
170+
except together.AuthenticationError:
171+
logger.critical(together.MISSING_API_KEY_MESSAGE)
172+
exit(0)
155173
print(json.dumps(response, indent=4))
156174

157175

158176
def _run_delete(args: argparse.Namespace) -> None:
159-
response = Files.delete(args.file_id)
177+
try:
178+
response = Files.delete(args.file_id)
179+
except together.AuthenticationError:
180+
logger.critical(together.MISSING_API_KEY_MESSAGE)
181+
exit(0)
160182
print(json.dumps(response, indent=4))
161183

162184

163185
def _run_retrieve(args: argparse.Namespace) -> None:
164-
response = Files.retrieve(args.file_id)
186+
try:
187+
response = Files.retrieve(args.file_id)
188+
except together.AuthenticationError:
189+
logger.critical(together.MISSING_API_KEY_MESSAGE)
190+
exit(0)
165191
if args.raw:
166192
print(json.dumps(response, indent=4))
167193
else:
@@ -171,5 +197,9 @@ def _run_retrieve(args: argparse.Namespace) -> None:
171197

172198

173199
def _run_retrieve_content(args: argparse.Namespace) -> None:
174-
output = Files.retrieve_content(args.file_id, args.output)
200+
try:
201+
output = Files.retrieve_content(args.file_id, args.output)
202+
except together.AuthenticationError:
203+
logger.critical(together.MISSING_API_KEY_MESSAGE)
204+
exit(0)
175205
print(output)

src/together/commands/finetune.py

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@
66

77
from tabulate import tabulate
88

9+
import together
910
from together import Finetune
10-
from together.utils import finetune_price_to_dollars, parse_timestamp
11+
from together.utils import finetune_price_to_dollars, get_logger, parse_timestamp
12+
13+
14+
logger = get_logger(str(__name__))
1115

1216

1317
def add_parser(subparsers: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
@@ -24,7 +28,6 @@ def add_parser(subparsers: argparse._SubParsersAction[argparse.ArgumentParser])
2428
_add_download(child_parsers)
2529
_add_status(child_parsers)
2630
_add_checkpoints(child_parsers)
27-
# _add_delete_model(child_parsers)
2831

2932

3033
def _add_create(parser: argparse._SubParsersAction[argparse.ArgumentParser]) -> None:
@@ -252,25 +255,32 @@ def _run_create(args: argparse.Namespace) -> None:
252255
args.batch_size = 144
253256
else:
254257
args.batch_size = 32
255-
256-
response = finetune.create(
257-
training_file=args.training_file, # training file_id
258-
model=args.model,
259-
n_epochs=args.n_epochs,
260-
n_checkpoints=args.n_checkpoints,
261-
batch_size=args.batch_size,
262-
learning_rate=args.learning_rate,
263-
suffix=args.suffix,
264-
estimate_price=args.estimate_price,
265-
wandb_api_key=args.wandb_api_key if not args.no_wandb_api_key else None,
266-
confirm_inputs=not args.quiet,
267-
)
258+
try:
259+
response = finetune.create(
260+
training_file=args.training_file, # training file_id
261+
model=args.model,
262+
n_epochs=args.n_epochs,
263+
n_checkpoints=args.n_checkpoints,
264+
batch_size=args.batch_size,
265+
learning_rate=args.learning_rate,
266+
suffix=args.suffix,
267+
estimate_price=args.estimate_price,
268+
wandb_api_key=args.wandb_api_key if not args.no_wandb_api_key else None,
269+
confirm_inputs=not args.quiet,
270+
)
271+
except together.AuthenticationError:
272+
logger.critical(together.MISSING_API_KEY_MESSAGE)
273+
exit(0)
268274

269275
print(json.dumps(response, indent=4))
270276

271277

272278
def _run_list(args: argparse.Namespace) -> None:
273-
response = Finetune.list()
279+
try:
280+
response = Finetune.list()
281+
except together.AuthenticationError:
282+
logger.critical(together.MISSING_API_KEY_MESSAGE)
283+
exit(0)
274284
response["data"].sort(key=lambda x: parse_timestamp(x["created_at"]))
275285
if args.raw:
276286
print(json.dumps(response, indent=4))
@@ -293,7 +303,11 @@ def _run_list(args: argparse.Namespace) -> None:
293303

294304

295305
def _run_retrieve(args: argparse.Namespace) -> None:
296-
response = Finetune.retrieve(args.fine_tune_id)
306+
try:
307+
response = Finetune.retrieve(args.fine_tune_id)
308+
except together.AuthenticationError:
309+
logger.critical(together.MISSING_API_KEY_MESSAGE)
310+
exit(0)
297311
if args.raw:
298312
print(json.dumps(response, indent=4))
299313
else:
@@ -307,12 +321,20 @@ def _run_retrieve(args: argparse.Namespace) -> None:
307321

308322

309323
def _run_cancel(args: argparse.Namespace) -> None:
310-
response = Finetune.cancel(args.fine_tune_id)
324+
try:
325+
response = Finetune.cancel(args.fine_tune_id)
326+
except together.AuthenticationError:
327+
logger.critical(together.MISSING_API_KEY_MESSAGE)
328+
exit(0)
311329
print(json.dumps(response, indent=4))
312330

313331

314332
def _run_list_events(args: argparse.Namespace) -> None:
315-
response = Finetune.list_events(args.fine_tune_id)
333+
try:
334+
response = Finetune.list_events(args.fine_tune_id)
335+
except together.AuthenticationError:
336+
logger.critical(together.MISSING_API_KEY_MESSAGE)
337+
exit(0)
316338
if args.raw:
317339
print(json.dumps(response, indent=4))
318340
else:
@@ -330,16 +352,30 @@ def _run_list_events(args: argparse.Namespace) -> None:
330352

331353

332354
def _run_download(args: argparse.Namespace) -> None:
333-
response = Finetune.download(args.fine_tune_id, args.output, args.checkpoint_step)
355+
try:
356+
response = Finetune.download(
357+
args.fine_tune_id, args.output, args.checkpoint_step
358+
)
359+
except together.AuthenticationError:
360+
logger.critical(together.MISSING_API_KEY_MESSAGE)
361+
exit(0)
334362
print(response)
335363

336364

337365
def _run_status(args: argparse.Namespace) -> None:
338-
response = Finetune.get_job_status(args.fine_tune_id)
366+
try:
367+
response = Finetune.get_job_status(args.fine_tune_id)
368+
except together.AuthenticationError:
369+
logger.critical(together.MISSING_API_KEY_MESSAGE)
370+
exit(0)
339371
print(response)
340372

341373

342374
def _run_checkpoint(args: argparse.Namespace) -> None:
343-
checkpoints = Finetune.get_checkpoints(args.fine_tune_id)
375+
try:
376+
checkpoints = Finetune.get_checkpoints(args.fine_tune_id)
377+
except together.AuthenticationError:
378+
logger.critical(together.MISSING_API_KEY_MESSAGE)
379+
exit(0)
344380
print(json.dumps(checkpoints, indent=4))
345381
print(f"\n{len(checkpoints)} checkpoints found")

src/together/commands/image.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,19 @@ def _save_image(args: argparse.Namespace, response: Dict[str, Any]) -> None:
124124

125125
def _run_complete(args: argparse.Namespace) -> None:
126126
complete = Image()
127-
128-
response = complete.create(
129-
prompt=args.prompt,
130-
model=args.model,
131-
steps=args.steps,
132-
seed=args.seed,
133-
results=args.results,
134-
height=args.height,
135-
width=args.width,
136-
negative_prompt=args.negative_prompt,
137-
)
127+
try:
128+
response = complete.create(
129+
prompt=args.prompt,
130+
model=args.model,
131+
steps=args.steps,
132+
seed=args.seed,
133+
results=args.results,
134+
height=args.height,
135+
width=args.width,
136+
negative_prompt=args.negative_prompt,
137+
)
138+
except together.AuthenticationError:
139+
logger.critical(together.MISSING_API_KEY_MESSAGE)
140+
exit(0)
138141

139142
_save_image(args, response)

0 commit comments

Comments
 (0)