File size: 22,913 Bytes
29bd8b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
# print("""
#  __  __    _    ___ _   _ _____ _____ _   _    _    _   _  ____ _____ 
# |  \/  |  / \  |_ _| \ | |_   _| ____| \ | |  / \  | \ | |/ ___| ____|
# | |\/| | / _ \  | ||  \| | | | |  _| |  \| | / _ \ |  \| | |   |  _|  
# | |  | |/ ___ \ | || |\  | | | | |___| |\  |/ ___ \| |\  | |___| |___ 
# |_|  |_/_/   \_\___|_| \_| |_| |_____|_| \_/_/   \_\_| \_|\____|_____|
                                                                      
#                  ____  ____  _____    _    _  __
#                 | __ )|  _ \| ____|  / \  | |/ /
#                 |  _ \| |_) |  _|   / _ \ | ' / 
#                 | |_) |  _ <| |___ / ___ \| . \ 
#                 |____/|_| \_\_____/_/   \_\_|\_\
# """)
import os
# os.system("pip uninstall -y gradio")
# os.system("pip install gradio==3.50.2")
# os.system("pip uninstall -y spaces")
# os.system("pip install spaces==0.8")
os.system("pip uninstall -y torch")
os.system("pip install torch==2.0.1")

import sys
import copy
import random
import tempfile
import shutil
import logging
from pathlib import Path
from functools import partial

import spaces
import gradio as gr
import torch
import numpy as np
import pandas as pd
from Bio.PDB.Polypeptide import protein_letters_3to1
from biopandas.pdb import PandasPdb
from colour import Color
from colour import RGB_TO_COLOR_NAMES

from mutils.proteins import AMINO_ACID_CODES_1
from mutils.pdb import download_pdb
from mutils.mutations import Mutation
from ppiref.extraction import PPIExtractor
from ppiref.utils.ppi import PPIPath
from ppiref.utils.residue import Residue
from ppiformer.tasks.node import DDGPPIformer
from ppiformer.utils.api import download_from_zenodo
from ppiformer.utils.api import predict_ddg as predict_ddg_
from ppiformer.utils.torch import fill_diagonal
from ppiformer.definitions import PPIFORMER_WEIGHTS_DIR


import pkg_resources
import sys

def print_package_versions():
    installed_packages = sorted([f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set])
    print("Installed packages and their versions:")
    for package in installed_packages:
        print(package)

    print("\nPython version:")
    print(sys.version)

print_package_versions()


logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)

random.seed(0)


@spaces.GPU
def predict_ddg(models, ppi, muts, return_attn):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"[INFO] Device on prediction: {device}")
    models = [model.to(device) for model in models]
    if return_attn:
        ddg_pred, attns = predict_ddg_(models, ppi, muts, return_attn=return_attn)
        return ddg_pred.detach().cpu(), attns.detach().cpu()
    else:
        ddg_pred = predict_ddg_(models, ppi, muts, return_attn=return_attn)
        return ddg_pred.detach().cpu()


