File size: 4,375 Bytes
96d5d14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

import os
import torch
import transformers
import chainlit as cl
from getpass import getpass
from dotenv import load_dotenv
from huggingface_hub import login
from transformers import AutoModel
from langchain.llms import BaseLLM
from langchain import HuggingFaceHub
from langchain_community.llms import Ollama
from langchain_community.llms import Cohere
from langchain_community.llms import LlamaCpp
from langchain.llms import HuggingFacePipeline
from langchain_community.vectorstores import FAISS
from langchain_community.llms import CTransformers
from langchain.chains import ConversationalRetrievalChain
from langchain.retrievers import ContextualCompressionRetriever
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.retrievers.document_compressors import FlashrankRerank
from langchain.memory import ChatMessageHistory, ConversationBufferMemory
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler

load_dotenv()
COHERE_API_KEY = os.getenv('COHERE_API_KEY')


# HUGGINGFACEHUB_API_TOKEN = getpass()
# os.environ["HUGGINGFACEHUB_API_TOKEN"] = HUGGINGFACEHUB_API_TOKEN
# load_dotenv()

# HUGGINGFACE_TOKEN = os.getenv('HUGGINGFACE_TOKEN')
# print(HUGGINGFACE_TOKEN)
# login(token = HUGGINGFACE_TOKEN)


# embeddings_model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")

# from transformers import AutoModel

embeddings_model = HuggingFaceEmbeddings(
    model_name="mixedbread-ai/mxbai-embed-large-v1",
    model_kwargs={'device': 'cpu'},
)

# Load FIASS db index as retriever
db = FAISS.load_local("mxbai_faiss_index_v2", embeddings_model, allow_dangerous_deserialization=True)
retriever = db.as_retriever()

# Use Flashrank as rerank engine
compressor = FlashrankRerank()

# Pass reranker as base compressor and retriever as base retriever
# to ContextualCompressonRetriever.
compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor, base_retriever=retriever
)

# I/0 stream
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])


#* Round 2
# llm = HuggingFaceHub(
#     huggingfacehub_api_token=HUGGINGFACE_TOKEN, 
#     repo_id=model_id, 
#     model_kwargs={
#         "temperature": 0.5
#         }
#     )

#* Round 3
# llm = CTransformers(model=model_id)
# llm = CTransformers(model='IlyaGusev/saiga_llama3_8b_gguf', model_file='model-q4_K.gguf', model_type="llama")

# llm = CTransformers(model='../../data_test/Meta-Llama-3-8B.Q4_K_M.gguf', model_type='llama')

#* Round 4  
# n_gpu_layers = 15
# n_batch = 128
# llm = LlamaCpp(
#     model_path="../../data_test/Meta-Llama-3-8B.Q4_K_M.gguf",
#     # n_ctx = 1024,
#     n_gpu_layers=n_gpu_layers,
#     n_batch=n_batch,
#     f16_kv=True,
#     callback_manager=callback_manager,
#     verbose=True,
# )

# llm = Ollama(model="llama3", temperature=0.2)
llm = Cohere(temperature=0.2)

@cl.on_chat_start
async def on_chat_start():

    message_history = ChatMessageHistory()

    memory = ConversationBufferMemory(
        memory_key="chat_history",
        output_key="answer",
        chat_memory=message_history,
        return_messages=True,
    )

    chain = ConversationalRetrievalChain.from_llm(
        llm,
        chain_type="stuff",
        retriever=compression_retriever,
        memory=memory,
        return_source_documents=True,
    )

    cl.user_session.set("chain", chain)

#TODO: Stream response
@cl.on_message
async def main(message: cl.Message):
    chain = cl.user_session.get("chain") 
    cb = cl.AsyncLangchainCallbackHandler()

    res = await chain.acall(message.content, callbacks=[cb])
    answer = res["answer"]
    source_documents = res["source_documents"]  

    text_elements = [] 

    #* Returning Sources
    if source_documents:
        for source_idx, source_doc in enumerate(source_documents):
            source_name = f"source_{source_idx+1}"
            text_elements.append(
                cl.Text(content=source_doc.page_content, name=source_name)
            )
        source_names = [text_el.name for text_el in text_elements]

        if source_names:
            answer += f"\nSources: {', '.join(source_names)}"
        else:
            answer += "\nNo sources found"

    await cl.Message(content=answer, elements=text_elements, author="Brocxi").send()