Spaces:
Paused
Paused
import spaces | |
import os | |
import gradio as gr | |
import torch | |
import torchaudio | |
from transformers import pipeline | |
from pytube import YouTube | |
import re | |
import numpy as np | |
from scipy.signal import wiener | |
from io import BytesIO | |
pipe = pipeline(model="anzorq/w2v-bert-2.0-kbd-v2", device=0) | |
# Define the replacements for Kabardian transcription | |
replacements = [ | |
('гъ', 'ɣ'), ('дж', 'j'), ('дз', 'ӡ'), ('жь', 'ʐ'), ('кӏ', 'қ'), | |
('кхъ', 'qҳ'), ('къ', 'q'), ('лъ', 'ɬ'), ('лӏ', 'ԯ'), ('пӏ', 'ԥ'), | |
('тӏ', 'ҭ'), ('фӏ', 'ჶ'), ('хь', 'h'), ('хъ', 'ҳ'), ('цӏ', 'ҵ'), | |
('щӏ', 'ɕ'), ('я', 'йа') | |
] | |
# Reverse replacements for transcription | |
reverse_replacements = {v: k for k, v in replacements} | |
reverse_pattern = re.compile('|'.join(re.escape(key) for key in reverse_replacements)) | |
def replace_symbols_back(text): | |
return reverse_pattern.sub(lambda match: reverse_replacements[match.group(0)], text) | |
def preprocess_audio(audio_tensor, original_sample_rate, apply_normalization): | |
audio_tensor = audio_tensor.to(dtype=torch.float32) | |
audio_tensor = torch.mean(audio_tensor, dim=0, keepdim=True) # Convert to mono | |
if apply_normalization: | |
audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor)) # Normalize | |
audio_tensor = torchaudio.functional.resample(audio_tensor, orig_freq=original_sample_rate, new_freq=16000) # Resample | |
return audio_tensor | |
def spectral_gating(audio_tensor): | |
audio_data = audio_tensor.numpy() | |
reduced_noise = nr.reduce_noise(y=audio_data, sr=16_000) | |
return torch.tensor(reduced_noise, dtype=audio_tensor.dtype) | |
def wiener_filter(audio_tensor): | |
audio_data = audio_tensor.numpy() | |
filtered_audio = wiener(audio_data) | |
return torch.tensor(filtered_audio, dtype=audio_tensor.dtype) | |
def transcribe_speech(audio, progress=gr.Progress()): | |
if audio is None: | |
return "No audio received.", None | |
progress(0.5, desc="Transcribing audio...") | |
audio_np = audio.numpy().squeeze() | |
transcription = pipe(audio_np, chunk_length_s=10)['text'] | |
return replace_symbols_back(transcription), audio | |
def transcribe_from_youtube(url, apply_wiener_filter, apply_normalization, apply_spectral_gating, progress=gr.Progress()): | |
progress(0, "Downloading YouTube audio...") | |
yt = YouTube(url) | |
stream = yt.streams.filter(only_audio=True).first() | |
audio_data = BytesIO() | |
stream.stream_to_buffer(audio_data) | |
audio_data.seek(0) | |
try: | |
audio, original_sample_rate = torchaudio.load(audio_data) | |
audio = preprocess_audio(audio, original_sample_rate, apply_normalization) | |
if apply_wiener_filter: | |
progress(0.4, "Applying Wiener filter...") | |
audio = wiener_filter(audio) | |
if apply_spectral_gating: | |
progress(0.4, "Applying Spectral Gating filter...") | |
audio = spectral_gating(audio) | |
transcription, processed_audio = transcribe_speech(audio) | |
audio_np = processed_audio.numpy().squeeze() | |
audio_output = BytesIO() | |
torchaudio.save(audio_output, torch.tensor(audio_np).unsqueeze(0), 16000) | |
audio_output.seek(0) | |
except Exception as e: | |
return str(e), None | |
return transcription, audio_output | |
def populate_metadata(url): | |
yt = YouTube(url) | |
return yt.thumbnail_url, yt.title | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.HTML( | |
""" | |
<div style="text-align: center; max-width: 500px; margin: 0 auto;"> | |
<div> | |
<h1>Kabardian Speech Transcription</h1> | |
</div> | |
<p style="margin-bottom: 10px; font-size: 94%"> | |
Kabardian speech to text transcription using a fine-tuned Wav2Vec2-BERT model | |
</p> | |
</div> | |
""" | |
) | |
with gr.Tab("Microphone Input"): | |
gr.Markdown("## Transcribe speech from microphone") | |
mic_audio = gr.Audio(sources=['microphone','upload'], type="filepath", label="Record or upload an audio") | |
transcribe_button = gr.Button("Transcribe") | |
transcription_output = gr.Textbox(label="Transcription") | |
audio_output = gr.Audio(label="Processed Audio") | |
transcribe_button.click(fn=transcribe_speech, inputs=mic_audio, outputs=[transcription_output, audio_output]) | |
with gr.Tab("YouTube URL"): | |
gr.Markdown("## Transcribe speech from YouTube video") | |
youtube_url = gr.Textbox(label="Enter YouTube video URL") | |
with gr.Accordion("Audio Improvements", open=False): | |
apply_normalization = gr.Checkbox(label="Normalize audio volume", value=True) | |
apply_spectral_gating = gr.Checkbox(label="Apply Spectral Gating filter", info="Noise reduction", value=True) | |
apply_wiener = gr.Checkbox(label="Apply Wiener filter", info="Noise reduction", value=False) | |
with gr.Row(): | |
img = gr.Image(label="Thumbnail", height=240, width=240, scale=1) | |
title = gr.Label(label="Video Title", scale=2) | |
transcribe_button = gr.Button("Transcribe") | |
transcription_output = gr.Textbox(label="Transcription", placeholder="Transcription Output", lines=10) | |
audio_output = gr.Audio(label="Processed Audio") | |
transcribe_button.click(fn=transcribe_from_youtube, inputs=[youtube_url, apply_wiener, apply_normalization, apply_spectral_gating], outputs=[transcription_output, audio_output]) | |
youtube_url.change(populate_metadata, inputs=[youtube_url], outputs=[img, title]) | |
demo.launch() |