File size: 2,743 Bytes
ec3f286 479ac18 ec3f286 479ac18 ec3f286 479ac18 01f35e8 ec3f286 01f35e8 479ac18 ec3f286 479ac18 ec3f286 479ac18 ec3f286 479ac18 ec3f286 479ac18 ec3f286 479ac18 966be4c 01f35e8 479ac18 ec3f286 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
from typing import Dict, List, Any
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor
from transformers.models.speecht5.number_normalizer import EnglishNumberNormalizer
import torch
import re
from string import punctuation
device = "cuda" if torch.cuda.is_available() else "cpu"
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.tokenizer = AutoTokenizer.from_pretrained(path)
#self.feature_extractor = AutoFeatureExtractor.from_pretrained(path)
self.model = ParlerTTSForConditionalGeneration.from_pretrained(path).to(device) #torch_dtype=torch.float16
self.number_normalizer = EnglishNumberNormalizer() # Initialize number normalizer
def preprocess_text(self, text):
"""Implement the same preprocessing as the Gradio app"""
text = self.number_normalizer(text).strip()
text = text.replace("-", " ")
if text[-1] not in punctuation:
text = f"{text}."
abbreviations_pattern = r'\b[A-Z][A-Z\.]+\b'
abbreviations = re.findall(abbreviations_pattern, text)
for abv in abbreviations:
if abv in text:
text = text.replace(abv, " ".join(abv.replace(".","")))
return text
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
"""
Args:
data (:dict:):
The payload with the text prompt and generation parameters.
"""
# process input
inputs = data.pop("inputs", data)
voice_description = data.pop("voice_description", "data")
parameters = data.pop("parameters", None)
gen_kwargs = {"min_new_tokens": 10}
if parameters is not None:
gen_kwargs.update(parameters)
processed_text = self.preprocess_text(inputs)
# preprocess
inputs = self.tokenizer(
text=[processed_text],
padding=True,
return_tensors="pt",
).to(device)
voice_description = self.tokenizer(
text=[voice_description],
padding=True,
return_tensors="pt",
).to(device)
# pass inputs with all kwargs in data
with torch.autocast(device):
outputs = self.model.generate(
**voice_description, prompt_input_ids=inputs.input_ids,
prompt_attention_mask=inputs.attention_mask, #prompt_attention_mask=voice_description.attention_mask,
**gen_kwargs
)
# postprocess the prediction
prediction = outputs[0].cpu().numpy().tolist()
return [{"generated_audio": prediction}] |