zliang commited on
Commit
8705301
1 Parent(s): a96314b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -292
app.py CHANGED
@@ -1,10 +1,16 @@
1
 
2
 
 
 
 
3
 
4
  import os
 
 
5
  import streamlit as st
6
  import numpy as np
7
  import fitz # PyMuPDF
 
8
  from ultralytics import YOLO
9
  from sklearn.cluster import KMeans
10
  from sklearn.metrics.pairwise import cosine_similarity
@@ -13,27 +19,22 @@ from langchain_community.document_loaders import PyMuPDFLoader
13
  from langchain_openai import OpenAIEmbeddings
14
  from langchain_text_splitters import RecursiveCharacterTextSplitter
15
  from langchain_core.prompts import ChatPromptTemplate
16
- from sklearn.decomposition import PCA
17
  from langchain_openai import ChatOpenAI
18
- import string
19
  import re
20
-
 
21
 
22
  # Load the trained model
23
- model = YOLO("best.pt")
24
- openai_api_key = os.environ.get("openai_api_key")
25
 
26
  # Define the class indices for figures, tables, and text
27
- figure_class_index = 4 # class index for figures
28
- table_class_index = 3 # class index for tables
29
-
30
- # Global variables to store embeddings and contents
31
- global_embeddings = None
32
- global_split_contents = None
33
 
 
34
  def clean_text(text):
35
- text = re.sub(r'\s+', ' ', text).strip()
36
- return text
37
 
38
  def remove_references(text):
39
  reference_patterns = [
@@ -47,9 +48,10 @@ def remove_references(text):
47
  return text
48
 
49
  def save_uploaded_file(uploaded_file):
50
- with open(uploaded_file.name, 'wb') as f:
51
- f.write(uploaded_file.getbuffer())
52
- return uploaded_file.name
 
53
 
54
  def summarize_pdf(pdf_file_path, num_clusters=10):
55
  embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
@@ -70,66 +72,21 @@ def summarize_pdf(pdf_file_path, num_clusters=10):
70
  loader = PyMuPDFLoader(pdf_file_path)
71
  docs = loader.load()
72
  full_text = "\n".join(doc.page_content for doc in docs)
73
- cleaned_full_text = remove_references(full_text)
74
- cleaned_full_text = clean_text(cleaned_full_text)
75
 
76
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0,separators=["\n\n", "\n",".", " "])
77
  split_contents = text_splitter.split_text(cleaned_full_text)
78
  embeddings = embeddings_model.embed_documents(split_contents)
79
 
80
- X = np.array(embeddings)
81
  kmeans = KMeans(n_clusters=num_clusters, init='k-means++', random_state=0).fit(embeddings)
82
- cluster_centers = kmeans.cluster_centers_
83
-
84
- closest_point_indices = []
85
- for center in cluster_centers:
86
- distances = np.linalg.norm(embeddings - center, axis=1)
87
- closest_point_indices.append(np.argmin(distances))
88
-
89
  extracted_contents = [split_contents[idx] for idx in closest_point_indices]
90
- results = chain.invoke({"topic": ' '.join(extracted_contents)})
91
-
92
- summary_sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', results)
93
- summary_embeddings = embeddings_model.embed_documents(summary_sentences)
94
- extracted_embeddings = embeddings_model.embed_documents(extracted_contents)
95
- similarity_matrix = cosine_similarity(summary_embeddings, extracted_embeddings)
96
 
97
- cited_results = results
98
- relevant_sources = []
99
- source_mapping = {}
100
- sentence_to_source = {}
101
- similarity_threshold = 0.6
102
 
