|
from dotenv import load_dotenv |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
from langchain_openai import ChatOpenAI |
|
from prompt import * |
|
from prompt_old import * |
|
from utils import * |
|
|
|
import os |
|
import json |
|
import re |
|
|
|
load_dotenv() |
|
|
|
class Validation(): |
|
|
|
def __init__(self, llm): |
|
|
|
if llm.startswith('gpt'): |
|
self.llm = ChatOpenAI(temperature=0, model_name=llm) |
|
elif llm.startswith('gemini'): |
|
self.llm = ChatGoogleGenerativeAI(temperature=0, model=llm) |
|
else: |
|
self.llm = ChatOpenAI(temperature=0, model_name=llm, api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai") |
|
|
|
def validate(self, df, text, api): |
|
|
|
df = df.fillna('') |
|
df['Genes'] = df['Genes'].str.replace(' ', '').str.upper() |
|
df['rsID'] = df['rsID'].str.replace(' ', '').str.lower() |
|
|
|
|
|
sym = [',', '/', '|', '-', '(', ')'] |
|
i = 0 |
|
while i < len(df): |
|
gene = df.loc[i, 'Genes'] |
|
for s in sym: |
|
if s in gene: |
|
genes = gene.split(s) |
|
df.loc[i + 0.1], df.loc[i + 0.9] = df.loc[i], df.loc[i] |
|
df = df.sort_index().reset_index(drop=True) |
|
df.loc[i + 1, 'Genes'], df.loc[i + 2, 'Genes'] = genes[0], s.join(genes[1:]) |
|
break |
|
i += 1 |
|
|
|
df.reset_index(drop=True, inplace=True) |
|
|
|
|
|
i = 0 |
|
while i < len(df): |
|
rsid = df.loc[i, 'rsID'] |
|
if ',' in rsid: |
|
rsids = rsid.split(',') |
|
df.loc[i + 0.1], df.loc[i + 0.9] = df.loc[i], df.loc[i] |
|
df = df.sort_index().reset_index(drop=True) |
|
df.loc[i + 1, 'rsID'], df.loc[i + 2, 'rsID'] = rsids[0], ','.join(rsids[1:]) |
|
i += 1 |
|
|
|
df.reset_index(drop=True, inplace=True) |
|
|
|
|
|
for i in df.index: |
|
safe = True |
|
snp = df.loc[i, 'rsID'] |
|
snp = snp.replace('l', '1') |
|
if re.fullmatch('rs(\d)+|', snp): |
|
pass |
|
elif re.fullmatch('ts(\d)+', snp): |
|
snp = 'r' + snp[1:] |
|
elif re.fullmatch('s(\d)+', snp): |
|
snp = 'r' + snp |
|
elif re.fullmatch('(\d)+', snp): |
|
snp = 'rs' + snp |
|
elif re.fullmatch('r(\d)+', snp): |
|
snp = 'rs' + snp[1:] |
|
if snp[2] == '5': |
|
snp += f',rs{snp[3:]}' |
|
else: |
|
safe = False |
|
df = df.drop(i) |
|
|
|
if safe: |
|
df.loc[i, 'rsID'] = snp |
|
|
|
df.reset_index(drop=True, inplace=True) |
|
df_clean = df.copy() |
|
|
|
|
|
|
|
if api: |
|
dbsnp = {} |
|
for i in df.index: |
|
snp = df.loc[i, 'SNPs'] |
|
gene = df.loc[i, 'Genes'] |
|
|
|
if snp not in dbsnp: |
|
try: |
|
res = call(f'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi?db=snp&retmode=json&id={snp[2:]}').json()['result'][snp[2:]] |
|
if 'error' not in res: |
|
dbsnp[snp].extend([r['name'] for r in res['genes']]) |
|
except Exception as e: |
|
print("Error at API", e) |
|
pass |
|
|
|
dbsnp[snp] = list(set(dbsnp[snp])) |
|
|
|
if gene not in dbsnp[snp]: |
|
for other in permutate(gene): |
|
if other in dbsnp[snp]: |
|
df.loc[i, 'Genes'] = other |
|
print(f'{gene} corrected to {other}') |
|
break |
|
else: |
|
df = df.drop(i) |
|
|
|
|
|
|
|
if False: |
|
for i in df.index: |
|
gene = df.loc[i, 'Genes'] |
|
snp = df.loc[i, 'rsID'] |
|
perms = permutate(gene) |
|
|
|
for perm in perms: |
|
if perm in ground_truth and snp in ground_truth[perm]: |
|
df.loc[i, 'Genes'] = perm |
|
if gene != perm: |
|
print(f'{gene} corrected to {perm} with {snp}') |
|
else: |
|
print(f'{gene} and {snp} safe') |
|
break |
|
else: |
|
print(f'{gene} and {snp} not found') |
|
df = df.drop(i) |
|
|
|
|
|
for i in df.index: |
|
gene = df.loc[i, 'Genes'] |
|
snp = df.loc[i, 'rsID'] |
|
perms = permutate(gene) |
|
|
|
for perm in perms: |
|
if perm in text and snp in text: |
|
df.loc[i, 'Genes'] = perm |
|
if gene != perm: |
|
print(f'{gene} corrected to {perm} with {snp}') |
|
else: |
|
print(f'{gene} and {snp} safe') |
|
break |
|
else: |
|
print(f'{gene} and {snp} not found') |
|
df = df.drop(i) |
|
|
|
|
|
genes = [] |
|
snps = [] |
|
for i in df.index: |
|
gene = df.loc[i, 'Genes'] |
|
snp = df.loc[i, 'rsID'] |
|
|
|
if len(gene) == 0 and len(snp) == 0: |
|
df = df.drop(i) |
|
elif len(gene) == 0: |
|
if snp in snps: |
|
df = df.drop(i) |
|
elif len(snp) == 0: |
|
if gene in genes: |
|
df = df.drop(i) |
|
else: |
|
genes.append(gene) |
|
snps.append(snp) |
|
|
|
df.reset_index(drop=True, inplace=True) |
|
|
|
|
|
for i in df.index: |
|
gene = df.loc[i, 'Genes'] |
|
trait = df.loc[i, 'Traits'] |
|
|
|
if len(trait) == 0: |
|
continue |
|
|
|
result = self.llm.invoke(input=prompt_validate_gene_trait.format(gene, trait)).content |
|
if "no" in result.lower(): |
|
df.loc[i, 'Traits'] = '' |
|
|
|
|
|
|
|
if False: |
|
idx = 0 |
|
df_llm = pd.DataFrame() |
|
while True: |
|
json_table = df[idx:idx+20].to_json(orient='records') |
|
str_json_table = json.dumps(json.loads(json_table), indent=2) |
|
|
|
result = self.llm.invoke(input=prompt_validation.format(str_json_table)).content |
|
result = result[result.find('['):result.rfind(']')+1] |
|
try: |
|
result = eval(result) |
|
except SyntaxError: |
|
result = [] |
|
|
|
df_llm = pd.concat([df_llm, pd.DataFrame(result)]) |
|
|
|
idx += 20 |
|
if idx not in df.index: |
|
break |
|
|
|
df = df_llm.copy() |
|
|
|
df.reset_index(drop=True, inplace=True) |
|
df.drop_duplicates(['Genes', 'rsID'], ignore_index=True, inplace=True) |
|
|
|
|
|
df_text = df[df['Source'] == 'Text'] |
|
for i in df_text.index: |
|
gene = df_text.loc[i, 'Genes'] |
|
snp = df_text.loc[i, 'rsID'] |
|
|
|
if len(gene) == 0 or len(snp) == 0: |
|
continue |
|
|
|
windows = [] |
|
window = 1000 |
|
matches = [m.start() for m in re.finditer(re.escape(gene), text)] |
|
for index in matches: |
|
start = max(0, index - window) |
|
end = min(len(text), index + len(gene) + window) |
|
window_text = text[start:end] |
|
if snp not in window_text: |
|
continue |
|
|
|
windows.append(window_text) |
|
|
|
result = self.llm.invoke(input=prompt_validate_gene_rsid.format(gene, snp, windows)).content |
|
if "no" in result.lower(): |
|
df = df.drop(i) |
|
|
|
df.reset_index(drop=True, inplace=True) |
|
|
|
return df, df_clean |