File size: 3,312 Bytes
d0328af
5e5f5c4
 
 
1f0c7b9
 
 
d0328af
1f0c7b9
d0328af
1f0c7b9
 
e81c73d
1f0c7b9
 
 
 
 
e81c73d
1f0c7b9
 
e81c73d
1f0c7b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b77f6e
 
 
 
1f0c7b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b77f6e
 
 
1f0c7b9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
# from PIL import Image
# import base64
# import io
import glob
import numpy as np
from gradio_molecule3d import Molecule3D

import os


def load(value: str):
    full_pdb_path = './test_set/' + value
    sdf_path = glob.glob(full_pdb_path.rstrip('.pdb') + '*.sdf')
    assert len(sdf_path) == 1
    sdf_path = sdf_path[0]
    return [full_pdb_path, sdf_path]

pdb_files = glob.glob("./test_set/*/*.pdb")
pdb_files = [f for f in pdb_files if 'ligand' not in f and 'complex' not in f]
pdb_files = sorted(pdb_files)
pdb_files = [f[11:] for f in pdb_files if './test_set/' in f and '_tmp.pdb' not in f]
pdb_files = [f for f in pdb_files if 'pdb/' not in f]


reps = [
    {
      "model": 0,
      "style": "cartoon",
      "color": "whiteCarbon",
    },
    {
      "model": 1,
      "style": "stick",
      "color": "redCarbon",
      "residue_range": "",
    },
    # {
    #   "model": 2,
    #   "style": "stick",
    #   "color": "greenCarbon",
    #   "residue_range": "",
    # }
]

from sample_for_pocket import call, OUT_DIR, Metrics
from rdkit import Chem
import json


def generate(value: str):
    protein_path, ligand_path = load(value)
    call(protein_path, ligand_path)
    
    out_fns = sorted(glob.glob(f'{OUT_DIR}/*.sdf'))
    gr.update(choices=out_fns, value=out_fns[0])
    return out_fns


def show(value: str, out_fn: str):
    protein_path, ligand_path = load(value)
    # sdf_mol = Chem.SDMolSupplier(out_fn, removeHs=False)[0]
    # # get all properties from sdf_mol
    # props = sdf_mol.GetPropsAsDict()

    return [protein_path, out_fn]


class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)
    

# def evaluate(value: str, out_fn: str):
#     protein_path, ligand_path = load(value)
#     metrics = Metrics(protein_path, ligand_path, out_fn).evaluate()
#     return json.dumps(metrics, indent=4, cls=NpEncoder)



with gr.Blocks() as demo:
    gr.Markdown("# MolCRAFT: Structure-Based Drug Design in Continuous Parameter Space [ICML 2024]")
    dropdown = gr.Dropdown(label="choose a pdb from CrossDocked test set:", choices=pdb_files, value=np.random.choice(pdb_files))
    ref_complex = Molecule3D(label="Protein Pocket & Reference Ligand", reps=reps)
    # out_ligand = Molecule3D(label='reference molecule', reps=reps)

    btn1 = gr.Button("visualize")
    btn1.click(load, inputs=dropdown, outputs=ref_complex)

    btn2 = gr.Button('generate')
    OUT_FILES = [f'./output/{i}.sdf' for i in range(10)]
    candidates = gr.Dropdown(label="choose a generated molecule:", choices=OUT_FILES, value=OUT_FILES[0], interactive=True)
    btn2.click(generate, inputs=[dropdown], outputs=[candidates])

    gen_complex = Molecule3D(label='Generated Molecule', reps=reps)
    btn3 = gr.Button('visualize')
    btn3.click(show, inputs=[dropdown, candidates], outputs=[gen_complex])

    # metrics = gr.Textbox(label='metrics')
    # btn4 = gr.Button('evaluate')
    # btn4.click(evaluate, inputs=[dropdown, candidates], outputs=[metrics])

if __name__ == '__main__':
    demo.launch(share=True)