103
- for i, sentence in enumerate(summary_sentences):
104
- if sentence in sentence_to_source:
105
- continue
106
- max_similarity = max(similarity_matrix[i])
107
- if max_similarity >= similarity_threshold:
108
- most_similar_idx = np.argmax(similarity_matrix[i])
109
- if most_similar_idx not in source_mapping:
110
- source_mapping[most_similar_idx] = len(relevant_sources) + 1
111
- relevant_sources.append((most_similar_idx, extracted_contents[most_similar_idx]))
112
- citation_idx = source_mapping[most_similar_idx]
113
- citation = f"([Source {citation_idx}](#source-{citation_idx}))"
114
- cited_sentence = re.sub(r'([.!?])$', f" {citation}\\1", sentence)
115
- sentence_to_source[sentence] = citation_idx
116
- cited_results = cited_results.replace(sentence, cited_sentence)
117
-
118
- sources_list = "\n\n## Sources:\n"
119
- for idx, (original_idx, content) in enumerate(relevant_sources):
120
- sources_list += f"""
121
- <details style="margin: 10px 0; padding: 10px; border: 1px solid #ccc; border-radius: 5px; background-color: #f9f9f9;">
122
- <summary style="font-weight: bold; cursor: pointer;">Source {idx + 1}</summary>
123
- <pre style="white-space: pre-wrap; word-wrap: break-word; margin-top: 10px;">{content}</pre>
124
- </details>
125
- """
126
- cited_results += sources_list
127
- return cited_results
128
 
129
  def qa_pdf(pdf_file_path, query, num_clusters=5, similarity_threshold=0.6):
130
- global global_embeddings, global_split_contents
131
-
132
- # Initialize models and embeddings
133
  embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
134
  llm = ChatOpenAI(model="gpt-3.5-turbo", api_key=openai_api_key, temperature=0.3)
135
  prompt = ChatPromptTemplate.from_template(
@@ -142,40 +99,37 @@ def qa_pdf(pdf_file_path, query, num_clusters=5, similarity_threshold=0.6):
142
  output_parser = StrOutputParser()
143
  chain = prompt | llm | output_parser
144
 
145
- # Load and process the PDF if not already loaded
146
- if global_embeddings is None or global_split_contents is None:
147
- loader = PyMuPDFLoader(pdf_file_path)
148
- docs = loader.load()
149
- full_text = "\n".join(doc.page_content for doc in docs)
150
- cleaned_full_text = remove_references(full_text)
151
- cleaned_full_text = clean_text(cleaned_full_text)
152
 
153
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=0, separators=["\n\n", "\n", ".", " "])
154
- global_split_contents = text_splitter.split_text(cleaned_full_text)
155
- global_embeddings = embeddings_model.embed_documents(global_split_contents)
156
 
157
- # Embed the query and find the most relevant contexts
158
  query_embedding = embeddings_model.embed_query(query)
159
- similarity_scores = cosine_similarity([query_embedding], global_embeddings)[0]
160
  top_indices = np.argsort(similarity_scores)[-num_clusters:]
161
- relevant_contents = [global_split_contents[i] for i in top_indices]
162
 
163
- # Generate the answer using the LLM chain
164
  results = chain.invoke({"question": query, "contexts": ' '.join(relevant_contents)})
165
 
166
- # Split the answer into sentences and embed them
167
- answer_sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', results)
168
- answer_embeddings = embeddings_model.embed_documents(answer_sentences)
169
- relevant_embeddings = embeddings_model.embed_documents(relevant_contents)
170
- similarity_matrix = cosine_similarity(answer_embeddings, relevant_embeddings)
171
 
172
- # Map sentences to sources and create citations
173
- cited_results = results
 
 
 
 
 
 
174
  relevant_sources = []
175
  source_mapping = {}
176
  sentence_to_source = {}
177
 
178
- for i, sentence in enumerate(answer_sentences):
179
  if sentence in sentence_to_source:
180
  continue
181
  max_similarity = max(similarity_matrix[i])
@@ -183,34 +137,42 @@ def qa_pdf(pdf_file_path, query, num_clusters=5, similarity_threshold=0.6):
183
  most_similar_idx = np.argmax(similarity_matrix[i])
184
  if most_similar_idx not in source_mapping:
185
  source_mapping[most_similar_idx] = len(relevant_sources) + 1
186
- relevant_sources.append((most_similar_idx, relevant_contents[most_similar_idx]))
187
  citation_idx = source_mapping[most_similar_idx]
