import base64
import os
import gradio as gr
from transformers import pipeline
import numpy as np
import librosa
from datetime import datetime
from datasets import (
    load_dataset, 
    concatenate_datasets,
    Dataset,
    DatasetDict,
    Features,
    Value,
    Audio,
)

# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #

# Hugging Face evaluation dataset
HF_DATASET_NAME = "BounharAbdelaziz/Moroccan-STT-Eval-Dataset"

# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #

# Models paths
MODEL_PATHS = {
    "NANO": "BounharAbdelaziz/Morocco-Darija-STT-tiny-v1.3",
    "SMALL": "BounharAbdelaziz/Morocco-Darija-STT-small-v1.3",
    "LARGE": "BounharAbdelaziz/Morocco-Darija-STT-large-v1.3",
}

# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #

# Access token to models
STT_MODEL_TOKEN = os.environ.get("STT_MODEL_TOKEN")

# Access token to dataset
STT_EVAL_DATASET_TOKEN = os.environ.get("STT_EVAL_DATASET_TOKEN")

# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #

def encode_image_to_base64(image_path):
    with open(image_path, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read()).decode()
    return encoded_string

# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #

def create_html_image(image_path):
    img_base64 = encode_image_to_base64(image_path)
    html_string = f"""
    <div style="display: flex; justify-content: center; align-items: center; width: 100%; text-align: center;">
        <div style="max-width: 800px; margin: auto;">
            <img src="data:image/jpeg;base64,{img_base64}"
                 style="max-width: 75%; height: auto; display: block; margin: 0 auto; margin-top: 50px;"
                 alt="Displayed Image">
        </div>
    </div>
    """
    return html_string

# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #

def save_to_hf_dataset(audio_signal, model_choice, transcription):
    print("[INFO] Loading dataset...")

    dataset = load_dataset(HF_DATASET_NAME, token=STT_EVAL_DATASET_TOKEN)
    print("[INFO] Dataset loaded successfully.")

    timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    new_entry = {
        "audio": [{"array": audio_signal, "sampling_rate": 16000}],
        "transcription": [transcription],
        "model_used": [model_choice],
        "timestamp": [timestamp],
    }

    new_dataset = Dataset.from_dict(
        new_entry,
        features=Features({
            "audio": Audio(sampling_rate=16000),
            "transcription": Value("string"),
            "model_used": Value("string"),
            "timestamp": Value("string"),
        })
    )

    print("[INFO] Adding the new entry to the dataset...")
    train_dataset = dataset["train"]
    updated_train_dataset = concatenate_datasets([train_dataset, new_dataset])
    dataset["train"] = updated_train_dataset

    print("[INFO] Pushing the updated dataset...")
    dataset.push_to_hub(HF_DATASET_NAME, token=STT_EVAL_DATASET_TOKEN)

    print("[INFO] Dataset updated and pushed successfully.")

# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #

def load_model(model_name):
    model_id = MODEL_PATHS[model_name.upper()]
    return pipeline("automatic-speech-recognition", model=model_id, token=STT_MODEL_TOKEN)

# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #

def process_audio(audio, model_choice, save_data):
    
    # Force to false for now, issue with dataset
    save_data = False
    pipe = load_model(model_choice)
    audio_signal = audio[1]
    sample_rate = audio[0]
    audio_signal = audio_signal.astype(np.float32)
    
    if np.abs(audio_signal).max() > 1.0:
        audio_signal = audio_signal / 32768.0
    
    if sample_rate != 16000:
        print(f"[INFO] Resampling audio from {sample_rate}Hz to 16000Hz")
        audio_signal = librosa.resample(
            y=audio_signal, 
            orig_sr=sample_rate,
            target_sr=16000
        )
    
    result = pipe(audio_signal)
    transcription = result["text"]
    
    if save_data:
        print(f"[INFO] Saving data to eval dataset...")
        save_to_hf_dataset(audio_signal, model_choice, transcription)
    
    return transcription

# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #

def create_interface():
    with gr.Blocks(css="footer{display:none !important}") as app:
        base_path = os.path.dirname(__file__)
        local_image_path = os.path.join(base_path, 'logo_image.png')
        gr.HTML(create_html_image(local_image_path))
        
        gr.Markdown("# πŸ‡²πŸ‡¦ πŸš€ Moroccan Fast Speech-to-Text Transcription 😍")

        gr.Markdown("⚠️ **Nota bene**: Make sure to click on **Stop** before hitting the **Transcribe** button")
        gr.Markdown("πŸ“Œ The **Large** model should be available soon. Stay tuned!")
        
        with gr.Row():
            model_choice = gr.Dropdown(
                choices=["Nano", "Small", "Large"],
                value="Small",
                label="Select one of the models"
            )
        
        with gr.Row():
            audio_input = gr.Audio(
                sources=["microphone"],
                type="numpy",
                label="Record Audio",
            )
        
        with gr.Row():
            save_data = gr.Checkbox(
                label="Contribute to the evaluation benchmark (coming soon)",
                value=False,
            )
        
        submit_btn = gr.Button("Transcribe πŸ”₯")
        output_text = gr.Textbox(label="Transcription")
        
        gr.Markdown("""
        ### πŸ“„πŸ“Œ Notice to our dearest users πŸ€— (coming soon)
        - By transcribing your audio, you’re actively contributing to the development of a benchmark evaluation dataset for Moroccan speech-to-text models.  
        - Your transcriptions will be logged into a dedicated Hugging Face dataset, playing a crucial role in advancing research and innovation in speech recognition for Moroccan dialects and languages.  
        - Together, we’re building tools that better understand and serve the unique linguistic landscape of Morocco.
        - We count on your **thoughtfulness and responsibility** when using the app. Thank you for your contribution! 🌟
        """)
        
        submit_btn.click(
            fn=process_audio,
            inputs=[audio_input, model_choice, save_data],
            outputs=output_text
        )
        
        gr.Markdown("<br/>")
    
    return app