Spaces:
Paused
Paused
import gradio as gr | |
import numpy as np | |
import torch | |
import transformers | |
from pathlib import Path | |
from transformers import pipeline | |
from transformers.utils import logging | |
# Log | |
#logging.set_verbosity_debug() | |
logger = logging.get_logger("transformers") | |
# Pipelines | |
device = 0 if torch.cuda.is_available() else "cpu" | |
## Automatic Speech Recognition | |
## https://huggingface.co/docs/transformers/task_summary#automatic-speech-recognition | |
## Require ffmpeg to be installed | |
asr_model = "openai/whisper-tiny" | |
asr = pipeline( | |
"automatic-speech-recognition", | |
model=asr_model, | |
# torch_dtype=torch.float16, | |
device=device | |
) | |
## Token Classification / Name Entity Recognition | |
## https://huggingface.co/docs/transformers/task_summary#token-classification | |
tc_model = "dslim/distilbert-NER" | |
tc = pipeline( | |
"token-classification", # ner | |
model=tc_model, | |
device=device | |
) | |
# --- | |
# Transformers | |
# https://www.gradio.app/main/docs/gradio/audio#behavior | |
# As output component: expects audio data in any of these formats: | |
# - a str or pathlib.Path filepath | |
# - or URL to an audio file, | |
# - or a bytes object (recommended for streaming), | |
# - or a tuple of (sample rate in Hz, audio data as numpy array) | |
def transcribe(audio: str | Path | bytes | tuple[int, np.ndarray] | None): | |
if audio is None: | |
return "..." | |
# TODO Manage str/Path | |
logger.debug("Transcribe") | |
text = "" | |
# https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline.__call__ | |
# Whisper input format for tuple differ from output provided by gradio audio component | |
if asr_model.startswith("openai/whisper"): | |
inputs = {"sampling_rate": audio[0], "raw": audio[1]} if type(audio) is tuple else audio | |
transcript = asr(inputs) | |
text = transcript['text'] | |
logger.debug("Tokenize:[" + text + "]") | |
entities = tc(text) | |
#logger.debug("Classify:[" + entities + "]") | |
# TODO Add Text Classification for sentiment analysis | |
return {"text": text, "entities": entities} | |
# --- | |
# Gradio | |
## Interfaces | |
# https://www.gradio.app/main/docs/gradio/audio | |
input_audio = gr.Audio( | |
sources=["upload", "microphone"], | |
show_share_button=False | |
) | |
## App | |
gradio_app = gr.Interface( | |
transcribe, | |
inputs=[ | |
input_audio | |
], | |
outputs=[ | |
gr.HighlightedText() | |
], | |
title="ASRNERSBX", | |
description=( | |
"Transcribe, Tokenize, Classify" | |
), | |
flagging_mode="never" | |
) | |
## Start! | |
gradio_app.launch() |