Shreyas094 commited on
Commit
d0388f2
1 Parent(s): e40971e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +473 -0
app.py CHANGED
@@ -1,3 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  demo = gr.ChatInterface(
2
  respond,
3
  additional_inputs=[
 
1
+ import os
2
+ import json
3
+ import re
4
+ import gradio as gr
5
+ import requests
6
+ from duckduckgo_search import DDGS
7
+ from typing import List
8
+ from pydantic import BaseModel, Field
9
+ from tempfile import NamedTemporaryFile
10
+ from langchain_community.vectorstores import FAISS
11
+ from langchain_community.document_loaders import PyPDFLoader
12
+ from langchain_community.embeddings import HuggingFaceEmbeddings
13
+ from llama_parse import LlamaParse
14
+ from langchain_core.documents import Document
15
+ from huggingface_hub import InferenceClient
16
+ import inspect
17
+ import logging
18
+
19
+
20
+ # Set up basic configuration for logging
21
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
22
+
23
+ # Environment variables and configurations
24
+ huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
25
+ llama_cloud_api_key = os.environ.get("LLAMA_CLOUD_API_KEY")
26
+ ACCOUNT_ID = os.environ.get("CLOUDFARE_ACCOUNT_ID")
27
+ API_TOKEN = os.environ.get("CLOUDFLARE_AUTH_TOKEN")
28
+ API_BASE_URL = "https://api.cloudflare.com/client/v4/accounts/a17f03e0f049ccae0c15cdcf3b9737ce/ai/run/"
29
+
30
+ print(f"ACCOUNT_ID: {ACCOUNT_ID}")
31
+ print(f"CLOUDFLARE_AUTH_TOKEN: {API_TOKEN[:5]}..." if API_TOKEN else "Not set")
32
+
33
+ MODELS = [
34
+ "mistralai/Mistral-7B-Instruct-v0.3",
35
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
36
+ "@cf/meta/llama-3.1-8b-instruct"
37
+ ]
38
+
39
+ # Initialize LlamaParse
40
+ llama_parser = LlamaParse(
41
+ api_key=llama_cloud_api_key,
42
+ result_type="markdown",
43
+ num_workers=4,
44
+ verbose=True,
45
+ language="en",
46
+ )
47
+
48
+ def load_document(file: NamedTemporaryFile, parser: str = "llamaparse") -> List[Document]:
49
+ """Loads and splits the document into pages."""
50
+ if parser == "pypdf":
51
+ loader = PyPDFLoader(file.name)
52
+ return loader.load_and_split()
53
+ elif parser == "llamaparse":
54
+ try:
55
+ documents = llama_parser.load_data(file.name)
56
+ return [Document(page_content=doc.text, metadata={"source": file.name}) for doc in documents]
57
+ except Exception as e:
58
+ print(f"Error using Llama Parse: {str(e)}")
59
+ print("Falling back to PyPDF parser")
60
+ loader = PyPDFLoader(file.name)
61
+ return loader.load_and_split()
62
+ else:
63
+ raise ValueError("Invalid parser specified. Use 'pypdf' or 'llamaparse'.")
64
+
65
+ def get_embeddings():
66
+ return HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
67
+
68
+ def update_vectors(files, parser):
69
+ global uploaded_documents
70
+ logging.info(f"Entering update_vectors with {len(files)} files and parser: {parser}")
71
+
72
+ if not files:
73
+ logging.warning("No files provided for update_vectors")
74
+ return "Please upload at least one PDF file.", gr.CheckboxGroup(
75
+ choices=[doc["name"] for doc in uploaded_documents],
76
+ value=[doc["name"] for doc in uploaded_documents if doc["selected"]],
77
+ label="Select documents to query"
78
+ )
79
+
80
+ embed = get_embeddings()
81
+ total_chunks = 0
82
+
83
+ all_data = []
84
+ for file in files:
85
+ logging.info(f"Processing file: {file.name}")
86
+ try:
87
+ data = load_document(file, parser)
88
+ logging.info(f"Loaded {len(data)} chunks from {file.name}")
89
+ all_data.extend(data)
90
+ total_chunks += len(data)
91
+ # Append new documents instead of replacing
92
+ if not any(doc["name"] == file.name for doc in uploaded_documents):
93
+ uploaded_documents.append({"name": file.name, "selected": True})
94
+ logging.info(f"Added new document to uploaded_documents: {file.name}")
95
+ else:
96
+ logging.info(f"Document already exists in uploaded_documents: {file.name}")
97
+ except Exception as e:
98
+ logging.error(f"Error processing file {file.name}: {str(e)}")
99
+
100
+ logging.info(f"Total chunks processed: {total_chunks}")
101
+
102
+ if os.path.exists("faiss_database"):
103
+ logging.info("Updating existing FAISS database")
104
+ database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
105
+ database.add_documents(all_data)
106
+ else:
107
+ logging.info("Creating new FAISS database")
108
+ database = FAISS.from_documents(all_data, embed)
109
+
110
+ database.save_local("faiss_database")
111
+ logging.info("FAISS database saved")
112
+
113
+ return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files using {parser}.", gr.CheckboxGroup(
114
+ choices=[doc["name"] for doc in uploaded_documents],
115
+ value=[doc["name"] for doc in uploaded_documents if doc["selected"]],
116
+ label="Select documents to query"
117
+ )
118
+
119
+ def generate_chunked_response(prompt, model, max_tokens=1000, num_calls=3, temperature=0.2, should_stop=False):
120
+ print(f"Starting generate_chunked_response with {num_calls} calls")
121
+ full_response = ""
122
+ messages = [{"role": "user", "content": prompt}]
123
+
124
+ if model == "@cf/meta/llama-3.1-8b-instruct":
125
+ # Cloudflare API
126
+ for i in range(num_calls):
127
+ print(f"Starting Cloudflare API call {i+1}")
128
+ if should_stop:
129
+ print("Stop clicked, breaking loop")
130
+ break
131
+ try:
132
+ response = requests.post(
133
+ f"https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}/ai/run/@cf/meta/llama-3.1-8b-instruct",
134
+ headers={"Authorization": f"Bearer {API_TOKEN}"},
135
+ json={
136
+ "stream": true,
137
+ "messages": [
138
+ {"role": "system", "content": "You are a friendly assistant"},
139
+ {"role": "user", "content": prompt}
140
+ ],
141
+ "max_tokens": max_tokens,
142
+ "temperature": temperature
143
+ },
144
+ stream=true
145
+ )
146
+
147
+ for line in response.iter_lines():
148
+ if should_stop:
149
+ print("Stop clicked during streaming, breaking")
150
+ break
151
+ if line:
152
+ try:
153
+ json_data = json.loads(line.decode('utf-8').split('data: ')[1])
154
+ chunk = json_data['response']
155
+ full_response += chunk
156
+ except json.JSONDecodeError:
157
+ continue
158
+ print(f"Cloudflare API call {i+1} completed")
159
+ except Exception as e:
160
+ print(f"Error in generating response from Cloudflare: {str(e)}")
161
+ else:
162
+ # Original Hugging Face API logic
163
+ client = InferenceClient(model, token=huggingface_token)
164
+
165
+ for i in range(num_calls):
166
+ print(f"Starting Hugging Face API call {i+1}")
167
+ if should_stop:
168
+ print("Stop clicked, breaking loop")
169
+ break
170
+ try:
171
+ for message in client.chat_completion(
172
+ messages=messages,
173
+ max_tokens=max_tokens,
174
+ temperature=temperature,
175
+ stream=True,
176
+ ):
177
+ if should_stop:
178
+ print("Stop clicked during streaming, breaking")
179
+ break
180
+ if message.choices and message.choices[0].delta and message.choices[0].delta.content:
181
+ chunk = message.choices[0].delta.content
182
+ full_response += chunk
183
+ print(f"Hugging Face API call {i+1} completed")
184
+ except Exception as e:
185
+ print(f"Error in generating response from Hugging Face: {str(e)}")
186
+
187
+ # Clean up the response
188
+ clean_response = re.sub(r'<s>\[INST\].*?\[/INST\]\s*', '', full_response, flags=re.DOTALL)
189
+ clean_response = clean_response.replace("Using the following context:", "").strip()
190
+ clean_response = clean_response.replace("Using the following context from the PDF documents:", "").strip()
191
+
192
+ # Remove duplicate paragraphs and sentences
193
+ paragraphs = clean_response.split('\n\n')
194
+ unique_paragraphs = []
195
+ for paragraph in paragraphs:
196
+ if paragraph not in unique_paragraphs:
197
+ sentences = paragraph.split('. ')
198
+ unique_sentences = []
199
+ for sentence in sentences:
200
+ if sentence not in unique_sentences:
201
+ unique_sentences.append(sentence)
202
+ unique_paragraphs.append('. '.join(unique_sentences))
203
+
204
+ final_response = '\n\n'.join(unique_paragraphs)
205
+
206
+ print(f"Final clean response: {final_response[:100]}...")
207
+ return final_response
208
+
209
+ def duckduckgo_search(query):
210
+ with DDGS() as ddgs:
211
+ results = ddgs.text(query, max_results=5)
212
+ return results
213
+
214
+ class CitingSources(BaseModel):
215
+ sources: List[str] = Field(
216
+ ...,
217
+ description="List of sources to cite. Should be an URL of the source."
218
+ )
219
+ def chatbot_interface(message, history, use_web_search, model, temperature, num_calls):
220
+ if not message.strip():
221
+ return "", history
222
+
223
+ history = history + [(message, "")]
224
+
225
+ try:
226
+ for response in respond(message, history, model, temperature, num_calls, use_web_search):
227
+ history[-1] = (message, response)
228
+ yield history
229
+ except gr.CancelledError:
230
+ yield history
231
+ except Exception as e:
232
+ logging.error(f"Unexpected error in chatbot_interface: {str(e)}")
233
+ history[-1] = (message, f"An unexpected error occurred: {str(e)}")
234
+ yield history
235
+
236
+ def retry_last_response(history, use_web_search, model, temperature, num_calls):
237
+ if not history:
238
+ return history
239
+
240
+ last_user_msg = history[-1][0]
241
+ history = history[:-1] # Remove the last response
242
+
243
+ return chatbot_interface(last_user_msg, history, use_web_search, model, temperature, num_calls)
244
+
245
+ def respond(message, history, model, temperature, num_calls, use_web_search, selected_docs):
246
+ logging.info(f"User Query: {message}")
247
+ logging.info(f"Model Used: {model}")
248
+ logging.info(f"Search Type: {'Web Search' if use_web_search else 'PDF Search'}")
249
+
250
+ logging.info(f"Selected Documents: {selected_docs}")
251
+
252
+ try:
253
+ if use_web_search:
254
+ for main_content, sources in get_response_with_search(message, model, num_calls=num_calls, temperature=temperature):
255
+ response = f"{main_content}\n\n{sources}"
256
+ first_line = response.split('\n')[0] if response else ''
257
+ logging.info(f"Generated Response (first line): {first_line}")
258
+ yield response
259
+ else:
260
+ embed = get_embeddings()
261
+ if os.path.exists("faiss_database"):
262
+ database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
263
+ retriever = database.as_retriever()
264
+
265
+ # Filter relevant documents based on user selection
266
+ all_relevant_docs = retriever.get_relevant_documents(message)
267
+ relevant_docs = [doc for doc in all_relevant_docs if doc.metadata["source"] in selected_docs]
268
+
269
+ if not relevant_docs:
270
+ yield "No relevant information found in the selected documents. Please try selecting different documents or rephrasing your query."
271
+ return
272
+
273
+ context_str = "\n".join([doc.page_content for doc in relevant_docs])
274
+ else:
275
+ context_str = "No documents available."
276
+ yield "No documents available. Please upload PDF documents to answer questions."
277
+ return
278
+
279
+ if model == "@cf/meta/llama-3.1-8b-instruct":
280
+ # Use Cloudflare API
281
+ for partial_response in get_response_from_cloudflare(prompt="", context=context_str, query=message, num_calls=num_calls, temperature=temperature, search_type="pdf"):
282
+ first_line = partial_response.split('\n')[0] if partial_response else ''
283
+ logging.info(f"Generated Response (first line): {first_line}")
284
+ yield partial_response
285
+ else:
286
+ # Use Hugging Face API
287
+ for partial_response in get_response_from_pdf(message, model, selected_docs, num_calls=num_calls, temperature=temperature):
288
+ first_line = partial_response.split('\n')[0] if partial_response else ''
289
+ logging.info(f"Generated Response (first line): {first_line}")
290
+ yield partial_response
291
+ except Exception as e:
292
+ logging.error(f"Error with {model}: {str(e)}")
293
+ if "microsoft/Phi-3-mini-4k-instruct" in model:
294
+ logging.info("Falling back to Mistral model due to Phi-3 error")
295
+ fallback_model = "mistralai/Mistral-7B-Instruct-v0.3"
296
+ yield from respond(message, history, fallback_model, temperature, num_calls, use_web_search, selected_docs)
297
+ else:
298
+ yield f"An error occurred with the {model} model: {str(e)}. Please try again or select a different model."
299
+
300
+ logging.basicConfig(level=logging.DEBUG)
301
+
302
+ def get_response_from_cloudflare(prompt, context, query, num_calls=3, temperature=0.2, search_type="pdf"):
303
+ headers = {
304
+ "Authorization": f"Bearer {API_TOKEN}",
305
+ "Content-Type": "application/json"
306
+ }
307
+ model = "@cf/meta/llama-3.1-8b-instruct"
308
+
309
+ if search_type == "pdf":
310
+ instruction = f"""Using the following context from the PDF documents:
311
+ {context}
312
+ Write a detailed and complete response that answers the following user question: '{query}'"""
313
+ else: # web search
314
+ instruction = f"""Using the following context:
315
+ {context}
316
+ Write a detailed and complete research document that fulfills the following user request: '{query}'
317
+ After writing the document, please provide a list of sources used in your response."""
318
+
319
+ inputs = [
320
+ {"role": "system", "content": instruction},
321
+ {"role": "user", "content": query}
322
+ ]
323
+
324
+ payload = {
325
+ "messages": inputs,
326
+ "stream": True,
327
+ "temperature": temperature
328
+ }
329
+
330
+ full_response = ""
331
+ for i in range(num_calls):
332
+ try:
333
+ with requests.post(f"{API_BASE_URL}{model}", headers=headers, json=payload, stream=True) as response:
334
+ if response.status_code == 200:
335
+ for line in response.iter_lines():
336
+ if line:
337
+ try:
338
+ json_response = json.loads(line.decode('utf-8').split('data: ')[1])
339
+ if 'response' in json_response:
340
+ chunk = json_response['response']
341
+ full_response += chunk
342
+ yield full_response
343
+ except (json.JSONDecodeError, IndexError) as e:
344
+ logging.error(f"Error parsing streaming response: {str(e)}")
345
+ continue
346
+ else:
347
+ logging.error(f"HTTP Error: {response.status_code}, Response: {response.text}")
348
+ yield f"I apologize, but I encountered an HTTP error: {response.status_code}. Please try again later."
349
+ except Exception as e:
350
+ logging.error(f"Error in generating response from Cloudflare: {str(e)}")
351
+ yield f"I apologize, but an error occurred: {str(e)}. Please try again later."
352
+
353
+ if not full_response:
354
+ yield "I apologize, but I couldn't generate a response at this time. Please try again later."
355
+
356
+ def get_response_with_search(query, model, num_calls=3, temperature=0.2):
357
+ search_results = duckduckgo_search(query)
358
+ context = "\n".join(f"{result['title']}\n{result['body']}\nSource: {result['href']}\n"
359
+ for result in search_results if 'body' in result)
360
+
361
+ prompt = f"""Using the following context:
362
+ {context}
363
+ Write a detailed and complete research document that fulfills the following user request: '{query}'
364
+ After writing the document, please provide a list of sources used in your response."""
365
+
366
+ if model == "@cf/meta/llama-3.1-8b-instruct":
367
+ # Use Cloudflare API
368
+ for response in get_response_from_cloudflare(prompt="", context=context, query=query, num_calls=num_calls, temperature=temperature, search_type="web"):
369
+ yield response, "" # Yield streaming response without sources
370
+ else:
371
+ # Use Hugging Face API
372
+ client = InferenceClient(model, token=huggingface_token)
373
+
374
+ main_content = ""
375
+ for i in range(num_calls):
376
+ for message in client.chat_completion(
377
+ messages=[{"role": "user", "content": prompt}],
378
+ max_tokens=1000,
379
+ temperature=temperature,
380
+ stream=True,
381
+ ):
382
+ if message.choices and message.choices[0].delta and message.choices[0].delta.content:
383
+ chunk = message.choices[0].delta.content
384
+ main_content += chunk
385
+ yield main_content, "" # Yield partial main content without sources
386
+
387
+ def get_response_from_pdf(query, model, selected_docs, num_calls=3, temperature=0.2):
388
+ logging.info(f"Entering get_response_from_pdf with query: {query}, model: {model}, selected_docs: {selected_docs}")
389
+
390
+ embed = get_embeddings()
391
+ if os.path.exists("faiss_database"):
392
+ logging.info("Loading FAISS database")
393
+ database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
394
+ else:
395
+ logging.warning("No FAISS database found")
396
+ yield "No documents available. Please upload PDF documents to answer questions."
397
+ return
398
+
399
+ retriever = database.as_retriever()
400
+ logging.info(f"Retrieving relevant documents for query: {query}")
401
+ relevant_docs = retriever.get_relevant_documents(query)
402
+ logging.info(f"Number of relevant documents retrieved: {len(relevant_docs)}")
403
+
404
+ # Filter relevant_docs based on selected documents
405
+ filtered_docs = [doc for doc in relevant_docs if doc.metadata["source"] in selected_docs]
406
+ logging.info(f"Number of filtered documents: {len(filtered_docs)}")
407
+
408
+ if not filtered_docs:
409
+ logging.warning(f"No relevant information found in the selected documents: {selected_docs}")
410
+ yield "No relevant information found in the selected documents. Please try selecting different documents or rephrasing your query."
411
+ return
412
+
413
+ for doc in filtered_docs:
414
+ logging.info(f"Document source: {doc.metadata['source']}")
415
+ logging.info(f"Document content preview: {doc.page_content[:100]}...") # Log first 100 characters of each document
416
+
417
+ context_str = "\n".join([doc.page_content for doc in filtered_docs])
418
+ logging.info(f"Total context length: {len(context_str)}")
419
+
420
+ if model == "@cf/meta/llama-3.1-8b-instruct":
421
+ logging.info("Using Cloudflare API")
422
+ # Use Cloudflare API with the retrieved context
423
+ for response in get_response_from_cloudflare(prompt="", context=context_str, query=query, num_calls=num_calls, temperature=temperature, search_type="pdf"):
424
+ yield response
425
+ else:
426
+ logging.info("Using Hugging Face API")
427
+ # Use Hugging Face API
428
+ prompt = f"""Using the following context from the PDF documents:
429
+ {context_str}
430
+ Write a detailed and complete response that answers the following user question: '{query}'"""
431
+
432
+ client = InferenceClient(model, token=huggingface_token)
433
+
434
+ response = ""
435
+ for i in range(num_calls):
436
+ logging.info(f"API call {i+1}/{num_calls}")
437
+ for message in client.chat_completion(
438
+ messages=[{"role": "user", "content": prompt}],
439
+ max_tokens=1000,
440
+ temperature=temperature,
441
+ stream=True,
442
+ ):
443
+ if message.choices and message.choices[0].delta and message.choices[0].delta.content:
444
+ chunk = message.choices[0].delta.content
445
+ response += chunk
446
+ yield response # Yield partial response
447
+
448
+ logging.info("Finished generating response")
449
+
450
+ def vote(data: gr.LikeData):
451
+ if data.liked:
452
+ print(f"You upvoted this response: {data.value}")
453
+ else:
454
+ print(f"You downvoted this response: {data.value}")
455
+
456
+ css = """
457
+ /* Add your custom CSS here */
458
+ """
459
+
460
+ uploaded_documents = []
461
+
462
+ def display_documents():
463
+ return gr.CheckboxGroup(
464
+ choices=[doc["name"] for doc in uploaded_documents],
465
+ value=[doc["name"] for doc in uploaded_documents if doc["selected"]],
466
+ label="Select documents to query"
467
+ )
468
+
469
+ # Define the checkbox outside the demo block
470
+ document_selector = gr.CheckboxGroup(label="Select documents to query")
471
+
472
+ use_web_search = gr.Checkbox(label="Use Web Search", value=False)
473
+
474
  demo = gr.ChatInterface(
475
  respond,
476
  additional_inputs=[