Spaces:
Sleeping
Sleeping
import streamlit as st | |
def load_resources(): | |
import torch | |
from auto_gptq import AutoGPTQForCausalLM | |
from langchain import HuggingFacePipeline, PromptTemplate | |
from langchain.chains import RetrievalQA | |
from langchain.document_loaders import PyPDFDirectoryLoader | |
from langchain.embeddings import HuggingFaceBgeEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import Chroma | |
from pdf2image import convert_from_path | |
from transformers import AutoTokenizer, TextStreamer, pipeline | |
DEVICE = "cuda:0" if torch.cuda.is_available() else 'cpu' | |
loader = PyPDFDirectoryLoader("pdfs") | |
docs = loader.load() | |
embeddings = HuggingFaceBgeEmbeddings( | |
model_name = "BAAI/bge-base-en", model_kwargs = {"device" : DEVICE} | |
) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1024, chunk_overlap = 64) | |
texts = text_splitter.split_documents(docs) | |
db = Chroma.from_documents(texts, embeddings, persist_directory = 'db') | |
model_name_or_path = "TheBloke/Llama-2-13B-chat-GPTQ" | |
# model_basename = "gptq_model-4bit-128g" | |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast = True) | |
model = AutoGPTQForCausalLM.from_quantized( | |
model_name_or_path, | |
revision = "main", | |
# model_basename = model_basename, | |
use_safetensors = True, | |
trust_remote_code = True, | |
inject_fused_attention = False, | |
device = DEVICE, | |
quantize_config = None, | |
) | |
streamer = TextStreamer(tokenizer, skip_prompt = True, skip_special_tokens = True) | |
text_pipeline = pipeline("text-generation", | |
model = model, | |
tokenizer = tokenizer, | |
max_new_tokens= 1024, | |
temperature = 0, | |
top_p = 0.95, | |
repetition_penalty = 1.15, | |
streamer = streamer,) | |
llm = HuggingFacePipeline(pipeline = text_pipeline, model_kwargs = {"temperature":0}) | |
SYSTEM_PROMPT = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer." | |
def generate_prompt(prompt: str, system_prompt : str = SYSTEM_PROMPT) -> str: | |
return f""" | |
[INST] <<SYS>> | |
{system_prompt} | |
<</SYS>> | |
{prompt} [/INST] | |
""".strip() | |
template = generate_prompt( | |
""" | |
{context} | |
Question: {question} | |
""", | |
system_prompt = SYSTEM_PROMPT | |
) | |
prompt = PromptTemplate(template = template, input_variables = {"context", "question"}) | |
qa_chain = RetrievalQA.from_chain_type( | |
llm = llm, | |
chain_type = "stuff", | |
retriever = db.as_retriever(search_kwargs = {"k" : 2}), | |
return_source_documents = True, | |
chain_type_kwargs = {"prompt" : prompt}, | |
verbose = True) | |
return qa_chain | |
st.title("Please ask your question on Lithuanian rules for foreigners.") | |
qa_chain = load_resources() | |
context = st.text_area("Enter the context:") | |
question = st.text_input("Enter your question:") | |
if context and question: | |
# Perform Question Answering | |
answer = qa_chain(context=context, question=question) | |
# Display the answer | |
st.header("Answer:") | |
st.write(answer) | |