File size: 4,907 Bytes
07ffad3
42df98c
07ffad3
31d7c4a
07ffad3
 
eaca477
8ccd1df
 
 
 
 
 
 
07ffad3
8ccd1df
31d7c4a
 
8ccd1df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31d7c4a
8ccd1df
 
 
 
 
 
31d7c4a
 
8ccd1df
 
 
31d7c4a
8ccd1df
31d7c4a
 
 
 
8ccd1df
1b7e4b0
31d7c4a
 
 
 
 
1b7e4b0
07ffad3
8ccd1df
 
31d7c4a
1b7e4b0
8ccd1df
 
 
 
 
1b7e4b0
8ccd1df
 
 
 
eaca477
8ccd1df
 
 
 
 
 
 
 
 
 
31d7c4a
07ffad3
8ccd1df
 
 
 
 
 
 
 
 
07ffad3
8ccd1df
31d7c4a
8ccd1df
 
 
31d7c4a
8ccd1df
07ffad3
8ccd1df
 
 
07ffad3
8ccd1df
07ffad3
 
 
 
 
8ccd1df
07ffad3
 
 
31d7c4a
 
 
 
 
07ffad3
8ccd1df
 
07ffad3
 
8ccd1df
 
07ffad3
 
ef4a283
 
 
 
 
 
 
 
 
 
 
8ccd1df
ef4a283
 
 
8ccd1df
 
18b530b
8ccd1df
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import gradio as gr
from datasets import load_dataset
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
import torch
from threading import Thread
from sentence_transformers import SentenceTransformer
import faiss
import fitz  # PyMuPDF

# ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ Hugging Face ํ† ํฐ ๊ฐ€์ ธ์˜ค๊ธฐ
token = os.environ.get("HF_TOKEN")
if not token:
    raise ValueError("Hugging Face token is missing. Please set it in your environment variables.")

# ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋“œ
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")

# PDF์—์„œ ํ…์ŠคํŠธ ์ถ”์ถœ
def extract_text_from_pdf(pdf_path):
    doc = fitz.open(pdf_path)
    text = ""
    for page in doc:
        text += page.get_text()
    return text

# ๋ฒ•๋ฅ  ๋ฌธ์„œ PDF ๊ฒฝ๋กœ ์ง€์ • ๋ฐ ํ…์ŠคํŠธ ์ถ”์ถœ
pdf_path = "./pdfs/law.pdf"  # ์—ฌ๊ธฐ์— ์‹ค์ œ PDF ๊ฒฝ๋กœ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”.
law_text = extract_text_from_pdf(pdf_path)

# ๋ฒ•๋ฅ  ๋ฌธ์„œ ํ…์ŠคํŠธ๋ฅผ ๋ฌธ์žฅ ๋‹จ์œ„๋กœ ๋‚˜๋ˆ„๊ณ  ์ž„๋ฒ ๋”ฉ
law_sentences = law_text.split('\n')
law_embeddings = ST.encode(law_sentences)

# FAISS ์ธ๋ฑ์Šค ์ƒ์„ฑ ๋ฐ ์ž„๋ฒ ๋”ฉ ์ถ”๊ฐ€
index = faiss.IndexFlatL2(law_embeddings.shape[1])
index.add(law_embeddings)

# Hugging Face์—์„œ ๋ฒ•๋ฅ  ์ƒ๋‹ด ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
dataset = load_dataset("jihye-moon/LawQA-Ko")
data = dataset["train"]

# ์งˆ๋ฌธ ์ปฌ๋Ÿผ์„ ์ž„๋ฒ ๋”ฉํ•˜์—ฌ ์ƒˆ๋กœ์šด ์ปฌ๋Ÿผ์— ์ถ”๊ฐ€
data = data.map(lambda x: {"question_embedding": ST.encode(x["question"])}, batched=True)
data.add_faiss_index(column="question_embedding")

# LLaMA ๋ชจ๋ธ ์„ค์ •
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    quantization_config=bnb_config,
    token=token
)

SYS_PROMPT = """You are an assistant for answering legal questions.
You are given the extracted parts of legal documents and a question. Provide a conversational answer.
If you don't know the answer, just say "I do not know." Don't make up an answer."""

# ๋ฒ•๋ฅ  ๋ฌธ์„œ ๊ฒ€์ƒ‰ ํ•จ์ˆ˜
def search_law(query, k=5):
    query_embedding = ST.encode([query])
    D, I = index.search(query_embedding, k)
    return [(law_sentences[i], D[0][idx]) for idx, i in enumerate(I[0])]

# ๋ฒ•๋ฅ  ์ƒ๋‹ด ๋ฐ์ดํ„ฐ ๊ฒ€์ƒ‰ ํ•จ์ˆ˜
def search_qa(query, k=3):
    scores, retrieved_examples = data.get_nearest_examples(
        "question_embedding", ST.encode(query), k=k
    )
    return [retrieved_examples["answer"][i] for i in range(k)]

# ์ตœ์ข… ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
def format_prompt(prompt, law_docs, qa_docs):
    PROMPT = f"Question: {prompt}\n\nLegal Context:\n"
    for doc in law_docs:
        PROMPT += f"{doc[0]}\n"
    PROMPT += "\nLegal QA:\n"
    for doc in qa_docs:
        PROMPT += f"{doc}\n"
    return PROMPT

# ์ฑ—๋ด‡ ์‘๋‹ต ํ•จ์ˆ˜
def talk(prompt, history):
    law_results = search_law(prompt, k=3)
    qa_results = search_qa(prompt, k=3)
    
    retrieved_law_docs = [result[0] for result in law_results]
    formatted_prompt = format_prompt(prompt, retrieved_law_docs, qa_results)
    formatted_prompt = formatted_prompt[:2000]  # GPU ๋ฉ”๋ชจ๋ฆฌ ๋ถ€์กฑ์„ ํ”ผํ•˜๊ธฐ ์œ„ํ•ด ํ”„๋กฌํ”„ํŠธ ์ œํ•œ
    messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]

    # ๋ชจ๋ธ์—๊ฒŒ ์ƒ์„ฑ ์ง€์‹œ
    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)

    streamer = TextIteratorStreamer(
        tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )

    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=True,
        top_p=0.95,
        temperature=0.75,
        eos_token_id=tokenizer.eos_token_id,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
TITLE = "Legal RAG Chatbot"

DESCRIPTION = """
A chatbot that uses Retrieval-Augmented Generation (RAG) for legal consultation.
This chatbot can search legal documents and previous legal QA pairs to provide answers.
"""

demo = gr.ChatInterface(
    fn=talk,
    chatbot=gr.Chatbot(
        show_label=True,
        show_share_button=True,
        show_copy_button=True,
        likeable=True,
        layout="bubble",
        bubble_full_width=False,
    ),
    theme="Soft",
    examples=[["What are the regulations on data privacy?"]],
    title=TITLE,
    description=DESCRIPTION,
)

# Gradio ๋ฐ๋ชจ ์‹คํ–‰
demo.launch(debug=True)