dwb2023 commited on
Commit
2de41dc
·
verified ·
1 Parent(s): a589aff

update app

Browse files
Files changed (1) hide show
  1. app.py +109 -68
app.py CHANGED
@@ -9,44 +9,46 @@ import pandas as pd
9
  import gradio as gr
10
  import yt_dlp as youtube_dl
11
  from transformers import (
12
- BitsAndBytesConfig,
13
  AutoModelForSpeechSeq2Seq,
14
  AutoTokenizer,
15
  AutoFeatureExtractor,
16
  pipeline,
17
  )
18
  from transformers.pipelines.audio_utils import ffmpeg_read
19
- import torch # If you're using PyTorch
20
  from datasets import load_dataset, Dataset, DatasetDict
21
  import spaces
22
 
23
  # Constants
24
- MODEL_NAME = "openai/whisper-large-v3"
25
- BATCH_SIZE = 8
26
- YT_LENGTH_LIMIT_S = 4800 # 1 hour 20 minutes
27
  DATASET_NAME = "dwb2023/yt-transcripts-v3"
28
 
29
  # Environment setup
30
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
31
 
32
  # Model setup
33
- bnb_config = BitsAndBytesConfig(
34
- load_in_4bit=True,
35
- bnb_4bit_use_double_quant=True,
36
- bnb_4bit_quant_type="nf4",
37
- bnb_4bit_compute_dtype=torch.bfloat16
38
- )
39
-
40
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
41
  MODEL_NAME,
42
- quantization_config=bnb_config,
43
  use_cache=False,
44
  device_map="auto"
45
  )
46
 
 
 
 
 
 
 
 
 
 
 
47
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
48
  feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
49
 
 
50
  pipe = pipeline(
51
  task="automatic-speech-recognition",
52
  model=model,
@@ -56,7 +58,12 @@ pipe = pipeline(
56
  )
57
 
58
  def reset_and_update_dataset(new_data):
59
- # Define the schema for an empty DataFrame
 
 
 
 
 
60
  schema = {
61
  "url": pd.Series(dtype="str"),
62
  "transcription": pd.Series(dtype="str"),
@@ -67,22 +74,24 @@ def reset_and_update_dataset(new_data):
67
  "description": pd.Series(dtype="str"),
68
  "datetime": pd.Series(dtype="datetime64[ns]")
69
  }
70
-
71
- # Create an empty DataFrame with the defined schema
72
  df = pd.DataFrame(schema)
73
-
74
- # Append the new data
75
  df = pd.concat([df, pd.DataFrame([new_data])], ignore_index=True)
76
-
77
- # Convert back to dataset
78
  updated_dataset = Dataset.from_pandas(df)
79
-
80
- # Push the updated dataset to the hub
81
  dataset_dict = DatasetDict({"train": updated_dataset})
82
  dataset_dict.push_to_hub(DATASET_NAME)
83
  print("Dataset reset and updated successfully!")
84
 
85
  def download_yt_audio(yt_url, filename):
 
 
 
 
 
 
 
 
 
 
86
  info_loader = youtube_dl.YoutubeDL()
87
  try:
88
  info = info_loader.extract_info(yt_url, download=False)
@@ -104,15 +113,20 @@ def download_yt_audio(yt_url, filename):
104
 
105
  @spaces.GPU(duration=120)
106
  def yt_transcribe(yt_url, task):
107
- # Load the dataset
 
 
 
 
 
 
 
 
 
108
  dataset = load_dataset(DATASET_NAME, split="train")
109
-
110
- # Check if the transcription already exists
111
  for row in dataset:
112
  if row['url'] == yt_url:
113
- return row['transcription'] # Return the existing transcription
114
-
115
- # If transcription does not exist, perform the transcription
116
  with tempfile.TemporaryDirectory() as tmpdirname:
117
  filepath = os.path.join(tmpdirname, "video.mp4")
118
  info = download_yt_audio(yt_url, filepath)
@@ -126,54 +140,56 @@ def yt_transcribe(yt_url, task):
126
  generate_kwargs={"task": task},
127
  return_timestamps=True,
128
  )["text"]
129
-
130
- # Extract additional fields
131
- try:
132
- title = info.get("title", "N/A")
133
- duration = info.get("duration", 0)
134
- uploader = info.get("uploader", "N/A")
135
- upload_date = info.get("upload_date", "N/A")
136
- description = info.get("description", "N/A")
137
- except KeyError:
138
- title = "N/A"
139
- duration = 0
140
- uploader = "N/A"
141
- upload_date = "N/A"
142
- description = "N/A"
143
-
144
- save_transcription(yt_url, text, title, duration, uploader, upload_date, description)
145
  return text
146
 
