RasmusToivanen
add article, change to gradio 3, remove 300m model
af31d45
raw
history blame
3.77 kB
import gradio as gr
import librosa
import soundfile as sf
import torch
import warnings
import os
from transformers import Wav2Vec2ProcessorWithLM, Wav2Vec2CTCTokenizer
warnings.filterwarnings("ignore")
#load wav2vec2 tokenizer and model
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from fastapi import FastAPI, HTTPException, File
from transformers import pipeline
pipe_95m = pipeline(model="Finnish-NLP/wav2vec2-base-fi-voxpopuli-v2-finetuned",chunk_length_s=20, stride_length_s=(3, 3))
pipe_1b = pipeline(model="Finnish-NLP/wav2vec2-xlsr-1b-finnish-lm-v2",chunk_length_s=20, stride_length_s=(3, 3))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_checkpoint = 'Finnish-NLP/t5-small-nl24-casing-punctuation-correction'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_auth_token=os.environ.get('hf_token'))
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, from_flax=False, torch_dtype=torch.float32, use_auth_token=os.environ.get('hf_token')).to(device)
# define speech-to-text function
def asr_transcript(audio, audio_microphone, model_params):
audio = audio_microphone if audio_microphone else audio
if audio == None and audio_microphone == None:
return "Please provide audio by uploading a file or by recording audio using microphone by pressing Record (And allow usage of microphone)", "Please provide audio by uploading a file or by recording audio using microphone by pressing Record (And allow usage of microphone)"
text = ""
if audio:
if model_params == "1 billion":
text = pipe_1b(audio.name)
elif model_params == "95 million":
text = pipe_95m(audio.name)
input_ids = tokenizer(text['text'], return_tensors="pt").input_ids.to(device)
outputs = model.generate(input_ids, max_length=128)
case_corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return text['text'], case_corrected_text
else:
return "File not valid"
gradio_ui = gr.Interface(
fn=asr_transcript,
title="Finnish Automatic Speech Recognition",
description="Upload an audio clip or record from browser using microphone, and let AI do the hard work of transcribing.",
article = """
This demo includes 2 kinds of models that are run together. First selected ASR model does speech recognition which produces lowercase text without punctuation.
After that we run a sequence-to-sequence model which tries to correct casing and punctuation which produces the final output.
You can select one of two speech recognition models listed below
1. 1 billion, best accuracy but slowest by big margin. Based on multilingual wav2vec2-xlsr model by Meta. More info here https://huggingface.co/Finnish-NLP/wav2vec2-xlsr-1b-finnish-lm-v2
2. 95 million, almost as accurate as 1. but really much faster. Based on finnish wav2vec2-xlsr model by Meta. More info here https://huggingface.co/Finnish-NLP/wav2vec2-base-fi-voxpopuli-v2-finetuned
More info about the casing+punctuation correction model can be found here https://huggingface.co/Finnish-NLP/t5-small-nl24-casing-punctuation-correction
""",
inputs=[gr.inputs.Audio(label="Upload Audio File", type="file", optional=True), gr.inputs.Audio(source="microphone", type="file", optional=True, label="Record from microphone"), gr.inputs.Dropdown(choices=["95 million","1 billion"], type="value", default="1 billion", label="Select speech recognition model parameter amount", optional=False)],
outputs=[gr.outputs.Textbox(label="Recognized speech"),gr.outputs.Textbox(label="Recognized speech with case correction and punctuation")]
)
gradio_ui.launch()