# print(""" # __ __ _ ___ _ _ _____ _____ _ _ _ _ _ ____ _____ # | \/ | / \ |_ _| \ | |_ _| ____| \ | | / \ | \ | |/ ___| ____| # | |\/| | / _ \ | || \| | | | | _| | \| | / _ \ | \| | | | _| # | | | |/ ___ \ | || |\ | | | | |___| |\ |/ ___ \| |\ | |___| |___ # |_| |_/_/ \_\___|_| \_| |_| |_____|_| \_/_/ \_\_| \_|\____|_____| # ____ ____ _____ _ _ __ # | __ )| _ \| ____| / \ | |/ / # | _ \| |_) | _| / _ \ | ' / # | |_) | _ <| |___ / ___ \| . \ # |____/|_| \_\_____/_/ \_\_|\_\ # """) import os # os.system("pip uninstall -y gradio") # os.system("pip install gradio==3.50.2") # os.system("pip uninstall -y spaces") # os.system("pip install spaces==0.8") os.system("pip uninstall -y torch") os.system("pip install torch==2.0.1") import sys import copy import random import tempfile import shutil import logging from pathlib import Path from functools import partial import spaces import gradio as gr import torch import numpy as np import pandas as pd from Bio.PDB.Polypeptide import protein_letters_3to1 from biopandas.pdb import PandasPdb from colour import Color from colour import RGB_TO_COLOR_NAMES from mutils.proteins import AMINO_ACID_CODES_1 from mutils.pdb import download_pdb from mutils.mutations import Mutation from ppiref.extraction import PPIExtractor from ppiref.utils.ppi import PPIPath from ppiref.utils.residue import Residue from ppiformer.tasks.node import DDGPPIformer from ppiformer.utils.api import download_from_zenodo from ppiformer.utils.api import predict_ddg as predict_ddg_ from ppiformer.utils.torch import fill_diagonal from ppiformer.definitions import PPIFORMER_WEIGHTS_DIR import pkg_resources import sys def print_package_versions(): installed_packages = sorted([f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set]) print("Installed packages and their versions:") for package in installed_packages: print(package) print("\nPython version:") print(sys.version) print_package_versions() logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)] ) random.seed(0) @spaces.GPU def predict_ddg(models, ppi, muts, return_attn): device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"[INFO] Device on prediction: {device}") models = [model.to(device) for model in models] if return_attn: ddg_pred, attns = predict_ddg_(models, ppi, muts, return_attn=return_attn) return ddg_pred.detach().cpu(), attns.detach().cpu() else: ddg_pred = predict_ddg_(models, ppi, muts, return_attn=return_attn) return ddg_pred.detach().cpu() def process_inputs(inputs, temp_dir): pdb_code, pdb_path, partners, muts, muts_path = inputs # Check inputs if not pdb_code and not pdb_path: raise gr.Error("PPI structure not specified.") if pdb_code and pdb_path: gr.Warning("Both PDB code and PDB file specified. Using PDB file.") if not partners: raise gr.Error("Partners not specified.") if not muts and not muts_path: raise gr.Error("Mutations not specified.") if muts and muts_path: gr.Warning("Both mutations and mutations file specified. Using mutations file.") # Prepare PDB input if pdb_path: # convert file name to PPIRef format new_pdb_path = temp_dir / f"pdb/{pdb_path.name.replace('_', '-')}" new_pdb_path.parent.mkdir(parents=True, exist_ok=True) shutil.copy(str(pdb_path), str(new_pdb_path)) pdb_path = new_pdb_path pdb_path = Path(pdb_path) else: try: pdb_code = pdb_code.strip().lower() pdb_path = temp_dir / f'pdb/{pdb_code}.pdb' download_pdb(pdb_code, path=pdb_path) except: raise gr.Error("PDB download failed.") # Parse partners partners = list(map(lambda x: x.strip(), partners.split(','))) # Add partners to file name pdb_path = pdb_path.rename(pdb_path.with_stem(f"{pdb_path.stem}-{'-'.join(partners)}")) # Extract PPI into temp dir try: ppi_dir = temp_dir / 'ppi' extractor = PPIExtractor(out_dir=ppi_dir, nest_out_dir=True, join=True, radius=10.0) extractor.extract(pdb_path, partners=partners) ppi_path = PPIPath.construct(ppi_dir, pdb_path.stem, partners) except: raise gr.Error("PPI extraction failed.") # Prepare mutations input if muts_path: muts_path = Path(muts_path) muts = muts_path.read_text() # Check mutations # Basic format try: muts = [Mutation.from_str(m) for m in muts.strip().split(';') if m.strip()] except Exception as e: raise gr.Error(f'Mutations parsing failed: {e}') # Partners for mut in muts: for pmut in mut.muts: if pmut.chain not in partners: raise gr.Error(f'Chain of point mutation {pmut} is not in the list of partners {partners}.') # Consistency with provided .pdb muts_on_interface = [] for mut in muts: if mut.wt_in_pdb(ppi_path): val = True elif mut.wt_in_pdb(pdb_path): val = False else: raise gr.Error(f'Wild-type of mutation {mut} is not in the provided .pdb file.') muts_on_interface.append(val) muts = [str(m) for m in muts] return pdb_path, ppi_path, muts, muts_on_interface def plot_3dmol(pdb_path, ppi_path, mut, attn, attn_mut_id=0): # NOTE 3DMol.js adapted from https://huggingface.co/spaces/huhlim/cg2all/blob/main/app.py # Read PDB for 3Dmol.js with open(pdb_path, "r") as fp: lines = fp.readlines() mol = "" for l in lines: mol += l mol = mol.replace("OT1", "O ") mol = mol.replace("OT2", "OXT") # Read PPI to customize 3Dmol.js visualization ppi_df = PandasPdb().read_pdb(ppi_path).df['ATOM'] ppi_df = ppi_df.groupby(list(Residue._fields)).apply(lambda df: df[df['atom_name'] == 'CA'].iloc[0]).reset_index(drop=True) ppi_df['id'] = ppi_df.apply(lambda row: ':'.join([row['residue_name'], row['chain_id'], str(row['residue_number']), row['insertion']]), axis=1) ppi_df['id'] = ppi_df['id'].apply(lambda x: x[:-1] if x[-1] == ':' else x) muts_id = Mutation.from_str(mut).wt_to_graphein() # flatten ids of all sp muts ppi_df['mutated'] = ppi_df.apply(lambda row: row['id'] in muts_id, axis=1) # Prepare attention coeffictients per residue (normalized sum of direct attention from mutated residues) attn = torch.nan_to_num(attn, nan=1e-10) attn_sub = attn[:, attn_mut_id, 0, :, 0, :, :, :] # models, layers, heads, tokens, tokens idx_mutated = torch.from_numpy(ppi_df.index[ppi_df['mutated']].to_numpy()) attn_sub = fill_diagonal(attn_sub, 1e-10) attn_mutated = attn_sub[..., idx_mutated, :] attn_mutated.shape attns_per_token = torch.sum(attn_mutated, dim=(0, 1, 2, 3)) attns_per_token = (attns_per_token - attns_per_token.min()) / (attns_per_token.max() - attns_per_token.min()) attns_per_token += 1e-10 ppi_df['attn'] = attns_per_token.numpy() chains = ppi_df.sort_values('attn', ascending=False)['chain_id'].unique() # Customize 3Dmol.js visualization https://3dmol.csb.pitt.edu/doc/ styles = [] zoom_atoms = [] # Cartoon chains preferred_colors = ['LimeGreen', 'HotPink', 'RoyalBlue'] all_colors = [c[0] for c in RGB_TO_COLOR_NAMES.values()] all_colors = [c for c in all_colors if c not in preferred_colors + ['Black', 'White']] random.shuffle(all_colors) all_colors = preferred_colors + all_colors all_colors = [Color(c) for c in all_colors] chain_to_color = dict(zip(chains, all_colors)) for chain in chains: styles.append([{"chain": chain}, {"cartoon": {"color": chain_to_color[chain].hex_l, "opacity": 0.6}}]) # Stick PPI and atoms for zoom # TODO Insertions for _, row in ppi_df.iterrows(): color = copy.deepcopy(chain_to_color[row['chain_id']]) color.saturation = row['attn'] color = color.hex_l if row['mutated']: styles.append([ {'chain': row['chain_id'], 'resi': str(row['residue_number'])}, {'stick': {'color': 'red', 'radius': 0.2, 'opacity': 1.0}} ]) zoom_atoms.append(row['atom_number']) else: styles.append([ {'chain': row['chain_id'], 'resi': str(row['residue_number'])}, {'stick': {'color': color, 'radius': row['attn'] / 5, 'opacity': row['attn']}} ]) # Convert style dicts to JS lines styles = ''.join(['viewer.addStyle(' + ', '.join([str(s).replace("'", '"') for s in dcts]) + ');\n' for dcts in styles]) # Convert zoom atoms to 3DMol.js selection and add labels for mutated residues zoom_animation_duration = 500 sel = '{\"or\": [' + ', '.join(["{\"serial\": " + str(a) + "}" for a in zoom_atoms]) + ']}' zoom = 'viewer.zoomTo(' + sel + ',' + f'{zoom_animation_duration});' for atom in zoom_atoms: sel = '{\"serial\": ' + str(atom) + '}' row = ppi_df[ppi_df['atom_number'] == atom].iloc[0] label = protein_letters_3to1[row['residue_name']] + row['chain_id'] + str(row['residue_number']) + row['insertion'] styles += 'viewer.addLabel(' + f"\"{label}\"," + "{fontSize:16, fontColor:\"red\", backgroundOpacity: 0.0}," + sel + ');\n' # Construct 3Dmol.js visualization script embedded in HTML html = ( """
""" ) return f"""""" def predict(models, temp_dir, *inputs): logging.info('Starting prediction') # Process input pdb_path, ppi_path, muts, muts_on_interface = process_inputs(inputs, temp_dir) # Create dataframe df = pd.DataFrame({ 'Mutation': muts, 'ddG [kcal/mol]': len(muts) * [np.nan], '10A Interface': muts_on_interface, 'Attn Id': len(muts) * [np.nan], }) # Show warning if some mutations are not on the interface muts_not_on_interface = df[~df['10A Interface']]['Mutation'].tolist() n_muts_not_on_interface = len(muts_not_on_interface) if n_muts_not_on_interface: n_muts_warn = 5 muts_not_on_interface = ';'.join(muts_not_on_interface[:n_muts_warn]) if n_muts_not_on_interface > n_muts_warn: muts_not_on_interface += f'... (and {n_muts_not_on_interface - n_muts_warn} more)' gr.Warning(( f"{muts_not_on_interface} {'is' if n_muts_not_on_interface == 1 else 'are'} not on the interface. " f"The model will predict the effect{'s' if n_muts_not_on_interface > 1 else ''} of " f"mutation{'s' if n_muts_not_on_interface > 1 else ''} on the whole complex. " f"This may lead to less accurate predictions." )) logging.info('Inputs processed') # Predict using interface for mutations on the interface and using the whole complex otherwise attn_ppi, attn_pdb = None, None for df_sub, path in [ [df[df['10A Interface']], ppi_path], [df[~df['10A Interface']], pdb_path] ]: if not len(df_sub): continue # Predict try: ddg, attn = predict_ddg(models, path, df_sub['Mutation'].tolist(), return_attn=True) except Exception as e: print(f"Prediction failed. {str(e)}") raise gr.Error(f"Prediction failed. {str(e)}") ddg = ddg.detach().numpy().tolist() logging.info(f'Predictions made for {path}') # Update dataframe and attention tensor idx = df_sub.index df.loc[idx, 'ddG [kcal/mol]'] = ddg df.loc[idx, 'Attn Id'] = np.arange(len(idx)) if path == ppi_path: attn_ppi = attn else: attn_pdb = attn df['Attn Id'] = df['Attn Id'].astype(int) # Round ddG values df['ddG [kcal/mol]'] = df['ddG [kcal/mol]'].round(3) # Create PPI-specific dropdown dropdown = gr.Dropdown( df['Mutation'].tolist(), value=df['Mutation'].iloc[0], interactive=True, visible=True, label="Mutation to visualize", ) # Predefine plot arguments for all dropdown choices dropdown_choices_to_plot_args = { mut: ( pdb_path, ppi_path if df[df['Mutation'] == mut]['10A Interface'].iloc[0] else pdb_path, mut, attn_ppi if df[df['Mutation'] == mut]['10A Interface'].iloc[0] else attn_pdb, df[df['Mutation'] == mut]['Attn Id'].iloc[0] ) for mut in df['Mutation'] } # Create dataframe file path = 'ppiformer_ddg_predictions.csv' if n_muts_not_on_interface: df = df[['Mutation', 'ddG [kcal/mol]', '10A Interface']] df.to_csv(path, index=False) df = gr.Dataframe( value=df, headers=['Mutation', 'ddG [kcal/mol]', '10A Interface'], datatype=['str', 'number', 'bool'], col_count=(3, 'fixed'), ) else: df = df[['Mutation', 'ddG [kcal/mol]']] df.to_csv(path, index=False) df = gr.Dataframe( value=df, headers=['Mutation', 'ddG [kcal/mol]'], datatype=['str', 'number'], col_count=(2, 'fixed'), ) logging.info('Prediction results prepared') return df, path, dropdown, dropdown_choices_to_plot_args def update_plot(dropdown, dropdown_choices_to_plot_args): return plot_3dmol(*dropdown_choices_to_plot_args[dropdown]) app = gr.Blocks(theme=gr.themes.Default(primary_hue="green", secondary_hue="pink")) with app: # Input GUI gr.Markdown(value=""" # PPIformer Web ### Computational Design of Protein-Protein Interactions """) gr.Image("assets/readme-dimer-close-up.png") gr.Markdown(value=""" [PPIformer](https://github.com/anton-bushuiev/PPIformer/tree/main) is a state-of-the-art predictor of the effects of mutations on protein-protein interactions (PPIs), as quantified by the binding free energy changes (ddG). PPIformer was shown to successfully identify known favourable mutations of the [staphylokinase thrombolytics](https://pubmed.ncbi.nlm.nih.gov/10942387/) and a [human antibody](https://www.pnas.org/doi/10.1073/pnas.2122954119) against the SARS-CoV-2 spike protein. The model was pre-trained on the [PPIRef](https://github.com/anton-bushuiev/PPIRef) dataset via a coarse-grained structural masked modeling and fine-tuned on the [SKEMPI v2.0](https://life.bsc.es/pid/skempi2) dataset via log odds. Please see more details in [our ICLR 2024 paper](https://arxiv.org/abs/2310.18515). **Inputs.** To use PPIformer on your data, please specify the PPI structure (PDB code or .pdb file), interacting proteins of interest (chain codes in the file) and mutations (semicolon-separated list or file with mutations in the [standard format](https://foldxsuite.crg.eu/parameter/mutant-file): wild-type residue, chain, residue number, mutant residue). For inspiration, you can use one of the examples below: click on one of the rows to pre-fill the inputs. After specifying the inputs, press the button to predict the effects of mutations on the PPI. Currently the model runs on CPU, so the predictions may take a few minutes. **Outputs.** After making a prediction with the model, you will see binding free energy changes for each mutation (ddG values in kcal/mol). A more negative value indicates an improvement in affinity, whereas a more positive value means a reduction in affinity. Below you will also see a 3D visualization of the PPI with wild types of mutated residues highlighted in red. The visualization additionally shows the attention coefficients of the model for the nearest neighboring residues, which quantifies the contribution of the residues to the predicted ddG value. The brighter and thicker a residue is, the more attention the model paid to it. """) with gr.Row(equal_height=True): with gr.Column(): gr.Markdown("## PPI structure") with gr.Row(equal_height=True): pdb_code = gr.Textbox(placeholder="1BUI", label="PDB code", info="Protein Data Bank identifier for the structure (https://www.rcsb.org/)") partners = gr.Textbox(placeholder="A,B,C", label="Partners", info="Protein chain identifiers in the PDB file forming the PPI interface (two or more)") pdb_path = gr.File(file_count="single", label="Or .pdb file instead of PDB code (your structure will only be used for this prediction and not stored anywhere)") with gr.Column(): gr.Markdown("## Mutations") muts = gr.Textbox(placeholder="SC16A;FC47A;SC16A,FC47A", label="List of (multi-point) mutations", info="SC16A;FC47A;SC16A,FC47A for three mutations: serine to alanine at position 16 in chain C, phenylalanine to alanine at position 47 in chain C, and their double-point combination") muts_path = gr.File(file_count="single", label="Or file with mutations") examples = gr.Examples( examples=[ ["1BUI", "A,B,C", "SC16A,FC47A;SC16A;FC47A"], ["3QIB", "A,B,P,C,D", "YP7F,TP12S;YP7F;TP12S"], ["1KNE", "A,P", ';'.join([f"TP6{a}" for a in AMINO_ACID_CODES_1])] ], inputs=[pdb_code, partners, muts], label="Examples (click on a line to pre-fill the inputs)", cache_examples=False ) # Predict GUI predict_button = gr.Button(value="Predict effects of mutations on PPI", variant="primary") # Output GUI gr.Markdown("## Predictions") df_file = gr.File(label="Download predictions as .csv", interactive=False, visible=True) df = gr.Dataframe( headers=["Mutation", "ddG [kcal/mol]"], datatype=["str", "number"], col_count=(2, "fixed"), ) dropdown = gr.Dropdown(interactive=True, visible=False) dropdown_choices_to_plot_args = gr.State([]) plot = gr.HTML() # Bottom info box gr.Markdown(value="""