|
import torch |
|
import os |
|
import logging |
|
import soundfile as sf |
|
import numpy as np |
|
from huggingface_hub import hf_hub_download |
|
from TTS.tts.configs.xtts_config import XttsConfig |
|
from TTS.tts.models.xtts import Xtts |
|
|
|
|
|
REPO_ID = "dofbi/galsenai-xtts-v2-wolof-inference" |
|
LOCAL_DIR = "./models" |
|
|
|
class WolofXTTSInference: |
|
def __init__(self, repo_id=REPO_ID, local_dir=LOCAL_DIR): |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s' |
|
) |
|
self.logger = logging.getLogger(__name__) |
|
|
|
|
|
os.makedirs(local_dir, exist_ok=True) |
|
|
|
|
|
try: |
|
|
|
os.makedirs(os.path.join(local_dir, "Anta_GPT_XTTS_Wo"), exist_ok=True) |
|
os.makedirs(os.path.join(local_dir, "XTTS_v2.0_original_model_files"), exist_ok=True) |
|
|
|
|
|
self.model_path = hf_hub_download( |
|
repo_id=repo_id, |
|
filename="Anta_GPT_XTTS_Wo/best_model_89250.pth", |
|
local_dir=local_dir |
|
) |
|
|
|
|
|
self.config_path = hf_hub_download( |
|
repo_id=repo_id, |
|
filename="Anta_GPT_XTTS_Wo/config.json", |
|
local_dir=local_dir |
|
) |
|
|
|
|
|
self.vocab_path = hf_hub_download( |
|
repo_id=repo_id, |
|
filename="XTTS_v2.0_original_model_files/vocab.json", |
|
local_dir=local_dir |
|
) |
|
|
|
|
|
self.reference_audio = hf_hub_download( |
|
repo_id=repo_id, |
|
filename="anta_sample.wav", |
|
local_dir=local_dir |
|
) |
|
|
|
except Exception as e: |
|
self.logger.error(f"Erreur lors du téléchargement des fichiers : {e}") |
|
raise |
|
|
|
|
|
self.device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
self.model = self._load_model() |
|
|
|
def _load_model(self): |
|
"""Charge le modèle XTTS""" |
|
try: |
|
self.logger.info("Chargement du modèle XTTS...") |
|
|
|
|
|
config = XttsConfig() |
|
config.load_json(self.config_path) |
|
model = Xtts.init_from_config(config) |
|
|
|
|
|
model.load_checkpoint(config, |
|
checkpoint_path=self.model_path, |
|
vocab_path=self.vocab_path, |
|
use_deepspeed=False |
|
) |
|
|
|
model.to(self.device) |
|
model.eval() |
|
|
|
self.logger.info("Modèle chargé avec succès!") |
|
return model |
|
|
|
except Exception as e: |
|
self.logger.error(f"Erreur lors du chargement du modèle : {e}") |
|
raise |
|
|
|
def generate_audio( |
|
self, |
|
text: str, |
|
reference_audio: str = None, |
|
speed: float = 1.06, |
|
language: str = "wo", |
|
output_path: str = None |
|
) -> tuple[np.ndarray, int]: |
|
""" |
|
Génère de l'audio à partir du texte fourni |
|
|
|
Args: |
|
text (str): Texte à convertir en audio |
|
reference_audio (str, optional): Chemin vers l'audio de référence. Defaults to None. |
|
speed (float, optional): Vitesse de lecture. Defaults to 1.06. |
|
language (str, optional): Langue du texte. Defaults to "wo". |
|
output_path (str, optional): Chemin de sauvegarde de l'audio généré. Defaults to None. |
|
|
|
Returns: |
|
tuple[np.ndarray, int]: audio_array, sample_rate |
|
""" |
|
if not text: |
|
raise ValueError("Le texte ne peut pas être vide.") |
|
|
|
try: |
|
|
|
ref_audio = reference_audio or self.reference_audio |
|
|
|
|
|
gpt_cond_latent, speaker_embedding = self.model.get_conditioning_latents( |
|
audio_path=[ref_audio], |
|
gpt_cond_len=self.model.config.gpt_cond_len, |
|
max_ref_length=self.model.config.max_ref_len, |
|
sound_norm_refs=self.model.config.sound_norm_refs |
|
) |
|
|
|
|
|
result = self.model.inference( |
|
text=text.lower(), |
|
gpt_cond_latent=gpt_cond_latent, |
|
speaker_embedding=speaker_embedding, |
|
do_sample=False, |
|
speed=speed, |
|
language=language, |
|
enable_text_splitting=True |
|
) |
|
|
|
|
|
sample_rate = self.model.config.audio.sample_rate |
|
|
|
|
|
if output_path: |
|
sf.write(output_path, result["wav"], sample_rate) |
|
self.logger.info(f"Audio sauvegardé dans {output_path}") |
|
|
|
return result["wav"], sample_rate |
|
|
|
except Exception as e: |
|
self.logger.error(f"Erreur lors de la génération de l'audio : {e}") |
|
raise |
|
|
|
def generate_audio_from_config(self, text: str, config: dict, output_path: str = None) -> tuple[np.ndarray, int]: |
|
""" |
|
Génère de l'audio à partir du texte et d'un dictionnaire de configuration. |
|
|
|
Args: |
|
text (str): Texte à convertir en audio |
|
config (dict): Dictionnaire de configuration (speed, language, reference_audio) |
|
output_path (str, optional): Chemin de sauvegarde de l'audio généré. Defaults to None. |
|
|
|
Returns: |
|
tuple[np.ndarray, int]: audio_array, sample_rate |
|
""" |
|
speed = config.get('speed', 1.06) |
|
language = config.get('language', "wo") |
|
reference_audio = config.get('reference_audio', None) |
|
return self.generate_audio(text=text, reference_audio=reference_audio, speed=speed, language=language, output_path=output_path) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
tts = WolofXTTSInference() |
|
|
|
|
|
text = "Màngi tuddu Aadama, di baat bii waa Galsen A.I defar ngir wax ak yéen ci wolof!" |
|
|
|
|
|
audio, sr = tts.generate_audio( |
|
text, |
|
output_path="generated_audio.wav" |
|
) |
|
|
|
|
|
config_gen_audio = { |
|
"speed": 1.2, |
|
"language": "wo", |
|
} |
|
audio, sr = tts.generate_audio_from_config( |
|
text=text, |
|
config=config_gen_audio, |
|
output_path="generated_audio_config.wav" |
|
) |