import os
import json
import random
import argparse
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, Subset
from huggingface_hub import upload_folder
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
from collections import Counter
from transformers.integrations import TensorBoardCallback
from transformers import (
    Wav2Vec2FeatureExtractor, HubertConfig, HubertForSequenceClassification,
    Trainer, TrainingArguments,
    EarlyStoppingCallback
)

MODEL = "ntu-spml/distilhubert" # modelo base
FEATURE_EXTRACTOR = Wav2Vec2FeatureExtractor.from_pretrained(MODEL) # feature extractor del modelo base
seed = 123
MAX_DURATION = 1.00 # Máxima duración de los audios
SAMPLING_RATE = FEATURE_EXTRACTOR.sampling_rate # 16kHz
token = os.getenv("HF_TOKEN")
config_file = "models_config.json"
batch_size = 1024 # TODO: repasar si sigue siendo necesario
num_workers = 12 # Núcleos de la CPU

class AudioDataset(Dataset):
    def __init__(self, dataset_path, label2id, filter_white_noise, undersample_normal):
        self.dataset_path = dataset_path
        self.label2id = label2id
        self.file_paths = []
        self.filter_white_noise = filter_white_noise
        self.labels = []
        for label_dir, label_id in self.label2id.items():
            label_path = os.path.join(self.dataset_path, label_dir)
            if os.path.isdir(label_path):
                for file_name in os.listdir(label_path):
                    audio_path = os.path.join(label_path, file_name)
                    self.file_paths.append(audio_path)
                    self.labels.append(label_id)
        if undersample_normal and self.label2id:
            self.undersample_normal_class()

    def undersample_normal_class(self):
        normal_label = self.label2id.get('1s_normal')
        label_counts = Counter(self.labels)
        other_counts = [count for label, count in label_counts.items() if label != normal_label]
        if other_counts:  # Ensure there are other counts before taking max
            target_count = max(other_counts)
            normal_indices = [i for i, label in enumerate(self.labels) if label == normal_label]
            keep_indices = random.sample(normal_indices, target_count)
            new_file_paths = []
            new_labels = []
            for i, (path, label) in enumerate(zip(self.file_paths, self.labels)):
                if label != normal_label or i in keep_indices:
                    new_file_paths.append(path)
                    new_labels.append(label)
            self.file_paths = new_file_paths
            self.labels = new_labels

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        audio_path = self.file_paths[idx]
        label = self.labels[idx]
        input_values = self.preprocess_audio(audio_path)
        return {
            "input_values": input_values,
            "labels": torch.tensor(label)
        }

    def preprocess_audio(self, audio_path):
        waveform, sample_rate = torchaudio.load(
            audio_path,
            normalize=True,
            )
        if sample_rate != SAMPLING_RATE: # Resamplear si no es 16kHz
            resampler = torchaudio.transforms.Resample(sample_rate, SAMPLING_RATE)
            waveform = resampler(waveform)
        if waveform.shape[0] > 1: # Si es stereo, convertir a mono
            waveform = waveform.mean(dim=0, keepdim=True)
        waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-6) # TODO: probar a quitar porque ya se hace, sin 1e-6 el accuracy es pésimo!!
        max_length = int(SAMPLING_RATE * MAX_DURATION)
        if waveform.shape[1] > max_length:
            waveform = waveform[:, :max_length] # Truncar
        else:
            waveform = torch.nn.functional.pad(waveform, (0, max_length - waveform.shape[1])) # Padding
        inputs = FEATURE_EXTRACTOR(
            waveform.squeeze(),
            sampling_rate=SAMPLING_RATE, # Hecho a mano, por si acaso
            return_tensors="pt",
        )
        return inputs.input_values.squeeze()
    
def is_white_noise(audio):
    mean = torch.mean(audio)
    std = torch.std(audio)
    return torch.abs(mean) < 0.001 and std < 0.01

def seed_everything(): # TODO: mirar si es necesario algo más 
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.backends.cudnn.deterministic = True # Para reproducibilidad
    # torch.backends.cudnn.benchmark = False # Para reproducibilidad

def build_label_mappings(dataset_path):
    label2id = {}
    id2label = {}
    label_id = 0
    for label_dir in os.listdir(dataset_path):
        if os.path.isdir(os.path.join(dataset_path, label_dir)):
            label2id[label_dir] = label_id
            id2label[label_id] = label_dir
            label_id += 1
    return label2id, id2label

def compute_class_weights(labels):
    class_counts = Counter(labels)
    total_samples = len(labels)
    class_weights = {cls: total_samples / count for cls, count in class_counts.items()}
    return [class_weights[label] for label in labels]

