|
from langchain.vectorstores import FAISS |
|
from typing import Any, Callable, List, Optional, Tuple, Dict |
|
from langchain.docstore.document import Document |
|
from langchain.docstore.base import Docstore |
|
|
|
from langchain.vectorstores.utils import maximal_marginal_relevance |
|
from langchain.embeddings.base import Embeddings |
|
import uuid |
|
from langchain.docstore.in_memory import InMemoryDocstore |
|
|
|
import numpy as np |
|
|
|
def dependable_faiss_import() -> Any: |
|
"""Import faiss if available, otherwise raise error.""" |
|
try: |
|
import faiss |
|
except ImportError: |
|
raise ValueError( |
|
"Could not import faiss python package. " |
|
"Please install it with `pip install faiss` " |
|
"or `pip install faiss-cpu` (depending on Python version)." |
|
) |
|
return faiss |
|
|
|
class FAISSVS(FAISS): |
|
def __init__(self, |
|
embedding_function: Callable[..., Any], |
|
index: Any, |
|
docstore: Docstore, |
|
index_to_docstore_id: Dict[int, str]): |
|
super().__init__(embedding_function, index, docstore, index_to_docstore_id) |
|
|
|
def max_marginal_relevance_search_by_vector( |
|
self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any |
|
) -> List[Tuple[Document, float]]: |
|
"""Return docs selected using the maximal marginal relevance. |
|
|
|
Maximal marginal relevance optimizes for similarity to query AND diversity |
|
among selected documents. |
|
|
|
Args: |
|
embedding: Embedding to look up documents similar to. |
|
k: Number of Documents to return. Defaults to 4. |
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm. |
|
|
|
Returns: |
|
List of Documents with scores selected by maximal marginal relevance. |
|
""" |
|
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), fetch_k) |
|
|
|
embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1] |
|
mmr_selected = maximal_marginal_relevance( |
|
np.array([embedding], dtype=np.float32), embeddings, k=k |
|
) |
|
selected_indices = [indices[0][i] for i in mmr_selected] |
|
selected_scores = [scores[0][i] for i in mmr_selected] |
|
docs = [] |
|
for i, score in zip(selected_indices, selected_scores): |
|
if i == -1: |
|
|
|
continue |
|
_id = self.index_to_docstore_id[i] |
|
doc = self.docstore.search(_id) |
|
if not isinstance(doc, Document): |
|
raise ValueError(f"Could not find document for id {_id}, got {doc}") |
|
docs.append((doc, score)) |
|
return docs |
|
|
|
def max_marginal_relevance_search( |
|
self, |
|
query: str, |
|
k: int = 4, |
|
fetch_k: int = 20, |
|
**kwargs: Any, |
|
) -> List[Tuple[Document, float]]: |
|
"""Return docs selected using the maximal marginal relevance. |
|
|
|
Maximal marginal relevance optimizes for similarity to query AND diversity |
|
among selected documents. |
|
|
|
Args: |
|
query: Text to look up documents similar to. |
|
k: Number of Documents to return. Defaults to 4. |
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm. |
|
|
|
Returns: |
|
List of Documents with scores selected by maximal marginal relevance. |
|
""" |
|
embedding = self.embedding_function(query) |
|
docs = self.max_marginal_relevance_search_by_vector(embedding, k, fetch_k) |
|
return docs |
|
|
|
@classmethod |
|
def __from( |
|
cls, |
|
texts: List[str], |
|
embeddings: List[List[float]], |
|
embedding: Embeddings, |
|
metadatas: Optional[List[dict]] = None, |
|
**kwargs: Any, |
|
) -> FAISS: |
|
faiss = dependable_faiss_import() |
|
index = faiss.IndexFlatIP(len(embeddings[0])) |
|
index.add(np.array(embeddings, dtype=np.float32)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
documents = [] |
|
for i, text in enumerate(texts): |
|
metadata = metadatas[i] if metadatas else {} |
|
documents.append(Document(page_content=text, metadata=metadata)) |
|
index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))} |
|
docstore = InMemoryDocstore( |
|
{index_to_id[i]: doc for i, doc in enumerate(documents)} |
|
) |
|
return cls(embedding.embed_query, index, docstore, index_to_id) |
|
|
|
|