import time
import json 

import gradio as gr

from gradio_molecule3d import Molecule3D


import numpy as np
from biotite.structure.io.pdb import PDBFile

def set_all_to_zero(input_pdb_file_1, input_pdb_file_2, output_file):
    structure1 = PDBFile.read(input_pdb_file_1).get_structure()
    structure2 = PDBFile.read(input_pdb_file_2).get_structure()
    structure1.coord = np.zeros_like(structure1.coord)
    # shift second to avoid interchain "clash"
    structure2.coord = np.ones_like(structure2.coord) * np.array([3.7, 0, 0])
    out_structure = structure1 + structure2
    file = PDBFile()
    file.set_structure(out_structure)
    file.write(output_file)


def predict(input_seq_1, input_msa_1, input_protein_1, input_seq_2, input_msa_2,  input_protein_2):
# def predict(input_protein_1, input_protein_2):
    start_time = time.time()

    # Do inference here
    # return an output pdb file with the protein and two chains A and B.  
    output_file = "test_out.pdb"
    set_all_to_zero(input_protein_1, input_protein_2, output_file)
    # also return a JSON with any metrics you want to report
    metrics = {"F_nat": 100}
    
    end_time = time.time()
    run_time = end_time - start_time
    return output_file, json.dumps(metrics), run_time

with gr.Blocks() as app:

    gr.Markdown("# Template for inference")

    gr.Markdown("Title, description, and other information about the model")   
    with gr.Row():
        with gr.Column():
            input_seq_1 = gr.Textbox(lines=3, label="Input Protein 1 sequence (FASTA)")
            input_msa_1 = gr.File(label="Input MSA Protein 1 (A3M)")
            input_protein_1 = gr.File(label="Input Protein 1 monomer (PDB)")
        with gr.Column():
            input_seq_2 = gr.Textbox(lines=3, label="Input Protein 2 sequence (FASTA)")
            input_msa_2 = gr.File(label="Input MSA Protein 2 (A3M)")
            input_protein_2 = gr.File(label="Input Protein 2 structure (PDB)")
        
        
    
    # define any options here

    # for automated inference the default options are used
    # slider_option = gr.Slider(0,10, label="Slider Option")
    # checkbox_option = gr.Checkbox(label="Checkbox Option")
    # dropdown_option = gr.Dropdown(["Option 1", "Option 2", "Option 3"], label="Radio Option")

    btn = gr.Button("Run Inference")

    gr.Examples(
        [
            [
                "",
                "empty_file.a3m",
                "3v1c_A.pdb",
                "",
                "empty_file.a3m",
                "3v1c_B.pdb",
            ],
        ],
        [input_seq_1, input_msa_1, input_protein_1, input_seq_2, input_msa_2,  input_protein_2],
    )
    reps =    [
    {
      "model": 0,
      "style": "sphere",
      "chain": "A",
      "color": "whiteCarbon",
    },
    {
      "model": 0,
      "style": "sphere",
       "chain": "B",
      "color": "greenCarbon",
    }, 
  ]
    # outputs 
    
    out = Molecule3D(reps=reps)
    metrics = gr.JSON(label="Metrics")
    run_time = gr.Textbox(label="Runtime")

    btn.click(predict, inputs=[input_seq_1, input_msa_1, input_protein_1, input_seq_2, input_msa_2,  input_protein_2], outputs=[out, metrics, run_time])

app.launch()