Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import os | |
from PIL import Image | |
import torch | |
import torchvision.transforms as transforms | |
import options | |
import test | |
import importlib | |
from scipy.interpolate import interp1d, splev, splprep | |
import cv2 | |
import subprocess | |
subprocess.run(["bash", "install_imaginaire.sh"]) | |
def get_single(sat_img, style_img, x_offset, y_offset): | |
name = '' | |
for i in [name for name in os.listdir('demo_img') if 'case' in name]: | |
style = Image.open('demo_img/{}/groundview.image.png'.format(i)).convert('RGB') | |
style =np.array(style) | |
if (style == style_img).all(): | |
name = i | |
break | |
input_dict = {} | |
trans = transforms.ToTensor() | |
input_dict['sat'] = trans(sat_img) | |
input_dict['pano'] = trans(style_img) | |
input_dict['paths'] = "demo.png" | |
sky = trans(Image.open('demo_img/{}/groundview.sky.png'.format(name)).convert("L")) | |
input_a = input_dict['pano']*sky | |
sky_histc = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))]) | |
input_dict['sky_histc'] = sky_histc | |
input_dict['sky_mask'] = sky | |
for key in input_dict.keys(): | |
if isinstance(input_dict[key], torch.Tensor): | |
input_dict[key] = input_dict[key].unsqueeze(0) | |
args = ["--yaml=sat2density_cvact", "--test_ckpt_path=wandb/run-20230219_141512-2u87bj8w/files/checkpoint/model.pth", "--task=test_vid", "--demo_img=demo_img/case1/satview-input.png", | |
"--sty_img=demo_img/case1/groundview.image.png", "--save_dir=output"] | |
opt_cmd = options.parse_arguments(args=args) | |
opt = options.set(opt_cmd=opt_cmd) | |
opt.isTrain = False | |
opt.name = opt.yaml if opt.name is None else opt.name | |
opt.batch_size = 1 | |
m = importlib.import_module("model.{}".format(opt.model)) | |
model = m.Model(opt) | |
# m.load_dataset(opt) | |
model.build_networks(opt) | |
ckpt = torch.load(opt.test_ckpt_path, map_location='cpu') | |
model.netG.load_state_dict(ckpt['netG']) | |
model.netG.eval() | |
model.set_input(input_dict) | |
model.style_temp = model.sky_histc | |
opt.origin_H_W = [-(y_offset*256-128)/128, (x_offset*256-128)/128] # TODO: hard code should be removed in the future | |
model.forward(opt) | |
rgb = model.out_put.pred[0].clamp(min=0,max=1.0).cpu().detach().numpy().transpose((1,2,0)) | |
rgb = np.array(rgb*255, dtype=np.uint8) | |
return rgb | |
def get_video(sat_img, style_img, positions): | |
name = '' | |
for i in [name for name in os.listdir('demo_img') if 'case' in name]: | |
style = Image.open('demo_img/{}/groundview.image.png'.format(i)).convert('RGB') | |
style =np.array(style) | |
if (style == style_img).all(): | |
name = i | |
break | |
input_dict = {} | |
trans = transforms.ToTensor() | |
input_dict['sat'] = trans(sat_img) | |
input_dict['pano'] = trans(style_img) | |
input_dict['paths'] = "demo.png" | |
sky = trans(Image.open('demo_img/{}/groundview.sky.png'.format(name)).convert("L")) | |
input_a = input_dict['pano']*sky | |
sky_histc = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))]) | |
input_dict['sky_histc'] = sky_histc | |
input_dict['sky_mask'] = sky | |
for key in input_dict.keys(): | |
if isinstance(input_dict[key], torch.Tensor): | |
input_dict[key] = input_dict[key].unsqueeze(0) | |
args = ["--yaml=sat2density_cvact", "--test_ckpt_path=wandb/run-20230219_141512-2u87bj8w/files/checkpoint/model.pth", "--task=test_vid", "--demo_img=demo_img/case1/satview-input.png", | |
"--sty_img=demo_img/case1/groundview.image.png", "--save_dir=output"] | |
opt_cmd = options.parse_arguments(args=args) | |
opt = options.set(opt_cmd=opt_cmd) | |
opt.isTrain = False | |
opt.name = opt.yaml if opt.name is None else opt.name | |
opt.batch_size = 1 | |
m = importlib.import_module("model.{}".format(opt.model)) | |
model = m.Model(opt) | |
# m.load_dataset(opt) | |
model.build_networks(opt) | |
ckpt = torch.load(opt.test_ckpt_path, map_location='cpu') | |
model.netG.load_state_dict(ckpt['netG']) | |
model.netG.eval() | |
model.set_input(input_dict) | |
model.style_temp = model.sky_histc | |
unique_lst = list(dict.fromkeys(positions)) | |
pixels = [] | |
for x in positions: | |
if x in unique_lst: | |
if x not in pixels: | |
pixels.append(x) | |
pixels = np.array(pixels) | |
tck, u = splprep(pixels.T, s=25, per=0) | |
u_new = np.linspace(u.min(), u.max(), 80) | |
x_new, y_new = splev(u_new, tck) | |
smooth_path = np.array([x_new,y_new]).T | |
rendered_image_list = [] | |
rendered_depth_list = [] | |
for i, (x,y) in enumerate(smooth_path): | |
opt.origin_H_W = [(y-128)/128, (x-128)/128] # TODO: hard code should be removed in the future | |
print('Rendering at ({}, {})'.format(x,y)) | |
model.forward(opt) | |
rgb = model.out_put.pred[0].clamp(min=0,max=1.0).cpu().detach().numpy().transpose((1,2,0)) | |
rgb = np.array(rgb*255, dtype=np.uint8) | |
rendered_image_list.append(rgb) | |
rendered_depth_list.append( | |
model.out_put.depth[0,0].cpu().detach().numpy() | |
) | |
output_video_path = 'output_video.mp4' | |
frame_rate = 15 | |
frame_width = 512 | |
frame_height = 128 | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
out = cv2.VideoWriter(output_video_path, fourcc, frame_rate, (frame_width, frame_height)) | |
for image_np in rendered_image_list: | |
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) | |
out.write(image_np) | |
out.release() | |
return "output_video.mp4" | |
def copy_image(image): | |
return image | |
def show_image_and_point(image, x, y): | |
x = int(x*image.shape[1]) | |
y = image.shape[0]-int(y*image.shape[0]) | |
mask = np.zeros(image.shape[:2]) | |
radius = min(image.shape[0], image.shape[1])//60 | |
for i in range(x-radius-2, x+radius+2): | |
for j in range(y-radius-2, y+radius+2): | |
if (i-x)**2+(j-y)**2<=radius**2: | |
mask[j, i] = 1 | |
return (image, [(mask, 'render point')]) | |
def add_select_point(image, evt: gr.SelectData, state1): | |
if state1 == None: | |
state1 = [] | |
x, y = evt.index | |
state1.append((x, y)) | |
print(state1) | |
radius = min(image.shape[0], image.shape[1])//60 | |
for i in range(x-radius-2, x+radius+2): | |
for j in range(y-radius-2, y+radius+2): | |
if (i-x)**2+(j-y)**2<=radius**2: | |
image[j, i, :] = 0 | |
return image, state1 | |
def reset_select_points(image): | |
return image, [] | |
with gr.Blocks() as demo: | |
gr.Markdown("# Sat2Density Demos") | |
gr.Markdown("### select/upload the satllite image and select the style image") | |
with gr.Row(): | |
with gr.Column(): | |
sat_img = gr.Image(source='upload', shape=[256, 256], interactive=True) | |
img_examples = gr.Examples(examples=['demo_img/{}/satview-input.png'.format(i) for i in os.listdir('demo_img') if 'case' in i], | |
inputs=sat_img, outputs=None, examples_per_page=20) | |
with gr.Column(): | |
style_img = gr.Image() | |
style_examples = gr.Examples(examples=['demo_img/{}/groundview.image.png'.format(i) for i in os.listdir('demo_img') if 'case' in i], | |
inputs=style_img, outputs=None, examples_per_page=20) | |
gr.Markdown("### select a certain point to generate single groundview image") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
slider_x = gr.Slider(0.2, 0.8, 0.5, label="x-axis position") | |
slider_y = gr.Slider(0.2, 0.8, 0.5, label="y-axis position") | |
btn_single = gr.Button(label="demo1") | |
annotation_image = gr.AnnotatedImage() | |
out_single = gr.Image() | |
gr.Markdown("### draw a trajectory on the map to generate video") | |
state_select_points = gr.State() | |
with gr.Row(): | |
with gr.Column(): | |
draw_img = gr.Image(shape=[256, 256], interactive=True) | |
with gr.Column(): | |
out_video = gr.Video() | |
reset_btn =gr.Button(value="Reset") | |
btn_video = gr.Button(label="demo1") | |
sat_img.change(copy_image, inputs = sat_img, outputs=draw_img) | |
draw_img.select(add_select_point, [draw_img, state_select_points], [draw_img, state_select_points]) | |
sat_img.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image) | |
slider_x.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image, show_progress='hidden') | |
slider_y.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image, show_progress='hidden') | |
btn_single.click(get_single, inputs = [sat_img, style_img, slider_x, slider_y], outputs=out_single) | |
reset_btn.click(reset_select_points, [sat_img], [draw_img, state_select_points]) | |
btn_video.click(get_video, inputs=[sat_img, style_img, state_select_points], outputs=out_video) # 触发 | |
demo.launch() |