Skip to content

Commit b2a900a

Browse files
authored
feat(drivers-rerank-nvidia-nim): add NvidiaNimRerankDriver (#1926)
* feat(drivers-rerank-nvidia-nim): add `NvidiaNimRerankDriver` * skipped test * clean up logic
1 parent 314c943 commit b2a900a

File tree

6 files changed

+144
-0
lines changed

6 files changed

+144
-0
lines changed

docs/griptape-framework/drivers/rerank-drivers.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,13 @@ The [CohereRerankDriver](../../reference/griptape/drivers/rerank/cohere_rerank_d
4444
```text
4545
--8<-- "docs/griptape-framework/drivers/logs/cohere_rerank_driver.txt"
4646
```
47+
48+
### Nvidia NIM
49+
50+
The [NvidiaNimRerankDriver](../../reference/griptape/drivers/rerank/nvidia_nim_rerank_driver.md) uses the [Nvidia NIM Reranking API](https://docs.nvidia.com/nim/nemo-retriever/text-reranking/latest/index.html).
51+
52+
=== "Code"
53+
54+
```python
55+
--8<-- "docs/griptape-framework/drivers/src/nvidia_nim_rerank_driver.py"
56+
```
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from griptape.artifacts import TextArtifact
2+
from griptape.drivers.rerank.nvidia_nim import NvidiaNimRerankDriver
3+
4+
rerank_driver = NvidiaNimRerankDriver(
5+
model="nvidia/bert-base-uncased",
6+
base_url="http://localhost:8000",
7+
)
8+
9+
artifacts = rerank_driver.run(
10+
"Where is NYC located?",
11+
[
12+
TextArtifact("NYC Media"),
13+
TextArtifact("New York City Police Department"),
14+
TextArtifact("New York City"),
15+
TextArtifact("New York City Subway"),
16+
],
17+
)
18+
for artifact in artifacts:
19+
print(artifact.value)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from griptape.drivers.rerank.nvidia_nim_rerank_driver import NvidiaNimRerankDriver
2+
3+
__all__ = ["NvidiaNimRerankDriver"]
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Literal
4+
5+
import requests
6+
from attrs import define, field
7+
8+
from griptape.drivers.rerank.base_rerank_driver import BaseRerankDriver
9+
10+
if TYPE_CHECKING:
11+
from griptape.artifacts import TextArtifact
12+
13+
14+
@define(kw_only=True)
15+
class NvidiaNimRerankDriver(BaseRerankDriver):
16+
"""Nvidia Rerank Driver."""
17+
18+
model: str = field()
19+
base_url: str = field()
20+
truncate: Literal["NONE", "END"] = field(default="NONE")
21+
headers: dict = field(factory=dict)
22+
23+
def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]:
24+
if not artifacts:
25+
return []
26+
27+
response = requests.post(
28+
url=f"{self.base_url.rstrip('/')}/v1/ranking",
29+
json=self._get_body(query, artifacts),
30+
headers=self.headers,
31+
)
32+
33+
response.raise_for_status()
34+
35+
ranked_artifacts = []
36+
for ranking in response.json()["rankings"]:
37+
artifact = artifacts[ranking["index"]]
38+
artifact.meta.update({"logit": ranking["logit"], "usage": ranking.get("usage")})
39+
ranked_artifacts.append(artifact)
40+
41+
return ranked_artifacts
42+
43+
def _get_body(self, query: str, artifacts: list[TextArtifact]) -> dict:
44+
return {
45+
"model": self.model,
46+
"query": {"text": query},
47+
"passages": [{"text": artifact.value} for artifact in artifacts],
48+
"truncate": self.truncate,
49+
}

tests/integration/test_code_blocks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"docs/griptape-framework/drivers/src/prompt_drivers_14.py",
2020
"docs/griptape-framework/drivers/src/observability_drivers_1.py",
2121
"docs/griptape-framework/drivers/src/observability_drivers_2.py",
22+
"docs/griptape-framework/drivers/src/nvidia_nim_rerank_driver.py",
2223
"docs/griptape-framework/structures/src/observability_1.py",
2324
"docs/griptape-framework/structures/src/observability_2.py",
2425
"docs/griptape-framework/data/src/loaders_9.py",
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import pytest
2+
import requests
3+
4+
from griptape.artifacts import TextArtifact
5+
from griptape.drivers.rerank.nvidia_nim import NvidiaNimRerankDriver
6+
7+
8+
class TestNvidiaNimRerankDriver:
9+
@pytest.fixture()
10+
def mock_client(self, mocker):
11+
def mock_post(*args, **kwargs):
12+
return mocker.Mock(
13+
status_code=200,
14+
json=lambda: {
15+
"rankings": [
16+
{"index": 0, "logit": 0.1, "usage": {"prompt_tokens": 10, "total_tokens": 20}},
17+
{"index": 1, "logit": 0.2, "usage": {"prompt_tokens": 10, "total_tokens": 20}},
18+
]
19+
},
20+
)
21+
22+
mocker.patch("griptape.drivers.rerank.nvidia_nim_rerank_driver.requests.post", side_effect=mock_post)
23+
24+
@pytest.fixture()
25+
def mock_empty_client(self, mocker):
26+
def mock_post(*args, **kwargs):
27+
return mocker.Mock(
28+
status_code=200,
29+
json=lambda: {"rankings": []},
30+
)
31+
32+
mocker.patch("griptape.drivers.rerank.nvidia_nim_rerank_driver.requests.post", side_effect=mock_post)
33+
34+
def test_run(self, mock_client):
35+
driver = NvidiaNimRerankDriver(model="model-name", base_url="http://localhost:8000")
36+
result = driver.run("hello", artifacts=[TextArtifact("foo"), TextArtifact("bar")])
37+
38+
assert len(result) == 2
39+
40+
def test_run_empty_artifacts(self, mock_empty_client):
41+
driver = NvidiaNimRerankDriver(model="model-name", base_url="http://localhost:8000")
42+
result = driver.run("hello", artifacts=[TextArtifact(""), TextArtifact(" ")])
43+
44+
assert len(result) == 0
45+
46+
result = driver.run("hello", artifacts=[])
47+
assert len(result) == 0
48+
49+
def test_run_error(self, mocker):
50+
mocker.patch(
51+
"griptape.drivers.rerank.nvidia_nim_rerank_driver.requests.post",
52+
return_value=mocker.Mock(
53+
status_code=500,
54+
text="Internal Server Error",
55+
raise_for_status=lambda: (_ for _ in ()).throw(requests.exceptions.HTTPError()),
56+
),
57+
)
58+
59+
driver = NvidiaNimRerankDriver(model="model-name", base_url="http://localhost:8000")
60+
61+
with pytest.raises(requests.exceptions.HTTPError):
62+
driver.run("hello", artifacts=[TextArtifact("foo"), TextArtifact("bar")])

0 commit comments

Comments
 (0)