|
|
|
"""main.py |
|
|
|
Automatically generated by Colaboratory. |
|
|
|
Original file is located at |
|
https://colab.research.google.com/drive/1DPJ6tc2bCveBZyHSX02h_fbBS0fzzMrC |
|
""" |
|
|
|
|
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.chains.question_answering import load_qa_chain |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.llms import HuggingFacePipeline |
|
from langchain import PromptTemplate |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores import Chroma |
|
from langchain.document_loaders import ( |
|
CSVLoader, |
|
DirectoryLoader, |
|
GitLoader, |
|
NotebookLoader, |
|
OnlinePDFLoader, |
|
PythonLoader, |
|
TextLoader, |
|
UnstructuredFileLoader, |
|
UnstructuredHTMLLoader, |
|
UnstructuredPDFLoader, |
|
UnstructuredWordDocumentLoader, |
|
WebBaseLoader, |
|
) |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
StoppingCriteria, |
|
StoppingCriteriaList, |
|
pipeline, |
|
GenerationConfig, |
|
TextStreamer, |
|
pipeline |
|
) |
|
import torch |
|
from transformers import BitsAndBytesConfig |
|
|
|
def load_model( |
|
model_path="vilsonrodrigues/falcon-7b-instruct-sharded" |
|
): |
|
|
|
if not os.path.exists(model_path): |
|
raise FileNotFoundError(f"No model file found at {model_path}") |
|
|
|
quantization_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=torch.float16, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_use_double_quant=True, |
|
) |
|
|
|
model_4bit = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
device_map="auto", |
|
quantization_config=quantization_config, |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
pipeline = pipeline( |
|
"text-generation", |
|
model=model_4bit, |
|
tokenizer=tokenizer, |
|
use_cache=True, |
|
device_map="auto", |
|
max_length=700, |
|
do_sample=True, |
|
top_k=5, |
|
num_return_sequences=1, |
|
eos_token_id=tokenizer.eos_token_id, |
|
pad_token_id=tokenizer.eos_token_id, |
|
) |
|
|
|
llm = HuggingFacePipeline(pipeline=pipeline) |
|
return llm |
|
|
|
def create_vector_database(): |
|
DB_DIR: str = os.path.join(ABS_PATH, "db") |
|
""" |
|
Creates a vector database using document loaders and embeddings. |
|
|
|
This function loads data from PDF, markdown and text files in the 'data/' directory, |
|
splits the loaded documents into chunks, transforms them into embeddings using HuggingFace, |
|
and finally persists the embeddings into a Chroma vector database. |
|
|
|
""" |
|
|
|
pdf_loader = DirectoryLoader("data/", glob="**/*.pdf", loader_cls=PyPDFLoader) |
|
markdown_loader = DirectoryLoader("data/", glob="**/*.md", loader_cls=UnstructuredMarkdownLoader) |
|
text_loader = DirectoryLoader("data/", glob="**/*.txt", loader_cls=TextLoader) |
|
csv_loader = DirectoryLoader("data/", glob="**/*.csv", loader_cls=CSVLoader) |
|
python_loader = DirectoryLoader("data/", glob="**/*.py", loader_cls=PythonLoader) |
|
epub_loader = DirectoryLoader("data/", glob="**/*.epub", loader_cls=UnstructuredEPubLoader) |
|
html_loader = DirectoryLoader("data/", glob="**/*.html", loader_cls=UnstructuredHTMLLoader) |
|
ppt_loader = DirectoryLoader("data/", glob="**/*.ppt", loader_cls=UnstructuredPowerPointLoader) |
|
pptx_loader = DirectoryLoader("data/", glob="**/*.pptx", loader_cls=UnstructuredPowerPointLoader) |
|
doc_loader = DirectoryLoader("data/", glob="**/*.doc", loader_cls=UnstructuredWordDocumentLoader) |
|
docx_loader = DirectoryLoader("data/", glob="**/*.docx", loader_cls=UnstructuredWordDocumentLoader) |
|
odt_loader = DirectoryLoader("data/", glob="**/*.odt", loader_cls=UnstructuredODTLoader) |
|
notebook_loader = DirectoryLoader("data/", glob="**/*.ipynb", loader_cls=NotebookLoader) |
|
|
|
|
|
all_loaders = [pdf_loader, markdown_loader, text_loader, csv_loader, python_loader, epub_loader, html_loader, ppt_loader, pptx_loader, doc_loader, docx_loader, odt_loader, notebook_loader] |
|
|
|
|
|
loaded_documents = [] |
|
for loader in all_loaders: |
|
loaded_documents.extend(loader.load()) |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=40) |
|
chunked_documents = text_splitter.split_documents(loaded_documents) |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings( |
|
model_name="sentence-transformers/all-MiniLM-L6-v2" |
|
) |
|
|
|
|
|
db = Chroma.from_documents( |
|
documents=chunked_documents, |
|
embedding=embeddings, |
|
persist_directory=DB_DIR, |
|
) |
|
db.persist() |
|
return db |
|
|
|
def set_custom_prompt_condense(): |
|
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. |
|
|
|
Chat History: |
|
{chat_history} |
|
Follow Up Input: {question} |
|
Standalone question:""" |
|
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template) |
|
return CONDENSE_QUESTION_PROMPT |
|
|
|
def set_custom_prompt(): |
|
""" |
|
Prompt template for retrieval for each vectorstore |
|
""" |
|
|
|
|
|
prompt_template = """<Instructions> |
|
Important: |
|
Answer with the facts listed in the list of sources below. If there isn't enough information below, say you don't know. |
|
If asking a clarifying question to the user would help, ask the question. |
|
ALWAYS return a "SOURCES" part in your answer, except for small-talk conversations. |
|
|
|
Question: {question} |
|
|
|
{context} |
|
|
|
|
|
Question: {question} |
|
Helpful Answer: |
|
|
|
--------------------------- |
|
--------------------------- |
|
Sources: |
|
""" |
|
|
|
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) |
|
return prompt |
|
|
|
def create_chain(llm, prompt, CONDENSE_QUESTION_PROMPT, db): |
|
""" |
|
Creates a Retrieval Question-Answering (QA) chain using a given language model, prompt, and database. |
|
|
|
This function initializes a ConversationalRetrievalChain object with a specific chain type and configurations, |
|
and returns this chain. The retriever is set up to return the top 3 results (k=3). |
|
|
|
Args: |
|
llm (any): The language model to be used in the RetrievalQA. |
|
prompt (str): The prompt to be used in the chain type. |
|
db (any): The database to be used as the retriever. |
|
|
|
Returns: |
|
ConversationalRetrievalChain: The initialized conversational chain. |
|
""" |
|
memory = ConversationTokenBufferMemory(llm=llm, memory_key="chat_history", return_messages=True, input_key='question', max_token_limit=1000) |
|
chain = ConversationalRetrievalChain.from_llm( |
|
llm=llm, |
|
chain_type="stuff", |
|
retriever=db.as_retriever(search_kwargs={"k": 3}), |
|
return_source_documents=True, |
|
combine_docs_chain_kwargs={"prompt": prompt}, |
|
condense_question_prompt=CONDENSE_QUESTION_PROMPT, |
|
memory=memory, |
|
) |
|
return chain |
|
|
|
def create_retrieval_qa_bot(): |
|
if not os.path.exists(persist_dir): |
|
raise FileNotFoundError(f"No directory found at {persist_dir}") |
|
|
|
try: |
|
llm = load_model() |
|
except Exception as e: |
|
raise Exception(f"Failed to load model: {str(e)}") |
|
|
|
try: |
|
prompt = set_custom_prompt() |
|
except Exception as e: |
|
raise Exception(f"Failed to get prompt: {str(e)}") |
|
|
|
try: |
|
CONDENSE_QUESTION_PROMPT = set_custom_prompt_condense() |
|
except Exception as e: |
|
raise Exception(f"Failed to get condense prompt: {str(e)}") |
|
|
|
try: |
|
db = create_vector_database() |
|
except Exception as e: |
|
raise Exception(f"Failed to get database: {str(e)}") |
|
|
|
try: |
|
qa = create_chain( |
|
llm=llm, prompt=prompt,CONDENSE_QUESTION_PROMPT=CONDENSE_QUESTION_PROMPT, db=db |
|
) |
|
except Exception as e: |
|
raise Exception(f"Failed to create retrieval QA chain: {str(e)}") |
|
|
|
return qa |
|
|
|
def retrieve_bot_answer(query): |
|
""" |
|
Retrieves the answer to a given query using a QA bot. |
|
|
|
This function creates an instance of a QA bot, passes the query to it, |
|
and returns the bot's response. |
|
|
|
Args: |
|
query (str): The question to be answered by the QA bot. |
|
|
|
Returns: |
|
dict: The QA bot's response, typically a dictionary with response details. |
|
""" |
|
qa_bot_instance = create_retrieval_qa_bot() |
|
bot_response = qa_bot_instance({"query": query}) |
|
return bot_response |
|
|
|
import streamlit as st |
|
from your_module import load_model, set_custom_prompt, set_custom_prompt_condense, create_vector_database, retrieve_bot_answer |
|
|
|
def main(): |
|
st.title("Docuverse") |
|
|
|
|
|
uploaded_files = st.file_uploader("Upload your documents", type=["pdf", "md", "txt", "csv", "py", "epub", "html", "ppt", "pptx", "doc", "docx", "odt", "ipynb"], accept_multiple_files=True) |
|
|
|
if uploaded_files: |
|
|
|
for uploaded_file in uploaded_files: |
|
st.write(f"Uploaded: {uploaded_file.name}") |
|
|
|
st.write("Chat with the Document:") |
|
query = st.text_input("Ask a question:") |
|
|
|
if st.button("Get Answer"): |
|
if query: |
|
|
|
try: |
|
llm = load_model() |
|
prompt = set_custom_prompt() |
|
CONDENSE_QUESTION_PROMPT = set_custom_prompt_condense() |
|
db = create_vector_database() |
|
response = retrieve_bot_answer(query) |
|
|
|
|
|
st.write("Bot Response:") |
|
st.write(response) |
|
except Exception as e: |
|
st.error(f"An error occurred: {str(e)}") |
|
else: |
|
st.warning("Please enter a question.") |
|
|
|
if __name__ == "__main__": |
|
main() |