from dotenv import load_dotenv from langchain_google_genai import ChatGoogleGenerativeAI from langchain_openai import ChatOpenAI from prompt import * from utils import * import os import json import re load_dotenv() class Validation(): def __init__(self, llm): if llm.startswith('gpt'): self.llm = ChatOpenAI(temperature=0, model_name=llm) elif llm.startswith('gemini'): self.llm = ChatGoogleGenerativeAI(temperature=0, model=llm) else: self.llm = ChatOpenAI(temperature=0, model_name=llm, api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai") def validate(self, df, text, api): df = df.fillna('') df['Genes'] = df['Genes'].str.replace(' ', '').str.upper() df['rsID'] = df['rsID'].str.replace(' ', '').str.lower() # Check if there are multiple Genes sym = [',', '/', '|', '-', '(', ')'] i = 0 while i < len(df): gene = df.loc[i, 'Genes'] for s in sym: if s in gene: genes = gene.split(s) df.loc[i + 0.1], df.loc[i + 0.9] = df.loc[i], df.loc[i] df = df.sort_index().reset_index(drop=True) df.loc[i + 1, 'Genes'], df.loc[i + 2, 'Genes'] = genes[0], s.join(genes[1:]) break i += 1 df.reset_index(drop=True, inplace=True) # Check if there are multiple rsIDs i = 0 while i < len(df): rsid = df.loc[i, 'rsID'] if ',' in rsid: rsids = rsid.split(',') df.loc[i + 0.1], df.loc[i + 0.9] = df.loc[i], df.loc[i] df = df.sort_index().reset_index(drop=True) df.loc[i + 1, 'rsID'], df.loc[i + 2, 'rsID'] = rsids[0], ','.join(rsids[1:]) i += 1 df.reset_index(drop=True, inplace=True) # Check if there are SNPs not well captured for i in df.index: safe = True snp = df.loc[i, 'rsID'] snp = snp.replace('l', '1') if re.fullmatch('rs(\d)+|', snp): pass elif re.fullmatch('ts(\d)+', snp): snp = 'r' + snp[1:] elif re.fullmatch('s(\d)+', snp): snp = 'r' + snp elif re.fullmatch('(\d)+', snp): snp = 'rs' + snp elif re.fullmatch('r(\d)+', snp): snp = 'rs' + snp[1:] if snp[2] == '5': snp += f',rs{snp[3:]}' else: safe = False df = df.drop(i) if safe: df.loc[i, 'rsID'] = snp df.reset_index(drop=True, inplace=True) df_clean = df.copy() # WARNING: DEPRECATED # Validate genes and SNPs with APIs if api: dbsnp = {} for i in df.index: snp = df.loc[i, 'SNPs'] gene = df.loc[i, 'Genes'] if snp not in dbsnp: try: res = call(f'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi?db=snp&retmode=json&id={snp[2:]}').json()['result'][snp[2:]] if 'error' not in res: dbsnp[snp].extend([r['name'] for r in res['genes']]) except Exception as e: print("Error at API", e) pass dbsnp[snp] = list(set(dbsnp[snp])) if gene not in dbsnp[snp]: for other in permutate(gene): if other in dbsnp[snp]: df.loc[i, 'Genes'] = other print(f'{gene} corrected to {other}') break else: df = df.drop(i) # WARNING: DEPRECATED # Check with GWAS ground truth if False: for i in df.index: gene = df.loc[i, 'Genes'] snp = df.loc[i, 'rsID'] perms = permutate(gene) for perm in perms: if perm in ground_truth and snp in ground_truth[perm]: df.loc[i, 'Genes'] = perm if gene != perm: print(f'{gene} corrected to {perm} with {snp}') else: print(f'{gene} and {snp} safe') break else: print(f'{gene} and {snp} not found') df = df.drop(i) # Check with Text for i in df.index: gene = df.loc[i, 'Genes'] snp = df.loc[i, 'rsID'] perms = permutate(gene) for perm in perms: if perm in text and snp in text: df.loc[i, 'Genes'] = perm if gene != perm: print(f'{gene} corrected to {perm} with {snp}') else: print(f'{gene} and {snp} safe') break else: print(f'{gene} and {snp} not found') df = df.drop(i) # Drop (duplicate) entries with empty values genes = [] snps = [] for i in df.index: gene = df.loc[i, 'Genes'] snp = df.loc[i, 'rsID'] if len(gene) == 0 and len(snp) == 0: df = df.drop(i) elif len(gene) == 0: if snp in snps: df = df.drop(i) elif len(snp) == 0: if gene in genes: df = df.drop(i) else: genes.append(gene) snps.append(snp) df.reset_index(drop=True, inplace=True) # Validate genes and traits with LLM (for each 20 rows) idx = 0 df_llm = pd.DataFrame() while True: json_table = df[idx:idx+20].to_json(orient='records') str_json_table = json.dumps(json.loads(json_table), indent=2) result = self.llm.invoke(input=prompt_validation.format(str_json_table)).content result = result[result.find('['):result.rfind(']')+1] try: result = eval(result) except SyntaxError: result = [] df_llm = pd.concat([df_llm, pd.DataFrame(result)]) idx += 20 if idx not in df.index: break df = df_llm.copy() df.reset_index(drop=True, inplace=True) df.drop_duplicates(['Genes', 'rsID'], ignore_index=True, inplace=True) # Evaluate chunk of texts df_text = df[df['Source'] == 'Text'] for i in df_text.index: gene = df_text.loc[i, 'Genes'] snp = df_text.loc[i, 'rsID'] if len(gene) == 0 or len(snp) == 0: continue windows = [] window = 1000 matches = [m.start() for m in re.finditer(re.escape(gene), text)] for index in matches: start = max(0, index - window) end = min(len(text), index + len(gene) + window) window_text = text[start:end] if snp not in window_text: continue windows.append(window_text) result = self.llm.invoke(input=prompt_validate_text.format(gene, snp, windows)).content if "no" in result.lower(): df = df.drop(i) df.reset_index(drop=True, inplace=True) return df, df_clean