StevenChen16 commited on
Commit
bccfe43
1 Parent(s): c39e972

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -73
app.py CHANGED
@@ -181,85 +181,85 @@ def chat_llama3_8b(message: str,
181
  Returns:
182
  str: Generated response with citations if available
183
  """
184
- try:
185
- # 1. Get relevant citations from vector store
186
- citation = query_vector_store(vector_store, message, k=4, relevance_threshold=0.7)
187
-
188
- # 2. Format conversation history
189
- conversation = []
190
- for user, assistant in history:
191
- conversation.extend([
192
- {"role": "user", "content": str(user)},
193
- {"role": "assistant", "content": str(assistant)}
194
- ])
195
-
196
- # 3. Construct the final prompt
197
- final_message = ""
198
- if citation:
199
- final_message = f"{background_prompt}\nBased on these references:\n{citation}\nPlease answer: {message}"
200
- else:
201
- final_message = f"{background_prompt}\n{message}"
202
-
203
- conversation.append({"role": "user", "content": final_message})
204
-
205
- # 4. Prepare model inputs
206
- input_ids = tokenizer.apply_chat_template(
207
- conversation,
208
- return_tensors="pt"
209
- ).to(model.device)
210
-
211
- # 5. Setup streamer
212
- streamer = TextIteratorStreamer(
213
- tokenizer,
214
- timeout=10.0,
215
- skip_prompt=True,
216
- skip_special_tokens=True
217
- )
218
 
219
- # 6. Configure generation parameters
220
- generation_config = {
221
- "input_ids": input_ids,
222
- "streamer": streamer,
223
- "max_new_tokens": max_new_tokens,
224
- "do_sample": temperature > 0,
225
- "temperature": temperature,
226
- "eos_token_id": terminators
227
- }
228
 
229
- # 7. Generate in a separate thread
230
- thread = Thread(target=model.generate, kwargs=generation_config)
231
- thread.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
- # 8. Stream the output
234
- accumulated_text = []
235
- final_chunk = False
 
 
 
236
 
237
- for text_chunk in streamer:
238
- accumulated_text.append(text_chunk)
239
- current_response = "".join(accumulated_text)
240
-
241
- # Check if this is the last chunk
242
- try:
243
- next_chunk = next(iter(streamer))
244
- accumulated_text.append(next_chunk)
245
- except (StopIteration, RuntimeError):
246
- final_chunk = True
247
 
248
- # Add citations on the final chunk if they exist
249
- if final_chunk and citation:
250
- formatted_citations = "\n\nReferences:\n" + "\n".join(
251
- f"[{i+1}] {cite.strip()}"
252
- for i, cite in enumerate(citation.split('\n'))
253
- if cite.strip()
254
- )
255
- current_response += formatted_citations
256
-
257
- yield current_response
258
 
259
- except Exception as e:
260
- error_message = f"An error occurred: {str(e)}"
261
- print(error_message) # For logging
262
- yield error_message
263
 
264
 
265
  # Gradio block
 
181
  Returns:
182
  str: Generated response with citations if available
183
  """
184
+ # try:
185
+ # 1. Get relevant citations from vector store
186
+ citation = query_vector_store(vector_store, message, k=4, relevance_threshold=0.7)
187
+
188
+ # 2. Format conversation history
189
+ conversation = []
190
+ for user, assistant in history:
191
+ conversation.extend([
192
+ {"role": "user", "content": str(user)},
193
+ {"role": "assistant", "content": str(assistant)}
194
+ ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
+ # 3. Construct the final prompt
197
+ final_message = ""
198
+ if citation:
199
+ final_message = f"{background_prompt}\nBased on these references:\n{citation}\nPlease answer: {message}"
200
+ else:
201
+ final_message = f"{background_prompt}\n{message}"
 
 
 
202
 
203
+ conversation.append({"role": "user", "content": final_message})
204
+
205
+ # 4. Prepare model inputs
206
+ input_ids = tokenizer.apply_chat_template(
207
+ conversation,
208
+ return_tensors="pt"
209
+ ).to(model.device)
210
+
211
+ # 5. Setup streamer
212
+ streamer = TextIteratorStreamer(
213
+ tokenizer,
214
+ timeout=10.0,
215
+ skip_prompt=True,
216
+ skip_special_tokens=True
217
+ )
218
+
219
+ # 6. Configure generation parameters
220
+ generation_config = {
221
+ "input_ids": input_ids,
222
+ "streamer": streamer,
223
+ "max_new_tokens": max_new_tokens,
224
+ "do_sample": temperature > 0,
225
+ "temperature": temperature,
226
+ "eos_token_id": terminators
227
+ }
228
+
229
+ # 7. Generate in a separate thread
230
+ thread = Thread(target=model.generate, kwargs=generation_config)
231
+ thread.start()
232
+
233
+ # 8. Stream the output
234
+ accumulated_text = []
235
+ final_chunk = False
236
+
237
+ for text_chunk in streamer:
238
+ accumulated_text.append(text_chunk)
239
+ current_response = "".join(accumulated_text)
240
 
241
+ # Check if this is the last chunk
242
+ try:
243
+ next_chunk = next(iter(streamer))
244
+ accumulated_text.append(next_chunk)
245
+ except (StopIteration, RuntimeError):
246
+ final_chunk = True
247
 
248
+ # Add citations on the final chunk if they exist
249
+ if final_chunk and citation:
250
+ formatted_citations = "\n\nReferences:\n" + "\n".join(
251
+ f"[{i+1}] {cite.strip()}"
252
+ for i, cite in enumerate(citation.split('\n'))
253
+ if cite.strip()
254
+ )
255
+ current_response += formatted_citations
 
 
256
 
257
+ yield current_response
 
 
 
 
 
 
 
 
 
258
 
259
+ # except Exception as e:
260
+ # error_message = f"An error occurred: {str(e)}"
261
+ # print(error_message) # For logging
262
+ # yield error_message
263
 
264
 
265
  # Gradio block