zliang commited on
Commit
8bbef17
β€’
1 Parent(s): d44d458

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +445 -445
app.py CHANGED
@@ -1,446 +1,446 @@
1
-
2
-
3
-
4
- import os
5
- import streamlit as st
6
- import numpy as np
7
- import fitz # PyMuPDF
8
- from ultralytics import YOLO
9
- from sklearn.cluster import KMeans
10
- from sklearn.metrics.pairwise import cosine_similarity
11
- from langchain_core.output_parsers import StrOutputParser
12
- from langchain_community.document_loaders import PyMuPDFLoader
13
- from langchain_openai import OpenAIEmbeddings
14
- from langchain_text_splitters import RecursiveCharacterTextSplitter
15
- from langchain_core.prompts import ChatPromptTemplate
16
- from sklearn.decomposition import PCA
17
- from langchain_openai import ChatOpenAI
18
- import string
19
- import re
20
- from termcolor import colored
21
-
22
- # Load the trained model
23
- model = YOLO("runs\\detect\\train7\\weights\\best.pt")
24
- openai_api_key = os.environ.get("openai_api_key")
25
-
26
- # Define the class indices for figures, tables, and text
27
- figure_class_index = 4 # class index for figures
28
- table_class_index = 3 # class index for tables
29
-
30
- # Global variables to store embeddings and contents
31
- global_embeddings = None
32
- global_split_contents = None
33
-
34
- def clean_text(text):
35
- text = re.sub(r'\s+', ' ', text).strip()
36
- return text
37
-
38
- def remove_references(text):
39
- reference_patterns = [
40
- r'\bReferences\b', r'\breferences\b', r'\bBibliography\b', r'\bCitations\b',
41
- r'\bWorks Cited\b', r'\bReference\b', r'\breference\b'
42
- ]
43
- lines = text.split('\n')
44
- for i, line in enumerate(lines):
45
- if any(re.search(pattern, line, re.IGNORECASE) for pattern in reference_patterns):
46
- return '\n'.join(lines[:i])
47
- return text
48
-
49
- def save_uploaded_file(uploaded_file):
50
- with open(uploaded_file.name, 'wb') as f:
51
- f.write(uploaded_file.getbuffer())
52
- return uploaded_file.name
53
-
54
- def summarize_pdf(pdf_file_path, num_clusters=10):
55
- embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
56
- llm = ChatOpenAI(model="gpt-3.5-turbo", api_key=openai_api_key, temperature=0.3)
57
- prompt = ChatPromptTemplate.from_template(
58
- """Could you please provide a concise and comprehensive summary of the given Contexts?
59
- The summary should capture the main points and key details of the text while conveying the author's intended meaning accurately.
60
- Please ensure that the summary is well-organized and easy to read, with clear headings and subheadings to guide the reader through each section.
61
- The length of the summary should be appropriate to capture the main points and key details of the text, without including unnecessary information or becoming overly long.
62
- example of summary:
63
- ## Summary:
64
- ## Key points:
65
- Contexts: {topic}"""
66
- )
67
- output_parser = StrOutputParser()
68
- chain = prompt | llm | output_parser
69
-
70
- loader = PyMuPDFLoader(pdf_file_path)
71
- docs = loader.load()
72
- full_text = "\n".join(doc.page_content for doc in docs)
73
- cleaned_full_text = remove_references(full_text)
74
- cleaned_full_text = clean_text(cleaned_full_text)
75
-
76
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0,separators=["\n\n", "\n",".", " "])
77
- split_contents = text_splitter.split_text(cleaned_full_text)
78
- embeddings = embeddings_model.embed_documents(split_contents)
79
-
80
- X = np.array(embeddings)
81
- kmeans = KMeans(n_clusters=num_clusters, init='k-means++', random_state=0).fit(embeddings)
82
- cluster_centers = kmeans.cluster_centers_
83
-
84
- closest_point_indices = []
85
- for center in cluster_centers:
86
- distances = np.linalg.norm(embeddings - center, axis=1)
87
- closest_point_indices.append(np.argmin(distances))
88
-
89
- extracted_contents = [split_contents[idx] for idx in closest_point_indices]
90
- results = chain.invoke({"topic": ' '.join(extracted_contents)})
91
-
92
- summary_sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', results)
93
- summary_embeddings = embeddings_model.embed_documents(summary_sentences)
94
- extracted_embeddings = embeddings_model.embed_documents(extracted_contents)
95
- similarity_matrix = cosine_similarity(summary_embeddings, extracted_embeddings)
96
-
97
- cited_results = results
98
- relevant_sources = []
99
- source_mapping = {}
100
- sentence_to_source = {}
101
- similarity_threshold = 0.6
102
-
103
- for i, sentence in enumerate(summary_sentences):
104
- if sentence in sentence_to_source:
105
- continue
106
- max_similarity = max(similarity_matrix[i])
107
- if max_similarity >= similarity_threshold:
108
- most_similar_idx = np.argmax(similarity_matrix[i])
109
- if most_similar_idx not in source_mapping:
110
- source_mapping[most_similar_idx] = len(relevant_sources) + 1
111
- relevant_sources.append((most_similar_idx, extracted_contents[most_similar_idx]))
112
- citation_idx = source_mapping[most_similar_idx]
113
- citation = f"([Source {citation_idx}](#source-{citation_idx}))"
114
- cited_sentence = re.sub(r'([.!?])$', f" {citation}\\1", sentence)
115
- sentence_to_source[sentence] = citation_idx
116
- cited_results = cited_results.replace(sentence, cited_sentence)
117
-
118
- sources_list = "\n\n## Sources:\n"
119
- for idx, (original_idx, content) in enumerate(relevant_sources):
120
- sources_list += f"""
121
- <details style="margin: 10px 0; padding: 10px; border: 1px solid #ccc; border-radius: 5px; background-color: #f9f9f9;">
122
- <summary style="font-weight: bold; cursor: pointer;">Source {idx + 1}</summary>
123
- <pre style="white-space: pre-wrap; word-wrap: break-word; margin-top: 10px;">{content}</pre>
124
- </details>
125
- """
126
- cited_results += sources_list
127
- return cited_results
128
-
129
- def qa_pdf(pdf_file_path, query, num_clusters=5, similarity_threshold=0.6):
130
- global global_embeddings, global_split_contents
131
-
132
- # Initialize models and embeddings
133
- embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
134
- llm = ChatOpenAI(model="gpt-3.5-turbo", api_key=openai_api_key, temperature=0.3)
135
- prompt = ChatPromptTemplate.from_template(
136
- """Please provide a detailed and accurate answer to the given question based on the provided contexts.
137
- Ensure that the answer is comprehensive and directly addresses the query.
138
- If necessary, include relevant examples or details from the text.
139
- Question: {question}
140
- Contexts: {contexts}"""
141
- )
142
- output_parser = StrOutputParser()
143
- chain = prompt | llm | output_parser
144
-
145
- # Load and process the PDF if not already loaded
146
- if global_embeddings is None or global_split_contents is None:
147
- loader = PyMuPDFLoader(pdf_file_path)
148
- docs = loader.load()
149
- full_text = "\n".join(doc.page_content for doc in docs)
150
- cleaned_full_text = remove_references(full_text)
151
- cleaned_full_text = clean_text(cleaned_full_text)
152
-
153
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=0, separators=["\n\n", "\n", ".", " "])
154
- global_split_contents = text_splitter.split_text(cleaned_full_text)
155
- global_embeddings = embeddings_model.embed_documents(global_split_contents)
156
-
157
- # Embed the query and find the most relevant contexts
158
- query_embedding = embeddings_model.embed_query(query)
159
- similarity_scores = cosine_similarity([query_embedding], global_embeddings)[0]
160
- top_indices = np.argsort(similarity_scores)[-num_clusters:]
161
- relevant_contents = [global_split_contents[i] for i in top_indices]
162
-
163
- # Generate the answer using the LLM chain
164
- results = chain.invoke({"question": query, "contexts": ' '.join(relevant_contents)})
165
-
166
- # Split the answer into sentences and embed them
167
- answer_sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', results)
168
- answer_embeddings = embeddings_model.embed_documents(answer_sentences)
169
- relevant_embeddings = embeddings_model.embed_documents(relevant_contents)
170
- similarity_matrix = cosine_similarity(answer_embeddings, relevant_embeddings)
171
-
172
- # Map sentences to sources and create citations
173
- cited_results = results
174
- relevant_sources = []
175
- source_mapping = {}
176
- sentence_to_source = {}
177
-
178
- for i, sentence in enumerate(answer_sentences):
179
- if sentence in sentence_to_source:
180
- continue
181
- max_similarity = max(similarity_matrix[i])
182
- if max_similarity >= similarity_threshold:
183
- most_similar_idx = np.argmax(similarity_matrix[i])
184
- if most_similar_idx not in source_mapping:
185
- source_mapping[most_similar_idx] = len(relevant_sources) + 1
186
- relevant_sources.append((most_similar_idx, relevant_contents[most_similar_idx]))
187
- citation_idx = source_mapping[most_similar_idx]
188
- citation = f"<strong style='color:blue;'>[Source {citation_idx}]</strong>"
189
- cited_sentence = re.sub(r'([.!?])$', f" {citation}\\1", sentence)
190
- sentence_to_source[sentence] = citation_idx
191
- cited_results = cited_results.replace(sentence, cited_sentence)
192
-
193
- # Format the sources for markdown rendering
194
- sources_list = "\n\n## Sources:\n"
195
- for idx, (original_idx, content) in enumerate(relevant_sources):
196
- sources_list += f"""
197
- <details style="margin: 10px 0; padding: 10px; border: 1px solid #ccc; border-radius: 5px; background-color: #f9f9f9;">
198
- <summary style="font-weight: bold; cursor: pointer;">Source {idx + 1}</summary>
199
- <pre style="white-space: pre-wrap; word-wrap: break-word; margin-top: 10px;">{content}</pre>
200
- </details>
201
- """
202
- cited_results += sources_list
203
- return cited_results
204
-
205
-
206
- def infer_image_and_get_boxes(image, confidence_threshold=0.6):
207
- results = model.predict(image)
208
- boxes = [
209
- (int(box.xyxy[0][0]), int(box.xyxy[0][1]), int(box.xyxy[0][2]), int(box.xyxy[0][3]), int(box.cls[0]))
210
- for result in results for box in result.boxes
211
- if int(box.cls[0]) in {figure_class_index, table_class_index} and box.conf[0] > confidence_threshold
212
- ]
213
- return boxes
214
-
215
- def crop_images_from_boxes(image, boxes, scale_factor):
216
- figures = []
217
- tables = []
218
- for (x1, y1, x2, y2, cls) in boxes:
219
- cropped_img = image[int(y1 * scale_factor):int(y2 * scale_factor), int(x1 * scale_factor):int(x2 * scale_factor)]
220
- if cls == figure_class_index:
221
- figures.append(cropped_img)
222
- elif cls == table_class_index:
223
- tables.append(cropped_img)
224
- return figures, tables
225
-
226
-
227
- def process_pdf(pdf_file_path):
228
- doc = fitz.open(pdf_file_path)
229
- all_figures = []
230
- all_tables = []
231
- low_dpi = 50
232
- high_dpi = 300
233
- scale_factor = high_dpi / low_dpi
234
- low_res_pixmaps = [page.get_pixmap(dpi=low_dpi) for page in doc]
235
-
236
- for page_num, low_res_pix in enumerate(low_res_pixmaps):
237
- low_res_img = np.frombuffer(low_res_pix.samples, dtype=np.uint8).reshape(low_res_pix.height, low_res_pix.width, 3)
238
- boxes = infer_image_and_get_boxes(low_res_img)
239
-
240
- if boxes:
241
- high_res_pix = doc[page_num].get_pixmap(dpi=high_dpi)
242
- high_res_img = np.frombuffer(high_res_pix.samples, dtype=np.uint8).reshape(high_res_pix.height, high_res_pix.width, 3)
243
- figures, tables = crop_images_from_boxes(high_res_img, boxes, scale_factor)
244
- all_figures.extend(figures)
245
- all_tables.extend(tables)
246
-
247
- return all_figures, all_tables
248
-
249
- # Set the page configuration for a modern look
250
-
251
- # Set the page configuration for a modern look
252
- # Set the page configuration for a modern look
253
- st.set_page_config(page_title="PDF Reading Assistant", page_icon="πŸ“„", layout="wide")
254
-
255
- # Add some custom CSS for a modern look
256
- st.markdown("""
257
- <style>
258
- /* Main background and padding */
259
- .main {
260
- background-color: #f8f9fa;
261
- padding: 2rem;
262
- font-family: 'Arial', sans-serif;
263
- }
264
-
265
- /* Section headers */
266
- .section-header {
267
- font-size: 2rem;
268
- font-weight: bold;
269
- color: #343a40;
270
- margin-top: 2rem;
271
- margin-bottom: 1rem;
272
- text-align: center;
273
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
274
- }
275
-
276
- /* Containers */
277
- .uploaded-file-container, .chat-container, .summary-container, .extract-container {
278
- padding: 2rem;
279
- background-color: #ffffff;
280
- border-radius: 10px;
281
- margin-bottom: 2rem;
282
- box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
283
- }
284
-
285
- /* Buttons */
286
- .stButton>button {
287
- background-color: #007bff;
288
- color: white;
289
- padding: 0.6rem 1.2rem;
290
- border-radius: 5px;
291
- border: none;
292
- cursor: pointer;
293
- font-size: 1rem;
294
- transition: background-color 0.3s ease, transform 0.3s ease;
295
- }
296
- .stButton>button:hover {
297
- background-color: #0056b3;
298
- transform: translateY(-2px);
299
- }
300
-
301
- /* Chat messages */
302
- .chat-message {
303
- padding: 1rem;
304
- border-radius: 10px;
305
- margin-bottom: 1rem;
306
- font-size: 1rem;
307
- transition: all 0.3s ease;
308
- box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
309
- }
310
- .chat-message.user {
311
- background-color: #e6f7ff;
312
- border-left: 5px solid #007bff;
313
- text-align: left;
314
- }
315
- .chat-message.bot {
316
- background-color: #fff0f1;
317
- border-left: 5px solid #dc3545;
318
- text-align: left;
319
- }
320
-
321
- /* Input area */
322
- .input-container {
323
- display: flex;
324
- align-items: center;
325
- gap: 10px;
326
- margin-top: 1rem;
327
- }
328
- .input-container textarea {
329
- border: 2px solid #ccc;
330
- border-radius: 10px;
331
- padding: 10px;
332
- width: 100%;
333
- background-color: #fff;
334
- transition: border-color 0.3s ease;
335
- margin: 0;
336
- font-size: 1rem;
337
- }
338
- .input-container textarea:focus {
339
- border-color: #007bff;
340
- outline: none;
341
- }
342
- .input-container button {
343
- background-color: #007bff;
344
- color: white;
345
- padding: 0.6rem 1.2rem;
346
- border-radius: 5px;
347
- border: none;
348
- cursor: pointer;
349
- font-size: 1rem;
350
- transition: background-color 0.3s ease, transform 0.3s ease;
351
- }
352
- .input-container button:hover {
353
- background-color: #0056b3;
354
- transform: translateY(-2px);
355
- }
356
-
357
- /* Expander */
358
- .st-expander {
359
- border: none;
360
- box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
361
- margin-bottom: 2rem;
362
- }
363
-
364
- /* Markdown elements */
365
- .stMarkdown {
366
- font-size: 1rem;
367
- color: #343a40;
368
- line-height: 1.6;
369
- }
370
-
371
- /* Titles and subtitles */
372
- .stTitle {
373
- color: #343a40;
374
- text-align: center;
375
- margin-bottom: 1rem;
376
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
377
- }
378
- .stSubtitle {
379
- color: #6c757d;
380
- text-align: center;
381
- margin-bottom: 1rem;
382
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
383
- }
384
- </style>
385
- """, unsafe_allow_html=True)
386
-
387
- # Streamlit interface
388
- st.title("πŸ“„ PDF Reading Assistant")
389
- st.markdown("### Extract tables, figures, summaries, and answers from your PDF files easily.")
390
-
391
- uploaded_file = st.file_uploader("Upload a PDF", type="pdf")
392
- if uploaded_file:
393
- file_path = save_uploaded_file(uploaded_file)
394
- with st.container():
395
- st.markdown("<div class='section-header'>Extract Tables and Figures</div>", unsafe_allow_html=True)
396
- with st.expander("Click to Extract Tables and Figures", expanded=True):
397
- with st.container():
398
- extract_button = st.button("Extract")
399
- if extract_button:
400
- figures, tables = process_pdf(file_path)
401
- col1, col2 = st.columns(2)
402
- with col1:
403
- st.write("### Figures")
404
- if figures:
405
- for figure in figures:
406
- st.image(figure, use_column_width=True)
407
- else:
408
- st.write("No figures found.")
409
- with col2:
410
- st.write("### Tables")
411
- if tables:
412
- for table in tables:
413
- st.image(table, use_column_width=True)
414
- else:
415
- st.write("No tables found.")
416
-
417
- with st.container():
418
- st.markdown("<div class='section-header'>Get Summary</div>", unsafe_allow_html=True)
419
- with st.expander("Click to Generate Summary", expanded=True):
420
- with st.container():
421
- summary_button = st.button("Generate Summary")
422
- if summary_button:
423
- summary = summarize_pdf(file_path)
424
- st.markdown(summary, unsafe_allow_html=True)
425
-
426
- with st.container():
427
- st.markdown("<div class='section-header'>Chat with your PDF</div>", unsafe_allow_html=True)
428
- st.write("### Chat with your PDF")
429
- if 'chat_history' not in st.session_state:
430
- st.session_state['chat_history'] = []
431
-
432
- for chat in st.session_state['chat_history']:
433
- chat_user_class = "user" if chat["user"] else ""
434
- chat_bot_class = "bot" if chat["bot"] else ""
435
- st.markdown(f"<div class='chat-message {chat_user_class}'>{chat['user']}</div>", unsafe_allow_html=True)
436
- st.markdown(f"<div class='chat-message {chat_bot_class}'>{chat['bot']}</div>", unsafe_allow_html=True)
437
-
438
- with st.form(key="chat_form", clear_on_submit=True):
439
- user_input = st.text_area("Ask a question about the PDF:", key="user_input")
440
- submit_button = st.form_submit_button(label="Send")
441
-
442
- if submit_button and user_input:
443
- st.session_state['chat_history'].append({"user": user_input, "bot": None})
444
- answer = qa_pdf(file_path, user_input)
445
- st.session_state['chat_history'][-1]["bot"] = answer
446
  st.experimental_rerun()
 
