yiyixuxu commited on
Commit
8b1feb9
·
1 Parent(s): 84ee260

initial commit

Browse files
Files changed (3) hide show
  1. app.py +159 -0
  2. packages.txt +2 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import clip
3
+ import cv2, youtube_dl
4
+ from PIL import Image,ImageDraw, ImageFont
5
+ import os
6
+ from functools import partial
7
+ from multiprocessing.pool import Pool
8
+ import shutil
9
+ from pathlib import Path
10
+ import numpy as np
11
+ import datetime
12
+ import gradio as gr
13
+
14
+
15
+ # load model and preprocess
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ model, preprocess = clip.load("ViT-B/32")
18
+
19
+ def select_video_format(url, format_note='480p', ext='mp4'):
20
+ defaults = ['480p', '360p','240p','144p']
21
+ ydl_opts = {}
22
+ ydl = youtube_dl.YoutubeDL(ydl_opts)
23
+ info_dict = ydl.extract_info(url, download=False)
24
+ formats = info_dict.get('formats', None)
25
+ available_format_notes = set([f['format_note'] for f in formats])
26
+ if format_note not in available_format_notes:
27
+ format_note = [d for d in defaults if d in available_format_notes][0]
28
+ formats = [f for f in formats if f['format_note'] == format_note and f['ext'] == ext]
29
+ format = formats[0]
30
+ format_id = format.get('format_id', None)
31
+ fps = format.get('fps', None)
32
+ print(f'format selected: {format}')
33
+ return(format_id, fps)
34
+
35
+ def download_video(url,format_id):
36
+ ydl_opts = {
37
+ 'format':format_id,
38
+ 'outtmpl': "%(id)s.%(ext)s"}
39
+ meta = youtube_dl.YoutubeDL(ydl_opts).extract_info(url)
40
+ save_location = meta['id'] + '.' + meta['ext']
41
+ return(save_location)
42
+
43
+ def read_frames(dest_path):
44
+ original_images = []
45
+ images = []
46
+ for filename in sorted(dest_path.glob('*.jpg'),key=lambda p: int(p.stem)):
47
+ image = Image.open(filename).convert("RGB")
48
+ original_images.append(image)
49
+ images.append(preprocess(image))
50
+ return original_images, images
51
+
52
+ def process_video_parallel(url, skip_frames, dest_path, process_number):
53
+ cap = cv2.VideoCapture(url)
54
+ num_processes = os.cpu_count()
55
+ chunks_per_process = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) // (num_processes * skip_frames)
56
+ count = skip_frames * chunks_per_process * process_number
57
+ print(f"worker: {process_number}, process frames {count} ~ {skip_frames * chunks_per_process * (process_number + 1)}")
58
+ while count < skip_frames * chunks_per_process * (process_number + 1) :
59
+ cap.set(cv2.CAP_PROP_POS_FRAMES, count)
60
+ ret, frame = cap.read()
61
+ if not ret:
62
+ break
63
+ filename =f"{dest_path}/{count}.jpg"
64
+ cv2.imwrite(filename, frame)
65
+ count += skip_frames # Skip 300 frames i.e. 10 seconds for 30 fps
66
+ cap.release()
67
+
68
+
69
+ def vid2frames(url, sampling_interval=1, ext='mp4'):
70
+ # create folder for extracted frames - if folder exists, delete and create a new one
71
+ dest_path = Path('frames')
72
+ try:
73
+ dest_path.mkdir(parents=True)
74
+ except FileExistsError:
75
+ shutil.rmtree(dest_path)
76
+ dest_path.mkdir(parents=True)
77
+ # figure out the format for download,
78
+ # by default select 480p, if not available, choose the best format available
79
+ # mp4
80
+ format_id, fps = select_video_format(url, format_note='480p', ext='mp4')
81
+ # download the video
82
+ video = download_video(url,format_id)
83
+ # calculate skip_frames
84
+ try:
85
+ skip_frames = int(fps * sampling_interval)
86
+ except:
87
+ skip_frames = int(30 * sampling_interval)
88
+ print(f'video saved at: {video}, fps:{fps}, skip_frames: {skip_frames}')
89
+ # extract video frames at given sampling interval with multiprocessing -
90
+ print('extracting frames...')
91
+ n_workers = os.cpu_count()
92
+ with Pool(n_workers) as pool:
93
+ pool.map(partial(process_video_parallel, video, skip_frames, dest_path), range(n_workers))
94
+ return dest_path
95
+
96
+
97
+ def captioned_strip(images, caption=None, times=None, rows=1):
98
+ increased_h = 0 if caption is None else 30
99
+ w, h = images[0].size[0], images[0].size[1]
100
+ img = Image.new("RGB", (len(images) * w // rows, h * rows + increased_h))
101
+ for i, img_ in enumerate(images):
102
+ img.paste(img_, (i // rows * w, increased_h + (i % rows) * h))
103
+ if caption is not None:
104
+ draw = ImageDraw.Draw(img)
105
+ font = ImageFont.truetype(
106
+ "/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 16
107
+ )
108
+ font_small = ImageFont.truetype(
109
+ "/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 12
110
+ )
111
+ draw.text((20, 3), caption, (255, 255, 255), font=font)
112
+ for i,ts in enumerate(times):
113
+ draw.text((
114
+ (i % rows) * w + 40 , #column poistion
115
+ i // rows * h + 33) # row position
116
+ , ts,
117
+ (255, 255, 255), font=font_small)
118
+ return img
119
+
120
+ def run_inference(url, sampling_interval, search_query):
121
+ path_frames = vid2frames(url,sampling_interval)
122
+ original_images, images = read_frames(path_frames)
123
+ image_input = torch.tensor(np.stack(images)).to(device)
124
+ with torch.no_grad():
125
+ image_features = model.encode_image(image_input)
126
+ text_features = model.encode_text(clip.tokenize(search_query).to(device))
127
+
128
+ image_features /= image_features.norm(dim=-1, keepdim=True)
129
+ text_features /= text_features.norm(dim=-1, keepdim=True)
130
+
131
+ similarity = (100.0 * image_features @ text_features.T)
132
+ values, indices = similarity.topk(4, dim=0)
133
+
134
+ best_frames = [original_images[ind] for ind in indices]
135
+ times = [f'{datetime.timedelta(seconds = ind[0].item() * sampling_interval)}' for ind in indices]
136
+ image_output = captioned_strip(best_frames,search_query, times,2)
137
+ title = search_query
138
+ return(title, image_output)
139
+
140
+ inputs = [gr.inputs.Textbox(label="Give us the link to your youtube video!"),
141
+ gr.Number(5),
142
+ gr.inputs.Textbox(label="What do you want to search?")]
143
+ outputs = [
144
+ gr.outputs.HTML(label=""), # To be used as title
145
+ gr.outputs.Image(label=""),
146
+ ]
147
+
148
+ gr.Interface(
149
+ run_inference,
150
+ inputs=inputs,
151
+ outputs=outputs,
152
+ title="It Happened One Frame",
153
+ description='A CLIP-based app that search video frame based on text',
154
+ examples=[
155
+ ['https://youtu.be/v1rkzUIL8oc', 1, "James Cagney dancing down the stairs"],
156
+ ['https://youtu.be/k4R5wZs8cxI', 1, "James Cagney smashes a grapefruit into Mae Clarke's face"]
157
+ ]
158
+ ).launch(debug=True,enable_queue=True)
159
+
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ python3-opencv
2
+ libssl-dev
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ git+https://github.com/openai/CLIP.git
2
+ torch
3
+ youtube_dl
4
+ opencv-python