Skip to content

Commit c733d5b

Browse files
authored
Add support for specifying seed (#139)
* Add support for specifying seed * update test models * fix model name * update model list * bump version to 1.2.12
1 parent 28bf248 commit c733d5b

File tree

6 files changed

+44
-2
lines changed

6 files changed

+44
-2
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
1212

1313
[tool.poetry]
1414
name = "together"
15-
version = "1.2.11"
15+
version = "1.2.12"
1616
authors = [
1717
"Together AI <[email protected]>"
1818
]

src/together/resources/chat/completions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def create(
3232
frequency_penalty: float | None = None,
3333
min_p: float | None = None,
3434
logit_bias: Dict[str, float] | None = None,
35+
seed: int | None = None,
3536
stream: bool = False,
3637
logprobs: int | None = None,
3738
echo: bool | None = None,
@@ -79,6 +80,7 @@ def create(
7980
logit_bias (Dict[str, float], optional): A dictionary of tokens and their bias values that modify the
8081
likelihood of specific tokens being sampled. Bias values must be in the range [-100, 100].
8182
Defaults to None.
83+
seed (int, optional): A seed value to use for reproducibility.
8284
stream (bool, optional): Flag indicating whether to stream the generated completions.
8385
Defaults to False.
8486
logprobs (int, optional): Number of top-k logprobs to return
@@ -124,6 +126,7 @@ def create(
124126
frequency_penalty=frequency_penalty,
125127
min_p=min_p,
126128
logit_bias=logit_bias,
129+
seed=seed,
127130
stream=stream,
128131
logprobs=logprobs,
129132
echo=echo,
@@ -171,6 +174,7 @@ async def create(
171174
frequency_penalty: float | None = None,
172175
min_p: float | None = None,
173176
logit_bias: Dict[str, float] | None = None,
177+
seed: int | None = None,
174178
stream: bool = False,
175179
logprobs: int | None = None,
176180
echo: bool | None = None,
@@ -218,6 +222,7 @@ async def create(
218222
logit_bias (Dict[str, float], optional): A dictionary of tokens and their bias values that modify the
219223
likelihood of specific tokens being sampled. Bias values must be in the range [-100, 100].
220224
Defaults to None.
225+
seed (int, optional): A seed value to use for reproducibility.
221226
stream (bool, optional): Flag indicating whether to stream the generated completions.
222227
Defaults to False.
223228
logprobs (int, optional): Number of top-k logprobs to return
@@ -263,6 +268,7 @@ async def create(
263268
frequency_penalty=frequency_penalty,
264269
min_p=min_p,
265270
logit_bias=logit_bias,
271+
seed=seed,
266272
stream=stream,
267273
logprobs=logprobs,
268274
echo=echo,

src/together/resources/completions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def create(
3232
frequency_penalty: float | None = None,
3333
min_p: float | None = None,
3434
logit_bias: Dict[str, float] | None = None,
35+
seed: int | None = None,
3536
stream: bool = False,
3637
logprobs: int | None = None,
3738
echo: bool | None = None,
@@ -75,6 +76,7 @@ def create(
7576
logit_bias (Dict[str, float], optional): A dictionary of tokens and their bias values that modify the
7677
likelihood of specific tokens being sampled. Bias values must be in the range [-100, 100].
7778
Defaults to None.
79+
seed (int, optional): Seed value for reproducibility.
7880
stream (bool, optional): Flag indicating whether to stream the generated completions.
7981
Defaults to False.
8082
logprobs (int, optional): Number of top-k logprobs to return
@@ -107,6 +109,7 @@ def create(
107109
repetition_penalty=repetition_penalty,
108110
presence_penalty=presence_penalty,
109111
frequency_penalty=frequency_penalty,
112+
seed=seed,
110113
min_p=min_p,
111114
logit_bias=logit_bias,
112115
stream=stream,
@@ -153,6 +156,7 @@ async def create(
153156
frequency_penalty: float | None = None,
154157
min_p: float | None = None,
155158
logit_bias: Dict[str, float] | None = None,
159+
seed: int | None = None,
156160
stream: bool = False,
157161
logprobs: int | None = None,
158162
echo: bool | None = None,
@@ -196,6 +200,7 @@ async def create(
196200
logit_bias (Dict[str, float], optional): A dictionary of tokens and their bias values that modify the
197201
likelihood of specific tokens being sampled. Bias values must be in the range [-100, 100].
198202
Defaults to None.
203+
seed (int, optional): Seed value for reproducibility.
199204
stream (bool, optional): Flag indicating whether to stream the generated completions.
200205
Defaults to False.
201206
logprobs (int, optional): Number of top-k logprobs to return
@@ -230,6 +235,7 @@ async def create(
230235
frequency_penalty=frequency_penalty,
231236
min_p=min_p,
232237
logit_bias=logit_bias,
238+
seed=seed,
233239
stream=stream,
234240
logprobs=logprobs,
235241
echo=echo,

src/together/types/chat_completions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class ChatCompletionRequest(BaseModel):
9696
frequency_penalty: float | None = None
9797
min_p: float | None = None
9898
logit_bias: Dict[str, float] | None = None
99+
seed: int | None = None
99100
# stream SSE token chunks
100101
stream: bool = False
101102
# return logprobs
@@ -126,6 +127,7 @@ def verify_parameters(self) -> Self:
126127
class ChatCompletionChoicesData(BaseModel):
127128
index: int | None = None
128129
logprobs: LogprobsPart | None = None
130+
seed: int | None = None
129131
finish_reason: FinishReason | None = None
130132
message: ChatCompletionMessage | None = None
131133

@@ -150,6 +152,7 @@ class ChatCompletionResponse(BaseModel):
150152
class ChatCompletionChoicesChunk(BaseModel):
151153
index: int | None = None
152154
logprobs: float | None = None
155+
seed: int | None = None
153156
finish_reason: FinishReason | None = None
154157
delta: DeltaContent | None = None
155158

src/together/types/completions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class CompletionRequest(BaseModel):
3535
frequency_penalty: float | None = None
3636
min_p: float | None = None
3737
logit_bias: Dict[str, float] | None = None
38+
seed: int | None = None
3839
# stream SSE token chunks
3940
stream: bool = False
4041
# return logprobs
@@ -61,13 +62,15 @@ def verify_parameters(self) -> Self:
6162
class CompletionChoicesData(BaseModel):
6263
index: int
6364
logprobs: LogprobsPart | None = None
64-
finish_reason: FinishReason | None = None
65+
seed: int | None = None
66+
finish_reason: FinishReason
6567
text: str
6668

6769

6870
class CompletionChoicesChunk(BaseModel):
6971
index: int | None = None
7072
logprobs: float | None = None
73+
seed: int | None = None
7174
finish_reason: FinishReason | None = None
7275
delta: DeltaContent | None = None
7376

tests/integration/resources/test_completion.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,3 +530,27 @@ def test_logit_bias(
530530
)
531531

532532
assert isinstance(response, CompletionResponse)
533+
534+
@pytest.mark.parametrize(
535+
"model,prompt",
536+
product(
537+
completion_test_model_list,
538+
completion_prompt_list,
539+
),
540+
)
541+
def test_seed(
542+
self,
543+
model,
544+
prompt,
545+
sync_together_client,
546+
):
547+
response = sync_together_client.completions.create(
548+
prompt=prompt,
549+
model=model,
550+
stop=STOP,
551+
max_tokens=1,
552+
seed=4242,
553+
)
554+
555+
assert isinstance(response, CompletionResponse)
556+
assert response.choices[0].seed == 4242

0 commit comments

Comments
 (0)