-
Notifications
You must be signed in to change notification settings - Fork 1.1k
feat: Add feast rag retriver functionality #5405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -107,6 +107,11 @@ postgres = ["psycopg[binary,pool]==3.2.5"] | |
postgres-c = ["psycopg[c,pool]==3.2.5"] | ||
pytorch = ["torch==2.2.2", "torchvision>=0.17.2"] | ||
qdrant = ["qdrant-client>=1.12.0"] | ||
rag = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rag can also be included in nlp group https://github.com/feast-dev/feast/blob/master/pyproject.toml#L168 |
||
"transformers>=4.36.0", | ||
"sentence-transformers>=2.5.0", | ||
"datasets>=3.6.0", | ||
] | ||
redis = [ | ||
"redis>=4.2.2,<5", | ||
"hiredis>=2.0.0,<3", | ||
|
@@ -163,7 +168,10 @@ ci = [ | |
"types-setuptools", | ||
"types-tabulate", | ||
"virtualenv<20.24.2", | ||
"feast[aws, azure, cassandra, clickhouse, couchbase, delta, docling, duckdb, elasticsearch, faiss, gcp, ge, go, grpcio, hazelcast, hbase, ibis, ikv, k8s, mcp, milvus, mssql, mysql, opentelemetry, spark, trino, postgres, pytorch, qdrant, redis, singlestore, snowflake, sqlite_vec]" | ||
"transformers>=4.36.0", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be removed from here, adding rag to |
||
"sentence-transformers>=2.5.0", | ||
"datasets>=3.6.0", | ||
"feast[aws, azure, cassandra, clickhouse, couchbase, delta, docling, duckdb, elasticsearch, faiss, gcp, ge, go, grpcio, hazelcast, hbase, ibis, ikv, k8s, mcp, milvus, mssql, mysql, opentelemetry, spark, trino, postgres, pytorch, qdrant, rag, redis, singlestore, snowflake, sqlite_vec]" | ||
] | ||
nlp = ["feast[docling, milvus, pytorch]"] | ||
dev = ["feast[ci]"] | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,271 @@ | ||||||
# Copyright 2019 The Feast Authors | ||||||
# | ||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
# you may not use this file except in compliance with the License. | ||||||
# You may obtain a copy of the License at | ||||||
# | ||||||
# https://www.apache.org/licenses/LICENSE-2.0 | ||||||
# | ||||||
# Unless required by applicable law or agreed to in writing, software | ||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
from typing import Callable, Dict, List, Optional, Union, Any, Tuple, TYPE_CHECKING | ||||||
# import subprocess | ||||||
|
||||||
import numpy as np | ||||||
|
||||||
# try: | ||||||
# from transformers import RagRetriever | ||||||
# except ImportError: | ||||||
# print("Installing transformers...") | ||||||
# subprocess.check_call(["pip", "install", "transformers"]) | ||||||
from transformers import RagRetriever | ||||||
|
||||||
from feast import FeatureStore | ||||||
from feast.vector_store import VectorStore | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since VectorStore is now added to the init module(sdk/python/feast/init.py) you can access it directly :
Suggested change
OR if you are planning to use it this way, lets remove it from init module. |
||||||
|
||||||
from feast.torch_wrapper import get_torch | ||||||
|
||||||
# try: | ||||||
# from sentence_transformers import SentenceTransformer | ||||||
# except ImportError: | ||||||
# print("Installing sentence_transformers...") | ||||||
# subprocess.check_call(["pip", "install", "sentence-transformers"]) | ||||||
from sentence_transformers import SentenceTransformer | ||||||
|
||||||
|
||||||
class FeastIndex: | ||||||
"""Dummy index required by HuggingFace's RagRetriever.""" | ||||||
|
||||||
def __init__(self, vector_store: VectorStore): | ||||||
"""Initialize the Feast index. | ||||||
|
||||||
Args: | ||||||
vector_store: Vector store instance to use for retrieval | ||||||
""" | ||||||
self.vector_store = vector_store | ||||||
|
||||||
def get_top_docs(self, query_vectors: np.ndarray, n_docs: int = 5): | ||||||
"""Get top documents (not implemented). | ||||||
|
||||||
This method is required by the RagRetriever interface but is not used | ||||||
as we override the retrieve method in FeastRAGRetriever. | ||||||
Comment on lines
+53
to
+54
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we are suppose to override it and we should be using the Feast capabilities of getting the top docs, we should remove this method. |
||||||
""" | ||||||
raise NotImplementedError("get_top_docs is not yet implemented.") | ||||||
|
||||||
def get_doc_dicts(self, doc_ids: List[str]): | ||||||
"""Get document dictionaries (not implemented). | ||||||
|
||||||
This method is required by the RagRetriever interface but is not used | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above |
||||||
as we override the retrieve method in FeastRAGRetriever. | ||||||
""" | ||||||
raise NotImplementedError("get_doc_dicts is not yet implemented.") | ||||||
|
||||||
|
||||||
class FeastRAGRetriever(RagRetriever): | ||||||
"""RAG retriever implementation that uses Feast as a backend.""" | ||||||
|
||||||
VALID_SEARCH_TYPES = {"text", "vector", "hybrid"} | ||||||
|
||||||
def __init__( | ||||||
self, | ||||||
question_encoder_tokenizer, | ||||||
question_encoder, | ||||||
generator_tokenizer, | ||||||
generator_model, | ||||||
feast_repo_path: str, | ||||||
vector_store: VectorStore, | ||||||
search_type: str, | ||||||
Comment on lines
+74
to
+80
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good to static type check for all / missing since this class would be at the forefront of user / app. This would also help in let Data Scientist know about the types of the parameters. |
||||||
config: Dict[str, Any], | ||||||
index: FeastIndex, | ||||||
format_document: Optional[Callable[[Dict[str, Any]], str]] = None, | ||||||
id_field: str = "", | ||||||
query_encoder_model: Union[str, SentenceTransformer] = "all-MiniLM-L6-v2", | ||||||
**kwargs, | ||||||
): | ||||||
"""Initialize the Feast RAG retriever. | ||||||
|
||||||
Args: | ||||||
question_encoder_tokenizer: Tokenizer for encoding questions | ||||||
question_encoder: Model for encoding questions | ||||||
generator_tokenizer: Tokenizer for the generator model | ||||||
generator_model: The generator model | ||||||
feast_repo_path: Path to the Feast repository | ||||||
vector_store: Vector store instance to use for retrieval | ||||||
search_type: Type of search to perform (text, vector, or hybrid) | ||||||
config: Configuration for the retriever | ||||||
index: Index instance (must be FeastIndex) | ||||||
format_document: Optional function to format retrieved documents | ||||||
id_field: Field to use as document ID | ||||||
query_encoder_model: Model to use for encoding queries | ||||||
**kwargs: Additional arguments passed to RagRetriever | ||||||
""" | ||||||
from sentence_transformers import SentenceTransformer | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove the duplicate imports |
||||||
|
||||||
if search_type.lower() not in self.VALID_SEARCH_TYPES: | ||||||
raise ValueError( | ||||||
f"Unsupported search_type {search_type}. " | ||||||
f"Must be one of: {self.VALID_SEARCH_TYPES}" | ||||||
) | ||||||
super().__init__( | ||||||
config=config, | ||||||
question_encoder_tokenizer=question_encoder_tokenizer, | ||||||
generator_tokenizer=generator_tokenizer, | ||||||
index=index, | ||||||
init_retrieval=False, | ||||||
**kwargs, | ||||||
) | ||||||
self.question_encoder = question_encoder | ||||||
self.generator_model = generator_model | ||||||
self.generator_tokenizer = generator_tokenizer | ||||||
self.feast = FeatureStore(repo_path=feast_repo_path) | ||||||
self.vector_store = vector_store | ||||||
self.search_type = search_type.lower() | ||||||
self.format_document = format_document or self._default_format_document | ||||||
self.id_field = id_field | ||||||
|
||||||
if isinstance(query_encoder_model, str): | ||||||
self.query_encoder = SentenceTransformer(query_encoder_model) | ||||||
else: | ||||||
self.query_encoder = query_encoder_model | ||||||
|
||||||
def retrieve( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Being explicit helps. |
||||||
self, | ||||||
question_hidden_states: np.ndarray, | ||||||
n_docs: int, | ||||||
**kwargs | ||||||
) -> Tuple[np.ndarray, List[str], List[Dict[str, str]]]: | ||||||
""" | ||||||
Retrieve relevant documents using Feast as a backend and return results | ||||||
in a format compatible with Hugging Face's RagRetriever. | ||||||
|
||||||
Args: | ||||||
question_hidden_states (np.ndarray): | ||||||
Hidden state representation of the question from the encoder. | ||||||
Expected shape is (1, seq_len, hidden_dim). | ||||||
n_docs (int): | ||||||
Number of top documents to retrieve. | ||||||
query (Optional[str]): | ||||||
Optional raw query string. If not provided and search_type is "text" or "hybrid", | ||||||
it will be decoded from question_hidden_states. | ||||||
**kwargs: | ||||||
- query (Optional[str]): raw text query. If not provided and search_type is | ||||||
"text" or "hybrid", it will be decoded from question_hidden_states. | ||||||
|
||||||
Returns: | ||||||
Tuple containing: | ||||||
- retrieved_doc_embeds (np.ndarray): | ||||||
Embeddings of the retrieved documents with shape (1, n_docs, embed_dim). | ||||||
- doc_ids (List[str]): | ||||||
List of document IDs or passage identifiers. | ||||||
- doc_dicts (List[Dict[str, str]]): | ||||||
List of dictionaries containing document text fields. | ||||||
""" | ||||||
torch = get_torch() | ||||||
|
||||||
# Convert numpy hidden states to torch tensor if needed | ||||||
if isinstance(question_hidden_states, np.ndarray): | ||||||
question_hidden_states = torch.from_numpy(question_hidden_states) | ||||||
|
||||||
# Average pooling across the sequence dimension to get a fixed-size query vector | ||||||
query_vector = torch.mean(question_hidden_states, dim=1).squeeze().detach().cpu().numpy() | ||||||
|
||||||
query: Optional[str] = kwargs.get("query", None) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
# If no query string is provided and search is text/hybrid, decode from token ids | ||||||
if query is None and self.search_type in ("text", "hybrid"): | ||||||
query = self.question_encoder_tokenizer.decode( | ||||||
question_hidden_states.argmax(axis=-1), | ||||||
skip_special_tokens=True | ||||||
) | ||||||
|
||||||
# Perform search using the configured search type | ||||||
if self.search_type == "text": | ||||||
results = self.vector_store.query(query_string=query, top_k=n_docs) | ||||||
elif self.search_type == "vector": | ||||||
results = self.vector_store.query(query_vector=query_vector, top_k=n_docs) | ||||||
elif self.search_type == "hybrid": | ||||||
results = self.vector_store.query( | ||||||
query_string=query, | ||||||
query_vector=query_vector, | ||||||
top_k=n_docs | ||||||
) | ||||||
else: | ||||||
raise ValueError(f"Unsupported search type: {self.search_type}") | ||||||
|
||||||
# Extract embeddings, IDs, and document text for each result | ||||||
doc_embeddings = np.array([doc["embedding"] for doc in results]) | ||||||
doc_ids = [str(doc.get(self.id_field, f"id_{i}")) for i, doc in enumerate(results)] | ||||||
doc_dicts = [{"text": doc["text"]} for doc in results] | ||||||
Comment on lines
+199
to
+200
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can choose to loop once for efficiency:
|
||||||
|
||||||
# Add batch dimension to embeddings to match expected RAG format: (1, n_docs, embed_dim) | ||||||
retrieved_doc_embeds = np.expand_dims(doc_embeddings, axis=0) | ||||||
|
||||||
return retrieved_doc_embeds, doc_ids, doc_dicts | ||||||
|
||||||
def generate_answer( | ||||||
self, query: str, top_k: int = 5, max_new_tokens: int = 100 | ||||||
) -> str: | ||||||
"""Generate an answer for a query using retrieved context. | ||||||
|
||||||
Args: | ||||||
query: The query to answer | ||||||
top_k: Number of documents to retrieve | ||||||
max_new_tokens: Maximum number of tokens to generate | ||||||
|
||||||
Returns: | ||||||
Generated answer string | ||||||
""" | ||||||
# Convert query to hidden states format expected by retrieve | ||||||
inputs = self.question_encoder_tokenizer( | ||||||
query, return_tensors="pt", padding=True, truncation=True | ||||||
) | ||||||
question_hidden_states = self.question_encoder(**inputs).last_hidden_state | ||||||
|
||||||
# Get documents using retrieve method | ||||||
_, _, doc_dicts = self.retrieve(question_hidden_states, n_docs=top_k) | ||||||
|
||||||
# Format context from retrieved documents | ||||||
contexts = [doc["text"] for doc in doc_dicts] | ||||||
context = "\n\n".join(contexts) | ||||||
|
||||||
prompt = ( | ||||||
f"Use the following context to answer the question. Context:\n{context}\n\n" | ||||||
f"Question: {query}\nAnswer:" | ||||||
) | ||||||
|
||||||
self.generator_tokenizer.pad_token = self.generator_tokenizer.eos_token | ||||||
inputs = self.generator_tokenizer( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we generate inputs from another function for simplicity / modularization. Curerntly input generation and answering is mixed in this function. Lets keep this function for genrating answer only. Optional ! |
||||||
prompt, return_tensors="pt", padding=True, truncation=True | ||||||
) | ||||||
input_ids = inputs["input_ids"] | ||||||
attention_mask = inputs["attention_mask"] | ||||||
output_ids = self.generator_model.generate( | ||||||
input_ids=input_ids, | ||||||
attention_mask=attention_mask, | ||||||
max_new_tokens=max_new_tokens, | ||||||
pad_token_id=self.generator_tokenizer.pad_token_id, | ||||||
) | ||||||
return self.generator_tokenizer.decode(output_ids[0], skip_special_tokens=True) | ||||||
|
||||||
def _default_format_document(self, doc: Dict[str, Any]) -> str: | ||||||
"""Default document formatting function. | ||||||
|
||||||
Args: | ||||||
doc: Document dictionary to format | ||||||
|
||||||
Returns: | ||||||
Formatted document string | ||||||
""" | ||||||
lines = [] | ||||||
for key, value in doc.items(): | ||||||
# Skip vectors by checking for long float lists | ||||||
if ( | ||||||
isinstance(value, list) | ||||||
and len(value) > 10 | ||||||
and all(isinstance(x, (float, int)) for x in value) | ||||||
): | ||||||
continue | ||||||
lines.append(f"{key.replace('_', ' ').capitalize()}: {value}") | ||||||
return "\n".join(lines) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could also choose to directly skip the step by matching the OS name in if block: