roest-demo / app.py
saattrupdan's picture
feat: Use token env var, add link to Røst model
4db4bee
raw
history blame
1.7 kB
"""Røst ASR demo."""
import os
import warnings
import gradio as gr
import numpy as np
import samplerate
import torch
from punctfix import PunctFixer
from transformers import pipeline
warnings.filterwarnings("ignore", category=FutureWarning)
TITLE = "Røst ASR Demo"
DESCRIPTION = """
This is a demo of the Danish speech recognition model
[Røst](https://huggingface.co/alexandrainst/roest-315m). Speak into the microphone and
see the text appear on the screen!
"""
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
transcriber = pipeline(
task="automatic-speech-recognition",
model="alexandrainst/roest-315m",
device=device,
token=os.getenv("HUGGINGFACE_HUB_TOKEN", True),
)
transcription_fixer = PunctFixer(language="da", device=device)
def transcribe_audio(sampling_rate_and_audio: tuple[int, np.ndarray]) -> str:
"""Transcribe the audio.
Args:
sampling_rate_and_audio:
A tuple with the sampling rate and the audio.
Returns:
The transcription.
"""
sampling_rate, audio = sampling_rate_and_audio
if audio.ndim > 1:
audio = np.mean(audio, axis=1)
audio = samplerate.resample(audio, 16_000 / sampling_rate, "sinc_best")
transcription = transcriber(inputs=audio)
if not isinstance(transcription, dict):
return ""
cleaned_transcription = transcription_fixer.punctuate(
text=transcription["text"]
)
return cleaned_transcription
demo = gr.Interface(
fn=transcribe_audio,
inputs=gr.Audio(sources=["microphone", "upload"]),
outputs="textbox",
title=TITLE,
description=DESCRIPTION,
allow_flagging="never",
)
demo.launch()