from typing import Dict, List, Any from parler_tts import ParlerTTSForConditionalGeneration from transformers import AutoTokenizer, AutoFeatureExtractor from transformers.models.speecht5.number_normalizer import EnglishNumberNormalizer import torch import re from string import punctuation device = "cuda" if torch.cuda.is_available() else "cpu" class EndpointHandler: def __init__(self, path=""): # load model and processor from path self.tokenizer = AutoTokenizer.from_pretrained(path) #self.feature_extractor = AutoFeatureExtractor.from_pretrained(path) self.model = ParlerTTSForConditionalGeneration.from_pretrained(path).to(device) #torch_dtype=torch.float16 self.number_normalizer = EnglishNumberNormalizer() # Initialize number normalizer def preprocess_text(self, text): """Implement the same preprocessing as the Gradio app""" text = self.number_normalizer(text).strip() text = text.replace("-", " ") if text[-1] not in punctuation: text = f"{text}." abbreviations_pattern = r'\b[A-Z][A-Z\.]+\b' abbreviations = re.findall(abbreviations_pattern, text) for abv in abbreviations: if abv in text: text = text.replace(abv, " ".join(abv.replace(".",""))) return text def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: """ Args: data (:dict:): The payload with the text prompt and generation parameters. """ # process input inputs = data.pop("inputs", data) voice_description = data.pop("voice_description", "data") parameters = data.pop("parameters", None) gen_kwargs = {"min_new_tokens": 10} if parameters is not None: gen_kwargs.update(parameters) processed_text = self.preprocess_text(inputs) # preprocess inputs = self.tokenizer( text=[processed_text], padding=True, return_tensors="pt", ).to(device) voice_description = self.tokenizer( text=[voice_description], padding=True, return_tensors="pt", ).to(device) # pass inputs with all kwargs in data with torch.autocast(device): outputs = self.model.generate( **voice_description, prompt_input_ids=inputs.input_ids, prompt_attention_mask=inputs.attention_mask, #prompt_attention_mask=voice_description.attention_mask, **gen_kwargs ) # postprocess the prediction prediction = outputs[0].cpu().numpy().tolist() return [{"generated_audio": prediction}]