Spaces:
Runtime error
Runtime error
watchtowerss
commited on
Commit
·
23d6e96
1
Parent(s):
98d86dc
memory usage reduce for tracking
Browse files- app.py +95 -45
- inpainter/base_inpainter.py +20 -4
- track_anything.py +21 -6
- tracker/.DS_Store +0 -0
app.py
CHANGED
@@ -8,15 +8,14 @@ import sys
|
|
8 |
sys.path.append(sys.path[0]+"/tracker")
|
9 |
sys.path.append(sys.path[0]+"/tracker/model")
|
10 |
from track_anything import TrackingAnything
|
11 |
-
from track_anything import parse_augment
|
12 |
import requests
|
13 |
import json
|
14 |
import torchvision
|
15 |
import torch
|
16 |
-
from tools.interact_tools import SamControler
|
17 |
-
from tracker.base_tracker import BaseTracker
|
18 |
from tools.painter import mask_painter
|
19 |
import psutil
|
|
|
20 |
try:
|
21 |
from mmcv.cnn import ConvModule
|
22 |
except:
|
@@ -71,6 +70,7 @@ def get_prompt(click_state, click_input):
|
|
71 |
return prompt
|
72 |
|
73 |
|
|
|
74 |
# extract frames from upload video
|
75 |
def get_frames_from_video(video_input, video_state):
|
76 |
"""
|
@@ -81,49 +81,72 @@ def get_frames_from_video(video_input, video_state):
|
|
81 |
[[0:nearest_frame], [nearest_frame:], nearest_frame]
|
82 |
"""
|
83 |
video_path = video_input
|
84 |
-
frames = []
|
|
|
|
|
|
|
85 |
|
|
|
|
|
86 |
operation_log = [("",""),("Upload video already. Try click the image for adding targets to track and inpaint.","Normal")]
|
87 |
try:
|
88 |
cap = cv2.VideoCapture(video_path)
|
89 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
while cap.isOpened():
|
91 |
ret, frame = cap.read()
|
92 |
if ret == True:
|
93 |
current_memory_usage = psutil.virtual_memory().percent
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
98 |
break
|
99 |
else:
|
100 |
break
|
|
|
101 |
except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
|
|
|
|
|
102 |
print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
104 |
# initialize video_state
|
105 |
video_state = {
|
|
|
106 |
"video_name": os.path.split(video_path)[-1],
|
107 |
"origin_images": frames,
|
108 |
"painted_images": frames.copy(),
|
109 |
-
"masks": [np.zeros((
|
110 |
"logits": [None]*len(frames),
|
111 |
"select_frame_number": 0,
|
112 |
"fps": fps
|
113 |
}
|
114 |
video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size)
|
115 |
model.samcontroler.sam_controler.reset_image()
|
116 |
-
model.samcontroler.sam_controler.set_image(
|
117 |
-
return video_state, video_info,
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
gr.update(visible=True), gr.update(visible=True), \
|
122 |
-
gr.update(visible=True), gr.update(visible=True), \
|
123 |
-
gr.update(visible=True, value=operation_log)
|
124 |
|
125 |
def run_example(example):
|
126 |
-
return
|
127 |
# get the select frame from gradio slider
|
128 |
def select_template(image_selection_slider, video_state, interactive_state):
|
129 |
|
@@ -134,21 +157,22 @@ def select_template(image_selection_slider, video_state, interactive_state):
|
|
134 |
# once select a new template frame, set the image in sam
|
135 |
|
136 |
model.samcontroler.sam_controler.reset_image()
|
137 |
-
model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
|
138 |
|
139 |
# update the masks when select a new template frame
|
140 |
# if video_state["masks"][image_selection_slider] is not None:
|
141 |
# video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider])
|
142 |
operation_log = [("",""), ("Select frame {}. Try click image and add mask for tracking.".format(image_selection_slider),"Normal")]
|
143 |
|
144 |
-
return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log
|
145 |
|
146 |
# set the tracking end frame
|
147 |
def get_end_number(track_pause_number_slider, video_state, interactive_state):
|
|
|
148 |
interactive_state["track_end_number"] = track_pause_number_slider
|
149 |
operation_log = [("",""),("Set the tracking finish at frame {}".format(track_pause_number_slider),"Normal")]
|
150 |
|
151 |
-
return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log
|
152 |
|
153 |
def get_resize_ratio(resize_ratio_slider, interactive_state):
|
154 |
interactive_state["resize_ratio"] = resize_ratio_slider
|
@@ -172,18 +196,18 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
|
|
172 |
|
173 |
# prompt for sam model
|
174 |
model.samcontroler.sam_controler.reset_image()
|
175 |
-
model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]])
|
176 |
prompt = get_prompt(click_state=click_state, click_input=coordinate)
|
177 |
|
178 |
mask, logit, painted_image = model.first_frame_click(
|
179 |
-
image=video_state["origin_images"][video_state["select_frame_number"]],
|
180 |
points=np.array(prompt["input_point"]),
|
181 |
labels=np.array(prompt["input_label"]),
|
182 |
multimask=prompt["multimask_output"],
|
183 |
)
|
184 |
video_state["masks"][video_state["select_frame_number"]] = mask
|
185 |
video_state["logits"][video_state["select_frame_number"]] = logit
|
186 |
-
video_state["painted_images"][video_state["select_frame_number"]] = painted_image
|
187 |
|
188 |
operation_log = [("",""), ("Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment","Normal")]
|
189 |
return painted_image, video_state, interactive_state, operation_log
|
@@ -203,7 +227,7 @@ def add_multi_mask(video_state, interactive_state, mask_dropdown):
|
|
203 |
|
204 |
def clear_click(video_state, click_state):
|
205 |
click_state = [[],[]]
|
206 |
-
template_frame = video_state["origin_images"][video_state["select_frame_number"]]
|
207 |
operation_log = [("",""), ("Clear points history and refresh the image.","Normal")]
|
208 |
return template_frame, click_state, operation_log
|
209 |
|
@@ -216,7 +240,7 @@ def remove_multi_mask(interactive_state, mask_dropdown):
|
|
216 |
|
217 |
def show_mask(video_state, interactive_state, mask_dropdown):
|
218 |
mask_dropdown.sort()
|
219 |
-
select_frame = video_state["origin_images"][video_state["select_frame_number"]]
|
220 |
|
221 |
for i in range(len(mask_dropdown)):
|
222 |
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
|
@@ -253,18 +277,18 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
253 |
template_mask[0][0]=1
|
254 |
operation_log = [("Error! Please add at least one mask to track by clicking the left image.","Error"), ("","")]
|
255 |
# return video_output, video_state, interactive_state, operation_error
|
256 |
-
masks, logits,
|
257 |
# clear GPU memory
|
258 |
model.xmem.clear_memory()
|
259 |
|
260 |
if interactive_state["track_end_number"]:
|
261 |
video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
|
262 |
video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits
|
263 |
-
video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] =
|
264 |
else:
|
265 |
video_state["masks"][video_state["select_frame_number"]:] = masks
|
266 |
video_state["logits"][video_state["select_frame_number"]:] = logits
|
267 |
-
video_state["painted_images"][video_state["select_frame_number"]:] =
|
268 |
|
269 |
video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/track/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
|
270 |
interactive_state["inference_times"] += 1
|
@@ -283,20 +307,16 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
283 |
for mask in video_state["masks"]:
|
284 |
np.save(os.path.join('./result/mask/{}'.format(video_state["video_name"].split('.')[0]), '{:05d}.npy'.format(i)), mask)
|
285 |
i+=1
|
286 |
-
# save_mask(video_state["masks"], video_state["video_name"])
|
287 |
#### shanggao code for mask save
|
288 |
return video_output, video_state, interactive_state, operation_log
|
289 |
|
290 |
-
|
291 |
-
# def extract_sole_mask(video_state, mask_dropdown):
|
292 |
-
# combined_masks =
|
293 |
-
# unique_masks = np.unique(combined_masks)
|
294 |
-
# return 0
|
295 |
|
296 |
# inpaint
|
297 |
def inpaint_video(video_state, interactive_state, mask_dropdown):
|
298 |
operation_log = [("",""), ("Removed the selected masks.","Normal")]
|
299 |
|
|
|
300 |
frames = np.asarray(video_state["origin_images"])
|
301 |
fps = video_state["fps"]
|
302 |
inpaint_masks = np.asarray(video_state["masks"])
|
@@ -319,13 +339,39 @@ def inpaint_video(video_state, interactive_state, mask_dropdown):
|
|
319 |
except:
|
320 |
operation_log = [("Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size.","Error"), ("","")]
|
321 |
inpainted_frames = video_state["origin_images"]
|
322 |
-
|
323 |
-
|
324 |
return video_output, operation_log
|
325 |
|
326 |
|
327 |
# generate video after vos inference
|
328 |
-
def generate_video_from_frames(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
329 |
"""
|
330 |
Generates a video from a list of frames.
|
331 |
|
@@ -375,8 +421,8 @@ folder ="./checkpoints"
|
|
375 |
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
|
376 |
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
377 |
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
|
378 |
-
# args.port =
|
379 |
-
# args.device = "cuda:
|
380 |
# args.mask_save = True
|
381 |
|
382 |
# initialize sam, xmem, e2fgvi models
|
@@ -409,6 +455,7 @@ with gr.Blocks() as iface:
|
|
409 |
|
410 |
video_state = gr.State(
|
411 |
{
|
|
|
412 |
"video_name": "",
|
413 |
"origin_images": None,
|
414 |
"painted_images": None,
|
@@ -458,7 +505,7 @@ with gr.Blocks() as iface:
|
|
458 |
track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False)
|
459 |
|
460 |
with gr.Column():
|
461 |
-
run_status = gr.HighlightedText(value=[("Text","Error"),("to be","Label 2"),("highlighted","Label 3")], visible=
|
462 |
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
|
463 |
video_output = gr.Video(autosize=True, visible=False).style(height=360)
|
464 |
with gr.Row():
|
@@ -471,9 +518,10 @@ with gr.Blocks() as iface:
|
|
471 |
inputs=[
|
472 |
video_input, video_state
|
473 |
],
|
474 |
-
outputs=[video_state, video_info, template_frame,
|
475 |
-
|
476 |
-
tracking_video_predict_button, video_output, mask_dropdown,
|
|
|
477 |
)
|
478 |
|
479 |
# second step: select images from slider
|
@@ -532,6 +580,8 @@ with gr.Blocks() as iface:
|
|
532 |
video_input.clear(
|
533 |
lambda: (
|
534 |
{
|
|
|
|
|
535 |
"origin_images": None,
|
536 |
"painted_images": None,
|
537 |
"masks": None,
|
|
|
8 |
sys.path.append(sys.path[0]+"/tracker")
|
9 |
sys.path.append(sys.path[0]+"/tracker/model")
|
10 |
from track_anything import TrackingAnything
|
11 |
+
from track_anything import parse_augment, save_image_to_userfolder, read_image_from_userfolder
|
12 |
import requests
|
13 |
import json
|
14 |
import torchvision
|
15 |
import torch
|
|
|
|
|
16 |
from tools.painter import mask_painter
|
17 |
import psutil
|
18 |
+
import time
|
19 |
try:
|
20 |
from mmcv.cnn import ConvModule
|
21 |
except:
|
|
|
70 |
return prompt
|
71 |
|
72 |
|
73 |
+
|
74 |
# extract frames from upload video
|
75 |
def get_frames_from_video(video_input, video_state):
|
76 |
"""
|
|
|
81 |
[[0:nearest_frame], [nearest_frame:], nearest_frame]
|
82 |
"""
|
83 |
video_path = video_input
|
84 |
+
frames = [] # save image path
|
85 |
+
user_name = time.time()
|
86 |
+
video_state["video_name"] = os.path.split(video_path)[-1]
|
87 |
+
video_state["user_name"] = user_name
|
88 |
|
89 |
+
os.makedirs(os.path.join("/tmp/{}/originimages/{}".format(video_state["user_name"], video_state["video_name"])), exist_ok=True)
|
90 |
+
os.makedirs(os.path.join("/tmp/{}/paintedimages/{}".format(video_state["user_name"], video_state["video_name"])), exist_ok=True)
|
91 |
operation_log = [("",""),("Upload video already. Try click the image for adding targets to track and inpaint.","Normal")]
|
92 |
try:
|
93 |
cap = cv2.VideoCapture(video_path)
|
94 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
95 |
+
if not cap.isOpened():
|
96 |
+
operation_log = [("No frames extracted, please input video file with '.mp4.' '.mov'.", "Error")]
|
97 |
+
print("No frames extracted, please input video file with '.mp4.' '.mov'.")
|
98 |
+
return None, None, None, None, \
|
99 |
+
None, None, None, None, \
|
100 |
+
None, None, None, None, \
|
101 |
+
None, None, gr.update(visible=True, value=operation_log)
|
102 |
+
image_index = 0
|
103 |
while cap.isOpened():
|
104 |
ret, frame = cap.read()
|
105 |
if ret == True:
|
106 |
current_memory_usage = psutil.virtual_memory().percent
|
107 |
+
|
108 |
+
# try solve memory usage problem, save image to disk instead of memory
|
109 |
+
frames.append(save_image_to_userfolder(video_state, image_index, frame, True))
|
110 |
+
image_index +=1
|
111 |
+
# frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
112 |
+
if current_memory_usage > 90:
|
113 |
+
operation_log = [("Memory usage is too high (>90%). Stop the video extraction. Please reduce the video resolution or frame rate.", "Error")]
|
114 |
+
print("Memory usage is too high (>90%). Please reduce the video resolution or frame rate.")
|
115 |
break
|
116 |
else:
|
117 |
break
|
118 |
+
|
119 |
except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
|
120 |
+
# except:
|
121 |
+
operation_log = [("read_frame_source:{} error. {}\n".format(video_path, str(e)), "Error")]
|
122 |
print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
|
123 |
+
return None, None, None, None, \
|
124 |
+
None, None, None, None, \
|
125 |
+
None, None, None, None, \
|
126 |
+
None, None, gr.update(visible=True, value=operation_log)
|
127 |
+
first_image = read_image_from_userfolder(frames[0])
|
128 |
+
image_size = (first_image.shape[0], first_image.shape[1])
|
129 |
# initialize video_state
|
130 |
video_state = {
|
131 |
+
"user_name": user_name,
|
132 |
"video_name": os.path.split(video_path)[-1],
|
133 |
"origin_images": frames,
|
134 |
"painted_images": frames.copy(),
|
135 |
+
"masks": [np.zeros((image_size[0], image_size[1]), np.uint8)]*len(frames),
|
136 |
"logits": [None]*len(frames),
|
137 |
"select_frame_number": 0,
|
138 |
"fps": fps
|
139 |
}
|
140 |
video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size)
|
141 |
model.samcontroler.sam_controler.reset_image()
|
142 |
+
model.samcontroler.sam_controler.set_image(first_image)
|
143 |
+
return video_state, video_info, first_image, gr.update(visible=True, maximum=len(frames), value=1), \
|
144 |
+
gr.update(visible=True, maximum=len(frames), value=len(frames)), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
|
145 |
+
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), \
|
146 |
+
gr.update(visible=True), gr.update(visible=True), gr.update(visible=True, value=operation_log),
|
|
|
|
|
|
|
147 |
|
148 |
def run_example(example):
|
149 |
+
return example
|
150 |
# get the select frame from gradio slider
|
151 |
def select_template(image_selection_slider, video_state, interactive_state):
|
152 |
|
|
|
157 |
# once select a new template frame, set the image in sam
|
158 |
|
159 |
model.samcontroler.sam_controler.reset_image()
|
160 |
+
model.samcontroler.sam_controler.set_image(read_image_from_userfolder(video_state["origin_images"][image_selection_slider]))
|
161 |
|
162 |
# update the masks when select a new template frame
|
163 |
# if video_state["masks"][image_selection_slider] is not None:
|
164 |
# video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider])
|
165 |
operation_log = [("",""), ("Select frame {}. Try click image and add mask for tracking.".format(image_selection_slider),"Normal")]
|
166 |
|
167 |
+
return read_image_from_userfolder(video_state["painted_images"][image_selection_slider]), video_state, interactive_state, operation_log
|
168 |
|
169 |
# set the tracking end frame
|
170 |
def get_end_number(track_pause_number_slider, video_state, interactive_state):
|
171 |
+
track_pause_number_slider -= 1
|
172 |
interactive_state["track_end_number"] = track_pause_number_slider
|
173 |
operation_log = [("",""),("Set the tracking finish at frame {}".format(track_pause_number_slider),"Normal")]
|
174 |
|
175 |
+
return read_image_from_userfolder(video_state["painted_images"][track_pause_number_slider]),interactive_state, operation_log
|
176 |
|
177 |
def get_resize_ratio(resize_ratio_slider, interactive_state):
|
178 |
interactive_state["resize_ratio"] = resize_ratio_slider
|
|
|
196 |
|
197 |
# prompt for sam model
|
198 |
model.samcontroler.sam_controler.reset_image()
|
199 |
+
model.samcontroler.sam_controler.set_image(read_image_from_userfolder(video_state["origin_images"][video_state["select_frame_number"]]))
|
200 |
prompt = get_prompt(click_state=click_state, click_input=coordinate)
|
201 |
|
202 |
mask, logit, painted_image = model.first_frame_click(
|
203 |
+
image=read_image_from_userfolder(video_state["origin_images"][video_state["select_frame_number"]]),
|
204 |
points=np.array(prompt["input_point"]),
|
205 |
labels=np.array(prompt["input_label"]),
|
206 |
multimask=prompt["multimask_output"],
|
207 |
)
|
208 |
video_state["masks"][video_state["select_frame_number"]] = mask
|
209 |
video_state["logits"][video_state["select_frame_number"]] = logit
|
210 |
+
video_state["painted_images"][video_state["select_frame_number"]] = save_image_to_userfolder(video_state, index=video_state["select_frame_number"], image=cv2.cvtColor(np.asarray(painted_image),cv2.COLOR_BGR2RGB),type=False)
|
211 |
|
212 |
operation_log = [("",""), ("Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment","Normal")]
|
213 |
return painted_image, video_state, interactive_state, operation_log
|
|
|
227 |
|
228 |
def clear_click(video_state, click_state):
|
229 |
click_state = [[],[]]
|
230 |
+
template_frame = read_image_from_userfolder(video_state["origin_images"][video_state["select_frame_number"]])
|
231 |
operation_log = [("",""), ("Clear points history and refresh the image.","Normal")]
|
232 |
return template_frame, click_state, operation_log
|
233 |
|
|
|
240 |
|
241 |
def show_mask(video_state, interactive_state, mask_dropdown):
|
242 |
mask_dropdown.sort()
|
243 |
+
select_frame = read_image_from_userfolder(video_state["origin_images"][video_state["select_frame_number"]])
|
244 |
|
245 |
for i in range(len(mask_dropdown)):
|
246 |
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
|
|
|
277 |
template_mask[0][0]=1
|
278 |
operation_log = [("Error! Please add at least one mask to track by clicking the left image.","Error"), ("","")]
|
279 |
# return video_output, video_state, interactive_state, operation_error
|
280 |
+
masks, logits, painted_images_path = model.generator(images=following_frames, template_mask=template_mask, video_state=video_state)
|
281 |
# clear GPU memory
|
282 |
model.xmem.clear_memory()
|
283 |
|
284 |
if interactive_state["track_end_number"]:
|
285 |
video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
|
286 |
video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits
|
287 |
+
video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = painted_images_path
|
288 |
else:
|
289 |
video_state["masks"][video_state["select_frame_number"]:] = masks
|
290 |
video_state["logits"][video_state["select_frame_number"]:] = logits
|
291 |
+
video_state["painted_images"][video_state["select_frame_number"]:] = painted_images_path
|
292 |
|
293 |
video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/track/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
|
294 |
interactive_state["inference_times"] += 1
|
|
|
307 |
for mask in video_state["masks"]:
|
308 |
np.save(os.path.join('./result/mask/{}'.format(video_state["video_name"].split('.')[0]), '{:05d}.npy'.format(i)), mask)
|
309 |
i+=1
|
|
|
310 |
#### shanggao code for mask save
|
311 |
return video_output, video_state, interactive_state, operation_log
|
312 |
|
313 |
+
|
|
|
|
|
|
|
|
|
314 |
|
315 |
# inpaint
|
316 |
def inpaint_video(video_state, interactive_state, mask_dropdown):
|
317 |
operation_log = [("",""), ("Removed the selected masks.","Normal")]
|
318 |
|
319 |
+
# solve memory
|
320 |
frames = np.asarray(video_state["origin_images"])
|
321 |
fps = video_state["fps"]
|
322 |
inpaint_masks = np.asarray(video_state["masks"])
|
|
|
339 |
except:
|
340 |
operation_log = [("Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size.","Error"), ("","")]
|
341 |
inpainted_frames = video_state["origin_images"]
|
342 |
+
video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
|
343 |
+
video_output = generate_video_from_paintedframes(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps)
|
344 |
return video_output, operation_log
|
345 |
|
346 |
|
347 |
# generate video after vos inference
|
348 |
+
def generate_video_from_frames(frames_path, output_path, fps=30):
|
349 |
+
"""
|
350 |
+
Generates a video from a list of frames.
|
351 |
+
|
352 |
+
Args:
|
353 |
+
frames (list of numpy arrays): The frames to include in the video.
|
354 |
+
output_path (str): The path to save the generated video.
|
355 |
+
fps (int, optional): The frame rate of the output video. Defaults to 30.
|
356 |
+
"""
|
357 |
+
# height, width, layers = frames[0].shape
|
358 |
+
# fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
359 |
+
# video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
360 |
+
# print(output_path)
|
361 |
+
# for frame in frames:
|
362 |
+
# video.write(frame)
|
363 |
+
|
364 |
+
# video.release()
|
365 |
+
frames = []
|
366 |
+
for file in frames_path:
|
367 |
+
frames.append(read_image_from_userfolder(file))
|
368 |
+
frames = torch.from_numpy(np.asarray(frames))
|
369 |
+
if not os.path.exists(os.path.dirname(output_path)):
|
370 |
+
os.makedirs(os.path.dirname(output_path))
|
371 |
+
torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
|
372 |
+
return output_path
|
373 |
+
|
374 |
+
def generate_video_from_paintedframes(frames, output_path, fps=30):
|
375 |
"""
|
376 |
Generates a video from a list of frames.
|
377 |
|
|
|
421 |
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
|
422 |
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
423 |
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
|
424 |
+
# args.port = 12213
|
425 |
+
# args.device = "cuda:1"
|
426 |
# args.mask_save = True
|
427 |
|
428 |
# initialize sam, xmem, e2fgvi models
|
|
|
455 |
|
456 |
video_state = gr.State(
|
457 |
{
|
458 |
+
"user_name": "",
|
459 |
"video_name": "",
|
460 |
"origin_images": None,
|
461 |
"painted_images": None,
|
|
|
505 |
track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False)
|
506 |
|
507 |
with gr.Column():
|
508 |
+
run_status = gr.HighlightedText(value=[("Text","Error"),("to be","Label 2"),("highlighted","Label 3")], visible=True)
|
509 |
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
|
510 |
video_output = gr.Video(autosize=True, visible=False).style(height=360)
|
511 |
with gr.Row():
|
|
|
518 |
inputs=[
|
519 |
video_input, video_state
|
520 |
],
|
521 |
+
outputs=[video_state, video_info, template_frame, image_selection_slider,
|
522 |
+
track_pause_number_slider,point_prompt, clear_button_click, Add_mask_button,
|
523 |
+
template_frame, tracking_video_predict_button, video_output, mask_dropdown,
|
524 |
+
remove_mask_button, inpaint_video_predict_button, run_status]
|
525 |
)
|
526 |
|
527 |
# second step: select images from slider
|
|
|
580 |
video_input.clear(
|
581 |
lambda: (
|
582 |
{
|
583 |
+
"user_name": "",
|
584 |
+
"video_name": "",
|
585 |
"origin_images": None,
|
586 |
"painted_images": None,
|
587 |
"masks": None,
|
inpainter/base_inpainter.py
CHANGED
@@ -1,17 +1,28 @@
|
|
1 |
import os
|
2 |
import glob
|
3 |
from PIL import Image
|
4 |
-
|
5 |
import torch
|
6 |
import yaml
|
7 |
import cv2
|
8 |
import importlib
|
9 |
import numpy as np
|
10 |
from tqdm import tqdm
|
11 |
-
|
12 |
from inpainter.util.tensor_util import resize_frames, resize_masks
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
class BaseInpainter:
|
16 |
def __init__(self, E2FGVI_checkpoint, device) -> None:
|
17 |
"""
|
@@ -46,7 +57,7 @@ class BaseInpainter:
|
|
46 |
ref_index.append(i)
|
47 |
return ref_index
|
48 |
|
49 |
-
def inpaint(self,
|
50 |
"""
|
51 |
frames: numpy array, T, H, W, 3
|
52 |
masks: numpy array, T, H, W
|
@@ -56,6 +67,11 @@ class BaseInpainter:
|
|
56 |
Output:
|
57 |
inpainted_frames: numpy array, T, H, W, 3
|
58 |
"""
|
|
|
|
|
|
|
|
|
|
|
59 |
assert frames.shape[:3] == masks.shape, 'different size between frames and masks'
|
60 |
assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]'
|
61 |
masks = masks.copy()
|
|
|
1 |
import os
|
2 |
import glob
|
3 |
from PIL import Image
|
|
|
4 |
import torch
|
5 |
import yaml
|
6 |
import cv2
|
7 |
import importlib
|
8 |
import numpy as np
|
9 |
from tqdm import tqdm
|
|
|
10 |
from inpainter.util.tensor_util import resize_frames, resize_masks
|
11 |
|
12 |
+
def read_image_from_userfolder(image_path):
|
13 |
+
# if type:
|
14 |
+
image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
15 |
+
# else:
|
16 |
+
# image = cv2.cvtColor(cv2.imread("/tmp/{}/paintedimages/{}/{:08d}.png".format(username, video_state["video_name"], index+ ".png")), cv2.COLOR_BGR2RGB)
|
17 |
+
return image
|
18 |
+
|
19 |
+
def save_image_to_userfolder(video_state, index, image, type:bool):
|
20 |
+
if type:
|
21 |
+
image_path = "/tmp/{}/originimages/{}/{:08d}.png".format(video_state["user_name"], video_state["video_name"], index)
|
22 |
+
else:
|
23 |
+
image_path = "/tmp/{}/paintedimages/{}/{:08d}.png".format(video_state["user_name"], video_state["video_name"], index)
|
24 |
+
cv2.imwrite(image_path, image)
|
25 |
+
return image_path
|
26 |
class BaseInpainter:
|
27 |
def __init__(self, E2FGVI_checkpoint, device) -> None:
|
28 |
"""
|
|
|
57 |
ref_index.append(i)
|
58 |
return ref_index
|
59 |
|
60 |
+
def inpaint(self, frames_path, masks, dilate_radius=15, ratio=1):
|
61 |
"""
|
62 |
frames: numpy array, T, H, W, 3
|
63 |
masks: numpy array, T, H, W
|
|
|
67 |
Output:
|
68 |
inpainted_frames: numpy array, T, H, W, 3
|
69 |
"""
|
70 |
+
frames = []
|
71 |
+
for file in frames_path:
|
72 |
+
frames.append(read_image_from_userfolder(file))
|
73 |
+
frames = np.asarray(frames)
|
74 |
+
|
75 |
assert frames.shape[:3] == masks.shape, 'different size between frames and masks'
|
76 |
assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]'
|
77 |
masks = masks.copy()
|
track_anything.py
CHANGED
@@ -6,9 +6,22 @@ from tracker.base_tracker import BaseTracker
|
|
6 |
from inpainter.base_inpainter import BaseInpainter
|
7 |
import numpy as np
|
8 |
import argparse
|
|
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
class TrackingAnything():
|
13 |
def __init__(self, sam_checkpoint, xmem_checkpoint, e2fgvi_checkpoint, args):
|
14 |
self.args = args
|
@@ -39,23 +52,25 @@ class TrackingAnything():
|
|
39 |
# mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
|
40 |
# return mask, logit, painted_image
|
41 |
|
42 |
-
def generator(self, images: list, template_mask:np.ndarray):
|
43 |
|
44 |
masks = []
|
45 |
logits = []
|
46 |
painted_images = []
|
47 |
for i in tqdm(range(len(images)), desc="Tracking image"):
|
48 |
if i ==0:
|
49 |
-
mask, logit, painted_image = self.xmem.track(images[i], template_mask)
|
50 |
masks.append(mask)
|
51 |
logits.append(logit)
|
52 |
-
painted_images.append(painted_image)
|
|
|
53 |
|
54 |
else:
|
55 |
-
mask, logit, painted_image = self.xmem.track(images[i])
|
56 |
masks.append(mask)
|
57 |
logits.append(logit)
|
58 |
-
painted_images.append(painted_image)
|
|
|
59 |
return masks, logits, painted_images
|
60 |
|
61 |
|
|
|
6 |
from inpainter.base_inpainter import BaseInpainter
|
7 |
import numpy as np
|
8 |
import argparse
|
9 |
+
import cv2
|
10 |
|
11 |
+
def read_image_from_userfolder(image_path):
|
12 |
+
# if type:
|
13 |
+
image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
14 |
+
# else:
|
15 |
+
# image = cv2.cvtColor(cv2.imread("/tmp/{}/paintedimages/{}/{:08d}.png".format(username, video_state["video_name"], index+ ".png")), cv2.COLOR_BGR2RGB)
|
16 |
+
return image
|
17 |
|
18 |
+
def save_image_to_userfolder(video_state, index, image, type:bool):
|
19 |
+
if type:
|
20 |
+
image_path = "/tmp/{}/originimages/{}/{:08d}.png".format(video_state["user_name"], video_state["video_name"], index)
|
21 |
+
else:
|
22 |
+
image_path = "/tmp/{}/paintedimages/{}/{:08d}.png".format(video_state["user_name"], video_state["video_name"], index)
|
23 |
+
cv2.imwrite(image_path, image)
|
24 |
+
return image_path
|
25 |
class TrackingAnything():
|
26 |
def __init__(self, sam_checkpoint, xmem_checkpoint, e2fgvi_checkpoint, args):
|
27 |
self.args = args
|
|
|
52 |
# mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
|
53 |
# return mask, logit, painted_image
|
54 |
|
55 |
+
def generator(self, images: list, template_mask:np.ndarray, video_state:dict):
|
56 |
|
57 |
masks = []
|
58 |
logits = []
|
59 |
painted_images = []
|
60 |
for i in tqdm(range(len(images)), desc="Tracking image"):
|
61 |
if i ==0:
|
62 |
+
mask, logit, painted_image = self.xmem.track(read_image_from_userfolder(images[i]), template_mask)
|
63 |
masks.append(mask)
|
64 |
logits.append(logit)
|
65 |
+
# painted_images.append(painted_image)
|
66 |
+
painted_images.append(save_image_to_userfolder(video_state, index=i, image=cv2.cvtColor(np.asarray(painted_image),cv2.COLOR_BGR2RGB), type=False))
|
67 |
|
68 |
else:
|
69 |
+
mask, logit, painted_image = self.xmem.track(read_image_from_userfolder(images[i]))
|
70 |
masks.append(mask)
|
71 |
logits.append(logit)
|
72 |
+
# painted_images.append(painted_image)
|
73 |
+
painted_images.append(save_image_to_userfolder(video_state, index=i, image=cv2.cvtColor(np.asarray(painted_image),cv2.COLOR_BGR2RGB), type=False))
|
74 |
return masks, logits, painted_images
|
75 |
|
76 |
|
tracker/.DS_Store
CHANGED
Binary files a/tracker/.DS_Store and b/tracker/.DS_Store differ
|
|