import gradio as gr
import torch
import soundfile as sf
import numpy as np
import yaml
from inference import MasteringStyleTransfer
from utils import download_youtube_audio
from config import args

mastering_transfer = MasteringStyleTransfer(args)

def process_audio(input_audio, reference_audio, perform_ito, ito_reference_audio=None):
    # Process the audio files
    output_audio, predicted_params, ito_output_audio, ito_predicted_params, ito_log, sr, _ = mastering_transfer.process_audio(
        input_audio, reference_audio, ito_reference_audio if ito_reference_audio else reference_audio, {}, perform_ito
    )
    
    # Generate parameter output strings
    param_output = mastering_transfer.get_param_output_string(predicted_params)
    ito_param_output = mastering_transfer.get_param_output_string(ito_predicted_params) if ito_predicted_params is not None else "ITO not performed"
    
    # Generate top 10 differences if ITO was performed
    top_10_diff = mastering_transfer.get_top_10_diff_string(predicted_params, ito_predicted_params) if ito_predicted_params is not None else "ITO not performed"
    
    return "output_mastered.wav", "ito_output_mastered.wav" if ito_output_audio is not None else None, param_output, ito_param_output, top_10_diff, ito_log

def process_with_ito(input_audio, reference_audio, perform_ito, use_same_reference, ito_reference_audio):
    ito_ref = reference_audio if use_same_reference else ito_reference_audio
    return process_audio(input_audio, reference_audio, perform_ito, ito_ref)

def process_youtube_with_ito(input_url, reference_url, perform_ito, use_same_reference, ito_reference_url):
    input_audio = download_youtube_audio(input_url)
    reference_audio = download_youtube_audio(reference_url)
    ito_ref = reference_audio if use_same_reference else download_youtube_audio(ito_reference_url)
    
    output_audio, predicted_params, ito_output_audio, ito_predicted_params, ito_log, sr, _ = mastering_transfer.process_audio(
        input_audio, reference_audio, ito_ref, {}, perform_ito, log_ito=True
    )
    
    param_output = mastering_transfer.get_param_output_string(predicted_params)
    ito_param_output = mastering_transfer.get_param_output_string(ito_predicted_params) if ito_predicted_params is not None else "ITO not performed"
    top_10_diff = mastering_transfer.get_top_10_diff_string(predicted_params, ito_predicted_params) if ito_predicted_params is not None else "ITO not performed"
    
    return "output_mastered_yt.wav", "ito_output_mastered_yt.wav" if ito_output_audio is not None else None, param_output, ito_param_output, top_10_diff, ito_log


with gr.Blocks() as demo:
    gr.Markdown("# Mastering Style Transfer Demo")

    with gr.Tab("Upload Audio"):
        input_audio = gr.Audio(label="Input Audio")
        reference_audio = gr.Audio(label="Reference Audio")
        perform_ito = gr.Checkbox(label="Perform ITO")
        with gr.Column(visible=False) as ito_options:
            use_same_reference = gr.Checkbox(label="Use same reference audio for ITO", value=True)
            ito_reference_audio = gr.Audio(label="ITO Reference Audio", visible=False)
        
        def update_ito_options(perform_ito):
            return gr.Column.update(visible=perform_ito)
        
        def update_ito_reference(use_same):
            return gr.Audio.update(visible=not use_same)
        
        perform_ito.change(fn=update_ito_options, inputs=perform_ito, outputs=ito_options)
        use_same_reference.change(fn=update_ito_reference, inputs=use_same_reference, outputs=ito_reference_audio)
        
        submit_button = gr.Button("Process")
        output_audio = gr.Audio(label="Output Audio")
        ito_output_audio = gr.Audio(label="ITO Output Audio")
        param_output = gr.Textbox(label="Predicted Parameters", lines=10)
        ito_param_output = gr.Textbox(label="ITO Predicted Parameters", lines=10)
        top_10_diff = gr.Textbox(label="Top 10 Parameter Differences", lines=10)
        ito_log = gr.Textbox(label="ITO Log", lines=20)
        
        submit_button.click(
            process_with_ito, 
            inputs=[input_audio, reference_audio, perform_ito, use_same_reference, ito_reference_audio], 
            outputs=[output_audio, ito_output_audio, param_output, ito_param_output, top_10_diff, ito_log]
        )

    with gr.Tab("YouTube URLs"):
        input_url = gr.Textbox(label="Input YouTube URL")
        reference_url = gr.Textbox(label="Reference YouTube URL")
        perform_ito_yt = gr.Checkbox(label="Perform ITO")
        with gr.Column(visible=False) as ito_options_yt:
            use_same_reference_yt = gr.Checkbox(label="Use same reference audio for ITO", value=True)
            ito_reference_url = gr.Textbox(label="ITO Reference YouTube URL", visible=False)
        
        def update_ito_options_yt(perform_ito):
            return gr.Column.update(visible=perform_ito)
        
        def update_ito_reference_yt(use_same):
            return gr.Textbox.update(visible=not use_same)
        
        perform_ito_yt.change(fn=update_ito_options_yt, inputs=perform_ito_yt, outputs=ito_options_yt)
        use_same_reference_yt.change(fn=update_ito_reference_yt, inputs=use_same_reference_yt, outputs=ito_reference_url)
        
        submit_button_yt = gr.Button("Process")
        output_audio_yt = gr.Audio(label="Output Audio")
        ito_output_audio_yt = gr.Audio(label="ITO Output Audio")
        param_output_yt = gr.Textbox(label="Predicted Parameters", lines=10)
        ito_param_output_yt = gr.Textbox(label="ITO Predicted Parameters", lines=10)
        top_10_diff_yt = gr.Textbox(label="Top 10 Parameter Differences", lines=10)
        ito_log_yt = gr.Textbox(label="ITO Log", lines=20)

        submit_button_yt.click(
            process_youtube_with_ito, 
            inputs=[input_url, reference_url, perform_ito_yt, use_same_reference_yt, ito_reference_url], 
            outputs=[output_audio_yt, ito_output_audio_yt, param_output_yt, ito_param_output_yt, top_10_diff_yt, ito_log_yt]
        )
    
demo.launch()