fadliaulawi
commited on
Commit
•
099741f
1
Parent(s):
74be325
Restructure LLM calls
Browse files- app.py +41 -38
- process.py +186 -181
app.py
CHANGED
@@ -3,24 +3,14 @@ import os
|
|
3 |
import pandas as pd
|
4 |
import streamlit as st
|
5 |
|
|
|
6 |
from datetime import datetime
|
7 |
from langchain_community.document_loaders.pdf import PyPDFLoader
|
8 |
from langchain_core.documents.base import Document
|
9 |
from langchain_text_splitters import TokenTextSplitter
|
10 |
-
from process import
|
11 |
from tempfile import NamedTemporaryFile
|
12 |
from stqdm import stqdm
|
13 |
-
from threading import Thread
|
14 |
-
|
15 |
-
class CustomThread(Thread):
|
16 |
-
def __init__(self, func, chunk):
|
17 |
-
super().__init__()
|
18 |
-
self.func = func
|
19 |
-
self.chunk = chunk
|
20 |
-
self.result = ''
|
21 |
-
|
22 |
-
def run(self):
|
23 |
-
self.result = self.func(self.chunk)
|
24 |
|
25 |
buffer = io.BytesIO()
|
26 |
|
@@ -34,18 +24,26 @@ uploaded_files = st.file_uploader("Upload Paper(s) here :", type="pdf", accept_m
|
|
34 |
col1, col2 = st.columns(2)
|
35 |
|
36 |
with col1:
|
37 |
-
|
38 |
-
'
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
40 |
)
|
41 |
-
chunk_overlap = 0
|
42 |
|
43 |
with col2:
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
)
|
|
|
|
|
|
|
|
|
49 |
|
50 |
if uploaded_files:
|
51 |
journals = []
|
@@ -58,6 +56,8 @@ if uploaded_files:
|
|
58 |
for uploaded_file in stqdm(uploaded_files):
|
59 |
with NamedTemporaryFile(dir='.', suffix=".pdf", delete=eval(os.getenv('DELETE_TEMP_PDF', 'True'))) as pdf:
|
60 |
pdf.write(uploaded_file.getbuffer())
|
|
|
|
|
61 |
loader = PyPDFLoader(pdf.name)
|
62 |
pages = loader.load()
|
63 |
|
@@ -65,6 +65,7 @@ if uploaded_files:
|
|
65 |
chunk_overlap = 0
|
66 |
docs = pages
|
67 |
|
|
|
68 |
if chunk_option:
|
69 |
docs = [Document('\n'.join([page.page_content for page in pages]))]
|
70 |
docs[0].metadata = {'source': pages[0].metadata['source']}
|
@@ -77,23 +78,22 @@ if uploaded_files:
|
|
77 |
)
|
78 |
chunks = text_splitter.split_documents(docs)
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
# Combine
|
97 |
result['Genes'] = res_gene + result_gsd['Genes']
|
98 |
result['SNPs'] = res_snp + result_gsd['SNPs']
|
99 |
result['Diseases'] = res_dis + result_gsd['Diseases']
|
@@ -119,9 +119,12 @@ if uploaded_files:
|
|
119 |
|
120 |
dataframe = pd.DataFrame(result)
|
121 |
dataframe = dataframe[['Genes', 'SNPs', 'Diseases', 'Title', 'Authors', 'Publisher Name', 'Publication Year', 'Population', 'Sample Size', 'Study Methodology', 'Study Level', 'Conclusion']]
|
|
|
122 |
dataframe.drop_duplicates(['Genes', 'SNPs'], inplace=True)
|
123 |
dataframe.reset_index(drop=True, inplace=True)
|
124 |
-
|
|
|
|
|
125 |
|
126 |
end_time = datetime.now()
|
127 |
st.write("Success in ", round((end_time.timestamp() - start_time.timestamp()) / 60, 2), "minutes")
|
|
|
3 |
import pandas as pd
|
4 |
import streamlit as st
|
5 |
|
6 |
+
from concurrent.futures import ThreadPoolExecutor
|
7 |
from datetime import datetime
|
8 |
from langchain_community.document_loaders.pdf import PyPDFLoader
|
9 |
from langchain_core.documents.base import Document
|
10 |
from langchain_text_splitters import TokenTextSplitter
|
11 |
+
from process import Process
|
12 |
from tempfile import NamedTemporaryFile
|
13 |
from stqdm import stqdm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
buffer = io.BytesIO()
|
16 |
|
|
|
24 |
col1, col2 = st.columns(2)
|
25 |
|
26 |
with col1:
|
27 |
+
models = (
|
28 |
+
'gpt-4-turbo',
|
29 |
+
'gemini-1.5-pro-latest'
|
30 |
+
# 'llama-3-sonar-large-32k-chat',
|
31 |
+
# 'mixtral-8x7b-instruct',
|
32 |
+
)
|
33 |
+
model = st.selectbox(
|
34 |
+
'Model selection:', models, key='model'
|
35 |
)
|
|
|
36 |
|
37 |
with col2:
|
38 |
+
tokens = (
|
39 |
+
24000,
|
40 |
+
16000,
|
41 |
+
8000
|
42 |
)
|
43 |
+
chunk_option = st.selectbox(
|
44 |
+
'Token amounts per process:', tokens, key='token'
|
45 |
+
)
|
46 |
+
chunk_overlap = 0
|
47 |
|
48 |
if uploaded_files:
|
49 |
journals = []
|
|
|
56 |
for uploaded_file in stqdm(uploaded_files):
|
57 |
with NamedTemporaryFile(dir='.', suffix=".pdf", delete=eval(os.getenv('DELETE_TEMP_PDF', 'True'))) as pdf:
|
58 |
pdf.write(uploaded_file.getbuffer())
|
59 |
+
|
60 |
+
# Load Documents
|
61 |
loader = PyPDFLoader(pdf.name)
|
62 |
pages = loader.load()
|
63 |
|
|
|
65 |
chunk_overlap = 0
|
66 |
docs = pages
|
67 |
|
68 |
+
# Split Documents
|
69 |
if chunk_option:
|
70 |
docs = [Document('\n'.join([page.page_content for page in pages]))]
|
71 |
docs[0].metadata = {'source': pages[0].metadata['source']}
|
|
|
78 |
)
|
79 |
chunks = text_splitter.split_documents(docs)
|
80 |
|
81 |
+
# Start extraction process in parallel
|
82 |
+
process = Process(model)
|
83 |
+
with ThreadPoolExecutor() as executor:
|
84 |
+
result_gsd = executor.submit(process.get_entity, (chunks, 'gsd'))
|
85 |
+
result_summ = executor.submit(process.get_entity, (chunks, 'summ'))
|
86 |
+
result = executor.submit(process.get_entity, (chunks, 'all'))
|
87 |
+
result_one = executor.submit(process.get_entity_one, [c.page_content for c in chunks[:1]])
|
88 |
+
result_table = executor.submit(process.get_table, pdf.name)
|
89 |
+
|
90 |
+
result_gsd = result_gsd.result()
|
91 |
+
result_summ = result_summ.result()
|
92 |
+
result = result.result()
|
93 |
+
result_one = result_one.result()
|
94 |
+
res_gene, res_snp, res_dis = result_table.result()
|
95 |
+
|
96 |
+
# Combine Result
|
|
|
97 |
result['Genes'] = res_gene + result_gsd['Genes']
|
98 |
result['SNPs'] = res_snp + result_gsd['SNPs']
|
99 |
result['Diseases'] = res_dis + result_gsd['Diseases']
|
|
|
119 |
|
120 |
dataframe = pd.DataFrame(result)
|
121 |
dataframe = dataframe[['Genes', 'SNPs', 'Diseases', 'Title', 'Authors', 'Publisher Name', 'Publication Year', 'Population', 'Sample Size', 'Study Methodology', 'Study Level', 'Conclusion']]
|
122 |
+
dataframe = dataframe[dataframe['Genes'].astype(bool)].reset_index(drop=True)
|
123 |
dataframe.drop_duplicates(['Genes', 'SNPs'], inplace=True)
|
124 |
dataframe.reset_index(drop=True, inplace=True)
|
125 |
+
|
126 |
+
# Validate Result
|
127 |
+
cleaned_df, cleaned_llm_df = process.validate(dataframe)
|
128 |
|
129 |
end_time = datetime.now()
|
130 |
st.write("Success in ", round((end_time.timestamp() - start_time.timestamp()) / 60, 2), "minutes")
|
process.py
CHANGED
@@ -6,12 +6,12 @@ from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
|
|
6 |
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
7 |
from langchain.chains.llm import LLMChain
|
8 |
from langchain.prompts import PromptTemplate
|
|
|
9 |
from langchain_openai import ChatOpenAI
|
10 |
from pdf2image import convert_from_path
|
11 |
from prompt import prompt_entity_gsd_chunk, prompt_entity_gsd_combine, prompt_entity_summ_chunk, prompt_entity_summ_combine, prompt_entities_chunk, prompt_entities_combine, prompt_entity_one_chunk, prompt_table, prompt_validation
|
12 |
from table_detector import detection_transform, device, model, ocr, outputs_to_objects
|
13 |
|
14 |
-
import google.generativeai as genai
|
15 |
import io
|
16 |
import json
|
17 |
import os
|
@@ -20,11 +20,6 @@ import re
|
|
20 |
import torch
|
21 |
|
22 |
load_dotenv()
|
23 |
-
genai.configure(api_key=os.environ['GOOGLE_API_KEY'])
|
24 |
-
|
25 |
-
llm = ChatOpenAI(temperature=0, model_name="gpt-4-turbo")
|
26 |
-
llm_p = ChatOpenAI(temperature=0, model_name="llama-3-sonar-large-32k-chat", api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai")
|
27 |
-
llm_g = genai.GenerativeModel(model_name='gemini-1.5-pro-latest')
|
28 |
|
29 |
prompts = {
|
30 |
'gsd': [prompt_entity_gsd_chunk, prompt_entity_gsd_combine],
|
@@ -32,211 +27,221 @@ prompts = {
|
|
32 |
'all': [prompt_entities_chunk, prompt_entities_combine]
|
33 |
}
|
34 |
|
35 |
-
|
36 |
|
37 |
-
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
42 |
|
43 |
-
|
44 |
-
reduce_prompt = PromptTemplate.from_template(reduce_template)
|
45 |
-
reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt)
|
46 |
|
47 |
-
|
48 |
-
llm_chain=reduce_chain, document_variable_name="doc_summaries"
|
49 |
-
)
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
token_max=100000,
|
55 |
-
)
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
document_variable_name="docs",
|
61 |
-
return_intermediate_steps=False,
|
62 |
-
)
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
if types != 'summ':
|
68 |
-
result = re.findall('(\{[^}]+\})', result)[0]
|
69 |
-
return eval(result)
|
70 |
-
|
71 |
-
return result
|
72 |
|
73 |
-
|
|
|
|
|
|
|
|
|
74 |
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
-
|
82 |
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
-
|
86 |
-
images = convert_from_path(path)
|
87 |
-
print('PDF to Image', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes")
|
88 |
-
tables = []
|
89 |
|
90 |
-
|
91 |
-
|
|
|
|
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
outputs = model(pixel_values)
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
if detected_tables[idx]["label"] == 'table rotated':
|
105 |
-
cropped_table = cropped_table.rotate(270, expand=True)
|
106 |
|
107 |
-
#
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
|
117 |
-
|
118 |
-
|
|
|
|
|
119 |
|
120 |
-
|
121 |
-
table
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
str_json_table = json.dumps(json.loads(json_table), indent=2)
|
159 |
|
160 |
-
result = llm.invoke(
|
161 |
-
print('
|
162 |
print(result)
|
|
|
163 |
result = result[result.find('['):result.rfind(']')+1]
|
164 |
try:
|
165 |
result = eval(result)
|
166 |
except SyntaxError:
|
167 |
result = []
|
168 |
|
169 |
-
|
170 |
-
|
171 |
-
res_snp = res['SNPs']
|
172 |
-
res_disease = res['Diseases']
|
173 |
-
|
174 |
-
for snp in res_snp:
|
175 |
-
genes.append(res_gene)
|
176 |
-
snps.append(snp)
|
177 |
-
diseases.append(res_disease)
|
178 |
-
|
179 |
-
print('OCR table to extract', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes")
|
180 |
-
print(genes, snps, diseases)
|
181 |
-
|
182 |
-
return genes, snps, diseases
|
183 |
-
|
184 |
-
def validate(df):
|
185 |
-
|
186 |
-
df = df[df['Genes'].notna()].reset_index(drop=True)
|
187 |
-
df = df.fillna('')
|
188 |
-
df['Genes'] = df['Genes'].str.upper()
|
189 |
-
df['SNPs'] = df['SNPs'].str.lower()
|
190 |
-
|
191 |
-
# Check if there is two gene names
|
192 |
-
sym = ['-', '/', '|']
|
193 |
-
for i in df.index:
|
194 |
-
gene = df.loc[i, 'Genes']
|
195 |
-
for s in sym:
|
196 |
-
if s in gene:
|
197 |
-
genes = gene.split(s)
|
198 |
-
df.loc[i + 0.5] = df.loc[i]
|
199 |
-
df = df.sort_index().reset_index(drop=True)
|
200 |
-
df.loc[i, 'Genes'], df.loc[i + 1, 'Genes'] = genes[0], genes[1]
|
201 |
-
|
202 |
-
# Check if there is SNPs without 'rs'
|
203 |
-
for i in df.index:
|
204 |
-
safe = True
|
205 |
-
snp = df.loc[i, 'SNPs']
|
206 |
-
if re.fullmatch('rs(\d)+|', snp):
|
207 |
-
pass
|
208 |
-
elif re.fullmatch('ts(\d)+', snp):
|
209 |
-
snp = 'r' + snp[1:]
|
210 |
-
elif re.fullmatch('s(\d)+', snp):
|
211 |
-
snp = 'r' + snp
|
212 |
-
elif re.fullmatch('(\d)+', snp):
|
213 |
-
snp = 'rs' + snp
|
214 |
-
else:
|
215 |
-
safe = False
|
216 |
-
df = df.drop(i)
|
217 |
-
|
218 |
-
if safe:
|
219 |
-
df.loc[i, 'SNPs'] = snp
|
220 |
-
|
221 |
-
df.reset_index(drop=True, inplace=True)
|
222 |
-
|
223 |
-
# Validate genes and diseases with LLM
|
224 |
-
json_table = df[['Genes', 'SNPs', 'Diseases']].to_json(orient='records')
|
225 |
-
str_json_table = json.dumps(json.loads(json_table), indent=2)
|
226 |
-
|
227 |
-
result = llm_p.invoke(input=prompt_validation.format(str_json_table)).content
|
228 |
-
print('val')
|
229 |
-
print(result)
|
230 |
-
|
231 |
-
result = result[result.find('['):result.rfind(']')+1]
|
232 |
-
try:
|
233 |
-
result = eval(result)
|
234 |
-
except SyntaxError:
|
235 |
-
result = []
|
236 |
-
|
237 |
-
df_val = pd.DataFrame(result)
|
238 |
-
df_val = df_val.merge(df.head(1).drop(['Genes', 'SNPs', 'Diseases'], axis=1), 'cross')
|
239 |
|
240 |
-
|
241 |
|
242 |
-
|
|
|
6 |
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
7 |
from langchain.chains.llm import LLMChain
|
8 |
from langchain.prompts import PromptTemplate
|
9 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
10 |
from langchain_openai import ChatOpenAI
|
11 |
from pdf2image import convert_from_path
|
12 |
from prompt import prompt_entity_gsd_chunk, prompt_entity_gsd_combine, prompt_entity_summ_chunk, prompt_entity_summ_combine, prompt_entities_chunk, prompt_entities_combine, prompt_entity_one_chunk, prompt_table, prompt_validation
|
13 |
from table_detector import detection_transform, device, model, ocr, outputs_to_objects
|
14 |
|
|
|
15 |
import io
|
16 |
import json
|
17 |
import os
|
|
|
20 |
import torch
|
21 |
|
22 |
load_dotenv()
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
prompts = {
|
25 |
'gsd': [prompt_entity_gsd_chunk, prompt_entity_gsd_combine],
|
|
|
27 |
'all': [prompt_entities_chunk, prompt_entities_combine]
|
28 |
}
|
29 |
|
30 |
+
class Process():
|
31 |
|
32 |
+
def __init__(self, llm):
|
33 |
|
34 |
+
if llm.startswith('gpt'):
|
35 |
+
self.llm = ChatOpenAI(temperature=0, model_name=llm)
|
36 |
+
elif llm.startswith('gemini'):
|
37 |
+
self.llm = ChatGoogleGenerativeAI(temperature=0, model=llm)
|
38 |
+
else:
|
39 |
+
self.llm = ChatOpenAI(temperature=0, model_name=llm, api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai")
|
40 |
|
41 |
+
def get_entity(self, data):
|
|
|
|
|
42 |
|
43 |
+
chunks, types = data
|
|
|
|
|
44 |
|
45 |
+
map_template = prompts[types][0]
|
46 |
+
map_prompt = PromptTemplate.from_template(map_template)
|
47 |
+
map_chain = LLMChain(llm=self.llm, prompt=map_prompt)
|
|
|
|
|
48 |
|
49 |
+
reduce_template = prompts[types][1]
|
50 |
+
reduce_prompt = PromptTemplate.from_template(reduce_template)
|
51 |
+
reduce_chain = LLMChain(llm=self.llm, prompt=reduce_prompt)
|
|
|
|
|
|
|
52 |
|
53 |
+
combine_chain = StuffDocumentsChain(
|
54 |
+
llm_chain=reduce_chain, document_variable_name="doc_summaries"
|
55 |
+
)
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
+
reduce_documents_chain = ReduceDocumentsChain(
|
58 |
+
combine_documents_chain=combine_chain,
|
59 |
+
collapse_documents_chain=combine_chain,
|
60 |
+
token_max=100000,
|
61 |
+
)
|
62 |
|
63 |
+
map_reduce_chain = MapReduceDocumentsChain(
|
64 |
+
llm_chain=map_chain,
|
65 |
+
reduce_documents_chain=reduce_documents_chain,
|
66 |
+
document_variable_name="docs",
|
67 |
+
return_intermediate_steps=False,
|
68 |
+
)
|
69 |
|
70 |
+
result = map_reduce_chain.invoke(chunks)['output_text']
|
71 |
+
print(types)
|
72 |
+
print(result)
|
73 |
+
if types != 'summ':
|
74 |
+
result = re.findall('(\{[^}]+\})', result)[0]
|
75 |
+
return eval(result)
|
76 |
+
|
77 |
+
return result
|
78 |
|
79 |
+
def get_entity_one(self, chunks):
|
80 |
|
81 |
+
result = self.llm.invoke(prompt_entity_one_chunk.format(chunks)).content
|
82 |
+
|
83 |
+
print('One')
|
84 |
+
print(result)
|
85 |
+
result = re.findall('(\{[^}]+\})', result)[0]
|
86 |
+
|
87 |
+
return eval(result)
|
88 |
|
89 |
+
def get_table(self, path):
|
|
|
|
|
|
|
90 |
|
91 |
+
start_time = datetime.now()
|
92 |
+
images = convert_from_path(path)
|
93 |
+
print('PDF to Image', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes")
|
94 |
+
tables = []
|
95 |
|
96 |
+
# Loop pages
|
97 |
+
for image in images:
|
|
|
98 |
|
99 |
+
pixel_values = detection_transform(image).unsqueeze(0).to(device)
|
100 |
+
with torch.no_grad():
|
101 |
+
outputs = model(pixel_values)
|
102 |
|
103 |
+
id2label = model.config.id2label
|
104 |
+
id2label[len(model.config.id2label)] = "no object"
|
105 |
+
detected_tables = outputs_to_objects(outputs, image.size, id2label)
|
|
|
|
|
106 |
|
107 |
+
# Loop table in page (if any)
|
108 |
+
for idx in range(len(detected_tables)):
|
109 |
+
cropped_table = image.crop(detected_tables[idx]["bbox"])
|
110 |
+
if detected_tables[idx]["label"] == 'table rotated':
|
111 |
+
cropped_table = cropped_table.rotate(270, expand=True)
|
112 |
|
113 |
+
# TODO: what is the perfect threshold?
|
114 |
+
if detected_tables[idx]['score'] > 0.9:
|
115 |
+
print(detected_tables[idx])
|
116 |
+
tables.append(cropped_table)
|
117 |
|
118 |
+
print('Detect table from image', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes")
|
119 |
+
genes = []
|
120 |
+
snps = []
|
121 |
+
diseases = []
|
122 |
|
123 |
+
# Loop tables
|
124 |
+
for table in tables:
|
125 |
+
|
126 |
+
buffer = io.BytesIO()
|
127 |
+
table.save(buffer, format='PNG')
|
128 |
+
image = Image(buffer)
|
129 |
+
|
130 |
+
# Extract to dataframe
|
131 |
+
extracted_tables = image.extract_tables(ocr=ocr, implicit_rows=True, borderless_tables=True, min_confidence=0)
|
132 |
+
|
133 |
+
if len(extracted_tables) == 0:
|
134 |
+
continue
|
135 |
+
|
136 |
+
# Combine multiple dataframe
|
137 |
+
df_table = extracted_tables[0].df
|
138 |
+
for extracted_table in extracted_tables[1:]:
|
139 |
+
df_table = pd.concat([df_table, extracted_table.df]).reset_index(drop=True)
|
140 |
+
|
141 |
+
df_table.loc[0] = df_table.loc[0].fillna('')
|
142 |
+
|
143 |
+
# Identify multiple rows (in dataframe) as one row (in image)
|
144 |
+
rows = []
|
145 |
+
indexes = []
|
146 |
+
for i in df_table.index:
|
147 |
+
if not df_table.loc[i].isna().any():
|
148 |
+
if len(indexes) > 0:
|
149 |
+
rows.append(indexes)
|
150 |
+
indexes = []
|
151 |
+
indexes.append(i)
|
152 |
+
rows.append(indexes)
|
153 |
+
|
154 |
+
df_table_cleaned = pd.DataFrame(columns=df_table.columns)
|
155 |
+
for row in rows:
|
156 |
+
row_str = df_table.loc[row[0]]
|
157 |
+
for idx in row[1:]:
|
158 |
+
row_str += ' ' + df_table.loc[idx].fillna('')
|
159 |
+
row_str = row_str.str.strip()
|
160 |
+
df_table_cleaned.loc[len(df_table_cleaned)] = row_str
|
161 |
+
|
162 |
+
# Ask LLM with JSON data
|
163 |
+
json_table = df_table_cleaned.to_json(orient='records')
|
164 |
+
str_json_table = json.dumps(json.loads(json_table), indent=2)
|
165 |
+
|
166 |
+
result = self.llm.invoke(prompt_table.format(str_json_table)).content
|
167 |
+
print('table')
|
168 |
+
print(result)
|
169 |
+
result = result[result.find('['):result.rfind(']')+1]
|
170 |
+
try:
|
171 |
+
result = eval(result)
|
172 |
+
except SyntaxError:
|
173 |
+
result = []
|
174 |
+
|
175 |
+
for res in result:
|
176 |
+
res_gene = res['Genes']
|
177 |
+
res_snp = res['SNPs']
|
178 |
+
res_disease = res['Diseases']
|
179 |
+
|
180 |
+
for snp in res_snp:
|
181 |
+
genes.append(res_gene)
|
182 |
+
snps.append(snp)
|
183 |
+
diseases.append(res_disease)
|
184 |
+
|
185 |
+
print('OCR table to extract', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes")
|
186 |
+
print(genes, snps, diseases)
|
187 |
|
188 |
+
return genes, snps, diseases
|
189 |
+
|
190 |
+
def validate(self, df):
|
191 |
+
|
192 |
+
df = df.fillna('')
|
193 |
+
df['Genes'] = df['Genes'].str.upper()
|
194 |
+
df['SNPs'] = df['SNPs'].str.lower()
|
195 |
+
|
196 |
+
# Check if there is two gene names
|
197 |
+
sym = ['-', '/', '|']
|
198 |
+
for i in df.index:
|
199 |
+
gene = df.loc[i, 'Genes']
|
200 |
+
for s in sym:
|
201 |
+
if s in gene:
|
202 |
+
genes = gene.split(s)
|
203 |
+
df.loc[i + 0.5] = df.loc[i]
|
204 |
+
df = df.sort_index().reset_index(drop=True)
|
205 |
+
df.loc[i, 'Genes'], df.loc[i + 1, 'Genes'] = genes[0], genes[1]
|
206 |
+
|
207 |
+
# Check if there is SNPs without 'rs'
|
208 |
+
for i in df.index:
|
209 |
+
safe = True
|
210 |
+
snp = df.loc[i, 'SNPs']
|
211 |
+
if re.fullmatch('rs(\d)+|', snp):
|
212 |
+
pass
|
213 |
+
elif re.fullmatch('ts(\d)+', snp):
|
214 |
+
snp = 'r' + snp[1:]
|
215 |
+
elif re.fullmatch('s(\d)+', snp):
|
216 |
+
snp = 'r' + snp
|
217 |
+
elif re.fullmatch('(\d)+', snp):
|
218 |
+
snp = 'rs' + snp
|
219 |
+
else:
|
220 |
+
safe = False
|
221 |
+
df = df.drop(i)
|
222 |
+
|
223 |
+
if safe:
|
224 |
+
df.loc[i, 'SNPs'] = snp
|
225 |
+
|
226 |
+
df.reset_index(drop=True, inplace=True)
|
227 |
+
|
228 |
+
# Validate genes and diseases with LLM
|
229 |
+
json_table = df[['Genes', 'SNPs', 'Diseases']].to_json(orient='records')
|
230 |
str_json_table = json.dumps(json.loads(json_table), indent=2)
|
231 |
|
232 |
+
result = self.llm.invoke(input=prompt_validation.format(str_json_table)).content
|
233 |
+
print('val')
|
234 |
print(result)
|
235 |
+
|
236 |
result = result[result.find('['):result.rfind(']')+1]
|
237 |
try:
|
238 |
result = eval(result)
|
239 |
except SyntaxError:
|
240 |
result = []
|
241 |
|
242 |
+
df_val = pd.DataFrame(result)
|
243 |
+
df_val = df_val.merge(df.head(1).drop(['Genes', 'SNPs', 'Diseases'], axis=1), 'cross')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
|
245 |
+
# TODO: How to validate genes and SNPs with ground truth?
|
246 |
|
247 |
+
return df, df_val
|