MolCRAFT / app.py
Atomu2014's picture
remove PIL py3dmol
5e5f5c4
import gradio as gr
# from PIL import Image
# import base64
# import io
import glob
import numpy as np
from gradio_molecule3d import Molecule3D
import os
def load(value: str):
full_pdb_path = './test_set/' + value
sdf_path = glob.glob(full_pdb_path.rstrip('.pdb') + '*.sdf')
assert len(sdf_path) == 1
sdf_path = sdf_path[0]
return [full_pdb_path, sdf_path]
pdb_files = glob.glob("./test_set/*/*.pdb")
pdb_files = [f for f in pdb_files if 'ligand' not in f and 'complex' not in f]
pdb_files = sorted(pdb_files)
pdb_files = [f[11:] for f in pdb_files if './test_set/' in f and '_tmp.pdb' not in f]
pdb_files = [f for f in pdb_files if 'pdb/' not in f]
reps = [
{
"model": 0,
"style": "cartoon",
"color": "whiteCarbon",
},
{
"model": 1,
"style": "stick",
"color": "redCarbon",
"residue_range": "",
},
# {
# "model": 2,
# "style": "stick",
# "color": "greenCarbon",
# "residue_range": "",
# }
]
from sample_for_pocket import call, OUT_DIR, Metrics
from rdkit import Chem
import json
def generate(value: str):
protein_path, ligand_path = load(value)
call(protein_path, ligand_path)
out_fns = sorted(glob.glob(f'{OUT_DIR}/*.sdf'))
gr.update(choices=out_fns, value=out_fns[0])
return out_fns
def show(value: str, out_fn: str):
protein_path, ligand_path = load(value)
# sdf_mol = Chem.SDMolSupplier(out_fn, removeHs=False)[0]
# # get all properties from sdf_mol
# props = sdf_mol.GetPropsAsDict()
return [protein_path, out_fn]
class NpEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
return super(NpEncoder, self).default(obj)
# def evaluate(value: str, out_fn: str):
# protein_path, ligand_path = load(value)
# metrics = Metrics(protein_path, ligand_path, out_fn).evaluate()
# return json.dumps(metrics, indent=4, cls=NpEncoder)
with gr.Blocks() as demo:
gr.Markdown("# MolCRAFT: Structure-Based Drug Design in Continuous Parameter Space [ICML 2024]")
dropdown = gr.Dropdown(label="choose a pdb from CrossDocked test set:", choices=pdb_files, value=np.random.choice(pdb_files))
ref_complex = Molecule3D(label="Protein Pocket & Reference Ligand", reps=reps)
# out_ligand = Molecule3D(label='reference molecule', reps=reps)
btn1 = gr.Button("visualize")
btn1.click(load, inputs=dropdown, outputs=ref_complex)
btn2 = gr.Button('generate')
OUT_FILES = [f'./output/{i}.sdf' for i in range(10)]
candidates = gr.Dropdown(label="choose a generated molecule:", choices=OUT_FILES, value=OUT_FILES[0], interactive=True)
btn2.click(generate, inputs=[dropdown], outputs=[candidates])
gen_complex = Molecule3D(label='Generated Molecule', reps=reps)
btn3 = gr.Button('visualize')
btn3.click(show, inputs=[dropdown, candidates], outputs=[gen_complex])
# metrics = gr.Textbox(label='metrics')
# btn4 = gr.Button('evaluate')
# btn4.click(evaluate, inputs=[dropdown, candidates], outputs=[metrics])
if __name__ == '__main__':
demo.launch(share=True)