Skip to content

Commit 245ad6b

Browse files
authored
add rerank api support (#171)
* add rerank api support * fix tests
1 parent af7826b commit 245ad6b

File tree

10 files changed

+198
-6
lines changed

10 files changed

+198
-6
lines changed

src/together/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class Together:
1818
images: resources.Images
1919
models: resources.Models
2020
fine_tuning: resources.FineTuning
21+
rerank: resources.Rerank
2122

2223
# client options
2324
client: TogetherClient
@@ -77,6 +78,7 @@ def __init__(
7778
self.images = resources.Images(self.client)
7879
self.models = resources.Models(self.client)
7980
self.fine_tuning = resources.FineTuning(self.client)
81+
self.rerank = resources.Rerank(self.client)
8082

8183

8284
class AsyncTogether:
@@ -87,6 +89,7 @@ class AsyncTogether:
8789
images: resources.AsyncImages
8890
models: resources.AsyncModels
8991
fine_tuning: resources.AsyncFineTuning
92+
rerank: resources.AsyncRerank
9093

9194
# client options
9295
client: TogetherClient
@@ -146,6 +149,7 @@ def __init__(
146149
self.images = resources.AsyncImages(self.client)
147150
self.models = resources.AsyncModels(self.client)
148151
self.fine_tuning = resources.AsyncFineTuning(self.client)
152+
self.rerank = resources.AsyncRerank(self.client)
149153

150154

151155
Client = Together

src/together/resources/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from together.resources.finetune import AsyncFineTuning, FineTuning
66
from together.resources.images import AsyncImages, Images
77
from together.resources.models import AsyncModels, Models
8+
from together.resources.rerank import AsyncRerank, Rerank
89

910

1011
__all__ = [
@@ -22,4 +23,6 @@
2223
"Images",
2324
"AsyncModels",
2425
"Models",
26+
"AsyncRerank",
27+
"Rerank",
2528
]

src/together/resources/rerank.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
from __future__ import annotations
2+
3+
from typing import List, Dict, Any
4+
5+
from together.abstract import api_requestor
6+
from together.together_response import TogetherResponse
7+
from together.types import (
8+
RerankRequest,
9+
RerankResponse,
10+
TogetherClient,
11+
TogetherRequest,
12+
)
13+
14+
15+
class Rerank:
16+
def __init__(self, client: TogetherClient) -> None:
17+
self._client = client
18+
19+
def create(
20+
self,
21+
*,
22+
model: str,
23+
query: str,
24+
documents: List[str] | List[Dict[str, Any]],
25+
top_n: int | None = None,
26+
return_documents: bool = False,
27+
rank_fields: List[str] | None = None,
28+
) -> RerankResponse:
29+
"""
30+
Method to generate completions based on a given prompt using a specified model.
31+
32+
Args:
33+
model (str): The name of the model to query.
34+
query (str): The input query or list of queries to rerank.
35+
documents (List[str] | List[Dict[str, Any]]): List of documents to be reranked.
36+
top_n (int | None): Number of top results to return.
37+
return_documents (bool): Flag to indicate whether to return documents.
38+
rank_fields (List[str] | None): Fields to be used for ranking the documents.
39+
40+
Returns:
41+
RerankResponse: Object containing reranked scores and documents
42+
"""
43+
44+
requestor = api_requestor.APIRequestor(
45+
client=self._client,
46+
)
47+
48+
parameter_payload = RerankRequest(
49+
model=model,
50+
query=query,
51+
documents=documents,
52+
top_n=top_n,
53+
return_documents=return_documents,
54+
rank_fields=rank_fields,
55+
).model_dump(exclude_none=True)
56+
57+
response, _, _ = requestor.request(
58+
options=TogetherRequest(
59+
method="POST",
60+
url="rerank",
61+
params=parameter_payload,
62+
),
63+
stream=False,
64+
)
65+
66+
assert isinstance(response, TogetherResponse)
67+
68+
return RerankResponse(**response.data)
69+
70+
71+
class AsyncRerank:
72+
def __init__(self, client: TogetherClient) -> None:
73+
self._client = client
74+
75+
async def create(
76+
self,
77+
*,
78+
model: str,
79+
query: str,
80+
documents: List[str] | List[Dict[str, Any]],
81+
top_n: int | None = None,
82+
return_documents: bool = False,
83+
rank_fields: List[str] | None = None,
84+
) -> RerankResponse:
85+
"""
86+
Async method to generate completions based on a given prompt using a specified model.
87+
88+
Args:
89+
model (str): The name of the model to query.
90+
query (str): The input query or list of queries to rerank.
91+
documents (List[str] | List[Dict[str, Any]]): List of documents to be reranked.
92+
top_n (int | None): Number of top results to return.
93+
return_documents (bool): Flag to indicate whether to return documents.
94+
rank_fields (List[str] | None): Fields to be used for ranking the documents.
95+
96+
Returns:
97+
RerankResponse: Object containing reranked scores and documents
98+
"""
99+
100+
requestor = api_requestor.APIRequestor(
101+
client=self._client,
102+
)
103+
104+
parameter_payload = RerankRequest(
105+
model=model,
106+
query=query,
107+
documents=documents,
108+
top_n=top_n,
109+
return_documents=return_documents,
110+
rank_fields=rank_fields,
111+
).model_dump(exclude_none=True)
112+
113+
response, _, _ = await requestor.arequest(
114+
options=TogetherRequest(
115+
method="POST",
116+
url="rerank",
117+
params=parameter_payload,
118+
),
119+
stream=False,
120+
)
121+
122+
assert isinstance(response, TogetherResponse)
123+
124+
return RerankResponse(**response.data)

src/together/types/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@
3535
ImageResponse,
3636
)
3737
from together.types.models import ModelObject
38-
38+
from together.types.rerank import (
39+
RerankRequest,
40+
RerankResponse,
41+
)
3942

4043
__all__ = [
4144
"TogetherClient",
@@ -66,4 +69,6 @@
6669
"TrainingType",
6770
"FullTrainingType",
6871
"LoRATrainingType",
72+
"RerankRequest",
73+
"RerankResponse",
6974
]

src/together/types/rerank.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from __future__ import annotations
2+
3+
from typing import List, Literal, Dict, Any
4+
5+
from together.types.abstract import BaseModel
6+
from together.types.common import UsageData
7+
8+
9+
class RerankRequest(BaseModel):
10+
# model to query
11+
model: str
12+
# input or list of inputs
13+
query: str
14+
# list of documents
15+
documents: List[str] | List[Dict[str, Any]]
16+
# return top_n results
17+
top_n: int | None = None
18+
# boolean to return documents
19+
return_documents: bool = False
20+
# field selector for documents
21+
rank_fields: List[str] | None = None
22+
23+
24+
class RerankChoicesData(BaseModel):
25+
# response index
26+
index: int
27+
# object type
28+
relevance_score: float
29+
# rerank response
30+
document: Dict[str, Any] | None = None
31+
32+
33+
class RerankResponse(BaseModel):
34+
# job id
35+
id: str | None = None
36+
# object type
37+
object: Literal["rerank"] | None = None
38+
# query model
39+
model: str | None = None
40+
# list of reranked results
41+
results: List[RerankChoicesData] | None = None
42+
# usage stats
43+
usage: UsageData | None = None

tests/integration/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
completion_test_model_list = [
22
"meta-llama/Llama-2-7b-hf",
3-
"togethercomputer/StripedHyena-Hessian-7B",
43
]
54
chat_test_model_list = []
65
embedding_test_model_list = []

tests/integration/resources/test_completion.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ def test_create(
9595
assert isinstance(response.id, str)
9696
assert isinstance(response.created, int)
9797
assert isinstance(response.object, ObjectType)
98-
assert response.model == model
9998
assert isinstance(response.choices, list)
10099
assert isinstance(response.choices[0], CompletionChoicesData)
101100
assert isinstance(response.choices[0].text, str)
@@ -170,8 +169,6 @@ def test_model(
170169

171170
assert isinstance(response, CompletionResponse)
172171

173-
assert response.model == model
174-
175172
@pytest.mark.parametrize(
176173
"model,prompt",
177174
product(completion_test_model_list, completion_prompt_list),

tests/integration/resources/test_completion_stream.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def test_create(
6969
assert isinstance(chunk.id, str)
7070
assert isinstance(chunk.created, int)
7171
assert isinstance(chunk.object, ObjectType)
72-
assert chunk.model == model
7372
assert isinstance(chunk.choices[0], CompletionChoicesChunk)
7473
assert isinstance(chunk.choices[0].index, int)
7574
assert isinstance(chunk.choices[0].delta, DeltaContent)

tests/unit/test_async_client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,12 @@ def test_fine_tuning_initialized(self, async_together_instance):
113113
assert async_together_instance.fine_tuning is not None
114114

115115
assert isinstance(async_together_instance.fine_tuning._client, TogetherClient)
116+
117+
def test_rerank_initialized(self, async_together_instance):
118+
"""
119+
Test initializing rerank
120+
"""
121+
122+
assert async_together_instance.rerank is not None
123+
124+
assert isinstance(async_together_instance.rerank._client, TogetherClient)

tests/unit/test_client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,12 @@ def test_fine_tuning_initialized(self, sync_together_instance):
114114
assert sync_together_instance.fine_tuning is not None
115115

116116
assert isinstance(sync_together_instance.fine_tuning._client, TogetherClient)
117+
118+
def test_rerank_initialized(self, sync_together_instance):
119+
"""
120+
Test initializing rerank
121+
"""
122+
123+
assert sync_together_instance.rerank is not None
124+
125+
assert isinstance(sync_together_instance.rerank._client, TogetherClient)

0 commit comments

Comments
 (0)