einanao commited on
Commit
0b43831
1 Parent(s): 552a118

draft streamlit app

Browse files
Files changed (2) hide show
  1. app.py +262 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ import numpy as np
5
+ import ffmpeg
6
+ import whisper
7
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
8
+ from sklearn.tree import DecisionTreeRegressor
9
+ import torch
10
+ import youtube_dl
11
+ import pandas as pd
12
+ import streamlit as st
13
+ import altair as alt
14
+
15
+ DATA_DIR = "./data"
16
+ if not os.path.exists(DATA_DIR):
17
+ os.makedirs(DATA_DIR)
18
+
19
+ YDL_OPTS = {
20
+ "download_archive": os.path.join(DATA_DIR, "archive.txt"),
21
+ "format": "bestaudio/best",
22
+ "outtmpl": os.path.join(DATA_DIR, "%(title)s.%(ext)s"),
23
+ "postprocessors": [
24
+ {
25
+ "key": "FFmpegExtractAudio",
26
+ "preferredcodec": "mp3",
27
+ "preferredquality": "192",
28
+ }
29
+ ],
30
+ }
31
+
32
+ llm = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
33
+ tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
34
+ device = "cuda" if torch.cuda.is_available() else "cpu"
35
+
36
+
37
+ def download(url, ydl_opts):
38
+ with youtube_dl.YoutubeDL(ydl_opts) as ydl:
39
+ result = ydl.extract_info("{}".format(url))
40
+ fname = ydl.prepare_filename(result)
41
+ return fname
42
+
43
+
44
+ def transcribe(audio_path, transcript_path):
45
+ if os.path.exists(transcript_path):
46
+ with open(transcript_path, "r") as f:
47
+ result = json.load(f)
48
+ else:
49
+ whisper_model = whisper.load_model("base")
50
+ result = whisper_model.transcribe(audio_path)
51
+ with open(transcript_path, "w") as f:
52
+ json.dump(result, f)
53
+ return result["segments"]
54
+
55
+
56
+ def compute_seg_durations(segments):
57
+ return [s["end"] - s["start"] for s in segments]
58
+
59
+
60
+ def compute_info_densities(
61
+ segments, seg_durations, llm, tokenizer, device, ctxt_len=512
62
+ ):
63
+ seg_encodings = [tokenizer(seg["text"], return_tensors="pt") for seg in segments]
64
+ input_ids = [enc.input_ids.to(device) for enc in seg_encodings]
65
+ seg_lens = [x.shape[1] for x in input_ids]
66
+ cat_input_ids = torch.cat(input_ids, axis=1)
67
+ end = 0
68
+ seg_nlls = []
69
+ n = cat_input_ids.shape[1]
70
+ for i, seg_len in enumerate(seg_lens):
71
+ end = min(n, end + seg_len)
72
+ start = max(0, end - ctxt_len)
73
+ ctxt_ids = cat_input_ids[:, start:end]
74
+ target_ids = ctxt_ids.clone()
75
+ target_ids[:, :-seg_len] = -100
76
+ avg_nll = llm(ctxt_ids, labels=target_ids).loss.detach().numpy()
77
+ nll = avg_nll * seg_len
78
+ seg_nlls.append(nll)
79
+ seg_nlls = np.array(seg_nlls)
80
+ info_densities = seg_nlls / seg_durations
81
+ return info_densities
82
+
83
+
84
+ def smooth_info_densities(info_densities, seg_durations, max_leaf_nodes, min_sec_leaf):
85
+ min_samples_leaf = int(np.ceil(min_sec_leaf / np.mean(seg_durations)))
86
+ tree = DecisionTreeRegressor(
87
+ max_leaf_nodes=max_leaf_nodes, min_samples_leaf=min_samples_leaf
88
+ )
89
+ X = np.arange(0, len(info_densities), 1)[:, np.newaxis]
90
+ tree.fit(X, info_densities)
91
+ smoothed_info_densities = tree.predict(X)
92
+ return smoothed_info_densities
93
+
94
+
95
+ def squash_segs(segments, info_densities):
96
+ start = segments[0]["start"]
97
+ end = None
98
+ seg_times = []
99
+ seg_densities = [info_densities[0]]
100
+ for i in range(1, len(segments)):
101
+ curr_density = info_densities[i]
102
+ if curr_density != info_densities[i - 1]:
103
+ seg = segments[i]
104
+ seg_start = seg["start"]
105
+ seg_times.append((start, seg_start))
106
+ seg_densities.append(curr_density)
107
+ start = seg_start
108
+ seg_times.append((start, segments[-1]["end"]))
109
+ return seg_times, seg_densities
110
+
111
+
112
+ def compute_speedups(info_densities):
113
+ avg_density = np.mean(info_densities)
114
+ speedups = avg_density / info_densities
115
+ return speedups
116
+
117
+
118
+ def compute_actual_speedup(durations, speedups, total_duration):
119
+ spedup_durations = durations / speedups
120
+ spedup_total_duration = spedup_durations.sum()
121
+ actual_speedup_factor = total_duration / spedup_total_duration
122
+ return spedup_total_duration, actual_speedup_factor
123
+
124
+
125
+ def postprocess_speedups(
126
+ speedups, factor, min_speedup, max_speedup, durations, total_duration, thresh=0.01
127
+ ):
128
+ assert min_speedup <= factor and factor <= max_speedup
129
+ tuned_factor = np.array([factor / 10, factor * 10])
130
+ actual_speedup_factor = None
131
+ while (
132
+ actual_speedup_factor is None
133
+ or abs(actual_speedup_factor - factor) / factor > thresh
134
+ ):
135
+ mid = tuned_factor.mean()
136
+ tuned_speedups = speedups * mid
137
+ tuned_speedups = np.round(tuned_speedups, decimals=2)
138
+ tuned_speedups = np.clip(tuned_speedups, min_speedup, max_speedup)
139
+ _, actual_speedup_factor = compute_actual_speedup(
140
+ durations, tuned_speedups, total_duration
141
+ )
142
+ tuned_factor[0 if actual_speedup_factor < factor else 1] = mid
143
+ return tuned_speedups
144
+
145
+
146
+ def cat_clips(seg_times, speedups, audio_path, output_path):
147
+ if os.path.exists(output_path):
148
+ os.remove(output_path)
149
+ in_file = ffmpeg.input(audio_path)
150
+ segs = []
151
+ for (start, end), speedup in zip(seg_times, speedups):
152
+ seg = in_file.filter("atrim", start=start, end=end).filter("atempo", speedup)
153
+ segs.append(seg)
154
+ cat = ffmpeg.concat(*segs, v=0, a=1)
155
+ cat.output(output_path).run()
156
+
157
+
158
+ def format_duration(duration):
159
+ s = duration % 60
160
+ m = duration // 60
161
+ h = m // 60
162
+ return "%02d:%02d:%02d" % (h, m, s)
163
+
164
+
165
+ def strike(url, speedup_factor, min_speedup, max_speedup, max_num_segments):
166
+
167
+ min_speedup = max(0.5, min_speedup) # ffmpeg limit
168
+
169
+ name = download(url, YDL_OPTS)
170
+ assert name.endswith(".m4a")
171
+ name = name.split(".m4a")[0].split("/")[-1]
172
+
173
+ audio_path = os.path.join(DATA_DIR, "%s.mp3" % name)
174
+ transcript_path = os.path.join(DATA_DIR, "%s.json" % name)
175
+ output_path = os.path.join(DATA_DIR, "%s_smooth.mp3" % name)
176
+
177
+ segments = transcribe(audio_path, transcript_path)
178
+
179
+ seg_durations = compute_seg_durations(segments)
180
+
181
+ info_densities = compute_info_densities(
182
+ segments, seg_durations, llm, tokenizer, device
183
+ )
184
+
185
+ total_duration = segments[-1]["end"] - segments[0]["start"]
186
+ min_sec_leaf = total_duration / max_num_segments
187
+ smoothed_info_densities = smooth_info_densities(
188
+ info_densities, seg_durations, max_num_segments, min_sec_leaf
189
+ )
190
+
191
+ squashed_times, squashed_densities = squash_segs(segments, smoothed_info_densities)
192
+ squashed_durations = np.array([end - start for start, end in squashed_times])
193
+
194
+ speedups = compute_speedups(squashed_densities)
195
+ speedups = postprocess_speedups(
196
+ speedups,
197
+ speedup_factor,
198
+ min_speedup,
199
+ max_speedup,
200
+ squashed_durations,
201
+ total_duration,
202
+ )
203
+
204
+ cat_clips(squashed_times, speedups, audio_path, output_path)
205
+
206
+ spedup_total_duration, actual_speedup_factor = compute_actual_speedup(
207
+ squashed_durations, speedups, total_duration
208
+ )
209
+ st.write("original duration: %s" % format_duration(total_duration))
210
+ st.write("new duration: %s" % format_duration(spedup_total_duration))
211
+ st.write("speedup: %0.2f" % actual_speedup_factor)
212
+
213
+ times = np.array([(seg["start"] + seg["end"]) / 2 for seg in segments])
214
+ times /= 60
215
+ annotations = [seg["text"] for seg in segments]
216
+ data = [times, info_densities / np.log(2), annotations]
217
+ cols = ["time (minutes)", "bits per second", "transcript"]
218
+ df = pd.DataFrame(list(zip(*data)), columns=cols)
219
+
220
+ lines = (
221
+ alt.Chart(df, title="information rate")
222
+ .mark_line(color="gray", opacity=0.5)
223
+ .encode(
224
+ x=cols[0],
225
+ y=cols[1],
226
+ )
227
+ )
228
+ dots = (
229
+ alt.Chart(df)
230
+ .mark_circle(size=50, opacity=1)
231
+ .encode(x=cols[0], y=cols[1], tooltip=["transcript"])
232
+ )
233
+ st.altair_chart((lines + dots).interactive(), use_container_width=True)
234
+
235
+ times = sum([list(x) for x in squashed_times], [])
236
+ times = np.array(times)
237
+ times /= 60
238
+ data = [times, np.repeat(speedups, 2)]
239
+ cols = ["time (minutes)", "speedup"]
240
+ df = pd.DataFrame(list(zip(*data)), columns=cols)
241
+ st.line_chart(df, x=cols[0], y=cols[1])
242
+
243
+ return output_path
244
+
245
+
246
+ with st.form("my_form"):
247
+ url = st.text_input(
248
+ "youtube url", value="https://www.youtube.com/watch?v=_3MBQm7GFIM"
249
+ )
250
+ speedup_factor = st.slider("speedup", min_value=1.0, max_value=10.0, value=1.5)
251
+ min_speedup = 1
252
+ max_speedup = st.slider("maximum speedup", min_value=1.0, max_value=10.0, value=2.0)
253
+ speedup_factor = min(speedup_factor, max_speedup)
254
+ max_num_segments = st.slider(
255
+ "variance in speedup over time", min_value=2, max_value=100, value=20
256
+ )
257
+ submitted = st.form_submit_button("submit")
258
+ if submitted:
259
+ output_path = strike(
260
+ url, speedup_factor, min_speedup, max_speedup, max_num_segments
261
+ )
262
+ st.audio(output_path)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ffmpeg-python==0.2.0
2
+ numpy==1.23.4
3
+ scikit-learn==1.1.3
4
+ torch==1.13.0
5
+ transformers==4.24.0
6
+ whisper @ git+https://github.com/openai/whisper.git@9f70a352f9f8630ab3aa0d06af5cb9532bd8c21d
7
+ youtube-dl==2021.12.17
8
+ streamlit==1.14.0
9
+ pandas==1.5.1
10
+ altair==4.2.0