def process_inputs(inputs, temp_dir):
    pdb_code, pdb_path, partners, muts, muts_path = inputs

    # Check inputs
    if not pdb_code and not pdb_path:
        raise gr.Error("PPI structure not specified.")
    
    if pdb_code and pdb_path:
        gr.Warning("Both PDB code and PDB file specified. Using PDB file.")

    if not partners:
        raise gr.Error("Partners not specified.")

    if not muts and not muts_path:
        raise gr.Error("Mutations not specified.")
    
    if muts and muts_path:
        gr.Warning("Both mutations and mutations file specified. Using mutations file.")

    # Prepare PDB input
    if pdb_path:
        # convert file name to PPIRef format
        new_pdb_path = temp_dir / f"pdb/{pdb_path.name.replace('_', '-')}"
        new_pdb_path.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy(str(pdb_path), str(new_pdb_path))
        pdb_path = new_pdb_path
        pdb_path = Path(pdb_path)
    else:
        try:
            pdb_code = pdb_code.strip().lower()
            pdb_path = temp_dir / f'pdb/{pdb_code}.pdb'
            download_pdb(pdb_code, path=pdb_path)
        except:
            raise gr.Error("PDB download failed.")

    # Parse partners
    partners = list(map(lambda x: x.strip(), partners.split(',')))

    # Add partners to file name    
    pdb_path = pdb_path.rename(pdb_path.with_stem(f"{pdb_path.stem}-{'-'.join(partners)}"))

    # Extract PPI into temp dir
    try:
        ppi_dir = temp_dir / 'ppi'
        extractor = PPIExtractor(out_dir=ppi_dir, nest_out_dir=True, join=True, radius=10.0)
        extractor.extract(pdb_path, partners=partners)
        ppi_path = PPIPath.construct(ppi_dir, pdb_path.stem, partners)
    except:
        raise gr.Error("PPI extraction failed.")

    # Prepare mutations input
    if muts_path:
        muts_path = Path(muts_path)
        muts = muts_path.read_text()
    
    # Check mutations
        
    # Basic format
    try:
        muts = [Mutation.from_str(m) for m in muts.strip().split(';') if m.strip()]
    except Exception as e:
        raise gr.Error(f'Mutations parsing failed: {e}')
    
    # Partners
    for mut in muts:
        for pmut in mut.muts:
            if pmut.chain not in partners:
                raise gr.Error(f'Chain of point mutation {pmut} is not in the list of partners {partners}.')
    
    # Consistency with provided .pdb
    muts_on_interface = []
    for mut in muts:
        if mut.wt_in_pdb(ppi_path):
            val = True
        elif mut.wt_in_pdb(pdb_path):
            val = False
        else:
            raise gr.Error(f'Wild-type of mutation {mut} is not in the provided .pdb file.')
        muts_on_interface.append(val)

    muts = [str(m) for m in muts]

    return pdb_path, ppi_path, muts, muts_on_interface