188
- citation = f"<strong style='color:blue;'>[Source {citation_idx}]</strong>"
189
  cited_sentence = re.sub(r'([.!?])$', f" {citation}\\1", sentence)
190
  sentence_to_source[sentence] = citation_idx
191
- cited_results = cited_results.replace(sentence, cited_sentence)
192
 
193
- # Format the sources for markdown rendering
194
  sources_list = "\n\n## Sources:\n"
195
  for idx, (original_idx, content) in enumerate(relevant_sources):
196
- sources_list += f"""
197
- <details style="margin: 10px 0; padding: 10px; border: 1px solid #ccc; border-radius: 5px; background-color: #f9f9f9;">
198
- <summary style="font-weight: bold; cursor: pointer;">Source {idx + 1}</summary>
199
- <pre style="white-space: pre-wrap; word-wrap: break-word; margin-top: 10px;">{content}</pre>
200
  </details>
201
  """
202
- cited_results += sources_list
203
- return cited_results
204
 
 
 
 
 
 
 
 
205
 
206
- def infer_image_and_get_boxes(image, confidence_threshold=0.6):
 
 
 
 
 
207
  results = model.predict(image)
208
- boxes = [
209
  (int(box.xyxy[0][0]), int(box.xyxy[0][1]), int(box.xyxy[0][2]), int(box.xyxy[0][3]), int(box.cls[0]))
210
  for result in results for box in result.boxes
211
  if int(box.cls[0]) in {figure_class_index, table_class_index} and box.conf[0] > confidence_threshold
212
  ]
213
- return boxes
214
 
215
  def crop_images_from_boxes(image, boxes, scale_factor):
216
  figures = []
@@ -223,7 +185,6 @@ def crop_images_from_boxes(image, boxes, scale_factor):
223
  tables.append(cropped_img)
224
  return figures, tables
225
 
226
-
227
  def process_pdf(pdf_file_path):
228
  doc = fitz.open(pdf_file_path)
229
  all_figures = []
@@ -246,213 +207,173 @@ def process_pdf(pdf_file_path):
246
 
247
  return all_figures, all_tables
248
 
249
- # Set the page configuration for a modern look
 
 
 
 
 
 
 
250
 
