import torch
import os
import logging
import soundfile as sf
import numpy as np
from huggingface_hub import hf_hub_download
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts

# --- CONSTANTES ---
REPO_ID = "dofbi/galsenai-xtts-v2-wolof-inference"
LOCAL_DIR = "./models"

class WolofXTTSInference:
    def __init__(self, repo_id=REPO_ID, local_dir=LOCAL_DIR):
        # Configuration du logging
        logging.basicConfig(
            level=logging.INFO, 
            format='%(asctime)s - %(levelname)s - %(message)s'
        )
        self.logger = logging.getLogger(__name__)

        # Créer le dossier local s'il n'existe pas
        os.makedirs(local_dir, exist_ok=True)

        # Téléchargement des fichiers nécessaires
        try:
            # Créer les sous-dossiers nécessaires
            os.makedirs(os.path.join(local_dir, "Anta_GPT_XTTS_Wo"), exist_ok=True)
            os.makedirs(os.path.join(local_dir, "XTTS_v2.0_original_model_files"), exist_ok=True)

            # Télécharger le checkpoint
            self.model_path = hf_hub_download(
                repo_id=repo_id, 
                filename="Anta_GPT_XTTS_Wo/best_model_89250.pth", 
                local_dir=local_dir
            )

            # Télécharger le fichier de configuration
            self.config_path = hf_hub_download(
                repo_id=repo_id, 
                filename="Anta_GPT_XTTS_Wo/config.json", 
                local_dir=local_dir
            )

            # Télécharger le vocabulaire
            self.vocab_path = hf_hub_download(
                repo_id=repo_id, 
                filename="XTTS_v2.0_original_model_files/vocab.json", 
                local_dir=local_dir
            )

            # Télécharger l'audio de référence
            self.reference_audio = hf_hub_download(
                repo_id=repo_id, 
                filename="anta_sample.wav", 
                local_dir=local_dir
            )

        except Exception as e:
            self.logger.error(f"Erreur lors du téléchargement des fichiers : {e}")
            raise

        # Sélection du device
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        
        # Initialisation du modèle
        self.model = self._load_model()

    def _load_model(self):
        """Charge le modèle XTTS"""
        try:
            self.logger.info("Chargement du modèle XTTS...")
            
            # Initialisation du modèle
            config = XttsConfig()
            config.load_json(self.config_path)
            model = Xtts.init_from_config(config)
            
            # Chargement du checkpoint avec load_checkpoint
            model.load_checkpoint(config,
                                checkpoint_path=self.model_path,
                                vocab_path=self.vocab_path,
                                use_deepspeed=False
                             )

            model.to(self.device)
            model.eval()  # Mettre le modèle en mode évaluation
            
            self.logger.info("Modèle chargé avec succès!")
            return model
        
        except Exception as e:
            self.logger.error(f"Erreur lors du chargement du modèle : {e}")
            raise

    def generate_audio(
        self, 
        text: str, 
        reference_audio: str = None, 
        speed: float = 1.06, 
        language: str = "wo",
        output_path: str = None
    ) -> tuple[np.ndarray, int]:
        """
        Génère de l'audio à partir du texte fourni
        
        Args:
            text (str): Texte à convertir en audio
            reference_audio (str, optional): Chemin vers l'audio de référence. Defaults to None.
            speed (float, optional): Vitesse de lecture. Defaults to 1.06.
            language (str, optional): Langue du texte. Defaults to "wo".
            output_path (str, optional): Chemin de sauvegarde de l'audio généré. Defaults to None.
        
        Returns:
            tuple[np.ndarray, int]: audio_array, sample_rate
        """
        if not text:
           raise ValueError("Le texte ne peut pas être vide.")
        
        try:
            # Utiliser l'audio de référence fourni ou par défaut
            ref_audio = reference_audio or self.reference_audio
            
            # Obtenir les embeddings
            gpt_cond_latent, speaker_embedding = self.model.get_conditioning_latents(
                audio_path=[ref_audio],
                gpt_cond_len=self.model.config.gpt_cond_len,
                max_ref_length=self.model.config.max_ref_len,
                sound_norm_refs=self.model.config.sound_norm_refs
            )
            
            # Génération de l'audio
            result = self.model.inference(
                text=text.lower(),
                gpt_cond_latent=gpt_cond_latent,
                speaker_embedding=speaker_embedding,
                do_sample=False,
                speed=speed,
                language=language,
                enable_text_splitting=True
            )
            
            # Récupérer le taux d'échantillonnage
            sample_rate = self.model.config.audio.sample_rate
            
            # Sauvegarde optionnelle
            if output_path:
                sf.write(output_path, result["wav"], sample_rate)
                self.logger.info(f"Audio sauvegardé dans {output_path}")
            
            return result["wav"], sample_rate
        
        except Exception as e:
            self.logger.error(f"Erreur lors de la génération de l'audio : {e}")
            raise
        
    def generate_audio_from_config(self, text: str, config: dict, output_path: str = None) -> tuple[np.ndarray, int]:
        """
        Génère de l'audio à partir du texte et d'un dictionnaire de configuration.

        Args:
            text (str): Texte à convertir en audio
            config (dict): Dictionnaire de configuration (speed, language, reference_audio)
            output_path (str, optional): Chemin de sauvegarde de l'audio généré. Defaults to None.
        
        Returns:
             tuple[np.ndarray, int]: audio_array, sample_rate
        """
        speed = config.get('speed', 1.06)
        language = config.get('language', "wo")
        reference_audio = config.get('reference_audio', None)
        return self.generate_audio(text=text, reference_audio=reference_audio, speed=speed, language=language, output_path=output_path)
        

# Exemple d'utilisation
if __name__ == "__main__":
    tts = WolofXTTSInference()
    
    # Exemple de génération d'audio
    text = "Màngi tuddu Aadama, di baat bii waa Galsen A.I defar ngir wax ak yéen ci wolof!"
    
    # Simple
    audio, sr = tts.generate_audio(
        text, 
        output_path="generated_audio.wav"
    )

    # Avec une config
    config_gen_audio = {
        "speed": 1.2,
        "language": "wo",
    }
    audio, sr = tts.generate_audio_from_config(
        text=text,
        config=config_gen_audio, 
        output_path="generated_audio_config.wav"
    )