fadliaulawi commited on
Commit
099741f
1 Parent(s): 74be325

Restructure LLM calls

Browse files
Files changed (2) hide show
  1. app.py +41 -38
  2. process.py +186 -181
app.py CHANGED
@@ -3,24 +3,14 @@ import os
3
  import pandas as pd
4
  import streamlit as st
5
 
 
6
  from datetime import datetime
7
  from langchain_community.document_loaders.pdf import PyPDFLoader
8
  from langchain_core.documents.base import Document
9
  from langchain_text_splitters import TokenTextSplitter
10
- from process import get_entity, get_entity_one, get_table, validate
11
  from tempfile import NamedTemporaryFile
12
  from stqdm import stqdm
13
- from threading import Thread
14
-
15
- class CustomThread(Thread):
16
- def __init__(self, func, chunk):
17
- super().__init__()
18
- self.func = func
19
- self.chunk = chunk
20
- self.result = ''
21
-
22
- def run(self):
23
- self.result = self.func(self.chunk)
24
 
25
  buffer = io.BytesIO()
26
 
@@ -34,18 +24,26 @@ uploaded_files = st.file_uploader("Upload Paper(s) here :", type="pdf", accept_m
34
  col1, col2 = st.columns(2)
35
 
36
  with col1:
37
- chunk_option = st.selectbox(
38
- 'Token amounts per process:',
39
- (24000, 16000, 8000), key='token'
 
 
 
 
 
40
  )
41
- chunk_overlap = 0
42
 
43
  with col2:
44
- model = st.selectbox(
45
- 'Model selection: (UNDER DEVELOPED)',
46
- # 128000, 32768, 1048576
47
- ('gpt-4-turbo', 'llama-3-sonar-large-32k-chat', 'gemini-1.5-pro-latest'), key='model'
48
  )
 
 
 
 
49
 
50
  if uploaded_files:
51
  journals = []
@@ -58,6 +56,8 @@ if uploaded_files:
58
  for uploaded_file in stqdm(uploaded_files):
59
  with NamedTemporaryFile(dir='.', suffix=".pdf", delete=eval(os.getenv('DELETE_TEMP_PDF', 'True'))) as pdf:
60
  pdf.write(uploaded_file.getbuffer())
 
 
61
  loader = PyPDFLoader(pdf.name)
62
  pages = loader.load()
63
 
@@ -65,6 +65,7 @@ if uploaded_files:
65
  chunk_overlap = 0
66
  docs = pages
67
 
 
68
  if chunk_option:
69
  docs = [Document('\n'.join([page.page_content for page in pages]))]
70
  docs[0].metadata = {'source': pages[0].metadata['source']}
@@ -77,23 +78,22 @@ if uploaded_files:
77
  )
78
  chunks = text_splitter.split_documents(docs)
79
 
80
- threads = []
81
- threads.append(CustomThread(get_entity, (chunks, 'gsd')))
82
- threads.append(CustomThread(get_entity, (chunks, 'summ')))
83
- threads.append(CustomThread(get_entity, (chunks, 'all')))
84
- threads.append(CustomThread(get_entity_one, [c.page_content for c in chunks[:1]]))
85
- threads.append(CustomThread(get_table, pdf.name))
86
-
87
- [t.start() for t in threads]
88
- [t.join() for t in threads]
89
-
90
- result_gsd = threads[0].result
91
- result_summ = threads[1].result
92
- result = threads[2].result
93
- result_one = threads[3].result
94
- res_gene, res_snp, res_dis = threads[4].result
95
-
96
- # Combine
97
  result['Genes'] = res_gene + result_gsd['Genes']
98
  result['SNPs'] = res_snp + result_gsd['SNPs']
99
  result['Diseases'] = res_dis + result_gsd['Diseases']
@@ -119,9 +119,12 @@ if uploaded_files:
119
 
120
  dataframe = pd.DataFrame(result)