251
- # Set the page configuration for a modern look
252
- # Set the page configuration for a modern look
253
- st.set_page_config(page_title="PDF Reading Assistant", page_icon="📄", layout="wide")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
- # Add some custom CSS for a modern look
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  st.markdown("""
257
  <style>
258
- /* Main background and padding */
259
- .main {
260
- background-color: #f8f9fa;
261
- padding: 2rem;
262
- font-family: 'Arial', sans-serif;
 
 
 
 
263
  }
264
-
265
- /* Section headers */
266
- .section-header {
267
- font-size: 2rem;
268
- font-weight: bold;
269
- color: #343a40;
270
- margin-top: 2rem;
271
- margin-bottom: 1rem;
272
- text-align: center;
273
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
274
  }
275
-
276
- /* Containers */
277
- .uploaded-file-container, .chat-container, .summary-container, .extract-container {
278
- padding: 2rem;
279
- background-color: #ffffff;
280
- border-radius: 10px;
281
- margin-bottom: 2rem;
282
- box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
283
  }
284
-
285
- /* Buttons */
286
- .stButton>button {
287
- background-color: #007bff;
288
- color: white;
289
- padding: 0.6rem 1.2rem;
290
- border-radius: 5px;
291
- border: none;
292
- cursor: pointer;
293
- font-size: 1rem;
294
- transition: background-color 0.3s ease, transform 0.3s ease;
295
  }
296
- .stButton>button:hover {
297
- background-color: #0056b3;
298
- transform: translateY(-2px);
299
  }
300
-
301
- /* Chat messages */
302
- .chat-message {
303
- padding: 1rem;
304
- border-radius: 10px;
305
- margin-bottom: 1rem;
306
- font-size: 1rem;
307
- transition: all 0.3s ease;
308
- box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
309
- }
310
- .chat-message.user {
311
- background-color: #e6f7ff;
312
- border-left: 5px solid #007bff;
313
- text-align: left;
314
  }
315
- .chat-message.bot {
316
- background-color: #fff0f1;
317
- border-left: 5px solid #dc3545;
318
- text-align: left;
319
- }
320
-
321
- /* Input area */
322
- .input-container {
323
- display: flex;
324
- align-items: center;
325
- gap: 10px;
326
- margin-top: 1rem;
327
  }
328
- .input-container textarea {
329
- border: 2px solid #ccc;
330
- border-radius: 10px;
331
- padding: 10px;
332
  width: 100%;
333
- background-color: #fff;
334
- transition: border-color 0.3s ease;
335
- margin: 0;
336
- font-size: 1rem;
 
337
  }
338
- .input-container textarea:focus {
339
- border-color: #007bff;
340
- outline: none;
341
  }
342
- .input-container button {
343
- background-color: #007bff;
 
344
  color: white;
345
- padding: 0.6rem 1.2rem;
346
- border-radius: 5px;
347
- border: none;
348
- cursor: pointer;
349
- font-size: 1rem;
350
- transition: background-color 0.3s ease, transform 0.3s ease;
351
- }
352
- .input-container button:hover {
353
- background-color: #0056b3;
354
- transform: translateY(-2px);
355
- }
356
-
357
- /* Expander */
358
- .st-expander {
359
  border: none;
360
- box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
361
- margin-bottom: 2rem;
 
 
 
362
  }
363
-
364
- /* Markdown elements */
365
- .stMarkdown {
366
- font-size: 1rem;
367
- color: #343a40;
368
- line-height: 1.6;
369
- }
370
-
371
- /* Titles and subtitles */
372
- .stTitle {
373
- color: #343a40;
374
- text-align: center;
375
- margin-bottom: 1rem;
376
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
377
- }
378
- .stSubtitle {
379
- color: #6c757d;
380
- text-align: center;
381
- margin-bottom: 1rem;
382
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
383
  }
384
  </style>
 
 
 
 
385
  """, unsafe_allow_html=True)
386
-
387
- # Streamlit interface
388
- # Streamlit interface
389
- st.title("📄 PDF Reading Assistant")
390
- st.markdown("### Extract tables, figures, summaries, and answers from your PDF files easily.")
391
-
392
- uploaded_file = st.file_uploader("Upload a PDF", type="pdf")
393
- if uploaded_file:
394
- file_path = save_uploaded_file(uploaded_file)
395
-
396
- if 'figures' not in st.session_state:
397
- st.session_state['figures'] = None
398
- if 'tables' not in st.session_state:
399
- st.session_state['tables'] = None
400
- if 'summary' not in st.session_state:
401
- st.session_state['summary'] = None
402
-
403
- with st.container():
404
- st.markdown("<div class='section-header'>Extract Tables and Figures</div>", unsafe_allow_html=True)
405
- with st.expander("Click to Extract Tables and Figures", expanded=True):
406
- with st.container():
407
- extract_button = st.button("Extract")
408
- if extract_button:
409
- figures, tables = process_pdf(file_path)
410
- st.session_state['figures'] = figures
411
- st.session_state['tables'] = tables
412
-
413
- if st.session_state['figures']:
414
- col1, col2 = st.columns(2)
415
- with col1:
416
- st.write("### Figures")
417
- for figure in st.session_state['figures']:
418
- st.image(figure, use_column_width=True)
419
- with col2:
420
- st.write("### Tables")
421
- for table in st.session_state['tables']:
422
- st.image(table, use_column_width=True)
423
- else:
424
- st.write("No figures or tables found.")
425
-
426
- with st.container():
427
- st.markdown("<div class='section-header'>Get Summary</div>", unsafe_allow_html=True)
428
- with st.expander("Click to Generate Summary", expanded=True):
429
- with st.container():
430
- summary_button = st.button("Generate Summary")
431
- if summary_button:
432
- summary = summarize_pdf(file_path)
433
- st.session_state['summary'] = summary
434
-
435
- if st.session_state['summary']:
436
- st.markdown(st.session_state['summary'], unsafe_allow_html=True)
437
-
438
- with st.container():
439
- st.markdown("<div class='section-header'>Chat with your PDF</div>", unsafe_allow_html=True)
440
- st.write("### Chat with your PDF")
441
- if 'chat_history' not in st.session_state:
442
- st.session_state['chat_history'] = []
443
-
444
- for chat in st.session_state['chat_history']:
445
- chat_user_class = "user" if chat["user"] else ""
446
- chat_bot_class = "bot" if chat["bot"] else ""
447
- st.markdown(f"<div class='chat-message {chat_user_class}'>{chat['user']}</div>", unsafe_allow_html=True)
448
- st.markdown(f"<div class='chat-message {chat_bot_class}'>{chat['bot']}</div>", unsafe_allow_html=True)
449
-
450
- with st.form(key="chat_form", clear_on_submit=True):
451
- user_input = st.text_area("Ask a question about the PDF:", key="user_input")
452
- submit_button = st.form_submit_button(label="Send")
453
-
454
- if submit_button and user_input:
455
- st.session_state['chat_history'].append({"user": user_input, "bot": None})
456
- answer = qa_pdf(file_path, user_input)
457
- st.session_state['chat_history'][-1]["bot"] = answer
458
- st.rerun()
 
