|
from pathlib import Path |
|
|
|
from llama_index.core import(SimpleDirectoryReader, |
|
VectorStoreIndex, StorageContext, |
|
Settings,set_global_tokenizer) |
|
from llama_index.llms.llama_cpp import LlamaCPP |
|
from llama_index.llms.llama_cpp.llama_utils import ( |
|
messages_to_prompt, |
|
completion_to_prompt, |
|
) |
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding |
|
from transformers import AutoTokenizer, BitsAndBytesConfig |
|
from llama_index.llms.huggingface import HuggingFaceLLM |
|
import torch |
|
import logging |
|
import sys |
|
import streamlit as st |
|
import os |
|
from llama_index.core import load_index_from_storage |
|
default_bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type='nf4', |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
) |
|
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) |
|
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout)) |
|
set_global_tokenizer( |
|
AutoTokenizer.from_pretrained("NousResearch/Llama-2-13b-chat-hf").encode |
|
) |
|
|
|
|
|
def getDocs(doc_path="./data/"): |
|
documents = SimpleDirectoryReader(doc_path).load_data() |
|
return documents |
|
|
|
|
|
def getVectorIndex(docs): |
|
Settings.chunk_size = 512 |
|
index_set = {} |
|
|
|
if(os.path.isdir(f"./storage/book_data")): |
|
storage_context = StorageContext.from_defaults(persist_dir=f"./storage/book_data") |
|
cur_index = load_index_from_storage( |
|
storage_context,embed_model = getEmbedModel() |
|
) |
|
else: |
|
storage_context = StorageContext.from_defaults() |
|
cur_index = VectorStoreIndex.from_documents(docs, embed_model = getEmbedModel(), storage_context=storage_context) |
|
storage_context.persist(persist_dir=f"./storage/book_data") |
|
return cur_index |
|
|
|
|
|
def getLLM(): |
|
|
|
model_path = "NousResearch/Llama-2-13b-chat-hf" |
|
# model_path = "NousResearch/Llama-2-7b-chat-hf" |
|
|
|
llm = HuggingFaceLLM( |
|
context_window=3900, |
|
max_new_tokens=256, |
|
# generate_kwargs={"temperature": 0.25, "do_sample": False}, |
|
tokenizer_name=model_path, |
|
model_name=model_path, |
|
device_map=0, |
|
tokenizer_kwargs={"max_length": 2048}, |
|
# uncomment this if using CUDA to reduce memory usage |
|
model_kwargs={"torch_dtype": torch.float16, |
|
"quantization_config": default_bnb_config, |
|
} |
|
) |
|
return llm |
|
|
|
|
|
def getQueryEngine(index): |
|
query_engine = index.as_chat_engine(llm=getLLM()) |
|
return query_engine |
|
|
|
def getEmbedModel(): |
|
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-en-v1.5") |
|
return embed_model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Project BookWorm: Your own Librarian!", page_icon="π¦", layout="centered", initial_sidebar_state="auto", menu_items=None) |
|
st.title("Project BookWorm: Your own Librarian!") |
|
st.info("Use this app to get recommendations for books that your kids will love!", icon="π") |
|
|
|
if "messages" not in st.session_state.keys(): # Initialize the chat messages history |
|
st.session_state.messages = [ |
|
{"role": "assistant", "content": "Ask me a question about children's books or movies!"} |
|
] |
|
|
|
@st.cache_resource(show_spinner=False) |
|
def load_data(): |
|
index = getVectorIndex(getDocs()) |
|
return index |
|
query_engine = getQueryEngine(index) |
|
|
|
index = load_data() |
|
|
|
if "chat_engine" not in st.session_state.keys(): # Initialize the chat engine |
|
st.session_state.chat_engine = index.as_chat_engine(llm=getLLM(),chat_mode="condense_question", verbose=True) |
|
|
|
if prompt := st.chat_input("Your question"): # Prompt for user input and save to chat history |
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
|
for message in st.session_state.messages: # Display the prior chat messages |
|
with st.chat_message(message["role"]): |
|
st.write(message["content"]) |
|
|
|
# If last message is not from assistant, generate a new response |
|
if st.session_state.messages[-1]["role"] != "assistant": |
|
with st.chat_message("assistant"): |
|
with st.spinner("Thinking..."): |
|
response = st.session_state.chat_engine.chat(prompt) |
|
st.write(response.response) |
|
message = {"role": "assistant", "content": response.response} |
|
st.session_state.messages.append(message) # Add response to message history |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# if __name__ == "__main__": |
|
|
|
# index = getVectorIndex(getDocs()) |
|
# query_engine = getQueryEngine(index) |
|
# while(True): |
|
# your_request = input("Your comment: ") |
|
# response = query_engine.chat(your_request) |
|
# print(response) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|