121
  dataframe = dataframe[['Genes', 'SNPs', 'Diseases', 'Title', 'Authors', 'Publisher Name', 'Publication Year', 'Population', 'Sample Size', 'Study Methodology', 'Study Level', 'Conclusion']]
 
122
  dataframe.drop_duplicates(['Genes', 'SNPs'], inplace=True)
123
  dataframe.reset_index(drop=True, inplace=True)
124
- cleaned_df, cleaned_llm_df = validate(dataframe)
 
 
125
 
126
  end_time = datetime.now()
127
  st.write("Success in ", round((end_time.timestamp() - start_time.timestamp()) / 60, 2), "minutes")
 
3
  import pandas as pd
4
  import streamlit as st
5
 
6
+ from concurrent.futures import ThreadPoolExecutor
7
  from datetime import datetime
8
  from langchain_community.document_loaders.pdf import PyPDFLoader
9
  from langchain_core.documents.base import Document
10
  from langchain_text_splitters import TokenTextSplitter
11
+ from process import Process
12
  from tempfile import NamedTemporaryFile
13
  from stqdm import stqdm
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  buffer = io.BytesIO()
16
 
 
24
  col1, col2 = st.columns(2)
25
 
26
  with col1:
27
+ models = (
28
+ 'gpt-4-turbo',
29
+ 'gemini-1.5-pro-latest'
30
+ # 'llama-3-sonar-large-32k-chat',
31
+ # 'mixtral-8x7b-instruct',
32
+ )
33
+ model = st.selectbox(
34
+ 'Model selection:', models, key='model'
35
  )
 
36
 
37
  with col2:
38
+ tokens = (
39
+ 24000,
40
+ 16000,
41
+ 8000
42
  )
43
+ chunk_option = st.selectbox(
44
+ 'Token amounts per process:', tokens, key='token'
45
+ )
46
+ chunk_overlap = 0
47
 
48
  if uploaded_files:
49
  journals = []
 
56
  for uploaded_file in stqdm(uploaded_files):
57
  with NamedTemporaryFile(dir='.', suffix=".pdf", delete=eval(os.getenv('DELETE_TEMP_PDF', 'True'))) as pdf:
58
  pdf.write(uploaded_file.getbuffer())
59
+
60
+ # Load Documents
61
  loader = PyPDFLoader(pdf.name)
62
  pages = loader.load()
63
 
 
65
  chunk_overlap = 0
66
  docs = pages
67
 
68
+ # Split Documents
69
  if chunk_option:
70
  docs = [Document('\n'.join([page.page_content for page in pages]))]
71
  docs[0].metadata = {'source': pages[0].metadata['source']}
 
78
  )
79
  chunks = text_splitter.split_documents(docs)
80
 
81
+ # Start extraction process in parallel
82
+ process = Process(model)
83
+ with ThreadPoolExecutor() as executor:
84
+ result_gsd = executor.submit(process.get_entity, (chunks, 'gsd'))
85
+ result_summ = executor.submit(process.get_entity, (chunks, 'summ'))
86
+ result = executor.submit(process.get_entity, (chunks, 'all'))
87
+ result_one = executor.submit(process.get_entity_one, [c.page_content for c in chunks[:1]])
88
+ result_table = executor.submit(process.get_table, pdf.name)
89
+
90
+ result_gsd = result_gsd.result()
91
+ result_summ = result_summ.result()
92
+ result = result.result()
93
+ result_one = result_one.result()
94
+ res_gene, res_snp, res_dis = result_table.result()
95
+
96
+ # Combine Result
 
97
  result['Genes'] = res_gene + result_gsd['Genes']
98
  result['SNPs'] = res_snp + result_gsd['SNPs']
99
  result['Diseases'] = res_dis + result_gsd['Diseases']
 
119
 
120
  dataframe = pd.DataFrame(result)
121
  dataframe = dataframe[['Genes', 'SNPs', 'Diseases', 'Title', 'Authors', 'Publisher Name', 'Publication Year', 'Population', 'Sample Size', 'Study Methodology', 'Study Level', 'Conclusion']]
