import spaces import gradio as gr import torch import torchaudio from transformers import AutoModelForCTC, Wav2Vec2BertProcessor import yt_dlp model = AutoModelForCTC.from_pretrained("anzorq/w2v-bert-2.0-kbd") processor = Wav2Vec2BertProcessor.from_pretrained("anzorq/w2v-bert-2.0-kbd") device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) @spaces.GPU def transcribe_speech(audio): # Load the audio file waveform, sr = torchaudio.load(audio) # Resample the audio if needed if sr != 16000: resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000) waveform = resampler(waveform) # Convert to mono if needed if waveform.dim() > 1: waveform = torchaudio.transforms.DownmixMono()(waveform) # Normalize the audio waveform = waveform / torch.max(torch.abs(waveform)) # Extract input features with torch.no_grad(): input_features = processor(waveform.unsqueeze(0), sampling_rate=16000).input_features input_features = torch.from_numpy(input_features).to(device) # Generate logits using the model logits = model(input_features).logits # Decode the predicted ids to text pred_ids = torch.argmax(logits, dim=-1)[0] pred_text = processor.decode(pred_ids) return pred_text @spaces.GPU def transcribe_from_youtube(url): # Download audio from YouTube using yt-dlp audio_path = f"downloaded_audio_{url.split('=')[-1]}.wav" ydl_opts = { 'format': 'bestaudio/best', 'outtmpl': audio_path, 'postprocessors': [{ 'key': 'FFmpegExtractAudio', 'preferredcodec': 'wav', 'preferredquality': '192', }], 'postprocessor_args': ['-ar', '16000'], # Ensure audio is at 16000 Hz 'prefer_ffmpeg': True, } with yt_dlp.YoutubeDL(ydl_opts) as ydl: ydl.download([url]) # # Check if the file exists # if not os.path.exists(audio_path): # raise FileNotFoundError(f"Failed to find the audio file {audio_path}") # Transcribe the downloaded audio transcription = transcribe_speech(audio_path) # Optionally, clean up the downloaded file os.remove(audio_path) return transcription with gr.Blocks() as demo: with gr.Tab("Microphone Input"): gr.Markdown("## Transcribe speech from microphone") mic_audio = gr.Audio(sources="microphone", type="filepath", label="Speak into your microphone") transcribe_button = gr.Button("Transcribe") transcription_output = gr.Textbox(label="Transcription") transcribe_button.click(fn=transcribe_speech, inputs=mic_audio, outputs=transcription_output) with gr.Tab("YouTube URL"): gr.Markdown("## Transcribe speech from YouTube video") youtube_url = gr.Textbox(label="Enter YouTube video URL") transcribe_button = gr.Button("Transcribe") transcription_output = gr.Textbox(label="Transcription") transcribe_button.click(fn=transcribe_from_youtube, inputs=youtube_url, outputs=transcription_output) demo.launch()