EnverLee commited on
Commit
8ccd1df
·
verified ·
1 Parent(s): e1facd9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -69
app.py CHANGED
@@ -1,32 +1,55 @@
1
  import gradio as gr
2
  from datasets import load_dataset
3
-
4
  import os
5
- import spaces
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
7
  import torch
8
  from threading import Thread
9
  from sentence_transformers import SentenceTransformer
10
- from datasets import load_dataset
11
- import time
 
 
 
 
 
12
 
13
- token = os.environ["HF_TOKEN"]
14
  ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
15
 
16
- dataset = load_dataset("jihye-moon/LawQA-Ko")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
 
 
 
 
 
 
18
  data = dataset["train"]
19
- data = data.add_faiss_index("question", "answer") # column name that has the embeddings of the dataset
20
 
 
 
 
21
 
 
22
  model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
23
-
24
- # use quantization to lower GPU usage
25
  bnb_config = BitsAndBytesConfig(
26
  load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
27
  )
28
-
29
- tokenizer = AutoTokenizer.from_pretrained(model_id,token=token)
30
  model = AutoModelForCausalLM.from_pretrained(
31
  model_id,
32
  torch_dtype=torch.bfloat16,
@@ -34,66 +57,63 @@ model = AutoModelForCausalLM.from_pretrained(
34
  quantization_config=bnb_config,
35
  token=token
36
  )
37
- terminators = [
38
- tokenizer.eos_token_id,
39
- tokenizer.convert_tokens_to_ids("<|eot_id|>")
40
- ]
41
 
42
- SYS_PROMPT = """You are an assistant for answering questions.
43
- You are given the extracted parts of a long document and a question. Provide a conversational answer.
44
  If you don't know the answer, just say "I do not know." Don't make up an answer."""
45
 
 
 
 
 
 
46
 
47
-
48
- def search(query: str, k: int = 3 ):
49
- """a function that embeds a new query and returns the most probable results"""
50
- embedded_query = ST.encode(query) # embed new query
51
- scores, retrieved_examples = data.get_nearest_examples( # retrieve results
52
- "embeddings", embedded_query, # compare our new embedded query with the dataset embeddings
53
- k=k # get only top k results
54
  )
55
- return scores, retrieved_examples
56
-
57
- def format_prompt(prompt,retrieved_documents,k):
58
- """using the retrieved documents we will prompt the model to generate our responses"""
59
- PROMPT = f"Question:{prompt}\nContext:"
60
- for idx in range(k) :
61
- PROMPT+= f"{retrieved_documents['text'][idx]}\n"
 
 
 
62
  return PROMPT
63
 
 
 
 
 
 
 
 
 
 
64
 
65
- @spaces.GPU(duration=150)
66
- def talk(prompt,history):
67
- k = 1 # number of retrieved documents
68
- scores , retrieved_documents = search(prompt, k)
69
- formatted_prompt = format_prompt(prompt,retrieved_documents,k)
70
- formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM
71
- messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
72
- # tell the model to generate
73
  input_ids = tokenizer.apply_chat_template(
74
- messages,
75
- add_generation_prompt=True,
76
- return_tensors="pt"
77
  ).to(model.device)
78
- outputs = model.generate(
79
- input_ids,
80
- max_new_tokens=1024,
81
- eos_token_id=terminators,
82
- do_sample=True,
83
- temperature=0.6,
84
- top_p=0.9,
85
- )
86
  streamer = TextIteratorStreamer(
87
- tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
88
- )
 
89
  generate_kwargs = dict(
90
- input_ids= input_ids,
91
  streamer=streamer,
92
  max_new_tokens=1024,
93
  do_sample=True,
94
  top_p=0.95,
95
  temperature=0.75,
96
- eos_token_id=terminators,
97
  )
98
  t = Thread(target=model.generate, kwargs=generate_kwargs)
99
  t.start()
@@ -101,25 +121,16 @@ def talk(prompt,history):
101
  outputs = []
102
  for text in streamer:
103
  outputs.append(text)
104
- print(outputs)
105
  yield "".join(outputs)
106
 
107
-
108
- TITLE = "# RAG"
109
 
110
  DESCRIPTION = """
111
- A rag pipeline with a chatbot feature
112
-
113
- Resources used to build this project :
114
-
115
- * embedding model : https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1
116
- * dataset : https://huggingface.co/datasets/not-lain/wikipedia
117
- * faiss docs : https://huggingface.co/docs/datasets/v2.18.0/en/package_reference/main_classes#datasets.Dataset.add_faiss_index
118
- * chatbot : https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct
119
- * Full documentation : https://huggingface.co/blog/not-lain/rag-chatbot-using-llama3
120
  """
121
 
122
-
123
  demo = gr.ChatInterface(
124
  fn=talk,
125
  chatbot=gr.Chatbot(
@@ -131,9 +142,11 @@ demo = gr.ChatInterface(
131
  bubble_full_width=False,
132
  ),
133
  theme="Soft",
134
- examples=[["what's anarchy ? "]],
135
  title=TITLE,
136
  description=DESCRIPTION,
137
-
138
  )
 
 
139
  demo.launch(debug=True)
 
 
1
  import gradio as gr
2
  from datasets import load_dataset
 
3
  import os
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
5
  import torch
6
  from threading import Thread
7
  from sentence_transformers import SentenceTransformer
8
+ import faiss
9
+ import fitz # PyMuPDF
10
+
11
+ # 환경 변수에서 Hugging Face 토큰 가져오기
12
+ token = os.environ.get("HF_TOKEN")
13
+ if not token:
14
+ raise ValueError("Hugging Face token is missing. Please set it in your environment variables.")
15
 
16
+ # 임베딩 모델 로드
17
  ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
18
 
19
+ # PDF에서 텍스트 추출
20
+ def extract_text_from_pdf(pdf_path):
21
+ doc = fitz.open(pdf_path)
22
+ text = ""
23
+ for page in doc:
24
+ text += page.get_text()
25
+ return text
26
+
27
+ # 법률 문서 PDF 경로 지정 및 텍스트 추출
28
+ pdf_path = "./pdfs/law.pdf" # 여기에 실제 PDF 경로를 입력하세요.
29
+ law_text = extract_text_from_pdf(pdf_path)
30
+
31
+ # 법률 문서 텍스트를 문장 단위로 나누고 임베딩
32
+ law_sentences = law_text.split('\n')
33
+ law_embeddings = ST.encode(law_sentences)
34
 
35
+ # FAISS 인덱스 생성 및 임베딩 추가
36
+ index = faiss.IndexFlatL2(law_embeddings.shape[1])
37
+ index.add(law_embeddings)
38
+
39
+ # Hugging Face에서 법률 상담 데이터셋 로드
40
+ dataset = load_dataset("jihye-moon/LawQA-Ko")
41
  data = dataset["train"]
 
42
 
43
+ # 질문 컬럼을 임베딩하여 새로운 컬럼에 추가
44
+ data = data.map(lambda x: {"question_embedding": ST.encode(x["question"])}, batched=True)
45
+ data.add_faiss_index(column="question_embedding")
46
 
47
+ # LLaMA 모델 설정
48
  model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
 
 
49
  bnb_config = BitsAndBytesConfig(
50
  load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
51
  )
52
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
 
53
  model = AutoModelForCausalLM.from_pretrained(
54
  model_id,
55
  torch_dtype=torch.bfloat16,
 
57
  quantization_config=bnb_config,
58
  token=token
59
  )
 
 
 
 
60
 
61
+ SYS_PROMPT = """You are an assistant for answering legal questions.
62
+ You are given the extracted parts of legal documents and a question. Provide a conversational answer.
63
  If you don't know the answer, just say "I do not know." Don't make up an answer."""
64
 
65
+ # 법률 문서 검색 함수
66
+ def search_law(query, k=5):
67
+ query_embedding = ST.encode([query])
68
+ D, I = index.search(query_embedding, k)
69
+ return [(law_sentences[i], D[0][idx]) for idx, i in enumerate(I[0])]
70
 
71
+ # 법률 상담 데이터 검색 함수
72
+ def search_qa(query, k=3):
73
+ scores, retrieved_examples = data.get_nearest_examples(
74
+ "question_embedding", ST.encode(query), k=k
 
 
 
75
  )
76
+ return [retrieved_examples["answer"][i] for i in range(k)]
77
+
78
+ # 최종 프롬프트 생성
79
+ def format_prompt(prompt, law_docs, qa_docs):
80
+ PROMPT = f"Question: {prompt}\n\nLegal Context:\n"
81
+ for doc in law_docs:
82
+ PROMPT += f"{doc[0]}\n"
83
+ PROMPT += "\nLegal QA:\n"
84
+ for doc in qa_docs:
85
+ PROMPT += f"{doc}\n"
86
  return PROMPT
87
 
88
+ # 챗봇 응답 함수
89
+ def talk(prompt, history):
90
+ law_results = search_law(prompt, k=3)
91
+ qa_results = search_qa(prompt, k=3)
92
+
93
+ retrieved_law_docs = [result[0] for result in law_results]
94
+ formatted_prompt = format_prompt(prompt, retrieved_law_docs, qa_results)
95
+ formatted_prompt = formatted_prompt[:2000] # GPU 메모리 부족을 피하기 위해 프롬프트 제한
96
+ messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]
97
 
98
+ # 모델에게 생성 지시
 
 
 
 
 
 
 
99
  input_ids = tokenizer.apply_chat_template(
100
+ messages,
101
+ add_generation_prompt=True,
102
+ return_tensors="pt"
103
  ).to(model.device)
104
+
 
 
 
 
 
 
 
105
  streamer = TextIteratorStreamer(
106
+ tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
107
+ )
108
+
109
  generate_kwargs = dict(
110
+ input_ids=input_ids,
111
  streamer=streamer,
112
  max_new_tokens=1024,
113
  do_sample=True,
114
  top_p=0.95,
115
  temperature=0.75,
116
+ eos_token_id=tokenizer.eos_token_id,
117
  )
118
  t = Thread(target=model.generate, kwargs=generate_kwargs)
119
  t.start()
 
121
  outputs = []
122
  for text in streamer:
123
  outputs.append(text)
 
124
  yield "".join(outputs)
125
 
126
+ # Gradio 인터페이스 설정
127
+ TITLE = "Legal RAG Chatbot"
128
 
129
  DESCRIPTION = """
130
+ A chatbot that uses Retrieval-Augmented Generation (RAG) for legal consultation.
131
+ This chatbot can search legal documents and previous legal QA pairs to provide answers.
 
 
 
 
 
 
 
132
  """
133
 
 
134
  demo = gr.ChatInterface(
135
  fn=talk,
136
  chatbot=gr.Chatbot(
 
142
  bubble_full_width=False,
143
  ),
144
  theme="Soft",
145
+ examples=[["What are the regulations on data privacy?"]],
146
  title=TITLE,
147
  description=DESCRIPTION,
 
148
  )
149
+
150
+ # Gradio 데모 실행
151
  demo.launch(debug=True)
152
+