DrishtiSharma's picture
Update app.py
11b23f3
raw
history blame
5.07 kB
import gradio as gr
import librosa
from transformers import AutoFeatureExtractor, AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
def load_and_fix_data(input_file, model_sampling_rate):
speech, sample_rate = librosa.load(input_file)
if len(speech.shape) > 1:
speech = speech[:, 0] + speech[:, 1]
if sample_rate != model_sampling_rate:
speech = librosa.resample(speech, sample_rate, model_sampling_rate)
return speech
feature_extractor = AutoFeatureExtractor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-spanish")
sampling_rate = feature_extractor.sampling_rate
asr = pipeline("automatic-speech-recognition", model="jonatasgrosman/wav2vec2-large-xlsr-53-spanish")
prefix = ''
model_checkpoint = "hackathon-pln-es/es_text_neutralizer"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
def postproc(input_sentence, preds):
try:
preds = preds.replace('De el', 'Del').replace('de el', 'del').replace(' ', ' ')
if preds[0].islower():
preds = preds.capitalize()
preds = preds.replace(' . ', '. ').replace(' , ', ', ')
# Nombres en mayusculas
prev_letter = ''
for word in input_sentence.split(' '):
if word:
if word[0].isupper():
if word.lower() in preds and word != input_sentence.split(' ')[0]:
if prev_letter == '.':
preds = preds.replace('. ' + word.lower() + ' ', '. ' + word + ' ')
else:
if word[-1] == '.':
preds = preds.replace(word.lower(), word)
else:
preds = preds.replace(word.lower() + ' ', word + ' ')
prev_letter = word[-1]
preds = preds.strip() # quitar ultimo espacio
except:
pass
return preds
model_name = "es/mai/tacotron2-DDC"
def predict_and_ctc_lm_decode(input_file, speaker_idx: str=None):
speech = load_and_fix_data(input_file, sampling_rate)
transcribed_text = asr(speech, chunk_length_s=5, stride_length_s=1)
transcribed_text = transcribed_text["text"]
inputs = tokenizer([prefix + transcribed_text], return_tensors="pt", padding=True)
with torch.no_grad():
if first_generation:
output_sequence = model.generate(
input_ids=inputs["input_ids"].to(device),
attention_mask=inputs["attention_mask"].to(device),
do_sample=False, # disable sampling to test if batching affects output
)
else:
output_sequence = model.generate(
input_ids=inputs["input_ids"].to(device),
attention_mask=inputs["attention_mask"].to(device),
do_sample=False,
num_beams=2,
repetition_penalty=2.5,
# length_penalty=1.0,
early_stopping=True# disable sampling to test if batching affects output
)
preds = postproc(transcribed_text,
preds=tokenizer.decode(output_sequence[0], skip_special_tokens=True, clean_up_tokenization_spaces=True))
if len(preds) > MAX_TXT_LEN:
text = preds[:MAX_TXT_LEN]
print(f"Input text was cutoff since it went over the {MAX_TXT_LEN} character limit.")
print(text, model_name)
# download model
model_path, config_path, model_item = manager.download_model(f"tts_models/{model_name}")
vocoder_name: Optional[str] = model_item["default_vocoder"]
# download vocoder
vocoder_path = None
vocoder_config_path = None
if vocoder_name is not None:
vocoder_path, vocoder_config_path, _ = manager.download_model(vocoder_name)
# init synthesizer
synthesizer = Synthesizer(
model_path, config_path, None, None, vocoder_path, vocoder_config_path,
)
# synthesize
if synthesizer is None:
raise NameError("model not found")
wavs = synthesizer.tts(preds, speaker_idx)
# return output
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
synthesizer.save_wav(wavs, fp)
return fp.name
gr.Interface(
predict_and_ctc_lm_decode,
inputs=[
gr.inputs.Audio(source="microphone", type="filepath", label="Record your audio")
],
outputs=gr.outputs.Audio(label="Output"),
examples=[["audio1.wav"], ["travel.wav"]],
title="Generate-Gender-Neutralized-Audios",
description = "This is a Gradio demo for generating gender neutralized audios. To use it, simply provide an audio input (via microphone or audio recording), which will then be transcribed and gender-neutralized using a pre-trained models. Finally, with the help of Coqui's TTS model, gender neutralised audio is generated.",
#article="<p><center><img src='........e'></center></p>",
layout="horizontal",
theme="huggingface",
).launch(enable_queue=True, cache_examples=True)