File size: 4,375 Bytes
96d5d14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os
import torch
import transformers
import chainlit as cl
from getpass import getpass
from dotenv import load_dotenv
from huggingface_hub import login
from transformers import AutoModel
from langchain.llms import BaseLLM
from langchain import HuggingFaceHub
from langchain_community.llms import Ollama
from langchain_community.llms import Cohere
from langchain_community.llms import LlamaCpp
from langchain.llms import HuggingFacePipeline
from langchain_community.vectorstores import FAISS
from langchain_community.llms import CTransformers
from langchain.chains import ConversationalRetrievalChain
from langchain.retrievers import ContextualCompressionRetriever
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.retrievers.document_compressors import FlashrankRerank
from langchain.memory import ChatMessageHistory, ConversationBufferMemory
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler
load_dotenv()
COHERE_API_KEY = os.getenv('COHERE_API_KEY')
# HUGGINGFACEHUB_API_TOKEN = getpass()
# os.environ["HUGGINGFACEHUB_API_TOKEN"] = HUGGINGFACEHUB_API_TOKEN
# load_dotenv()
# HUGGINGFACE_TOKEN = os.getenv('HUGGINGFACE_TOKEN')
# print(HUGGINGFACE_TOKEN)
# login(token = HUGGINGFACE_TOKEN)
# embeddings_model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
# from transformers import AutoModel
embeddings_model = HuggingFaceEmbeddings(
model_name="mixedbread-ai/mxbai-embed-large-v1",
model_kwargs={'device': 'cpu'},
)
# Load FIASS db index as retriever
db = FAISS.load_local("mxbai_faiss_index_v2", embeddings_model, allow_dangerous_deserialization=True)
retriever = db.as_retriever()
# Use Flashrank as rerank engine
compressor = FlashrankRerank()
# Pass reranker as base compressor and retriever as base retriever
# to ContextualCompressonRetriever.
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever
)
# I/0 stream
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
#* Round 2
# llm = HuggingFaceHub(
# huggingfacehub_api_token=HUGGINGFACE_TOKEN,
# repo_id=model_id,
# model_kwargs={
# "temperature": 0.5
# }
# )
#* Round 3
# llm = CTransformers(model=model_id)
# llm = CTransformers(model='IlyaGusev/saiga_llama3_8b_gguf', model_file='model-q4_K.gguf', model_type="llama")
# llm = CTransformers(model='../../data_test/Meta-Llama-3-8B.Q4_K_M.gguf', model_type='llama')
#* Round 4
# n_gpu_layers = 15
# n_batch = 128
# llm = LlamaCpp(
# model_path="../../data_test/Meta-Llama-3-8B.Q4_K_M.gguf",
# # n_ctx = 1024,
# n_gpu_layers=n_gpu_layers,
# n_batch=n_batch,
# f16_kv=True,
# callback_manager=callback_manager,
# verbose=True,
# )
# llm = Ollama(model="llama3", temperature=0.2)
llm = Cohere(temperature=0.2)
@cl.on_chat_start
async def on_chat_start():
message_history = ChatMessageHistory()
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key="answer",
chat_memory=message_history,
return_messages=True,
)
chain = ConversationalRetrievalChain.from_llm(
llm,
chain_type="stuff",
retriever=compression_retriever,
memory=memory,
return_source_documents=True,
)
cl.user_session.set("chain", chain)
#TODO: Stream response
@cl.on_message
async def main(message: cl.Message):
chain = cl.user_session.get("chain")
cb = cl.AsyncLangchainCallbackHandler()
res = await chain.acall(message.content, callbacks=[cb])
answer = res["answer"]
source_documents = res["source_documents"]
text_elements = []
#* Returning Sources
if source_documents:
for source_idx, source_doc in enumerate(source_documents):
source_name = f"source_{source_idx+1}"
text_elements.append(
cl.Text(content=source_doc.page_content, name=source_name)
)
source_names = [text_el.name for text_el in text_elements]
if source_names:
answer += f"\nSources: {', '.join(source_names)}"
else:
answer += "\nNo sources found"
await cl.Message(content=answer, elements=text_elements, author="Brocxi").send()
|