def plot_3dmol(pdb_path, ppi_path, mut, attn, attn_mut_id=0):
    # NOTE 3DMol.js adapted from https://huggingface.co/spaces/huhlim/cg2all/blob/main/app.py

    # Read PDB for 3Dmol.js
    with open(pdb_path, "r") as fp:
        lines = fp.readlines()
    mol = ""
    for l in lines:
        mol += l
    mol = mol.replace("OT1", "O  ")
    mol = mol.replace("OT2", "OXT")

    # Read PPI to customize 3Dmol.js visualization
    ppi_df = PandasPdb().read_pdb(ppi_path).df['ATOM']
    ppi_df = ppi_df.groupby(list(Residue._fields)).apply(lambda df: df[df['atom_name'] == 'CA'].iloc[0]).reset_index(drop=True)
    ppi_df['id'] = ppi_df.apply(lambda row: ':'.join([row['residue_name'], row['chain_id'], str(row['residue_number']), row['insertion']]), axis=1)
    ppi_df['id'] = ppi_df['id'].apply(lambda x: x[:-1] if x[-1] == ':' else x)
    muts_id = Mutation.from_str(mut).wt_to_graphein()  # flatten ids of all sp muts 
    ppi_df['mutated'] = ppi_df.apply(lambda row: row['id'] in muts_id, axis=1)

    # Prepare attention coeffictients per residue (normalized sum of direct attention from mutated residues)
    attn = torch.nan_to_num(attn, nan=1e-10)
    attn_sub = attn[:, attn_mut_id, 0, :, 0, :, :, :]  # models, layers, heads, tokens, tokens
    idx_mutated = torch.from_numpy(ppi_df.index[ppi_df['mutated']].to_numpy())
    attn_sub = fill_diagonal(attn_sub, 1e-10)
    attn_mutated = attn_sub[..., idx_mutated, :]
    attn_mutated.shape
    attns_per_token = torch.sum(attn_mutated, dim=(0, 1, 2, 3))
    attns_per_token = (attns_per_token - attns_per_token.min()) / (attns_per_token.max() - attns_per_token.min())
    attns_per_token += 1e-10
    ppi_df['attn'] = attns_per_token.numpy()

    chains = ppi_df.sort_values('attn', ascending=False)['chain_id'].unique()

    # Customize 3Dmol.js visualization https://3dmol.csb.pitt.edu/doc/
    styles = []
    zoom_atoms = []

    # Cartoon chains
    preferred_colors = ['LimeGreen', 'HotPink', 'RoyalBlue']
    all_colors = [c[0] for c in RGB_TO_COLOR_NAMES.values()]
    all_colors = [c for c in all_colors if c not in preferred_colors + ['Black', 'White']]
    random.shuffle(all_colors)
    all_colors = preferred_colors + all_colors
    all_colors = [Color(c) for c in all_colors]
    chain_to_color = dict(zip(chains, all_colors))
    for chain in chains:
        styles.append([{"chain": chain}, {"cartoon": {"color": chain_to_color[chain].hex_l, "opacity": 0.6}}])

    # Stick PPI and atoms for zoom
    # TODO Insertions
    for _, row in ppi_df.iterrows():
        color = copy.deepcopy(chain_to_color[row['chain_id']])
        color.saturation = row['attn']
        color = color.hex_l
        if row['mutated']:
            styles.append([
                {'chain': row['chain_id'], 'resi': str(row['residue_number'])},
                {'stick': {'color': 'red', 'radius': 0.2, 'opacity': 1.0}}
            ])
            zoom_atoms.append(row['atom_number'])
        else:
            styles.append([
                {'chain': row['chain_id'], 'resi': str(row['residue_number'])},
                {'stick': {'color': color, 'radius': row['attn'] / 5, 'opacity': row['attn']}}
            ])

    # Convert style dicts to JS lines
    styles = ''.join(['viewer.addStyle(' + ', '.join([str(s).replace("'", '"') for s in dcts]) + ');\n' for dcts in styles])

    # Convert zoom atoms to 3DMol.js selection and add labels for mutated residues
    zoom_animation_duration = 500
    sel = '{\"or\": [' + ', '.join(["{\"serial\": " + str(a) + "}" for a in zoom_atoms]) + ']}'
    zoom = 'viewer.zoomTo(' + sel + ',' + f'{zoom_animation_duration});'
    for atom in zoom_atoms:
        sel = '{\"serial\": ' + str(atom) + '}'
        row = ppi_df[ppi_df['atom_number'] == atom].iloc[0]
        label = protein_letters_3to1[row['residue_name']] + row['chain_id'] + str(row['residue_number']) + row['insertion']
        styles += 'viewer.addLabel(' + f"\"{label}\"," + "{fontSize:16, fontColor:\"red\", backgroundOpacity: 0.0}," + sel + ');\n'

    # Construct 3Dmol.js visualization script embedded in HTML
    html = (
        """<!DOCTYPE html>
        <html>
        <head>    
    <meta http-equiv="content-type" content="text/html; charset=UTF-8" />
    <style>
    body{
        font-family:sans-serif
    }
    .mol-container {
    width: 100%;
    height: 600px;
    position: relative;
    }
    .mol-container select{
        background-image:None;
    }
    </style>
     <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.6.3/jquery.min.js" integrity="sha512-STof4xm1wgkfm7heWqFJVn58Hm3EtS31XFaagaa8VMReCXAkQnJZ+jEy8PCC/iT18dFy95WcExNHFTqLyp72eQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
    <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
    </head>
    <body>  
    <div id="container" class="mol-container"></div>
  
            <script>
               let pdb = `"""
        + mol
        + """`  
      
             $(document).ready(function () {
                let element = $("#container");
                let config = { backgroundColor: "white" };
                let viewer = $3Dmol.createViewer(element, config);
                viewer.addModel(pdb, "pdb");
                viewer.setStyle({"model": 0}, {"ray_opaque_background": "off"}, {"stick": {"color": "lightgrey", "opacity": 0.5}});
          """
        + styles
        + zoom
        + """
                viewer.render();
              })
        </script>
        </body></html>"""
    )

    return f"""<iframe style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera; 
    display-capture; encrypted-media;" sandbox="allow-modals allow-forms 
    allow-scripts allow-same-origin allow-popups 
    allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" 
    allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>"""


