dofbi's picture
Update inference.py
8e9aeb7 verified
import torch
import os
import logging
import soundfile as sf
import numpy as np
from huggingface_hub import hf_hub_download
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
# --- CONSTANTES ---
REPO_ID = "dofbi/galsenai-xtts-v2-wolof-inference"
LOCAL_DIR = "./models"
class WolofXTTSInference:
def __init__(self, repo_id=REPO_ID, local_dir=LOCAL_DIR):
# Configuration du logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
self.logger = logging.getLogger(__name__)
# Créer le dossier local s'il n'existe pas
os.makedirs(local_dir, exist_ok=True)
# Téléchargement des fichiers nécessaires
try:
# Créer les sous-dossiers nécessaires
os.makedirs(os.path.join(local_dir, "Anta_GPT_XTTS_Wo"), exist_ok=True)
os.makedirs(os.path.join(local_dir, "XTTS_v2.0_original_model_files"), exist_ok=True)
# Télécharger le checkpoint
self.model_path = hf_hub_download(
repo_id=repo_id,
filename="Anta_GPT_XTTS_Wo/best_model_89250.pth",
local_dir=local_dir
)
# Télécharger le fichier de configuration
self.config_path = hf_hub_download(
repo_id=repo_id,
filename="Anta_GPT_XTTS_Wo/config.json",
local_dir=local_dir
)
# Télécharger le vocabulaire
self.vocab_path = hf_hub_download(
repo_id=repo_id,
filename="XTTS_v2.0_original_model_files/vocab.json",
local_dir=local_dir
)
# Télécharger l'audio de référence
self.reference_audio = hf_hub_download(
repo_id=repo_id,
filename="anta_sample.wav",
local_dir=local_dir
)
except Exception as e:
self.logger.error(f"Erreur lors du téléchargement des fichiers : {e}")
raise
# Sélection du device
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Initialisation du modèle
self.model = self._load_model()
def _load_model(self):
"""Charge le modèle XTTS"""
try:
self.logger.info("Chargement du modèle XTTS...")
# Initialisation du modèle
config = XttsConfig()
config.load_json(self.config_path)
model = Xtts.init_from_config(config)
# Chargement du checkpoint avec load_checkpoint
model.load_checkpoint(config,
checkpoint_path=self.model_path,
vocab_path=self.vocab_path,
use_deepspeed=False
)
model.to(self.device)
model.eval() # Mettre le modèle en mode évaluation
self.logger.info("Modèle chargé avec succès!")
return model
except Exception as e:
self.logger.error(f"Erreur lors du chargement du modèle : {e}")
raise
def generate_audio(
self,
text: str,
reference_audio: str = None,
speed: float = 1.06,
language: str = "wo",
output_path: str = None
) -> tuple[np.ndarray, int]:
"""
Génère de l'audio à partir du texte fourni
Args:
text (str): Texte à convertir en audio
reference_audio (str, optional): Chemin vers l'audio de référence. Defaults to None.
speed (float, optional): Vitesse de lecture. Defaults to 1.06.
language (str, optional): Langue du texte. Defaults to "wo".
output_path (str, optional): Chemin de sauvegarde de l'audio généré. Defaults to None.
Returns:
tuple[np.ndarray, int]: audio_array, sample_rate
"""
if not text:
raise ValueError("Le texte ne peut pas être vide.")
try:
# Utiliser l'audio de référence fourni ou par défaut
ref_audio = reference_audio or self.reference_audio
# Obtenir les embeddings
gpt_cond_latent, speaker_embedding = self.model.get_conditioning_latents(
audio_path=[ref_audio],
gpt_cond_len=self.model.config.gpt_cond_len,
max_ref_length=self.model.config.max_ref_len,
sound_norm_refs=self.model.config.sound_norm_refs
)
# Génération de l'audio
result = self.model.inference(
text=text.lower(),
gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding,
do_sample=False,
speed=speed,
language=language,
enable_text_splitting=True
)
# Récupérer le taux d'échantillonnage
sample_rate = self.model.config.audio.sample_rate
# Sauvegarde optionnelle
if output_path:
sf.write(output_path, result["wav"], sample_rate)
self.logger.info(f"Audio sauvegardé dans {output_path}")
return result["wav"], sample_rate
except Exception as e:
self.logger.error(f"Erreur lors de la génération de l'audio : {e}")
raise
def generate_audio_from_config(self, text: str, config: dict, output_path: str = None) -> tuple[np.ndarray, int]:
"""
Génère de l'audio à partir du texte et d'un dictionnaire de configuration.
Args:
text (str): Texte à convertir en audio
config (dict): Dictionnaire de configuration (speed, language, reference_audio)
output_path (str, optional): Chemin de sauvegarde de l'audio généré. Defaults to None.
Returns:
tuple[np.ndarray, int]: audio_array, sample_rate
"""
speed = config.get('speed', 1.06)
language = config.get('language', "wo")
reference_audio = config.get('reference_audio', None)
return self.generate_audio(text=text, reference_audio=reference_audio, speed=speed, language=language, output_path=output_path)
# Exemple d'utilisation
if __name__ == "__main__":
tts = WolofXTTSInference()
# Exemple de génération d'audio
text = "Màngi tuddu Aadama, di baat bii waa Galsen A.I defar ngir wax ak yéen ci wolof!"
# Simple
audio, sr = tts.generate_audio(
text,
output_path="generated_audio.wav"
)
# Avec une config
config_gen_audio = {
"speed": 1.2,
"language": "wo",
}
audio, sr = tts.generate_audio_from_config(
text=text,
config=config_gen_audio,
output_path="generated_audio_config.wav"
)