Spaces:
Runtime error
Runtime error
from langchain.vectorstores import FAISS | |
import math | |
import os | |
import pickle | |
import uuid | |
from pathlib import Path | |
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple | |
import numpy as np | |
from langchain.docstore.base import AddableMixin, Docstore | |
from langchain.docstore.document import Document | |
from langchain.docstore.in_memory import InMemoryDocstore | |
from langchain.embeddings.base import Embeddings | |
from langchain.vectorstores.base import VectorStore | |
from langchain.vectorstores.utils import maximal_marginal_relevance | |
class MyFAISS(FAISS): | |
def max_marginal_relevance_search_by_vector( | |
self, | |
embedding: List[float], | |
k: int = 4, | |
fetch_k: int = 20, | |
lambda_mult: float = 0.5, | |
filter: Optional[Dict[str, Any]] = None, | |
**kwargs: Any, | |
) -> List[Document]: | |
"""Return docs selected using the maximal marginal relevance. | |
Maximal marginal relevance optimizes for similarity to query AND diversity | |
among selected documents. | |
Args: | |
embedding: Embedding to look up documents similar to. | |
k: Number of Documents to return. Defaults to 4. | |
fetch_k: Number of Documents to fetch before filtering to | |
pass to MMR algorithm. | |
lambda_mult: Number between 0 and 1 that determines the degree | |
of diversity among the results with 0 corresponding | |
to maximum diversity and 1 to minimum diversity. | |
Defaults to 0.5. | |
Returns: | |
List of Documents selected by maximal marginal relevance. | |
""" | |
_, indices = self.index.search( | |
np.array([embedding], dtype=np.float32), | |
fetch_k if filter is None else fetch_k * 2, | |
) | |
if filter is not None: | |
filtered_indices = [] | |
for i in indices[0]: | |
if i == -1: | |
# This happens when not enough docs are returned. | |
continue | |
_id = self.index_to_docstore_id[i] | |
doc = self.docstore.search(_id) | |
if not isinstance(doc, Document): | |
raise ValueError(f"Could not find document for id {_id}, got {doc}") | |
print("metadata: " + str(doc.metadata)) | |
print("filter: " + str(filter)) | |
if any(filter_word in doc.metadata.get(key, '') for key, value in filter.items() for filter_word in | |
value.split()): | |
filtered_indices.append(i) | |
indices = np.array([filtered_indices]) | |
# -1 happens when not enough docs are returned. | |
embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1] | |
mmr_selected = maximal_marginal_relevance( | |
np.array([embedding], dtype=np.float32), | |
embeddings, | |
k=k, | |
lambda_mult=lambda_mult, | |
) | |
selected_indices = [indices[0][i] for i in mmr_selected] | |
docs = [] | |
for i in selected_indices: | |
if i == -1: | |
# This happens when not enough docs are returned. | |
continue | |
_id = self.index_to_docstore_id[i] | |
doc = self.docstore.search(_id) | |
if not isinstance(doc, Document): | |
raise ValueError(f"Could not find document for id {_id}, got {doc}") | |
docs.append(doc) | |
return docs | |
def max_marginal_relevance_search( | |
self, | |
query: str, | |
k: int = 4, | |
fetch_k: int = 20, | |
lambda_mult: float = 0.5, | |
filter: Optional[Dict[str, Any]] = None, | |
**kwargs: Any, | |
) -> List[Document]: | |
"""Return docs selected using the maximal marginal relevance. | |
Maximal marginal relevance optimizes for similarity to query AND diversity | |
among selected documents. | |
Args: | |
query: Text to look up documents similar to. | |
k: Number of Documents to return. Defaults to 4. | |
fetch_k: Number of Documents to fetch before filtering (if needed) to | |
pass to MMR algorithm. | |
lambda_mult: Number between 0 and 1 that determines the degree | |
of diversity among the results with 0 corresponding | |
to maximum diversity and 1 to minimum diversity. | |
Defaults to 0.5. | |
Returns: | |
List of Documents selected by maximal marginal relevance. | |
""" | |
print("MMR search") | |
embedding = self.embedding_function(query) | |
docs = self.max_marginal_relevance_search_by_vector( | |
embedding, | |
k, | |
fetch_k, | |
lambda_mult=lambda_mult, | |
filter=filter, | |
**kwargs, | |
) | |
return docs | |