"""Røst ASR demo.""" import os import warnings import gradio as gr import numpy as np import samplerate import torch from punctfix import PunctFixer from transformers import pipeline warnings.filterwarnings("ignore", category=FutureWarning) TITLE = "Røst ASR Demo" DESCRIPTION = """ This is a demo of the Danish speech recognition model [Røst](https://huggingface.co/alexandrainst/roest-315m). Speak into the microphone and see the text appear on the screen! """ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") transcriber = pipeline( task="automatic-speech-recognition", model="alexandrainst/roest-315m", device=device, token=os.getenv("HUGGINGFACE_HUB_TOKEN", True), ) transcription_fixer = PunctFixer(language="da", device=device) def transcribe_audio(sampling_rate_and_audio: tuple[int, np.ndarray]) -> str: """Transcribe the audio. Args: sampling_rate_and_audio: A tuple with the sampling rate and the audio. Returns: The transcription. """ sampling_rate, audio = sampling_rate_and_audio if audio.ndim > 1: audio = np.mean(audio, axis=1) audio = samplerate.resample(audio, 16_000 / sampling_rate, "sinc_best") transcription = transcriber(inputs=audio) if not isinstance(transcription, dict): return "" cleaned_transcription = transcription_fixer.punctuate( text=transcription["text"] ) return cleaned_transcription demo = gr.Interface( fn=transcribe_audio, inputs=gr.Audio(sources=["microphone", "upload"]), outputs="textbox", title=TITLE, description=DESCRIPTION, allow_flagging="never", ) demo.launch()