1
 
2
 
3
+ # Load the trained model
4
+ model = YOLO("best.pt")
5
+ openai_api_key = os.environ.get("openai_api_key")
6
 
7
  import os
8
+ import io
9
+ import base64
10
  import streamlit as st
11
  import numpy as np
12
  import fitz # PyMuPDF
13
+ import tempfile
14
  from ultralytics import YOLO
15
  from sklearn.cluster import KMeans
16
  from sklearn.metrics.pairwise import cosine_similarity
 
19
  from langchain_openai import OpenAIEmbeddings
20
  from langchain_text_splitters import RecursiveCharacterTextSplitter
21
  from langchain_core.prompts import ChatPromptTemplate
 
22
  from langchain_openai import ChatOpenAI
 
23
  import re
24
+ from PIL import Image
25
+ from streamlit_chat import message
26
 
27
  # Load the trained model
28
+ model = YOLO("runs\\detect\\train7\\weights\\best.pt")
29
+ openai_api_key = "sk-proj-J7kj0kbG1m0eIMPWMdjoT3BlbkFJqwZNqQeOYJ9UH6I0efPi"
30
 
31
  # Define the class indices for figures, tables, and text
32
+ figure_class_index = 4
33
+ table_class_index = 3
 
 
 
 
34
 
35
+ # Utility functions
36
  def clean_text(text):
37
+ return re.sub(r'\s+', ' ', text).strip()
 
38
 
39
  def remove_references(text):
40
  reference_patterns = [
 
48
  return text
49
 
50
  def save_uploaded_file(uploaded_file):
51
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
52
+ temp_file.write(uploaded_file.getbuffer())
53
+ temp_file.close()
54
+ return temp_file.name
55
 
56
  def summarize_pdf(pdf_file_path, num_clusters=10):
57
  embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
 
72
  loader = PyMuPDFLoader(pdf_file_path)
73
  docs = loader.load()
74
  full_text = "\n".join(doc.page_content for doc in docs)
75
+ cleaned_full_text = clean_text(remove_references(full_text))
 
76
 
77
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0, separators=["\n\n", "\n", ".", " "])
78
  split_contents = text_splitter.split_text(cleaned_full_text)
79
  embeddings = embeddings_model.embed_documents(split_contents)
80
 
 
81
  kmeans = KMeans(n_clusters=num_clusters, init='k-means++', random_state=0).fit(embeddings)
82
+ closest_point_indices = [np.argmin(np.linalg.norm(embeddings - center, axis=1)) for center in kmeans.cluster_centers_]
 
 
 
 
 
 
83
  extracted_contents = [split_contents[idx] for idx in closest_point_indices]
 
 
 
 
 
 
84
 
