|
"""Some utility functions for the app.""" |
|
from base64 import b64encode |
|
from io import BytesIO |
|
import os |
|
import sys |
|
from gtts import gTTS |
|
from mtranslate import translate |
|
from speech_recognition import AudioFile, Recognizer |
|
from transformers import (BlenderbotSmallForConditionalGeneration, |
|
BlenderbotSmallTokenizer) |
|
from contextlib import closing |
|
import boto3 |
|
from botocore.config import Config |
|
|
|
import time |
|
|
|
|
|
def log_execution_time(func): |
|
def wrapper(*args, **kwargs): |
|
start_time = time.time() |
|
result = func(*args, **kwargs) |
|
end_time = time.time() |
|
execution_time = end_time - start_time |
|
print(f"Execution time of {func.__name__}: {execution_time} seconds") |
|
return result |
|
return wrapper |
|
|
|
def stt(audio: object, language: str) -> str: |
|
"""Converts speech to text. |
|
|
|
Args: |
|
audio: record of user speech |
|
|
|
Returns: |
|
text (str): recognized speech of user |
|
""" |
|
r = Recognizer() |
|
|
|
with AudioFile(audio) as source: |
|
|
|
audio_data = r.record(source) |
|
|
|
text = r.recognize_google(audio_data, language=language) |
|
return text |
|
|
|
|
|
def to_en_translation(text: str, language: str) -> str: |
|
"""Translates text from specified language to English. |
|
|
|
Args: |
|
text (str): input text |
|
language (str): desired language |
|
|
|
Returns: |
|
str: translated text |
|
""" |
|
return translate(text, "en", language) |
|
|
|
|
|
def from_en_translation(text: str, language: str) -> str: |
|
"""Translates text from english to specified language. |
|
|
|
Args: |
|
text (str): input text |
|
language (str): desired language |
|
|
|
Returns: |
|
str: translated text |
|
""" |
|
return translate(text, language, "en") |
|
|
|
|
|
class TextGenerationPipeline: |
|
"""Pipeline for text generation of blenderbot model. |
|
|
|
Returns: |
|
str: generated text |
|
""" |
|
|
|
|
|
model_name = "facebook/blenderbot_small-90M" |
|
tokenizer = BlenderbotSmallTokenizer.from_pretrained(model_name) |
|
model = BlenderbotSmallForConditionalGeneration.from_pretrained(model_name) |
|
|
|
def __init__(self, **kwargs): |
|
"""Specififying text generation parameters. |
|
|
|
For example: max_length=100 which generates text shorter than |
|
100 tokens. Visit: |
|
https://huggingface.co/docs/transformers/main_classes/text_generation |
|
for more parameters |
|
""" |
|
self.__dict__.update(kwargs) |
|
|
|
def preprocess(self, text) -> str: |
|
"""Tokenizes input text. |
|
|
|
Args: |
|
text (str): user specified text |
|
|
|
Returns: |
|
torch.Tensor (obj): text representation as tensors |
|
""" |
|
return self.tokenizer(text, return_tensors="pt") |
|
|
|
def postprocess(self, outputs) -> str: |
|
"""Converts tensors into text. |
|
|
|
Args: |
|
outputs (torch.Tensor obj): model text generation output |
|
|
|
Returns: |
|
str: generated text |
|
""" |
|
return self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
def __call__(self, text: str) -> str: |
|
"""Generates text from input text. |
|
|
|
Args: |
|
text (str): user specified text |
|
|
|
Returns: |
|
str: generated text |
|
""" |
|
tokenized_text = self.preprocess(text) |
|
output = self.model.generate(**tokenized_text, **self.__dict__) |
|
return self.postprocess(output) |
|
|
|
|
|
def tts(text: str, language: str) -> object: |
|
"""Converts text into audio object. |
|
|
|
Args: |
|
text (str): generated answer of bot |
|
|
|
Returns: |
|
object: text to speech object |
|
""" |
|
return gTTS(text=text, lang=language, slow=False) |
|
|
|
|
|
def tts_polly(text: str, language: str) -> object: |
|
my_config = Config( |
|
region_name=os.getenv('AWS_REGION', 'us-east-1'), |
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
|
client = boto3.client('polly', config=my_config) |
|
response = client.synthesize_speech( |
|
Engine='neural', |
|
OutputFormat='mp3', |
|
VoiceId='Camila', |
|
LanguageCode=language, |
|
Text=text) |
|
|
|
return response |
|
|
|
|
|
def tts_polly_to_bytesio(polly_object: object) -> bytes: |
|
|
|
if "AudioStream" in polly_object: |
|
|
|
|
|
|
|
|
|
with closing(polly_object["AudioStream"]) as stream: |
|
try: |
|
bytes_object = BytesIO() |
|
bytes_object.write(stream.read()) |
|
bytes_object.seek(0) |
|
return bytes_object.getvalue() |
|
except IOError as error: |
|
|
|
print(error) |
|
sys.exit(-1) |
|
|
|
else: |
|
|
|
print("Could not stream audio") |
|
sys.exit(-1) |
|
|
|
|
|
def tts_to_bytesio(tts_object: object) -> bytes: |
|
"""Converts tts object to bytes. |
|
|
|
Args: |
|
tts_object (object): audio object obtained from gtts |
|
|
|
Returns: |
|
bytes: audio bytes |
|
""" |
|
bytes_object = BytesIO() |
|
tts_object.write_to_fp(bytes_object) |
|
bytes_object.seek(0) |
|
return bytes_object.getvalue() |
|
|
|
|
|
def html_audio_autoplay(bytes: bytes) -> object: |
|
"""Creates html object for autoplaying audio at gradio app. |
|
|
|
Args: |
|
bytes (bytes): audio bytes |
|
|
|
Returns: |
|
object: html object that provides audio autoplaying |
|
""" |
|
b64 = b64encode(bytes).decode() |
|
html = f""" |
|
<audio controls autoplay> |
|
<source src="data:audio/wav;base64,{b64}" type="audio/wav"> |
|
</audio> |
|
""" |
|
return html |
|
|