import os
import tqdm
import torch
import torchaudio
import numpy as np
from torch.utils.data import DataLoader
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification, Wav2Vec2Processor
from torch.nn import functional as F

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, basedir=None, sampling_rate=16000, max_audio_len=5):
        self.dataset = dataset
        self.basedir = basedir
        self.sampling_rate = sampling_rate
        self.max_audio_len = max_audio_len

    def __len__(self):
        return len(self.dataset)

    def _cutorpad(self, audio):
        effective_length = self.sampling_rate * self.max_audio_len
        len_audio = len(audio)

        if len_audio > effective_length:
            audio = audio[:effective_length]

        return audio

    def __getitem__(self, index):
        if self.basedir is None:
            filepath = self.dataset[index]
        else:
            filepath = os.path.join(self.basedir, self.dataset[index])

        speech_array, sr = torchaudio.load(filepath)

        if speech_array.shape[0] > 1:
            speech_array = torch.mean(speech_array, dim=0, keepdim=True)

        if sr != self.sampling_rate:
            transform = torchaudio.transforms.Resample(sr, self.sampling_rate)
            speech_array = transform(speech_array)
            sr = self.sampling_rate

        speech_array = speech_array.squeeze().numpy()
        speech_array = self._cutorpad(speech_array)

        return {"input_values": speech_array, "attention_mask": None}

class CollateFunc:
    def __init__(self, processor, max_length=None, padding=True, pad_to_multiple_of=None, sampling_rate=16000):
        self.padding = padding
        self.processor = processor
        self.max_length = max_length
        self.sampling_rate = sampling_rate
        self.pad_to_multiple_of = pad_to_multiple_of

    def __call__(self, batch):
        input_features = []

        for audio in batch:
            input_tensor = self.processor(audio["input_values"], sampling_rate=self.sampling_rate).input_values
            input_tensor = np.squeeze(input_tensor)
            input_features.append({"input_values": input_tensor})

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        return batch

def predict(test_dataloader, model, device):
    model.to(device)
    model.eval()
    preds = []

    with torch.no_grad():
        for batch in tqdm.tqdm(test_dataloader):
            input_values = batch['input_values'].to(device)

            logits = model(input_values).logits
            scores = F.softmax(logits, dim=-1)

            pred = torch.argmax(scores, dim=1).cpu().detach().numpy()
            preds.extend(pred)

    return preds

def get_gender(model_name_or_path, audio_paths, device):
    num_labels = 2

    feature_extractor = AutoFeatureExtractor.from_pretrained(model_name_or_path)
    model = AutoModelForAudioClassification.from_pretrained(
        pretrained_model_name_or_path=model_name_or_path,
        num_labels=num_labels,
    )

    test_dataset = CustomDataset(audio_paths)
    data_collator = CollateFunc(
        processor=feature_extractor,
        padding=True,
        sampling_rate=16000,
    )

    test_dataloader = DataLoader(
        dataset=test_dataset,
        batch_size=16,
        collate_fn=data_collator,
        shuffle=False,
        num_workers=10
    )

    preds = predict(test_dataloader=test_dataloader, model=model, device=device)

    # Map class indices to labels
    label_mapping = {0: "female", 1: "male"}

    # Determine the most common predicted label
    most_common_label = max(set(preds), key=preds.count)
    predicted_label = label_mapping[most_common_label]

    return predicted_label