#!/usr/bin/env python

from __future__ import annotations

import os

import gradio as gr

from constants import UploadTarget
from inference import InferencePipeline
from trainer import Trainer


def create_training_demo(trainer: Trainer,
                         pipe: InferencePipeline | None = None,
                         disable_run_button: bool = False) -> gr.Blocks:
    def read_log() -> str:
        with open(trainer.log_file) as f:
            lines = f.readlines()
        return ''.join(lines[-10:])

    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column():
                with gr.Box():
                    gr.Markdown('Training Data')
                    training_video = gr.File(label='Training video')
                    training_prompt = gr.Textbox(
                        label='Training prompt',
                        max_lines=1,
                        placeholder='A man is surfing')
                    gr.Markdown('''
                        - Upload a video and write a `Training Prompt` that describes the video.
                        ''')

            with gr.Column():
                with gr.Box():
                    gr.Markdown('Training Parameters')
                    with gr.Row():
                        base_model = gr.Text(
                            label='Base Model',
                            value='CompVis/stable-diffusion-v1-4',
                            max_lines=1)
                        resolution = gr.Dropdown(choices=['512', '768'],
                                                 value='512',
                                                 label='Resolution',
                                                 visible=False)

                    hf_token = gr.Text(label='Hugging Face Write Token',
                                       type='password',
                                       visible=os.getenv('HF_TOKEN') is None)
                    with gr.Accordion(label='Advanced options', open=False):
                        num_training_steps = gr.Number(
                            label='Number of Training Steps',
                            value=300,
                            precision=0)
                        learning_rate = gr.Number(label='Learning Rate',
                                                  value=0.000035)
                        gradient_accumulation = gr.Number(
                            label='Number of Gradient Accumulation',
                            value=1,
                            precision=0)
                        seed = gr.Slider(label='Seed',
                                         minimum=0,
                                         maximum=100000,
                                         step=1,
                                         randomize=True,
                                         value=0)
                        fp16 = gr.Checkbox(label='FP16', value=True)
                        use_8bit_adam = gr.Checkbox(label='Use 8bit Adam',
                                                    value=False)
                        checkpointing_steps = gr.Number(
                            label='Checkpointing Steps',
                            value=1000,
                            precision=0)
                        validation_epochs = gr.Number(
                            label='Validation Epochs', value=100, precision=0)
                    gr.Markdown('''
                        - The base model must be a Stable Diffusion model compatible with [diffusers](https://github.com/huggingface/diffusers) library.
                        - Expected time to train a model for 300 steps: ~20 minutes with T4
                        - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
                        ''')

        with gr.Row():
            with gr.Column():
                gr.Markdown('Output Model')
                output_model_name = gr.Text(label='Name of your model',
                                            placeholder='The surfer man',
                                            max_lines=1)
                validation_prompt = gr.Text(
                    label='Validation Prompt',
                    placeholder=
                    'prompt to test the model, e.g: a dog is surfing')
            with gr.Column():
                gr.Markdown('Upload Settings')
                with gr.Row():
                    upload_to_hub = gr.Checkbox(label='Upload model to Hub',
                                                value=True)
                    use_private_repo = gr.Checkbox(label='Private', value=True)
                    delete_existing_repo = gr.Checkbox(
                        label='Delete existing repo of the same name',
                        value=False)
                    upload_to = gr.Radio(
                        label='Upload to',
                        choices=[_.value for _ in UploadTarget],
                        value=UploadTarget.MODEL_LIBRARY.value)

        pause_space_after_training = gr.Checkbox(
            label='Pause this Space after training',
            value=False,
            interactive=bool(os.getenv('SPACE_ID')),
            visible=False)
        run_button = gr.Button('Start Training',
                               interactive=not disable_run_button)

        with gr.Box():
            gr.Text(label='Log',
                    value=read_log,
                    lines=10,
                    max_lines=10,
                    every=1)

        if pipe is not None:
            run_button.click(fn=pipe.clear)
        run_button.click(fn=trainer.run,
                         inputs=[
                             training_video,
                             training_prompt,
                             output_model_name,
                             delete_existing_repo,
                             validation_prompt,
                             base_model,
                             resolution,
                             num_training_steps,
                             learning_rate,
                             gradient_accumulation,
                             seed,
                             fp16,
                             use_8bit_adam,
                             checkpointing_steps,
                             validation_epochs,
                             upload_to_hub,
                             use_private_repo,
                             delete_existing_repo,
                             upload_to,
                             pause_space_after_training,
                             hf_token,
                         ])
    return demo


if __name__ == '__main__':
    trainer = Trainer()
    demo = create_training_demo(trainer)
    demo.queue(api_open=False, max_size=1).launch()