Spaces:
Sleeping
Sleeping
StevenChen16
commited on
Commit
•
c39e972
1
Parent(s):
6c2ef5e
update chat_llama3_8b function
Browse files
app.py
CHANGED
@@ -165,67 +165,101 @@ def query_vector_store(vector_store: FAISS, query, k=4, relevance_threshold=0.8)
|
|
165 |
|
166 |
@spaces.GPU(duration=120)
|
167 |
def chat_llama3_8b(message: str,
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
"""
|
173 |
-
Generate a streaming response using the
|
174 |
-
Will display citations after the response if citations are available.
|
175 |
-
"""
|
176 |
-
# Get citations from vector store
|
177 |
-
citation = query_vector_store(vector_store, message, 4, 0.7)
|
178 |
-
|
179 |
-
# Build conversation history
|
180 |
-
conversation = []
|
181 |
-
for user, assistant in history:
|
182 |
-
conversation.extend([
|
183 |
-
{"role": "user", "content": user},
|
184 |
-
{"role": "assistant", "content": assistant}
|
185 |
-
])
|
186 |
-
|
187 |
-
# Construct the final message with background prompt and citations
|
188 |
-
if citation:
|
189 |
-
message = f"{background_prompt}Based on these citations: {citation}\nPlease answer question: {message}"
|
190 |
-
else:
|
191 |
-
message = f"{background_prompt}{message}"
|
192 |
-
|
193 |
-
conversation.append({"role": "user", "content": message})
|
194 |
-
|
195 |
-
# Generate response
|
196 |
-
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
|
197 |
-
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
198 |
-
|
199 |
-
generate_kwargs = dict(
|
200 |
-
input_ids=input_ids,
|
201 |
-
streamer=streamer,
|
202 |
-
max_new_tokens=max_new_tokens,
|
203 |
-
do_sample=True,
|
204 |
-
temperature=temperature,
|
205 |
-
eos_token_id=terminators,
|
206 |
-
)
|
207 |
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
210 |
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
|
219 |
-
#
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
|
228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
|
230 |
|
231 |
# Gradio block
|
|
|
165 |
|
166 |
@spaces.GPU(duration=120)
|
167 |
def chat_llama3_8b(message: str,
|
168 |
+
history: list,
|
169 |
+
temperature=0.6,
|
170 |
+
max_new_tokens=4096
|
171 |
+
) -> str:
|
172 |
"""
|
173 |
+
Generate a streaming response using the LLaMA model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
+
Args:
|
176 |
+
message (str): The current user message
|
177 |
+
history (list): List of previous conversation turns
|
178 |
+
temperature (float): Sampling temperature (0.0 to 1.0)
|
179 |
+
max_new_tokens (int): Maximum number of tokens to generate
|
180 |
|
181 |
+
Returns:
|
182 |
+
str: Generated response with citations if available
|
183 |
+
"""
|
184 |
+
try:
|
185 |
+
# 1. Get relevant citations from vector store
|
186 |
+
citation = query_vector_store(vector_store, message, k=4, relevance_threshold=0.7)
|
187 |
+
|
188 |
+
# 2. Format conversation history
|
189 |
+
conversation = []
|
190 |
+
for user, assistant in history:
|
191 |
+
conversation.extend([
|
192 |
+
{"role": "user", "content": str(user)},
|
193 |
+
{"role": "assistant", "content": str(assistant)}
|
194 |
+
])
|
195 |
+
|
196 |
+
# 3. Construct the final prompt
|
197 |
+
final_message = ""
|
198 |
+
if citation:
|
199 |
+
final_message = f"{background_prompt}\nBased on these references:\n{citation}\nPlease answer: {message}"
|
200 |
+
else:
|
201 |
+
final_message = f"{background_prompt}\n{message}"
|
202 |
+
|
203 |
+
conversation.append({"role": "user", "content": final_message})
|
204 |
+
|
205 |
+
# 4. Prepare model inputs
|
206 |
+
input_ids = tokenizer.apply_chat_template(
|
207 |
+
conversation,
|
208 |
+
return_tensors="pt"
|
209 |
+
).to(model.device)
|
210 |
+
|
211 |
+
# 5. Setup streamer
|
212 |
+
streamer = TextIteratorStreamer(
|
213 |
+
tokenizer,
|
214 |
+
timeout=10.0,
|
215 |
+
skip_prompt=True,
|
216 |
+
skip_special_tokens=True
|
217 |
+
)
|
218 |
|
219 |
+
# 6. Configure generation parameters
|
220 |
+
generation_config = {
|
221 |
+
"input_ids": input_ids,
|
222 |
+
"streamer": streamer,
|
223 |
+
"max_new_tokens": max_new_tokens,
|
224 |
+
"do_sample": temperature > 0,
|
225 |
+
"temperature": temperature,
|
226 |
+
"eos_token_id": terminators
|
227 |
+
}
|
228 |
+
|
229 |
+
# 7. Generate in a separate thread
|
230 |
+
thread = Thread(target=model.generate, kwargs=generation_config)
|
231 |
+
thread.start()
|
232 |
+
|
233 |
+
# 8. Stream the output
|
234 |
+
accumulated_text = []
|
235 |
+
final_chunk = False
|
236 |
+
|
237 |
+
for text_chunk in streamer:
|
238 |
+
accumulated_text.append(text_chunk)
|
239 |
+
current_response = "".join(accumulated_text)
|
240 |
+
|
241 |
+
# Check if this is the last chunk
|
242 |
+
try:
|
243 |
+
next_chunk = next(iter(streamer))
|
244 |
+
accumulated_text.append(next_chunk)
|
245 |
+
except (StopIteration, RuntimeError):
|
246 |
+
final_chunk = True
|
247 |
|
248 |
+
# Add citations on the final chunk if they exist
|
249 |
+
if final_chunk and citation:
|
250 |
+
formatted_citations = "\n\nReferences:\n" + "\n".join(
|
251 |
+
f"[{i+1}] {cite.strip()}"
|
252 |
+
for i, cite in enumerate(citation.split('\n'))
|
253 |
+
if cite.strip()
|
254 |
+
)
|
255 |
+
current_response += formatted_citations
|
256 |
+
|
257 |
+
yield current_response
|
258 |
+
|
259 |
+
except Exception as e:
|
260 |
+
error_message = f"An error occurred: {str(e)}"
|
261 |
+
print(error_message) # For logging
|
262 |
+
yield error_message
|
263 |
|
264 |
|
265 |
# Gradio block
|