122
+ dataframe = dataframe[dataframe['Genes'].astype(bool)].reset_index(drop=True)
123
  dataframe.drop_duplicates(['Genes', 'SNPs'], inplace=True)
124
  dataframe.reset_index(drop=True, inplace=True)
125
+
126
+ # Validate Result
127
+ cleaned_df, cleaned_llm_df = process.validate(dataframe)
128
 
129
  end_time = datetime.now()
130
  st.write("Success in ", round((end_time.timestamp() - start_time.timestamp()) / 60, 2), "minutes")
process.py CHANGED
@@ -6,12 +6,12 @@ from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
6
  from langchain.chains.combine_documents.stuff import StuffDocumentsChain
7
  from langchain.chains.llm import LLMChain
8
  from langchain.prompts import PromptTemplate
 
9
  from langchain_openai import ChatOpenAI
10
  from pdf2image import convert_from_path
11
  from prompt import prompt_entity_gsd_chunk, prompt_entity_gsd_combine, prompt_entity_summ_chunk, prompt_entity_summ_combine, prompt_entities_chunk, prompt_entities_combine, prompt_entity_one_chunk, prompt_table, prompt_validation
12
  from table_detector import detection_transform, device, model, ocr, outputs_to_objects
13
 
14
- import google.generativeai as genai
15
  import io
16
  import json
17
  import os
@@ -20,11 +20,6 @@ import re
20
  import torch
21
 
22
  load_dotenv()
