Spaces:
Sleeping
Sleeping
import urllib.parse | |
from datetime import datetime | |
from email.mime.multipart import MIMEMultipart | |
from email.mime.text import MIMEText | |
from email.utils import formatdate, make_msgid | |
from functools import cache | |
import html | |
import os | |
from pathlib import Path | |
import smtplib | |
import sys | |
import tempfile | |
import pandas as pd | |
from bokeh.models import NumberFormatter, BooleanFormatter, HTMLTemplateFormatter | |
import gradio as gr | |
import pytz | |
import panel as pn | |
import seaborn as sns | |
from markdown import markdown | |
from rdkit import Chem, RDConfig | |
from rdkit.Chem import Crippen, Descriptors, rdMolDescriptors, Lipinski, rdmolops, Draw, rdDepictor | |
import requests | |
from app import static | |
sys.path.append(str(Path(RDConfig.RDContribDir) / 'SA_Score')) | |
import sascorer | |
COL_ALIASES = { | |
'out_path': 'Pose', | |
'ligand_conf_path': 'Pose', | |
'ID1': 'Compound ID', | |
'ID2': 'Target ID', | |
'X1': 'Fragment SMILES', | |
'X1^': 'Compound SMILES', | |
'name': 'Complex Name', | |
} | |
COL_DTYPE = { | |
'out_path': 'str', | |
'ligand_conf_path': 'str', | |
'ID1': 'str', | |
'ID2': 'str', | |
'X1': 'str', | |
'X1^': 'str', | |
'name': 'str', | |
} | |
def lipinski(mol): | |
""" | |
Lipinski's rules: | |
Hydrogen bond donors <= 5 | |
Hydrogen bond acceptors <= 10 | |
Molecular weight <= 500 daltons | |
logP <= 5 | |
""" | |
return ( | |
Lipinski.NumHDonors(mol) <= 5 and | |
Lipinski.NumHAcceptors(mol) <= 10 and | |
Descriptors.MolWt(mol) <= 500 and | |
Crippen.MolLogP(mol) <= 5 | |
) | |
def reos(mol): | |
""" | |
Rapid Elimination Of Swill filter: | |
Molecular weight between 200 and 500 | |
LogP between -5.0 and +5.0 | |
H-bond donor count between 0 and 5 | |
H-bond acceptor count between 0 and 10 | |
Formal charge between -2 and +2 | |
Rotatable bond count between 0 and 8 | |
Heavy atom count between 15 and 50 | |
""" | |
return ( | |
200 <= Descriptors.MolWt(mol) <= 500 and | |
-5.0 <= Crippen.MolLogP(mol) <= 5.0 and | |
0 <= Lipinski.NumHDonors(mol) <= 5 and | |
0 <= Lipinski.NumHAcceptors(mol) <= 10 and | |
-2 <= rdmolops.GetFormalCharge(mol) <= 2 and | |
0 <= rdMolDescriptors.CalcNumRotatableBonds(mol) <= 8 and | |
15 <= rdMolDescriptors.CalcNumHeavyAtoms(mol) <= 50 | |
) | |
def ghose(mol): | |
""" | |
Ghose drug like filter: | |
Molecular weight between 160 and 480 | |
LogP between -0.4 and +5.6 | |
Atom count between 20 and 70 | |
Molar refractivity between 40 and 130 | |
""" | |
return ( | |
160 <= Descriptors.MolWt(mol) <= 480 and | |
-0.4 <= Crippen.MolLogP(mol) <= 5.6 and | |
20 <= rdMolDescriptors.CalcNumAtoms(mol) <= 70 and | |
40 <= Crippen.MolMR(mol) <= 130 | |
) | |
def veber(mol): | |
""" | |
The Veber filter is a rule of thumb filter for orally active drugs described in | |
Veber et al., J Med Chem. 2002; 45(12): 2615-23.: | |
Rotatable bonds <= 10 | |
Topological polar surface area <= 140 | |
""" | |
return ( | |
rdMolDescriptors.CalcNumRotatableBonds(mol) <= 10 and | |
rdMolDescriptors.CalcTPSA(mol) <= 140 | |
) | |
def rule_of_three(mol): | |
""" | |
Rule of Three filter (Congreve et al., Drug Discov. Today. 8 (19): 876–7, (2003).): | |
Molecular weight <= 300 | |
LogP <= 3 | |
H-bond donor <= 3 | |
H-bond acceptor count <= 3 | |
Rotatable bond count <= 3 | |
""" | |
return ( | |
Descriptors.MolWt(mol) <= 300 and | |
Crippen.MolLogP(mol) <= 3 and | |
Lipinski.NumHDonors(mol) <= 3 and | |
Lipinski.NumHAcceptors(mol) <= 3 and | |
rdMolDescriptors.CalcNumRotatableBonds(mol) <= 3 | |
) | |
def load_smarts_patterns(smarts_path): | |
# Load the CSV file containing SMARTS patterns | |
smarts_df = pd.read_csv(Path(smarts_path)) | |
# Convert all SMARTS patterns to molecules | |
smarts_mols = [Chem.MolFromSmarts(smarts) for smarts in smarts_df['smarts']] | |
return smarts_mols | |
def smarts_filter(mol, smarts_mols): | |
for smarts_mol in smarts_mols: | |
if smarts_mol is not None and mol.HasSubstructMatch(smarts_mol): | |
return False | |
return True | |
def pains(mol): | |
smarts_mols = load_smarts_patterns("data/filters/pains.csv") | |
return smarts_filter(mol, smarts_mols) | |
def mlsmr(mol): | |
smarts_mols = load_smarts_patterns("data/filters/mlsmr.csv") | |
return smarts_filter(mol, smarts_mols) | |
def dundee(mol): | |
smarts_mols = load_smarts_patterns("data/filters/dundee.csv") | |
return smarts_filter(mol, smarts_mols) | |
def glaxo(mol): | |
smarts_mols = load_smarts_patterns("data/filters/glaxo.csv") | |
return smarts_filter(mol, smarts_mols) | |
def bms(mol): | |
smarts_mols = load_smarts_patterns("data/filters/bms.csv") | |
return smarts_filter(mol, smarts_mols) | |
SCORE_MAP = { | |
'Synthetic Accessibility': sascorer.calculateScore, | |
'LogP': Crippen.MolLogP, | |
'Molecular Weight': Descriptors.MolWt, | |
'Number of Atoms': rdMolDescriptors.CalcNumAtoms, | |
'Number of Heavy Atoms': rdMolDescriptors.CalcNumHeavyAtoms, | |
'Molar Refractivity': Crippen.MolMR, | |
'H-Bond Donor Count': Lipinski.NumHDonors, | |
'H-Bond Acceptor Count': Lipinski.NumHAcceptors, | |
'Rotatable Bond Count': rdMolDescriptors.CalcNumRotatableBonds, | |
'Topological Polar Surface Area': rdMolDescriptors.CalcTPSA, | |
} | |
FILTER_MAP = { | |
# TODO support number_of_violations | |
'REOS': reos, | |
"Lipinski's Rule of Five": lipinski, | |
'Ghose': ghose, | |
'Rule of Three': rule_of_three, | |
'Veber': veber, | |
'PAINS': pains, | |
'MLSMR': mlsmr, | |
'Dundee': dundee, | |
'Glaxo': glaxo, | |
'BMS': bms, | |
} | |
def get_timezone_by_ip(ip): | |
try: | |
data = requests.get(f'https://worldtimeapi.org/api/ip/{ip}').json() | |
return data['timezone'] | |
except Exception: | |
return 'UTC' | |
def ts_to_str(timestamp, timezone): | |
# Create a timezone-aware datetime object from the UNIX timestamp | |
dt = datetime.fromtimestamp(timestamp, pytz.utc) | |
# Convert the timezone-aware datetime object to the target timezone | |
target_timezone = pytz.timezone(timezone) | |
localized_dt = dt.astimezone(target_timezone) | |
# Format the datetime object to the specified string format | |
return localized_dt.strftime('%Y-%m-%d %H:%M:%S (%Z%z)') | |
def send_email(job_info): | |
if job_info.get('email'): | |
try: | |
email_info = job_info.copy() | |
email_serv = os.getenv('EMAIL_SERV') | |
email_port = os.getenv('EMAIL_PORT') | |
email_addr = os.getenv('EMAIL_ADDR') | |
email_pass = os.getenv('EMAIL_PASS') | |
email_form = os.getenv('EMAIL_FORM') | |
email_subj = os.getenv('EMAIL_SUBJ') | |
for key, value in email_info.items(): | |
if key.endswith("time") and value: | |
email_info[key] = ts_to_str(value, get_timezone_by_ip(email_info['ip'])) | |
server = smtplib.SMTP(email_serv, int(email_port)) | |
# server.starttls() | |
server.login(email_addr, email_pass) | |
msg = MIMEMultipart("alternative") | |
msg["From"] = email_addr | |
msg["To"] = email_info['email'] | |
msg["Subject"] = email_subj.format(**email_info) | |
msg["Date"] = formatdate(localtime=True) | |
msg["Message-ID"] = make_msgid() | |
msg.attach(MIMEText(markdown(email_form.format(**email_info)), 'html')) | |
msg.attach(MIMEText(email_form.format(**email_info), 'plain')) | |
server.sendmail(email_addr, email_info['email'], msg.as_string()) | |
server.quit() | |
gr.Info('Email notification sent.') | |
except Exception as e: | |
gr.Warning('Failed to send email notification due to error: ' + str(e)) | |
def read_molecule(path): | |
if path.endswith('.pdb'): | |
return Chem.MolFromPDBFile(path, sanitize=False, removeHs=True) | |
if path.endswith('.pdr'): | |
return open(path, 'r').read() | |
elif path.endswith('.mol'): | |
return Chem.MolFromMolFile(path, sanitize=False, removeHs=True) | |
elif path.endswith('.mol2'): | |
return Chem.MolFromMol2File(path, sanitize=False, removeHs=True) | |
elif path.endswith('.sdf'): | |
return Chem.SDMolSupplier(path, sanitize=False, removeHs=True)[0] | |
raise Exception('Unknown file extension') | |
def read_molecule_file(in_file, allowed_extentions): | |
if isinstance(in_file, str): | |
path = in_file | |
else: | |
path = in_file.name | |
extension = path.split('.')[-1] | |
if extension not in allowed_extentions: | |
msg = static.INVALID_FORMAT_MSG.format(extension=extension) | |
return None, None, msg | |
try: | |
mol = read_molecule(path) | |
except Exception as e: | |
e = str(e).replace('\'', '') | |
msg = static.ERROR_FORMAT_MSG.format(message=e) | |
return None, None, msg | |
if extension in 'pdb': | |
content = Chem.MolToPDBBlock(mol) | |
elif extension in ['mol', 'mol2', 'sdf']: | |
content = Chem.MolToMolBlock(mol, kekulize=False) | |
extension = 'mol' | |
else: | |
raise NotImplementedError | |
return content, extension, None | |
# def create_complex_view_html( | |
# complex_path, pocket_path_dict=None, | |
# interactive_ligands=True, interactive_pockets=True | |
# ): | |
# """Generates HTML for complex visualization.""" | |
# model_i = -1 | |
# viewer_models = "" | |
# if complex_path: | |
# complex_data, extension, html = read_molecule_file(complex_path, allowed_extentions=['pdb']) | |
# viewer_models += f'viewer.addModel(`{complex_data}`, "pdb");' | |
# model_i += 1 | |
# viewer_models += f"viewer.getModel({model_i}).setStyle({{ hetflag: false }}, proteinStyle);" | |
# viewer_models += f"viewer.getModel({model_i}).setStyle({{ hetflag: true }}, ligandStyle);" | |
# if interactive_ligands: | |
# # return ligand residue info when the ligand is clicked | |
# viewer_models += f""" | |
# let selectedLigand = null; | |
# viewer.getModel({model_i}).setClickable( | |
# {{ hetflag: true, byres: true }}, | |
# true, | |
# function (_atom, _viewer, _event, _container) {{ | |
# let currentLigand = {{ resn: _atom.resn, chain: _atom.chain, resi: _atom.resi }}; | |
# | |
# if (selectedLigand === currentLigand) {{ | |
# // Deselect ligand | |
# selectedLigand = null; | |
# _viewer.setStyle( | |
# {{ resn: _atom.resn, chain: _atom.chain, resi: _atom.resi }}, | |
# ligandStyle | |
# ); | |
# console.log("Deselected Residue:", currentLigand); | |
# window.parent.postMessage({{ | |
# name: "ligand_selection", | |
# data: {{ residue: currentLigand, add: false }} | |
# }}, "*"); | |
# }} else {{ | |
# // Select ligand and deselect previous | |
# if (selectedLigand) {{ | |
# _viewer.setStyle( | |
# {{ | |
# resn: selectedLigand.resn, | |
# chain: selectedLigand.chain, | |
# resi: selectedLigand.resi | |
# }}, | |
# ligandStyle | |
# ); | |
# }} | |
# selectedLigand = currentLigand; | |
# _viewer.setStyle( | |
# {{ resn: _atom.resn, chain: _atom.chain, resi: _atom.resi }}, | |
# {{ stick: {{ color: "red", radius: 0.4}} }} | |
# ); | |
# console.log("Selected Residue:", currentLigand); | |
# window.parent.postMessage({{ | |
# name: "ligand_selection", | |
# data: {{ residue: currentLigand, add: true }} | |
# }}, "*"); | |
# }} | |
# _viewer.render(); | |
# }} | |
# ); | |
# """ | |
# if pocket_path_dict: | |
# pocket_data_dict = {k: open(v, 'r').read() for k, v in pocket_path_dict.items()} | |
# for pocket_name, pocket_data in pocket_data_dict.items(): | |
# viewer_models += f'viewer.addModel(`{pocket_data}`, "pqr");' | |
# model_i += 1 | |
# viewer_models += f'viewer.getModel({model_i}).setStyle(pocketStyle);' | |
# if interactive_pockets: | |
# # return the pocket name when the pocket is clicked | |
# viewer_models += f""" | |
# let selectedPocket = null; | |
# viewer.getModel({model_i}).setClickable( | |
# {{ byres: true }}, | |
# true, | |
# function (_atom, _viewer, _event, _container) {{ | |
# let currentPocket = "{pocket_name}"; | |
# | |
# if (selectedPocket == currentPocket) {{ | |
# // Deselect pocket | |
# selectedPocket = null; | |
# _viewer.getModel({model_i}).setStyle( pocketStyle ); | |
# console.log("Deselected Pocket:", currentPocket); | |
# window.parent.postMessage({{ | |
# name: "pocket_selection", | |
# data: {{ pocket: currentPocket, add: false }} | |
# }}, "*"); | |
# }} else {{ | |
# // Select pocket and deselect previous | |
# if (selectedPocket) {{ | |
# _viewer.getModel(selectedPocket).setStyle( pocketStyle ); | |
# }} | |
# selectedPocket = currentPocket; | |
# _viewer.getModel({model_i}).setStyle( | |
# {{ sphere: {{ color: "red", opacity: 0.9}} }} | |
# ); | |
# console.log("Selected Pocket:", currentPocket); | |
# window.parent.postMessage({{ | |
# name: "pocket_selection", | |
# data: {{ pocket: currentPocket, add: true }} | |
# }}, "*"); | |
# }} | |
# _viewer.render(); | |
# }} | |
# ); | |
# """ | |
# | |
# html = static.COMPLEX_RENDERING_TEMPLATE.format(viewer_models=viewer_models) | |
# return static.IFRAME_TEMPLATE.format(html=html) | |
def prepare_df_for_table(result_df): | |
result_df.dropna(subset=['mol'], inplace=True) | |
rdDepictor.SetPreferCoordGen(True) | |
draw_opts = Draw.rdMolDraw2D.MolDrawOptions() | |
draw_opts.clearBackground = False | |
draw_opts.bondLineWidth = 0.5 | |
draw_opts.explicitMethyl = True | |
draw_opts.singleColourWedgeBonds = True | |
draw_opts.addStereoAnnotation = False | |
draw_opts.useCDKAtomPalette() | |
def draw_mol(mol): | |
# Create a new drawer instance for each molecule (for efficiency) | |
drawer = Draw.MolDraw2DSVG(90, 56) | |
drawer.SetDrawOptions(draw_opts) | |
# Draw the molecule and return the SVG as a URI | |
drawer.DrawMolecule(mol) | |
drawer.FinishDrawing() | |
return urllib.parse.quote(drawer.GetDrawingText()) | |
# Convert to URI-formatted inline SVG | |
result_df['Compound'] = result_df['mol'].apply(draw_mol) | |
return result_df | |
def create_result_table_html(summary_df, result_info=None, opts=(), progress=gr.Progress(track_tqdm=True)): | |
html_df = summary_df.copy().drop(columns=['mol']) | |
html_df.rename(columns=COL_ALIASES, inplace=True) | |
if result_info: | |
output_dir = Path(result_info['output_dir']) | |
job_type = result_info['type'] | |
html_df['Pose'] = html_df['Pose'].apply(lambda x: str(output_dir / job_type / x)) | |
hidden_cols = [col for col in html_df.columns if col.endswith('_path')] | |
rightmost_cols = ['Complex Name', 'Fragment SMILES', 'Compound SMILES'] | |
col_order = ([col for col in html_df.columns if col not in rightmost_cols] + | |
[col for col in html_df.columns if col in rightmost_cols]) | |
html_df = html_df[col_order] | |
html_df.index.name = 'Index' | |
# if 'Scaffold' in html_df.columns and 'Exclude Scaffold Graph' not in opts: | |
# html_df['Scaffold'] = html_df['Scaffold'].parallel_apply( | |
# lambda x: PandasTools.PrintAsImageString(x) if not pd.isna(x) else x) | |
# else: | |
# html_df.drop(['Scaffold'], axis=1, inplace=True) | |
num_cols = html_df.select_dtypes('number').columns | |
num_col_colors = sns.color_palette('husl', len(num_cols)) | |
bool_cols = html_df.select_dtypes(bool).columns | |
image_zoom_formatter = HTMLTemplateFormatter( | |
template='<img src="data:image/svg+xml,<%= value %>" alt="Molecule" class="zoom-img">' | |
) | |
bool_formatters = {col: BooleanFormatter() for col in bool_cols} | |
float_formatters = {col: NumberFormatter(format='0.000') for col in html_df.select_dtypes('floating').columns} | |
other_formatters = { | |
'Compound': image_zoom_formatter, | |
# 'Scaffold': image_zoom_formatter, | |
'Pose': {'type': 'molDisplayButtonFormatter'}, | |
} | |
formatters = {**bool_formatters, **float_formatters, **other_formatters} | |
# html = df.to_html(file) | |
# return html | |
static_url = "gradio_api/file=app/static/" | |
pn.extension( | |
design='material', | |
css_files=[ | |
static_url + 'panel.css' | |
], | |
js_files={ | |
'panel_custom': static_url + 'panel.js', | |
}, | |
) | |
report_table = pn.widgets.Tabulator( | |
html_df, formatters=formatters, | |
frozen_columns=['Index', 'Pose', 'Compound ID', 'Compound'], | |
hidden_columns=hidden_cols, | |
sizing_mode='stretch_both', | |
disabled=True, selectable=False, | |
pagination='local', | |
configuration={ | |
'rowHeight': 60, | |
}, | |
) | |
for i, col in enumerate(num_cols): | |
cmap = sns.light_palette(num_col_colors[i], as_cmap=True) | |
cmap.set_bad(color='white') | |
report_table.style.background_gradient( | |
subset=html_df.columns == col, cmap=cmap) | |
# TODO change this to use commonn substructures | |
# pie_charts = {} | |
# for y in html_df.columns.intersection(['Interaction Probability', 'Binding Affinity (IC50 [nM])']): | |
# for category in categories: | |
# pie_charts[y][category] = [] | |
# for k in [10, 30, 100]: | |
# if k < len(html_df): | |
# pie_charts[y][category].append(create_pie_chart(html_df, category=category, value=y, top_k=k)) | |
# else: | |
# pie_charts[y][category].append(create_pie_chart(html_df, category=category, value=y, top_k=len(html_df))) | |
# break | |
# # Add 'All' tab regardless of the prediction dataset size | |
# # pie_charts[y].append(create_pie_chart(html_df, category=category, value=y, top_k=len(html_df))) | |
# | |
# # Remove key-value pairs with an empty list | |
# pie_charts[y] = {k: v for k, v in pie_charts[y].items() if any(v)} | |
# pie_charts = {k: v for k, v in pie_charts.items() if any(v)} | |
# stats_pane = pn.Column() | |
# if pie_charts: | |
# for score_name, figure_dict in pie_charts.items(): | |
# score_row = pn.Row() | |
# for category, figure_list in figure_dict.items(): | |
# score_row.append( | |
# pn.Column(f'### {category} by Top {score_name}', pn.Tabs(*figure_list, tabs_location='above')), | |
# # pn.Card(pn.Row(v), title=f'{category} by Top {k}') | |
# ) | |
# stats_pane.append( | |
# score_row | |
# ) | |
# | |
# if stats_pane: | |
# template.main.append( | |
# pn.Card(stats_pane, sizing_mode='stretch_width', title='Summary Statistics', margin=10) | |
# ) | |
if result_info: | |
table_title = (f"{job_type.title()} Results " | |
f"({'No Linkable Pairs' if job_type != 'linking' else 'Generated Molecules'})") | |
report = pn.Column(pn.Accordion( | |
(table_title, report_table), | |
toggle=True, margin=5, active=[0] | |
)) | |
aspect_ratio = '1.090 / 1' | |
else: | |
report = report_table | |
aspect_ratio = '1.618 / 1' | |
with tempfile.TemporaryDirectory() as tmpdir: | |
file = Path(tmpdir) / 'report.html' | |
report.save(file) | |
# iframe_html = static.IFRAME_LINK_TEMPLATE.format(src="gradio_api/file=" + str(file)) | |
html_str = file.read_text() # .replace('\'', '\"') | |
iframe_html = static.IFRAME_TEMPLATE.format(srcdoc=html.escape(html_str), aspect_ratio=aspect_ratio) | |
return iframe_html | |
def download_file(url): | |
"""Downloads a small file to a temporary location, preserving its filename.""" | |
response = requests.get(url) | |
if response.status_code == 404: | |
raise ValueError('No record found for the provided PDB ID.') | |
response.raise_for_status() | |
filename = Path(url).name | |
temp_dir = Path(tempfile.gettempdir()) / 'gradio' | |
temp_path = temp_dir / filename | |
temp_path.write_bytes(response.content) | |
return str(temp_path) | |
def uniprot_to_pdb(uniprot_id): | |
"""Queries the RCSB PDB API to find PDB entities associated with a UniProt ID.""" | |
base_url = "https://search.rcsb.org/rcsbsearch/v2/query" | |
query_payload = { | |
"query": { | |
"type": "group", | |
"logical_operator": "and", | |
"nodes": [ | |
{ | |
"type": "terminal", | |
"service": "text", | |
"parameters": { | |
"operator": "exact_match", | |
"value": uniprot_id, | |
"attribute": "rcsb_polymer_entity_container_identifiers.reference_sequence_identifiers.database_accession" | |
} | |
}, | |
{ | |
"type": "terminal", | |
"service": "text", | |
"parameters": { | |
"operator": "exact_match", | |
"value": "UniProt", | |
"attribute": "rcsb_polymer_entity_container_identifiers.reference_sequence_identifiers.database_name" | |
} | |
} | |
] | |
}, | |
"return_type": "entry" | |
} | |
try: | |
# Send POST request with JSON payload | |
response = requests.post(base_url, json=query_payload) | |
response.raise_for_status() | |
data = response.json() | |
return [entry["identifier"] for entry in data.get("result_set", [])] | |
except Exception as e: | |
return [] | |
def fasta_to_pdb(fasta_sequence): | |
"""Queries the RCSB PDB API to find PDB IDs associated with a FASTA sequence.""" | |
base_url = "https://search.rcsb.org/rcsbsearch/v2/query" | |
query_payload = { | |
"query": { | |
"type": "terminal", | |
"service": "sequence", | |
"parameters": { | |
"evalue_cutoff": 1, | |
"identity_cutoff": 0.9, | |
"sequence_type": "protein", | |
"value": fasta_sequence | |
} | |
}, | |
"request_options": { | |
"scoring_strategy": "sequence" | |
}, | |
"return_type": "entry" | |
} | |
try: | |
# Send POST request with JSON payload | |
response = requests.post(base_url, json=query_payload) | |
response.raise_for_status() | |
data = response.json() | |
return [entry["identifier"] for entry in data.get("result_set", [])] | |
except Exception as e: | |
return [] | |