fadliaulawi's picture
Add gene-trait validation
b688e46
from dotenv import load_dotenv
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI
from prompt import *
from prompt_old import *
from utils import *
import os
import json
import re
load_dotenv()
class Validation():
def __init__(self, llm):
if llm.startswith('gpt'):
self.llm = ChatOpenAI(temperature=0, model_name=llm)
elif llm.startswith('gemini'):
self.llm = ChatGoogleGenerativeAI(temperature=0, model=llm)
else:
self.llm = ChatOpenAI(temperature=0, model_name=llm, api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai")
def validate(self, df, text, api):
df = df.fillna('')
df['Genes'] = df['Genes'].str.replace(' ', '').str.upper()
df['rsID'] = df['rsID'].str.replace(' ', '').str.lower()
# Check if there are multiple Genes
sym = [',', '/', '|', '-', '(', ')']
i = 0
while i < len(df):
gene = df.loc[i, 'Genes']
for s in sym:
if s in gene:
genes = gene.split(s)
df.loc[i + 0.1], df.loc[i + 0.9] = df.loc[i], df.loc[i]
df = df.sort_index().reset_index(drop=True)
df.loc[i + 1, 'Genes'], df.loc[i + 2, 'Genes'] = genes[0], s.join(genes[1:])
break
i += 1
df.reset_index(drop=True, inplace=True)
# Check if there are multiple rsIDs
i = 0
while i < len(df):
rsid = df.loc[i, 'rsID']
if ',' in rsid:
rsids = rsid.split(',')
df.loc[i + 0.1], df.loc[i + 0.9] = df.loc[i], df.loc[i]
df = df.sort_index().reset_index(drop=True)
df.loc[i + 1, 'rsID'], df.loc[i + 2, 'rsID'] = rsids[0], ','.join(rsids[1:])
i += 1
df.reset_index(drop=True, inplace=True)
# Check if there are SNPs not well captured
for i in df.index:
safe = True
snp = df.loc[i, 'rsID']
snp = snp.replace('l', '1')
if re.fullmatch('rs(\d)+|', snp):
pass
elif re.fullmatch('ts(\d)+', snp):
snp = 'r' + snp[1:]
elif re.fullmatch('s(\d)+', snp):
snp = 'r' + snp
elif re.fullmatch('(\d)+', snp):
snp = 'rs' + snp
elif re.fullmatch('r(\d)+', snp):
snp = 'rs' + snp[1:]
if snp[2] == '5':
snp += f',rs{snp[3:]}'
else:
safe = False
df = df.drop(i)
if safe:
df.loc[i, 'rsID'] = snp
df.reset_index(drop=True, inplace=True)
df_clean = df.copy()
# WARNING: DEPRECATED
# Validate genes and SNPs with APIs
if api:
dbsnp = {}
for i in df.index:
snp = df.loc[i, 'SNPs']
gene = df.loc[i, 'Genes']
if snp not in dbsnp:
try:
res = call(f'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi?db=snp&retmode=json&id={snp[2:]}').json()['result'][snp[2:]]
if 'error' not in res:
dbsnp[snp].extend([r['name'] for r in res['genes']])
except Exception as e:
print("Error at API", e)
pass
dbsnp[snp] = list(set(dbsnp[snp]))
if gene not in dbsnp[snp]:
for other in permutate(gene):
if other in dbsnp[snp]:
df.loc[i, 'Genes'] = other
print(f'{gene} corrected to {other}')
break
else:
df = df.drop(i)
# WARNING: DEPRECATED
# Check with GWAS ground truth
if False:
for i in df.index:
gene = df.loc[i, 'Genes']
snp = df.loc[i, 'rsID']
perms = permutate(gene)
for perm in perms:
if perm in ground_truth and snp in ground_truth[perm]:
df.loc[i, 'Genes'] = perm
if gene != perm:
print(f'{gene} corrected to {perm} with {snp}')
else:
print(f'{gene} and {snp} safe')
break
else:
print(f'{gene} and {snp} not found')
df = df.drop(i)
# Check Genes and rsIDs appear in Text
for i in df.index:
gene = df.loc[i, 'Genes']
snp = df.loc[i, 'rsID']
perms = permutate(gene)
for perm in perms:
if perm in text and snp in text:
df.loc[i, 'Genes'] = perm
if gene != perm:
print(f'{gene} corrected to {perm} with {snp}')
else:
print(f'{gene} and {snp} safe')
break
else:
print(f'{gene} and {snp} not found')
df = df.drop(i)
# Drop (duplicate) entries with empty values
genes = []
snps = []
for i in df.index:
gene = df.loc[i, 'Genes']
snp = df.loc[i, 'rsID']
if len(gene) == 0 and len(snp) == 0:
df = df.drop(i)
elif len(gene) == 0:
if snp in snps:
df = df.drop(i)
elif len(snp) == 0:
if gene in genes:
df = df.drop(i)
else:
genes.append(gene)
snps.append(snp)
df.reset_index(drop=True, inplace=True)
# Validate genes and traits with LLM
for i in df.index:
gene = df.loc[i, 'Genes']
trait = df.loc[i, 'Traits']
if len(trait) == 0:
continue
result = self.llm.invoke(input=prompt_validate_gene_trait.format(gene, trait)).content
if "no" in result.lower():
df.loc[i, 'Traits'] = ''
# WARNING: DEPRECATED
# Validate genes and traits with LLM (for each 20 rows)
if False:
idx = 0
df_llm = pd.DataFrame()
while True:
json_table = df[idx:idx+20].to_json(orient='records')
str_json_table = json.dumps(json.loads(json_table), indent=2)
result = self.llm.invoke(input=prompt_validation.format(str_json_table)).content
result = result[result.find('['):result.rfind(']')+1]
try:
result = eval(result)
except SyntaxError:
result = []
df_llm = pd.concat([df_llm, pd.DataFrame(result)])
idx += 20
if idx not in df.index:
break
df = df_llm.copy()
df.reset_index(drop=True, inplace=True)
df.drop_duplicates(['Genes', 'rsID'], ignore_index=True, inplace=True)
# Evaluate chunk of texts
df_text = df[df['Source'] == 'Text']
for i in df_text.index:
gene = df_text.loc[i, 'Genes']
snp = df_text.loc[i, 'rsID']
if len(gene) == 0 or len(snp) == 0:
continue
windows = []
window = 1000
matches = [m.start() for m in re.finditer(re.escape(gene), text)]
for index in matches:
start = max(0, index - window)
end = min(len(text), index + len(gene) + window)
window_text = text[start:end]
if snp not in window_text:
continue
windows.append(window_text)
result = self.llm.invoke(input=prompt_validate_gene_rsid.format(gene, snp, windows)).content
if "no" in result.lower():
df = df.drop(i)
df.reset_index(drop=True, inplace=True)
return df, df_clean