|
import gradio as gr |
|
|
|
|
|
|
|
import glob |
|
import numpy as np |
|
from gradio_molecule3d import Molecule3D |
|
|
|
import os |
|
|
|
|
|
def load(value: str): |
|
full_pdb_path = './test_set/' + value |
|
sdf_path = glob.glob(full_pdb_path.rstrip('.pdb') + '*.sdf') |
|
assert len(sdf_path) == 1 |
|
sdf_path = sdf_path[0] |
|
return [full_pdb_path, sdf_path] |
|
|
|
pdb_files = glob.glob("./test_set/*/*.pdb") |
|
pdb_files = [f for f in pdb_files if 'ligand' not in f and 'complex' not in f] |
|
pdb_files = sorted(pdb_files) |
|
pdb_files = [f[11:] for f in pdb_files if './test_set/' in f and '_tmp.pdb' not in f] |
|
pdb_files = [f for f in pdb_files if 'pdb/' not in f] |
|
|
|
|
|
reps = [ |
|
{ |
|
"model": 0, |
|
"style": "cartoon", |
|
"color": "whiteCarbon", |
|
}, |
|
{ |
|
"model": 1, |
|
"style": "stick", |
|
"color": "redCarbon", |
|
"residue_range": "", |
|
}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
from sample_for_pocket import call, OUT_DIR, Metrics |
|
from rdkit import Chem |
|
import json |
|
|
|
|
|
def generate(value: str): |
|
protein_path, ligand_path = load(value) |
|
call(protein_path, ligand_path) |
|
|
|
out_fns = sorted(glob.glob(f'{OUT_DIR}/*.sdf')) |
|
gr.update(choices=out_fns, value=out_fns[0]) |
|
return out_fns |
|
|
|
|
|
def show(value: str, out_fn: str): |
|
protein_path, ligand_path = load(value) |
|
|
|
|
|
|
|
|
|
return [protein_path, out_fn] |
|
|
|
|
|
class NpEncoder(json.JSONEncoder): |
|
def default(self, obj): |
|
if isinstance(obj, np.integer): |
|
return int(obj) |
|
if isinstance(obj, np.floating): |
|
return float(obj) |
|
if isinstance(obj, np.ndarray): |
|
return obj.tolist() |
|
return super(NpEncoder, self).default(obj) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# MolCRAFT: Structure-Based Drug Design in Continuous Parameter Space [ICML 2024]") |
|
dropdown = gr.Dropdown(label="choose a pdb from CrossDocked test set:", choices=pdb_files, value=np.random.choice(pdb_files)) |
|
ref_complex = Molecule3D(label="Protein Pocket & Reference Ligand", reps=reps) |
|
|
|
|
|
btn1 = gr.Button("visualize") |
|
btn1.click(load, inputs=dropdown, outputs=ref_complex) |
|
|
|
btn2 = gr.Button('generate') |
|
OUT_FILES = [f'./output/{i}.sdf' for i in range(10)] |
|
candidates = gr.Dropdown(label="choose a generated molecule:", choices=OUT_FILES, value=OUT_FILES[0], interactive=True) |
|
btn2.click(generate, inputs=[dropdown], outputs=[candidates]) |
|
|
|
gen_complex = Molecule3D(label='Generated Molecule', reps=reps) |
|
btn3 = gr.Button('visualize') |
|
btn3.click(show, inputs=[dropdown, candidates], outputs=[gen_complex]) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
demo.launch(share=True) |
|
|