Spaces:
Runtime error
Runtime error
import gradio as gr | |
from datasets import load_dataset | |
import os | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig | |
import torch | |
from threading import Thread | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import fitz # PyMuPDF | |
# ํ๊ฒฝ ๋ณ์์์ Hugging Face ํ ํฐ ๊ฐ์ ธ์ค๊ธฐ | |
token = os.environ.get("HF_TOKEN") | |
if not token: | |
raise ValueError("Hugging Face token is missing. Please set it in your environment variables.") | |
# ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋ | |
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") | |
# PDF์์ ํ ์คํธ ์ถ์ถ | |
def extract_text_from_pdf(pdf_path): | |
doc = fitz.open(pdf_path) | |
text = "" | |
for page in doc: | |
text += page.get_text() | |
return text | |
# ๋ฒ๋ฅ ๋ฌธ์ PDF ๊ฒฝ๋ก ์ง์ ๋ฐ ํ ์คํธ ์ถ์ถ | |
pdf_path = "./pdfs/law.pdf" # ์ฌ๊ธฐ์ ์ค์ PDF ๊ฒฝ๋ก๋ฅผ ์ ๋ ฅํ์ธ์. | |
law_text = extract_text_from_pdf(pdf_path) | |
# ๋ฒ๋ฅ ๋ฌธ์ ํ ์คํธ๋ฅผ ๋ฌธ์ฅ ๋จ์๋ก ๋๋๊ณ ์๋ฒ ๋ฉ | |
law_sentences = law_text.split('\n') | |
law_embeddings = ST.encode(law_sentences) | |
# FAISS ์ธ๋ฑ์ค ์์ฑ ๋ฐ ์๋ฒ ๋ฉ ์ถ๊ฐ | |
index = faiss.IndexFlatL2(law_embeddings.shape[1]) | |
index.add(law_embeddings) | |
# Hugging Face์์ ๋ฒ๋ฅ ์๋ด ๋ฐ์ดํฐ์ ๋ก๋ | |
dataset = load_dataset("jihye-moon/LawQA-Ko") | |
data = dataset["train"] | |
# ์ง๋ฌธ ์ปฌ๋ผ์ ์๋ฒ ๋ฉํ์ฌ ์๋ก์ด ์ปฌ๋ผ์ ์ถ๊ฐ | |
data = data.map(lambda x: {"question_embedding": ST.encode(x["question"])}, batched=True) | |
data.add_faiss_index(column="question_embedding") | |
# LLaMA ๋ชจ๋ธ ์ค์ | |
model_id = "meta-llama/Meta-Llama-3-8B-Instruct" | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
quantization_config=bnb_config, | |
token=token | |
) | |
SYS_PROMPT = """You are an assistant for answering legal questions. | |
You are given the extracted parts of legal documents and a question. Provide a conversational answer. | |
If you don't know the answer, just say "I do not know." Don't make up an answer.""" | |
# ๋ฒ๋ฅ ๋ฌธ์ ๊ฒ์ ํจ์ | |
def search_law(query, k=5): | |
query_embedding = ST.encode([query]) | |
D, I = index.search(query_embedding, k) | |
return [(law_sentences[i], D[0][idx]) for idx, i in enumerate(I[0])] | |
# ๋ฒ๋ฅ ์๋ด ๋ฐ์ดํฐ ๊ฒ์ ํจ์ | |
def search_qa(query, k=3): | |
scores, retrieved_examples = data.get_nearest_examples( | |
"question_embedding", ST.encode(query), k=k | |
) | |
return [retrieved_examples["answer"][i] for i in range(k)] | |
# ์ต์ข ํ๋กฌํํธ ์์ฑ | |
def format_prompt(prompt, law_docs, qa_docs): | |
PROMPT = f"Question: {prompt}\n\nLegal Context:\n" | |
for doc in law_docs: | |
PROMPT += f"{doc[0]}\n" | |
PROMPT += "\nLegal QA:\n" | |
for doc in qa_docs: | |
PROMPT += f"{doc}\n" | |
return PROMPT | |
# ์ฑ๋ด ์๋ต ํจ์ | |
def talk(prompt, history): | |
law_results = search_law(prompt, k=3) | |
qa_results = search_qa(prompt, k=3) | |
retrieved_law_docs = [result[0] for result in law_results] | |
formatted_prompt = format_prompt(prompt, retrieved_law_docs, qa_results) | |
formatted_prompt = formatted_prompt[:2000] # GPU ๋ฉ๋ชจ๋ฆฌ ๋ถ์กฑ์ ํผํ๊ธฐ ์ํด ํ๋กฌํํธ ์ ํ | |
messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}] | |
# ๋ชจ๋ธ์๊ฒ ์์ฑ ์ง์ | |
input_ids = tokenizer.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
return_tensors="pt" | |
).to(model.device) | |
streamer = TextIteratorStreamer( | |
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True | |
) | |
generate_kwargs = dict( | |
input_ids=input_ids, | |
streamer=streamer, | |
max_new_tokens=1024, | |
do_sample=True, | |
top_p=0.95, | |
temperature=0.75, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
outputs = [] | |
for text in streamer: | |
outputs.append(text) | |
yield "".join(outputs) | |
# Gradio ์ธํฐํ์ด์ค ์ค์ | |
TITLE = "Legal RAG Chatbot" | |
DESCRIPTION = """ | |
A chatbot that uses Retrieval-Augmented Generation (RAG) for legal consultation. | |
This chatbot can search legal documents and previous legal QA pairs to provide answers. | |
""" | |
demo = gr.ChatInterface( | |
fn=talk, | |
chatbot=gr.Chatbot( | |
show_label=True, | |
show_share_button=True, | |
show_copy_button=True, | |
likeable=True, | |
layout="bubble", | |
bubble_full_width=False, | |
), | |
theme="Soft", | |
examples=[["What are the regulations on data privacy?"]], | |
title=TITLE, | |
description=DESCRIPTION, | |
) | |
# Gradio ๋ฐ๋ชจ ์คํ | |
demo.launch(debug=True) | |