Bookworm / app.py.bkp
kshitijk
Switch to openAI apis
8a078c3
from pathlib import Path
from llama_index.core import(SimpleDirectoryReader,
VectorStoreIndex, StorageContext,
Settings,set_global_tokenizer)
from llama_index.llms.llama_cpp import LlamaCPP
from llama_index.llms.llama_cpp.llama_utils import (
messages_to_prompt,
completion_to_prompt,
)
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from transformers import AutoTokenizer, BitsAndBytesConfig
from llama_index.llms.huggingface import HuggingFaceLLM
import torch
import logging
import sys
import streamlit as st
import os
from llama_index.core import load_index_from_storage
default_bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
set_global_tokenizer(
AutoTokenizer.from_pretrained("NousResearch/Llama-2-13b-chat-hf").encode
)
def getDocs(doc_path="./data/"):
documents = SimpleDirectoryReader(doc_path).load_data()
return documents
def getVectorIndex(docs):
Settings.chunk_size = 512
index_set = {}
if(os.path.isdir(f"./storage/book_data")):
storage_context = StorageContext.from_defaults(persist_dir=f"./storage/book_data")
cur_index = load_index_from_storage(
storage_context,embed_model = getEmbedModel()
)
else:
storage_context = StorageContext.from_defaults()
cur_index = VectorStoreIndex.from_documents(docs, embed_model = getEmbedModel(), storage_context=storage_context)
storage_context.persist(persist_dir=f"./storage/book_data")
return cur_index
def getLLM():
model_path = "NousResearch/Llama-2-13b-chat-hf"
# model_path = "NousResearch/Llama-2-7b-chat-hf"
llm = HuggingFaceLLM(
context_window=3900,
max_new_tokens=256,
# generate_kwargs={"temperature": 0.25, "do_sample": False},
tokenizer_name=model_path,
model_name=model_path,
device_map=0,
tokenizer_kwargs={"max_length": 2048},
# uncomment this if using CUDA to reduce memory usage
model_kwargs={"torch_dtype": torch.float16,
"quantization_config": default_bnb_config,
}
)
return llm
def getQueryEngine(index):
query_engine = index.as_chat_engine(llm=getLLM())
return query_engine
def getEmbedModel():
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5")
return embed_model
st.set_page_config(page_title="Project BookWorm: Your own Librarian!", page_icon="πŸ¦™", layout="centered", initial_sidebar_state="auto", menu_items=None)
st.title("Project BookWorm: Your own Librarian!")
st.info("Use this app to get recommendations for books that your kids will love!", icon="πŸ“ƒ")
if "messages" not in st.session_state.keys(): # Initialize the chat messages history
st.session_state.messages = [
{"role": "assistant", "content": "Ask me a question about children's books or movies!"}
]
@st.cache_resource(show_spinner=False)
def load_data():
index = getVectorIndex(getDocs())
return index
query_engine = getQueryEngine(index)
index = load_data()
if "chat_engine" not in st.session_state.keys(): # Initialize the chat engine
st.session_state.chat_engine = index.as_chat_engine(llm=getLLM(),chat_mode="condense_question", verbose=True)
if prompt := st.chat_input("Your question"): # Prompt for user input and save to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
for message in st.session_state.messages: # Display the prior chat messages
with st.chat_message(message["role"]):
st.write(message["content"])
# If last message is not from assistant, generate a new response
if st.session_state.messages[-1]["role"] != "assistant":
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
response = st.session_state.chat_engine.chat(prompt)
st.write(response.response)
message = {"role": "assistant", "content": response.response}
st.session_state.messages.append(message) # Add response to message history
# if __name__ == "__main__":
# index = getVectorIndex(getDocs())
# query_engine = getQueryEngine(index)
# while(True):
# your_request = input("Your comment: ")
# response = query_engine.chat(your_request)
# print(response)