85
+ results = chain.invoke({"topic": ' '.join(extracted_contents)})
 
 
 
 
86
 
87
+ return generate_citations(results, extracted_contents)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  def qa_pdf(pdf_file_path, query, num_clusters=5, similarity_threshold=0.6):
 
 
 
90
  embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
91
  llm = ChatOpenAI(model="gpt-3.5-turbo", api_key=openai_api_key, temperature=0.3)
92
  prompt = ChatPromptTemplate.from_template(
 
99
  output_parser = StrOutputParser()
100
  chain = prompt | llm | output_parser
101
 
102
+ loader = PyMuPDFLoader(pdf_file_path)
103
+ docs = loader.load()
104
+ full_text = "\n".join(doc.page_content for doc in docs)
105
+ cleaned_full_text = clean_text(remove_references(full_text))
 
 
 
106
 
107
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=0, separators=["\n\n", "\n", ".", " "])
108
+ split_contents = text_splitter.split_text(cleaned_full_text)
109
+ embeddings = embeddings_model.embed_documents(split_contents)
110
 
 
111
  query_embedding = embeddings_model.embed_query(query)
112
+ similarity_scores = cosine_similarity([query_embedding], embeddings)[0]
113
  top_indices = np.argsort(similarity_scores)[-num_clusters:]
114
+ relevant_contents = [split_contents[i] for i in top_indices]
115
 
 
116
  results = chain.invoke({"question": query, "contexts": ' '.join(relevant_contents)})
117
 
118
+ return generate_citations(results, relevant_contents, similarity_threshold)
 
 
 
 
119
 
120
+ def generate_citations(text, contents, similarity_threshold=0.6):
121
+ embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
122
+ text_sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
123
+ text_embeddings = embeddings_model.embed_documents(text_sentences)
124
+ content_embeddings = embeddings_model.embed_documents(contents)
125
+ similarity_matrix = cosine_similarity(text_embeddings, content_embeddings)
126
+
127
+ cited_text = text
128
  relevant_sources = []
129
  source_mapping = {}
130
  sentence_to_source = {}
131
 
132
+ for i, sentence in enumerate(text_sentences):
133
  if sentence in sentence_to_source:
134
  continue
135
  max_similarity = max(similarity_matrix[i])
 
137
  most_similar_idx = np.argmax(similarity_matrix[i])
138
  if most_similar_idx not in source_mapping:
139
  source_mapping[most_similar_idx] = len(relevant_sources) + 1
140
+ relevant_sources.append((most_similar_idx, contents[most_similar_idx]))
141
  citation_idx = source_mapping[most_similar_idx]
142
+ citation = f"([Source {citation_idx}](#source-{citation_idx}))"
143
  cited_sentence = re.sub(r'([.!?])$', f" {citation}\\1", sentence)
144
  sentence_to_source[sentence] = citation_idx
145
+ cited_text = cited_text.replace(sentence, cited_sentence)
146
 
 
147
  sources_list = "\n\n## Sources:\n"
148
  for idx, (original_idx, content) in enumerate(relevant_sources):
149
+ sources_list += f"""
150
+ <details style="margin: 1px 0; padding: 5px; border: 1px solid #ccc; border-radius: 8px; background-color: #f9f9f9; transition: all 0.3s ease;">
151
+ <summary style="font-weight: bold; cursor: pointer; outline: none; padding: 5px 0; transition: color 0.3s ease;">Source {idx + 1}</summary>
152
+ <pre style="white-space: pre-wrap; word-wrap: break-word; margin: 1px 0; padding: 10px; background-color: #fff; border-radius: 5px; border: 1px solid #ddd; box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);">{content}</pre>
153
  </details>
154
  """
 
 
155
 
156
+ # Add dummy blanks after the last source
157
+ dummy_blanks = """
158
+ <div style="margin: 20px 0;"></div>
159
+ <div style="margin: 20px 0;"></div>
160
+ <div style="margin: 20px 0;"></div>
161
+ <div style="margin: 20px 0;"></div>
162
+ <div style="margin: 20px 0;"></div>
163
 
164
+ """
165
+
166
+ cited_text += sources_list + dummy_blanks
167
+ return cited_text
168
+
169
+ def infer_image_and_get_boxes(image, confidence_threshold=0.8):
170
  results = model.predict(image)