147
- def save_transcription(yt_url, transcription, title, duration, uploader, upload_date, description):
 
 
 
 
 
 
 
 
148
  data = {
149
  "url": yt_url,
150
  "transcription": transcription,
151
- "title": title,
152
- "duration": duration,
153
- "uploader": uploader,
154
- "upload_date": upload_date,
155
- "description": description,
156
  "datetime": datetime.now().isoformat()
157
  }
158
-
159
- # Load the existing dataset
160
  dataset = load_dataset(DATASET_NAME, split="train")
161
-
162
- # Convert to pandas dataframe
163
  df = dataset.to_pandas()
164
-
165
- # Append the new data
166
  df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)
167
-
168
- # Convert back to dataset
169
  updated_dataset = Dataset.from_pandas(df)
170
-
171
- # Push the updated dataset to the hub
172
  dataset_dict = DatasetDict({"train": updated_dataset})
173
  dataset_dict.push_to_hub(DATASET_NAME)
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  demo = gr.Blocks()
176
 
 
177
  yt_transcribe_interface = gr.Interface(
178
  fn=yt_transcribe,
179
  inputs=[
@@ -185,20 +201,45 @@ yt_transcribe_interface = gr.Interface(
185
  gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
186
  ],
187
  outputs="text",
188
- title="👂👁️👅👃✋ KnowledgeScribe 📝 🧠💡🎓🚀",
189
  description=(
190
- f"""**KnowledgeScribe** is your all-in-one transcription and summarization tool designed to help your LLM extract and distill knowledge from various sources, including YouTube videos and Arxiv papers.
191
- \n\nCurrently leverages the following datasets and models:
192
- \n- [{DATASET_NAME}](https://huggingface.co/datasets/{DATASET_NAME}/viewer) dataset
193
- \n- [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) model
194
- """
195
  ),
196
  allow_flagging="never",
197
  )
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  with demo:
200
  gr.TabbedInterface(
201
- [yt_transcribe_interface], ["YouTube"]
 
202
  )
203
 
204
- demo.queue().launch()
 
9
  import gradio as gr
10
  import yt_dlp as youtube_dl
11
  from transformers import (
 
12
  AutoModelForSpeechSeq2Seq,
13
  AutoTokenizer,
14
  AutoFeatureExtractor,
15
  pipeline,
16
  )
17
  from transformers.pipelines.audio_utils import ffmpeg_read
18
+ import torch
19
  from datasets import load_dataset, Dataset, DatasetDict
20
  import spaces
21
 
22
  # Constants
23
+ MODEL_NAME = "openai/whisper-large-v3-turbo"
24
+ BATCH_SIZE = 8 # Optimized for better GPU utilization
25
+ YT_LENGTH_LIMIT_S = 10800 # 3 hours
26
  DATASET_NAME = "dwb2023/yt-transcripts-v3"
27
 
28
  # Environment setup
29
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
30
 
31
  # Model setup
 
 
 
 
 
 
 
32
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
33
  MODEL_NAME,
 
34
  use_cache=False,
35
  device_map="auto"
36
  )
37
 
38
+ # Flash Attention setup for memory and speed optimization if supported
39
+ try:
40
+ from flash_attn import flash_attn_fn
41
+ model.config.use_flash_attention = True
42
+ except ImportError:
43
+ print("Flash Attention is not available. Proceeding without it.")
44
+
45
+ # Note: torch.compile is not compatible with Flash Attention or the chunked long-form algorithm.
46
+
47
+ # Processor setup
48
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
49
  feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
50
 
51
+ # Pipeline setup
52
  pipe = pipeline(
53
  task="automatic-speech-recognition",
54
  model=model,
 
58
  )
59
 
60
  def reset_and_update_dataset(new_data):
61
+ """
62
+ Resets and updates the dataset with new transcription data.
63
+
64
+ Args:
65
+ new_data (dict): Dictionary containing the new data to be added to the dataset.
66
+ """
67
  schema = {
68
  "url": pd.Series(dtype="str"),
69
  "transcription": pd.Series(dtype="str"),
 
74
  "description": pd.Series(dtype="str"),
75
  "datetime": pd.Series(dtype="datetime64[ns]")
76
  }
 
 
77
  df = pd.DataFrame(schema)
 
 
78
  df = pd.concat([df, pd.DataFrame([new_data])], ignore_index=True)
 
 
79
  updated_dataset = Dataset.from_pandas(df)
 
 
80
  dataset_dict = DatasetDict({"train": updated_dataset})
81
  dataset_dict.push_to_hub(DATASET_NAME)
82
  print("Dataset reset and updated successfully!")
83
 
84
  def download_yt_audio(yt_url, filename):
