import gradio as gr import numpy as np import os from PIL import Image import torch import torchvision.transforms as transforms import options import test import importlib from scipy.interpolate import interp1d, splev, splprep import cv2 import subprocess subprocess.run(["bash", "install_imaginaire.sh"]) def get_single(sat_img, style_img, x_offset, y_offset): name = '' for i in [name for name in os.listdir('demo_img') if 'case' in name]: style = Image.open('demo_img/{}/groundview.image.png'.format(i)).convert('RGB') style =np.array(style) if (style == style_img).all(): name = i break input_dict = {} trans = transforms.ToTensor() input_dict['sat'] = trans(sat_img) input_dict['pano'] = trans(style_img) input_dict['paths'] = "demo.png" sky = trans(Image.open('demo_img/{}/groundview.sky.png'.format(name)).convert("L")) input_a = input_dict['pano']*sky sky_histc = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))]) input_dict['sky_histc'] = sky_histc input_dict['sky_mask'] = sky for key in input_dict.keys(): if isinstance(input_dict[key], torch.Tensor): input_dict[key] = input_dict[key].unsqueeze(0) args = ["--yaml=sat2density_cvact", "--test_ckpt_path=wandb/run-20230219_141512-2u87bj8w/files/checkpoint/model.pth", "--task=test_vid", "--demo_img=demo_img/case1/satview-input.png", "--sty_img=demo_img/case1/groundview.image.png", "--save_dir=output"] opt_cmd = options.parse_arguments(args=args) opt = options.set(opt_cmd=opt_cmd) opt.isTrain = False opt.name = opt.yaml if opt.name is None else opt.name opt.batch_size = 1 m = importlib.import_module("model.{}".format(opt.model)) model = m.Model(opt) # m.load_dataset(opt) model.build_networks(opt) ckpt = torch.load(opt.test_ckpt_path, map_location='cpu') model.netG.load_state_dict(ckpt['netG']) model.netG.eval() model.set_input(input_dict) model.style_temp = model.sky_histc opt.origin_H_W = [-(y_offset*256-128)/128, (x_offset*256-128)/128] # TODO: hard code should be removed in the future model.forward(opt) rgb = model.out_put.pred[0].clamp(min=0,max=1.0).cpu().detach().numpy().transpose((1,2,0)) rgb = np.array(rgb*255, dtype=np.uint8) return rgb def get_video(sat_img, style_img, positions): name = '' for i in [name for name in os.listdir('demo_img') if 'case' in name]: style = Image.open('demo_img/{}/groundview.image.png'.format(i)).convert('RGB') style =np.array(style) if (style == style_img).all(): name = i break input_dict = {} trans = transforms.ToTensor() input_dict['sat'] = trans(sat_img) input_dict['pano'] = trans(style_img) input_dict['paths'] = "demo.png" sky = trans(Image.open('demo_img/{}/groundview.sky.png'.format(name)).convert("L")) input_a = input_dict['pano']*sky sky_histc = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))]) input_dict['sky_histc'] = sky_histc input_dict['sky_mask'] = sky for key in input_dict.keys(): if isinstance(input_dict[key], torch.Tensor): input_dict[key] = input_dict[key].unsqueeze(0) args = ["--yaml=sat2density_cvact", "--test_ckpt_path=wandb/run-20230219_141512-2u87bj8w/files/checkpoint/model.pth", "--task=test_vid", "--demo_img=demo_img/case1/satview-input.png", "--sty_img=demo_img/case1/groundview.image.png", "--save_dir=output"] opt_cmd = options.parse_arguments(args=args) opt = options.set(opt_cmd=opt_cmd) opt.isTrain = False opt.name = opt.yaml if opt.name is None else opt.name opt.batch_size = 1 m = importlib.import_module("model.{}".format(opt.model)) model = m.Model(opt) # m.load_dataset(opt) model.build_networks(opt) ckpt = torch.load(opt.test_ckpt_path, map_location='cpu') model.netG.load_state_dict(ckpt['netG']) model.netG.eval() model.set_input(input_dict) model.style_temp = model.sky_histc unique_lst = list(dict.fromkeys(positions)) pixels = [] for x in positions: if x in unique_lst: if x not in pixels: pixels.append(x) pixels = np.array(pixels) tck, u = splprep(pixels.T, s=25, per=0) u_new = np.linspace(u.min(), u.max(), 80) x_new, y_new = splev(u_new, tck) smooth_path = np.array([x_new,y_new]).T rendered_image_list = [] rendered_depth_list = [] for i, (x,y) in enumerate(smooth_path): opt.origin_H_W = [(y-128)/128, (x-128)/128] # TODO: hard code should be removed in the future print('Rendering at ({}, {})'.format(x,y)) model.forward(opt) rgb = model.out_put.pred[0].clamp(min=0,max=1.0).cpu().detach().numpy().transpose((1,2,0)) rgb = np.array(rgb*255, dtype=np.uint8) rendered_image_list.append(rgb) rendered_depth_list.append( model.out_put.depth[0,0].cpu().detach().numpy() ) output_video_path = 'output_video.mp4' frame_rate = 15 frame_width = 512 frame_height = 128 fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_video_path, fourcc, frame_rate, (frame_width, frame_height)) for image_np in rendered_image_list: image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) out.write(image_np) out.release() return "output_video.mp4" def copy_image(image): return image def show_image_and_point(image, x, y): x = int(x*image.shape[1]) y = image.shape[0]-int(y*image.shape[0]) mask = np.zeros(image.shape[:2]) radius = min(image.shape[0], image.shape[1])//60 for i in range(x-radius-2, x+radius+2): for j in range(y-radius-2, y+radius+2): if (i-x)**2+(j-y)**2<=radius**2: mask[j, i] = 1 return (image, [(mask, 'render point')]) def add_select_point(image, evt: gr.SelectData, state1): if state1 == None: state1 = [] x, y = evt.index state1.append((x, y)) print(state1) radius = min(image.shape[0], image.shape[1])//60 for i in range(x-radius-2, x+radius+2): for j in range(y-radius-2, y+radius+2): if (i-x)**2+(j-y)**2<=radius**2: image[j, i, :] = 0 return image, state1 def reset_select_points(image): return image, [] with gr.Blocks() as demo: gr.Markdown("# Sat2Density Demos") gr.Markdown("### select/upload the satllite image and select the style image") with gr.Row(): with gr.Column(): sat_img = gr.Image(source='upload', shape=[256, 256], interactive=True) img_examples = gr.Examples(examples=['demo_img/{}/satview-input.png'.format(i) for i in os.listdir('demo_img') if 'case' in i], inputs=sat_img, outputs=None, examples_per_page=20) with gr.Column(): style_img = gr.Image() style_examples = gr.Examples(examples=['demo_img/{}/groundview.image.png'.format(i) for i in os.listdir('demo_img') if 'case' in i], inputs=style_img, outputs=None, examples_per_page=20) gr.Markdown("### select a certain point to generate single groundview image") with gr.Row(): with gr.Column(): with gr.Row(): with gr.Column(): slider_x = gr.Slider(0.2, 0.8, 0.5, label="x-axis position") slider_y = gr.Slider(0.2, 0.8, 0.5, label="y-axis position") btn_single = gr.Button(label="demo1") annotation_image = gr.AnnotatedImage() out_single = gr.Image() gr.Markdown("### draw a trajectory on the map to generate video") state_select_points = gr.State() with gr.Row(): with gr.Column(): draw_img = gr.Image(shape=[256, 256], interactive=True) with gr.Column(): out_video = gr.Video() reset_btn =gr.Button(value="Reset") btn_video = gr.Button(label="demo1") sat_img.change(copy_image, inputs = sat_img, outputs=draw_img) draw_img.select(add_select_point, [draw_img, state_select_points], [draw_img, state_select_points]) sat_img.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image) slider_x.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image, show_progress='hidden') slider_y.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image, show_progress='hidden') btn_single.click(get_single, inputs = [sat_img, style_img, slider_x, slider_y], outputs=out_single) reset_btn.click(reset_select_points, [sat_img], [draw_img, state_select_points]) btn_video.click(get_video, inputs=[sat_img, style_img, state_select_points], outputs=out_video) # 触发 demo.launch()