cima-free-chat / query_data.py
ethanrom's picture
Update query_data.py
03c9deb
from langchain.chains import ConversationalRetrievalChain
from langchain.prompts import PromptTemplate
import pickle
import config
from langchain.retrievers import EnsembleRetriever, BM25Retriever, ContextualCompressionRetriever
from memory import memory3
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.document_transformers import EmbeddingsRedundantFilter
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
from langchain.text_splitter import CharacterTextSplitter
from pydantic import BaseModel, Field
from typing import Any, Optional, Dict, List
from huggingface_hub import InferenceClient
from langchain.llms.base import LLM
import os
chat_model_name = "HuggingFaceH4/zephyr-7b-alpha"
reform_model_name = "mistralai/Mistral-7B-Instruct-v0.1"
hf_token = os.getenv("apiToken")
kwargs = {"max_new_tokens":500, "temperature":0.9, "top_p":0.95, "repetition_penalty":1.0, "do_sample":True}
reform_kwargs = {"max_new_tokens":50, "temperature":0.5, "top_p":0.9, "repetition_penalty":1.0, "do_sample":True}
class KwArgsModel(BaseModel):
kwargs: Dict[str, Any] = Field(default_factory=dict)
class CustomInferenceClient(LLM, KwArgsModel):
model_name: str
inference_client: InferenceClient
def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None):
inference_client = InferenceClient(model=model_name, token=hf_token)
super().__init__(
model_name=model_name,
hf_token=hf_token,
kwargs=kwargs,
inference_client=inference_client
)
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True, return_full_text=False)
response = ''.join(response_gen)
return response
@property
def _llm_type(self) -> str:
return "custom"
@property
def _identifying_params(self) -> dict:
return {"model_name": self.model_name}
chat_llm = CustomInferenceClient(model_name=chat_model_name, hf_token=hf_token, kwargs=kwargs)
reform_llm = CustomInferenceClient(model_name=reform_model_name, hf_token=hf_token, kwargs=reform_kwargs)
prompt_template = config.DEFAULT_CHAT_TEMPLATE
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question", "chat_history"]
)
chain_type_kwargs = {"prompt": PROMPT}
embeddings = HuggingFaceEmbeddings()
vectorstore = FAISS.load_local("cima_faiss_index", embeddings)
retriever=vectorstore.as_retriever(search_type="similarity", search_kwargs={"k":5})
splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0, separator=". ")
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
relevant_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.5)
pipeline_compressor = DocumentCompressorPipeline(
transformers=[splitter, redundant_filter, relevant_filter]
)
compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=retriever)
with open("docs_data.pkl", "rb") as file:
docs = pickle.load(file)
bm25_retriever = BM25Retriever.from_texts(docs)
bm25_retriever.k = 2
bm25_compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=bm25_retriever)
ensemble_retriever = EnsembleRetriever(retrievers=[compression_retriever, bm25_compression_retriever], weights=[0.5, 0.5])
custom_template = """Given the following conversation and a follow-up message, rephrase the follow-up user message to be a standalone message. If the follow-up message is not a question, keep it unchanged[/INST].
Chat History:
{chat_history}
Follow-up user message: {question}
Rewritten user message:"""
CUSTOM_QUESTION_PROMPT = PromptTemplate.from_template(custom_template)
chat_chain = ConversationalRetrievalChain.from_llm(llm=chat_llm,
chain_type="stuff",
retriever=ensemble_retriever,
combine_docs_chain_kwargs=chain_type_kwargs,
return_source_documents=True,
get_chat_history=lambda h : h,
condense_question_prompt=CUSTOM_QUESTION_PROMPT,
memory=memory3,
condense_question_llm = reform_llm
)