Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -45,7 +45,7 @@ data = data.map(lambda x: {"question_embedding": ST.encode(x["question"])}, batc
|
|
45 |
data.add_faiss_index(column="question_embedding")
|
46 |
|
47 |
# LLaMA ๋ชจ๋ธ ์ค์
|
48 |
-
model_id = "
|
49 |
bnb_config = BitsAndBytesConfig(
|
50 |
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
|
51 |
)
|
@@ -60,7 +60,7 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
60 |
|
61 |
SYS_PROMPT = """You are an assistant for answering legal questions.
|
62 |
You are given the extracted parts of legal documents and a question. Provide a conversational answer.
|
63 |
-
If you don't know the answer, just say "I do not know." Don't
|
64 |
|
65 |
# ๋ฒ๋ฅ ๋ฌธ์ ๊ฒ์ ํจ์
|
66 |
def search_law(query, k=5):
|
@@ -96,32 +96,25 @@ def talk(prompt, history):
|
|
96 |
messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]
|
97 |
|
98 |
# ๋ชจ๋ธ์๊ฒ ์์ฑ ์ง์
|
99 |
-
input_ids = tokenizer.
|
100 |
-
messages,
|
101 |
-
add_generation_prompt=True,
|
102 |
-
return_tensors="pt"
|
103 |
-
).to(model.device)
|
104 |
-
|
105 |
-
streamer = TextIteratorStreamer(
|
106 |
-
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
107 |
-
)
|
108 |
|
109 |
generate_kwargs = dict(
|
110 |
input_ids=input_ids,
|
111 |
-
streamer=streamer,
|
112 |
max_new_tokens=1024,
|
113 |
do_sample=True,
|
114 |
top_p=0.95,
|
115 |
temperature=0.75,
|
116 |
eos_token_id=tokenizer.eos_token_id,
|
117 |
)
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
-
outputs = []
|
122 |
-
for text in streamer:
|
123 |
-
outputs.append(text)
|
124 |
-
yield "".join(outputs)
|
125 |
|
126 |
# Gradio ์ธํฐํ์ด์ค ์ค์
|
127 |
TITLE = "Legal RAG Chatbot"
|
|
|
45 |
data.add_faiss_index(column="question_embedding")
|
46 |
|
47 |
# LLaMA ๋ชจ๋ธ ์ค์
|
48 |
+
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
49 |
bnb_config = BitsAndBytesConfig(
|
50 |
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
|
51 |
)
|
|
|
60 |
|
61 |
SYS_PROMPT = """You are an assistant for answering legal questions.
|
62 |
You are given the extracted parts of legal documents and a question. Provide a conversational answer.
|
63 |
+
If you don't know the answer, just say "I do not know." Don't makup an answer."""
|
64 |
|
65 |
# ๋ฒ๋ฅ ๋ฌธ์ ๊ฒ์ ํจ์
|
66 |
def search_law(query, k=5):
|
|
|
96 |
messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]
|
97 |
|
98 |
# ๋ชจ๋ธ์๊ฒ ์์ฑ ์ง์
|
99 |
+
input_ids = tokenizer(messages, return_tensors="pt").input_ids.to(model.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
generate_kwargs = dict(
|
102 |
input_ids=input_ids,
|
|
|
103 |
max_new_tokens=1024,
|
104 |
do_sample=True,
|
105 |
top_p=0.95,
|
106 |
temperature=0.75,
|
107 |
eos_token_id=tokenizer.eos_token_id,
|
108 |
)
|
109 |
+
|
110 |
+
try:
|
111 |
+
outputs = model.generate(**generate_kwargs)
|
112 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
113 |
+
except Exception as e:
|
114 |
+
response = f"Error: {str(e)}"
|
115 |
+
|
116 |
+
return response
|
117 |
|
|
|
|
|
|
|
|
|
118 |
|
119 |
# Gradio ์ธํฐํ์ด์ค ์ค์
|
120 |
TITLE = "Legal RAG Chatbot"
|