1
+
2
+
3
+
4
+ import os
5
+ import streamlit as st
6
+ import numpy as np
7
+ import fitz # PyMuPDF
8
+ from ultralytics import YOLO
9
+ from sklearn.cluster import KMeans
10
+ from sklearn.metrics.pairwise import cosine_similarity
11
+ from langchain_core.output_parsers import StrOutputParser
12
+ from langchain_community.document_loaders import PyMuPDFLoader
13
+ from langchain_openai import OpenAIEmbeddings
14
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
15
+ from langchain_core.prompts import ChatPromptTemplate
16
+ from sklearn.decomposition import PCA
17
+ from langchain_openai import ChatOpenAI
18
+ import string
19
+ import re
20
+
21
+
22
+ # Load the trained model
23
+ model = YOLO("runs\\detect\\train7\\weights\\best.pt")
24
+ openai_api_key = os.environ.get("openai_api_key")
25
+
26
+ # Define the class indices for figures, tables, and text
27
+ figure_class_index = 4 # class index for figures
28
+ table_class_index = 3 # class index for tables
29
+
30
+ # Global variables to store embeddings and contents
31
+ global_embeddings = None
32
+ global_split_contents = None
33
+
34
+ def clean_text(text):
35
+ text = re.sub(r'\s+', ' ', text).strip()
36
+ return text
37
+
38
+ def remove_references(text):
39
+ reference_patterns = [
40
+ r'\bReferences\b', r'\breferences\b', r'\bBibliography\b', r'\bCitations\b',
41
+ r'\bWorks Cited\b', r'\bReference\b', r'\breference\b'
42
+ ]
43
+ lines = text.split('\n')
44
+ for i, line in enumerate(lines):
45
+ if any(re.search(pattern, line, re.IGNORECASE) for pattern in reference_patterns):
46
+ return '\n'.join(lines[:i])
47
+ return text
48
+
49
+ def save_uploaded_file(uploaded_file):
50
+ with open(uploaded_file.name, 'wb') as f:
51
+ f.write(uploaded_file.getbuffer())
52
+ return uploaded_file.name
53
+
54
+ def summarize_pdf(pdf_file_path, num_clusters=10):
55
+ embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
56
+ llm = ChatOpenAI(model="gpt-3.5-turbo", api_key=openai_api_key, temperature=0.3)
57
+ prompt = ChatPromptTemplate.from_template(
58
+ """Could you please provide a concise and comprehensive summary of the given Contexts?
59
+ The summary should capture the main points and key details of the text while conveying the author's intended meaning accurately.
60
+ Please ensure that the summary is well-organized and easy to read, with clear headings and subheadings to guide the reader through each section.
61
+ The length of the summary should be appropriate to capture the main points and key details of the text, without including unnecessary information or becoming overly long.
62
+ example of summary:
63
+ ## Summary:
64
+ ## Key points:
65
+ Contexts: {topic}"""
66
+ )
67
+ output_parser = StrOutputParser()
68
+ chain = prompt | llm | output_parser
69
+
70
+ loader = PyMuPDFLoader(pdf_file_path)
71
+ docs = loader.load()
72
+ full_text = "\n".join(doc.page_content for doc in docs)
73
+ cleaned_full_text = remove_references(full_text)
74
+ cleaned_full_text = clean_text(cleaned_full_text)
75
+
76
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0,separators=["\n\n", "\n",".", " "])
77
+ split_contents = text_splitter.split_text(cleaned_full_text)
78
+ embeddings = embeddings_model.embed_documents(split_contents)
79
+
80
+ X = np.array(embeddings)
81
+ kmeans = KMeans(n_clusters=num_clusters, init='k-means++', random_state=0).fit(embeddings)
82
+ cluster_centers = kmeans.cluster_centers_
83
+
84
+ closest_point_indices = []
85
+ for center in cluster_centers:
86
+ distances = np.linalg.norm(embeddings - center, axis=1)
87
+ closest_point_indices.append(np.argmin(distances))
88
+
89
+ extracted_contents = [split_contents[idx] for idx in closest_point_indices]
90
+ results = chain.invoke({"topic": ' '.join(extracted_contents)})
91
+
92
+ summary_sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', results)
93
+ summary_embeddings = embeddings_model.embed_documents(summary_sentences)
94
+ extracted_embeddings = embeddings_model.embed_documents(extracted_contents)
95
+ similarity_matrix = cosine_similarity(summary_embeddings, extracted_embeddings)
96
+
97
+ cited_results = results
98
+ relevant_sources = []
99
+ source_mapping = {}
100
+ sentence_to_source = {}
101
+ similarity_threshold = 0.6
102
+
103
+ for i, sentence in enumerate(summary_sentences):
104
+ if sentence in sentence_to_source:
105
+ continue
106
+ max_similarity = max(similarity_matrix[i])
107
+ if max_similarity >= similarity_threshold:
108
+ most_similar_idx = np.argmax(similarity_matrix[i])
109
+ if most_similar_idx not in source_mapping:
110
+ source_mapping[most_similar_idx] = len(relevant_sources) + 1
111
+ relevant_sources.append((most_similar_idx, extracted_contents[most_similar_idx]))
112
+ citation_idx = source_mapping[most_similar_idx]
113
+ citation = f"([Source {citation_idx}](#source-{citation_idx}))"
114
+ cited_sentence = re.sub(r'([.!?])$', f" {citation}\\1", sentence)
115
+ sentence_to_source[sentence] = citation_idx
116
+ cited_results = cited_results.replace(sentence, cited_sentence)
117
+
118
+ sources_list = "\n\n## Sources:\n"
119
+ for idx, (original_idx, content) in enumerate(relevant_sources):
120
+ sources_list += f"""
121
+ <details style="margin: 10px 0; padding: 10px; border: 1px solid #ccc; border-radius: 5px; background-color: #f9f9f9;">
122
+ <summary style="font-weight: bold; cursor: pointer;">Source {idx + 1}</summary>
123
+ <pre style="white-space: pre-wrap; word-wrap: break-word; margin-top: 10px;">{content}</pre>
124
+ </details>
125
+ """
126
+ cited_results += sources_list
127
+ return cited_results
128
+
129
+ def qa_pdf(pdf_file_path, query, num_clusters=5, similarity_threshold=0.6):
130
+ global global_embeddings, global_split_contents
131
+
132
+ # Initialize models and embeddings
133
+ embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small", api_key=openai_api_key)
134
+ llm = ChatOpenAI(model="gpt-3.5-turbo", api_key=openai_api_key, temperature=0.3)
135
+ prompt = ChatPromptTemplate.from_template(
136
+ """Please provide a detailed and accurate answer to the given question based on the provided contexts.
137
+ Ensure that the answer is comprehensive and directly addresses the query.
138
+ If necessary, include relevant examples or details from the text.
139
+ Question: {question}
140
+ Contexts: {contexts}"""
141
+ )
142
+ output_parser = StrOutputParser()
143
+ chain = prompt | llm | output_parser
144
+
145
+ # Load and process the PDF if not already loaded
146
+ if global_embeddings is None or global_split_contents is None:
147
+ loader = PyMuPDFLoader(pdf_file_path)
148
+ docs = loader.load()
149
+ full_text = "\n".join(doc.page_content for doc in docs)
150
+ cleaned_full_text = remove_references(full_text)
151
+ cleaned_full_text = clean_text(cleaned_full_text)
152
+
153
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=0, separators=["\n\n", "\n", ".", " "])
154
+ global_split_contents = text_splitter.split_text(cleaned_full_text)
155
+ global_embeddings = embeddings_model.embed_documents(global_split_contents)
156
+
157
+ # Embed the query and find the most relevant contexts
158
+ query_embedding = embeddings_model.embed_query(query)
159
+ similarity_scores = cosine_similarity([query_embedding], global_embeddings)[0]
160
+ top_indices = np.argsort(similarity_scores)[-num_clusters:]
161
+ relevant_contents = [global_split_contents[i] for i in top_indices]
162
+
163
+ # Generate the answer using the LLM chain
164
+ results = chain.invoke({"question": query, "contexts": ' '.join(relevant_contents)})
165
+
166
+ # Split the answer into sentences and embed them
167
+ answer_sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', results)
168
+ answer_embeddings = embeddings_model.embed_documents(answer_sentences)
169
+ relevant_embeddings = embeddings_model.embed_documents(relevant_contents)
170
+ similarity_matrix = cosine_similarity(answer_embeddings, relevant_embeddings)
171
+
172
+ # Map sentences to sources and create citations
173
+ cited_results = results
174
+ relevant_sources = []
175
+ source_mapping = {}
176
+ sentence_to_source = {}
177
+
178
+ for i, sentence in enumerate(answer_sentences):
179
+ if sentence in sentence_to_source:
180
+ continue
181
+ max_similarity = max(similarity_matrix[i])
182
+ if max_similarity >= similarity_threshold:
183
+ most_similar_idx = np.argmax(similarity_matrix[i])
184
+ if most_similar_idx not in source_mapping:
185
+ source_mapping[most_similar_idx] = len(relevant_sources) + 1
186
+ relevant_sources.append((most_similar_idx, relevant_contents[most_similar_idx]))
187
+ citation_idx = source_mapping[most_similar_idx]
188
+ citation = f"<strong style='color:blue;'>[Source {citation_idx}]</strong>"
189
+ cited_sentence = re.sub(r'([.!?])$', f" {citation}\\1", sentence)
190
+ sentence_to_source[sentence] = citation_idx
191
+ cited_results = cited_results.replace(sentence, cited_sentence)
192
+
193
+ # Format the sources for markdown rendering
194
+ sources_list = "\n\n## Sources:\n"
195
+ for idx, (original_idx, content) in enumerate(relevant_sources):
196
+ sources_list += f"""
197
+ <details style="margin: 10px 0; padding: 10px; border: 1px solid #ccc; border-radius: 5px; background-color: #f9f9f9;">
198
+ <summary style="font-weight: bold; cursor: pointer;">Source {idx + 1}</summary>
199
+ <pre style="white-space: pre-wrap; word-wrap: break-word; margin-top: 10px;">{content}</pre>
200
+ </details>
201
+ """
202
+ cited_results += sources_list
203
+ return cited_results
204
+
205
+
206
+ def infer_image_and_get_boxes(image, confidence_threshold=0.6):
207
+ results = model.predict(image)
208
+ boxes = [
209
+ (int(box.xyxy[0][0]), int(box.xyxy[0][1]), int(box.xyxy[0][2]), int(box.xyxy[0][3]), int(box.cls[0]))
210
+ for result in results for box in result.boxes
211
+ if int(box.cls[0]) in {figure_class_index, table_class_index} and box.conf[0] > confidence_threshold
212
+ ]
213
+ return boxes
214
+
215
+ def crop_images_from_boxes(image, boxes, scale_factor):
216
+ figures = []
217
+ tables = []
218
+ for (x1, y1, x2, y2, cls) in boxes:
219
+ cropped_img = image[int(y1 * scale_factor):int(y2 * scale_factor), int(x1 * scale_factor):int(x2 * scale_factor)]
220
+ if cls == figure_class_index:
221
+ figures.append(cropped_img)
222
+ elif cls == table_class_index:
223
+ tables.append(cropped_img)
224
+ return figures, tables
225
+
226
+
227
+ def process_pdf(pdf_file_path):
228
+ doc = fitz.open(pdf_file_path)
229
+ all_figures = []
230
+ all_tables = []
231
+ low_dpi = 50
232
+ high_dpi = 300
233
+ scale_factor = high_dpi / low_dpi
234
+ low_res_pixmaps = [page.get_pixmap(dpi=low_dpi) for page in doc]
235
+
236
+ for page_num, low_res_pix in enumerate(low_res_pixmaps):
237
+ low_res_img = np.frombuffer(low_res_pix.samples, dtype=np.uint8).reshape(low_res_pix.height, low_res_pix.width, 3)
238
+ boxes = infer_image_and_get_boxes(low_res_img)
239
+
240
+ if boxes:
241
+ high_res_pix = doc[page_num].get_pixmap(dpi=high_dpi)
242
+ high_res_img = np.frombuffer(high_res_pix.samples, dtype=np.uint8).reshape(high_res_pix.height, high_res_pix.width, 3)
243
+ figures, tables = crop_images_from_boxes(high_res_img, boxes, scale_factor)
244
+ all_figures.extend(figures)
245
+ all_tables.extend(tables)
246
+
247
+ return all_figures, all_tables
248
+
249
+ # Set the page configuration for a modern look
250
+
251
+ # Set the page configuration for a modern look
252
+ # Set the page configuration for a modern look
253
+ st.set_page_config(page_title="PDF Reading Assistant", page_icon="πŸ“„", layout="wide")
254
+
255
+ # Add some custom CSS for a modern look
256
+ st.markdown("""
257
+ <style>
258
+ /* Main background and padding */
259
+ .main {
260
+ background-color: #f8f9fa;
261
+ padding: 2rem;
262
+ font-family: 'Arial', sans-serif;
263
+ }
264
+
265
+ /* Section headers */
266
+ .section-header {
267
+ font-size: 2rem;
268
+ font-weight: bold;
269
+ color: #343a40;
270
+ margin-top: 2rem;
271
+ margin-bottom: 1rem;
272
+ text-align: center;
273
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
274
+ }
275
+
276
+ /* Containers */
277
+ .uploaded-file-container, .chat-container, .summary-container, .extract-container {
278
+ padding: 2rem;
279
+ background-color: #ffffff;
280
+ border-radius: 10px;
281
+ margin-bottom: 2rem;
282
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
283
+ }
284
+
285
+ /* Buttons */
286
+ .stButton>button {
287
+ background-color: #007bff;
288
+ color: white;
289
+ padding: 0.6rem 1.2rem;
290
+ border-radius: 5px;
291
+ border: none;
292
+ cursor: pointer;
293
+ font-size: 1rem;
294
+ transition: background-color 0.3s ease, transform 0.3s ease;
295
+ }
296
+ .stButton>button:hover {
297
+ background-color: #0056b3;
298
+ transform: translateY(-2px);
299
+ }
300
+
301
+ /* Chat messages */
302
+ .chat-message {
303
+ padding: 1rem;
304
+ border-radius: 10px;
305
+ margin-bottom: 1rem;
306
+ font-size: 1rem;
307
+ transition: all 0.3s ease;
308
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
309
+ }
310
+ .chat-message.user {
311
+ background-color: #e6f7ff;
312
+ border-left: 5px solid #007bff;
313
+ text-align: left;
314
+ }
315
+ .chat-message.bot {
316
+ background-color: #fff0f1;
317
+ border-left: 5px solid #dc3545;
318
+ text-align: left;
319
+ }
320
+
321
+ /* Input area */
322
+ .input-container {
323
+ display: flex;
324
+ align-items: center;
325
+ gap: 10px;
326
+ margin-top: 1rem;
327
+ }
328
+ .input-container textarea {
329
+ border: 2px solid #ccc;
330
+ border-radius: 10px;
331
+ padding: 10px;
332
+ width: 100%;
333
+ background-color: #fff;
334
+ transition: border-color 0.3s ease;
335
+ margin: 0;
336
+ font-size: 1rem;
337
+ }
338
+ .input-container textarea:focus {
339
+ border-color: #007bff;
340
+ outline: none;
341
+ }
342
+ .input-container button {
343
+ background-color: #007bff;
344
+ color: white;
345
+ padding: 0.6rem 1.2rem;
346
+ border-radius: 5px;
347
+ border: none;
348
+ cursor: pointer;
349
+ font-size: 1rem;
350
+ transition: background-color 0.3s ease, transform 0.3s ease;
351
+ }
352
+ .input-container button:hover {
353
+ background-color: #0056b3;
354
+ transform: translateY(-2px);
355
+ }
356
+
357
+ /* Expander */
358
+ .st-expander {
359
+ border: none;
360
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
361
+ margin-bottom: 2rem;
362
+ }
363
+
364
+ /* Markdown elements */
365
+ .stMarkdown {
366
+ font-size: 1rem;
367
+ color: #343a40;
368
+ line-height: 1.6;
369
+ }
370
+
371
+ /* Titles and subtitles */
372
+ .stTitle {
373
+ color: #343a40;
374
+ text-align: center;
375
+ margin-bottom: 1rem;
376
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
377
+ }
378
+ .stSubtitle {
379
+ color: #6c757d;
380
+ text-align: center;
381
+ margin-bottom: 1rem;
382
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
383
+ }
384
+ </style>
385
+ """, unsafe_allow_html=True)
386
+
387
+ # Streamlit interface
388
+ st.title("πŸ“„ PDF Reading Assistant")
389
+ st.markdown("### Extract tables, figures, summaries, and answers from your PDF files easily.")
390
+
391
+ uploaded_file = st.file_uploader("Upload a PDF", type="pdf")
392
+ if uploaded_file:
393
+ file_path = save_uploaded_file(uploaded_file)
394
+ with st.container():
395
+ st.markdown("<div class='section-header'>Extract Tables and Figures</div>", unsafe_allow_html=True)
396
+ with st.expander("Click to Extract Tables and Figures", expanded=True):
397
+ with st.container():
398
+ extract_button = st.button("Extract")
399
+ if extract_button:
400
+ figures, tables = process_pdf(file_path)
401
+ col1, col2 = st.columns(2)
402
+ with col1:
403
+ st.write("### Figures")
404
+ if figures:
405
+ for figure in figures:
406
+ st.image(figure, use_column_width=True)
407
+ else:
408
+ st.write("No figures found.")
409
+ with col2:
410
+ st.write("### Tables")
411
+ if tables:
412
+ for table in tables:
413
+ st.image(table, use_column_width=True)
414
+ else:
415
+ st.write("No tables found.")
416
+
417
+ with st.container():
418
+ st.markdown("<div class='section-header'>Get Summary</div>", unsafe_allow_html=True)
419
+ with st.expander("Click to Generate Summary", expanded=True):
420
+ with st.container():
421
+ summary_button = st.button("Generate Summary")
422
+ if summary_button:
423
+ summary = summarize_pdf(file_path)
424
+ st.markdown(summary, unsafe_allow_html=True)
425
+
426
+ with st.container():
427
+ st.markdown("<div class='section-header'>Chat with your PDF</div>", unsafe_allow_html=True)
428
+ st.write("### Chat with your PDF")
429
+ if 'chat_history' not in st.session_state:
430
+ st.session_state['chat_history'] = []
431
+
432
+ for chat in st.session_state['chat_history']:
433
+ chat_user_class = "user" if chat["user"] else ""
434
+ chat_bot_class = "bot" if chat["bot"] else ""
435
+ st.markdown(f"<div class='chat-message {chat_user_class}'>{chat['user']}</div>", unsafe_allow_html=True)
436
+ st.markdown(f"<div class='chat-message {chat_bot_class}'>{chat['bot']}</div>", unsafe_allow_html=True)
437
+
438
+ with st.form(key="chat_form", clear_on_submit=True):
439
+ user_input = st.text_area("Ask a question about the PDF:", key="user_input")
440
+ submit_button = st.form_submit_button(label="Send")
441
+
442
+ if submit_button and user_input:
443
+ st.session_state['chat_history'].append({"user": user_input, "bot": None})
444
+ answer = qa_pdf(file_path, user_input)
445
+ st.session_state['chat_history'][-1]["bot"] = answer
446
  st.experimental_rerun()