EnverLee's picture
Update app.py
14b5f79 verified
import gradio as gr
from datasets import load_dataset
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
import torch
from threading import Thread
from sentence_transformers import SentenceTransformer
import faiss
import fitz # PyMuPDF
# ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ Hugging Face ํ† ํฐ ๊ฐ€์ ธ์˜ค๊ธฐ
token = os.environ.get("HF_TOKEN")
if not token:
raise ValueError("Hugging Face token is missing. Please set it in your environment variables.")
# ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋“œ
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
# PDF์—์„œ ํ…์ŠคํŠธ ์ถ”์ถœ
def extract_text_from_pdf(pdf_path):
doc = fitz.open(pdf_path)
text = ""
for page in doc:
text += page.get_text()
return text
# ๋ฒ•๋ฅ  ๋ฌธ์„œ PDF ๊ฒฝ๋กœ ์ง€์ • ๋ฐ ํ…์ŠคํŠธ ์ถ”์ถœ
pdf_path = "laws.pdf" # ์—ฌ๊ธฐ์— ์‹ค์ œ PDF ๊ฒฝ๋กœ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”.
law_text = extract_text_from_pdf(pdf_path)
# ๋ฒ•๋ฅ  ๋ฌธ์„œ ํ…์ŠคํŠธ๋ฅผ ๋ฌธ์žฅ ๋‹จ์œ„๋กœ ๋‚˜๋ˆ„๊ณ  ์ž„๋ฒ ๋”ฉ
law_sentences = law_text.split('\n')
law_embeddings = ST.encode(law_sentences)
# FAISS ์ธ๋ฑ์Šค ์ƒ์„ฑ ๋ฐ ์ž„๋ฒ ๋”ฉ ์ถ”๊ฐ€
index = faiss.IndexFlatL2(law_embeddings.shape[1])
index.add(law_embeddings)
# Hugging Face์—์„œ ๋ฒ•๋ฅ  ์ƒ๋‹ด ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
dataset = load_dataset("jihye-moon/LawQA-Ko")
data = dataset["train"]
# ์งˆ๋ฌธ ์ปฌ๋Ÿผ์„ ์ž„๋ฒ ๋”ฉํ•˜์—ฌ ์ƒˆ๋กœ์šด ์ปฌ๋Ÿผ์— ์ถ”๊ฐ€
data = data.map(lambda x: {"question_embedding": ST.encode(x["question"])}, batched=True)
data.add_faiss_index(column="question_embedding")
# LLaMA ๋ชจ๋ธ ์„ค์ •
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=bnb_config,
token=token
)
SYS_PROMPT = """You are an assistant for answering legal questions.
You are given the extracted parts of legal documents and a question. Provide a conversational answer.
If you don't know the answer, just say "I do not know." Don't makup an answer."""
# ๋ฒ•๋ฅ  ๋ฌธ์„œ ๊ฒ€์ƒ‰ ํ•จ์ˆ˜
def search_law(query, k=5):
query_embedding = ST.encode([query])
D, I = index.search(query_embedding, k)
return [(law_sentences[i], D[0][idx]) for idx, i in enumerate(I[0])]
# ๋ฒ•๋ฅ  ์ƒ๋‹ด ๋ฐ์ดํ„ฐ ๊ฒ€์ƒ‰ ํ•จ์ˆ˜
def search_qa(query, k=3):
scores, retrieved_examples = data.get_nearest_examples(
"question_embedding", ST.encode(query), k=k
)
return [retrieved_examples["answer"][i] for i in range(k)]
# ์ตœ์ข… ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
def format_prompt(prompt, law_docs, qa_docs):
PROMPT = f"Question: {prompt}\n\nLegal Context:\n"
for doc in law_docs:
PROMPT += f"{doc[0]}\n"
PROMPT += "\nLegal QA:\n"
for doc in qa_docs:
PROMPT += f"{doc}\n"
return PROMPT
# ์ฑ—๋ด‡ ์‘๋‹ต ํ•จ์ˆ˜
def talk(prompt, history):
law_results = search_law(prompt, k=3)
qa_results = search_qa(prompt, k=3)
retrieved_law_docs = [result[0] for result in law_results]
formatted_prompt = format_prompt(prompt, retrieved_law_docs, qa_results)
formatted_prompt = formatted_prompt[:2000] # GPU ๋ฉ”๋ชจ๋ฆฌ ๋ถ€์กฑ์„ ํ”ผํ•˜๊ธฐ ์œ„ํ•ด ํ”„๋กฌํ”„ํŠธ ์ œํ•œ
messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]
# ๋ชจ๋ธ์—๊ฒŒ ์ƒ์„ฑ ์ง€์‹œ
input_ids = tokenizer(messages, return_tensors="pt").input_ids.to(model.device)
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=1024,
do_sample=True,
top_p=0.95,
temperature=0.75,
eos_token_id=tokenizer.eos_token_id,
)
try:
outputs = model.generate(**generate_kwargs)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
response = f"Error: {str(e)}"
return response
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
TITLE = "Legal RAG Chatbot"
DESCRIPTION = """
A chatbot that uses Retrieval-Augmented Generation (RAG) for legal consultation.
This chatbot can search legal documents and previous legal QA pairs to provide answers.
"""
demo = gr.ChatInterface(
fn=talk,
chatbot=gr.Chatbot(
show_label=True,
show_share_button=True,
show_copy_button=True,
likeable=True,
layout="bubble",
bubble_full_width=False,
),
theme="Soft",
examples=[["What are the regulations on data privacy?"]],
title=TITLE,
description=DESCRIPTION,
)
# Gradio ๋ฐ๋ชจ ์‹คํ–‰
demo.launch(debug=True)