Skip to content

Commit 9b6bfff

Browse files
authored
fix(drivers-vector-marqo): fix upsert failing due to inability to upsert_vectors (#1803)
1 parent fe27585 commit 9b6bfff

File tree

2 files changed

+22
-48
lines changed

2 files changed

+22
-48
lines changed

griptape/drivers/vector/base_vector_store_driver.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def upsert_text_artifacts(
4242
**kwargs,
4343
) -> list[str] | dict[str, list[str]]:
4444
warnings.warn(
45-
"`BaseVectorStoreDriver.upsert_text_artifacts` is deprecated and will be removed in a future release. `BaseEmbeddingDriver.upsert_collection` is a drop-in replacement.",
45+
"`BaseVectorStoreDriver.upsert_text_artifacts` is deprecated and will be removed in a future release. `BaseVectorStoreDriver.upsert_collection` is a drop-in replacement.",
4646
DeprecationWarning,
4747
stacklevel=2,
4848
)
@@ -58,7 +58,7 @@ def upsert_text_artifact(
5858
**kwargs,
5959
) -> str:
6060
warnings.warn(
61-
"`BaseVectorStoreDriver.upsert_text_artifacts` is deprecated and will be removed in a future release. `BaseEmbeddingDriver.upsert` is a drop-in replacement.",
61+
"`BaseVectorStoreDriver.upsert_text_artifacts` is deprecated and will be removed in a future release. `BaseVectorStoreDriver.upsert` is a drop-in replacement.",
6262
DeprecationWarning,
6363
stacklevel=2,
6464
)
@@ -74,7 +74,7 @@ def upsert_text(
7474
**kwargs,
7575
) -> str:
7676
warnings.warn(
77-
"`BaseVectorStoreDriver.upsert_text` is deprecated and will be removed in a future release. `BaseEmbeddingDriver.upsert` is a drop-in replacement.",
77+
"`BaseVectorStoreDriver.upsert_text` is deprecated and will be removed in a future release. `BaseVectorStoreDriver.upsert` is a drop-in replacement.",
7878
DeprecationWarning,
7979
stacklevel=2,
8080
)

griptape/drivers/vector/marqo_vector_store_driver.py

Lines changed: 19 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55
from attrs import define, field
66

77
from griptape import utils
8+
from griptape.artifacts import ImageArtifact, TextArtifact
89
from griptape.drivers.vector import BaseVectorStoreDriver
910
from griptape.utils import import_optional_dependency
1011
from griptape.utils.decorators import lazy_property
1112

1213
if TYPE_CHECKING:
1314
import marqo
1415

15-
from griptape.artifacts import ImageArtifact, TextArtifact
16-
1716

1817
@define
1918
class MarqoVectorStoreDriver(BaseVectorStoreDriver):
@@ -37,28 +36,40 @@ class MarqoVectorStoreDriver(BaseVectorStoreDriver):
3736
def client(self) -> marqo.Client:
3837
return import_optional_dependency("marqo").Client(self.url, api_key=self.api_key)
3938

40-
def upsert_text(
39+
def upsert(
4140
self,
42-
string: str,
41+
value: str | TextArtifact | ImageArtifact,
4342
*,
44-
vector_id: Optional[str] = None,
4543
namespace: Optional[str] = None,
4644
meta: Optional[dict] = None,
45+
vector_id: Optional[str] = None,
4746
**kwargs: Any,
4847
) -> str:
4948
"""Upsert a text document into the Marqo index.
5049
5150
Args:
52-
string: The string to be indexed.
53-
vector_id: The ID for the vector. If None, Marqo will generate an ID.
51+
value: The value to be indexed.
5452
namespace: An optional namespace for the document.
5553
meta: An optional dictionary of metadata for the document.
54+
vector_id: The ID for the vector. If None, Marqo will generate an ID.
5655
kwargs: Additional keyword arguments to pass to the Marqo client.
5756
5857
Returns:
5958
str: The ID of the document that was added.
6059
"""
61-
doc = {"_id": vector_id, "Description": string} # Description will be treated as tensor field
60+
if isinstance(value, TextArtifact):
61+
artifact_json = value.to_json()
62+
vector_id = utils.str_to_hash(value.value) if vector_id is None else vector_id
63+
64+
doc = {
65+
"_id": vector_id,
66+
"Description": value.value,
67+
"artifact": str(artifact_json),
68+
}
69+
elif isinstance(value, ImageArtifact):
70+
raise NotImplementedError("`MarqoVectorStoreDriver` does not upserting Image Artifacts.")
71+
else:
72+
doc = {"_id": vector_id, "Description": value}
6273

6374
# Non-tensor fields
6475
if meta:
@@ -72,43 +83,6 @@ def upsert_text(
7283
else:
7384
raise ValueError(f"Failed to upsert text: {response}")
7485

75-
def upsert_text_artifact(
76-
self,
77-
artifact: TextArtifact,
78-
*,
79-
namespace: Optional[str] = None,
80-
meta: Optional[dict] = None,
81-
vector_id: Optional[str] = None,
82-
**kwargs: Any,
83-
) -> str:
84-
"""Upsert a text artifact into the Marqo index.
85-
86-
Args:
87-
artifact: The text artifact to be indexed.
88-
namespace: An optional namespace for the artifact.
89-
meta: An optional dictionary of metadata for the artifact.
90-
vector_id: An optional explicit vector_id.
91-
kwargs: Additional keyword arguments to pass to the Marqo client.
92-
93-
Returns:
94-
str: The ID of the artifact that was added.
95-
"""
96-
artifact_json = artifact.to_json()
97-
vector_id = utils.str_to_hash(artifact.value) if vector_id is None else vector_id
98-
99-
doc = {
100-
"_id": vector_id,
101-
"Description": artifact.value, # Description will be treated as tensor field
102-
"artifact": str(artifact_json),
103-
"namespace": namespace,
104-
}
105-
106-
response = self.client.index(self.index).add_documents([doc], tensor_fields=["Description", "artifact"])
107-
if isinstance(response, dict) and "items" in response and response["items"]:
108-
return response["items"][0]["_id"]
109-
else:
110-
raise ValueError(f"Failed to upsert text: {response}")
111-
11286
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
11387
"""Load a document entry from the Marqo index.
11488

0 commit comments

Comments
 (0)