File size: 2,743 Bytes
ec3f286
 
479ac18
 
ec3f286
479ac18
 
 
 
ec3f286
 
 
 
 
479ac18
 
01f35e8
ec3f286
01f35e8
 
 
 
 
 
 
 
 
 
 
 
 
479ac18
ec3f286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479ac18
 
ec3f286
 
479ac18
ec3f286
479ac18
 
ec3f286
 
 
479ac18
 
ec3f286
 
479ac18
 
 
966be4c
01f35e8
479ac18
ec3f286
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from typing import Dict, List, Any
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor
from transformers.models.speecht5.number_normalizer import EnglishNumberNormalizer
import torch
import re
from string import punctuation

device = "cuda" if torch.cuda.is_available() else "cpu"

class EndpointHandler:
    def __init__(self, path=""):
        # load model and processor from path
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        #self.feature_extractor = AutoFeatureExtractor.from_pretrained(path)
        self.model = ParlerTTSForConditionalGeneration.from_pretrained(path).to(device)  #torch_dtype=torch.float16
        self.number_normalizer = EnglishNumberNormalizer()  # Initialize number normalizer

    def preprocess_text(self, text):
        """Implement the same preprocessing as the Gradio app"""
        text = self.number_normalizer(text).strip()
        text = text.replace("-", " ")
        if text[-1] not in punctuation:
            text = f"{text}."
        
        abbreviations_pattern = r'\b[A-Z][A-Z\.]+\b'
        abbreviations = re.findall(abbreviations_pattern, text)
        for abv in abbreviations:
            if abv in text:
                text = text.replace(abv, " ".join(abv.replace(".","")))
        return text
    
    def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
        """
        Args:
            data (:dict:):
                The payload with the text prompt and generation parameters.
        """
        # process input
        inputs = data.pop("inputs", data)
        voice_description = data.pop("voice_description", "data")
        parameters = data.pop("parameters", None)

        gen_kwargs = {"min_new_tokens": 10}
        if parameters is not None:
            gen_kwargs.update(parameters)

        processed_text = self.preprocess_text(inputs)

        # preprocess
        inputs = self.tokenizer(
            text=[processed_text],
            padding=True,
            return_tensors="pt",
        ).to(device)
        voice_description = self.tokenizer(
            text=[voice_description],
            padding=True,
            return_tensors="pt",
        ).to(device)

        # pass inputs with all kwargs in data
        with torch.autocast(device):
            outputs = self.model.generate(
                **voice_description, prompt_input_ids=inputs.input_ids, 
                prompt_attention_mask=inputs.attention_mask,  #prompt_attention_mask=voice_description.attention_mask, 
                **gen_kwargs
            )

        # postprocess the prediction
        prediction = outputs[0].cpu().numpy().tolist()

        return [{"generated_audio": prediction}]