Spaces:
Running
Running
File size: 3,384 Bytes
0ea1ca0 cc307c2 0ea1ca0 73b127f 42f99c5 b8b03ab 73b127f a24b835 3d1236b 73b127f 3d1236b a24b835 3d1236b 73b127f 3d1236b 0ea1ca0 73b127f fe71c01 0ea1ca0 a24b835 21b1477 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import gradio as gr
import whisper_at as whisper
link = "https://github.com/YuanGongND/whisper-AT"
text = "[Github]"
paper_link = "https://arxiv.org/pdf/2307.03183.pdf"
paper_text = "[Paper]"
model_large = whisper.load_model("large-v1")
model_tiny = whisper.load_model("tiny")
model_tiny_en = whisper.load_model("tiny.en")
model_small = whisper.load_model("small")
mdl_dict = {"tiny": model_tiny, "tiny.en": model_tiny_en, "small": model_small, "large": model_large}
def round_time_resolution(time_resolution):
multiple = float(time_resolution) / 0.4
rounded_multiple = round(multiple)
rounded_time_resolution = rounded_multiple * 0.4
return rounded_time_resolution
def predict(audio_path_m, audio_path_t, model_size, time_resolution):
# print(audio_path_m, audio_path_t)
# print(type(audio_path_m), type(audio_path_t))
#return audio_path_m, audio_path_t
if ((audio_path_m is None) != (audio_path_t is None)) == False:
return "Please upload and only upload one recording, either upload the audio file or record using microphone.", "Please upload and only upload one recording, either upload the audio file or record using microphone."
else:
audio_path = audio_path_m or audio_path_t
audio_tagging_time_resolution = round_time_resolution(time_resolution)
model = mdl_dict[model_size]
result = model.transcribe(audio_path, at_time_res=audio_tagging_time_resolution)
audio_tag_result = whisper.parse_at_label(result, language='follow_asr', top_k=5, p_threshold=-1, include_class_list=list(range(527)))
asr_output = ""
for segment in result['segments']:
asr_output = asr_output + format(segment['start'], ".1f") + 's-' + format(segment['end'], ".1f") + 's: ' + segment['text'] + '\n'
at_output = ""
for segment in audio_tag_result:
print(segment)
at_output = at_output + format(segment['time']['start'], ".1f") + 's-' + format(segment['time']['end'], ".1f") + 's: ' + ', '.join([x[0] for x in segment['audio tags']]) + '\n'
print(at_output)
return asr_output, at_output
iface = gr.Interface(fn=predict,
inputs=[gr.Audio(type="filepath", source='microphone', label='Please either upload an audio file or record using the microphone.', show_label=True), gr.Audio(type="filepath"),
gr.Radio(["tiny", "tiny.en", "small", "large"], value='large', label="Model size", info="The larger the model, the better the performance and the slower the speed."),
gr.Textbox(value='10', label='Time Resolution in Seconds (Must be must be an integer multiple of 0.4, e.g., 0.4, 2, 10)')],
outputs=[gr.Textbox(label="Speech Output"), gr.Textbox(label="Audio Tag Output")],
cache_examples=True,
title="Quick Demo of Whisper-AT",
description="We are glad to introduce Whisper-AT - A new joint audio tagging and speech recognition model. It outputs background sound labels in addition to text." + f"<a href='{paper_link}'>{paper_text}</a> " + f"<a href='{link}'>{text}</a> <br>" +
"Whisper-AT is authored by Yuan Gong, Sameer Khurana, Leonid Karlinsky, and James Glass (MIT & MIT-IBM Watson AI Lab). It is an Interspeech 2023 paper.")
iface.launch(debug=True) |