Spaces:
Runtime error
Runtime error
initial
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +122 -13
- __pycache__/options.cpython-38.pyc +0 -0
- __pycache__/test.cpython-38.pyc +0 -0
- __pycache__/utils.cpython-38.pyc +0 -0
- app.py +244 -0
- data/CVACT_Shi.py +119 -0
- data/CVUSA.py +86 -0
- dataset/INSTALL.md +32 -0
- demo_img/case1/groundview.image.png +0 -0
- demo_img/case1/groundview.sky.png +0 -0
- demo_img/case1/satview-input.png +0 -0
- demo_img/case10/groundview.image.png +0 -0
- demo_img/case10/groundview.sky.png +0 -0
- demo_img/case10/satview-input.png +0 -0
- demo_img/case11/groundview.image.png +0 -0
- demo_img/case11/groundview.sky.png +0 -0
- demo_img/case11/satview-input.png +0 -0
- demo_img/case12/groundview.image.png +0 -0
- demo_img/case12/groundview.sky.png +0 -0
- demo_img/case12/satview-input.png +0 -0
- demo_img/case13/groundview.image.png +0 -0
- demo_img/case13/groundview.sky.png +0 -0
- demo_img/case13/satview-input.png +0 -0
- demo_img/case2/groundview.image.png +0 -0
- demo_img/case2/groundview.sky.png +0 -0
- demo_img/case2/satview-input.png +0 -0
- demo_img/case3/groundview.image.png +0 -0
- demo_img/case3/groundview.sky.png +0 -0
- demo_img/case3/satview-input.png +0 -0
- demo_img/case4/groundview.image.png +0 -0
- demo_img/case4/groundview.sky.png +0 -0
- demo_img/case4/satview-input.png +0 -0
- demo_img/case5/groundview.image.png +0 -0
- demo_img/case5/groundview.sky.png +0 -0
- demo_img/case5/satview-input.png +0 -0
- demo_img/case6/groundview.image.png +0 -0
- demo_img/case6/groundview.sky.png +0 -0
- demo_img/case6/satview-input.png +0 -0
- demo_img/case7/groundview.image.png +0 -0
- demo_img/case7/groundview.sky.png +0 -0
- demo_img/case7/satview-input.png +0 -0
- demo_img/case8/groundview.image.png +0 -0
- demo_img/case8/groundview.sky.png +0 -0
- demo_img/case8/satview-input.png +0 -0
- demo_img/case9/groundview.image.png +0 -0
- demo_img/case9/groundview.sky.png +0 -0
- demo_img/case9/satview-input.png +0 -0
- demo_img/runall.sh +30 -0
- imaginaire/__init__.py +4 -0
- imaginaire/__pycache__/__init__.cpython-38.pyc +0 -0
README.md
CHANGED
@@ -1,13 +1,122 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Sat2Density: Faithful Density Learning from Satellite-Ground Image Pairs
|
2 |
+
|
3 |
+
> [Ming Qian](https://qianmingduowan.github.io/), Jincheng Xiong, [Gui-Song Xia](http://www.captain-whu.com/xia_En.html), [Nan Xue](https://xuenan.net)
|
4 |
+
>
|
5 |
+
> IEEE/CVF International Conference on Computer Vision (ICCV), 2023
|
6 |
+
>
|
7 |
+
> [Project](https://sat2density.github.io/) | [Paper](https://arxiv.org/abs/2303.14672) | [Data]() | [Install.md](docs/INSTALL.md)
|
8 |
+
|
9 |
+
> <p align="center" float="left">
|
10 |
+
> <img src="docs/figures/demo/case1.sat.gif" alt="drawing" width="19%">
|
11 |
+
> <img src="docs/figures/demo-density/case1.gif" alt="drawing" width="38%">
|
12 |
+
> <img src="docs/figures/demo/case1.render.gif" alt="drawing" width="38%">
|
13 |
+
> </p>
|
14 |
+
|
15 |
+
> <p align="center" float="left">
|
16 |
+
> <img src="docs/figures/demo/case2.sat.gif" alt="drawing" width="19%">
|
17 |
+
> <img src="docs/figures/demo-density/case2.gif" alt="drawing" width="38%">
|
18 |
+
> <img src="docs/figures/demo/case2.render.gif" alt="drawing" width="38%">
|
19 |
+
> </p>
|
20 |
+
|
21 |
+
> <p align="center" float="left">
|
22 |
+
> <img src="docs/figures/demo/case3.sat.gif" alt="drawing" width="19%">
|
23 |
+
> <img src="docs/figures/demo-density/case3.gif" alt="drawing" width="38%">
|
24 |
+
> <img src="docs/figures/demo/case3.render.gif" alt="drawing" width="38%">
|
25 |
+
> </p>
|
26 |
+
|
27 |
+
> <p align="center" float="left">
|
28 |
+
> <img src="docs/figures/demo/case4.sat.gif" alt="drawing" width="19%">
|
29 |
+
> <img src="docs/figures/demo-density/case4.gif" alt="drawing" width="38%">
|
30 |
+
> <img src="docs/figures/demo/case4.render.gif" alt="drawing" width="38%">
|
31 |
+
> </p>
|
32 |
+
|
33 |
+
## Checkpoints Downloading
|
34 |
+
> Two checkpoints for CVACT and CVUSA can be found from [thisurl](https://github.com/sat2density/checkpoints/releases). You can also run the following command to download them.
|
35 |
+
```
|
36 |
+
bash scripts/download_weights.sh
|
37 |
+
```
|
38 |
+
|
39 |
+
## QuickStart Demo
|
40 |
+
### Video Synthesis
|
41 |
+
#### Example Usage
|
42 |
+
```
|
43 |
+
python test.py --yaml=sat2density_cvact \
|
44 |
+
--test_ckpt_path=2u87bj8w \
|
45 |
+
--task=test_vid \
|
46 |
+
--demo_img=demo_img/case1/satview-input.png \
|
47 |
+
--sty_img=demo_img/case1/groundview.image.png \
|
48 |
+
--save_dir=results/case1
|
49 |
+
```
|
50 |
+
####
|
51 |
+
|
52 |
+
### Illumination Interpolation
|
53 |
+
<!-- ```
|
54 |
+
bash inference/quick_demo_interpolation.sh
|
55 |
+
``` -->
|
56 |
+
```
|
57 |
+
python test.py --task=test_interpolation \
|
58 |
+
--yaml=sat2density_cvact \
|
59 |
+
--test_ckpt_path=2u87bj8w \
|
60 |
+
--sty_img1=demo_img/case9/groundview.image.png \
|
61 |
+
--sty_img2=demo_img/case7/groundview.image.png \
|
62 |
+
--demo_img=demo_img/case3/satview-input.png \
|
63 |
+
--save_dir=results/case2
|
64 |
+
```
|
65 |
+
|
66 |
+
## Train & Inference
|
67 |
+
- *We trained our model using 1 V100 32GB GPU. The training phase will take about 20 hours.*
|
68 |
+
- *For data preparation, please check out [data.md](dataset/INSTALL.md).*
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
### Inference
|
74 |
+
|
75 |
+
To test Center Ground-View Synthesis setting
|
76 |
+
If you want save results, please add --task=vis_test
|
77 |
+
```bash
|
78 |
+
# CVACT
|
79 |
+
python offline_train_test.py --yaml=sat2density_cvact --test_ckpt_path=2u87bj8w
|
80 |
+
# CVUSA
|
81 |
+
python offline_train_test.py --yaml=sat2density_cvusa --test_ckpt_path=2cqv8uh4
|
82 |
+
```
|
83 |
+
|
84 |
+
To test inference with different illumination
|
85 |
+
```bash
|
86 |
+
# CVACT
|
87 |
+
bash inference/single_style_test_cvact.sh
|
88 |
+
# CVUSA
|
89 |
+
bash inference/single_style_test_cvusa.sh
|
90 |
+
```
|
91 |
+
|
92 |
+
To test synthesis ground videos
|
93 |
+
```bash
|
94 |
+
bash inference/synthesis_video.sh
|
95 |
+
```
|
96 |
+
|
97 |
+
## Training
|
98 |
+
|
99 |
+
### Training command
|
100 |
+
|
101 |
+
```bash
|
102 |
+
# CVACT
|
103 |
+
CUDA_VISIBLE_DEVICES=X python train.py --yaml=sat2density_cvact
|
104 |
+
# CVUSA
|
105 |
+
CUDA_VISIBLE_DEVICES=X python train.py --yaml=sat2density_cvusa
|
106 |
+
```
|
107 |
+
|
108 |
+
## Citation
|
109 |
+
If you use this code for your research, please cite
|
110 |
+
|
111 |
+
```
|
112 |
+
@inproceedings{qian2021sat2density,
|
113 |
+
title={Sat2Density: Faithful Density Learning from Satellite-Ground Image Pairs},
|
114 |
+
author={Qian, Ming and Xiong, Jincheng and Xia, Gui-Song and Xue, Nan},
|
115 |
+
booktitle={ICCV},
|
116 |
+
year={2023}
|
117 |
+
}
|
118 |
+
```
|
119 |
+
|
120 |
+
## License
|
121 |
+
This work is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.
|
122 |
+
For commercial use, please contact [mingqian@whu.edu.cn].
|
__pycache__/options.cpython-38.pyc
ADDED
Binary file (3.74 kB). View file
|
|
__pycache__/test.cpython-38.pyc
ADDED
Binary file (9.11 kB). View file
|
|
__pycache__/utils.cpython-38.pyc
ADDED
Binary file (8.16 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
import options
|
8 |
+
import test
|
9 |
+
import importlib
|
10 |
+
from scipy.interpolate import interp1d, splev, splprep
|
11 |
+
import cv2
|
12 |
+
|
13 |
+
|
14 |
+
def get_single(sat_img, style_img, x_offset, y_offset):
|
15 |
+
name = ''
|
16 |
+
for i in [name for name in os.listdir('demo_img') if 'case' in name]:
|
17 |
+
style = Image.open('demo_img/{}/groundview.image.png'.format(i)).convert('RGB')
|
18 |
+
style =np.array(style)
|
19 |
+
if (style == style_img).all():
|
20 |
+
name = i
|
21 |
+
break
|
22 |
+
|
23 |
+
input_dict = {}
|
24 |
+
trans = transforms.ToTensor()
|
25 |
+
input_dict['sat'] = trans(sat_img)
|
26 |
+
input_dict['pano'] = trans(style_img)
|
27 |
+
input_dict['paths'] = "demo.png"
|
28 |
+
sky = trans(Image.open('demo_img/{}/groundview.sky.png'.format(name)).convert("L"))
|
29 |
+
input_a = input_dict['pano']*sky
|
30 |
+
sky_histc = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))])
|
31 |
+
input_dict['sky_histc'] = sky_histc
|
32 |
+
input_dict['sky_mask'] = sky
|
33 |
+
|
34 |
+
for key in input_dict.keys():
|
35 |
+
if isinstance(input_dict[key], torch.Tensor):
|
36 |
+
input_dict[key] = input_dict[key].unsqueeze(0)
|
37 |
+
|
38 |
+
args = ["--yaml=sat2density_cvact", "--test_ckpt_path=wandb/run-20230219_141512-2u87bj8w/files/checkpoint/model.pth", "--task=test_vid", "--demo_img=demo_img/case1/satview-input.png",
|
39 |
+
"--sty_img=demo_img/case1/groundview.image.png", "--save_dir=output"]
|
40 |
+
opt_cmd = options.parse_arguments(args=args)
|
41 |
+
opt = options.set(opt_cmd=opt_cmd)
|
42 |
+
opt.isTrain = False
|
43 |
+
opt.name = opt.yaml if opt.name is None else opt.name
|
44 |
+
opt.batch_size = 1
|
45 |
+
|
46 |
+
m = importlib.import_module("model.{}".format(opt.model))
|
47 |
+
model = m.Model(opt)
|
48 |
+
|
49 |
+
# m.load_dataset(opt)
|
50 |
+
model.build_networks(opt)
|
51 |
+
ckpt = torch.load(opt.test_ckpt_path, map_location='cpu')
|
52 |
+
model.netG.load_state_dict(ckpt['netG'])
|
53 |
+
model.netG.eval()
|
54 |
+
|
55 |
+
model.set_input(input_dict)
|
56 |
+
|
57 |
+
model.style_temp = model.sky_histc
|
58 |
+
opt.origin_H_W = [-(y_offset*256-128)/128, (x_offset*256-128)/128] # TODO: hard code should be removed in the future
|
59 |
+
|
60 |
+
model.forward(opt)
|
61 |
+
|
62 |
+
rgb = model.out_put.pred[0].clamp(min=0,max=1.0).cpu().detach().numpy().transpose((1,2,0))
|
63 |
+
rgb = np.array(rgb*255, dtype=np.uint8)
|
64 |
+
return rgb
|
65 |
+
|
66 |
+
def get_video(sat_img, style_img, positions):
|
67 |
+
name = ''
|
68 |
+
for i in [name for name in os.listdir('demo_img') if 'case' in name]:
|
69 |
+
style = Image.open('demo_img/{}/groundview.image.png'.format(i)).convert('RGB')
|
70 |
+
style =np.array(style)
|
71 |
+
if (style == style_img).all():
|
72 |
+
name = i
|
73 |
+
break
|
74 |
+
|
75 |
+
input_dict = {}
|
76 |
+
trans = transforms.ToTensor()
|
77 |
+
input_dict['sat'] = trans(sat_img)
|
78 |
+
input_dict['pano'] = trans(style_img)
|
79 |
+
input_dict['paths'] = "demo.png"
|
80 |
+
sky = trans(Image.open('demo_img/{}/groundview.sky.png'.format(name)).convert("L"))
|
81 |
+
input_a = input_dict['pano']*sky
|
82 |
+
sky_histc = torch.cat([input_a[i].histc()[10:] for i in reversed(range(3))])
|
83 |
+
input_dict['sky_histc'] = sky_histc
|
84 |
+
input_dict['sky_mask'] = sky
|
85 |
+
|
86 |
+
for key in input_dict.keys():
|
87 |
+
if isinstance(input_dict[key], torch.Tensor):
|
88 |
+
input_dict[key] = input_dict[key].unsqueeze(0)
|
89 |
+
|
90 |
+
args = ["--yaml=sat2density_cvact", "--test_ckpt_path=wandb/run-20230219_141512-2u87bj8w/files/checkpoint/model.pth", "--task=test_vid", "--demo_img=demo_img/case1/satview-input.png",
|
91 |
+
"--sty_img=demo_img/case1/groundview.image.png", "--save_dir=output"]
|
92 |
+
opt_cmd = options.parse_arguments(args=args)
|
93 |
+
opt = options.set(opt_cmd=opt_cmd)
|
94 |
+
opt.isTrain = False
|
95 |
+
opt.name = opt.yaml if opt.name is None else opt.name
|
96 |
+
opt.batch_size = 1
|
97 |
+
|
98 |
+
m = importlib.import_module("model.{}".format(opt.model))
|
99 |
+
model = m.Model(opt)
|
100 |
+
|
101 |
+
# m.load_dataset(opt)
|
102 |
+
model.build_networks(opt)
|
103 |
+
ckpt = torch.load(opt.test_ckpt_path, map_location='cpu')
|
104 |
+
model.netG.load_state_dict(ckpt['netG'])
|
105 |
+
model.netG.eval()
|
106 |
+
|
107 |
+
model.set_input(input_dict)
|
108 |
+
|
109 |
+
model.style_temp = model.sky_histc
|
110 |
+
|
111 |
+
unique_lst = list(dict.fromkeys(positions))
|
112 |
+
pixels = []
|
113 |
+
for x in positions:
|
114 |
+
if x in unique_lst:
|
115 |
+
if x not in pixels:
|
116 |
+
pixels.append(x)
|
117 |
+
pixels = np.array(pixels)
|
118 |
+
tck, u = splprep(pixels.T, s=25, per=0)
|
119 |
+
u_new = np.linspace(u.min(), u.max(), 80)
|
120 |
+
x_new, y_new = splev(u_new, tck)
|
121 |
+
smooth_path = np.array([x_new,y_new]).T
|
122 |
+
|
123 |
+
rendered_image_list = []
|
124 |
+
rendered_depth_list = []
|
125 |
+
|
126 |
+
|
127 |
+
for i, (x,y) in enumerate(smooth_path):
|
128 |
+
opt.origin_H_W = [(y-128)/128, (x-128)/128] # TODO: hard code should be removed in the future
|
129 |
+
print('Rendering at ({}, {})'.format(x,y))
|
130 |
+
model.forward(opt)
|
131 |
+
|
132 |
+
rgb = model.out_put.pred[0].clamp(min=0,max=1.0).cpu().detach().numpy().transpose((1,2,0))
|
133 |
+
rgb = np.array(rgb*255, dtype=np.uint8)
|
134 |
+
rendered_image_list.append(rgb)
|
135 |
+
|
136 |
+
rendered_depth_list.append(
|
137 |
+
model.out_put.depth[0,0].cpu().detach().numpy()
|
138 |
+
)
|
139 |
+
|
140 |
+
output_video_path = 'output_video.mp4'
|
141 |
+
|
142 |
+
# 设置视频的帧率、宽度和高度
|
143 |
+
frame_rate = 15
|
144 |
+
frame_width = 512
|
145 |
+
frame_height = 128
|
146 |
+
|
147 |
+
# 使用OpenCV创建视频写入对象,选择H.264编码器
|
148 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
149 |
+
out = cv2.VideoWriter(output_video_path, fourcc, frame_rate, (frame_width, frame_height))
|
150 |
+
|
151 |
+
# 遍历图像列表并将它们写入视频
|
152 |
+
for image_np in rendered_image_list:
|
153 |
+
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
|
154 |
+
out.write(image_np)
|
155 |
+
|
156 |
+
# 释放视频写入对象
|
157 |
+
out.release()
|
158 |
+
|
159 |
+
return "output_video.mp4"
|
160 |
+
|
161 |
+
def copy_image(image):
|
162 |
+
return image
|
163 |
+
|
164 |
+
def show_image_and_point(image, x, y):
|
165 |
+
x = int(x*image.shape[1])
|
166 |
+
y = image.shape[0]-int(y*image.shape[0])
|
167 |
+
mask = np.zeros(image.shape[:2])
|
168 |
+
radius = min(image.shape[0], image.shape[1])//60
|
169 |
+
for i in range(x-radius-2, x+radius+2):
|
170 |
+
for j in range(y-radius-2, y+radius+2):
|
171 |
+
if (i-x)**2+(j-y)**2<=radius**2:
|
172 |
+
mask[j, i] = 1
|
173 |
+
return (image, [(mask, 'render point')])
|
174 |
+
|
175 |
+
def add_select_point(image, evt: gr.SelectData, state1):
|
176 |
+
if state1 == None:
|
177 |
+
state1 = []
|
178 |
+
x, y = evt.index
|
179 |
+
state1.append((x, y))
|
180 |
+
print(state1)
|
181 |
+
radius = min(image.shape[0], image.shape[1])//60
|
182 |
+
for i in range(x-radius-2, x+radius+2):
|
183 |
+
for j in range(y-radius-2, y+radius+2):
|
184 |
+
if (i-x)**2+(j-y)**2<=radius**2:
|
185 |
+
image[j, i, :] = 0
|
186 |
+
return image, state1
|
187 |
+
|
188 |
+
def reset_select_points(image):
|
189 |
+
return image, []
|
190 |
+
|
191 |
+
|
192 |
+
|
193 |
+
|
194 |
+
|
195 |
+
|
196 |
+
with gr.Blocks() as demo:
|
197 |
+
gr.Markdown("# Sat2Density Demos")
|
198 |
+
gr.Markdown("### select/upload the satllite image and select the style image")
|
199 |
+
with gr.Row():
|
200 |
+
with gr.Column():
|
201 |
+
sat_img = gr.Image(source='upload', shape=[256, 256], interactive=True)
|
202 |
+
img_examples = gr.Examples(examples=['demo_img/{}/satview-input.png'.format(i) for i in os.listdir('demo_img') if 'case' in i],
|
203 |
+
inputs=sat_img, outputs=None, examples_per_page=20)
|
204 |
+
with gr.Column():
|
205 |
+
style_img = gr.Image()
|
206 |
+
style_examples = gr.Examples(examples=['demo_img/{}/groundview.image.png'.format(i) for i in os.listdir('demo_img') if 'case' in i],
|
207 |
+
inputs=style_img, outputs=None, examples_per_page=20)
|
208 |
+
|
209 |
+
|
210 |
+
gr.Markdown("### select a certain point to generate single groundview image")
|
211 |
+
with gr.Row():
|
212 |
+
with gr.Column():
|
213 |
+
with gr.Row():
|
214 |
+
with gr.Column():
|
215 |
+
slider_x = gr.Slider(0.2, 0.8, 0.5, label="x-axis position")
|
216 |
+
slider_y = gr.Slider(0.2, 0.8, 0.5, label="y-axis position")
|
217 |
+
btn_single = gr.Button(label="demo1")
|
218 |
+
|
219 |
+
annotation_image = gr.AnnotatedImage()
|
220 |
+
|
221 |
+
out_single = gr.Image()
|
222 |
+
|
223 |
+
gr.Markdown("### draw a trajectory on the map to generate video")
|
224 |
+
state_select_points = gr.State()
|
225 |
+
with gr.Row():
|
226 |
+
with gr.Column():
|
227 |
+
draw_img = gr.Image(shape=[256, 256], interactive=True)
|
228 |
+
with gr.Column():
|
229 |
+
out_video = gr.Video()
|
230 |
+
reset_btn =gr.Button(value="Reset")
|
231 |
+
btn_video = gr.Button(label="demo1")
|
232 |
+
|
233 |
+
sat_img.change(copy_image, inputs = sat_img, outputs=draw_img)
|
234 |
+
|
235 |
+
draw_img.select(add_select_point, [draw_img, state_select_points], [draw_img, state_select_points])
|
236 |
+
sat_img.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image)
|
237 |
+
slider_x.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image, show_progress='hidden')
|
238 |
+
slider_y.change(show_image_and_point, inputs = [sat_img, slider_x, slider_y], outputs = annotation_image, show_progress='hidden')
|
239 |
+
btn_single.click(get_single, inputs = [sat_img, style_img, slider_x, slider_y], outputs=out_single)
|
240 |
+
reset_btn.click(reset_select_points, [sat_img], [draw_img, state_select_points])
|
241 |
+
btn_video.click(get_video, inputs=[sat_img, style_img, state_select_points], outputs=out_video) # 触发
|
242 |
+
|
243 |
+
|
244 |
+
demo.launch()
|
data/CVACT_Shi.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch,os
|
2 |
+
from torch.utils.data.dataset import Dataset
|
3 |
+
from PIL import Image
|
4 |
+
import scipy.io as sio
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
|
7 |
+
def data_list(img_root,mode):
|
8 |
+
exist_aer_list = os.listdir(os.path.join(img_root , 'satview_correct'))
|
9 |
+
exist_grd_list = os.listdir(os.path.join(img_root , 'streetview'))
|
10 |
+
allDataList = os.path.join(img_root, 'ACT_data.mat')
|
11 |
+
anuData = sio.loadmat(allDataList)
|
12 |
+
|
13 |
+
all_data_list = []
|
14 |
+
for i in range(0, len(anuData['panoIds'])):
|
15 |
+
grd_id_align = anuData['panoIds'][i] + '_grdView.png'
|
16 |
+
sat_id_ori = anuData['panoIds'][i] + '_satView_polish.png'
|
17 |
+
all_data_list.append([grd_id_align, sat_id_ori])
|
18 |
+
|
19 |
+
data_list = []
|
20 |
+
|
21 |
+
if mode=='train':
|
22 |
+
training_inds = anuData['trainSet']['trainInd'][0][0] - 1
|
23 |
+
trainNum = len(training_inds)
|
24 |
+
for k in range(trainNum):
|
25 |
+
data_list.append(all_data_list[training_inds[k][0]])
|
26 |
+
else:
|
27 |
+
val_inds = anuData['valSet']['valInd'][0][0] - 1
|
28 |
+
valNum = len(val_inds)
|
29 |
+
for k in range(valNum):
|
30 |
+
data_list.append(all_data_list[val_inds[k][0]])
|
31 |
+
|
32 |
+
|
33 |
+
pano_list = [img_root + 'streetview/' + item[0] for item in data_list if item[0] in exist_grd_list and item[1] in exist_aer_list]
|
34 |
+
|
35 |
+
return pano_list
|
36 |
+
|
37 |
+
def img_read(img,size=None,datatype='RGB'):
|
38 |
+
img = Image.open(img).convert('RGB' if datatype=='RGB' else "L")
|
39 |
+
if size:
|
40 |
+
if type(size) is int:
|
41 |
+
size = (size,size)
|
42 |
+
img = img.resize(size = size,resample=Image.BICUBIC if datatype=='RGB' else Image.NEAREST)
|
43 |
+
img = transforms.ToTensor()(img)
|
44 |
+
return img
|
45 |
+
|
46 |
+
|
47 |
+
class Dataset(Dataset):
|
48 |
+
def __init__(self, opt,split='train',sub=None,sty_img=None):
|
49 |
+
if sty_img:
|
50 |
+
assert sty_img.endswith('grdView.png')
|
51 |
+
demo_img_path = os.path.join(opt.data.root,'streetview',sty_img)
|
52 |
+
self.pano_list = [demo_img_path]
|
53 |
+
|
54 |
+
elif opt.task in ['test_vid','test_interpolation'] :
|
55 |
+
demo_img_path = os.path.join(opt.data.root,'streetview',opt.demo_img.replace('satView_polish.png','grdView.png'))
|
56 |
+
self.pano_list = [demo_img_path]
|
57 |
+
|
58 |
+
else:
|
59 |
+
self.pano_list = data_list(img_root=opt.data.root,mode=split)
|
60 |
+
if sub:
|
61 |
+
self.pano_list = self.pano_list[:sub]
|
62 |
+
|
63 |
+
# select some ground images to test the influence of different skys.
|
64 |
+
# different skys guide different illumination intensity, colors, and etc.
|
65 |
+
if opt.task == 'test_sty':
|
66 |
+
demo_name = [
|
67 |
+
'dataset/CVACT/streetview/pPfo7qQ1fP_24rXrJ2Uxog_grdView.png',
|
68 |
+
'dataset/CVACT/streetview/YL81FiK9PucIvAkr1FHkpA_grdView.png',
|
69 |
+
'dataset/CVACT/streetview/Tzis1jBKHjbXiVB2oRYwAQ_grdView.png',
|
70 |
+
'dataset/CVACT/streetview/eqGgeBLGXRhSj6c-0h0KoQ_grdView.png',
|
71 |
+
'dataset/CVACT/streetview/pdZmLHYEhe2PHj_8-WHMhw_grdView.png',
|
72 |
+
'dataset/CVACT/streetview/ehsu9Q3iTin5t52DM-MwyQ_grdView.png',
|
73 |
+
'dataset/CVACT/streetview/agLEcuq3_-qFj7wwGbktVg_grdView.png',
|
74 |
+
'dataset/CVACT/streetview/HwQIDdMI3GfHyPGtCSo6aA_grdView.png',
|
75 |
+
'dataset/CVACT/streetview/hV8svb3ZVXcQ0AtTRFE1dQ_grdView.png',
|
76 |
+
'dataset/CVACT/streetview/fzq2mBfKP3UIczAd9KpMMg_grdView.png',
|
77 |
+
'dataset/CVACT/streetview/acRP98sACUIlwl2ZIsEyiQ_grdView.png',
|
78 |
+
'dataset/CVACT/streetview/WSh9tNVryLdupUlU0ri2tQ_grdView.png',
|
79 |
+
'dataset/CVACT/streetview/FhEuB9NA5o08VJ_TBCbHjw_grdView.png',
|
80 |
+
'dataset/CVACT/streetview/YHfpn2Mgu1lqgT2OUeBpOg_grdView.png',
|
81 |
+
'dataset/CVACT/streetview/vNhv7ZP1dUkJ93UwFXagJw_grdView.png',
|
82 |
+
]
|
83 |
+
self.pano_list = demo_name
|
84 |
+
|
85 |
+
self.opt = opt
|
86 |
+
|
87 |
+
def __len__(self):
|
88 |
+
return len(self.pano_list)
|
89 |
+
|
90 |
+
def __getitem__(self, index):
|
91 |
+
pano = self.pano_list[index]
|
92 |
+
aer = pano.replace('streetview','satview_correct').replace('_grdView','_satView_polish')
|
93 |
+
if self.opt.data.sky_mask:
|
94 |
+
sky = pano.replace('streetview','pano_sky_mask')
|
95 |
+
name = pano
|
96 |
+
aer = img_read(aer, size = self.opt.data.sat_size)
|
97 |
+
pano = img_read(pano,size = self.opt.data.pano_size)
|
98 |
+
if self.opt.data.sky_mask:
|
99 |
+
sky = img_read(sky,size=self.opt.data.pano_size,datatype='L')
|
100 |
+
|
101 |
+
input = {}
|
102 |
+
input['sat']=aer
|
103 |
+
input['pano']=pano
|
104 |
+
input['paths']=name
|
105 |
+
if self.opt.data.sky_mask:
|
106 |
+
input['sky_mask']=sky
|
107 |
+
black_ground = torch.zeros_like(pano)
|
108 |
+
if self.opt.data.histo_mode =='grey':
|
109 |
+
input['sky_histc'] = (pano*sky+black_ground*(1-sky)).histc()[10:]
|
110 |
+
elif self.opt.data.histo_mode in ['rgb','RGB']:
|
111 |
+
input_a = (pano*sky+black_ground*(1-sky))
|
112 |
+
for idx in range(len(input_a)):
|
113 |
+
if idx == 0:
|
114 |
+
sky_histc = input_a[idx].histc()[10:]
|
115 |
+
else:
|
116 |
+
sky_histc = torch.cat([input_a[idx].histc()[10:],sky_histc],dim=0)
|
117 |
+
input['sky_histc'] = sky_histc
|
118 |
+
return input
|
119 |
+
|
data/CVUSA.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch,os
|
2 |
+
from torch.utils.data.dataset import Dataset
|
3 |
+
from PIL import Image
|
4 |
+
import torchvision.transforms as transforms
|
5 |
+
import re
|
6 |
+
from easydict import EasyDict as edict
|
7 |
+
|
8 |
+
def data_list(img_root,mode):
|
9 |
+
data_list=[]
|
10 |
+
if mode=='train':
|
11 |
+
split_file=os.path.join(img_root, 'splits/train-19zl.csv')
|
12 |
+
with open(split_file) as f:
|
13 |
+
list = f.readlines()
|
14 |
+
for i in list:
|
15 |
+
aerial_name=re.split(r',', re.split('\n', i)[0])[0]
|
16 |
+
panorama_name = re.split(r',', re.split('\n', i)[0])[1]
|
17 |
+
data_list.append([aerial_name, panorama_name])
|
18 |
+
else:
|
19 |
+
split_file=os.path.join(img_root+'splits/val-19zl.csv')
|
20 |
+
with open(split_file) as f:
|
21 |
+
list = f.readlines()
|
22 |
+
for i in list:
|
23 |
+
aerial_name=re.split(r',', re.split('\n', i)[0])[0]
|
24 |
+
panorama_name = re.split(r',', re.split('\n', i)[0])[1]
|
25 |
+
data_list.append([aerial_name, panorama_name])
|
26 |
+
print('length of dataset is: ', len(data_list))
|
27 |
+
return [os.path.join(img_root, i[1]) for i in data_list]
|
28 |
+
|
29 |
+
def img_read(img,size=None,datatype='RGB'):
|
30 |
+
img = Image.open(img).convert('RGB' if datatype=='RGB' else "L")
|
31 |
+
if size:
|
32 |
+
if type(size) is int:
|
33 |
+
size = (size,size)
|
34 |
+
img = img.resize(size = size,resample=Image.BICUBIC if datatype=='RGB' else Image.NEAREST)
|
35 |
+
img = transforms.ToTensor()(img)
|
36 |
+
return img
|
37 |
+
|
38 |
+
|
39 |
+
class Dataset(Dataset):
|
40 |
+
def __init__(self, opt,split='train',sub=None,sty_img=None):
|
41 |
+
self.pano_list = data_list(img_root=opt.data.root,mode=split)
|
42 |
+
if sub:
|
43 |
+
self.pano_list = self.pano_list[:sub]
|
44 |
+
if opt.task == 'test_vid':
|
45 |
+
demo_img_path = os.path.join(opt.data.root, 'streetview/panos', opt.demo_img)
|
46 |
+
self.pano_list = [demo_img_path]
|
47 |
+
if sty_img:
|
48 |
+
assert opt.sty_img.split('.')[-1] == 'jpg'
|
49 |
+
demo_img_path = os.path.join(opt.data.root, 'streetview/panos', opt.sty_img)
|
50 |
+
self.pano_list = [demo_img_path]
|
51 |
+
|
52 |
+
self.opt = opt
|
53 |
+
|
54 |
+
def __len__(self):
|
55 |
+
return len(self.pano_list)
|
56 |
+
|
57 |
+
def __getitem__(self, index):
|
58 |
+
pano = self.pano_list[index]
|
59 |
+
aer = pano.replace('streetview/panos', 'bingmap/19')
|
60 |
+
if self.opt.data.sky_mask:
|
61 |
+
sky = pano.replace('streetview/panos','sky_mask').replace('jpg', 'png')
|
62 |
+
name = pano
|
63 |
+
aer = img_read(aer, size = self.opt.data.sat_size)
|
64 |
+
pano = img_read(pano,size = self.opt.data.pano_size)
|
65 |
+
if self.opt.data.sky_mask:
|
66 |
+
sky = img_read(sky,size=self.opt.data.pano_size,datatype='L')
|
67 |
+
|
68 |
+
input = {}
|
69 |
+
input['sat']=aer
|
70 |
+
input['pano']=pano
|
71 |
+
input['paths']=name
|
72 |
+
if self.opt.data.sky_mask:
|
73 |
+
input['sky_mask']=sky
|
74 |
+
black_ground = torch.zeros_like(pano)
|
75 |
+
if self.opt.data.histo_mode =='grey':
|
76 |
+
input['sky_histc'] = (pano*sky+black_ground*(1-sky)).histc()[10:]
|
77 |
+
elif self.opt.data.histo_mode in ['rgb','RGB']:
|
78 |
+
input_a = (pano*sky+black_ground*(1-sky))
|
79 |
+
for idx in range(len(input_a)):
|
80 |
+
if idx == 0:
|
81 |
+
sky_histc = input_a[idx].histc()[10:]
|
82 |
+
else:
|
83 |
+
sky_histc = torch.cat([input_a[idx].histc()[10:],sky_histc],dim=0)
|
84 |
+
input['sky_histc'] = sky_histc
|
85 |
+
return input
|
86 |
+
|
dataset/INSTALL.md
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
For reproduce our paper,
|
2 |
+
|
3 |
+
you should first download 4 zip file:
|
4 |
+
|
5 |
+
`
|
6 |
+
CVACT/satview_correct.zip ,
|
7 |
+
CVACT/streetview.zip ,
|
8 |
+
CVUSA/bingmap/19.zip ,
|
9 |
+
CVUSA/streetview/panos.zip
|
10 |
+
`
|
11 |
+
from [here](https://anu365-my.sharepoint.com/:f:/g/personal/u6293587_anu_edu_au/EuOBUDUQNClJvCpQ8bD1hnoBjdRBWxsHOVp946YVahiMGg?e=F4yRAC), the project page is [Sat2StrPanoramaSynthesis](https://github.com/shiyujiao/Sat2StrPanoramaSynthesis).
|
12 |
+
|
13 |
+
Then download the sky mask from [here](https://drive.google.com/drive/folders/1pfzwONg4P-Mzvxvzb2HoCpuZFynElPCk?usp=sharing)
|
14 |
+
|
15 |
+
Last,the users should organize the dataset just like:
|
16 |
+
```
|
17 |
+
├dataset
|
18 |
+
├── CVACT
|
19 |
+
│ ├── streetview
|
20 |
+
│ ├── satview_correct
|
21 |
+
│ ├── pano_sky_mask
|
22 |
+
│ ├── ACT_data.mat
|
23 |
+
└── CVUSA
|
24 |
+
│ ├── bingmap
|
25 |
+
│ │ ├── 19
|
26 |
+
│ └── streetview
|
27 |
+
│ │ ├── panos
|
28 |
+
│ ├── sky_mask
|
29 |
+
│ ├── splits
|
30 |
+
```
|
31 |
+
|
32 |
+
Tip: The sky masks are processed with [Trans4PASS](https://github.com/jamycheung/Trans4PASS).
|
demo_img/case1/groundview.image.png
ADDED
demo_img/case1/groundview.sky.png
ADDED
demo_img/case1/satview-input.png
ADDED
demo_img/case10/groundview.image.png
ADDED
demo_img/case10/groundview.sky.png
ADDED
demo_img/case10/satview-input.png
ADDED
demo_img/case11/groundview.image.png
ADDED
demo_img/case11/groundview.sky.png
ADDED
demo_img/case11/satview-input.png
ADDED
demo_img/case12/groundview.image.png
ADDED
demo_img/case12/groundview.sky.png
ADDED
demo_img/case12/satview-input.png
ADDED
demo_img/case13/groundview.image.png
ADDED
demo_img/case13/groundview.sky.png
ADDED
demo_img/case13/satview-input.png
ADDED
demo_img/case2/groundview.image.png
ADDED
demo_img/case2/groundview.sky.png
ADDED
demo_img/case2/satview-input.png
ADDED
demo_img/case3/groundview.image.png
ADDED
demo_img/case3/groundview.sky.png
ADDED
demo_img/case3/satview-input.png
ADDED
demo_img/case4/groundview.image.png
ADDED
demo_img/case4/groundview.sky.png
ADDED
demo_img/case4/satview-input.png
ADDED
demo_img/case5/groundview.image.png
ADDED
demo_img/case5/groundview.sky.png
ADDED
demo_img/case5/satview-input.png
ADDED
demo_img/case6/groundview.image.png
ADDED
demo_img/case6/groundview.sky.png
ADDED
demo_img/case6/satview-input.png
ADDED
demo_img/case7/groundview.image.png
ADDED
demo_img/case7/groundview.sky.png
ADDED
demo_img/case7/satview-input.png
ADDED
demo_img/case8/groundview.image.png
ADDED
demo_img/case8/groundview.sky.png
ADDED
demo_img/case8/satview-input.png
ADDED
demo_img/case9/groundview.image.png
ADDED
demo_img/case9/groundview.sky.png
ADDED
demo_img/case9/satview-input.png
ADDED
demo_img/runall.sh
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# for case in `ls -d demo_img/case*`
|
2 |
+
for case_id in 1 2 3 4
|
3 |
+
do
|
4 |
+
case=demo_img/case$case_id
|
5 |
+
echo $case
|
6 |
+
python test.py --yaml=sat2density_cvact \
|
7 |
+
--test_ckpt_path=2u87bj8w \
|
8 |
+
--task=test_vid \
|
9 |
+
--demo_img=$case/satview-input.png \
|
10 |
+
--sty_img=$case/groundview.image.png \
|
11 |
+
--save_dir=results/$case
|
12 |
+
# ffmpeg -framerate 10 -i results/$case/rendered_images+depths/%5d.png results/$case/render.gif
|
13 |
+
ffmpeg -framerate 10 -i results/$case/rendered_images+depths/%5d.png -vf "palettegen" results/$case-palette.png
|
14 |
+
ffmpeg -framerate 10 -i results/$case/rendered_images+depths/%5d.png -i results/$case-palette.png -filter_complex "paletteuse" results/$case/render.gif
|
15 |
+
|
16 |
+
ffmpeg -framerate 10 -i results/$case/sat_images/%5d.png -vf "palettegen" results/$case-palette.png
|
17 |
+
ffmpeg -framerate 10 -i results/$case/sat_images/%5d.png -i results/$case-palette.png -filter_complex "paletteuse" results/$case/sat.gif
|
18 |
+
# ffmpeg -framerate 10 -i results/$case/sat_images/%5d.png results/$case/sat.gif
|
19 |
+
done
|
20 |
+
|
21 |
+
# for case in `ls -d demo_img/case*`
|
22 |
+
for case_id in 1 2 3 4
|
23 |
+
do
|
24 |
+
case=demo_img/case$case_id
|
25 |
+
sat_gif=results/$case/sat.gif
|
26 |
+
render_gif=results/$case/render.gif
|
27 |
+
# echo $sat_gif
|
28 |
+
cp $sat_gif docs/figures/demo/case$case_id.sat.gif
|
29 |
+
cp $render_gif docs/figures/demo/case$case_id.render.gif
|
30 |
+
done
|
imaginaire/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
# To view a copy of this license, check out LICENSE.md
|
imaginaire/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (135 Bytes). View file
|
|