171
+ return [
172
  (int(box.xyxy[0][0]), int(box.xyxy[0][1]), int(box.xyxy[0][2]), int(box.xyxy[0][3]), int(box.cls[0]))
173
  for result in results for box in result.boxes
174
  if int(box.cls[0]) in {figure_class_index, table_class_index} and box.conf[0] > confidence_threshold
175
  ]
 
176
 
177
  def crop_images_from_boxes(image, boxes, scale_factor):
178
  figures = []
 
185
  tables.append(cropped_img)
186
  return figures, tables
187
 
 
188
  def process_pdf(pdf_file_path):
189
  doc = fitz.open(pdf_file_path)
190
  all_figures = []
 
207
 
208
  return all_figures, all_tables
209
 
210
+ def image_to_base64(img):
211
+ buffered = io.BytesIO()
212
+ img = Image.fromarray(img)
213
+ img.save(buffered, format="PNG")
214
+ return base64.b64encode(buffered.getvalue()).decode()
215
+
216
+ def on_btn_click():
217
+ del st.session_state.chat_history[:]
218
 
219
+ # Streamlit interface
220
+
221
+ # Custom CSS for the file uploader
222
+ uploadercss='''
223
+ <style>
224
+ [data-testid='stFileUploader'] {
225
+ width: max-content;
226
+ }
227
+ [data-testid='stFileUploader'] section {
228
+ padding: 0;
229
+ float: left;
230
+ }
231
+ [data-testid='stFileUploader'] section > input + div {
232
+ display: none;
233
+ }
234
+ [data-testid='stFileUploader'] section + div {
235
+ float: right;
236
+ padding-top: 0;
237
+ }
238
+
239
+ </style>
240
+ '''
241
+
242
+ st.set_page_config(page_title="PDF Reading Assistant", page_icon="📄")
243
+
244
+ # Initialize chat history in session state if not already present
245
+ if 'chat_history' not in st.session_state:
246
+ st.session_state.chat_history = []
247
+
248
+ st.title("📄 PDF Reading Assistant")
249
+ st.markdown("### Extract tables, figures, summaries, and answers from your PDF files easily.")
250
+ chat_placeholder = st.empty()
251
+
252
+ # File uploader for PDF
253
+ uploaded_file = st.file_uploader("Upload a PDF", type="pdf")
254
+ st.markdown(uploadercss, unsafe_allow_html=True)
255
+ if uploaded_file:
256
+ file_path = save_uploaded_file(uploaded_file)
257
 
