File size: 4,907 Bytes
07ffad3
42df98c
07a8064
07ffad3
 
eaca477
8ccd1df
 
07a8064
07ffad3
8ccd1df
31d7c4a
 
8ccd1df
 
 
 
 
 
 
 
07a8064
 
 
 
 
 
 
 
 
 
8ccd1df
07a8064
 
31d7c4a
07a8064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ccd1df
 
 
31d7c4a
 
8ccd1df
 
 
31d7c4a
8ccd1df
07a8064
 
1b7e4b0
31d7c4a
 
 
1b7e4b0
07ffad3
8ccd1df
 
96aaaae
07a8064
 
 
 
1b7e4b0
8ccd1df
 
 
 
eaca477
8ccd1df
 
 
 
 
 
96aaaae
8ccd1df
 
 
31d7c4a
07ffad3
8ccd1df
 
 
 
96aaaae
8ccd1df
 
 
96aaaae
8ccd1df
07ffad3
8ccd1df
07a8064
8ccd1df
07ffad3
8ccd1df
 
 
07ffad3
8ccd1df
07ffad3
 
 
 
 
8ccd1df
07ffad3
96aaaae
07ffad3
 
31d7c4a
 
 
 
 
07ffad3
8ccd1df
 
96aaaae
 
07ffad3
ef4a283
 
 
 
 
 
 
 
 
 
 
8ccd1df
ef4a283
 
 
8ccd1df
 
07a8064
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import gradio as gr
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
from sentence_transformers import SentenceTransformer
import faiss
import fitz  # PyMuPDF
import os

# ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋“œ
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")

# PDF์—์„œ ํ…์ŠคํŠธ ์ถ”์ถœ
def extract_text_from_pdf(pdf_path):
    doc = fitz.open(pdf_path)
    text = ""
    for page in doc:
        text += page.get_text()
    return text

# ๊ธฐ๋ณธ ์ œ๊ณต PDF ํŒŒ์ผ ๊ฒฝ๋กœ
default_pdf_path = "laws.pdf"

# FAISS ์ธ๋ฑ์Šค ์ดˆ๊ธฐํ™”
index = None
law_sentences = []

# ๊ธฐ๋ณธ ์ œ๊ณต PDF ํŒŒ์ผ ์ฒ˜๋ฆฌ ํ•จ์ˆ˜
def process_default_pdf():
    global index, law_sentences

    # PDF์—์„œ ํ…์ŠคํŠธ ์ถ”์ถœ
    law_text = extract_text_from_pdf(default_pdf_path)

    # ๋ฌธ์žฅ์„ ๋‚˜๋ˆ„๊ณ  ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ
    law_sentences = law_text.split('\n')
    law_embeddings = ST.encode(law_sentences)

    # FAISS ์ธ๋ฑ์Šค ์ƒ์„ฑ ๋ฐ ์ž„๋ฒ ๋”ฉ ์ถ”๊ฐ€
    index = faiss.IndexFlatL2(law_embeddings.shape[1])
    index.add(law_embeddings)

# ์ฒ˜์Œ์— ๊ธฐ๋ณธ PDF ํŒŒ์ผ ์ฒ˜๋ฆฌ
process_default_pdf()

# ๋ฒ•๋ฅ  ๋ฌธ์„œ ๊ฒ€์ƒ‰ ํ•จ์ˆ˜
def search_law(query, k=5):
    query_embedding = ST.encode([query])
    D, I = index.search(query_embedding, k)
    return [(law_sentences[i], D[0][idx]) for idx, i in enumerate(I[0])]

# Hugging Face์—์„œ ๋ฒ•๋ฅ  ์ƒ๋‹ด ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
dataset = load_dataset("jihye-moon/LawQA-Ko")
data = dataset["train"]

# ์งˆ๋ฌธ ์ปฌ๋Ÿผ์„ ์ž„๋ฒ ๋”ฉํ•˜์—ฌ ์ƒˆ๋กœ์šด ์ปฌ๋Ÿผ์— ์ถ”๊ฐ€
data = data.map(lambda x: {"question_embedding": ST.encode(x["question"])}, batched=True)
data.add_faiss_index(column="question_embedding")

# LLaMA ๋ชจ๋ธ ์„ค์ •
model_id = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

SYS_PROMPT = """You are an assistant for answering legal questions.
You are given the extracted parts of legal documents and a question. Provide a conversational answer.
If you don't know the answer, just say "I do not know." Don't make up an answer.
you must answer korean.
You're a LAWEYE legal advisor bot. Your job is to provide korean legal assistance by asking questions to korean speaker, then offering advice or guidance based on the information and law provisions provided. Make sure you only respond with one question at a time.
...
"""

# ๋ฒ•๋ฅ  ์ƒ๋‹ด ๋ฐ์ดํ„ฐ ๊ฒ€์ƒ‰ ํ•จ์ˆ˜
def search_qa(query, k=3):
    scores, retrieved_examples = data.get_nearest_examples(
        "question_embedding", ST.encode(query), k=k
    )
    return [retrieved_examples["answer"][i] for i in range(k)]

# ์ตœ์ข… ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ
def format_prompt(prompt, law_docs, qa_docs):
    PROMPT = f"Question: {prompt}\n\nLegal Context:\n"
    for doc in law_docs:
        PROMPT += f"{doc[0]}\n"  # Assuming doc[0] contains the relevant text
    PROMPT += "\nLegal QA:\n"
    for doc in qa_docs:
        PROMPT += f"{doc}\n"
    return PROMPT

# ์ฑ—๋ด‡ ์‘๋‹ต ํ•จ์ˆ˜
def talk(prompt, history):
    law_results = search_law(prompt, k=3)
    qa_results = search_qa(prompt, k=3)

    retrieved_law_docs = [result[0] for result in law_results]
    formatted_prompt = format_prompt(prompt, retrieved_law_docs, qa_results)
    formatted_prompt = formatted_prompt[:2000]  # GPU ๋ฉ”๋ชจ๋ฆฌ ๋ถ€์กฑ์„ ํ”ผํ•˜๊ธฐ ์œ„ํ•ด ํ”„๋กฌํ”„ํŠธ ์ œํ•œ

    messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]

    # ๋ชจ๋ธ์—๊ฒŒ ์ƒ์„ฑ ์ง€์‹œ
    input_ids = tokenizer(messages, return_tensors="pt").to(model.device).input_ids

    streamer = TextIteratorStreamer(
        tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )

    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=True,
        top_p=0.95,
        temperature=0.75,
        eos_token_id=tokenizer.eos_token_id,
    )

    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
TITLE = "Legal RAG Chatbot"
DESCRIPTION = """A chatbot that uses Retrieval-Augmented Generation (RAG) for legal consultation.
This chatbot can search legal documents and previous legal QA pairs to provide answers."""

demo = gr.ChatInterface(
    fn=talk,
    chatbot=gr.Chatbot(
        show_label=True,
        show_share_button=True,
        show_copy_button=True,
        likeable=True,
        layout="bubble",
        bubble_full_width=False,
    ),
    theme="Soft",
    examples=[["What are the regulations on data privacy?"]],
    title=TITLE,
    description=DESCRIPTION,
)

# Gradio ๋ฐ๋ชจ ์‹คํ–‰
demo.launch(debug=True)