import time
import json 

import gradio as gr

from gradio_molecule3d import Molecule3D
from run_on_seq import run_on_sample_seqs
from env_consts import RUN_CONFIG_PATH, OUTPUT_PATH



def predict (input_seq_1, input_msa_1, input_protein_1, input_seq_2,input_msa_2,  input_protein_2):
    start_time = time.time()
    # Do inference here
    # return an output pdb file with the protein and two chains A and B.  
    # also return a JSON with any metrics you want to report
    # metrics = {"mean_plddt": 80, "binding_affinity": 2}
    metrics = {}
    run_on_sample_seqs(input_seq_1, input_protein_1, input_seq_2, input_protein_2, OUTPUT_PATH, RUN_CONFIG_PATH)

    end_time = time.time()
    run_time = end_time - start_time

    return OUTPUT_PATH, json.dumps(metrics), run_time

with gr.Blocks() as app:

    gr.Markdown("# Template for inference")

    gr.Markdown("Title, description, and other information about the model")   
    with gr.Row():
        with gr.Column():
            input_seq_1 = gr.Textbox(lines=3, label="Input Protein 1 sequence (FASTA)")
            input_msa_1 = gr.File(label="Input MSA Protein 1 (A3M)")
            input_protein_1 = gr.File(label="Input Protein 2 monomer (PDB)")
        with gr.Column():
            input_seq_2 = gr.Textbox(lines=3, label="Input Protein 2 sequence (FASTA)")
            input_msa_2 = gr.File(label="Input MSA Protein 2 (A3M)")
            input_protein_2 = gr.File(label="Input Protein 2 structure (PDB)")
        
        
    
    # define any options here

    # for automated inference the default options are used
    # slider_option = gr.Slider(0,10, label="Slider Option")
    # checkbox_option = gr.Checkbox(label="Checkbox Option")
    # dropdown_option = gr.Dropdown(["Option 1", "Option 2", "Option 3"], label="Radio Option")

    btn = gr.Button("Run Inference")

    gr.Examples(
        [
            [
                "GSGSPLAQQIKNIHSFIHQAKAAGRMDEVRTLQENLHQLMHEYFQQSD",
                "3v1c_A.pdb",
                "GSGSPLAQQIKNIHSFIHQAKAAGRMDEVRTLQENLHQLMHEYFQQSD",
                "3v1c_B.pdb",
                
            ],
        ],
        [input_seq_1, input_protein_1, input_seq_2,  input_protein_2],
    )
    reps =    [
    {
      "model": 0,
      "style": "cartoon",
      "chain": "A",
      "color": "whiteCarbon",
    },
    {
      "model": 0,
      "style": "cartoon",
       "chain": "B",
      "color": "greenCarbon",
    },
    {
      "model": 0,
      "chain": "A",
      "style": "stick",
      "sidechain": True,
      "color": "whiteCarbon",
    },
    {
      "model": 0,
      "chain": "B",
      "style": "stick",
      "sidechain": True,
      "color": "greenCarbon"
    }      
  ]
    # outputs 
    
    out = Molecule3D(reps=reps)
    metrics = gr.JSON(label="Metrics")
    run_time = gr.Textbox(label="Runtime")

    btn.click(predict, inputs=[input_seq_1, input_msa_1, input_protein_1, input_seq_2, input_msa_2,  input_protein_2], outputs=[out, metrics, run_time])

app.launch(show_error=True)