ncoop57 commited on
Commit
52636a1
1 Parent(s): 7f85c65

Update app with new example and configuration tools

Browse files
Files changed (1) hide show
  1. app.py +31 -20
app.py CHANGED
@@ -11,41 +11,44 @@ from PIL import Image
11
 
12
  @st.cache(allow_output_mutation=True, max_entries=1)
13
  def get_model():
 
14
  clip = CLIPModel()
15
- model = SentenceTransformer(modules=[clip]).to(dtype=torch.float32, device=torch.device('cpu'))
16
- return model
17
 
18
 
19
- def get_embedding(model, query, video):
20
- text_emb = model.encode(query, device='cpu')
21
 
22
  # Encode an image:
23
  images = []
24
  for img in video:
25
  images.append(Image.fromarray(img))
26
- img_embs = model.encode(images, device='cpu')
27
 
28
  return text_emb, img_embs
29
 
30
- def find_frames(url, model, desc, top_k, text):
31
- text.text("Processing video...")
32
  probe = ffmpeg.probe(url)
33
  video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
34
  width = int(video_stream['width'])
35
  height = int(video_stream['height'])
36
  out, _ = (
37
  ffmpeg
38
- .input(url, t=60)
39
  .output('pipe:', format='rawvideo', pix_fmt='rgb24')
40
  .run(capture_stdout=True)
41
  )
 
 
42
  video = (
43
  np
44
  .frombuffer(out, np.uint8)
45
  .reshape([-1, height, width, 3])
46
  )[::10]
47
 
48
- txt_embd, img_embds = get_embedding(model, desc, video)
49
  cos_scores = np.array(util.cos_sim(txt_embd, img_embds))
50
  ids = np.argsort(cos_scores)[0][-top_k:]
51
 
@@ -53,13 +56,25 @@ def find_frames(url, model, desc, top_k, text):
53
  text.empty()
54
  st.image(imgs)
55
 
56
- def main_page(model):
 
 
 
57
  st.title("Introducing Youtube CLIFS")
 
 
58
 
59
- def clifs_page(model):
60
  st.title("CLIFS")
61
 
62
  st.sidebar.markdown("### Controls:")
 
 
 
 
 
 
 
63
  top_k = st.sidebar.slider(
64
  "Top K",
65
  min_value=1,
@@ -68,26 +83,22 @@ def clifs_page(model):
68
  )
69
  desc = st.sidebar.text_input(
70
  "Search Description",
71
- value="Two white puppies",
72
  help="Text description of what you want to find in the video",
73
  )
74
  url = st.sidebar.text_input(
75
  "Youtube Video URL",
76
- value='https://youtu.be/I3AaW9ZevIU',
77
  help="Youtube video you'd like to search through",
78
  )
79
 
80
  submit_button = st.sidebar.button("Search")
81
  if submit_button:
82
- text = st.text("Downloading video...")
83
- hook = lambda d: my_hook(d, )
84
  ydl_opts = {"format": "mp4[height=360]"}
85
  with youtube_dl.YoutubeDL(ydl_opts) as ydl:
86
  info_dict = ydl.extract_info(url, download=False)
87
  video_url = info_dict.get("url", None)
88
- find_frames(video_url, model, desc, top_k, text)
89
- print(video_url)
90
- # ydl.download([url])
91
 
92
  PAGES = {
93
  "Home": main_page,
@@ -99,12 +110,12 @@ PAGES = {
99
  def run():
100
  st.set_page_config(page_title="Youtube CLIFS")
101
  # main body
102
- model = get_model()
103
 
104
  st.sidebar.title('Navigation')
105
  selection = st.sidebar.radio("Go to", list(PAGES.keys()))
106
 
107
- page = PAGES[selection](model)
108
 
109
 
110
 
 
11
 
12
  @st.cache(allow_output_mutation=True, max_entries=1)
13
  def get_model():
14
+ txt_model = SentenceTransformer('clip-ViT-B-32-multilingual-v1').to(dtype=torch.float32, device=torch.device('cpu'))
15
  clip = CLIPModel()
16
+ vis_model = SentenceTransformer(modules=[clip]).to(dtype=torch.float32, device=torch.device('cpu'))
17
+ return txt_model, vis_model
18
 
19
 
20
+ def get_embedding(txt_model, vis_model, query, video):
21
+ text_emb = txt_model.encode(query, device='cpu')
22
 
23
  # Encode an image:
24
  images = []
25
  for img in video:
26
  images.append(Image.fromarray(img))
27
+ img_embs = vis_model.encode(images, device='cpu')
28
 
29
  return text_emb, img_embs
30
 
31
+ def find_frames(url, txt_model, vis_model, desc, seconds, top_k):
32
+ text = st.text("Downloading video...")
33
  probe = ffmpeg.probe(url)
34
  video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
35
  width = int(video_stream['width'])
36
  height = int(video_stream['height'])
37
  out, _ = (
38
  ffmpeg
39
+ .input(url, t=seconds)
40
  .output('pipe:', format='rawvideo', pix_fmt='rgb24')
41
  .run(capture_stdout=True)
42
  )
43
+
44
+ text.text("Processing video...")
45
  video = (
46
  np
47
  .frombuffer(out, np.uint8)
48
  .reshape([-1, height, width, 3])
49
  )[::10]
50
 
51
+ txt_embd, img_embds = get_embedding(txt_model, vis_model, desc, video)
52
  cos_scores = np.array(util.cos_sim(txt_embd, img_embds))
53
  ids = np.argsort(cos_scores)[0][-top_k:]
54
 
 
56
  text.empty()
57
  st.image(imgs)
58
 
59
+ with open("HOME.md", "r") as f:
60
+ HOME_PAGE = f.read()
61
+
62
+ def main_page(txt_model, vis_model):
63
  st.title("Introducing Youtube CLIFS")
64
+
65
+ st.markdown(HOME_PAGE)
66
 
67
+ def clifs_page(txt_model, vis_model):
68
  st.title("CLIFS")
69
 
70
  st.sidebar.markdown("### Controls:")
71
+ seconds = st.sidebar.slider(
72
+ "How many seconds of video to consider?",
73
+ min_value=10,
74
+ max_value=120,
75
+ value=60,
76
+ step=1,
77
+ )
78
  top_k = st.sidebar.slider(
79
  "Top K",
80
  min_value=1,
 
83
  )
84
  desc = st.sidebar.text_input(
85
  "Search Description",
86
+ value="Pancake in the shape of an otter", # panqueque en forma de nutria
87
  help="Text description of what you want to find in the video",
88
  )
89
  url = st.sidebar.text_input(
90
  "Youtube Video URL",
91
+ value='https://youtu.be/xUv6XgPwGaQ',
92
  help="Youtube video you'd like to search through",
93
  )
94
 
95
  submit_button = st.sidebar.button("Search")
96
  if submit_button:
 
 
97
  ydl_opts = {"format": "mp4[height=360]"}
98
  with youtube_dl.YoutubeDL(ydl_opts) as ydl:
99
  info_dict = ydl.extract_info(url, download=False)
100
  video_url = info_dict.get("url", None)
101
+ find_frames(video_url, txt_model, vis_model, desc, seconds, top_k)
 
 
102
 
103
  PAGES = {
104
  "Home": main_page,
 
110
  def run():
111
  st.set_page_config(page_title="Youtube CLIFS")
112
  # main body
113
+ txt_model, vis_model = get_model()
114
 
115
  st.sidebar.title('Navigation')
116
  selection = st.sidebar.radio("Go to", list(PAGES.keys()))
117
 
118
+ page = PAGES[selection](txt_model, vis_model)
119
 
120
 
121