258
+ # Chat container where all messages will be displayed
259
+ chat_container = st.container()
260
+ user_input = st.chat_input("Ask a question about the pdf......", key="user_input")
261
+ with chat_container:
262
+ # Scrollable chat messages
263
+ for idx, chat in enumerate(st.session_state.chat_history):
264
+ if chat.get("user"):
265
+ message(chat["user"], is_user=True, allow_html=True, key=f"user_{idx}", avatar_style="initials", seed="user")
266
+ if chat.get("bot"):
267
+ message(chat["bot"], is_user=False, allow_html=True, key=f"bot_{idx}",seed="bot")
268
+
269
+ # Input area and buttons for user interaction
270
+ with st.form(key="chat_form", clear_on_submit=True,border=False):
271
+
272
+ col1, col2, col3 = st.columns([1, 1, 1])
273
+ with col1:
274
+ summary_button = st.form_submit_button("Generate Summary")
275
+ with col2:
276
+ extract_button = st.form_submit_button("Extract Tables and Figures")
277
+ with col3:
278
+ st.form_submit_button("Clear message", on_click=on_btn_click)
279
+
280
+ # Handle responses based on user input and button presses
281
+ if summary_button:
282
+ with st.spinner("Generating summary..."):
283
+ summary = summarize_pdf(file_path)
284
+ st.session_state.chat_history.append({"user": "Generate Summary", "bot": summary})
285
+ st.rerun()
286
+
287
+ if extract_button:
288
+ with st.spinner("Extracting tables and figures..."):
289
+ figures, tables = process_pdf(file_path)
290
+ if figures:
291
+ st.session_state.chat_history.append({"user": "Figures"})
292
+
293
+ for idx, figure in enumerate(figures):
294
+ figure_base64 = image_to_base64(figure)
295
+ result_html = f'<img src="data:image/png;base64,{figure_base64}" style="width:100%; display:block;" alt="Figure {idx+1}"/>'
296
+ st.session_state.chat_history.append({"bot": f"Figure {idx+1} {result_html}"})
297
+ if tables:
298
+ st.session_state.chat_history.append({"user": "Tables"})
299
+ for idx, table in enumerate(tables):
300
+ table_base64 = image_to_base64(table)
301
+ result_html = f'<img src="data:image/png;base64,{table_base64}" style="width:100%; display:block;" alt="Table {idx+1}"/>'
302
+ st.session_state.chat_history.append({"bot": f"Table {idx+1} {result_html}"})
303
+ st.rerun()
304
+
305
+ if user_input:
306
+ st.session_state.chat_history.append({"user": user_input, "bot": None})
307
+ with st.spinner("Processing..."):
308
+ answer = qa_pdf(file_path, user_input)
309
+ st.session_state.chat_history[-1]["bot"] = answer
310
+ st.rerun()
311
+
312
+ # Additional CSS and JavaScript to ensure the chat container is scrollable and scrolls to the bottom
313
  st.markdown("""
314
  <style>
315
+ #chat-container {
316
+ max-height: 500px;
317
+ overflow-y: auto;
318
+ padding: 1rem;
319
+ border: 1px solid #ddd;
320
+ border-radius: 8px;
321
+ background-color: #fefefe;
322
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
323
+ transition: background-color 0.3s ease;
324
  }
325
+ #chat-container:hover {
326
+ background-color: #f9f9f9;
 
 
 
 
 
 
 
 
327
  }
328
+ .stChatMessage {
329
+ padding: 0.75rem;
330
+ margin: 0.75rem 0;
331
+ border-radius: 8px;
332
+ box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
333
+ transition: background-color 0.3s ease;
 
 
334
  }
335
+ .stChatMessage--user {
336
+ background-color: #E3F2FD;
 
 
 
 
 
 
 
 
 
337
  }
338
+ .stChatMessage--user:hover {
339
+ background-color: #BBDEFB;
 
340
  }
341
+ .stChatMessage--bot {
342
+ background-color: #EDE7F6;
 
 
 
 
 
 
 
 
 
 
 
 
343
  }
344
+ .stChatMessage--bot:hover {
345
+ background-color: #D1C4E9;
 
 
 
 
 
 
 
 
 
 
346
  }
347
+ textarea {
 
 
 
348
  width: 100%;
349
+ padding: 1rem;
350
+ border: 1px solid #ddd;
351
+ border-radius: 8px;
352
+ box-shadow: inset 0 1px 3px rgba(0, 0, 0, 0.1);
353
+ transition: border-color 0.3s ease, box-shadow 0.3s ease;
354
  }
355
+ textarea:focus {
356
+ border-color: #4CAF50;
357
+ box-shadow: 0 0 5px rgba(76, 175, 80, 0.5);
358
  }
359
+ .stButton > button {
360
+ width: 100%;
361
+ background-color: #4CAF50;
362
  color: white;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  border: none;
364
+ border-radius: 8px;
365
+ padding: 0.75rem;
366
+ font-size: 16px;
367
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
368
+ transition: background-color 0.3s ease, box-shadow 0.3s ease;
369
  }
370
+ .stButton > button:hover {
371
+ background-color: #45A049;
372
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  }
374
  </style>
375
+ <script>
376
+ const chatContainer = document.getElementById('chat-container');
377
+ chatContainer.scrollTop = chatContainer.scrollHeight;
378
+ </script>
379
  """, unsafe_allow_html=True)