|
import librosa |
|
import numpy as np |
|
import torch |
|
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan |
|
from datasets import load_dataset |
|
import time |
|
import re |
|
import inflect |
|
from typing import Dict, List, Any |
|
|
|
def contains_special_characters(s): |
|
return bool(re.search(r'[π΅πΎπππ]', s)) |
|
|
|
def check_punctuation(s): |
|
if s.endswith('.'): |
|
return '.' |
|
elif s.endswith(','): |
|
return ',' |
|
elif s.endswith('!'): |
|
return '!' |
|
elif s.endswith('?'): |
|
return '?' |
|
else: |
|
return '' |
|
|
|
def convert_numbers_to_text(input_string): |
|
p = inflect.engine() |
|
new_string = input_string |
|
|
|
|
|
mixed_patterns = re.findall(r'\[?\b\d+[^)\] ]*\]?', new_string) |
|
for pattern in mixed_patterns: |
|
|
|
numbers = re.findall(r'\d+', pattern) |
|
|
|
for number in numbers: |
|
number_word = p.number_to_words(number) |
|
pattern_with_words = re.sub(number_word, number, pattern, 1) |
|
new_string = new_string.replace(pattern, pattern_with_words) |
|
|
|
words = new_string.split() |
|
new_words = [] |
|
|
|
for word in words: |
|
|
|
punct = check_punctuation(word) |
|
|
|
if contains_special_characters(word): |
|
pass |
|
elif word.isdigit() and len(word) == 4: |
|
year = int(word) |
|
if year < 2000: |
|
|
|
first_part = year // 100 |
|
second_part = year % 100 |
|
|
|
word = p.number_to_words(first_part) + " " + p.number_to_words(second_part) |
|
elif year < 9999: |
|
|
|
word = p.number_to_words(year) |
|
elif word.replace(',','').isdigit(): |
|
word = word.replace(',','') |
|
number = int(word) |
|
word = p.number_to_words(number).replace(',', '') |
|
|
|
|
|
new_words.append(word+punct) |
|
|
|
return ' '.join(new_words) |
|
|
|
def split_and_recombine_text(text, desired_length=200, max_length=400): |
|
"""Split text it into chunks of a desired length trying to keep sentences intact.""" |
|
|
|
text = re.sub(r'\n\n+', '\n', text) |
|
text = re.sub(r'\s+', ' ', text) |
|
text = re.sub(r'[ββ]', '"', text) |
|
|
|
rv = [] |
|
in_quote = False |
|
current = "" |
|
split_pos = [] |
|
pos = -1 |
|
end_pos = len(text) - 1 |
|
|
|
def seek(delta): |
|
nonlocal pos, in_quote, current |
|
is_neg = delta < 0 |
|
for _ in range(abs(delta)): |
|
if is_neg: |
|
pos -= 1 |
|
current = current[:-1] |
|
else: |
|
pos += 1 |
|
current += text[pos] |
|
if text[pos] == '"': |
|
in_quote = not in_quote |
|
return text[pos] |
|
|
|
def peek(delta): |
|
p = pos + delta |
|
return text[p] if p < end_pos and p >= 0 else "" |
|
|
|
def commit(): |
|
nonlocal rv, current, split_pos |
|
rv.append(current) |
|
current = "" |
|
split_pos = [] |
|
|
|
while pos < end_pos: |
|
c = seek(1) |
|
|
|
if len(current) >= max_length: |
|
if len(split_pos) > 0 and len(current) > (desired_length / 2): |
|
|
|
d = pos - split_pos[-1] |
|
seek(-d) |
|
else: |
|
|
|
while c not in '!?.\n ' and pos > 0 and len(current) > desired_length: |
|
c = seek(-1) |
|
commit() |
|
|
|
elif not in_quote and (c in '!?\n' or (c == '.' and peek(1) in '\n ')): |
|
|
|
while pos < len(text) - 1 and len(current) < max_length and peek(1) in '!?.': |
|
c = seek(1) |
|
split_pos.append(pos) |
|
if len(current) >= desired_length: |
|
commit() |
|
|
|
elif in_quote and peek(1) == '"' and peek(2) in '\n ': |
|
seek(2) |
|
split_pos.append(pos) |
|
rv.append(current) |
|
|
|
|
|
rv = [s.strip() for s in rv] |
|
rv = [s for s in rv if len(s) > 0 and not re.match(r'^[\s\.,;:!?]*$', s)] |
|
|
|
return rv |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
|
|
|
|
|
|
|
|
|
|
checkpoint = "Dupaja/speecht5_tts" |
|
vocoder_id = "Dupaja/speecht5_hifigan" |
|
dataset_id = "Dupaja/cmu-arctic-xvectors" |
|
|
|
self.model= SpeechT5ForTextToSpeech.from_pretrained(checkpoint, low_cpu_mem_usage=True) |
|
|
|
self.processor = SpeechT5Processor.from_pretrained(checkpoint) |
|
self.vocoder = SpeechT5HifiGan.from_pretrained(vocoder_id) |
|
|
|
embeddings_dataset = load_dataset(dataset_id, split="validation", trust_remote_code=True) |
|
self.embeddings_dataset = embeddings_dataset |
|
|
|
self.speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0) |
|
|
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
given_text = data.get("inputs", "") |
|
given_text = given_text.replace('&','and') |
|
given_text = given_text.replace('-',' ') |
|
|
|
start_time = time.time() |
|
|
|
given_text = convert_numbers_to_text(given_text) |
|
|
|
texts = split_and_recombine_text(given_text) |
|
audios = [] |
|
|
|
for t in texts: |
|
inputs = self.processor(text=t, return_tensors="pt") |
|
speech = self.model.generate_speech(inputs["input_ids"], self.speaker_embeddings, vocoder=self.vocoder) |
|
|
|
audios.append(speech) |
|
|
|
|
|
|
|
final_speech = np.concatenate(audios) |
|
|
|
run_time_total = time.time() - start_time |
|
|
|
|
|
return { |
|
"statusCode": 200, |
|
"body": { |
|
"audio": final_speech, |
|
"sampling_rate": 16000, |
|
"run_time_total": str(run_time_total), |
|
} |
|
} |
|
|
|
handler = EndpointHandler() |