dwb2023's picture
yadda yadda... getting rid of flash attention for now
4c90570 verified
raw
history blame
7.24 kB
import os
import json
import time
from datetime import datetime
from pathlib import Path
import tempfile
import pandas as pd
import gradio as gr
import yt_dlp as youtube_dl
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoTokenizer,
AutoFeatureExtractor,
pipeline,
)
from transformers.pipelines.audio_utils import ffmpeg_read
import torch
from datasets import load_dataset, Dataset, DatasetDict
import spaces
# Constants
MODEL_NAME = "openai/whisper-large-v3-turbo"
BATCH_SIZE = 8 # Optimized for better GPU utilization
YT_LENGTH_LIMIT_S = 10800 # 3 hours
DATASET_NAME = "dwb2023/yt-transcripts-v3"
FILE_LIMIT_MB = 1000
# Environment setup
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
device = 0 if torch.cuda.is_available() else "cpu"
# Pipeline setup
pipe = pipeline(
task="automatic-speech-recognition",
model=MODEL_NAME,
chunk_length_s=30,
device=device,
)
def reset_and_update_dataset(new_data):
"""
Resets and updates the dataset with new transcription data.
Args:
new_data (dict): Dictionary containing the new data to be added to the dataset.
"""
schema = {
"url": pd.Series(dtype="str"),
"transcription": pd.Series(dtype="str"),
"title": pd.Series(dtype="str"),
"duration": pd.Series(dtype="int"),
"uploader": pd.Series(dtype="str"),
"upload_date": pd.Series(dtype="datetime64[ns]"),
"description": pd.Series(dtype="str"),
"datetime": pd.Series(dtype="datetime64[ns]")
}
df = pd.DataFrame(schema)
df = pd.concat([df, pd.DataFrame([new_data])], ignore_index=True)
updated_dataset = Dataset.from_pandas(df)
dataset_dict = DatasetDict({"train": updated_dataset})
dataset_dict.push_to_hub(DATASET_NAME)
print("Dataset reset and updated successfully!")
def download_yt_audio(yt_url, filename):
"""
Downloads audio from a YouTube video using yt_dlp.
Args:
yt_url (str): URL of the YouTube video.
filename (str): Path to save the downloaded audio.
Returns:
dict: Information about the YouTube video.
"""
info_loader = youtube_dl.YoutubeDL()
try:
info = info_loader.extract_info(yt_url, download=False)
except youtube_dl.utils.DownloadError as err:
raise gr.Error(str(err))
file_length = info["duration"]
if file_length > YT_LENGTH_LIMIT_S:
yt_length_limit_hms = time.strftime("%H:%M:%S", time.gmtime(YT_LENGTH_LIMIT_S))
file_length_hms = time.strftime("%H:%M:%S", time.gmtime(file_length))
raise gr.Error(
f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video."
)
ydl_opts = {"outtmpl": filename, "format": "bestaudio/best"}
with youtube_dl.YoutubeDL(ydl_opts) as ydl:
ydl.download([yt_url])
return info
@spaces.GPU(duration=120)
def yt_transcribe(yt_url, task):
"""
Transcribes a YouTube video and saves the transcription if it doesn't already exist.
Args:
yt_url (str): URL of the YouTube video.
task (str): Task to perform - "transcribe" or "translate".
Returns:
str: The transcription of the video.
"""
dataset = load_dataset(DATASET_NAME, split="train")
for row in dataset:
if row['url'] == yt_url:
return row['transcription']
with tempfile.TemporaryDirectory() as tmpdirname:
filepath = os.path.join(tmpdirname, "video.mp4")
info = download_yt_audio(yt_url, filepath)
with open(filepath, "rb") as f:
video_data = f.read()
inputs = ffmpeg_read(video_data, pipe.feature_extractor.sampling_rate)
inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
text = pipe(
inputs,
batch_size=BATCH_SIZE,
generate_kwargs={"task": task},
return_timestamps=True,
)["text"]
save_transcription(yt_url, text, info)
return text
def save_transcription(yt_url, transcription, info):
"""
Saves the transcription data to the dataset.
Args:
yt_url (str): URL of the YouTube video.
transcription (str): The transcribed text.
info (dict): Additional information about the video.
"""
data = {
"url": yt_url,
"transcription": transcription,
"title": info.get("title", "N/A"),
"duration": info.get("duration", 0),
"uploader": info.get("uploader", "N/A"),
"upload_date": info.get("upload_date", "N/A"),
"description": info.get("description", "N/A"),
"datetime": datetime.now().isoformat()
}
dataset = load_dataset(DATASET_NAME, split="train")
df = dataset.to_pandas()
df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)
updated_dataset = Dataset.from_pandas(df)
dataset_dict = DatasetDict({"train": updated_dataset})
dataset_dict.push_to_hub(DATASET_NAME)
@spaces.GPU
def transcribe(inputs, task):
"""
Transcribes an audio input.
Args:
inputs (str): Path to the audio file.
task (str): Task to perform - "transcribe" or "translate".
Returns:
str: The transcription of the audio.
"""
if inputs is None:
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
return text
# Gradio App Setup
demo = gr.Blocks()
# YouTube Transcribe Tab
yt_transcribe_interface = gr.Interface(
fn=yt_transcribe,
inputs=[
gr.Textbox(
lines=1,
placeholder="Paste the URL to a YouTube video here",
label="YouTube URL",
),
gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
],
outputs="text",
title="YouTube Transcription",
description=(
f"Transcribe and archive YouTube videos using the {MODEL_NAME} model. "
"The transcriptions are saved for future reference, so repeated requests are faster!"
),
allow_flagging="never",
)
# Microphone Transcribe Tab
mf_transcribe_interface = gr.Interface(
fn=transcribe,
inputs=[
gr.Audio(sources="microphone", type="filepath"),
gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
],
outputs="text",
title="Microphone Transcription",
description="Transcribe audio captured through your microphone.",
allow_flagging="never",
)
# File Upload Transcribe Tab
file_transcribe_interface = gr.Interface(
fn=transcribe,
inputs=[
gr.Audio(sources="upload", type="filepath", label="Audio file"),
gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
],
outputs="text",
title="Audio File Transcription",
description="Transcribe uploaded audio files of arbitrary length.",
allow_flagging="never",
)
# Organize Tabs in the Gradio App
with demo:
gr.TabbedInterface(
[yt_transcribe_interface, mf_transcribe_interface, file_transcribe_interface],
["YouTube", "Microphone", "Audio File"]
)
demo.queue().launch()