def predict(models, temp_dir, *inputs):
    logging.info('Starting prediction')

    # Process input
    pdb_path, ppi_path, muts, muts_on_interface = process_inputs(inputs, temp_dir)

    # Create dataframe
    df = pd.DataFrame({
        'Mutation': muts,
        'ddG [kcal/mol]': len(muts) * [np.nan],
        '10A Interface': muts_on_interface,
        'Attn Id': len(muts) * [np.nan],
    })

    # Show warning if some mutations are not on the interface
    muts_not_on_interface = df[~df['10A Interface']]['Mutation'].tolist()
    n_muts_not_on_interface = len(muts_not_on_interface)
    if n_muts_not_on_interface:
        n_muts_warn = 5
        muts_not_on_interface = ';'.join(muts_not_on_interface[:n_muts_warn])
        if n_muts_not_on_interface > n_muts_warn:
            muts_not_on_interface += f'... (and {n_muts_not_on_interface - n_muts_warn} more)'
        gr.Warning((
            f"{muts_not_on_interface} {'is' if n_muts_not_on_interface == 1 else 'are'} not on the interface. "
            f"The model will predict the effect{'s' if n_muts_not_on_interface > 1 else ''} of "
            f"mutation{'s' if n_muts_not_on_interface > 1 else ''} on the whole complex. "
            f"This may lead to less accurate predictions."
        ))

    logging.info('Inputs processed')
    
    # Predict using interface for mutations on the interface and using the whole complex otherwise
    attn_ppi, attn_pdb = None, None
    for df_sub, path in [
        [df[df['10A Interface']], ppi_path],
        [df[~df['10A Interface']], pdb_path]
    ]:
        if not len(df_sub):
            continue

        # Predict
        try:
            ddg, attn = predict_ddg(models, path, df_sub['Mutation'].tolist(), return_attn=True)
        except Exception as e:
            print(f"Prediction failed. {str(e)}")
            raise gr.Error(f"Prediction failed. {str(e)}")
        ddg = ddg.detach().numpy().tolist()

        logging.info(f'Predictions made for {path}')

        # Update dataframe and attention tensor
        idx = df_sub.index
        df.loc[idx, 'ddG [kcal/mol]'] = ddg
        df.loc[idx, 'Attn Id'] = np.arange(len(idx))

        if path == ppi_path:
            attn_ppi = attn
        else:
            attn_pdb = attn
    df['Attn Id'] = df['Attn Id'].astype(int)

    # Round ddG values
    df['ddG [kcal/mol]'] = df['ddG [kcal/mol]'].round(3)

    # Create PPI-specific dropdown
    dropdown = gr.Dropdown(
        df['Mutation'].tolist(), value=df['Mutation'].iloc[0],
        interactive=True,  visible=True, label="Mutation to visualize",
    )

    # Predefine plot arguments for all dropdown choices
    dropdown_choices_to_plot_args = {
        mut: (
            pdb_path, 
            ppi_path if df[df['Mutation'] == mut]['10A Interface'].iloc[0] else pdb_path, 
            mut, 
            attn_ppi if df[df['Mutation'] == mut]['10A Interface'].iloc[0] else attn_pdb,
            df[df['Mutation'] == mut]['Attn Id'].iloc[0]
        )
        for mut in df['Mutation']
    }

    # Create dataframe file
    path = 'ppiformer_ddg_predictions.csv'
    if n_muts_not_on_interface:
        df = df[['Mutation', 'ddG [kcal/mol]', '10A Interface']]
        df.to_csv(path, index=False)
        df = gr.Dataframe(
            value=df,
            headers=['Mutation', 'ddG [kcal/mol]', '10A Interface'],
            datatype=['str', 'number', 'bool'],
            col_count=(3, 'fixed'),
        )
    else:
        df = df[['Mutation', 'ddG [kcal/mol]']]
        df.to_csv(path, index=False)
        df = gr.Dataframe(
            value=df,
            headers=['Mutation', 'ddG [kcal/mol]'],
            datatype=['str', 'number'],
            col_count=(2, 'fixed'),
        )

    logging.info('Prediction results prepared')

    return df, path, dropdown, dropdown_choices_to_plot_args


