asrnersbx / app.py
MikeTangoEcho's picture
feat: update app
f09d2ab
raw
history blame
3.61 kB
import gradio as gr
import numpy as np
import torch
import transformers
from pathlib import Path
from transformers import pipeline
from transformers.utils import logging
# Log
#logging.set_verbosity_debug()
logger = logging.get_logger("transformers")
# Pipelines
## Automatic Speech Recognition
## https://huggingface.co/docs/transformers/task_summary#automatic-speech-recognition
## Require ffmpeg to be installed
asr_device = "cuda:0" if torch.cuda.is_available() else "cpu"
asr_torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
asr_model = "openai/whisper-tiny"
asr = pipeline(
"automatic-speech-recognition",
model=asr_model,
torch_dtype=asr_torch_dtype,
device=asr_device
)
## Token Classification / Name Entity Recognition
## https://huggingface.co/docs/transformers/task_summary#token-classification
tc_device = 0 if torch.cuda.is_available() else "cpu"
tc_model = "dslim/distilbert-NER"
tc = pipeline(
"token-classification", # ner
model=tc_model,
device=tc_device
)
# ---
# Transformers
# https://www.gradio.app/main/docs/gradio/audio#behavior
# As output component: expects audio data in any of these formats:
# - a str or pathlib.Path filepath
# - or URL to an audio file,
# - or a bytes object (recommended for streaming),
# - or a tuple of (sample rate in Hz, audio data as numpy array)
def transcribe(audio: str | Path | bytes | tuple[int, np.ndarray] | None):
logger.debug(">Transcribe")
if audio is None:
return "..."
# TODO Manage str/Path
text = ""
# https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline.__call__
# Whisper input format for tuple differ from output provided by gradio audio component
if asr_model.startswith("openai/whisper") and type(audio) is tuple:
sampling_rate, raw = audio
# Convert to mono if stereo
if raw.ndim > 1:
raw = raw.mean(axis=1)
# Convert according to asr_torch_dtype
raw = raw.astype(np.float16 if type(asr_torch_dtype) is torch.float16 else np.float32)
raw /= np.max(np.abs(raw))
inputs = {"sampling_rate": sampling_rate, "raw": raw}
logger.debug(inputs)
transcript = asr(inputs)
text = transcript['text']
logger.debug(text)
return text
def tokenize(text: str):
logger.debug(">Tokenize")
entities = tc(text)
logger.debug(entities)
# TODO Add Text Classification for sentiment analysis
return {"text": text, "entities": entities}
def classify(text: str):
logger.debug(">Classify")
return None
def transcribe_tokenize(*arg):
return tokenize(transcribe(arg))
# ---
# Gradio
## Interfaces
# https://www.gradio.app/main/docs/gradio/audio
input_audio = gr.Audio(
sources=["upload", "microphone"],
show_share_button=False
)
## App
asrner_app = gr.Interface(
transcribe_tokenize,
inputs=[
input_audio
],
outputs=[
gr.HighlightedText()
],
title="ASR>NER",
description=(
"Transcribe, Tokenize, Classify"
),
flagging_mode="never"
)
ner_app = gr.Interface(
tokenize,
inputs=[
gr.Textbox()
],
outputs=[
gr.HighlightedText()
],
title="NER",
description=(
"Tokenize, Classify"
),
flagging_mode="never"
)
gradio_app = gr.TabbedInterface(
interface_list=[
asrner_app,
ner_app
],
tab_names=[
asrner_app.title,
ner_app.title
],
title="ASRNERSBX"
)
## Start!
gradio_app.launch()