Spaces:
Runtime error
Runtime error
watchtowerss
commited on
Commit
·
3c7c9f9
1
Parent(s):
6738b38
operation prompt version
Browse files- app.py +59 -36
- inpainter/base_inpainter.py +10 -10
app.py
CHANGED
@@ -103,7 +103,7 @@ def get_frames_from_video(video_input, video_state):
|
|
103 |
"fps": fps
|
104 |
}
|
105 |
video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size)
|
106 |
-
|
107 |
model.samcontroler.sam_controler.reset_image()
|
108 |
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
|
109 |
return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \
|
@@ -111,7 +111,8 @@ def get_frames_from_video(video_input, video_state):
|
|
111 |
gr.update(visible=True), gr.update(visible=True), \
|
112 |
gr.update(visible=True), gr.update(visible=True), \
|
113 |
gr.update(visible=True), gr.update(visible=True), \
|
114 |
-
gr.update(visible=True), gr.update(visible=True)
|
|
|
115 |
|
116 |
def run_example(example):
|
117 |
return video_input
|
@@ -130,15 +131,16 @@ def select_template(image_selection_slider, video_state, interactive_state):
|
|
130 |
# update the masks when select a new template frame
|
131 |
# if video_state["masks"][image_selection_slider] is not None:
|
132 |
# video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider])
|
|
|
133 |
|
134 |
-
|
135 |
-
return video_state["painted_images"][image_selection_slider], video_state, interactive_state
|
136 |
|
137 |
# set the tracking end frame
|
138 |
def get_end_number(track_pause_number_slider, video_state, interactive_state):
|
139 |
interactive_state["track_end_number"] = track_pause_number_slider
|
|
|
140 |
|
141 |
-
return video_state["painted_images"][track_pause_number_slider],interactive_state
|
142 |
|
143 |
def get_resize_ratio(resize_ratio_slider, interactive_state):
|
144 |
interactive_state["resize_ratio"] = resize_ratio_slider
|
@@ -175,25 +177,31 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
|
|
175 |
video_state["logits"][video_state["select_frame_number"]] = logit
|
176 |
video_state["painted_images"][video_state["select_frame_number"]] = painted_image
|
177 |
|
178 |
-
|
|
|
179 |
|
180 |
def add_multi_mask(video_state, interactive_state, mask_dropdown):
|
181 |
mask = video_state["masks"][video_state["select_frame_number"]]
|
182 |
interactive_state["multi_mask"]["masks"].append(mask)
|
183 |
interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
|
184 |
mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
|
185 |
-
select_frame = show_mask(video_state, interactive_state, mask_dropdown)
|
186 |
-
|
|
|
|
|
187 |
|
188 |
def clear_click(video_state, click_state):
|
189 |
click_state = [[],[]]
|
190 |
template_frame = video_state["origin_images"][video_state["select_frame_number"]]
|
191 |
-
|
|
|
192 |
|
193 |
-
def remove_multi_mask(interactive_state):
|
194 |
interactive_state["multi_mask"]["mask_names"]= []
|
195 |
interactive_state["multi_mask"]["masks"] = []
|
196 |
-
|
|
|
|
|
197 |
|
198 |
def show_mask(video_state, interactive_state, mask_dropdown):
|
199 |
mask_dropdown.sort()
|
@@ -203,12 +211,13 @@ def show_mask(video_state, interactive_state, mask_dropdown):
|
|
203 |
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
|
204 |
mask = interactive_state["multi_mask"]["masks"][mask_number]
|
205 |
select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
|
206 |
-
|
207 |
-
|
|
|
208 |
|
209 |
# tracking vos
|
210 |
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
211 |
-
|
212 |
model.xmem.clear_memory()
|
213 |
if interactive_state["track_end_number"]:
|
214 |
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
@@ -227,6 +236,12 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
227 |
else:
|
228 |
template_mask = video_state["masks"][video_state["select_frame_number"]]
|
229 |
fps = video_state["fps"]
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
231 |
# clear GPU memory
|
232 |
model.xmem.clear_memory()
|
@@ -259,7 +274,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
259 |
i+=1
|
260 |
# save_mask(video_state["masks"], video_state["video_name"])
|
261 |
#### shanggao code for mask save
|
262 |
-
return video_output, video_state, interactive_state
|
263 |
|
264 |
# extracting masks from mask_dropdown
|
265 |
# def extract_sole_mask(video_state, mask_dropdown):
|
@@ -269,6 +284,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
269 |
|
270 |
# inpaint
|
271 |
def inpaint_video(video_state, interactive_state, mask_dropdown):
|
|
|
272 |
|
273 |
frames = np.asarray(video_state["origin_images"])
|
274 |
fps = video_state["fps"]
|
@@ -286,10 +302,15 @@ def inpaint_video(video_state, interactive_state, mask_dropdown):
|
|
286 |
continue
|
287 |
inpaint_masks[inpaint_masks==i] = 0
|
288 |
# inpaint for videos
|
289 |
-
|
|
|
|
|
|
|
|
|
|
|
290 |
video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
|
291 |
|
292 |
-
return video_output
|
293 |
|
294 |
|
295 |
# generate video after vos inference
|
@@ -343,7 +364,7 @@ folder ="./checkpoints"
|
|
343 |
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
|
344 |
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
345 |
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
|
346 |
-
# args.port =
|
347 |
# args.device = "cuda:2"
|
348 |
# args.mask_save = True
|
349 |
|
@@ -396,9 +417,9 @@ with gr.Blocks() as iface:
|
|
396 |
with gr.Row(scale=0.4):
|
397 |
video_input = gr.Video(autosize=True)
|
398 |
with gr.Column():
|
399 |
-
video_info = gr.Textbox()
|
400 |
-
resize_info = gr.Textbox(value="
|
401 |
-
Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.")
|
402 |
resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
|
403 |
|
404 |
|
@@ -432,12 +453,13 @@ with gr.Blocks() as iface:
|
|
432 |
track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False)
|
433 |
|
434 |
with gr.Column():
|
435 |
-
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="
|
436 |
remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
|
437 |
video_output = gr.Video(autosize=True, visible=False).style(height=360)
|
438 |
with gr.Row():
|
439 |
tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
|
440 |
inpaint_video_predict_button = gr.Button(value="Inpaint", visible=False)
|
|
|
441 |
|
442 |
# first step: get the video information
|
443 |
extract_frames_button.click(
|
@@ -447,16 +469,16 @@ with gr.Blocks() as iface:
|
|
447 |
],
|
448 |
outputs=[video_state, video_info, template_frame,
|
449 |
image_selection_slider, track_pause_number_slider,point_prompt, click_mode, clear_button_click, Add_mask_button, template_frame,
|
450 |
-
tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button, inpaint_video_predict_button]
|
451 |
)
|
452 |
|
453 |
# second step: select images from slider
|
454 |
image_selection_slider.release(fn=select_template,
|
455 |
inputs=[image_selection_slider, video_state, interactive_state],
|
456 |
-
outputs=[template_frame, video_state, interactive_state], api_name="select_image")
|
457 |
track_pause_number_slider.release(fn=get_end_number,
|
458 |
inputs=[track_pause_number_slider, video_state, interactive_state],
|
459 |
-
outputs=[template_frame, interactive_state], api_name="end_image")
|
460 |
resize_ratio_slider.release(fn=get_resize_ratio,
|
461 |
inputs=[resize_ratio_slider, interactive_state],
|
462 |
outputs=[interactive_state], api_name="resize_ratio")
|
@@ -465,41 +487,41 @@ with gr.Blocks() as iface:
|
|
465 |
template_frame.select(
|
466 |
fn=sam_refine,
|
467 |
inputs=[video_state, point_prompt, click_state, interactive_state],
|
468 |
-
outputs=[template_frame, video_state, interactive_state]
|
469 |
)
|
470 |
|
471 |
# add different mask
|
472 |
Add_mask_button.click(
|
473 |
fn=add_multi_mask,
|
474 |
inputs=[video_state, interactive_state, mask_dropdown],
|
475 |
-
outputs=[interactive_state, mask_dropdown, template_frame, click_state]
|
476 |
)
|
477 |
|
478 |
remove_mask_button.click(
|
479 |
fn=remove_multi_mask,
|
480 |
-
inputs=[interactive_state],
|
481 |
-
outputs=[interactive_state]
|
482 |
)
|
483 |
|
484 |
# tracking video from select image and mask
|
485 |
tracking_video_predict_button.click(
|
486 |
fn=vos_tracking_video,
|
487 |
inputs=[video_state, interactive_state, mask_dropdown],
|
488 |
-
outputs=[video_output, video_state, interactive_state]
|
489 |
)
|
490 |
|
491 |
# inpaint video from select image and mask
|
492 |
inpaint_video_predict_button.click(
|
493 |
fn=inpaint_video,
|
494 |
inputs=[video_state, interactive_state, mask_dropdown],
|
495 |
-
outputs=[video_output]
|
496 |
)
|
497 |
|
498 |
# click to get mask
|
499 |
mask_dropdown.change(
|
500 |
fn=show_mask,
|
501 |
inputs=[video_state, interactive_state, mask_dropdown],
|
502 |
-
outputs=[template_frame]
|
503 |
)
|
504 |
|
505 |
# clear input
|
@@ -531,7 +553,8 @@ with gr.Blocks() as iface:
|
|
531 |
None,
|
532 |
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
533 |
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
534 |
-
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False),
|
|
|
535 |
|
536 |
),
|
537 |
[],
|
@@ -542,7 +565,7 @@ with gr.Blocks() as iface:
|
|
542 |
video_output,
|
543 |
template_frame,
|
544 |
tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, click_mode, clear_button_click,
|
545 |
-
Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button,inpaint_video_predict_button
|
546 |
],
|
547 |
queue=False,
|
548 |
show_progress=False)
|
@@ -551,7 +574,7 @@ with gr.Blocks() as iface:
|
|
551 |
clear_button_click.click(
|
552 |
fn = clear_click,
|
553 |
inputs = [video_state, click_state,],
|
554 |
-
outputs = [template_frame,click_state],
|
555 |
)
|
556 |
# set example
|
557 |
gr.Markdown("## Examples")
|
@@ -566,7 +589,7 @@ with gr.Blocks() as iface:
|
|
566 |
# cache_examples=True,
|
567 |
)
|
568 |
iface.queue(concurrency_count=1)
|
569 |
-
iface.launch(debug=True)
|
570 |
|
571 |
|
572 |
|
|
|
103 |
"fps": fps
|
104 |
}
|
105 |
video_info = "Video Name: {}, FPS: {}, Total Frames: {}, Image Size:{}".format(video_state["video_name"], video_state["fps"], len(frames), image_size)
|
106 |
+
operation_log = "Upload video already. Try click the image for adding targets to track and inpaint."
|
107 |
model.samcontroler.sam_controler.reset_image()
|
108 |
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
|
109 |
return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \
|
|
|
111 |
gr.update(visible=True), gr.update(visible=True), \
|
112 |
gr.update(visible=True), gr.update(visible=True), \
|
113 |
gr.update(visible=True), gr.update(visible=True), \
|
114 |
+
gr.update(visible=True), gr.update(visible=True), \
|
115 |
+
gr.update(visible=True, value=operation_log)
|
116 |
|
117 |
def run_example(example):
|
118 |
return video_input
|
|
|
131 |
# update the masks when select a new template frame
|
132 |
# if video_state["masks"][image_selection_slider] is not None:
|
133 |
# video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider])
|
134 |
+
operation_log = "Select frame {}. Try click image and add mask for tracking.".format(image_selection_slider)
|
135 |
|
136 |
+
return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log
|
|
|
137 |
|
138 |
# set the tracking end frame
|
139 |
def get_end_number(track_pause_number_slider, video_state, interactive_state):
|
140 |
interactive_state["track_end_number"] = track_pause_number_slider
|
141 |
+
operation_log = "Set the tracking finish at frame {}".format(track_pause_number_slider)
|
142 |
|
143 |
+
return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log
|
144 |
|
145 |
def get_resize_ratio(resize_ratio_slider, interactive_state):
|
146 |
interactive_state["resize_ratio"] = resize_ratio_slider
|
|
|
177 |
video_state["logits"][video_state["select_frame_number"]] = logit
|
178 |
video_state["painted_images"][video_state["select_frame_number"]] = painted_image
|
179 |
|
180 |
+
operation_log = "Use SAM for segment. You can try add positive and negative points by clicking. Or press Clear clicks button to refresh the image. Press Add mask button when you are satisfied with the segment"
|
181 |
+
return painted_image, video_state, interactive_state, operation_log
|
182 |
|
183 |
def add_multi_mask(video_state, interactive_state, mask_dropdown):
|
184 |
mask = video_state["masks"][video_state["select_frame_number"]]
|
185 |
interactive_state["multi_mask"]["masks"].append(mask)
|
186 |
interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
|
187 |
mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
|
188 |
+
select_frame, run_status = show_mask(video_state, interactive_state, mask_dropdown)
|
189 |
+
|
190 |
+
operation_log = "Added a mask, use the mask select for target tracking or inpainting."
|
191 |
+
return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log
|
192 |
|
193 |
def clear_click(video_state, click_state):
|
194 |
click_state = [[],[]]
|
195 |
template_frame = video_state["origin_images"][video_state["select_frame_number"]]
|
196 |
+
operation_log = "Clear points history and refresh the image."
|
197 |
+
return template_frame, click_state, operation_log
|
198 |
|
199 |
+
def remove_multi_mask(interactive_state, mask_dropdown):
|
200 |
interactive_state["multi_mask"]["mask_names"]= []
|
201 |
interactive_state["multi_mask"]["masks"] = []
|
202 |
+
|
203 |
+
operation_log = "Remove all mask, please add new masks"
|
204 |
+
return interactive_state, gr.update(choices=[],value=[]), operation_log
|
205 |
|
206 |
def show_mask(video_state, interactive_state, mask_dropdown):
|
207 |
mask_dropdown.sort()
|
|
|
211 |
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
|
212 |
mask = interactive_state["multi_mask"]["masks"][mask_number]
|
213 |
select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
|
214 |
+
|
215 |
+
operation_log = "Select {} for tracking or inpainting".format(mask_dropdown)
|
216 |
+
return select_frame, operation_log
|
217 |
|
218 |
# tracking vos
|
219 |
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
220 |
+
operation_log = "Track the selected masks, and then you can select the masks for inpainting."
|
221 |
model.xmem.clear_memory()
|
222 |
if interactive_state["track_end_number"]:
|
223 |
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
|
|
236 |
else:
|
237 |
template_mask = video_state["masks"][video_state["select_frame_number"]]
|
238 |
fps = video_state["fps"]
|
239 |
+
|
240 |
+
# operation error
|
241 |
+
if len(np.unique(template_mask))==1:
|
242 |
+
template_mask[0][0]=1
|
243 |
+
operation_log = "Error! Please add at least one mask to track by clicking the left image."
|
244 |
+
# return video_output, video_state, interactive_state, operation_error
|
245 |
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
246 |
# clear GPU memory
|
247 |
model.xmem.clear_memory()
|
|
|
274 |
i+=1
|
275 |
# save_mask(video_state["masks"], video_state["video_name"])
|
276 |
#### shanggao code for mask save
|
277 |
+
return video_output, video_state, interactive_state, operation_log
|
278 |
|
279 |
# extracting masks from mask_dropdown
|
280 |
# def extract_sole_mask(video_state, mask_dropdown):
|
|
|
284 |
|
285 |
# inpaint
|
286 |
def inpaint_video(video_state, interactive_state, mask_dropdown):
|
287 |
+
operation_log = "Removed the selected masks."
|
288 |
|
289 |
frames = np.asarray(video_state["origin_images"])
|
290 |
fps = video_state["fps"]
|
|
|
302 |
continue
|
303 |
inpaint_masks[inpaint_masks==i] = 0
|
304 |
# inpaint for videos
|
305 |
+
|
306 |
+
try:
|
307 |
+
inpainted_frames = model.baseinpainter.inpaint(frames, inpaint_masks, ratio=interactive_state["resize_ratio"]) # numpy array, T, H, W, 3
|
308 |
+
except:
|
309 |
+
operation_log = "Error! You are trying to inpaint without masks input. Please track the selected mask first, and then press inpaint. If VRAM exceeded, please use the resize ratio to scaling down the image size."
|
310 |
+
inpainted_frames = video_state["origin_images"]
|
311 |
video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
|
312 |
|
313 |
+
return video_output, operation_log
|
314 |
|
315 |
|
316 |
# generate video after vos inference
|
|
|
364 |
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
|
365 |
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
366 |
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
|
367 |
+
# args.port = 12212
|
368 |
# args.device = "cuda:2"
|
369 |
# args.mask_save = True
|
370 |
|
|
|
417 |
with gr.Row(scale=0.4):
|
418 |
video_input = gr.Video(autosize=True)
|
419 |
with gr.Column():
|
420 |
+
video_info = gr.Textbox(label="Video Info")
|
421 |
+
resize_info = gr.Textbox(value="If you want to use the inpaint function, it is best to git clone the repo and use a machine with more VRAM locally. \
|
422 |
+
Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.", label="Tips for running this demo.")
|
423 |
resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
|
424 |
|
425 |
|
|
|
453 |
track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False)
|
454 |
|
455 |
with gr.Column():
|
456 |
+
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False)
|
457 |
remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
|
458 |
video_output = gr.Video(autosize=True, visible=False).style(height=360)
|
459 |
with gr.Row():
|
460 |
tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
|
461 |
inpaint_video_predict_button = gr.Button(value="Inpaint", visible=False)
|
462 |
+
run_status = gr.Textbox(label="Operation log", visible=False)
|
463 |
|
464 |
# first step: get the video information
|
465 |
extract_frames_button.click(
|
|
|
469 |
],
|
470 |
outputs=[video_state, video_info, template_frame,
|
471 |
image_selection_slider, track_pause_number_slider,point_prompt, click_mode, clear_button_click, Add_mask_button, template_frame,
|
472 |
+
tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button, inpaint_video_predict_button, run_status]
|
473 |
)
|
474 |
|
475 |
# second step: select images from slider
|
476 |
image_selection_slider.release(fn=select_template,
|
477 |
inputs=[image_selection_slider, video_state, interactive_state],
|
478 |
+
outputs=[template_frame, video_state, interactive_state, run_status], api_name="select_image")
|
479 |
track_pause_number_slider.release(fn=get_end_number,
|
480 |
inputs=[track_pause_number_slider, video_state, interactive_state],
|
481 |
+
outputs=[template_frame, interactive_state, run_status], api_name="end_image")
|
482 |
resize_ratio_slider.release(fn=get_resize_ratio,
|
483 |
inputs=[resize_ratio_slider, interactive_state],
|
484 |
outputs=[interactive_state], api_name="resize_ratio")
|
|
|
487 |
template_frame.select(
|
488 |
fn=sam_refine,
|
489 |
inputs=[video_state, point_prompt, click_state, interactive_state],
|
490 |
+
outputs=[template_frame, video_state, interactive_state, run_status]
|
491 |
)
|
492 |
|
493 |
# add different mask
|
494 |
Add_mask_button.click(
|
495 |
fn=add_multi_mask,
|
496 |
inputs=[video_state, interactive_state, mask_dropdown],
|
497 |
+
outputs=[interactive_state, mask_dropdown, template_frame, click_state, run_status]
|
498 |
)
|
499 |
|
500 |
remove_mask_button.click(
|
501 |
fn=remove_multi_mask,
|
502 |
+
inputs=[interactive_state, mask_dropdown],
|
503 |
+
outputs=[interactive_state, mask_dropdown, run_status]
|
504 |
)
|
505 |
|
506 |
# tracking video from select image and mask
|
507 |
tracking_video_predict_button.click(
|
508 |
fn=vos_tracking_video,
|
509 |
inputs=[video_state, interactive_state, mask_dropdown],
|
510 |
+
outputs=[video_output, video_state, interactive_state, run_status]
|
511 |
)
|
512 |
|
513 |
# inpaint video from select image and mask
|
514 |
inpaint_video_predict_button.click(
|
515 |
fn=inpaint_video,
|
516 |
inputs=[video_state, interactive_state, mask_dropdown],
|
517 |
+
outputs=[video_output, run_status]
|
518 |
)
|
519 |
|
520 |
# click to get mask
|
521 |
mask_dropdown.change(
|
522 |
fn=show_mask,
|
523 |
inputs=[video_state, interactive_state, mask_dropdown],
|
524 |
+
outputs=[template_frame, run_status]
|
525 |
)
|
526 |
|
527 |
# clear input
|
|
|
553 |
None,
|
554 |
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
555 |
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
556 |
+
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False), \
|
557 |
+
gr.update(visible=False), gr.update(visible=False)
|
558 |
|
559 |
),
|
560 |
[],
|
|
|
565 |
video_output,
|
566 |
template_frame,
|
567 |
tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, click_mode, clear_button_click,
|
568 |
+
Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button,inpaint_video_predict_button, run_status
|
569 |
],
|
570 |
queue=False,
|
571 |
show_progress=False)
|
|
|
574 |
clear_button_click.click(
|
575 |
fn = clear_click,
|
576 |
inputs = [video_state, click_state,],
|
577 |
+
outputs = [template_frame,click_state, run_status],
|
578 |
)
|
579 |
# set example
|
580 |
gr.Markdown("## Examples")
|
|
|
589 |
# cache_examples=True,
|
590 |
)
|
591 |
iface.queue(concurrency_count=1)
|
592 |
+
iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
|
593 |
|
594 |
|
595 |
|
inpainter/base_inpainter.py
CHANGED
@@ -64,21 +64,21 @@ class BaseInpainter:
|
|
64 |
masks = np.stack([cv2.dilate(mask, kernel) for mask in masks], 0)
|
65 |
|
66 |
T, H, W = masks.shape
|
|
|
67 |
# size: (w, h)
|
68 |
if ratio == 1:
|
69 |
size = None
|
|
|
70 |
else:
|
71 |
size = [int(W*ratio), int(H*ratio)]
|
72 |
-
if size
|
73 |
-
|
74 |
-
if size
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
frames = resize_frames(frames, tuple(size)) # T, H, W, 3
|
80 |
# frames and binary_masks are numpy arrays
|
81 |
-
|
82 |
h, w = frames.shape[1:3]
|
83 |
video_length = T
|
84 |
|
@@ -156,7 +156,7 @@ if __name__ == '__main__':
|
|
156 |
base_inpainter = BaseInpainter(checkpoint, device)
|
157 |
# 3/3: inpainting (frames: numpy array, T, H, W, 3; masks: numpy array, T, H, W)
|
158 |
# ratio: (0, 1], ratio for down sample, default value is 1
|
159 |
-
inpainted_frames = base_inpainter.inpaint(frames, masks, ratio=
|
160 |
# ----------------------------------------------
|
161 |
# end
|
162 |
# ----------------------------------------------
|
|
|
64 |
masks = np.stack([cv2.dilate(mask, kernel) for mask in masks], 0)
|
65 |
|
66 |
T, H, W = masks.shape
|
67 |
+
masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
|
68 |
# size: (w, h)
|
69 |
if ratio == 1:
|
70 |
size = None
|
71 |
+
binary_masks = masks
|
72 |
else:
|
73 |
size = [int(W*ratio), int(H*ratio)]
|
74 |
+
size = [si+1 if si%2>0 else si for si in size] # only consider even values
|
75 |
+
# shortest side should be larger than 50
|
76 |
+
if min(size) < 50:
|
77 |
+
ratio = 50. / min(H, W)
|
78 |
+
size = [int(W*ratio), int(H*ratio)]
|
79 |
+
binary_masks = resize_masks(masks, tuple(size))
|
80 |
+
frames = resize_frames(frames, tuple(size)) # T, H, W, 3
|
|
|
81 |
# frames and binary_masks are numpy arrays
|
|
|
82 |
h, w = frames.shape[1:3]
|
83 |
video_length = T
|
84 |
|
|
|
156 |
base_inpainter = BaseInpainter(checkpoint, device)
|
157 |
# 3/3: inpainting (frames: numpy array, T, H, W, 3; masks: numpy array, T, H, W)
|
158 |
# ratio: (0, 1], ratio for down sample, default value is 1
|
159 |
+
inpainted_frames = base_inpainter.inpaint(frames, masks, ratio=0.01) # numpy array, T, H, W, 3
|
160 |
# ----------------------------------------------
|
161 |
# end
|
162 |
# ----------------------------------------------
|