23
- genai.configure(api_key=os.environ['GOOGLE_API_KEY'])
24
-
25
- llm = ChatOpenAI(temperature=0, model_name="gpt-4-turbo")
26
- llm_p = ChatOpenAI(temperature=0, model_name="llama-3-sonar-large-32k-chat", api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai")
27
- llm_g = genai.GenerativeModel(model_name='gemini-1.5-pro-latest')
28
 
29
  prompts = {
30
  'gsd': [prompt_entity_gsd_chunk, prompt_entity_gsd_combine],
@@ -32,211 +27,221 @@ prompts = {
32
  'all': [prompt_entities_chunk, prompt_entities_combine]
33
  }
34
 
35
- def get_entity(data):
36
 
37
- chunks, types = data
38
 
39
- map_template = prompts[types][0]
40
- map_prompt = PromptTemplate.from_template(map_template)
41
- map_chain = LLMChain(llm=llm, prompt=map_prompt)
 
 
 
42
 
43
- reduce_template = prompts[types][1]
44
- reduce_prompt = PromptTemplate.from_template(reduce_template)
45
- reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt)
46
 
47
- combine_chain = StuffDocumentsChain(
48
- llm_chain=reduce_chain, document_variable_name="doc_summaries"
49
- )
50
 
51
- reduce_documents_chain = ReduceDocumentsChain(
52
- combine_documents_chain=combine_chain,
53
- collapse_documents_chain=combine_chain,
54
- token_max=100000,
55
- )
56
 
57
- map_reduce_chain = MapReduceDocumentsChain(
58
- llm_chain=map_chain,
59
- reduce_documents_chain=reduce_documents_chain,
60
- document_variable_name="docs",
61
- return_intermediate_steps=False,
62
- )
63
 
64
- result = map_reduce_chain.invoke(chunks)['output_text']
65
- print(types)
66
- print(result)
67
- if types != 'summ':
68
- result = re.findall('(\{[^}]+\})', result)[0]
69
- return eval(result)
70
-
71
- return result
72
 
73
- def get_entity_one(chunks):
 
 
 
 
74
 
75
- result = llm.invoke(prompt_entity_one_chunk.format(chunks)).content
 
 
 
 
 
76
 
77
- print('One')
78
- print(result)
79
- result = re.findall('(\{[^}]+\})', result)[0]
 
 
 
 
 
80
 
81
- return eval(result)
82
 
83
- def get_table(path):
 
 
 
 
 
 
84
 
85
- start_time = datetime.now()
86
- images = convert_from_path(path)
87
- print('PDF to Image', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes")
88
- tables = []
89
 
90
- # Loop pages
91
- for image in images:
 
 
92
 
93
- pixel_values = detection_transform(image).unsqueeze(0).to(device)
94
- with torch.no_grad():
95
- outputs = model(pixel_values)
96
 
97
- id2label = model.config.id2label
98
- id2label[len(model.config.id2label)] = "no object"
99
- detected_tables = outputs_to_objects(outputs, image.size, id2label)
100
 
101
- # Loop table in page (if any)
102
- for idx in range(len(detected_tables)):
103
- cropped_table = image.crop(detected_tables[idx]["bbox"])
104
- if detected_tables[idx]["label"] == 'table rotated':
105
- cropped_table = cropped_table.rotate(270, expand=True)
106
 
107
- # TODO: what is the perfect threshold?
108
- if detected_tables[idx]['score'] > 0.9:
109
- print(detected_tables[idx])
110
- tables.append(cropped_table)
 
111
 
112
- print('Detect table from image', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes")
113
- genes = []
114
- snps = []
115
- diseases = []
116
 
117
- # Loop tables
118
- for table in tables:
 
 
119
 
120
- buffer = io.BytesIO()
121
- table.save(buffer, format='PNG')
122
- image = Image(buffer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
- # Extract to dataframe
125
- extracted_tables = image.extract_tables(ocr=ocr, implicit_rows=True, borderless_tables=True, min_confidence=0)
126
-
127
- if len(extracted_tables) == 0:
128
- continue
129
-
130
- # Combine multiple dataframe
131
- df_table = extracted_tables[0].df
132
- for extracted_table in extracted_tables[1:]:
133
- df_table = pd.concat([df_table, extracted_table.df]).reset_index(drop=True)
134
-
135
- df_table.loc[0] = df_table.loc[0].fillna('')
136
-
137
- # Identify multiple rows (in dataframe) as one row (in image)
138
- rows = []
139
- indexes = []
140
- for i in df_table.index:
141
- if not df_table.loc[i].isna().any():
142
- if len(indexes) > 0:
143
- rows.append(indexes)
144
- indexes = []
145
- indexes.append(i)
146
- rows.append(indexes)
147
-
148
- df_table_cleaned = pd.DataFrame(columns=df_table.columns)
149
- for row in rows:
150
- row_str = df_table.loc[row[0]]
151
- for idx in row[1:]:
152
- row_str += ' ' + df_table.loc[idx].fillna('')
153
- row_str = row_str.str.strip()
154
- df_table_cleaned.loc[len(df_table_cleaned)] = row_str
155
-
156
- # Ask LLM with JSON data
157
- json_table = df_table_cleaned.to_json(orient='records')
 
 
 
 
 
 
 
 
158
  str_json_table = json.dumps(json.loads(json_table), indent=2)
159
 
160
- result = llm.invoke(prompt_table.format(str_json_table)).content
161
- print('table')
162
  print(result)
 
163
  result = result[result.find('['):result.rfind(']')+1]
164
  try:
165
  result = eval(result)
166
  except SyntaxError:
167
  result = []
168
 
169
- for res in result:
170
- res_gene = res['Genes']
171
- res_snp = res['SNPs']
172
- res_disease = res['Diseases']
173
-
174
- for snp in res_snp:
175
- genes.append(res_gene)
176
- snps.append(snp)
177
- diseases.append(res_disease)
178
-
179
- print('OCR table to extract', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes")
180
- print(genes, snps, diseases)
181
-
182
- return genes, snps, diseases
183
-
184
- def validate(df):
185
-
186
- df = df[df['Genes'].notna()].reset_index(drop=True)
187
- df = df.fillna('')
188
- df['Genes'] = df['Genes'].str.upper()
189
- df['SNPs'] = df['SNPs'].str.lower()
190
-
191
- # Check if there is two gene names
192
- sym = ['-', '/', '|']
193
- for i in df.index:
194
- gene = df.loc[i, 'Genes']
195
- for s in sym:
196
- if s in gene:
197
- genes = gene.split(s)
198
- df.loc[i + 0.5] = df.loc[i]
199
- df = df.sort_index().reset_index(drop=True)
200
- df.loc[i, 'Genes'], df.loc[i + 1, 'Genes'] = genes[0], genes[1]
201
-
202
- # Check if there is SNPs without 'rs'
203
- for i in df.index:
204
- safe = True
205
- snp = df.loc[i, 'SNPs']
206
- if re.fullmatch('rs(\d)+|', snp):
207
- pass
208
- elif re.fullmatch('ts(\d)+', snp):
209
- snp = 'r' + snp[1:]
210
- elif re.fullmatch('s(\d)+', snp):
211
- snp = 'r' + snp
212
- elif re.fullmatch('(\d)+', snp):
213
- snp = 'rs' + snp
214
- else:
215
- safe = False
216
- df = df.drop(i)
217
-
218
- if safe:
219
- df.loc[i, 'SNPs'] = snp
220
-
221
- df.reset_index(drop=True, inplace=True)
222
-
223
- # Validate genes and diseases with LLM
224
- json_table = df[['Genes', 'SNPs', 'Diseases']].to_json(orient='records')
225
- str_json_table = json.dumps(json.loads(json_table), indent=2)
226
-
227
- result = llm_p.invoke(input=prompt_validation.format(str_json_table)).content
228
- print('val')
229
- print(result)
230
-
231
- result = result[result.find('['):result.rfind(']')+1]
232
- try:
233
- result = eval(result)
234
- except SyntaxError:
235
- result = []
236
-
237
- df_val = pd.DataFrame(result)
238
- df_val = df_val.merge(df.head(1).drop(['Genes', 'SNPs', 'Diseases'], axis=1), 'cross')
239
 
240
- # TODO: How to validate genes and SNPs with ground truth?
241
 
242
- return df, df_val
 
6
  from langchain.chains.combine_documents.stuff import StuffDocumentsChain
7
  from langchain.chains.llm import LLMChain
8
  from langchain.prompts import PromptTemplate
9
+ from langchain_google_genai import ChatGoogleGenerativeAI
10
  from langchain_openai import ChatOpenAI
11
  from pdf2image import convert_from_path
12
  from prompt import prompt_entity_gsd_chunk, prompt_entity_gsd_combine, prompt_entity_summ_chunk, prompt_entity_summ_combine, prompt_entities_chunk, prompt_entities_combine, prompt_entity_one_chunk, prompt_table, prompt_validation
13
  from table_detector import detection_transform, device, model, ocr, outputs_to_objects
14
 
 
15
  import io
16
  import json
17
  import os
 
20
  import torch
21
 
22
  load_dotenv()
 
 
 
 
 
23
 
24
  prompts = {
25
  'gsd': [prompt_entity_gsd_chunk, prompt_entity_gsd_combine],
 
27
  'all': [prompt_entities_chunk, prompt_entities_combine]
28
  }
29
 
30
+ class Process():
31
 
32
+ def __init__(self, llm):
33
 
34
+ if llm.startswith('gpt'):
35
+ self.llm = ChatOpenAI(temperature=0, model_name=llm)
36
+ elif llm.startswith('gemini'):
37
+ self.llm = ChatGoogleGenerativeAI(temperature=0, model=llm)
38
+ else:
39
+ self.llm = ChatOpenAI(temperature=0, model_name=llm, api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai")
40
 
41
+ def get_entity(self, data):
 
 
42
 
43
+ chunks, types = data
 
 
44
 
45
+ map_template = prompts[types][0]
46
+ map_prompt = PromptTemplate.from_template(map_template)
47
+ map_chain = LLMChain(llm=self.llm, prompt=map_prompt)
 
 
48
 
49
+ reduce_template = prompts[types][1]
50
+ reduce_prompt = PromptTemplate.from_template(reduce_template)
51
+ reduce_chain = LLMChain(llm=self.llm, prompt=reduce_prompt)
 
 
 
52
 
53
+ combine_chain = StuffDocumentsChain(
54
+ llm_chain=reduce_chain, document_variable_name="doc_summaries"
55
+ )
 
 
 
 
 
56
 
57
+ reduce_documents_chain = ReduceDocumentsChain(
58
+ combine_documents_chain=combine_chain,
59
+ collapse_documents_chain=combine_chain,
60
+ token_max=100000,
61
+ )
62
 
63
+ map_reduce_chain = MapReduceDocumentsChain(
64
+ llm_chain=map_chain,
65
+ reduce_documents_chain=reduce_documents_chain,
66
+ document_variable_name="docs",
67
+ return_intermediate_steps=False,
68
+ )
69
 
70
+ result = map_reduce_chain.invoke(chunks)['output_text']
71
+ print(types)
72
+ print(result)
73
+ if types != 'summ':
74
+ result = re.findall('(\{[^}]+\})', result)[0]
75
+ return eval(result)
76
+
77
+ return result
78
 
79
+ def get_entity_one(self, chunks):
80
 
81
+ result = self.llm.invoke(prompt_entity_one_chunk.format(chunks)).content
82
+
83
+ print('One')
84
+ print(result)
85
+ result = re.findall('(\{[^}]+\})', result)[0]
86
+
87
+ return eval(result)
88
 
89
+ def get_table(self, path):
 
 
 
90
 
91
+ start_time = datetime.now()
92
+ images = convert_from_path(path)
93
+ print('PDF to Image', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes")
94
+ tables = []
95
 
96
+ # Loop pages
97
+ for image in images:
 
98
 
99
+ pixel_values = detection_transform(image).unsqueeze(0).to(device)
100
+ with torch.no_grad():
101
+ outputs = model(pixel_values)
102
 
103
+ id2label = model.config.id2label
104
+ id2label[len(model.config.id2label)] = "no object"
105
+ detected_tables = outputs_to_objects(outputs, image.size, id2label)
 
 
106
 
107
+ # Loop table in page (if any)
108
+ for idx in range(len(detected_tables)):
109
+ cropped_table = image.crop(detected_tables[idx]["bbox"])
110
+ if detected_tables[idx]["label"] == 'table rotated':
111
+ cropped_table = cropped_table.rotate(270, expand=True)
112
 
113
+ # TODO: what is the perfect threshold?
114
+ if detected_tables[idx]['score'] > 0.9:
115
+ print(detected_tables[idx])
116
+ tables.append(cropped_table)
117
 
118
+ print('Detect table from image', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes")
119
+ genes = []
120
+ snps = []
121
+ diseases = []
122
 
123
+ # Loop tables
124
+ for table in tables:
125
+
126
+ buffer = io.BytesIO()
127
+ table.save(buffer, format='PNG')
128
+ image = Image(buffer)
129
+
130
+ # Extract to dataframe
131
+ extracted_tables = image.extract_tables(ocr=ocr, implicit_rows=True, borderless_tables=True, min_confidence=0)
132
+
133
+ if len(extracted_tables) == 0:
134
+ continue
135
+
136
+ # Combine multiple dataframe
137
+ df_table = extracted_tables[0].df
138
+ for extracted_table in extracted_tables[1:]:
139
+ df_table = pd.concat([df_table, extracted_table.df]).reset_index(drop=True)
140
+
141
+ df_table.loc[0] = df_table.loc[0].fillna('')
142
+
143
+ # Identify multiple rows (in dataframe) as one row (in image)
144
+ rows = []
145
+ indexes = []
146
+ for i in df_table.index:
147
+ if not df_table.loc[i].isna().any():
148
+ if len(indexes) > 0:
149
+ rows.append(indexes)
150
+ indexes = []
151
+ indexes.append(i)
152
+ rows.append(indexes)
153
+
154
+ df_table_cleaned = pd.DataFrame(columns=df_table.columns)
155
+ for row in rows:
156
+ row_str = df_table.loc[row[0]]
157
+ for idx in row[1:]:
158
+ row_str += ' ' + df_table.loc[idx].fillna('')
159
+ row_str = row_str.str.strip()
160
+ df_table_cleaned.loc[len(df_table_cleaned)] = row_str
161
+
162
+ # Ask LLM with JSON data
163
+ json_table = df_table_cleaned.to_json(orient='records')
164
+ str_json_table = json.dumps(json.loads(json_table), indent=2)
165
+
166
+ result = self.llm.invoke(prompt_table.format(str_json_table)).content
167
+ print('table')
168
+ print(result)
169
+ result = result[result.find('['):result.rfind(']')+1]
170
+ try:
171
+ result = eval(result)
172
+ except SyntaxError:
173
+ result = []
174
+
175
+ for res in result:
176
+ res_gene = res['Genes']
177
+ res_snp = res['SNPs']
178
+ res_disease = res['Diseases']
179
+
180
+ for snp in res_snp:
181
+ genes.append(res_gene)
182
+ snps.append(snp)
183
+ diseases.append(res_disease)
184
+
185
+ print('OCR table to extract', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes")
186
+ print(genes, snps, diseases)
187
 
188
+ return genes, snps, diseases
189
+
190
+ def validate(self, df):
191
+
192
+ df = df.fillna('')
193
+ df['Genes'] = df['Genes'].str.upper()
194
+ df['SNPs'] = df['SNPs'].str.lower()
195
+
196
+ # Check if there is two gene names
197
+ sym = ['-', '/', '|']
198
+ for i in df.index:
199
+ gene = df.loc[i, 'Genes']
200
+ for s in sym:
201
+ if s in gene:
202
+ genes = gene.split(s)
203
+ df.loc[i + 0.5] = df.loc[i]
204
+ df = df.sort_index().reset_index(drop=True)
205
+ df.loc[i, 'Genes'], df.loc[i + 1, 'Genes'] = genes[0], genes[1]
206
+
207
+ # Check if there is SNPs without 'rs'
208
+ for i in df.index:
209
+ safe = True
210
+ snp = df.loc[i, 'SNPs']
211
+ if re.fullmatch('rs(\d)+|', snp):
212
+ pass
213
+ elif re.fullmatch('ts(\d)+', snp):
214
+ snp = 'r' + snp[1:]
215
+ elif re.fullmatch('s(\d)+', snp):
216
+ snp = 'r' + snp
217
+ elif re.fullmatch('(\d)+', snp):
218
+ snp = 'rs' + snp
219
+ else:
220
+ safe = False
221
+ df = df.drop(i)
222
+
223
+ if safe:
224
+ df.loc[i, 'SNPs'] = snp
225
+
226
+ df.reset_index(drop=True, inplace=True)
227
+
228
+ # Validate genes and diseases with LLM
229
+ json_table = df[['Genes', 'SNPs', 'Diseases']].to_json(orient='records')
230
  str_json_table = json.dumps(json.loads(json_table), indent=2)
231
 
232
+ result = self.llm.invoke(input=prompt_validation.format(str_json_table)).content
233
+ print('val')
234
  print(result)
235
+
236
  result = result[result.find('['):result.rfind(']')+1]
237
  try:
238
  result = eval(result)
239
  except SyntaxError:
240
  result = []
241
 
242
+ df_val = pd.DataFrame(result)
243
+ df_val = df_val.merge(df.head(1).drop(['Genes', 'SNPs', 'Diseases'], axis=1), 'cross')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
+ # TODO: How to validate genes and SNPs with ground truth?
246
 
247
+ return df, df_val