Spaces:
Sleeping
Sleeping
from langchain.chains import ConversationalRetrievalChain | |
from langchain.prompts import PromptTemplate | |
import pickle | |
import config | |
from langchain.retrievers import EnsembleRetriever, BM25Retriever, ContextualCompressionRetriever | |
from memory import memory3 | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.retrievers.document_compressors import EmbeddingsFilter | |
from langchain.document_transformers import EmbeddingsRedundantFilter | |
from langchain.retrievers.document_compressors import DocumentCompressorPipeline | |
from langchain.text_splitter import CharacterTextSplitter | |
from pydantic import BaseModel, Field | |
from typing import Any, Optional, Dict, List | |
from huggingface_hub import InferenceClient | |
from langchain.llms.base import LLM | |
import os | |
chat_model_name = "HuggingFaceH4/zephyr-7b-alpha" | |
reform_model_name = "mistralai/Mistral-7B-Instruct-v0.1" | |
hf_token = os.getenv("apiToken") | |
kwargs = {"max_new_tokens":500, "temperature":0.9, "top_p":0.95, "repetition_penalty":1.0, "do_sample":True} | |
reform_kwargs = {"max_new_tokens":50, "temperature":0.5, "top_p":0.9, "repetition_penalty":1.0, "do_sample":True} | |
class KwArgsModel(BaseModel): | |
kwargs: Dict[str, Any] = Field(default_factory=dict) | |
class CustomInferenceClient(LLM, KwArgsModel): | |
model_name: str | |
inference_client: InferenceClient | |
def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None): | |
inference_client = InferenceClient(model=model_name, token=hf_token) | |
super().__init__( | |
model_name=model_name, | |
hf_token=hf_token, | |
kwargs=kwargs, | |
inference_client=inference_client | |
) | |
def _call( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None | |
) -> str: | |
if stop is not None: | |
raise ValueError("stop kwargs are not permitted.") | |
response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True, return_full_text=False) | |
response = ''.join(response_gen) | |
return response | |
def _llm_type(self) -> str: | |
return "custom" | |
def _identifying_params(self) -> dict: | |
return {"model_name": self.model_name} | |
chat_llm = CustomInferenceClient(model_name=chat_model_name, hf_token=hf_token, kwargs=kwargs) | |
reform_llm = CustomInferenceClient(model_name=reform_model_name, hf_token=hf_token, kwargs=reform_kwargs) | |
prompt_template = config.DEFAULT_CHAT_TEMPLATE | |
PROMPT = PromptTemplate( | |
template=prompt_template, input_variables=["context", "question", "chat_history"] | |
) | |
chain_type_kwargs = {"prompt": PROMPT} | |
embeddings = HuggingFaceEmbeddings() | |
vectorstore = FAISS.load_local("cima_faiss_index", embeddings) | |
retriever=vectorstore.as_retriever(search_type="similarity", search_kwargs={"k":5}) | |
splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0, separator=". ") | |
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings) | |
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.5) | |
pipeline_compressor = DocumentCompressorPipeline( | |
transformers=[splitter, redundant_filter, relevant_filter] | |
) | |
compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=retriever) | |
with open("docs_data.pkl", "rb") as file: | |
docs = pickle.load(file) | |
bm25_retriever = BM25Retriever.from_texts(docs) | |
bm25_retriever.k = 2 | |
bm25_compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=bm25_retriever) | |
ensemble_retriever = EnsembleRetriever(retrievers=[compression_retriever, bm25_compression_retriever], weights=[0.5, 0.5]) | |
custom_template = """Given the following conversation and a follow-up message, rephrase the follow-up user message to be a standalone message. If the follow-up message is not a question, keep it unchanged[/INST]. | |
Chat History: | |
{chat_history} | |
Follow-up user message: {question} | |
Rewritten user message:""" | |
CUSTOM_QUESTION_PROMPT = PromptTemplate.from_template(custom_template) | |
chat_chain = ConversationalRetrievalChain.from_llm(llm=chat_llm, | |
chain_type="stuff", | |
retriever=ensemble_retriever, | |
combine_docs_chain_kwargs=chain_type_kwargs, | |
return_source_documents=True, | |
get_chat_history=lambda h : h, | |
condense_question_prompt=CUSTOM_QUESTION_PROMPT, | |
memory=memory3, | |
condense_question_llm = reform_llm | |
) | |