File size: 3,312 Bytes
d0328af 5e5f5c4 1f0c7b9 d0328af 1f0c7b9 d0328af 1f0c7b9 e81c73d 1f0c7b9 e81c73d 1f0c7b9 e81c73d 1f0c7b9 3b77f6e 1f0c7b9 3b77f6e 1f0c7b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import gradio as gr
# from PIL import Image
# import base64
# import io
import glob
import numpy as np
from gradio_molecule3d import Molecule3D
import os
def load(value: str):
full_pdb_path = './test_set/' + value
sdf_path = glob.glob(full_pdb_path.rstrip('.pdb') + '*.sdf')
assert len(sdf_path) == 1
sdf_path = sdf_path[0]
return [full_pdb_path, sdf_path]
pdb_files = glob.glob("./test_set/*/*.pdb")
pdb_files = [f for f in pdb_files if 'ligand' not in f and 'complex' not in f]
pdb_files = sorted(pdb_files)
pdb_files = [f[11:] for f in pdb_files if './test_set/' in f and '_tmp.pdb' not in f]
pdb_files = [f for f in pdb_files if 'pdb/' not in f]
reps = [
{
"model": 0,
"style": "cartoon",
"color": "whiteCarbon",
},
{
"model": 1,
"style": "stick",
"color": "redCarbon",
"residue_range": "",
},
# {
# "model": 2,
# "style": "stick",
# "color": "greenCarbon",
# "residue_range": "",
# }
]
from sample_for_pocket import call, OUT_DIR, Metrics
from rdkit import Chem
import json
def generate(value: str):
protein_path, ligand_path = load(value)
call(protein_path, ligand_path)
out_fns = sorted(glob.glob(f'{OUT_DIR}/*.sdf'))
gr.update(choices=out_fns, value=out_fns[0])
return out_fns
def show(value: str, out_fn: str):
protein_path, ligand_path = load(value)
# sdf_mol = Chem.SDMolSupplier(out_fn, removeHs=False)[0]
# # get all properties from sdf_mol
# props = sdf_mol.GetPropsAsDict()
return [protein_path, out_fn]
class NpEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
return super(NpEncoder, self).default(obj)
# def evaluate(value: str, out_fn: str):
# protein_path, ligand_path = load(value)
# metrics = Metrics(protein_path, ligand_path, out_fn).evaluate()
# return json.dumps(metrics, indent=4, cls=NpEncoder)
with gr.Blocks() as demo:
gr.Markdown("# MolCRAFT: Structure-Based Drug Design in Continuous Parameter Space [ICML 2024]")
dropdown = gr.Dropdown(label="choose a pdb from CrossDocked test set:", choices=pdb_files, value=np.random.choice(pdb_files))
ref_complex = Molecule3D(label="Protein Pocket & Reference Ligand", reps=reps)
# out_ligand = Molecule3D(label='reference molecule', reps=reps)
btn1 = gr.Button("visualize")
btn1.click(load, inputs=dropdown, outputs=ref_complex)
btn2 = gr.Button('generate')
OUT_FILES = [f'./output/{i}.sdf' for i in range(10)]
candidates = gr.Dropdown(label="choose a generated molecule:", choices=OUT_FILES, value=OUT_FILES[0], interactive=True)
btn2.click(generate, inputs=[dropdown], outputs=[candidates])
gen_complex = Molecule3D(label='Generated Molecule', reps=reps)
btn3 = gr.Button('visualize')
btn3.click(show, inputs=[dropdown, candidates], outputs=[gen_complex])
# metrics = gr.Textbox(label='metrics')
# btn4 = gr.Button('evaluate')
# btn4.click(evaluate, inputs=[dropdown, candidates], outputs=[metrics])
if __name__ == '__main__':
demo.launch(share=True)
|