Skip to content

Commit 6811596

Browse files
authored
pgai knowledge base vector store driver init (#1889)
* pgai knowledge base vector store driver init * cleanup * more format * fix docs for now * fix unit test * remove debug log * lint * pgai extra * address comments * format * camelcase
1 parent 9dc0601 commit 6811596

File tree

9 files changed

+217
-1
lines changed

9 files changed

+217
-1
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import os
2+
3+
from griptape.drivers.vector.pgai import PgAiKnowledgeBaseVectorStoreDriver
4+
5+
# PG.AI connection parameters
6+
connection_string = os.environ["PGAI_CONNECTION_STRING"]
7+
knowledge_base_name = os.environ["PGAI_KNOWLEDGE_BASE_NAME"]
8+
9+
vector_store_driver = PgAiKnowledgeBaseVectorStoreDriver(
10+
connection_string=connection_string,
11+
knowledge_base_name=knowledge_base_name, # optional
12+
)
13+
14+
results = vector_store_driver.query(query="What is griptape?")
15+
16+
values = [r.to_artifact().value for r in results]
17+
18+
print("\n\n".join(values))

docs/griptape-framework/drivers/vector-store-drivers.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,3 +294,25 @@ The following example shows how to store vector entries and query the informatio
294294
```text
295295
--8<-- "docs/griptape-framework/drivers/logs/vector_store_drivers_11.txt"
296296
```
297+
298+
### PG.AI
299+
300+
!!! info
301+
302+
This Driver requires the `drivers-vector-pgai` [extra](../index.md#extras).
303+
304+
The PgAiKnowledgeBaseVectorStoreDriver integrates with PG.AI, a managed postgres platform from [EnterpriseDB](https://www.enterprisedb.com/).
305+
306+
Here is an example of how the Driver can be used to load and query information in a PG.AI Knowledge Base:
307+
308+
=== "Code"
309+
310+
```python
311+
--8<-- "docs/griptape-framework/drivers/src/vector_store_drivers_12.py"
312+
```
313+
314+
=== "Logs"
315+
316+
```text
317+
--8<-- "docs/griptape-framework/drivers/logs/vector_store_drivers_12.txt"
318+
```

griptape/drivers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from .vector.qdrant import QdrantVectorStoreDriver
5050
from .vector.astradb import AstraDbVectorStoreDriver
5151
from .vector.griptape_cloud import GriptapeCloudVectorStoreDriver
52+
from .vector.pgai import PgAiKnowledgeBaseVectorStoreDriver
5253

5354
from .sql import BaseSqlDriver
5455
from .sql.sql_driver import SqlDriver
@@ -226,6 +227,7 @@
226227
"OpenTelemetryObservabilityDriver",
227228
"PerplexityPromptDriver",
228229
"PerplexityWebSearchDriver",
230+
"PgAiKnowledgeBaseVectorStoreDriver",
229231
"PgVectorVectorStoreDriver",
230232
"PineconeVectorStoreDriver",
231233
"ProxyWebScraperDriver",
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from griptape.drivers.vector.pgai_knowledge_base_vector_store_driver import PgAiKnowledgeBaseVectorStoreDriver
2+
3+
__all__ = ["PgAiKnowledgeBaseVectorStoreDriver"]
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
from typing import TYPE_CHECKING, NoReturn, Optional
5+
6+
from attrs import Factory, define, field
7+
8+
from griptape.artifacts import ImageArtifact, ListArtifact, TextArtifact
9+
from griptape.drivers.embedding.dummy import DummyEmbeddingDriver
10+
from griptape.drivers.vector import BaseVectorStoreDriver
11+
from griptape.utils import import_optional_dependency
12+
from griptape.utils.decorators import lazy_property
13+
14+
if TYPE_CHECKING:
15+
import sqlalchemy
16+
17+
from griptape.drivers.embedding import BaseEmbeddingDriver
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
@define
23+
class PgAiKnowledgeBaseVectorStoreDriver(BaseVectorStoreDriver):
24+
connection_string: str = field(kw_only=True, metadata={"serializable": True})
25+
knowledge_base_name: str = field(kw_only=True, metadata={"serializable": True})
26+
embedding_driver: BaseEmbeddingDriver = field(
27+
default=Factory(lambda: DummyEmbeddingDriver()),
28+
metadata={"serializable": True},
29+
kw_only=True,
30+
init=False,
31+
)
32+
_engine: sqlalchemy.Engine = field(default=None, kw_only=True, alias="engine", metadata={"serializable": False})
33+
34+
@lazy_property()
35+
def engine(self) -> sqlalchemy.Engine:
36+
return import_optional_dependency("sqlalchemy").create_engine(self.connection_string)
37+
38+
def query(
39+
self,
40+
query: str | TextArtifact | ImageArtifact,
41+
*,
42+
count: Optional[int] = BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
43+
**kwargs,
44+
) -> list[BaseVectorStoreDriver.Entry]:
45+
if isinstance(query, ImageArtifact):
46+
raise ValueError(f"{self.__class__.__name__} does not support querying with Image Artifacts.")
47+
48+
sqlalchemy = import_optional_dependency("sqlalchemy")
49+
50+
with sqlalchemy.orm.Session(self.engine) as session:
51+
rows = session.query(sqlalchemy.func.aidb.retrieve_text(self.knowledge_base_name, query, count)).all()
52+
53+
entries = []
54+
for (row,) in rows:
55+
# Remove the first and last parentheses from the row and list by commas
56+
# Example: '(foo,bar)' -> ['foo', 'bar']
57+
row_list = "".join(row.replace("(", "", 1).rsplit(")", 1)).split(",")
58+
entries.append(
59+
BaseVectorStoreDriver.Entry(
60+
id=row_list[0],
61+
score=float(row_list[2]),
62+
meta={"artifact": TextArtifact(row_list[1]).to_json()},
63+
)
64+
)
65+
66+
return entries
67+
68+
def upsert_vector(
69+
self,
70+
vector: list[float],
71+
vector_id: Optional[str] = None,
72+
namespace: Optional[str] = None,
73+
meta: Optional[dict] = None,
74+
**kwargs,
75+
) -> str:
76+
raise NotImplementedError(f"{self.__class__.__name__} does not support vector upsert.")
77+
78+
def upsert_text_artifact(
79+
self,
80+
artifact: TextArtifact,
81+
namespace: Optional[str] = None,
82+
meta: Optional[dict] = None,
83+
vector_id: Optional[str] = None,
84+
**kwargs,
85+
) -> str:
86+
raise NotImplementedError(f"{self.__class__.__name__} does not support text artifact upsert.")
87+
88+
def upsert_text(
89+
self,
90+
string: str,
91+
vector_id: Optional[str] = None,
92+
namespace: Optional[str] = None,
93+
meta: Optional[dict] = None,
94+
**kwargs,
95+
) -> str:
96+
raise NotImplementedError(f"{self.__class__.__name__} does not support text upsert.")
97+
98+
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry:
99+
raise NotImplementedError(f"{self.__class__.__name__} does not support entry loading.")
100+
101+
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
102+
raise NotImplementedError(f"{self.__class__.__name__} does not support entry loading.")
103+
104+
def load_artifacts(self, *, namespace: Optional[str] = None) -> ListArtifact:
105+
raise NotImplementedError(f"{self.__class__.__name__} does not support Artifact loading.")
106+
107+
def delete_vector(self, vector_id: str) -> NoReturn:
108+
raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ drivers-vector-pgvector = [
6060
]
6161
drivers-vector-qdrant = ["qdrant-client>=1.10.1"]
6262
drivers-vector-astra-db = ["astrapy>=2.0"]
63+
drivers-vector-pgai = [
64+
"sqlalchemy>=2.0.31"
65+
]
6366
drivers-embedding-amazon-bedrock = ["boto3>=1.34.119"]
6467
drivers-embedding-amazon-sagemaker = ["boto3>=1.34.119"]
6568
drivers-embedding-huggingface = [

tests/integration/test_code_blocks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"docs/griptape-framework/structures/src/observability_2.py",
2424
"docs/griptape-framework/data/src/loaders_9.py",
2525
"docs/recipes/src/talk_to_an_audio_2.py",
26+
"docs/griptape-framework/drivers/src/vector_store_drivers_12.py",
2627
]
2728

2829

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import json
2+
from unittest.mock import MagicMock
3+
4+
import pytest
5+
6+
from griptape.drivers.vector.pgai import PgAiKnowledgeBaseVectorStoreDriver
7+
8+
9+
class TestPgAiKnowledgeBaseVectorStoreDriver:
10+
connection_string = "postgresql://postgres:postgres@localhost:5432/postgres"
11+
knowledge_base_name = "example_knowledge_base"
12+
13+
@pytest.fixture()
14+
def mock_engine(self):
15+
return MagicMock()
16+
17+
@pytest.fixture()
18+
def mock_session(self, mocker):
19+
session = MagicMock()
20+
mock_session_manager = MagicMock()
21+
mock_session_manager.__enter__.return_value = session
22+
mocker.patch("sqlalchemy.orm.Session", return_value=mock_session_manager)
23+
24+
return session
25+
26+
def test_initialize(self):
27+
PgAiKnowledgeBaseVectorStoreDriver(
28+
connection_string=self.connection_string, knowledge_base_name=self.knowledge_base_name
29+
)
30+
31+
def test_query(self, mock_engine, mock_session):
32+
test_ids = [17, 23]
33+
test_values = ['"foo"', "bar"]
34+
test_scores = [0.4, 0.6]
35+
mock_query = MagicMock()
36+
mock_query.all.return_value = [
37+
(f"{test_ids[0]},{test_values[0]},{test_scores[0]}",),
38+
(f"{test_ids[1]},{test_values[1]},{test_scores[1]}",),
39+
]
40+
mock_session.query.return_value = mock_query
41+
42+
driver = PgAiKnowledgeBaseVectorStoreDriver(
43+
engine=mock_engine, connection_string=self.connection_string, knowledge_base_name=self.knowledge_base_name
44+
)
45+
46+
result = driver.query("some query")
47+
48+
assert result[0].id == str(test_ids[0])
49+
assert result[1].id == str(test_ids[1])
50+
assert result[0].meta
51+
assert json.loads(result[0].meta["artifact"])["value"] == test_values[0]
52+
assert result[1].meta
53+
assert json.loads(result[1].meta["artifact"])["value"] == test_values[1]
54+
assert result[0].score == test_scores[0]
55+
assert result[1].score == test_scores[1]

uv.lock

Lines changed: 5 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)