85
+ """
86
+ Downloads audio from a YouTube video using yt_dlp.
87
+
88
+ Args:
89
+ yt_url (str): URL of the YouTube video.
90
+ filename (str): Path to save the downloaded audio.
91
+
92
+ Returns:
93
+ dict: Information about the YouTube video.
94
+ """
95
  info_loader = youtube_dl.YoutubeDL()
96
  try:
97
  info = info_loader.extract_info(yt_url, download=False)
 
113
 
114
  @spaces.GPU(duration=120)
115
  def yt_transcribe(yt_url, task):
116
+ """
117
+ Transcribes a YouTube video and saves the transcription if it doesn't already exist.
118
+
119
+ Args:
120
+ yt_url (str): URL of the YouTube video.
121
+ task (str): Task to perform - "transcribe" or "translate".
122
+
123
+ Returns:
124
+ str: The transcription of the video.
125
+ """
126
  dataset = load_dataset(DATASET_NAME, split="train")
 
 
127
  for row in dataset:
128
  if row['url'] == yt_url:
129
+ return row['transcription']
 
 
130
  with tempfile.TemporaryDirectory() as tmpdirname:
131
  filepath = os.path.join(tmpdirname, "video.mp4")
132
  info = download_yt_audio(yt_url, filepath)
 
140
  generate_kwargs={"task": task},
141
  return_timestamps=True,
142
  )["text"]
143
+ save_transcription(yt_url, text, info)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  return text
145
 
146
+ def save_transcription(yt_url, transcription, info):
147
+ """
148
+ Saves the transcription data to the dataset.
149
+
150
+ Args:
151
+ yt_url (str): URL of the YouTube video.
152
+ transcription (str): The transcribed text.
153
+ info (dict): Additional information about the video.
154
+ """
155
  data = {
156
  "url": yt_url,
157
  "transcription": transcription,
158
+ "title": info.get("title", "N/A"),
159
+ "duration": info.get("duration", 0),
160
+ "uploader": info.get("uploader", "N/A"),
161
+ "upload_date": info.get("upload_date", "N/A"),
162
+ "description": info.get("description", "N/A"),
163
  "datetime": datetime.now().isoformat()
164
  }
 
 
165
  dataset = load_dataset(DATASET_NAME, split="train")
 
 
166
  df = dataset.to_pandas()
 
 
167
  df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)
 
 
168
  updated_dataset = Dataset.from_pandas(df)
 
 
169
  dataset_dict = DatasetDict({"train": updated_dataset})
170
  dataset_dict.push_to_hub(DATASET_NAME)
171
 
172
+ @spaces.GPU
173
+ def transcribe(inputs, task):
174
+ """
175
+ Transcribes an audio input.
176
+
177
+ Args:
178
+ inputs (str): Path to the audio file.
179
+ task (str): Task to perform - "transcribe" or "translate".
180
+
181
+ Returns:
182
+ str: The transcription of the audio.
183
+ """
184
+ if inputs is None:
185
+ raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
186
+ text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
187
+ return text
188
+
189
+ # Gradio App Setup
190
  demo = gr.Blocks()
191
 
192
+ # YouTube Transcribe Tab
193
  yt_transcribe_interface = gr.Interface(
194
  fn=yt_transcribe,
195
  inputs=[
 
201
  gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
202
  ],
203
  outputs="text",
204
+ title="YouTube Transcription",
205
  description=(
206
+ f"Transcribe and archive YouTube videos using the {MODEL_NAME} model. "
207
+ "The transcriptions are saved for future reference, so repeated requests are faster!"
 
 
 
208
  ),
209
  allow_flagging="never",
210
  )
211
 
212
+ # Microphone Transcribe Tab
213
+ mf_transcribe_interface = gr.Interface(
214
+ fn=transcribe,
215
+ inputs=[
216
+ gr.Audio(sources="microphone", type="filepath"),
217
+ gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
218
+ ],
219
+ outputs="text",
220
+ title="Microphone Transcription",
221
+ description="Transcribe audio captured through your microphone.",
222
+ allow_flagging="never",
223
+ )
224
+
225
+ # File Upload Transcribe Tab
226
+ file_transcribe_interface = gr.Interface(
227
+ fn=transcribe,
228
+ inputs=[
229
+ gr.Audio(sources="upload", type="filepath", label="Audio file"),
230
+ gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
231
+ ],
232
+ outputs="text",
233
+ title="Audio File Transcription",
234
+ description="Transcribe uploaded audio files of arbitrary length.",
235
+ allow_flagging="never",
236
+ )
237
+
238
+ # Organize Tabs in the Gradio App
239
  with demo:
240
  gr.TabbedInterface(
241
+ [yt_transcribe_interface, mf_transcribe_interface, file_transcribe_interface],
242
+ ["YouTube", "Microphone", "Audio File"]
243
  )
244
 
245
+ demo.queue().launch()