speecht5_tts / handler.py
Dupaja's picture
Add handling for fractions and hopefully other number uses
5f8cefd
import librosa
import numpy as np
import torch
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from datasets import load_dataset
import time
import re
import inflect
from typing import Dict, List, Any
def contains_special_characters(s):
return bool(re.search(r'[π“΅π–Ύπ“žπšŸπ”Ÿ]', s))
def check_punctuation(s):
if s.endswith('.'):
return '.'
elif s.endswith(','):
return ','
elif s.endswith('!'):
return '!'
elif s.endswith('?'):
return '?'
else:
return ''
def convert_numbers_to_text(input_string):
p = inflect.engine()
new_string = input_string
# Find patterns like [6/7] or other number-character combinations
mixed_patterns = re.findall(r'\[?\b\d+[^)\] ]*\]?', new_string)
for pattern in mixed_patterns:
# Isolate numbers from other characters
numbers = re.findall(r'\d+', pattern)
# Replace numbers with words within the pattern
for number in numbers:
number_word = p.number_to_words(number)
pattern_with_words = re.sub(number_word, number, pattern, 1)
new_string = new_string.replace(pattern, pattern_with_words)
words = new_string.split()
new_words = []
for word in words:
punct = check_punctuation(word)
if contains_special_characters(word):
pass
elif word.isdigit() and len(word) == 4: # Check for years (4-digit numbers)
year = int(word)
if year < 2000:
# Split the year into two parts
first_part = year // 100
second_part = year % 100
# Convert each part to words and combine
word = p.number_to_words(first_part) + " " + p.number_to_words(second_part)
elif year < 9999:
# Convert directly for year 2000 and beyond
word = p.number_to_words(year)
elif word.replace(',','').isdigit(): # Check for any other digits
word = word.replace(',','')
number = int(word)
word = p.number_to_words(number).replace(',', '')
new_words.append(word+punct)
return ' '.join(new_words)
def split_and_recombine_text(text, desired_length=200, max_length=400):
"""Split text it into chunks of a desired length trying to keep sentences intact."""
# normalize text, remove redundant whitespace and convert non-ascii quotes to ascii
text = re.sub(r'\n\n+', '\n', text)
text = re.sub(r'\s+', ' ', text)
text = re.sub(r'[β€œβ€]', '"', text)
rv = []
in_quote = False
current = ""
split_pos = []
pos = -1
end_pos = len(text) - 1
def seek(delta):
nonlocal pos, in_quote, current
is_neg = delta < 0
for _ in range(abs(delta)):
if is_neg:
pos -= 1
current = current[:-1]
else:
pos += 1
current += text[pos]
if text[pos] == '"':
in_quote = not in_quote
return text[pos]
def peek(delta):
p = pos + delta
return text[p] if p < end_pos and p >= 0 else ""
def commit():
nonlocal rv, current, split_pos
rv.append(current)
current = ""
split_pos = []
while pos < end_pos:
c = seek(1)
# do we need to force a split?
if len(current) >= max_length:
if len(split_pos) > 0 and len(current) > (desired_length / 2):
# we have at least one sentence and we are over half the desired length, seek back to the last split
d = pos - split_pos[-1]
seek(-d)
else:
# no full sentences, seek back until we are not in the middle of a word and split there
while c not in '!?.\n ' and pos > 0 and len(current) > desired_length:
c = seek(-1)
commit()
# check for sentence boundaries
elif not in_quote and (c in '!?\n' or (c == '.' and peek(1) in '\n ')):
# seek forward if we have consecutive boundary markers but still within the max length
while pos < len(text) - 1 and len(current) < max_length and peek(1) in '!?.':
c = seek(1)
split_pos.append(pos)
if len(current) >= desired_length:
commit()
# treat end of quote as a boundary if its followed by a space or newline
elif in_quote and peek(1) == '"' and peek(2) in '\n ':
seek(2)
split_pos.append(pos)
rv.append(current)
# clean up, remove lines with only whitespace or punctuation
rv = [s.strip() for s in rv]
rv = [s for s in rv if len(s) > 0 and not re.match(r'^[\s\.,;:!?]*$', s)]
return rv
class EndpointHandler:
def __init__(self, path=""):
#checkpoint = "microsoft/speecht5_tts"
#vocoder_id = "microsoft/speecht5_hifigan"
#dataset_id = "Matthijs/cmu-arctic-xvectors"
checkpoint = "Dupaja/speecht5_tts"
vocoder_id = "Dupaja/speecht5_hifigan"
dataset_id = "Dupaja/cmu-arctic-xvectors"
self.model= SpeechT5ForTextToSpeech.from_pretrained(checkpoint, low_cpu_mem_usage=True)
self.processor = SpeechT5Processor.from_pretrained(checkpoint)
self.vocoder = SpeechT5HifiGan.from_pretrained(vocoder_id)
embeddings_dataset = load_dataset(dataset_id, split="validation", trust_remote_code=True)
self.embeddings_dataset = embeddings_dataset
self.speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
given_text = data.get("inputs", "")
given_text = given_text.replace('&','and')
given_text = given_text.replace('-',' ')
start_time = time.time()
given_text = convert_numbers_to_text(given_text)
texts = split_and_recombine_text(given_text)
audios = []
for t in texts:
inputs = self.processor(text=t, return_tensors="pt")
speech = self.model.generate_speech(inputs["input_ids"], self.speaker_embeddings, vocoder=self.vocoder)
audios.append(speech)
#audios.append(speech.numpy())
final_speech = np.concatenate(audios)
run_time_total = time.time() - start_time
# Return the expected response format
return {
"statusCode": 200,
"body": {
"audio": final_speech, # Consider encoding this to a suitable format
"sampling_rate": 16000,
"run_time_total": str(run_time_total),
}
}
handler = EndpointHandler()