blackhole_models_ppi / inference_app.py
OleinikovasV's picture
Update inference_app.py
085cf46 verified
raw
history blame
3.36 kB
import time
import json
import gradio as gr
from gradio_molecule3d import Molecule3D
import numpy as np
from biotite.structure.io.pdb import PDBFile
def set_all_to_zero(input_pdb_file_1, input_pdb_file_2, output_file):
structure1 = PDBFile.read(input_pdb_file_1).get_structure()
structure2 = PDBFile.read(input_pdb_file_2).get_structure()
structure1.coord = np.zeros_like(structure1.coord)
structure2.coord = np.zeros_like(structure2.coord)
out_structure = structure1 + structure2
file = PDBFile()
file.set_structure(out_structure)
file.write(output_file)
def predict(input_seq_1, input_msa_1, input_protein_1, input_seq_2, input_msa_2, input_protein_2):
# def predict(input_protein_1, input_protein_2):
start_time = time.time()
# Do inference here
# return an output pdb file with the protein and two chains A and B.
output_file = "test_out.pdb"
set_all_to_zero(input_protein_1, input_protein_2, output_file)
# also return a JSON with any metrics you want to report
metrics = {"F_nat": 100}
end_time = time.time()
run_time = end_time - start_time
return output_file, json.dumps(metrics), run_time
with gr.Blocks() as app:
gr.Markdown("# Template for inference")
gr.Markdown("Title, description, and other information about the model")
with gr.Row():
with gr.Column():
input_seq_1 = gr.Textbox(lines=3, label="Input Protein 1 sequence (FASTA)")
input_msa_1 = gr.File(label="Input MSA Protein 1 (A3M)")
input_protein_1 = gr.File(label="Input Protein 1 monomer (PDB)")
with gr.Column():
input_seq_2 = gr.Textbox(lines=3, label="Input Protein 2 sequence (FASTA)")
input_msa_2 = gr.File(label="Input MSA Protein 2 (A3M)")
input_protein_2 = gr.File(label="Input Protein 2 structure (PDB)")
# define any options here
# for automated inference the default options are used
# slider_option = gr.Slider(0,10, label="Slider Option")
# checkbox_option = gr.Checkbox(label="Checkbox Option")
# dropdown_option = gr.Dropdown(["Option 1", "Option 2", "Option 3"], label="Radio Option")
btn = gr.Button("Run Inference")
gr.Examples(
[
[
"",
"",
"3v1c_A.pdb",
"",
"",
"3v1c_B.pdb",
],
],
[input_seq_1, input_msa_1, input_protein_1, input_seq_2, input_msa_2, input_protein_2],
)
reps = [
{
"model": 0,
"style": "cartoon",
"chain": "A",
"color": "whiteCarbon",
},
{
"model": 0,
"style": "cartoon",
"chain": "B",
"color": "greenCarbon",
},
{
"model": 0,
"chain": "A",
"style": "stick",
"sidechain": True,
"color": "whiteCarbon",
},
{
"model": 0,
"chain": "B",
"style": "stick",
"sidechain": True,
"color": "greenCarbon"
}
]
# outputs
out = Molecule3D(reps=reps)
metrics = gr.JSON(label="Metrics")
run_time = gr.Textbox(label="Runtime")
btn.click(predict, inputs=[input_seq_1, input_msa_1, input_protein_1, input_seq_2, input_msa_2, input_protein_2], outputs=[out, metrics, run_time])
app.launch()