|
import os |
|
import sys |
|
import uuid |
|
|
|
import gradio as gr |
|
import mediapy |
|
import numpy as np |
|
import cv2 |
|
import matplotlib |
|
import torch |
|
|
|
from locotrack_pytorch.models.locotrack_model import load_model |
|
from viz_utils import paint_point_track |
|
|
|
|
|
PREVIEW_WIDTH = 768 |
|
VIDEO_INPUT_RESO = (256, 256) |
|
POINT_SIZE = 4 |
|
FRAME_LIMIT = 300 |
|
|
|
|
|
def get_point(frame_num, video_queried_preview, query_points, query_points_color, query_count, evt: gr.SelectData): |
|
print(f"You selected {(evt.index[0], evt.index[1], frame_num)}") |
|
|
|
current_frame = video_queried_preview[int(frame_num)] |
|
|
|
|
|
query_points[int(frame_num)].append((evt.index[0], evt.index[1], frame_num)) |
|
|
|
|
|
color = matplotlib.colormaps.get_cmap("gist_rainbow")(query_count % 20 / 20) |
|
color = (int(color[0] * 255), int(color[1] * 255), int(color[2] * 255)) |
|
print(f"Color: {color}") |
|
query_points_color[int(frame_num)].append(color) |
|
|
|
|
|
x, y = evt.index |
|
current_frame_draw = cv2.circle(current_frame, (x, y), POINT_SIZE, color, -1) |
|
|
|
|
|
video_queried_preview[int(frame_num)] = current_frame_draw |
|
|
|
|
|
query_count += 1 |
|
return ( |
|
current_frame_draw, |
|
video_queried_preview, |
|
query_points, |
|
query_points_color, |
|
query_count |
|
) |
|
|
|
|
|
def undo_point(frame_num, video_preview, video_queried_preview, query_points, query_points_color, query_count): |
|
if len(query_points[int(frame_num)]) == 0: |
|
return ( |
|
video_queried_preview[int(frame_num)], |
|
video_queried_preview, |
|
query_points, |
|
query_points_color, |
|
query_count |
|
) |
|
|
|
|
|
query_points[int(frame_num)].pop(-1) |
|
query_points_color[int(frame_num)].pop(-1) |
|
|
|
|
|
current_frame_draw = video_preview[int(frame_num)].copy() |
|
for point, color in zip(query_points[int(frame_num)], query_points_color[int(frame_num)]): |
|
x, y, _ = point |
|
current_frame_draw = cv2.circle(current_frame_draw, (x, y), POINT_SIZE, color, -1) |
|
|
|
|
|
query_count -= 1 |
|
|
|
|
|
video_queried_preview[int(frame_num)] = current_frame_draw |
|
return ( |
|
current_frame_draw, |
|
video_queried_preview, |
|
query_points, |
|
query_points_color, |
|
query_count |
|
) |
|
|
|
|
|
def clear_frame_fn(frame_num, video_preview, video_queried_preview, query_points, query_points_color, query_count): |
|
query_count -= len(query_points[int(frame_num)]) |
|
|
|
query_points[int(frame_num)] = [] |
|
query_points_color[int(frame_num)] = [] |
|
|
|
video_queried_preview[int(frame_num)] = video_preview[int(frame_num)].copy() |
|
|
|
return ( |
|
video_preview[int(frame_num)], |
|
video_queried_preview, |
|
query_points, |
|
query_points_color, |
|
query_count |
|
) |
|
|
|
|
|
|
|
def clear_all_fn(frame_num, video_preview): |
|
return ( |
|
video_preview[int(frame_num)], |
|
video_preview.copy(), |
|
[[] for _ in range(len(video_preview))], |
|
[[] for _ in range(len(video_preview))], |
|
0 |
|
) |
|
|
|
|
|
def choose_frame(frame_num, video_preview_array): |
|
return video_preview_array[int(frame_num)] |
|
|
|
|
|
def extract_feature(video_input, model_size="small"): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
dtype = torch.bfloat16 if device == "cuda" else torch.float16 |
|
|
|
model = load_model(model_size=model_size).to(device) |
|
|
|
video_input = (video_input / 255.0) * 2 - 1 |
|
video_input = torch.tensor(video_input).unsqueeze(0).to(device, dtype) |
|
|
|
with torch.autocast(device_type=device, dtype=dtype): |
|
with torch.no_grad(): |
|
feature = model.get_feature_grids(video_input) |
|
|
|
return feature |
|
|
|
|
|
def preprocess_video_input(video_path, model_size): |
|
video_arr = mediapy.read_video(video_path) |
|
video_fps = video_arr.metadata.fps |
|
num_frames = video_arr.shape[0] |
|
if num_frames > FRAME_LIMIT: |
|
gr.Warning(f"The video is too long. Only the first {FRAME_LIMIT} frames will be used.", duration=5) |
|
video_arr = video_arr[:FRAME_LIMIT] |
|
num_frames = FRAME_LIMIT |
|
|
|
|
|
height, width = video_arr.shape[1:3] |
|
new_height, new_width = int(PREVIEW_WIDTH * height / width), PREVIEW_WIDTH |
|
|
|
preview_video = mediapy.resize_video(video_arr, (new_height, new_width)) |
|
input_video = mediapy.resize_video(video_arr, VIDEO_INPUT_RESO) |
|
|
|
preview_video = np.array(preview_video) |
|
input_video = np.array(input_video) |
|
|
|
video_feature = extract_feature(input_video, model_size) |
|
|
|
return ( |
|
video_arr, |
|
preview_video, |
|
preview_video.copy(), |
|
input_video, |
|
video_feature, |
|
video_fps, |
|
gr.update(open=False), |
|
model_size, |
|
preview_video[0], |
|
gr.update(minimum=0, maximum=num_frames - 1, value=0, interactive=True), |
|
[[] for _ in range(num_frames)], |
|
[[] for _ in range(num_frames)], |
|
[[] for _ in range(num_frames)], |
|
0, |
|
gr.update(interactive=True), |
|
gr.update(interactive=True), |
|
gr.update(interactive=True), |
|
gr.update(interactive=True), |
|
) |
|
|
|
|
|
def track( |
|
model_size, |
|
video_preview, |
|
video_input, |
|
video_feature, |
|
video_fps, |
|
query_points, |
|
query_points_color, |
|
query_count, |
|
): |
|
if query_count == 0: |
|
gr.Warning("Please add query points before tracking.", duration=5) |
|
return None |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
dtype = torch.bfloat16 if device == "cuda" else torch.float16 |
|
|
|
|
|
query_points_tensor = [] |
|
for frame_points in query_points: |
|
query_points_tensor.extend(frame_points) |
|
|
|
query_points_tensor = torch.tensor(query_points_tensor).float() |
|
query_points_tensor *= torch.tensor([ |
|
VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0], 1 |
|
]) / torch.tensor([ |
|
[video_preview.shape[2], video_preview.shape[1], 1] |
|
]) |
|
query_points_tensor = query_points_tensor[None].flip(-1).to(device, dtype) |
|
|
|
|
|
video_input = (video_input / 255.0) * 2 - 1 |
|
video_input = torch.tensor(video_input).unsqueeze(0).to(device, dtype) |
|
|
|
model = load_model(model_size=model_size).to(device) |
|
with torch.autocast(device_type=device, dtype=dtype): |
|
with torch.no_grad(): |
|
output = model(video_input, query_points_tensor, feature_grids=video_feature) |
|
|
|
tracks = output['tracks'][0].cpu() |
|
tracks = tracks * torch.tensor([ |
|
video_preview.shape[2], video_preview.shape[1] |
|
]) / torch.tensor([ |
|
VIDEO_INPUT_RESO[1], VIDEO_INPUT_RESO[0] |
|
]) |
|
tracks = tracks.numpy() |
|
|
|
|
|
occlusion_logits = output['occlusion'] |
|
pred_occ = torch.sigmoid(occlusion_logits) |
|
if 'expected_dist' in output: |
|
expected_dist = output['expected_dist'] |
|
pred_occ = 1 - (1 - pred_occ) * (1 - torch.sigmoid(expected_dist)) |
|
|
|
pred_occ = (pred_occ > 0.5)[0].cpu().numpy() |
|
|
|
|
|
colors = [] |
|
for frame_colors in query_points_color: |
|
colors.extend(frame_colors) |
|
colors = np.array(colors) |
|
|
|
painted_video = paint_point_track( |
|
video_preview, |
|
tracks, |
|
~pred_occ, |
|
colors, |
|
) |
|
|
|
|
|
video_file_name = uuid.uuid4().hex + ".mp4" |
|
video_path = os.path.join(os.path.dirname(__file__), "tmp") |
|
video_file_path = os.path.join(video_path, video_file_name) |
|
os.makedirs(video_path, exist_ok=True) |
|
|
|
mediapy.write_video(video_file_path, painted_video, fps=video_fps) |
|
|
|
return video_file_path |
|
|
|
|
|
with gr.Blocks() as demo: |
|
video = gr.State() |
|
video_queried_preview = gr.State() |
|
video_preview = gr.State() |
|
video_input = gr.State() |
|
video_feautre = gr.State() |
|
video_fps = gr.State(24) |
|
model_size = gr.State("small") |
|
|
|
query_points = gr.State([]) |
|
query_points_color = gr.State([]) |
|
is_tracked_query = gr.State([]) |
|
query_count = gr.State(0) |
|
|
|
gr.Markdown("# LocoTrack Demo") |
|
gr.Markdown("This is an interactive demo for LocoTrack. For more details, please refer to the [GitHub repository](https://github.com/KU-CVLAB/LocoTrack) or the [paper](https://arxiv.org/abs/2407.15420).") |
|
|
|
gr.Markdown("## First step: Choose the model size and upload your video") |
|
with gr.Row(): |
|
with gr.Accordion("Your video input", open=True) as video_in_drawer: |
|
model_size_selection = gr.Radio( |
|
label="Model Size", |
|
choices=["small", "base"], |
|
value="small", |
|
) |
|
video_in = gr.Video(label="Video Input", format="mp4") |
|
|
|
gr.Markdown("## Second step: Add query points to track") |
|
with gr.Row(): |
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
query_frames = gr.Slider( |
|
minimum=0, maximum=100, value=0, step=1, label="Choose Frame", interactive=False) |
|
with gr.Row(): |
|
undo = gr.Button("Undo", interactive=False) |
|
clear_frame = gr.Button("Clear Frame", interactive=False) |
|
clear_all = gr.Button("Clear All", interactive=False) |
|
|
|
with gr.Row(): |
|
current_frame = gr.Image( |
|
label="Click to add query points", |
|
type="numpy", |
|
interactive=False |
|
) |
|
|
|
with gr.Row(): |
|
track_button = gr.Button("Track", interactive=False) |
|
|
|
with gr.Column(): |
|
output_video = gr.Video( |
|
label="Output Video", |
|
interactive=False, |
|
autoplay=True, |
|
loop=True, |
|
) |
|
|
|
video_in.upload( |
|
fn = preprocess_video_input, |
|
inputs = [video_in, model_size_selection], |
|
outputs = [ |
|
video, |
|
video_preview, |
|
video_queried_preview, |
|
video_input, |
|
video_feautre, |
|
video_fps, |
|
video_in_drawer, |
|
model_size, |
|
current_frame, |
|
query_frames, |
|
query_points, |
|
query_points_color, |
|
is_tracked_query, |
|
query_count, |
|
undo, |
|
clear_frame, |
|
clear_all, |
|
track_button, |
|
], |
|
queue = False |
|
) |
|
|
|
query_frames.change( |
|
fn = choose_frame, |
|
inputs = [query_frames, video_queried_preview], |
|
outputs = [ |
|
current_frame, |
|
], |
|
queue = False |
|
) |
|
|
|
current_frame.select( |
|
fn = get_point, |
|
inputs = [ |
|
query_frames, |
|
video_queried_preview, |
|
query_points, |
|
query_points_color, |
|
query_count, |
|
], |
|
outputs = [ |
|
current_frame, |
|
video_queried_preview, |
|
query_points, |
|
query_points_color, |
|
query_count |
|
], |
|
queue = False |
|
) |
|
|
|
undo.click( |
|
fn = undo_point, |
|
inputs = [ |
|
query_frames, |
|
video_preview, |
|
video_queried_preview, |
|
query_points, |
|
query_points_color, |
|
query_count |
|
], |
|
outputs = [ |
|
current_frame, |
|
video_queried_preview, |
|
query_points, |
|
query_points_color, |
|
query_count |
|
], |
|
queue = False |
|
) |
|
|
|
clear_frame.click( |
|
fn = clear_frame_fn, |
|
inputs = [ |
|
query_frames, |
|
video_preview, |
|
video_queried_preview, |
|
query_points, |
|
query_points_color, |
|
query_count |
|
], |
|
outputs = [ |
|
current_frame, |
|
video_queried_preview, |
|
query_points, |
|
query_points_color, |
|
query_count |
|
], |
|
queue = False |
|
) |
|
|
|
clear_all.click( |
|
fn = clear_all_fn, |
|
inputs = [ |
|
query_frames, |
|
video_preview, |
|
], |
|
outputs = [ |
|
current_frame, |
|
video_queried_preview, |
|
query_points, |
|
query_points_color, |
|
query_count |
|
], |
|
queue = False |
|
) |
|
|
|
track_button.click( |
|
fn = track, |
|
inputs = [ |
|
model_size, |
|
video_preview, |
|
video_input, |
|
video_feautre, |
|
video_fps, |
|
query_points, |
|
query_points_color, |
|
query_count, |
|
], |
|
outputs = [ |
|
output_video, |
|
], |
|
queue = True, |
|
) |
|
|
|
demo.launch(show_api=False, show_error=True, debug=True) |