ncoop57
commited on
Commit
•
52636a1
1
Parent(s):
7f85c65
Update app with new example and configuration tools
Browse files
app.py
CHANGED
@@ -11,41 +11,44 @@ from PIL import Image
|
|
11 |
|
12 |
@st.cache(allow_output_mutation=True, max_entries=1)
|
13 |
def get_model():
|
|
|
14 |
clip = CLIPModel()
|
15 |
-
|
16 |
-
return
|
17 |
|
18 |
|
19 |
-
def get_embedding(
|
20 |
-
text_emb =
|
21 |
|
22 |
# Encode an image:
|
23 |
images = []
|
24 |
for img in video:
|
25 |
images.append(Image.fromarray(img))
|
26 |
-
img_embs =
|
27 |
|
28 |
return text_emb, img_embs
|
29 |
|
30 |
-
def find_frames(url,
|
31 |
-
text.text("
|
32 |
probe = ffmpeg.probe(url)
|
33 |
video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
|
34 |
width = int(video_stream['width'])
|
35 |
height = int(video_stream['height'])
|
36 |
out, _ = (
|
37 |
ffmpeg
|
38 |
-
.input(url, t=
|
39 |
.output('pipe:', format='rawvideo', pix_fmt='rgb24')
|
40 |
.run(capture_stdout=True)
|
41 |
)
|
|
|
|
|
42 |
video = (
|
43 |
np
|
44 |
.frombuffer(out, np.uint8)
|
45 |
.reshape([-1, height, width, 3])
|
46 |
)[::10]
|
47 |
|
48 |
-
txt_embd, img_embds = get_embedding(
|
49 |
cos_scores = np.array(util.cos_sim(txt_embd, img_embds))
|
50 |
ids = np.argsort(cos_scores)[0][-top_k:]
|
51 |
|
@@ -53,13 +56,25 @@ def find_frames(url, model, desc, top_k, text):
|
|
53 |
text.empty()
|
54 |
st.image(imgs)
|
55 |
|
56 |
-
|
|
|
|
|
|
|
57 |
st.title("Introducing Youtube CLIFS")
|
|
|
|
|
58 |
|
59 |
-
def clifs_page(
|
60 |
st.title("CLIFS")
|
61 |
|
62 |
st.sidebar.markdown("### Controls:")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
top_k = st.sidebar.slider(
|
64 |
"Top K",
|
65 |
min_value=1,
|
@@ -68,26 +83,22 @@ def clifs_page(model):
|
|
68 |
)
|
69 |
desc = st.sidebar.text_input(
|
70 |
"Search Description",
|
71 |
-
value="
|
72 |
help="Text description of what you want to find in the video",
|
73 |
)
|
74 |
url = st.sidebar.text_input(
|
75 |
"Youtube Video URL",
|
76 |
-
value='https://youtu.be/
|
77 |
help="Youtube video you'd like to search through",
|
78 |
)
|
79 |
|
80 |
submit_button = st.sidebar.button("Search")
|
81 |
if submit_button:
|
82 |
-
text = st.text("Downloading video...")
|
83 |
-
hook = lambda d: my_hook(d, )
|
84 |
ydl_opts = {"format": "mp4[height=360]"}
|
85 |
with youtube_dl.YoutubeDL(ydl_opts) as ydl:
|
86 |
info_dict = ydl.extract_info(url, download=False)
|
87 |
video_url = info_dict.get("url", None)
|
88 |
-
find_frames(video_url,
|
89 |
-
print(video_url)
|
90 |
-
# ydl.download([url])
|
91 |
|
92 |
PAGES = {
|
93 |
"Home": main_page,
|
@@ -99,12 +110,12 @@ PAGES = {
|
|
99 |
def run():
|
100 |
st.set_page_config(page_title="Youtube CLIFS")
|
101 |
# main body
|
102 |
-
|
103 |
|
104 |
st.sidebar.title('Navigation')
|
105 |
selection = st.sidebar.radio("Go to", list(PAGES.keys()))
|
106 |
|
107 |
-
page = PAGES[selection](
|
108 |
|
109 |
|
110 |
|
|
|
11 |
|
12 |
@st.cache(allow_output_mutation=True, max_entries=1)
|
13 |
def get_model():
|
14 |
+
txt_model = SentenceTransformer('clip-ViT-B-32-multilingual-v1').to(dtype=torch.float32, device=torch.device('cpu'))
|
15 |
clip = CLIPModel()
|
16 |
+
vis_model = SentenceTransformer(modules=[clip]).to(dtype=torch.float32, device=torch.device('cpu'))
|
17 |
+
return txt_model, vis_model
|
18 |
|
19 |
|
20 |
+
def get_embedding(txt_model, vis_model, query, video):
|
21 |
+
text_emb = txt_model.encode(query, device='cpu')
|
22 |
|
23 |
# Encode an image:
|
24 |
images = []
|
25 |
for img in video:
|
26 |
images.append(Image.fromarray(img))
|
27 |
+
img_embs = vis_model.encode(images, device='cpu')
|
28 |
|
29 |
return text_emb, img_embs
|
30 |
|
31 |
+
def find_frames(url, txt_model, vis_model, desc, seconds, top_k):
|
32 |
+
text = st.text("Downloading video...")
|
33 |
probe = ffmpeg.probe(url)
|
34 |
video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
|
35 |
width = int(video_stream['width'])
|
36 |
height = int(video_stream['height'])
|
37 |
out, _ = (
|
38 |
ffmpeg
|
39 |
+
.input(url, t=seconds)
|
40 |
.output('pipe:', format='rawvideo', pix_fmt='rgb24')
|
41 |
.run(capture_stdout=True)
|
42 |
)
|
43 |
+
|
44 |
+
text.text("Processing video...")
|
45 |
video = (
|
46 |
np
|
47 |
.frombuffer(out, np.uint8)
|
48 |
.reshape([-1, height, width, 3])
|
49 |
)[::10]
|
50 |
|
51 |
+
txt_embd, img_embds = get_embedding(txt_model, vis_model, desc, video)
|
52 |
cos_scores = np.array(util.cos_sim(txt_embd, img_embds))
|
53 |
ids = np.argsort(cos_scores)[0][-top_k:]
|
54 |
|
|
|
56 |
text.empty()
|
57 |
st.image(imgs)
|
58 |
|
59 |
+
with open("HOME.md", "r") as f:
|
60 |
+
HOME_PAGE = f.read()
|
61 |
+
|
62 |
+
def main_page(txt_model, vis_model):
|
63 |
st.title("Introducing Youtube CLIFS")
|
64 |
+
|
65 |
+
st.markdown(HOME_PAGE)
|
66 |
|
67 |
+
def clifs_page(txt_model, vis_model):
|
68 |
st.title("CLIFS")
|
69 |
|
70 |
st.sidebar.markdown("### Controls:")
|
71 |
+
seconds = st.sidebar.slider(
|
72 |
+
"How many seconds of video to consider?",
|
73 |
+
min_value=10,
|
74 |
+
max_value=120,
|
75 |
+
value=60,
|
76 |
+
step=1,
|
77 |
+
)
|
78 |
top_k = st.sidebar.slider(
|
79 |
"Top K",
|
80 |
min_value=1,
|
|
|
83 |
)
|
84 |
desc = st.sidebar.text_input(
|
85 |
"Search Description",
|
86 |
+
value="Pancake in the shape of an otter", # panqueque en forma de nutria
|
87 |
help="Text description of what you want to find in the video",
|
88 |
)
|
89 |
url = st.sidebar.text_input(
|
90 |
"Youtube Video URL",
|
91 |
+
value='https://youtu.be/xUv6XgPwGaQ',
|
92 |
help="Youtube video you'd like to search through",
|
93 |
)
|
94 |
|
95 |
submit_button = st.sidebar.button("Search")
|
96 |
if submit_button:
|
|
|
|
|
97 |
ydl_opts = {"format": "mp4[height=360]"}
|
98 |
with youtube_dl.YoutubeDL(ydl_opts) as ydl:
|
99 |
info_dict = ydl.extract_info(url, download=False)
|
100 |
video_url = info_dict.get("url", None)
|
101 |
+
find_frames(video_url, txt_model, vis_model, desc, seconds, top_k)
|
|
|
|
|
102 |
|
103 |
PAGES = {
|
104 |
"Home": main_page,
|
|
|
110 |
def run():
|
111 |
st.set_page_config(page_title="Youtube CLIFS")
|
112 |
# main body
|
113 |
+
txt_model, vis_model = get_model()
|
114 |
|
115 |
st.sidebar.title('Navigation')
|
116 |
selection = st.sidebar.radio("Go to", list(PAGES.keys()))
|
117 |
|
118 |
+
page = PAGES[selection](txt_model, vis_model)
|
119 |
|
120 |
|
121 |
|