Spaces:
Running
on
A10G
Running
on
A10G
import sys | |
sys.path.append("../../") | |
import os | |
import json | |
import time | |
import psutil | |
import argparse | |
import cv2 | |
import torch | |
import torchvision | |
import numpy as np | |
import gradio as gr | |
from tools.painter import mask_painter | |
from track_anything import TrackingAnything | |
from model.misc import get_device | |
from utils.download_util import load_file_from_url, download_url_to_file | |
# make sample videos into mp4 as git does not allow mp4 without lfs | |
sample_videos_path = os.path.join('/home/user/app/web-demos/hugging_face/', "test_sample/") | |
download_url_to_file("https://github-production-user-asset-6210df.s3.amazonaws.com/14334509/281805130-e57c7016-5a6d-4d3b-9df9-b4ea6372cc87.mp4", os.path.join(sample_videos_path, "test-sample0.mp4")) | |
download_url_to_file("https://github-production-user-asset-6210df.s3.amazonaws.com/14334509/281828039-5def0fc9-3a22-45b7-838d-6bf78b6772c3.mp4", os.path.join(sample_videos_path, "test-sample1.mp4")) | |
download_url_to_file("https://github-production-user-asset-6210df.s3.amazonaws.com/76810782/281807801-69b9f70c-1e56-428d-9b1b-4870c5e533a7.mp4", os.path.join(sample_videos_path, "test-sample2.mp4")) | |
download_url_to_file("https://github-production-user-asset-6210df.s3.amazonaws.com/76810782/281808625-ad98f03f-99c7-4008-acf1-3d7beb48f13b.mp4", os.path.join(sample_videos_path, "test-sample3.mp4")) | |
download_url_to_file("https://github-production-user-asset-6210df.s3.amazonaws.com/14334509/281828066-ee09ae82-916f-4a2e-a6c7-6fc50645fd20.mp4", os.path.join(sample_videos_path, "test-sample4.mp4")) | |
def parse_augment(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--device', type=str, default=None) | |
parser.add_argument('--sam_model_type', type=str, default="vit_h") | |
parser.add_argument('--port', type=int, default=8000, help="only useful when running gradio applications") | |
parser.add_argument('--mask_save', default=False) | |
args = parser.parse_args() | |
if not args.device: | |
args.device = str(get_device()) | |
return args | |
# convert points input to prompt state | |
def get_prompt(click_state, click_input): | |
inputs = json.loads(click_input) | |
points = click_state[0] | |
labels = click_state[1] | |
for input in inputs: | |
points.append(input[:2]) | |
labels.append(input[2]) | |
click_state[0] = points | |
click_state[1] = labels | |
prompt = { | |
"prompt_type":["click"], | |
"input_point":click_state[0], | |
"input_label":click_state[1], | |
"multimask_output":"True", | |
} | |
return prompt | |
# extract frames from upload video | |
def get_frames_from_video(video_input, video_state): | |
""" | |
Args: | |
video_path:str | |
timestamp:float64 | |
Return | |
[[0:nearest_frame], [nearest_frame:], nearest_frame] | |
""" | |
video_path = video_input | |
frames = [] | |
user_name = time.time() | |
operation_log = [("",""),("Video uploaded! Try to click the image shown in step2 to add masks.","Normal")] | |
try: | |
cap = cv2.VideoCapture(video_path) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if ret == True: | |
current_memory_usage = psutil.virtual_memory().percent | |
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
# if current_memory_usage > 90: | |
# operation_log = [("Memory usage is too high (>90%). Stop the video extraction. Please reduce the video resolution or frame rate.", "Error")] | |
# print("Memory usage is too high (>90%). Please reduce the video resolution or frame rate.") | |
# break | |
else: | |
break | |
except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e: | |
print("read_frame_source:{} error. {}\n".format(video_path, str(e))) | |
image_size = (frames[0].shape[0],frames[0].shape[1]) | |
# initialize video_state | |
video_state = { | |
"user_name": user_name, | |
"video_name": os.path.split(video_path)[-1], | |
"origin_images": frames, | |
"painted_images": frames.copy(), | |
"masks": [np.zeros((frames[0].shape[0],frames[0].shape[1]), np.uint8)]*len(frames), | |
"logits": [None]*len(frames), | |
"select_frame_number": 0, | |
"fps": fps | |
} | |
video_info = "Video Name: {},\nFPS: {},\nTotal Frames: {},\nImage Size:{}".format(video_state["video_name"], round(video_state["fps"], 0), len(frames), image_size) | |
model.samcontroler.sam_controler.reset_image() | |
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0]) | |
return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \ | |
gr.update(visible=True), gr.update(visible=True), \ | |
gr.update(visible=True), gr.update(visible=True),\ | |
gr.update(visible=True), gr.update(visible=True), \ | |
gr.update(visible=True), gr.update(visible=True), \ | |
gr.update(visible=True), gr.update(visible=True), \ | |
gr.update(visible=True), gr.update(visible=True, choices=[], value=[]), \ | |
gr.update(visible=True, value=operation_log), gr.update(visible=True, value=operation_log) | |
# get the select frame from gradio slider | |
def select_template(image_selection_slider, video_state, interactive_state, mask_dropdown): | |
# images = video_state[1] | |
image_selection_slider -= 1 | |
video_state["select_frame_number"] = image_selection_slider | |
# once select a new template frame, set the image in sam | |
model.samcontroler.sam_controler.reset_image() | |
model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider]) | |
operation_log = [("",""), ("Select tracking start frame {}. Try to click the image to add masks for tracking.".format(image_selection_slider),"Normal")] | |
return video_state["painted_images"][image_selection_slider], video_state, interactive_state, operation_log, operation_log | |
# set the tracking end frame | |
def get_end_number(track_pause_number_slider, video_state, interactive_state): | |
interactive_state["track_end_number"] = track_pause_number_slider | |
operation_log = [("",""),("Select tracking finish frame {}.Try to click the image to add masks for tracking.".format(track_pause_number_slider),"Normal")] | |
return video_state["painted_images"][track_pause_number_slider],interactive_state, operation_log, operation_log | |
# use sam to get the mask | |
def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData): | |
""" | |
Args: | |
template_frame: PIL.Image | |
point_prompt: flag for positive or negative button click | |
click_state: [[points], [labels]] | |
""" | |
if point_prompt == "Positive": | |
coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1]) | |
interactive_state["positive_click_times"] += 1 | |
else: | |
coordinate = "[[{},{},0]]".format(evt.index[0], evt.index[1]) | |
interactive_state["negative_click_times"] += 1 | |
# prompt for sam model | |
model.samcontroler.sam_controler.reset_image() | |
model.samcontroler.sam_controler.set_image(video_state["origin_images"][video_state["select_frame_number"]]) | |
prompt = get_prompt(click_state=click_state, click_input=coordinate) | |
mask, logit, painted_image = model.first_frame_click( | |
image=video_state["origin_images"][video_state["select_frame_number"]], | |
points=np.array(prompt["input_point"]), | |
labels=np.array(prompt["input_label"]), | |
multimask=prompt["multimask_output"], | |
) | |
video_state["masks"][video_state["select_frame_number"]] = mask | |
video_state["logits"][video_state["select_frame_number"]] = logit | |
video_state["painted_images"][video_state["select_frame_number"]] = painted_image | |
operation_log = [("",""), ("You can try to add positive or negative points by clicking, click Clear clicks button to refresh the image, click Add mask button when you are satisfied with the segment, or click Remove mask button to remove all added masks.","Normal")] | |
return painted_image, video_state, interactive_state, operation_log, operation_log | |
def add_multi_mask(video_state, interactive_state, mask_dropdown): | |
try: | |
mask = video_state["masks"][video_state["select_frame_number"]] | |
interactive_state["multi_mask"]["masks"].append(mask) | |
interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) | |
mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"]))) | |
select_frame, _, _ = show_mask(video_state, interactive_state, mask_dropdown) | |
operation_log = [("",""),("Added a mask, use the mask select for target tracking or inpainting.","Normal")] | |
except: | |
operation_log = [("Please click the image in step2 to generate masks.", "Error"), ("","")] | |
return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]], operation_log, operation_log | |
def clear_click(video_state, click_state): | |
click_state = [[],[]] | |
template_frame = video_state["origin_images"][video_state["select_frame_number"]] | |
operation_log = [("",""), ("Cleared points history and refresh the image.","Normal")] | |
return template_frame, click_state, operation_log, operation_log | |
def remove_multi_mask(interactive_state, mask_dropdown): | |
interactive_state["multi_mask"]["mask_names"]= [] | |
interactive_state["multi_mask"]["masks"] = [] | |
operation_log = [("",""), ("Remove all masks. Try to add new masks","Normal")] | |
return interactive_state, gr.update(choices=[],value=[]), operation_log, operation_log | |
def show_mask(video_state, interactive_state, mask_dropdown): | |
mask_dropdown.sort() | |
select_frame = video_state["origin_images"][video_state["select_frame_number"]] | |
for i in range(len(mask_dropdown)): | |
mask_number = int(mask_dropdown[i].split("_")[1]) - 1 | |
mask = interactive_state["multi_mask"]["masks"][mask_number] | |
select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2) | |
operation_log = [("",""), ("Added masks {}. If you want to do the inpainting with current masks, please go to step3, and click the Tracking button first and then Inpainting button.".format(mask_dropdown),"Normal")] | |
return select_frame, operation_log, operation_log | |
# tracking vos | |
def vos_tracking_video(video_state, interactive_state, mask_dropdown): | |
operation_log = [("",""), ("Tracking finished! Try to click the Inpainting button to get the inpainting result.","Normal")] | |
model.cutie.clear_memory() | |
if interactive_state["track_end_number"]: | |
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] | |
else: | |
following_frames = video_state["origin_images"][video_state["select_frame_number"]:] | |
if interactive_state["multi_mask"]["masks"]: | |
if len(mask_dropdown) == 0: | |
mask_dropdown = ["mask_001"] | |
mask_dropdown.sort() | |
template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1])) | |
for i in range(1,len(mask_dropdown)): | |
mask_number = int(mask_dropdown[i].split("_")[1]) - 1 | |
template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1) | |
video_state["masks"][video_state["select_frame_number"]]= template_mask | |
else: | |
template_mask = video_state["masks"][video_state["select_frame_number"]] | |
fps = video_state["fps"] | |
# operation error | |
if len(np.unique(template_mask))==1: | |
template_mask[0][0]=1 | |
operation_log = [("Please add at least one mask to track by clicking the image in step2.","Error"), ("","")] | |
# return video_output, video_state, interactive_state, operation_error | |
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask) | |
# clear GPU memory | |
model.cutie.clear_memory() | |
if interactive_state["track_end_number"]: | |
video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks | |
video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits | |
video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = painted_images | |
else: | |
video_state["masks"][video_state["select_frame_number"]:] = masks | |
video_state["logits"][video_state["select_frame_number"]:] = logits | |
video_state["painted_images"][video_state["select_frame_number"]:] = painted_images | |
video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/track/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video | |
interactive_state["inference_times"] += 1 | |
print("For generating this tracking result, inference times: {}, click times: {}, positive: {}, negative: {}".format(interactive_state["inference_times"], | |
interactive_state["positive_click_times"]+interactive_state["negative_click_times"], | |
interactive_state["positive_click_times"], | |
interactive_state["negative_click_times"])) | |
#### shanggao code for mask save | |
if interactive_state["mask_save"]: | |
if not os.path.exists('./result/mask/{}'.format(video_state["video_name"].split('.')[0])): | |
os.makedirs('./result/mask/{}'.format(video_state["video_name"].split('.')[0])) | |
i = 0 | |
print("save mask") | |
for mask in video_state["masks"]: | |
np.save(os.path.join('./result/mask/{}'.format(video_state["video_name"].split('.')[0]), '{:05d}.npy'.format(i)), mask) | |
i+=1 | |
# save_mask(video_state["masks"], video_state["video_name"]) | |
#### shanggao code for mask save | |
return video_output, video_state, interactive_state, operation_log, operation_log | |
# inpaint | |
def inpaint_video(video_state, resize_ratio_number, dilate_radius_number, raft_iter_number, subvideo_length_number, neighbor_length_number, ref_stride_number, mask_dropdown): | |
operation_log = [("",""), ("Inpainting finished!","Normal")] | |
frames = np.asarray(video_state["origin_images"]) | |
fps = video_state["fps"] | |
inpaint_masks = np.asarray(video_state["masks"]) | |
if len(mask_dropdown) == 0: | |
mask_dropdown = ["mask_001"] | |
mask_dropdown.sort() | |
# convert mask_dropdown to mask numbers | |
inpaint_mask_numbers = [int(mask_dropdown[i].split("_")[1]) for i in range(len(mask_dropdown))] | |
# interate through all masks and remove the masks that are not in mask_dropdown | |
unique_masks = np.unique(inpaint_masks) | |
num_masks = len(unique_masks) - 1 | |
for i in range(1, num_masks + 1): | |
if i in inpaint_mask_numbers: | |
continue | |
inpaint_masks[inpaint_masks==i] = 0 | |
# inpaint for videos | |
inpainted_frames = model.baseinpainter.inpaint(frames, | |
inpaint_masks, | |
ratio=resize_ratio_number, | |
dilate_radius=dilate_radius_number, | |
raft_iter=raft_iter_number, | |
subvideo_length=subvideo_length_number, | |
neighbor_length=neighbor_length_number, | |
ref_stride=ref_stride_number) # numpy array, T, H, W, 3 | |
video_output = generate_video_from_frames(inpainted_frames, output_path="./result/inpaint/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video | |
return video_output, operation_log, operation_log | |
# generate video after vos inference | |
def generate_video_from_frames(frames, output_path, fps=30): | |
""" | |
Generates a video from a list of frames. | |
Args: | |
frames (list of numpy arrays): The frames to include in the video. | |
output_path (str): The path to save the generated video. | |
fps (int, optional): The frame rate of the output video. Defaults to 30. | |
""" | |
frames = torch.from_numpy(np.asarray(frames)) | |
if not os.path.exists(os.path.dirname(output_path)): | |
os.makedirs(os.path.dirname(output_path)) | |
torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264") | |
return output_path | |
def restart(): | |
operation_log = [("",""), ("Try to upload your video and click the Get video info button to get started!", "Normal")] | |
return { | |
"user_name": "", | |
"video_name": "", | |
"origin_images": None, | |
"painted_images": None, | |
"masks": None, | |
"inpaint_masks": None, | |
"logits": None, | |
"select_frame_number": 0, | |
"fps": 30 | |
}, { | |
"inference_times": 0, | |
"negative_click_times" : 0, | |
"positive_click_times": 0, | |
"mask_save": args.mask_save, | |
"multi_mask": { | |
"mask_names": [], | |
"masks": [] | |
}, | |
"track_end_number": None, | |
}, [[],[]], None, None, None, \ | |
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),\ | |
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ | |
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \ | |
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), "", \ | |
gr.update(visible=True, value=operation_log), gr.update(visible=False, value=operation_log) | |
# args, defined in track_anything.py | |
args = parse_augment() | |
pretrain_model_url = 'https://github.com/sczhou/ProPainter/releases/download/v0.1.0/' | |
sam_checkpoint_url_dict = { | |
'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", | |
'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", | |
'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" | |
} | |
checkpoint_fodler = os.path.join('..', '..', 'weights') | |
sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[args.sam_model_type], checkpoint_fodler) | |
cutie_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'cutie-base-mega.pth'), checkpoint_fodler) | |
propainter_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'ProPainter.pth'), checkpoint_fodler) | |
raft_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'raft-things.pth'), checkpoint_fodler) | |
flow_completion_checkpoint = load_file_from_url(os.path.join(pretrain_model_url, 'recurrent_flow_completion.pth'), checkpoint_fodler) | |
# initialize sam, cutie, propainter models | |
model = TrackingAnything(sam_checkpoint, cutie_checkpoint, propainter_checkpoint, raft_checkpoint, flow_completion_checkpoint, args) | |
title = r"""<h1 align="center">ProPainter: Improving Propagation and Transformer for Video Inpainting</h1>""" | |
description = r""" | |
<center><img src='https://github.com/sczhou/ProPainter/raw/main/assets/propainter_logo1_glow.png' alt='Propainter logo' style="width:180px; margin-bottom:20px"></center> | |
<b>Official Gradio demo</b> for <a href='https://github.com/sczhou/ProPainter' target='_blank'><b>Improving Propagation and Transformer for Video Inpainting (ICCV 2023)</b></a>.<br> | |
π₯ Propainter is a robust inpainting algorithm.<br> | |
π€ Try to drop your video, add the masks and get the the inpainting results!<br> | |
""" | |
article = r""" | |
If ProPainter is helpful, please help to β the <a href='https://github.com/sczhou/ProPainter' target='_blank'>Github Repo</a>. Thanks! | |
[![GitHub Stars](https://img.shields.io/github/stars/sczhou/ProPainter?style=social)](https://github.com/sczhou/ProPainter) | |
--- | |
π **Citation** | |
<br> | |
If our work is useful for your research, please consider citing: | |
```bibtex | |
@inproceedings{zhou2023propainter, | |
title={{ProPainter}: Improving Propagation and Transformer for Video Inpainting}, | |
author={Zhou, Shangchen and Li, Chongyi and Chan, Kelvin C.K and Loy, Chen Change}, | |
booktitle={Proceedings of IEEE International Conference on Computer Vision (ICCV)}, | |
year={2023} | |
} | |
``` | |
π **License** | |
<br> | |
This project is licensed under <a rel="license" href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">S-Lab License 1.0</a>. | |
Redistribution and use for non-commercial purposes should follow this license. | |
π§ **Contact** | |
<br> | |
If you have any questions, please feel free to reach me out at <b>shangchenzhou@gmail.com</b>. | |
<div> | |
π€ Find Me: | |
<a href="https://twitter.com/ShangchenZhou"><img style="margin-top:0.5em; margin-bottom:0.5em" src="https://img.shields.io/twitter/follow/ShangchenZhou?label=%40ShangchenZhou&style=social" alt="Twitter Follow"></a> | |
<a href="https://github.com/sczhou"><img style="margin-top:0.5em; margin-bottom:2em" src="https://img.shields.io/github/followers/sczhou?style=social" alt="Github Follow"></a> | |
</div> | |
""" | |
css = """ | |
.gradio-container {width: 85% !important} | |
.gr-monochrome-group {border-radius: 5px !important; border: revert-layer !important; border-width: 2px !important; color: black !important;} | |
span.svelte-s1r2yt {font-size: 17px !important; font-weight: bold !important; color: #d30f2f !important;} | |
button {border-radius: 8px !important;} | |
.add_button {background-color: #4CAF50 !important;} | |
.remove_button {background-color: #f44336 !important;} | |
.mask_button_group {gap: 10px !important;} | |
.video {height: 300px !important;} | |
.image {height: 300px !important;} | |
.video .wrap.svelte-lcpz3o {display: flex !important; align-items: center !important; justify-content: center !important;} | |
.video .wrap.svelte-lcpz3o > :first-child {height: 100% !important;} | |
.margin_center {width: 50% !important; margin: auto !important;} | |
.jc_center {justify-content: center !important;} | |
""" | |
with gr.Blocks(theme=gr.themes.Monochrome(), css=css) as iface: | |
click_state = gr.State([[],[]]) | |
interactive_state = gr.State({ | |
"inference_times": 0, | |
"negative_click_times" : 0, | |
"positive_click_times": 0, | |
"mask_save": args.mask_save, | |
"multi_mask": { | |
"mask_names": [], | |
"masks": [] | |
}, | |
"track_end_number": None, | |
} | |
) | |
video_state = gr.State( | |
{ | |
"user_name": "", | |
"video_name": "", | |
"origin_images": None, | |
"painted_images": None, | |
"masks": None, | |
"inpaint_masks": None, | |
"logits": None, | |
"select_frame_number": 0, | |
"fps": 30 | |
} | |
) | |
gr.Markdown(title) | |
gr.Markdown(description) | |
with gr.Group(elem_classes="gr-monochrome-group"): | |
with gr.Row(): | |
with gr.Accordion('ProPainter Parameters (click to expand)', open=False): | |
with gr.Row(): | |
resize_ratio_number = gr.Slider(label='Resize ratio', | |
minimum=0.01, | |
maximum=1.0, | |
step=0.01, | |
value=1.0) | |
raft_iter_number = gr.Slider(label='Iterations for RAFT inference.', | |
minimum=5, | |
maximum=20, | |
step=1, | |
value=20,) | |
with gr.Row(): | |
dilate_radius_number = gr.Slider(label='Mask dilation for video and flow masking.', | |
minimum=0, | |
maximum=10, | |
step=1, | |
value=8,) | |
subvideo_length_number = gr.Slider(label='Length of sub-video for long video inference.', | |
minimum=40, | |
maximum=200, | |
step=1, | |
value=80,) | |
with gr.Row(): | |
neighbor_length_number = gr.Slider(label='Length of local neighboring frames.', | |
minimum=5, | |
maximum=20, | |
step=1, | |
value=10,) | |
ref_stride_number = gr.Slider(label='Stride of global reference frames.', | |
minimum=5, | |
maximum=20, | |
step=1, | |
value=10,) | |
with gr.Column(): | |
# input video | |
gr.Markdown("## Step1: Upload video") | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=2): | |
video_input = gr.Video(elem_classes="video") | |
extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary") | |
with gr.Column(scale=2): | |
run_status = gr.HighlightedText(value=[("",""), ("Try to upload your video and click the Get svideo info button to get started!", "Normal")]) | |
video_info = gr.Textbox(label="Video Info") | |
# add masks | |
step2_title = gr.Markdown("---\n## Step2: Add masks", visible=False) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=2): | |
template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image") | |
image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track start frame", visible=False) | |
track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frame", visible=False) | |
with gr.Column(scale=2, elem_classes="jc_center"): | |
run_status2 = gr.HighlightedText(value=[("",""), ("Try to upload your video and click the Get svideo info button to get started!", "Normal")], visible=False) | |
with gr.Row(): | |
with gr.Column(scale=2, elem_classes="mask_button_group"): | |
clear_button_click = gr.Button(value="Clear clicks", interactive=True, visible=False) | |
remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False, elem_classes="remove_button") | |
Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False, elem_classes="add_button") | |
point_prompt = gr.Radio( | |
choices=["Positive", "Negative"], | |
value="Positive", | |
label="Point prompt", | |
interactive=True, | |
visible=False, | |
min_width=100, | |
scale=1) | |
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask selection", info=".", visible=False) | |
# output video | |
step3_title = gr.Markdown("---\n## Step3: Track masks and get the inpainting result", visible=False) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=2): | |
tracking_video_output = gr.Video(visible=False, elem_classes="video") | |
tracking_video_predict_button = gr.Button(value="1. Tracking", visible=False, elem_classes="margin_center") | |
with gr.Column(scale=2): | |
inpaiting_video_output = gr.Video(visible=False, elem_classes="video") | |
inpaint_video_predict_button = gr.Button(value="2. Inpainting", visible=False, elem_classes="margin_center") | |
# first step: get the video information | |
extract_frames_button.click( | |
fn=get_frames_from_video, | |
inputs=[ | |
video_input, video_state | |
], | |
outputs=[video_state, video_info, template_frame, | |
image_selection_slider, track_pause_number_slider,point_prompt, clear_button_click, Add_mask_button, template_frame, | |
tracking_video_predict_button, tracking_video_output, inpaiting_video_output, remove_mask_button, inpaint_video_predict_button, step2_title, step3_title,mask_dropdown, run_status, run_status2] | |
) | |
# second step: select images from slider | |
image_selection_slider.release(fn=select_template, | |
inputs=[image_selection_slider, video_state, interactive_state], | |
outputs=[template_frame, video_state, interactive_state, run_status, run_status2], api_name="select_image") | |
track_pause_number_slider.release(fn=get_end_number, | |
inputs=[track_pause_number_slider, video_state, interactive_state], | |
outputs=[template_frame, interactive_state, run_status, run_status2], api_name="end_image") | |
# click select image to get mask using sam | |
template_frame.select( | |
fn=sam_refine, | |
inputs=[video_state, point_prompt, click_state, interactive_state], | |
outputs=[template_frame, video_state, interactive_state, run_status, run_status2] | |
) | |
# add different mask | |
Add_mask_button.click( | |
fn=add_multi_mask, | |
inputs=[video_state, interactive_state, mask_dropdown], | |
outputs=[interactive_state, mask_dropdown, template_frame, click_state, run_status, run_status2] | |
) | |
remove_mask_button.click( | |
fn=remove_multi_mask, | |
inputs=[interactive_state, mask_dropdown], | |
outputs=[interactive_state, mask_dropdown, run_status, run_status2] | |
) | |
# tracking video from select image and mask | |
tracking_video_predict_button.click( | |
fn=vos_tracking_video, | |
inputs=[video_state, interactive_state, mask_dropdown], | |
outputs=[tracking_video_output, video_state, interactive_state, run_status, run_status2] | |
) | |
# inpaint video from select image and mask | |
inpaint_video_predict_button.click( | |
fn=inpaint_video, | |
inputs=[video_state, resize_ratio_number, dilate_radius_number, raft_iter_number, subvideo_length_number, neighbor_length_number, ref_stride_number, mask_dropdown], | |
outputs=[inpaiting_video_output, run_status, run_status2] | |
) | |
# click to get mask | |
mask_dropdown.change( | |
fn=show_mask, | |
inputs=[video_state, interactive_state, mask_dropdown], | |
outputs=[template_frame, run_status, run_status2] | |
) | |
# clear input | |
video_input.change( | |
fn=restart, | |
inputs=[], | |
outputs=[ | |
video_state, | |
interactive_state, | |
click_state, | |
tracking_video_output, inpaiting_video_output, | |
template_frame, | |
tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click, | |
Add_mask_button, template_frame, tracking_video_predict_button, tracking_video_output, inpaiting_video_output, remove_mask_button,inpaint_video_predict_button, step2_title, step3_title, mask_dropdown, video_info, run_status, run_status2 | |
], | |
queue=False, | |
show_progress=False) | |
video_input.clear( | |
fn=restart, | |
inputs=[], | |
outputs=[ | |
video_state, | |
interactive_state, | |
click_state, | |
tracking_video_output, inpaiting_video_output, | |
template_frame, | |
tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, clear_button_click, | |
Add_mask_button, template_frame, tracking_video_predict_button, tracking_video_output, inpaiting_video_output, remove_mask_button,inpaint_video_predict_button, step2_title, step3_title, mask_dropdown, video_info, run_status, run_status2 | |
], | |
queue=False, | |
show_progress=False) | |
# points clear | |
clear_button_click.click( | |
fn = clear_click, | |
inputs = [video_state, click_state,], | |
outputs = [template_frame,click_state, run_status, run_status2], | |
) | |
# set example | |
gr.Markdown("## Examples") | |
gr.Examples( | |
examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample0.mp4", "test-sample1.mp4", "test-sample2.mp4", "test-sample3.mp4", "test-sample4.mp4"]], | |
inputs=[video_input], | |
) | |
gr.Markdown(article) | |
iface.queue(concurrency_count=1) | |
iface.launch(debug=True) |