|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_community.vectorstores import FAISS |
|
from langchain.schema import Document |
|
from typing import List |
|
|
|
|
|
class Retrieval: |
|
def __init__(self, model_name, max_model_tokens=384): |
|
self.model_name = model_name |
|
self.embeddings = HuggingFaceEmbeddings( |
|
model_name=model_name, |
|
encode_kwargs={"max_length": max_model_tokens, "truncation": True}, |
|
) |
|
|
|
def create_vector_store(self, chunks: List[Document]): |
|
|
|
self.chunks = chunks |
|
|
|
self.vectorstore = FAISS.from_documents(self.chunks, self.embeddings) |
|
|
|
def search(self, query, k=10) -> List[Document]: |
|
|
|
similar_docs = self.vectorstore.similarity_search(query, k) |
|
|
|
return similar_docs |
|
|