PPIformer-CPU / app.py
Anton Bushuiev
Initial commit
29bd8b5
raw
history blame
22.9 kB
# print("""
# __ __ _ ___ _ _ _____ _____ _ _ _ _ _ ____ _____
# | \/ | / \ |_ _| \ | |_ _| ____| \ | | / \ | \ | |/ ___| ____|
# | |\/| | / _ \ | || \| | | | | _| | \| | / _ \ | \| | | | _|
# | | | |/ ___ \ | || |\ | | | | |___| |\ |/ ___ \| |\ | |___| |___
# |_| |_/_/ \_\___|_| \_| |_| |_____|_| \_/_/ \_\_| \_|\____|_____|
# ____ ____ _____ _ _ __
# | __ )| _ \| ____| / \ | |/ /
# | _ \| |_) | _| / _ \ | ' /
# | |_) | _ <| |___ / ___ \| . \
# |____/|_| \_\_____/_/ \_\_|\_\
# """)
import os
# os.system("pip uninstall -y gradio")
# os.system("pip install gradio==3.50.2")
# os.system("pip uninstall -y spaces")
# os.system("pip install spaces==0.8")
os.system("pip uninstall -y torch")
os.system("pip install torch==2.0.1")
import sys
import copy
import random
import tempfile
import shutil
import logging
from pathlib import Path
from functools import partial
import spaces
import gradio as gr
import torch
import numpy as np
import pandas as pd
from Bio.PDB.Polypeptide import protein_letters_3to1
from biopandas.pdb import PandasPdb
from colour import Color
from colour import RGB_TO_COLOR_NAMES
from mutils.proteins import AMINO_ACID_CODES_1
from mutils.pdb import download_pdb
from mutils.mutations import Mutation
from ppiref.extraction import PPIExtractor
from ppiref.utils.ppi import PPIPath
from ppiref.utils.residue import Residue
from ppiformer.tasks.node import DDGPPIformer
from ppiformer.utils.api import download_from_zenodo
from ppiformer.utils.api import predict_ddg as predict_ddg_
from ppiformer.utils.torch import fill_diagonal
from ppiformer.definitions import PPIFORMER_WEIGHTS_DIR
import pkg_resources
import sys
def print_package_versions():
installed_packages = sorted([f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set])
print("Installed packages and their versions:")
for package in installed_packages:
print(package)
print("\nPython version:")
print(sys.version)
print_package_versions()
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
random.seed(0)
@spaces.GPU
def predict_ddg(models, ppi, muts, return_attn):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"[INFO] Device on prediction: {device}")
models = [model.to(device) for model in models]
if return_attn:
ddg_pred, attns = predict_ddg_(models, ppi, muts, return_attn=return_attn)
return ddg_pred.detach().cpu(), attns.detach().cpu()
else:
ddg_pred = predict_ddg_(models, ppi, muts, return_attn=return_attn)
return ddg_pred.detach().cpu()
def process_inputs(inputs, temp_dir):
pdb_code, pdb_path, partners, muts, muts_path = inputs
# Check inputs
if not pdb_code and not pdb_path:
raise gr.Error("PPI structure not specified.")
if pdb_code and pdb_path:
gr.Warning("Both PDB code and PDB file specified. Using PDB file.")
if not partners:
raise gr.Error("Partners not specified.")
if not muts and not muts_path:
raise gr.Error("Mutations not specified.")
if muts and muts_path:
gr.Warning("Both mutations and mutations file specified. Using mutations file.")
# Prepare PDB input
if pdb_path:
# convert file name to PPIRef format
new_pdb_path = temp_dir / f"pdb/{pdb_path.name.replace('_', '-')}"
new_pdb_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(str(pdb_path), str(new_pdb_path))
pdb_path = new_pdb_path
pdb_path = Path(pdb_path)
else:
try:
pdb_code = pdb_code.strip().lower()
pdb_path = temp_dir / f'pdb/{pdb_code}.pdb'
download_pdb(pdb_code, path=pdb_path)
except:
raise gr.Error("PDB download failed.")
# Parse partners
partners = list(map(lambda x: x.strip(), partners.split(',')))
# Add partners to file name
pdb_path = pdb_path.rename(pdb_path.with_stem(f"{pdb_path.stem}-{'-'.join(partners)}"))
# Extract PPI into temp dir
try:
ppi_dir = temp_dir / 'ppi'
extractor = PPIExtractor(out_dir=ppi_dir, nest_out_dir=True, join=True, radius=10.0)
extractor.extract(pdb_path, partners=partners)
ppi_path = PPIPath.construct(ppi_dir, pdb_path.stem, partners)
except:
raise gr.Error("PPI extraction failed.")
# Prepare mutations input
if muts_path:
muts_path = Path(muts_path)
muts = muts_path.read_text()
# Check mutations
# Basic format
try:
muts = [Mutation.from_str(m) for m in muts.strip().split(';') if m.strip()]
except Exception as e:
raise gr.Error(f'Mutations parsing failed: {e}')
# Partners
for mut in muts:
for pmut in mut.muts:
if pmut.chain not in partners:
raise gr.Error(f'Chain of point mutation {pmut} is not in the list of partners {partners}.')
# Consistency with provided .pdb
muts_on_interface = []
for mut in muts:
if mut.wt_in_pdb(ppi_path):
val = True
elif mut.wt_in_pdb(pdb_path):
val = False
else:
raise gr.Error(f'Wild-type of mutation {mut} is not in the provided .pdb file.')
muts_on_interface.append(val)
muts = [str(m) for m in muts]
return pdb_path, ppi_path, muts, muts_on_interface
def plot_3dmol(pdb_path, ppi_path, mut, attn, attn_mut_id=0):
# NOTE 3DMol.js adapted from https://huggingface.co/spaces/huhlim/cg2all/blob/main/app.py
# Read PDB for 3Dmol.js
with open(pdb_path, "r") as fp:
lines = fp.readlines()
mol = ""
for l in lines:
mol += l
mol = mol.replace("OT1", "O ")
mol = mol.replace("OT2", "OXT")
# Read PPI to customize 3Dmol.js visualization
ppi_df = PandasPdb().read_pdb(ppi_path).df['ATOM']
ppi_df = ppi_df.groupby(list(Residue._fields)).apply(lambda df: df[df['atom_name'] == 'CA'].iloc[0]).reset_index(drop=True)
ppi_df['id'] = ppi_df.apply(lambda row: ':'.join([row['residue_name'], row['chain_id'], str(row['residue_number']), row['insertion']]), axis=1)
ppi_df['id'] = ppi_df['id'].apply(lambda x: x[:-1] if x[-1] == ':' else x)
muts_id = Mutation.from_str(mut).wt_to_graphein() # flatten ids of all sp muts
ppi_df['mutated'] = ppi_df.apply(lambda row: row['id'] in muts_id, axis=1)
# Prepare attention coeffictients per residue (normalized sum of direct attention from mutated residues)
attn = torch.nan_to_num(attn, nan=1e-10)
attn_sub = attn[:, attn_mut_id, 0, :, 0, :, :, :] # models, layers, heads, tokens, tokens
idx_mutated = torch.from_numpy(ppi_df.index[ppi_df['mutated']].to_numpy())
attn_sub = fill_diagonal(attn_sub, 1e-10)
attn_mutated = attn_sub[..., idx_mutated, :]
attn_mutated.shape
attns_per_token = torch.sum(attn_mutated, dim=(0, 1, 2, 3))
attns_per_token = (attns_per_token - attns_per_token.min()) / (attns_per_token.max() - attns_per_token.min())
attns_per_token += 1e-10
ppi_df['attn'] = attns_per_token.numpy()
chains = ppi_df.sort_values('attn', ascending=False)['chain_id'].unique()
# Customize 3Dmol.js visualization https://3dmol.csb.pitt.edu/doc/
styles = []
zoom_atoms = []
# Cartoon chains
preferred_colors = ['LimeGreen', 'HotPink', 'RoyalBlue']
all_colors = [c[0] for c in RGB_TO_COLOR_NAMES.values()]
all_colors = [c for c in all_colors if c not in preferred_colors + ['Black', 'White']]
random.shuffle(all_colors)
all_colors = preferred_colors + all_colors
all_colors = [Color(c) for c in all_colors]
chain_to_color = dict(zip(chains, all_colors))
for chain in chains:
styles.append([{"chain": chain}, {"cartoon": {"color": chain_to_color[chain].hex_l, "opacity": 0.6}}])
# Stick PPI and atoms for zoom
# TODO Insertions
for _, row in ppi_df.iterrows():
color = copy.deepcopy(chain_to_color[row['chain_id']])
color.saturation = row['attn']
color = color.hex_l
if row['mutated']:
styles.append([
{'chain': row['chain_id'], 'resi': str(row['residue_number'])},
{'stick': {'color': 'red', 'radius': 0.2, 'opacity': 1.0}}
])
zoom_atoms.append(row['atom_number'])
else:
styles.append([
{'chain': row['chain_id'], 'resi': str(row['residue_number'])},
{'stick': {'color': color, 'radius': row['attn'] / 5, 'opacity': row['attn']}}
])
# Convert style dicts to JS lines
styles = ''.join(['viewer.addStyle(' + ', '.join([str(s).replace("'", '"') for s in dcts]) + ');\n' for dcts in styles])
# Convert zoom atoms to 3DMol.js selection and add labels for mutated residues
zoom_animation_duration = 500
sel = '{\"or\": [' + ', '.join(["{\"serial\": " + str(a) + "}" for a in zoom_atoms]) + ']}'
zoom = 'viewer.zoomTo(' + sel + ',' + f'{zoom_animation_duration});'
for atom in zoom_atoms:
sel = '{\"serial\": ' + str(atom) + '}'
row = ppi_df[ppi_df['atom_number'] == atom].iloc[0]
label = protein_letters_3to1[row['residue_name']] + row['chain_id'] + str(row['residue_number']) + row['insertion']
styles += 'viewer.addLabel(' + f"\"{label}\"," + "{fontSize:16, fontColor:\"red\", backgroundOpacity: 0.0}," + sel + ');\n'
# Construct 3Dmol.js visualization script embedded in HTML
html = (
"""<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
<style>
body{
font-family:sans-serif
}
.mol-container {
width: 100%;
height: 600px;
position: relative;
}
.mol-container select{
background-image:None;
}
</style>
<script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
</head>
<body>
<div id="container" class="mol-container"></div>
<script>
let pdb = `"""
+ mol
+ """`
$(document).ready(function () {
let element = $("#container");
let config = { backgroundColor: "white" };
let viewer = $3Dmol.createViewer(element, config);
viewer.addModel(pdb, "pdb");
viewer.setStyle({"model": 0}, {"ray_opaque_background": "off"}, {"stick": {"color": "lightgrey", "opacity": 0.5}});
"""
+ styles
+ zoom
+ """
viewer.render();
})
</script>
</body></html>"""
)
return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""
def predict(models, temp_dir, *inputs):
logging.info('Starting prediction')
# Process input
pdb_path, ppi_path, muts, muts_on_interface = process_inputs(inputs, temp_dir)
# Create dataframe
df = pd.DataFrame({
'Mutation': muts,
'ddG [kcal/mol]': len(muts) * [np.nan],
'10A Interface': muts_on_interface,
'Attn Id': len(muts) * [np.nan],
})
# Show warning if some mutations are not on the interface
muts_not_on_interface = df[~df['10A Interface']]['Mutation'].tolist()
n_muts_not_on_interface = len(muts_not_on_interface)
if n_muts_not_on_interface:
n_muts_warn = 5
muts_not_on_interface = ';'.join(muts_not_on_interface[:n_muts_warn])
if n_muts_not_on_interface > n_muts_warn:
muts_not_on_interface += f'... (and {n_muts_not_on_interface - n_muts_warn} more)'
gr.Warning((
f"{muts_not_on_interface} {'is' if n_muts_not_on_interface == 1 else 'are'} not on the interface. "
f"The model will predict the effect{'s' if n_muts_not_on_interface > 1 else ''} of "
f"mutation{'s' if n_muts_not_on_interface > 1 else ''} on the whole complex. "
f"This may lead to less accurate predictions."
))
logging.info('Inputs processed')
# Predict using interface for mutations on the interface and using the whole complex otherwise
attn_ppi, attn_pdb = None, None
for df_sub, path in [
[df[df['10A Interface']], ppi_path],
[df[~df['10A Interface']], pdb_path]
]:
if not len(df_sub):
continue
# Predict
try:
ddg, attn = predict_ddg(models, path, df_sub['Mutation'].tolist(), return_attn=True)
except Exception as e:
print(f"Prediction failed. {str(e)}")
raise gr.Error(f"Prediction failed. {str(e)}")
ddg = ddg.detach().numpy().tolist()
logging.info(f'Predictions made for {path}')
# Update dataframe and attention tensor
idx = df_sub.index
df.loc[idx, 'ddG [kcal/mol]'] = ddg
df.loc[idx, 'Attn Id'] = np.arange(len(idx))
if path == ppi_path:
attn_ppi = attn
else:
attn_pdb = attn
df['Attn Id'] = df['Attn Id'].astype(int)
# Round ddG values
df['ddG [kcal/mol]'] = df['ddG [kcal/mol]'].round(3)
# Create PPI-specific dropdown
dropdown = gr.Dropdown(
df['Mutation'].tolist(), value=df['Mutation'].iloc[0],
interactive=True, visible=True, label="Mutation to visualize",
)
# Predefine plot arguments for all dropdown choices
dropdown_choices_to_plot_args = {
mut: (
pdb_path,
ppi_path if df[df['Mutation'] == mut]['10A Interface'].iloc[0] else pdb_path,
mut,
attn_ppi if df[df['Mutation'] == mut]['10A Interface'].iloc[0] else attn_pdb,
df[df['Mutation'] == mut]['Attn Id'].iloc[0]
)
for mut in df['Mutation']
}
# Create dataframe file
path = 'ppiformer_ddg_predictions.csv'
if n_muts_not_on_interface:
df = df[['Mutation', 'ddG [kcal/mol]', '10A Interface']]
df.to_csv(path, index=False)
df = gr.Dataframe(
value=df,
headers=['Mutation', 'ddG [kcal/mol]', '10A Interface'],
datatype=['str', 'number', 'bool'],
col_count=(3, 'fixed'),
)
else:
df = df[['Mutation', 'ddG [kcal/mol]']]
df.to_csv(path, index=False)
df = gr.Dataframe(
value=df,
headers=['Mutation', 'ddG [kcal/mol]'],
datatype=['str', 'number'],
col_count=(2, 'fixed'),
)
logging.info('Prediction results prepared')
return df, path, dropdown, dropdown_choices_to_plot_args
def update_plot(dropdown, dropdown_choices_to_plot_args):
return plot_3dmol(*dropdown_choices_to_plot_args[dropdown])
app = gr.Blocks(theme=gr.themes.Default(primary_hue="green", secondary_hue="pink"))
with app:
# Input GUI
gr.Markdown(value="""
# PPIformer Web
### Computational Design of Protein-Protein Interactions
""")
gr.Image("assets/readme-dimer-close-up.png")
gr.Markdown(value="""
[PPIformer](https://github.com/anton-bushuiev/PPIformer/tree/main) is a state-of-the-art predictor of the effects of mutations
on protein-protein interactions (PPIs), as quantified by the binding free energy changes (ddG). PPIformer was shown to successfully
identify known favourable mutations of the [staphylokinase thrombolytics](https://pubmed.ncbi.nlm.nih.gov/10942387/)
and a [human antibody](https://www.pnas.org/doi/10.1073/pnas.2122954119) against the SARS-CoV-2 spike protein. The model was pre-trained
on the [PPIRef](https://github.com/anton-bushuiev/PPIRef)
dataset via a coarse-grained structural masked modeling and fine-tuned on the [SKEMPI v2.0](https://life.bsc.es/pid/skempi2) dataset via log odds.
Please see more details in [our ICLR 2024 paper](https://arxiv.org/abs/2310.18515).
**Inputs.** To use PPIformer on your data, please specify the PPI structure (PDB code or .pdb file), interacting proteins of interest
(chain codes in the file) and mutations (semicolon-separated list or file with mutations in the
[standard format](https://foldxsuite.crg.eu/parameter/mutant-file): wild-type residue, chain, residue number, mutant residue).
For inspiration, you can use one of the examples below: click on one of the rows to pre-fill the inputs. After specifying the inputs,
press the button to predict the effects of mutations on the PPI. Currently the model runs on CPU, so the predictions may take a few minutes.
**Outputs.** After making a prediction with the model, you will see binding free energy changes for each mutation (ddG values in kcal/mol).
A more negative value indicates an improvement in affinity, whereas a more positive value means a reduction in affinity.
Below you will also see a 3D visualization of the PPI with wild types of mutated residues highlighted in red. The visualization additionally shows
the attention coefficients of the model for the nearest neighboring residues, which quantifies the contribution of the residues
to the predicted ddG value. The brighter and thicker a residue is, the more attention the model paid to it.
""")
with gr.Row(equal_height=True):
with gr.Column():
gr.Markdown("## PPI structure")
with gr.Row(equal_height=True):
pdb_code = gr.Textbox(placeholder="1BUI", label="PDB code", info="Protein Data Bank identifier for the structure (https://www.rcsb.org/)")
partners = gr.Textbox(placeholder="A,B,C", label="Partners", info="Protein chain identifiers in the PDB file forming the PPI interface (two or more)")
pdb_path = gr.File(file_count="single", label="Or .pdb file instead of PDB code (your structure will only be used for this prediction and not stored anywhere)")
with gr.Column():
gr.Markdown("## Mutations")
muts = gr.Textbox(placeholder="SC16A;FC47A;SC16A,FC47A", label="List of (multi-point) mutations", info="SC16A;FC47A;SC16A,FC47A for three mutations: serine to alanine at position 16 in chain C, phenylalanine to alanine at position 47 in chain C, and their double-point combination")
muts_path = gr.File(file_count="single", label="Or file with mutations")
examples = gr.Examples(
examples=[
["1BUI", "A,B,C", "SC16A,FC47A;SC16A;FC47A"],
["3QIB", "A,B,P,C,D", "YP7F,TP12S;YP7F;TP12S"],
["1KNE", "A,P", ';'.join([f"TP6{a}" for a in AMINO_ACID_CODES_1])]
],
inputs=[pdb_code, partners, muts],
label="Examples (click on a line to pre-fill the inputs)",
cache_examples=False
)
# Predict GUI
predict_button = gr.Button(value="Predict effects of mutations on PPI", variant="primary")
# Output GUI
gr.Markdown("## Predictions")
df_file = gr.File(label="Download predictions as .csv", interactive=False, visible=True)
df = gr.Dataframe(
headers=["Mutation", "ddG [kcal/mol]"],
datatype=["str", "number"],
col_count=(2, "fixed"),
)
dropdown = gr.Dropdown(interactive=True, visible=False)
dropdown_choices_to_plot_args = gr.State([])
plot = gr.HTML()
# Bottom info box
gr.Markdown(value="""
<br/>
## About this web
**Use cases**. The predictor can be used in: (i) Drug Discovery for the development of novel drugs and vaccines for various diseases such as cancer,
neurodegenerative disorders, and infectious diseases, (ii) Biotechnological Applications to develop new biocatalysts for biofuels,
industrial chemicals, and pharmaceuticals (iii) Therapeutic Protein Design to develop therapeutic proteins with enhanced stability,
specificity, and efficacy, and (iv) Mechanistic Studies to gain insights into fundamental biological processes, such as signal transduction,
gene regulation, and immune response.
**Acknowledgement**. Please, use the following citation to acknowledge the use of our service. The web server is provided free of charge for non-commercial use.
> Bushuiev, Anton, Roman Bushuiev, Petr Kouba, Anatolii Filkin, Marketa Gabrielova, Michal Gabriel, Jiri Sedlar, Tomas Pluskal, Jiri Damborsky, Stanislav Mazurenko, Josef Sivic.
> "Learning to design protein-protein interactions with enhanced generalization". The Twelfth International Conference on Learning Representations (ICLR 2024).
> [https://arxiv.org/abs/2310.18515](https://arxiv.org/abs/2310.18515).
**Contact**. Please share your feedback or report any bugs through [GitHub Issues](https://github.com/anton-bushuiev/PPIformer/issues/new), or feel free to contact us directly at [anton.bushuiev@cvut.cz](mailto:anton.bushuiev@cvut.cz).
""")
gr.Image("assets/logos.png")
# Download weights from Zenodo
download_from_zenodo('weights.zip')
# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"[INFO] Device on start: {device}")
# Load models
models = [
DDGPPIformer.load_from_checkpoint(
PPIFORMER_WEIGHTS_DIR / f'ddg_regression/{i}.ckpt',
map_location=torch.device('cpu')
).eval()
for i in range(3)
]
models = [model.to(device) for model in models]
# Create temporary directory for storing downloaded PDBs and extracted PPIs
temp_dir_obj = tempfile.TemporaryDirectory()
temp_dir = Path(temp_dir_obj.name)
# Main logic
inputs = [pdb_code, pdb_path, partners, muts, muts_path]
outputs = [df, df_file, dropdown, dropdown_choices_to_plot_args]
predict = partial(predict, models, temp_dir)
predict_button.click(predict, inputs=inputs, outputs=outputs)
# Update plot on dropdown change
dropdown.change(update_plot, inputs=[dropdown, dropdown_choices_to_plot_args], outputs=[plot])
app.launch(allowed_paths=['./assets'])