import json
import os
import shutil
import sys
from collections import defaultdict

import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix

from utils import compute_average_wer, download_dataset


def main():
    """
    Main function to orchestrate the multilingual data generation process.

    This function performs the following steps:
    1. Downloads multilingual evaluation data if requested.
    2. Processes multilingual evaluation files.
    3. Calculates and saves results, including Word Error Rate (WER) and
       language detection confusion matrices.
    """
    source_repo = "argmaxinc/whisperkit-evals-multilingual"
    source_subfolder = "WhisperKit"
    source_directory = f"{source_repo}/{source_subfolder}"
    if len(sys.argv) > 1 and sys.argv[1] == "download":
        try:
            shutil.rmtree(source_repo)
        except:
            print("Nothing to remove.")
        download_dataset(source_repo, source_repo, source_subfolder)

    results = defaultdict(
        lambda: {
            "average_wer": [],
            "language_wer": defaultdict(list),
            "language_detection": [],
        }
    )

    confusion_matrices = {}

    for subdir, _, files in os.walk(source_directory):
        for filename in files:
            if not filename.endswith(".json") or "summary" in filename:
                continue

            file_path = os.path.join(subdir, filename)
            with open(file_path, "r") as f:
                data = json.load(f)

            subdir_components = subdir.split(os.path.sep)
            is_forced = "forced" in subdir_components
            model = subdir_components[-3] if not is_forced else subdir_components[-4]

            key = f"{model}/{'forced' if is_forced else 'not_forced'}"

            for item in data["results"]:
                if "reference_language" not in item:
                    continue
                reference_language = item["reference_language"]
                wer = item["wer"]
                detected_language = item["predicted_language"]

                result = {
                    "reference": item["reference"],
                    "prediction": item["prediction"],
                }

                results[key]["average_wer"].append(result)
                results[key]["language_wer"][reference_language].append(result)
                results[key]["language_detection"].append(
                    (reference_language, detected_language)
                )

    calculate_and_save_results(results, confusion_matrices)


def calculate_and_save_results(results, confusion_matrices):
    """
    Calculates final multilingual metrics and saves them to CSV and JSON files.

    :param results: Dictionary containing raw multilingual evaluation data.
    :param confusion_matrices: Dictionary to store confusion matrices for language detection.

    This function processes the raw multilingual data, calculates average metrics,
    creates confusion matrices for language detection, and saves the results to:
    1. A CSV file with WER data for each model and language.
    2. A JSON file with confusion matrices for language detection.
    """
    wer_data = []
    for key, data in results.items():
        model, forced = key.rsplit("/", 1)
        model = model.replace("_", "/")
        row = {
            "Model": model,
            "Forced Tokens": forced == "forced",
            "Average WER": compute_average_wer(data["average_wer"]),
        }
        for lang, wers in data["language_wer"].items():
            row[f"WER_{lang}"] = compute_average_wer(wers)
        wer_data.append(row)

        true_languages, detected_languages = zip(*data["language_detection"])
        unique_languages = sorted(set(true_languages))
        cm = confusion_matrix(
            true_languages, detected_languages, labels=unique_languages
        )

        row_sums = cm.sum(axis=1)
        cm_normalized = np.zeros_like(cm, dtype=float)
        non_zero_rows = row_sums != 0
        cm_normalized[non_zero_rows] = (
            cm[non_zero_rows] / row_sums[non_zero_rows, np.newaxis]
        )

        if model not in confusion_matrices:
            confusion_matrices[model] = {}
        confusion_matrices[model][forced] = {
            "matrix": cm_normalized.tolist(),
            "labels": unique_languages,
        }

    df = pd.DataFrame(wer_data)
    df.to_csv("dashboard_data/multilingual_results.csv", index=False)

    with open("dashboard_data/multilingual_confusion_matrices.json", "w") as f:
        json.dump(confusion_matrices, f, indent=2)


if __name__ == "__main__":
    main()