def update_plot(dropdown, dropdown_choices_to_plot_args):
    return plot_3dmol(*dropdown_choices_to_plot_args[dropdown])


app = gr.Blocks(theme=gr.themes.Default(primary_hue="green", secondary_hue="pink"))
with app:

    # Input GUI
    gr.Markdown(value="""
        # PPIformer Web
        ### Computational Design of Protein-Protein Interactions
    """)
    gr.Image("assets/readme-dimer-close-up.png")
    gr.Markdown(value="""
        [PPIformer](https://github.com/anton-bushuiev/PPIformer/tree/main) is a state-of-the-art predictor of the effects of mutations 
        on protein-protein interactions (PPIs), as quantified by the binding free energy changes (ddG). PPIformer was shown to successfully 
        identify known favourable mutations of the [staphylokinase thrombolytics](https://pubmed.ncbi.nlm.nih.gov/10942387/) 
        and a [human antibody](https://www.pnas.org/doi/10.1073/pnas.2122954119) against the SARS-CoV-2 spike protein. The model was pre-trained 
        on the [PPIRef](https://github.com/anton-bushuiev/PPIRef) 
        dataset via a coarse-grained structural masked modeling and fine-tuned on the [SKEMPI v2.0](https://life.bsc.es/pid/skempi2) dataset via log odds. 
        Please see more details in [our ICLR 2024 paper](https://arxiv.org/abs/2310.18515).
                
        **Inputs.** To use PPIformer on your data, please specify the PPI structure (PDB code or .pdb file), interacting proteins of interest 
        (chain codes in the file) and mutations (semicolon-separated list or file with mutations in the 
        [standard format](https://foldxsuite.crg.eu/parameter/mutant-file): wild-type residue, chain, residue number, mutant residue). 
        For inspiration, you can use one of the examples below: click on one of the rows to pre-fill the inputs. After specifying the inputs, 
        press the button to predict the effects of mutations on the PPI. Currently the model runs on CPU, so the predictions may take a few minutes.
                
        **Outputs.** After making a prediction with the model, you will see binding free energy changes for each mutation (ddG values in kcal/mol). 
        A more negative value indicates an improvement in affinity, whereas a more positive value means a reduction in affinity. 
        Below you will also see a 3D visualization of the PPI with wild types of mutated residues highlighted in red. The visualization additionally shows
        the attention coefficients of the model for the nearest neighboring residues, which quantifies the contribution of the residues 
        to the predicted ddG value. The brighter and thicker a residue is, the more attention the model paid to it.
    """)

    with gr.Row(equal_height=True):
        with gr.Column():
            gr.Markdown("## PPI structure")
            with gr.Row(equal_height=True):
                pdb_code = gr.Textbox(placeholder="1BUI", label="PDB code", info="Protein Data Bank identifier for the structure (https://www.rcsb.org/)")
                partners = gr.Textbox(placeholder="A,B,C", label="Partners", info="Protein chain identifiers in the PDB file forming the PPI interface (two or more)")
            pdb_path = gr.File(file_count="single", label="Or .pdb file instead of PDB code (your structure will only be used for this prediction and not stored anywhere)")

        with gr.Column():
            gr.Markdown("## Mutations")
            muts = gr.Textbox(placeholder="SC16A;FC47A;SC16A,FC47A", label="List of (multi-point) mutations", info="SC16A;FC47A;SC16A,FC47A for three mutations: serine to alanine at position 16 in chain C, phenylalanine to alanine at position 47 in chain C, and their double-point combination")
            muts_path = gr.File(file_count="single", label="Or file with mutations")

    examples = gr.Examples(
        examples=[
            ["1BUI", "A,B,C", "SC16A,FC47A;SC16A;FC47A"],
            ["3QIB", "A,B,P,C,D", "YP7F,TP12S;YP7F;TP12S"],
            ["1KNE", "A,P", ';'.join([f"TP6{a}" for a in AMINO_ACID_CODES_1])]
        ],
        inputs=[pdb_code, partners, muts],
        label="Examples (click on a line to pre-fill the inputs)",
        cache_examples=False
    )

    # Predict GUI
    predict_button = gr.Button(value="Predict effects of mutations on PPI", variant="primary")

    # Output GUI
    gr.Markdown("## Predictions")
    df_file = gr.File(label="Download predictions as .csv", interactive=False, visible=True)
    df = gr.Dataframe(
        headers=["Mutation", "ddG [kcal/mol]"],
        datatype=["str", "number"],
        col_count=(2, "fixed"),
    )
    dropdown = gr.Dropdown(interactive=True, visible=False)
    dropdown_choices_to_plot_args = gr.State([])
    plot = gr.HTML()

    # Bottom info box
    gr.Markdown(value="""
        <br/>
        
        ## About this web
                
        **Use cases**. The predictor can be used in: (i) Drug Discovery for the development of novel drugs and vaccines for various diseases such as cancer, 
        neurodegenerative disorders, and infectious diseases, (ii) Biotechnological Applications to develop new biocatalysts for biofuels, 
        industrial chemicals, and pharmaceuticals (iii) Therapeutic Protein Design to develop therapeutic proteins with enhanced stability, 
        specificity, and efficacy, and (iv) Mechanistic Studies to gain insights into fundamental biological processes, such as signal transduction, 
        gene regulation, and immune response.
                
        **Acknowledgement**. Please, use the following citation to acknowledge the use of our service. The web server is provided free of charge for non-commercial use.
        > Bushuiev, Anton, Roman Bushuiev, Petr Kouba, Anatolii Filkin, Marketa Gabrielova, Michal Gabriel, Jiri Sedlar, Tomas Pluskal, Jiri Damborsky, Stanislav Mazurenko, Josef Sivic. 
        > "Learning to design protein-protein interactions with enhanced generalization". The Twelfth International Conference on Learning Representations (ICLR 2024). 
        > [https://arxiv.org/abs/2310.18515](https://arxiv.org/abs/2310.18515).
                
        **Contact**. Please share your feedback or report any bugs through [GitHub Issues](https://github.com/anton-bushuiev/PPIformer/issues/new), or feel free to contact us directly at [anton.bushuiev@cvut.cz](mailto:anton.bushuiev@cvut.cz).
    """)
    gr.Image("assets/logos.png")

    # Download weights from Zenodo
    download_from_zenodo('weights.zip')

    # Set device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"[INFO] Device on start: {device}")

    # Load models
    models = [
        DDGPPIformer.load_from_checkpoint(
            PPIFORMER_WEIGHTS_DIR / f'ddg_regression/{i}.ckpt',
            map_location=torch.device('cpu')
        ).eval()
        for i in range(3)
    ]
    models = [model.to(device) for model in models]

    # Create temporary directory for storing downloaded PDBs and extracted PPIs
    temp_dir_obj = tempfile.TemporaryDirectory()
    temp_dir = Path(temp_dir_obj.name)

    # Main logic
    inputs = [pdb_code, pdb_path, partners, muts, muts_path]
    outputs = [df, df_file, dropdown, dropdown_choices_to_plot_args]
    predict = partial(predict, models, temp_dir)
    predict_button.click(predict, inputs=inputs, outputs=outputs)

    # Update plot on dropdown change
    dropdown.change(update_plot, inputs=[dropdown, dropdown_choices_to_plot_args], outputs=[plot])

app.launch(allowed_paths=['./assets'])