def create_dataloader(dataset_path, filter_white_noise, undersample_normal, test_size=0.2, shuffle=True, pin_memory=True):
    label2id, id2label = build_label_mappings(dataset_path)
    dataset = AudioDataset(dataset_path, label2id, filter_white_noise, undersample_normal)
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    random.shuffle(indices)
    split_idx = int(dataset_size * (1 - test_size))
    train_indices = indices[:split_idx]
    test_indices = indices[split_idx:]
    train_dataset = Subset(dataset, train_indices)
    test_dataset = Subset(dataset, test_indices)
    labels = [dataset.labels[i] for i in train_indices]
    class_weights = compute_class_weights(labels)
    sampler = WeightedRandomSampler(
        weights=class_weights,
        num_samples=len(train_dataset),
        replacement=True
    )
    train_dataloader = DataLoader(
        train_dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers, pin_memory=pin_memory
    )
    test_dataloader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory
    )
    return train_dataloader, test_dataloader, id2label

def load_model(model_path, id2label, num_labels):
    config = HubertConfig.from_pretrained(
        pretrained_model_name_or_path=model_path,
        num_labels=num_labels,
        id2label=id2label,
        finetuning_task="audio-classification"
    )
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = HubertForSequenceClassification.from_pretrained(
        pretrained_model_name_or_path=model_path,
        config=config,
        torch_dtype=torch.float32, # TODO: Comprobar si se necesita float32 y ver si se puede cambiar por float16
    )
    model.to(device)
    return model

def train_params(dataset_path, filter_white_noise, undersample_normal):
    train_dataloader, test_dataloader, id2label = create_dataloader(dataset_path, filter_white_noise, undersample_normal)
    model = load_model(MODEL, id2label, num_labels=len(id2label))    
    return model, train_dataloader, test_dataloader, id2label

def predict_params(dataset_path, model_path, filter_white_noise, undersample_normal):
    _, _, id2label = create_dataloader(dataset_path, filter_white_noise, undersample_normal)
    model = load_model(model_path, id2label, num_labels=len(id2label))
    return model, id2label

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    acc = accuracy_score(labels, preds)
    cm = confusion_matrix(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'confusion_matrix': cm.tolist()
        }

def main(training_args, output_dir, dataset_path, filter_white_noise, undersample_normal):
    seed_everything()
    model, train_dataloader, test_dataloader, id2label = train_params(dataset_path, filter_white_noise, undersample_normal)
    early_stopping_callback = EarlyStoppingCallback(
        early_stopping_patience=5,
        early_stopping_threshold=0.001
        )
    trainer = Trainer(
        model=model,
        args=training_args,
        compute_metrics=compute_metrics,
        train_dataset=train_dataloader.dataset,
        eval_dataset=test_dataloader.dataset,
        callbacks=[TensorBoardCallback, early_stopping_callback]
    )
    torch.cuda.empty_cache() # liberar memoria de la GPU
    trainer.train() # resume_from_checkpoint para continuar el train
    # trainer.save_model(output_dir) # Guardar modelo local.
    os.makedirs(output_dir, exist_ok=True)
    trainer.save_model(output_dir) # Guardar modelo local.
    eval_results = trainer.evaluate()
    print(f"Evaluation results: {eval_results}")
    trainer.push_to_hub(token=token) # Subir modelo a perfil
    upload_folder(repo_id=f"A-POR-LOS-8000/{output_dir}", folder_path=output_dir, token=token) # subir a organización y local
    
    def predict(audio_path):
        waveform, sample_rate = torchaudio.load(audio_path, normalize=True)
        if sample_rate != SAMPLING_RATE:
            resampler = torchaudio.transforms.Resample(sample_rate, SAMPLING_RATE)
            waveform = resampler(waveform)
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
        waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-6)
        max_length = int(SAMPLING_RATE * MAX_DURATION)
        if waveform.shape[1] > max_length:
            waveform = waveform[:, :max_length]
        else:
            waveform = torch.nn.functional.pad(waveform, (0, max_length - waveform.shape[1]))
        inputs = FEATURE_EXTRACTOR(
            waveform.squeeze(),
            sampling_rate=SAMPLING_RATE,
            return_tensors="pt",
        )
        with torch.no_grad():
            logits = model(inputs.input_values.to(model.device)).logits
            predicted_class_id = logits.argmax().item()
            predicted_label = id2label[predicted_class_id]
        return predicted_label, logits
    test_samples = random.sample(test_dataloader.dataset.dataset.file_paths, 15)
    for sample in test_samples:
        predicted_label, logits = predict(sample)
        print(f"File: {sample}")
        print(f"Predicted label: {predicted_label}")
        print(f"Logits: {logits}")
        print("---")

def load_config(model_name):
    with open(config_file, 'r') as f:
        config = json.load(f)
    model_config = config[model_name]
    training_args = TrainingArguments(**model_config["training_args"])
    model_config["training_args"] = training_args
    return model_config

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--n", choices=["mon", "class"], 
        required=True, help="Elegir qué modelo entrenar"
        )
    args = parser.parse_args()
    config = load_config(args.n)
    training_args = config["training_args"]
    output_dir = config["output_dir"]
    dataset_path = config["dataset_path"]
    if args.n == "mon":
        filter_white_noise = False
        undersample_normal = False
    elif args.n == "class":
        filter_white_noise = True
        undersample_normal = True
    main(training_args, output_dir, dataset_path, filter_white_noise, undersample_normal)