Spaces:
Running
Running
wangerniu
commited on
Commit
·
629144d
1
Parent(s):
5de8ec7
maplocnet
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +4 -0
- conf/maplocnet.yaml +100 -0
- dataset/UAV/dataset.py +116 -0
- dataset/UAV/prepara_dataset.py +270 -0
- dataset/__init__.py +4 -0
- dataset/dataset.py +93 -0
- dataset/image.py +140 -0
- dataset/torch.py +111 -0
- demo.py +354 -0
- evaluation/kitti.py +89 -0
- evaluation/mapillary.py +111 -0
- evaluation/run.py +252 -0
- evaluation/utils.py +40 -0
- evaluation/viz.py +178 -0
- flagged/inp/10d2e4a8712491181c2f48b61f5003b216d2b9f9/tmp48n9eoyh.png +0 -0
- flagged/inp/e1b18d44d9e381d586209f73a015fed7f688822b/tmp86ith_2q.png +0 -0
- flagged/log.csv +3 -0
- flagged/output/tmp59657zop.json +1 -0
- flagged/output/tmpbs17s28d.json +1 -0
- images/00000.jpg +0 -0
- images/00011.jpg +0 -0
- images/00022.jpg +0 -0
- images/00033.jpg +0 -0
- images/cat_dog.png +0 -0
- label.txt +1000 -0
- logger.py +28 -0
- main.py +98 -0
- models/__init__.py +34 -0
- models/base.py +123 -0
- models/feature_extractor.py +231 -0
- models/feature_extractor_v2.py +192 -0
- models/map_encoder.py +67 -0
- models/maplocnet.py +204 -0
- models/metrics.py +118 -0
- models/utils.py +87 -0
- models/voting.py +365 -0
- module.py +171 -0
- osm/analysis.py +182 -0
- osm/data.py +230 -0
- osm/download.py +118 -0
- osm/parser.py +255 -0
- osm/raster.py +103 -0
- osm/reader.py +310 -0
- osm/tiling.py +310 -0
- osm/viz.py +159 -0
- requirements.txt +20 -0
- train.py +217 -0
- train.sh +1 -0
- utils/exif.py +356 -0
- utils/geo.py +130 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.idea
|
2 |
+
temp
|
3 |
+
temp.py
|
4 |
+
weight
|
conf/maplocnet.yaml
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
root: '/root/DATASET/UAV2MAP/UAV/'
|
3 |
+
train_citys:
|
4 |
+
- Paris
|
5 |
+
- Berlin
|
6 |
+
- London
|
7 |
+
- Tokyo
|
8 |
+
- NewYork
|
9 |
+
val_citys:
|
10 |
+
- Toronto
|
11 |
+
image_size: 256
|
12 |
+
train:
|
13 |
+
batch_size: 12
|
14 |
+
num_workers: 4
|
15 |
+
val:
|
16 |
+
batch_size: ${..train.batch_size}
|
17 |
+
num_workers: ${.batch_size}
|
18 |
+
num_classes:
|
19 |
+
areas: 7
|
20 |
+
ways: 10
|
21 |
+
nodes: 33
|
22 |
+
pixel_per_meter: 1
|
23 |
+
crop_size_meters: 64
|
24 |
+
max_init_error: 48
|
25 |
+
add_map_mask: true
|
26 |
+
resize_image: 512
|
27 |
+
pad_to_square: true
|
28 |
+
rectify_pitch: true
|
29 |
+
augmentation:
|
30 |
+
rot90: true
|
31 |
+
flip: true
|
32 |
+
image:
|
33 |
+
apply: true
|
34 |
+
brightness: 0.5
|
35 |
+
contrast: 0.4
|
36 |
+
saturation: 0.4
|
37 |
+
hue": 0.5/3.14
|
38 |
+
model:
|
39 |
+
image_size: ${data.image_size}
|
40 |
+
latent_dim: 128
|
41 |
+
val_citys: ${data.val_citys}
|
42 |
+
image_encoder:
|
43 |
+
name: feature_extractor_v2
|
44 |
+
backbone:
|
45 |
+
encoder: resnet50
|
46 |
+
pretrained: true
|
47 |
+
output_dim: 8
|
48 |
+
num_downsample: null
|
49 |
+
remove_stride_from_first_conv: false
|
50 |
+
name: orienternet
|
51 |
+
matching_dim: 8
|
52 |
+
z_max: 32
|
53 |
+
x_max: 32
|
54 |
+
pixel_per_meter: 1
|
55 |
+
num_scale_bins: 33
|
56 |
+
num_rotations: 64
|
57 |
+
map_encoder:
|
58 |
+
embedding_dim: 16
|
59 |
+
output_dim: 8
|
60 |
+
num_classes:
|
61 |
+
areas: 7
|
62 |
+
ways: 10
|
63 |
+
nodes: 33
|
64 |
+
backbone:
|
65 |
+
encoder: vgg19
|
66 |
+
pretrained: false
|
67 |
+
output_scales:
|
68 |
+
- 0
|
69 |
+
num_downsample: 3
|
70 |
+
decoder:
|
71 |
+
- 128
|
72 |
+
- 64
|
73 |
+
- 64
|
74 |
+
padding: replicate
|
75 |
+
unary_prior: false
|
76 |
+
bev_net:
|
77 |
+
num_blocks: 4
|
78 |
+
latent_dim: 128
|
79 |
+
output_dim: 8
|
80 |
+
confidence: true
|
81 |
+
experiment:
|
82 |
+
name: maplocanet_0906_diffhight
|
83 |
+
gpus: 6
|
84 |
+
seed: 0
|
85 |
+
training:
|
86 |
+
lr: 0.0001
|
87 |
+
lr_scheduler: null
|
88 |
+
finetune_from_checkpoint: null
|
89 |
+
trainer:
|
90 |
+
val_check_interval: 1000
|
91 |
+
log_every_n_steps: 100
|
92 |
+
# limit_val_batches: 1000
|
93 |
+
max_steps: 200000
|
94 |
+
devices: ${experiment.gpus}
|
95 |
+
checkpointing:
|
96 |
+
monitor: "loss/total/val"
|
97 |
+
save_top_k: 10
|
98 |
+
mode: min
|
99 |
+
|
100 |
+
# filename: '{epoch}-{step}-{loss_SanFrancisco:.2f}'
|
dataset/UAV/dataset.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
import os
|
4 |
+
import cv2
|
5 |
+
# @Time : 2023-02-13 22:56
|
6 |
+
# @Author : Wang Zhen
|
7 |
+
# @Email : frozenzhencola@163.com
|
8 |
+
# @File : SatelliteTool.py
|
9 |
+
# @Project : TGRS_seqmatch_2023_1
|
10 |
+
import numpy as np
|
11 |
+
import random
|
12 |
+
from utils.geo import BoundaryBox, Projection
|
13 |
+
from osm.tiling import TileManager,MapTileManager
|
14 |
+
from pathlib import Path
|
15 |
+
from torchvision import transforms
|
16 |
+
from torch.utils.data import DataLoader
|
17 |
+
|
18 |
+
class UavMapPair(Dataset):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
root: Path,
|
22 |
+
city:str,
|
23 |
+
training:bool,
|
24 |
+
transform
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
# self.root = root
|
29 |
+
|
30 |
+
# city = 'Manhattan'
|
31 |
+
# root = '/root/DATASET/CrossModel/'
|
32 |
+
# root=Path(root)
|
33 |
+
self.uav_image_path = root/city/'uav'
|
34 |
+
self.map_path = root/city/'map'
|
35 |
+
self.map_vis = root / city / 'map_vis'
|
36 |
+
info_path = root / city / 'info.csv'
|
37 |
+
|
38 |
+
self.info = np.loadtxt(str(info_path), dtype=str, delimiter=",", skiprows=1)
|
39 |
+
|
40 |
+
self.transform=transform
|
41 |
+
self.training=training
|
42 |
+
|
43 |
+
def random_center_crop(self,image):
|
44 |
+
height, width = image.shape[:2]
|
45 |
+
|
46 |
+
# 随机生成剪裁尺寸
|
47 |
+
crop_size = random.randint(min(height, width) // 2, min(height, width))
|
48 |
+
|
49 |
+
# 计算剪裁的起始坐标
|
50 |
+
start_x = (width - crop_size) // 2
|
51 |
+
start_y = (height - crop_size) // 2
|
52 |
+
|
53 |
+
# 进行剪裁
|
54 |
+
cropped_image = image[start_y:start_y + crop_size, start_x:start_x + crop_size]
|
55 |
+
|
56 |
+
return cropped_image
|
57 |
+
def __getitem__(self, index: int):
|
58 |
+
id, uav_name, map_name, \
|
59 |
+
uav_long, uav_lat, \
|
60 |
+
map_long, map_lat, \
|
61 |
+
tile_size_meters, pixel_per_meter, \
|
62 |
+
u, v, yaw,dis=self.info[index]
|
63 |
+
|
64 |
+
|
65 |
+
uav_image=cv2.imread(str(self.uav_image_path/uav_name))
|
66 |
+
if self.training:
|
67 |
+
uav_image =self.random_center_crop(uav_image)
|
68 |
+
uav_image=cv2.cvtColor(uav_image,cv2.COLOR_BGR2RGB)
|
69 |
+
if self.transform:
|
70 |
+
uav_image=self.transform(uav_image)
|
71 |
+
map=np.load(str(self.map_path/map_name))
|
72 |
+
|
73 |
+
return {
|
74 |
+
'map':torch.from_numpy(np.ascontiguousarray(map)).long(),
|
75 |
+
'image':torch.tensor(uav_image),
|
76 |
+
'roll_pitch_yaw':torch.tensor((0, 0, float(yaw))).float(),
|
77 |
+
'pixels_per_meter':torch.tensor(float(pixel_per_meter)).float(),
|
78 |
+
"uv":torch.tensor([float(u), float(v)]).float(),
|
79 |
+
}
|
80 |
+
def __len__(self):
|
81 |
+
return len(self.info)
|
82 |
+
if __name__ == '__main__':
|
83 |
+
|
84 |
+
root=Path('/root/DATASET/OrienterNet/UavMap/')
|
85 |
+
city='NewYork'
|
86 |
+
|
87 |
+
transform = transforms.Compose([
|
88 |
+
transforms.ToTensor(),
|
89 |
+
transforms.Resize(256),
|
90 |
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
91 |
+
])
|
92 |
+
|
93 |
+
dataset=UavMapPair(
|
94 |
+
root=root,
|
95 |
+
city=city,
|
96 |
+
transform=transform
|
97 |
+
)
|
98 |
+
datasetloder = DataLoader(dataset, batch_size=3)
|
99 |
+
for batch, i in enumerate(datasetloder):
|
100 |
+
pass
|
101 |
+
# 将PyTorch张量转换为PIL图像
|
102 |
+
# pil_image = Image.fromarray(i['uav_image'][0].permute(1, 2, 0).byte().numpy())
|
103 |
+
|
104 |
+
# 显示图像
|
105 |
+
# 将PyTorch张量转换为NumPy数组
|
106 |
+
# numpy_array = i['uav_image'][0].numpy()
|
107 |
+
#
|
108 |
+
# # 显示图像
|
109 |
+
# plt.imshow(numpy_array.transpose(1, 2, 0))
|
110 |
+
# plt.axis('off')
|
111 |
+
# plt.show()
|
112 |
+
#
|
113 |
+
# map_viz, label = Colormap.apply(i['map'][0])
|
114 |
+
# map_viz = map_viz * 255
|
115 |
+
# map_viz = map_viz.astype(np.uint8)
|
116 |
+
# plot_images([map_viz], titles=["OpenStreetMap raster"])
|
dataset/UAV/prepara_dataset.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
import os
|
4 |
+
import cv2
|
5 |
+
# @Time : 2023-02-13 22:56
|
6 |
+
# @Author : Wang Zhen
|
7 |
+
# @Email : frozenzhencola@163.com
|
8 |
+
# @File : SatelliteTool.py
|
9 |
+
# @Project : TGRS_seqmatch_2023_1
|
10 |
+
import numpy as np
|
11 |
+
import random
|
12 |
+
from utils.geo import BoundaryBox, Projection
|
13 |
+
from osm.tiling import TileManager,MapTileManager
|
14 |
+
from pathlib import Path
|
15 |
+
from torchvision import transforms
|
16 |
+
from tqdm import tqdm
|
17 |
+
import time
|
18 |
+
import math
|
19 |
+
import random
|
20 |
+
from geopy import Point, distance
|
21 |
+
from osm.viz import Colormap, plot_nodes
|
22 |
+
|
23 |
+
def generate_random_coordinate(latitude, longitude, dis):
|
24 |
+
# 生成一个随机方向角
|
25 |
+
random_angle = random.uniform(0, 360)
|
26 |
+
# print("random_angle",random_angle)
|
27 |
+
# 计算目标点的经纬度
|
28 |
+
start_point = Point(latitude, longitude)
|
29 |
+
destination = distance.distance(kilometers=dis/1000).destination(start_point, random_angle)
|
30 |
+
|
31 |
+
return destination.latitude, destination.longitude
|
32 |
+
|
33 |
+
def rotate_corp(src,angle):
|
34 |
+
# 原图的高、宽 以及通道数
|
35 |
+
rows, cols, channel = src.shape
|
36 |
+
|
37 |
+
# 绕图像的中心旋转
|
38 |
+
# 参数:旋转中心 旋转度数 scale
|
39 |
+
M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1)
|
40 |
+
# rows, cols=700,700
|
41 |
+
# 自适应图片边框大小
|
42 |
+
cos = np.abs(M[0, 0])
|
43 |
+
sin = np.abs(M[0, 1])
|
44 |
+
new_w = rows * sin + cols * cos
|
45 |
+
new_h = rows * cos + cols * sin
|
46 |
+
M[0, 2] += (new_w - cols) * 0.5
|
47 |
+
M[1, 2] += (new_h - rows) * 0.5
|
48 |
+
w = int(np.round(new_w))
|
49 |
+
h = int(np.round(new_h))
|
50 |
+
rotated = cv2.warpAffine(src, M, (w, h))
|
51 |
+
|
52 |
+
# rotated = cv2.warpAffine(src, M, (cols, rows))
|
53 |
+
|
54 |
+
c=int(w / 2)
|
55 |
+
w=int(rows*math.sqrt(2)/4)
|
56 |
+
rotated2=rotated[c-w:c+w,c-w:c+w,:]
|
57 |
+
return rotated2
|
58 |
+
|
59 |
+
class SatelliteGeoTools:
|
60 |
+
"""
|
61 |
+
用于读取卫星图tfw文件,执行 像素坐标-Mercator-GPS坐标 的转化
|
62 |
+
"""
|
63 |
+
def __init__(self, tfw_path):
|
64 |
+
self.SatelliteParameter=self.Parsetfw(tfw_path)
|
65 |
+
def Parsetfw(self, tfw_path):
|
66 |
+
info = []
|
67 |
+
f = open(tfw_path)
|
68 |
+
for _ in range(6):
|
69 |
+
line = f.readline()
|
70 |
+
line = line.strip('\n')
|
71 |
+
info.append(float(line))
|
72 |
+
f.close()
|
73 |
+
return info
|
74 |
+
def Pix2Geo(self, x, y):
|
75 |
+
A, D, B, E, C, F = self.SatelliteParameter
|
76 |
+
x1 = A * x + B * y + C
|
77 |
+
y1 = D * x + E * y + F
|
78 |
+
# print(x1,y1)
|
79 |
+
s_long, s_lat = self.MercatorTolonlat(x1, y1)
|
80 |
+
return s_long, s_lat
|
81 |
+
|
82 |
+
def Geo2Pix(self, lon, lat):
|
83 |
+
"""
|
84 |
+
https://baike.baidu.com/item/TFW%E6%A0%BC%E5%BC%8F/6273151?fr=aladdin
|
85 |
+
x'=Ax+By+C
|
86 |
+
y'=Dx+Ey+F
|
87 |
+
:return:
|
88 |
+
"""
|
89 |
+
x1, y1 = self.LonlatToMercator(lon, lat)
|
90 |
+
A, D, B, E, C, F = self.SatelliteParameter
|
91 |
+
M = np.array([[A, B, C],
|
92 |
+
[D, E, F],
|
93 |
+
[0, 0, 1]])
|
94 |
+
M_INV = np.linalg.inv(M)
|
95 |
+
XY = np.matmul(M_INV, np.array([x1, y1, 1]).T)
|
96 |
+
return int(XY[0]), int(XY[1])
|
97 |
+
def MercatorTolonlat(self,mx,my):
|
98 |
+
x = mx/20037508.3427892*180
|
99 |
+
y = my/20037508.3427892*180
|
100 |
+
# y= 180/math.pi*(2*math.atan(math.exp(y*math.pi/180))-math.pi/2)
|
101 |
+
y = 180.0 / np.pi * (2.0 * np.arctan(np.exp(y * np.pi / 180.0)) - np.pi / 2.0)
|
102 |
+
return x,y
|
103 |
+
def LonlatToMercator(self,lon, lat):
|
104 |
+
x = lon * 20037508.342789 / 180
|
105 |
+
y = np.log(np.tan((90 + lat) * np.pi / 360)) / (np.pi / 180)
|
106 |
+
y = y * 20037508.34789 / 180
|
107 |
+
return x, y
|
108 |
+
|
109 |
+
def geodistance(lng1, lat1, lng2, lat2):
|
110 |
+
lng1, lat1, lng2, lat2 = map(np.radians, [lng1, lat1, lng2, lat2])
|
111 |
+
dlon = lng2 - lng1
|
112 |
+
dlat = lat2 - lat1
|
113 |
+
a = np.sin(dlat / 2) ** 2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2) ** 2
|
114 |
+
distance = 2 * np.arcsin(np.sqrt(a)) * 6371 * 1000 # 地球平均半径,6371km
|
115 |
+
return distance
|
116 |
+
|
117 |
+
class PreparaDataset:
|
118 |
+
def __init__(
|
119 |
+
self,
|
120 |
+
root: Path,
|
121 |
+
city:str,
|
122 |
+
patch_size:int,
|
123 |
+
tile_size_meters:float
|
124 |
+
):
|
125 |
+
super().__init__()
|
126 |
+
|
127 |
+
# self.root = root
|
128 |
+
|
129 |
+
# city = 'Manhattan'
|
130 |
+
# root = '/root/DATASET/CrossModel/'
|
131 |
+
imagepath = root/city/ '{}.tif'.format(city)
|
132 |
+
tfwpath = root/city/'{}.tfw'.format(city)
|
133 |
+
|
134 |
+
self.osmpath = root/city/'{}.osm'.format(city)
|
135 |
+
|
136 |
+
self.TileManager=MapTileManager(self.osmpath)
|
137 |
+
image = cv2.imread(str(imagepath))
|
138 |
+
self.image=cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
|
139 |
+
|
140 |
+
self.ST = SatelliteGeoTools(str(tfwpath))
|
141 |
+
|
142 |
+
self.patch_size=patch_size
|
143 |
+
self.tile_size_meters=tile_size_meters
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
def get_osm(self,prior_latlon,uav_latlon):
|
148 |
+
latlon = np.array(prior_latlon)
|
149 |
+
proj = Projection(*latlon)
|
150 |
+
center = proj.project(latlon)
|
151 |
+
|
152 |
+
uav_latlon=np.array(uav_latlon)
|
153 |
+
|
154 |
+
XY=proj.project(uav_latlon)
|
155 |
+
# tile_size_meters = 128
|
156 |
+
bbox = BoundaryBox(center, center) + self.tile_size_meters
|
157 |
+
# bbox= BoundaryBox(center, center)
|
158 |
+
# Query OpenStreetMap for this area
|
159 |
+
self.pixel_per_meter = 1
|
160 |
+
start_time = time.time()
|
161 |
+
canvas = self.TileManager.from_bbox(proj, bbox, self.pixel_per_meter)
|
162 |
+
end_time = time.time()
|
163 |
+
execution_time = end_time - start_time
|
164 |
+
# print("方法执行时间:", execution_time, "秒")
|
165 |
+
# canvas = tiler.query(bbox)
|
166 |
+
XY=[XY[0]+self.tile_size_meters,-XY[1]+self.tile_size_meters]
|
167 |
+
return canvas,XY
|
168 |
+
def random_corp(self):
|
169 |
+
|
170 |
+
# 根据随机裁剪尺寸计算出裁剪区域的左上角坐标
|
171 |
+
x = random.randint(1000, self.image.shape[1] - self.patch_size-1000)
|
172 |
+
y = random.randint(1000, self.image.shape[0] - self.patch_size-1000)
|
173 |
+
x1 = x + self.patch_size
|
174 |
+
y1 = y + self.patch_size
|
175 |
+
return x,x1,y,y1
|
176 |
+
|
177 |
+
def generate(self):
|
178 |
+
x,x1,y,y1 = self.random_corp()
|
179 |
+
uav_center_x,uav_center_y=int((x+x1)//2),int((y+y1)//2)
|
180 |
+
uav_center_long,uav_center_lat=self.ST.Pix2Geo(uav_center_x,uav_center_y)
|
181 |
+
# print(uav_center_long,uav_center_lat)
|
182 |
+
self.image_patch = self.image[y:y1, x:x1]
|
183 |
+
|
184 |
+
map_center_lat, map_center_long = generate_random_coordinate(uav_center_lat, uav_center_long, self.tile_size_meters)
|
185 |
+
map,XY=self.get_osm([map_center_lat,map_center_long],[uav_center_lat, uav_center_long])
|
186 |
+
|
187 |
+
|
188 |
+
yaw=np.random.random()*360
|
189 |
+
self.image_patch=rotate_corp(self.image_patch,yaw)
|
190 |
+
# return self.image_patch,self.osm_patch
|
191 |
+
# XY=[X+self.tile_size_meters
|
192 |
+
return {
|
193 |
+
'uav_image':self.image_patch,
|
194 |
+
'uav_long_lat':[uav_center_long,uav_center_lat],
|
195 |
+
'map_long_lat': [map_center_long,map_center_lat],
|
196 |
+
'tile_size_meters': map.raster.shape[1],
|
197 |
+
'pixel_per_meter':self.pixel_per_meter,
|
198 |
+
'yaw':yaw,
|
199 |
+
'map':map.raster,
|
200 |
+
"uv":XY
|
201 |
+
}
|
202 |
+
if __name__ == '__main__':
|
203 |
+
|
204 |
+
import argparse
|
205 |
+
|
206 |
+
parser = argparse.ArgumentParser(description='manual to this script')
|
207 |
+
parser.add_argument('--city', type=str, default=None,required=True)
|
208 |
+
parser.add_argument('--num', type=int, default=10000)
|
209 |
+
args = parser.parse_args()
|
210 |
+
|
211 |
+
|
212 |
+
root=Path('/root/DATASET/OrienterNet/UavMap/')
|
213 |
+
city=args.city
|
214 |
+
dataset = PreparaDataset(
|
215 |
+
root=root,
|
216 |
+
city=city,
|
217 |
+
patch_size=512,
|
218 |
+
tile_size_meters=128,
|
219 |
+
)
|
220 |
+
|
221 |
+
uav_path=root/city/'uav'
|
222 |
+
if not uav_path.exists():
|
223 |
+
uav_path.mkdir(parents=True)
|
224 |
+
|
225 |
+
map_path = root / city / 'map'
|
226 |
+
if not map_path.exists():
|
227 |
+
map_path.mkdir(parents=True)
|
228 |
+
|
229 |
+
map_vis_path = root / city / 'map_vis'
|
230 |
+
if not map_vis_path.exists():
|
231 |
+
map_vis_path.mkdir(parents=True)
|
232 |
+
|
233 |
+
info_path = root / city / 'info.csv'
|
234 |
+
|
235 |
+
# num=1000
|
236 |
+
num = args.num
|
237 |
+
info=[['id','uav_name','map_name','uav_long','uav_lat','map_long','map_lat','tile_size_meters','pixel_per_meter','u','v','yaw']]
|
238 |
+
# info =[]
|
239 |
+
for i in tqdm(range(num)):
|
240 |
+
data=dataset.generate()
|
241 |
+
# print(str(uav_path/"{:05d}.jpg".format(i)))
|
242 |
+
|
243 |
+
cv2.imwrite(str(uav_path/"{:05d}.jpg".format(i)),cv2.cvtColor(data['uav_image'],cv2.COLOR_RGB2BGR))
|
244 |
+
|
245 |
+
np.save(str(map_path/"{:05d}.npy".format(i)),data['map'])
|
246 |
+
|
247 |
+
map_viz, label = Colormap.apply(data['map'])
|
248 |
+
map_viz = map_viz * 255
|
249 |
+
map_viz = map_viz.astype(np.uint8)
|
250 |
+
cv2.imwrite(str(map_vis_path / "{:05d}.jpg".format(i)), cv2.cvtColor(map_viz, cv2.COLOR_RGB2BGR))
|
251 |
+
|
252 |
+
|
253 |
+
uav_center_long, uav_center_lat=data['uav_long_lat']
|
254 |
+
map_center_long, map_center_lat = data['map_long_lat']
|
255 |
+
info.append([
|
256 |
+
i,
|
257 |
+
"{:05d}.jpg".format(i),
|
258 |
+
"{:05d}.npy".format(i),
|
259 |
+
uav_center_long,
|
260 |
+
uav_center_lat,
|
261 |
+
map_center_long,
|
262 |
+
map_center_lat,
|
263 |
+
data["tile_size_meters"],
|
264 |
+
data["pixel_per_meter"],
|
265 |
+
data['uv'][0],
|
266 |
+
data['uv'][1],
|
267 |
+
data['yaw']
|
268 |
+
])
|
269 |
+
# print(info)
|
270 |
+
np.savetxt(info_path,info,delimiter=',',fmt="%s")
|
dataset/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from .UAV.dataset import UavMapPair
|
2 |
+
from .dataset import UavMapDatasetModule
|
3 |
+
|
4 |
+
# modules = {"UAV": UavMapPair}
|
dataset/dataset.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
from copy import deepcopy
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Any, Dict, List
|
6 |
+
# from logger import logger
|
7 |
+
import numpy as np
|
8 |
+
# import torch
|
9 |
+
# import torch.utils.data as torchdata
|
10 |
+
# import torchvision.transforms as tvf
|
11 |
+
from omegaconf import DictConfig, OmegaConf
|
12 |
+
import pytorch_lightning as pl
|
13 |
+
from dataset.UAV.dataset import UavMapPair
|
14 |
+
# from torch.utils.data import Dataset, DataLoader
|
15 |
+
# from torchvision import transforms
|
16 |
+
from torch.utils.data import Dataset, ConcatDataset
|
17 |
+
from torch.utils.data import Dataset, DataLoader, random_split
|
18 |
+
import torchvision.transforms as tvf
|
19 |
+
|
20 |
+
# 自定义数据模块类,继承自pl.LightningDataModule
|
21 |
+
class UavMapDatasetModule(pl.LightningDataModule):
|
22 |
+
|
23 |
+
|
24 |
+
def __init__(self, cfg: Dict[str, Any]):
|
25 |
+
super().__init__()
|
26 |
+
|
27 |
+
# default_cfg = OmegaConf.create(self.default_cfg)
|
28 |
+
# OmegaConf.set_struct(default_cfg, True) # cannot add new keys
|
29 |
+
# self.cfg = OmegaConf.merge(default_cfg, cfg)
|
30 |
+
self.cfg=cfg
|
31 |
+
# self.transform = tvf.Compose([
|
32 |
+
# tvf.ToTensor(),
|
33 |
+
# tvf.Resize(self.cfg.image_size),
|
34 |
+
# tvf.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
35 |
+
# ])
|
36 |
+
|
37 |
+
tfs = []
|
38 |
+
tfs.append(tvf.ToTensor())
|
39 |
+
tfs.append(tvf.Resize(self.cfg.image_size))
|
40 |
+
self.val_tfs = tvf.Compose(tfs)
|
41 |
+
|
42 |
+
# transforms.Resize(self.cfg.image_size),
|
43 |
+
if cfg.augmentation.image.apply:
|
44 |
+
args = OmegaConf.masked_copy(
|
45 |
+
cfg.augmentation.image, ["brightness", "contrast", "saturation", "hue"]
|
46 |
+
)
|
47 |
+
tfs.append(tvf.ColorJitter(**args))
|
48 |
+
self.train_tfs = tvf.Compose(tfs)
|
49 |
+
|
50 |
+
# self.train_tfs=self.transform
|
51 |
+
# self.val_tfs = self.transform
|
52 |
+
self.init()
|
53 |
+
def init(self):
|
54 |
+
self.train_dataset = ConcatDataset([
|
55 |
+
UavMapPair(root=Path(self.cfg.root),city=city,training=True,transform=self.train_tfs)
|
56 |
+
for city in self.cfg.train_citys
|
57 |
+
])
|
58 |
+
|
59 |
+
self.val_dataset = ConcatDataset([
|
60 |
+
UavMapPair(root=Path(self.cfg.root),city=city,training=False,transform=self.val_tfs)
|
61 |
+
for city in self.cfg.val_citys
|
62 |
+
])
|
63 |
+
|
64 |
+
# self.val_datasets = {
|
65 |
+
# city:UavMapPair(root=Path(self.cfg.root),city=city,transform=self.val_tfs)
|
66 |
+
# for city in self.cfg.val_citys
|
67 |
+
# }
|
68 |
+
# logger.info("train data len:{},val data len:{}".format(len(self.train_dataset),len(self.val_dataset)))
|
69 |
+
# # 定义分割比例
|
70 |
+
# train_ratio = 0.8 # 训练集比例
|
71 |
+
# # 计算分割的样本数量
|
72 |
+
# train_size = int(len(self.dataset) * train_ratio)
|
73 |
+
# val_size = len(self.dataset) - train_size
|
74 |
+
# self.train_dataset, self.val_dataset = random_split(self.dataset, [train_size, val_size])
|
75 |
+
def train_dataloader(self):
|
76 |
+
train_loader = DataLoader(self.train_dataset,
|
77 |
+
batch_size=self.cfg.train.batch_size,
|
78 |
+
num_workers=self.cfg.train.num_workers,
|
79 |
+
shuffle=True,pin_memory = True)
|
80 |
+
return train_loader
|
81 |
+
|
82 |
+
def val_dataloader(self):
|
83 |
+
val_loader = DataLoader(self.val_dataset,
|
84 |
+
batch_size=self.cfg.val.batch_size,
|
85 |
+
num_workers=self.cfg.val.num_workers,
|
86 |
+
shuffle=True,pin_memory = True)
|
87 |
+
#
|
88 |
+
# my_dict = {k: v for k, v in self.val_datasets}
|
89 |
+
# val_loaders={city: DataLoader(dataset,
|
90 |
+
# batch_size=self.cfg.val.batch_size,
|
91 |
+
# num_workers=self.cfg.val.num_workers,
|
92 |
+
# shuffle=False,pin_memory = True) for city, dataset in self.val_datasets.items()}
|
93 |
+
return val_loader
|
dataset/image.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
from typing import Callable, Optional, Union, Sequence
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torchvision.transforms.functional as tvf
|
8 |
+
import collections
|
9 |
+
from scipy.spatial.transform import Rotation
|
10 |
+
|
11 |
+
from utils.geometry import from_homogeneous, to_homogeneous
|
12 |
+
from utils.wrappers import Camera
|
13 |
+
|
14 |
+
|
15 |
+
def rectify_image(
|
16 |
+
image: torch.Tensor,
|
17 |
+
cam: Camera,
|
18 |
+
roll: float,
|
19 |
+
pitch: Optional[float] = None,
|
20 |
+
valid: Optional[torch.Tensor] = None,
|
21 |
+
):
|
22 |
+
*_, h, w = image.shape
|
23 |
+
grid = torch.meshgrid(
|
24 |
+
[torch.arange(w, device=image.device), torch.arange(h, device=image.device)],
|
25 |
+
indexing="xy",
|
26 |
+
)
|
27 |
+
grid = torch.stack(grid, -1).to(image.dtype)
|
28 |
+
|
29 |
+
if pitch is not None:
|
30 |
+
args = ("ZX", (roll, pitch))
|
31 |
+
else:
|
32 |
+
args = ("Z", roll)
|
33 |
+
R = Rotation.from_euler(*args, degrees=True).as_matrix()
|
34 |
+
R = torch.from_numpy(R).to(image)
|
35 |
+
|
36 |
+
grid_rect = to_homogeneous(cam.normalize(grid)) @ R.T
|
37 |
+
grid_rect = cam.denormalize(from_homogeneous(grid_rect))
|
38 |
+
grid_norm = (grid_rect + 0.5) / grid.new_tensor([w, h]) * 2 - 1
|
39 |
+
rectified = torch.nn.functional.grid_sample(
|
40 |
+
image[None],
|
41 |
+
grid_norm[None],
|
42 |
+
align_corners=False,
|
43 |
+
mode="bilinear",
|
44 |
+
).squeeze(0)
|
45 |
+
if valid is None:
|
46 |
+
valid = torch.all((grid_norm >= -1) & (grid_norm <= 1), -1)
|
47 |
+
else:
|
48 |
+
valid = (
|
49 |
+
torch.nn.functional.grid_sample(
|
50 |
+
valid[None, None].float(),
|
51 |
+
grid_norm[None],
|
52 |
+
align_corners=False,
|
53 |
+
mode="nearest",
|
54 |
+
)[0, 0]
|
55 |
+
> 0
|
56 |
+
)
|
57 |
+
return rectified, valid
|
58 |
+
|
59 |
+
|
60 |
+
def resize_image(
|
61 |
+
image: torch.Tensor,
|
62 |
+
size: Union[int, Sequence, np.ndarray],
|
63 |
+
fn: Optional[Callable] = None,
|
64 |
+
camera: Optional[Camera] = None,
|
65 |
+
valid: np.ndarray = None,
|
66 |
+
):
|
67 |
+
"""Resize an image to a fixed size, or according to max or min edge."""
|
68 |
+
*_, h, w = image.shape
|
69 |
+
if fn is not None:
|
70 |
+
assert isinstance(size, int)
|
71 |
+
scale = size / fn(h, w)
|
72 |
+
h_new, w_new = int(round(h * scale)), int(round(w * scale))
|
73 |
+
scale = (scale, scale)
|
74 |
+
else:
|
75 |
+
if isinstance(size, (collections.abc.Sequence, np.ndarray)):
|
76 |
+
w_new, h_new = size
|
77 |
+
elif isinstance(size, int):
|
78 |
+
w_new = h_new = size
|
79 |
+
else:
|
80 |
+
raise ValueError(f"Incorrect new size: {size}")
|
81 |
+
scale = (w_new / w, h_new / h)
|
82 |
+
if (w, h) != (w_new, h_new):
|
83 |
+
mode = tvf.InterpolationMode.BILINEAR
|
84 |
+
image = tvf.resize(image, (h_new, w_new), interpolation=mode, antialias=True)
|
85 |
+
image.clip_(0, 1)
|
86 |
+
if camera is not None:
|
87 |
+
camera = camera.scale(scale)
|
88 |
+
if valid is not None:
|
89 |
+
valid = tvf.resize(
|
90 |
+
valid.unsqueeze(0),
|
91 |
+
(h_new, w_new),
|
92 |
+
interpolation=tvf.InterpolationMode.NEAREST,
|
93 |
+
).squeeze(0)
|
94 |
+
ret = [image, scale]
|
95 |
+
if camera is not None:
|
96 |
+
ret.append(camera)
|
97 |
+
if valid is not None:
|
98 |
+
ret.append(valid)
|
99 |
+
return ret
|
100 |
+
|
101 |
+
|
102 |
+
def pad_image(
|
103 |
+
image: torch.Tensor,
|
104 |
+
size: Union[int, Sequence, np.ndarray],
|
105 |
+
camera: Optional[Camera] = None,
|
106 |
+
valid: torch.Tensor = None,
|
107 |
+
crop_and_center: bool = False,
|
108 |
+
):
|
109 |
+
if isinstance(size, int):
|
110 |
+
w_new = h_new = size
|
111 |
+
elif isinstance(size, (collections.abc.Sequence, np.ndarray)):
|
112 |
+
w_new, h_new = size
|
113 |
+
else:
|
114 |
+
raise ValueError(f"Incorrect new size: {size}")
|
115 |
+
*c, h, w = image.shape
|
116 |
+
if crop_and_center:
|
117 |
+
diff = np.array([w - w_new, h - h_new])
|
118 |
+
left, top = left_top = np.round(diff / 2).astype(int)
|
119 |
+
right, bottom = diff - left_top
|
120 |
+
else:
|
121 |
+
assert h <= h_new
|
122 |
+
assert w <= w_new
|
123 |
+
top = bottom = left = right = 0
|
124 |
+
slice_out = np.s_[..., : min(h, h_new), : min(w, w_new)]
|
125 |
+
slice_in = np.s_[
|
126 |
+
..., max(top, 0) : h - max(bottom, 0), max(left, 0) : w - max(right, 0)
|
127 |
+
]
|
128 |
+
if (w, h) == (w_new, h_new):
|
129 |
+
out = image
|
130 |
+
else:
|
131 |
+
out = torch.zeros((*c, h_new, w_new), dtype=image.dtype)
|
132 |
+
out[slice_out] = image[slice_in]
|
133 |
+
if camera is not None:
|
134 |
+
camera = camera.crop((max(left, 0), max(top, 0)), (w_new, h_new))
|
135 |
+
out_valid = torch.zeros((h_new, w_new), dtype=torch.bool)
|
136 |
+
out_valid[slice_out] = True if valid is None else valid[slice_in]
|
137 |
+
if camera is not None:
|
138 |
+
return out, out_valid, camera
|
139 |
+
else:
|
140 |
+
return out, out_valid
|
dataset/torch.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import collections
|
4 |
+
import os
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import get_worker_info
|
8 |
+
from torch.utils.data._utils.collate import (
|
9 |
+
default_collate_err_msg_format,
|
10 |
+
np_str_obj_array_pattern,
|
11 |
+
)
|
12 |
+
from lightning_fabric.utilities.seed import pl_worker_init_function
|
13 |
+
from lightning_utilities.core.apply_func import apply_to_collection
|
14 |
+
from lightning_fabric.utilities.apply_func import move_data_to_device
|
15 |
+
|
16 |
+
|
17 |
+
def collate(batch):
|
18 |
+
"""Difference with PyTorch default_collate: it can stack other tensor-like objects.
|
19 |
+
Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
|
20 |
+
https://github.com/cvg/pixloc
|
21 |
+
Released under the Apache License 2.0
|
22 |
+
"""
|
23 |
+
if not isinstance(batch, list): # no batching
|
24 |
+
return batch
|
25 |
+
elem = batch[0]
|
26 |
+
elem_type = type(elem)
|
27 |
+
if isinstance(elem, torch.Tensor):
|
28 |
+
out = None
|
29 |
+
if torch.utils.data.get_worker_info() is not None:
|
30 |
+
# If we're in a background process, concatenate directly into a
|
31 |
+
# shared memory tensor to avoid an extra copy
|
32 |
+
numel = sum(x.numel() for x in batch)
|
33 |
+
storage = elem.storage()._new_shared(numel, device=elem.device)
|
34 |
+
out = elem.new(storage).resize_(len(batch), *list(elem.size()))
|
35 |
+
return torch.stack(batch, 0, out=out)
|
36 |
+
elif (
|
37 |
+
elem_type.__module__ == "numpy"
|
38 |
+
and elem_type.__name__ != "str_"
|
39 |
+
and elem_type.__name__ != "string_"
|
40 |
+
):
|
41 |
+
if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap":
|
42 |
+
# array of string classes and object
|
43 |
+
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
44 |
+
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
|
45 |
+
|
46 |
+
return collate([torch.as_tensor(b) for b in batch])
|
47 |
+
elif elem.shape == (): # scalars
|
48 |
+
return torch.as_tensor(batch)
|
49 |
+
elif isinstance(elem, float):
|
50 |
+
return torch.tensor(batch, dtype=torch.float64)
|
51 |
+
elif isinstance(elem, int):
|
52 |
+
return torch.tensor(batch)
|
53 |
+
elif isinstance(elem, (str, bytes)):
|
54 |
+
return batch
|
55 |
+
elif isinstance(elem, collections.abc.Mapping):
|
56 |
+
return {key: collate([d[key] for d in batch]) for key in elem}
|
57 |
+
elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
|
58 |
+
return elem_type(*(collate(samples) for samples in zip(*batch)))
|
59 |
+
elif isinstance(elem, collections.abc.Sequence):
|
60 |
+
# check to make sure that the elements in batch have consistent size
|
61 |
+
it = iter(batch)
|
62 |
+
elem_size = len(next(it))
|
63 |
+
if not all(len(elem) == elem_size for elem in it):
|
64 |
+
raise RuntimeError("each element in list of batch should be of equal size")
|
65 |
+
transposed = zip(*batch)
|
66 |
+
return [collate(samples) for samples in transposed]
|
67 |
+
else:
|
68 |
+
# try to stack anyway in case the object implements stacking.
|
69 |
+
try:
|
70 |
+
return torch.stack(batch, 0)
|
71 |
+
except TypeError as e:
|
72 |
+
if "expected Tensor as element" in str(e):
|
73 |
+
return batch
|
74 |
+
else:
|
75 |
+
raise e
|
76 |
+
|
77 |
+
|
78 |
+
def set_num_threads(nt):
|
79 |
+
"""Force numpy and other libraries to use a limited number of threads."""
|
80 |
+
try:
|
81 |
+
import mkl
|
82 |
+
except ImportError:
|
83 |
+
pass
|
84 |
+
else:
|
85 |
+
mkl.set_num_threads(nt)
|
86 |
+
torch.set_num_threads(1)
|
87 |
+
os.environ["IPC_ENABLE"] = "1"
|
88 |
+
for o in [
|
89 |
+
"OPENBLAS_NUM_THREADS",
|
90 |
+
"NUMEXPR_NUM_THREADS",
|
91 |
+
"OMP_NUM_THREADS",
|
92 |
+
"MKL_NUM_THREADS",
|
93 |
+
]:
|
94 |
+
os.environ[o] = str(nt)
|
95 |
+
|
96 |
+
|
97 |
+
def worker_init_fn(i):
|
98 |
+
info = get_worker_info()
|
99 |
+
pl_worker_init_function(info.id)
|
100 |
+
num_threads = info.dataset.cfg.get("num_threads")
|
101 |
+
if num_threads is not None:
|
102 |
+
set_num_threads(num_threads)
|
103 |
+
|
104 |
+
|
105 |
+
def unbatch_to_device(data, device="cpu"):
|
106 |
+
data = move_data_to_device(data, device)
|
107 |
+
data = apply_to_collection(data, torch.Tensor, lambda x: x.squeeze(0))
|
108 |
+
data = apply_to_collection(
|
109 |
+
data, list, lambda x: x[0] if len(x) == 1 and isinstance(x[0], str) else x
|
110 |
+
)
|
111 |
+
return data
|
demo.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
# from demo import Demo, read_input_image,read_input_image_test
|
4 |
+
from evaluation.viz import plot_example_single
|
5 |
+
from dataset.torch import unbatch_to_device
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from typing import Optional, Tuple
|
8 |
+
import cv2
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
import time
|
12 |
+
from logger import logger
|
13 |
+
from evaluation.run import resolve_checkpoint_path, pretrained_models
|
14 |
+
from models.maplocnet import MapLocNet
|
15 |
+
from models.voting import fuse_gps, argmax_xyr
|
16 |
+
# from data.image import resize_image, pad_image, rectify_image
|
17 |
+
from osm.raster import Canvas
|
18 |
+
from utils.wrappers import Camera
|
19 |
+
from utils.io import read_image
|
20 |
+
from utils.geo import BoundaryBox, Projection
|
21 |
+
from utils.exif import EXIF
|
22 |
+
import requests
|
23 |
+
from pathlib import Path
|
24 |
+
from utils.exif import EXIF
|
25 |
+
from dataset.image import resize_image, pad_image, rectify_image
|
26 |
+
# from maploc.demo import Demo, read_input_image
|
27 |
+
from dataset import UavMapDatasetModule
|
28 |
+
import torchvision.transforms as tvf
|
29 |
+
import matplotlib.pyplot as plt
|
30 |
+
import numpy as np
|
31 |
+
from sklearn.decomposition import PCA
|
32 |
+
from PIL import Image
|
33 |
+
# import pyproj
|
34 |
+
# Query OpenStreetMap for this area
|
35 |
+
from osm.tiling import TileManager
|
36 |
+
from utils.viz_localization import (
|
37 |
+
likelihood_overlay,
|
38 |
+
plot_dense_rotations,
|
39 |
+
add_circle_inset,
|
40 |
+
)
|
41 |
+
# Show the inputs to the model: image and raster map
|
42 |
+
from osm.viz import Colormap, plot_nodes
|
43 |
+
from utils.viz_2d import plot_images
|
44 |
+
|
45 |
+
from utils.viz_2d import features_to_RGB
|
46 |
+
import random
|
47 |
+
from geopy.distance import geodesic
|
48 |
+
|
49 |
+
|
50 |
+
def vis_image_feature(F):
|
51 |
+
def normalize(x):
|
52 |
+
return x / np.linalg.norm(x, axis=-1, keepdims=True)
|
53 |
+
|
54 |
+
# F=neural_map.numpy()
|
55 |
+
F = F[:, 0:180, 0:180]
|
56 |
+
flatten = []
|
57 |
+
c, h, w = F.shape
|
58 |
+
print(F.shape)
|
59 |
+
F = np.rollaxis(F, 0, 3)
|
60 |
+
F_flat = F.reshape(-1, c)
|
61 |
+
flatten.append(F_flat)
|
62 |
+
flatten = normalize(flatten)[0]
|
63 |
+
|
64 |
+
flatten = np.nan_to_num(flatten, nan=0)
|
65 |
+
pca = PCA(n_components=3)
|
66 |
+
|
67 |
+
print(flatten.shape)
|
68 |
+
flatten = pca.fit_transform(flatten)
|
69 |
+
flatten = (normalize(flatten) + 1) / 2
|
70 |
+
|
71 |
+
# h, w = F.shape[-2:]
|
72 |
+
F_rgb, flatten = np.split(flatten, [h * w], axis=0)
|
73 |
+
F_rgb = F_rgb.reshape((h, w, 3))
|
74 |
+
return F_rgb
|
75 |
+
def distance(lat1, lon1, lat2, lon2):
|
76 |
+
point1 = (lat1, lon1)
|
77 |
+
point2 = (lat2, lon2)
|
78 |
+
distance_km = geodesic(point1, point2).meters
|
79 |
+
return distance_km
|
80 |
+
|
81 |
+
# # 示例
|
82 |
+
# lat1, lon1 = 39.9, 116.4 # 北京的经纬度
|
83 |
+
# lat2, lon2 = 31.2, 121.5 # 上海的经纬度
|
84 |
+
|
85 |
+
# distance_km = distance(lat1, lon1, lat2, lon2)
|
86 |
+
# print(distance_km)
|
87 |
+
def show_result(map_vis_image, pre_uv, pre_yaw):
|
88 |
+
# 创建一个和原始图片大小相同的灰色蒙版图像
|
89 |
+
gray_mask = np.zeros_like(map_vis_image)
|
90 |
+
gray_mask.fill(128) # 填充灰色
|
91 |
+
|
92 |
+
# 将灰色蒙版图像与原始图像进行融合
|
93 |
+
image = cv2.addWeighted(map_vis_image, 1, gray_mask, 0, 0)
|
94 |
+
# 绘制真实值
|
95 |
+
|
96 |
+
# 绘制预测值
|
97 |
+
u, v = pre_uv
|
98 |
+
x1, y1 = int(u), int(v) # 替换为实际的起点坐标
|
99 |
+
angle = pre_yaw - 90 # 替换为实际的箭头角度
|
100 |
+
# 计算箭头的终点坐标
|
101 |
+
length = 20
|
102 |
+
x2 = int(x1 + length * np.cos(np.radians(angle)))
|
103 |
+
y2 = int(y1 + length * np.sin(np.radians(angle)))
|
104 |
+
# 在图像上画出箭头
|
105 |
+
cv2.arrowedLine(image, (x1, y1), (x2, y2), (0, 0, 0), 2, 5, 0, 0.3)
|
106 |
+
# cv2.circle(image, (x1, y1), radius=2, color=(255, 0, 255), thickness=-1)
|
107 |
+
return image
|
108 |
+
|
109 |
+
|
110 |
+
def xyz_to_latlon(x, y, z):
|
111 |
+
# 定义WGS84投影
|
112 |
+
wgs84 = pyproj.CRS('EPSG:4326')
|
113 |
+
|
114 |
+
# 定义XYZ投影
|
115 |
+
xyz = pyproj.CRS(f'+proj=geocent +datum=WGS84 +units=m +no_defs')
|
116 |
+
|
117 |
+
# 创建坐标转换器
|
118 |
+
transformer = pyproj.Transformer.from_crs(xyz, wgs84)
|
119 |
+
|
120 |
+
# 转换坐标
|
121 |
+
lon, lat, _ = transformer.transform(x, y, z)
|
122 |
+
|
123 |
+
return lat, lon
|
124 |
+
|
125 |
+
|
126 |
+
class Demo:
|
127 |
+
def __init__(
|
128 |
+
self,
|
129 |
+
experiment_or_path: Optional[str] = "OrienterNet_MGL",
|
130 |
+
device=None,
|
131 |
+
**kwargs
|
132 |
+
):
|
133 |
+
if experiment_or_path in pretrained_models:
|
134 |
+
experiment_or_path, _ = pretrained_models[experiment_or_path]
|
135 |
+
path = resolve_checkpoint_path(experiment_or_path)
|
136 |
+
ckpt = torch.load(path, map_location=(lambda storage, loc: storage))
|
137 |
+
config = ckpt["hyper_parameters"]
|
138 |
+
config.model.update(kwargs)
|
139 |
+
config.model.image_encoder.backbone.pretrained = False
|
140 |
+
|
141 |
+
model = MapLocNet(config.model).eval()
|
142 |
+
state = {k[len("model."):]: v for k, v in ckpt["state_dict"].items()}
|
143 |
+
model.load_state_dict(state, strict=True)
|
144 |
+
if device is None:
|
145 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
146 |
+
model = model.to(device)
|
147 |
+
|
148 |
+
self.model = model
|
149 |
+
self.config = config
|
150 |
+
self.device = device
|
151 |
+
|
152 |
+
def prepare_data(
|
153 |
+
self,
|
154 |
+
image: np.ndarray,
|
155 |
+
camera: Camera,
|
156 |
+
canvas: Canvas,
|
157 |
+
roll_pitch: Optional[Tuple[float]] = None,
|
158 |
+
):
|
159 |
+
|
160 |
+
image = torch.from_numpy(image).permute(2, 0, 1).float().div_(255)
|
161 |
+
|
162 |
+
return {
|
163 |
+
'map': torch.from_numpy(canvas.raster).long(),
|
164 |
+
'image': image,
|
165 |
+
# 'roll_pitch_yaw':torch.tensor((0, 0, float(yaw))).float().unsqueeze(0),
|
166 |
+
# 'pixels_per_meter':torch.tensor(float(pixel_per_meter)).float().unsqueeze(0),
|
167 |
+
# "uv":torch.tensor([float(u), float(v)]).float().unsqueeze(0),
|
168 |
+
}
|
169 |
+
# return dict(
|
170 |
+
# image=image,
|
171 |
+
# map=torch.from_numpy(canvas.raster).long(),
|
172 |
+
# camera=camera.float(),
|
173 |
+
# valid=valid,
|
174 |
+
# )
|
175 |
+
|
176 |
+
def localize(self, image: np.ndarray, camera: Camera, canvas: Canvas, **kwargs):
|
177 |
+
|
178 |
+
data = self.prepare_data(image, camera, canvas, **kwargs)
|
179 |
+
data_ = {k: v.to(self.device)[None] for k, v in data.items()}
|
180 |
+
# data_np = {k: v.cpu().numpy()[None] for k, v in data.items()}
|
181 |
+
# logger.info(data_)
|
182 |
+
# np.save(data_np, 'data_.npy')
|
183 |
+
start = time.time()
|
184 |
+
with torch.no_grad():
|
185 |
+
pred = self.model(data_)
|
186 |
+
|
187 |
+
end = time.time()
|
188 |
+
xy_gps = canvas.bbox.center
|
189 |
+
uv_gps = torch.from_numpy(canvas.to_uv(xy_gps))
|
190 |
+
|
191 |
+
lp_xyr = pred["log_probs"].squeeze(0)
|
192 |
+
# tile_size = canvas.bbox.size.min() / 2
|
193 |
+
# sigma = tile_size - 20 # 20 meters margin
|
194 |
+
# lp_xyr = fuse_gps(
|
195 |
+
# lp_xyr,
|
196 |
+
# uv_gps.to(lp_xyr),
|
197 |
+
# self.config.model.pixel_per_meter,
|
198 |
+
# sigma=sigma,
|
199 |
+
# )
|
200 |
+
xyr = argmax_xyr(lp_xyr).cpu()
|
201 |
+
|
202 |
+
prob = lp_xyr.exp().cpu()
|
203 |
+
neural_map = pred["map"]["map_features"][0].squeeze(0).cpu()
|
204 |
+
print('total time:', start - end)
|
205 |
+
return xyr[:2], xyr[2], prob, neural_map, data["image"], data_, pred
|
206 |
+
|
207 |
+
|
208 |
+
def load_test_data(
|
209 |
+
root: Path,
|
210 |
+
city: str,
|
211 |
+
index: int,
|
212 |
+
):
|
213 |
+
uav_image_path = root / city / 'uav'
|
214 |
+
map_path = root / city / 'map'
|
215 |
+
map_vis = root / city / 'map_vis'
|
216 |
+
info_path = root / city / 'info.csv'
|
217 |
+
osm_path = root / city / '{}.osm'.format(city)
|
218 |
+
|
219 |
+
info = np.loadtxt(str(info_path), dtype=str, delimiter=",", skiprows=1)
|
220 |
+
|
221 |
+
id, uav_name, map_name, \
|
222 |
+
uav_long, uav_lat, \
|
223 |
+
map_long, map_lat, \
|
224 |
+
tile_size_meters, pixel_per_meter, \
|
225 |
+
u, v, yaw, dis = info[index]
|
226 |
+
print(info[index])
|
227 |
+
uav_image_rgb = cv2.imread(str(uav_image_path / uav_name))
|
228 |
+
uav_image_rgb = cv2.cvtColor(uav_image_rgb, cv2.COLOR_BGR2RGB)
|
229 |
+
|
230 |
+
# w,h,c=uav_image_rgb.shape
|
231 |
+
# # 指定裁剪区域的坐标
|
232 |
+
# x = w//2 # 起始横坐标
|
233 |
+
# y = h//2 # 起始纵坐标
|
234 |
+
# w = 150 # 宽度
|
235 |
+
# h = 150 # 高度
|
236 |
+
|
237 |
+
# # 裁剪图像
|
238 |
+
# uav_image_rgb = uav_image_rgb[y-h:y+h, x-w:x+w]
|
239 |
+
|
240 |
+
map_vis_image = cv2.imread(str(map_vis / uav_name))
|
241 |
+
map_vis_image = cv2.cvtColor(map_vis_image, cv2.COLOR_BGR2RGB)
|
242 |
+
|
243 |
+
map = np.load(str(map_path / map_name))
|
244 |
+
|
245 |
+
tfs = []
|
246 |
+
tfs.append(tvf.ToTensor())
|
247 |
+
tfs.append(tvf.Resize(256))
|
248 |
+
val_tfs = tvf.Compose(tfs)
|
249 |
+
|
250 |
+
uav_image = val_tfs(uav_image_rgb)
|
251 |
+
# print(id, uav_name, map_name, \
|
252 |
+
# uav_long, uav_lat, \
|
253 |
+
# map_long, map_lat, \
|
254 |
+
# tile_size_meters, pixel_per_meter, \
|
255 |
+
# u, v, yaw,dis)
|
256 |
+
uav_path = str(uav_image_path / uav_name)
|
257 |
+
return {
|
258 |
+
'map': torch.from_numpy(np.ascontiguousarray(map)).long().unsqueeze(0),
|
259 |
+
'image': torch.tensor(uav_image).unsqueeze(0),
|
260 |
+
'roll_pitch_yaw': torch.tensor((0, 0, float(yaw))).float().unsqueeze(0),
|
261 |
+
'pixels_per_meter': torch.tensor(float(pixel_per_meter)).float().unsqueeze(0),
|
262 |
+
"uv": torch.tensor([float(u), float(v)]).float().unsqueeze(0),
|
263 |
+
}, uav_image_rgb, map_vis_image, uav_path, [float(map_lat), float(map_long)]
|
264 |
+
|
265 |
+
|
266 |
+
def crop_image(image, width, height):
|
267 |
+
# 计算剪裁区域的起始点坐标
|
268 |
+
x = int((image.shape[1] - width) / 2)
|
269 |
+
y = int((image.shape[0] - height) / 2)
|
270 |
+
|
271 |
+
# 剪裁图像
|
272 |
+
cropped_image = image[y:y + height, x:x + width]
|
273 |
+
return cropped_image
|
274 |
+
|
275 |
+
|
276 |
+
def crop_square(image):
|
277 |
+
# 获取图像的宽度和高度
|
278 |
+
height, width = image.shape[:2]
|
279 |
+
|
280 |
+
# 确定最小边的长度
|
281 |
+
min_length = min(height, width)
|
282 |
+
|
283 |
+
# 计算剪裁区域的坐标
|
284 |
+
top = (height - min_length) // 2
|
285 |
+
bottom = top + min_length
|
286 |
+
left = (width - min_length) // 2
|
287 |
+
right = left + min_length
|
288 |
+
|
289 |
+
# 剪裁图像为正方形
|
290 |
+
cropped_image = image[top:bottom, left:right]
|
291 |
+
|
292 |
+
return cropped_image
|
293 |
+
def read_input_image_test(
|
294 |
+
image,
|
295 |
+
prior_latlon,
|
296 |
+
tile_size_meters,
|
297 |
+
):
|
298 |
+
# image = read_image(image_path)
|
299 |
+
# # 剪裁图像
|
300 |
+
# # 指定剪裁的宽度和高度
|
301 |
+
# width = 1080*2
|
302 |
+
# height =1080*2
|
303 |
+
# image = crop_square(image)
|
304 |
+
# # print("input image:",image.shape)
|
305 |
+
# image = crop_image(image, width, height)
|
306 |
+
# # print("crop_image:",image.shape)
|
307 |
+
image = cv2.resize(image,(256,256))
|
308 |
+
roll_pitch = None
|
309 |
+
|
310 |
+
|
311 |
+
latlon = None
|
312 |
+
if prior_latlon is not None:
|
313 |
+
latlon = prior_latlon
|
314 |
+
logger.info("Using prior latlon %s.", prior_latlon)
|
315 |
+
|
316 |
+
if latlon is None:
|
317 |
+
with open(image_path, "rb") as fid:
|
318 |
+
exif = EXIF(fid, lambda: image.shape[:2])
|
319 |
+
geo = exif.extract_geo()
|
320 |
+
if geo:
|
321 |
+
alt = geo.get("altitude", 0) # read if available
|
322 |
+
latlon = (geo["latitude"], geo["longitude"], alt)
|
323 |
+
logger.info("Using prior location from EXIF.")
|
324 |
+
# print(latlon)
|
325 |
+
else:
|
326 |
+
logger.info("Could not find any prior location in the image EXIF metadata.")
|
327 |
+
|
328 |
+
latlon = np.array(latlon)
|
329 |
+
|
330 |
+
proj = Projection(*latlon)
|
331 |
+
center = proj.project(latlon)
|
332 |
+
bbox = BoundaryBox(center, center) + float(tile_size_meters)
|
333 |
+
camera=None
|
334 |
+
image=cv2.resize(image,(256,256))
|
335 |
+
return image, camera, roll_pitch, proj, bbox, latlon
|
336 |
+
if __name__ == '__main__':
|
337 |
+
experiment_or_path = "weight/last-step-checkpointing.ckpt"
|
338 |
+
# experiment_or_path="experiments/maplocanet_0906_diffhight/last-step-checkpointing.ckpt"
|
339 |
+
image_path='images/00000.jpg'
|
340 |
+
prior_latlon=(37.75704325989902,-122.435941445631)
|
341 |
+
tile_size_meters=128
|
342 |
+
demo = Demo(experiment_or_path=experiment_or_path, num_rotations=128, device='cpu')
|
343 |
+
image, camera, gravity, proj, bbox, true_prior_latlon = read_input_image_test(
|
344 |
+
image_path,
|
345 |
+
prior_latlon=prior_latlon,
|
346 |
+
tile_size_meters=tile_size_meters, # try 64, 256, etc.
|
347 |
+
)
|
348 |
+
tiler = TileManager.from_bbox(projection=proj, bbox=bbox + 10,ppm=1, tile_size=tile_size_meters)
|
349 |
+
# tiler = TileManager.from_bbox(projection=proj, bbox=bbox + 10,ppm=1,path=root/city/'{}.osm'.format(city), tile_size=1)
|
350 |
+
canvas = tiler.query(bbox)
|
351 |
+
uv, yaw, prob, neural_map, image_rectified, data_, pred = demo.localize(
|
352 |
+
image, camera, canvas)
|
353 |
+
prior_latlon_pred = proj.unproject(canvas.to_xy(uv))
|
354 |
+
pass
|
evaluation/kitti.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Optional, Tuple
|
6 |
+
|
7 |
+
from omegaconf import OmegaConf, DictConfig
|
8 |
+
|
9 |
+
from .. import logger
|
10 |
+
from ..data import KittiDataModule
|
11 |
+
from .run import evaluate
|
12 |
+
|
13 |
+
|
14 |
+
default_cfg_single = OmegaConf.create({})
|
15 |
+
# For the sequential evaluation, we need to center the map around the GT location,
|
16 |
+
# since random offsets would accumulate and leave only the GT location with a valid mask.
|
17 |
+
# This should not have much impact on the results.
|
18 |
+
default_cfg_sequential = OmegaConf.create(
|
19 |
+
{
|
20 |
+
"data": {
|
21 |
+
"mask_radius": KittiDataModule.default_cfg["max_init_error"],
|
22 |
+
"prior_range_rotation": KittiDataModule.default_cfg[
|
23 |
+
"max_init_error_rotation"
|
24 |
+
]
|
25 |
+
+ 1,
|
26 |
+
"max_init_error": 0,
|
27 |
+
"max_init_error_rotation": 0,
|
28 |
+
},
|
29 |
+
"chunking": {
|
30 |
+
"max_length": 100, # about 10s?
|
31 |
+
},
|
32 |
+
}
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
def run(
|
37 |
+
split: str,
|
38 |
+
experiment: str,
|
39 |
+
cfg: Optional[DictConfig] = None,
|
40 |
+
sequential: bool = False,
|
41 |
+
thresholds: Tuple[int] = (1, 3, 5),
|
42 |
+
**kwargs,
|
43 |
+
):
|
44 |
+
cfg = cfg or {}
|
45 |
+
if isinstance(cfg, dict):
|
46 |
+
cfg = OmegaConf.create(cfg)
|
47 |
+
default = default_cfg_sequential if sequential else default_cfg_single
|
48 |
+
cfg = OmegaConf.merge(default, cfg)
|
49 |
+
dataset = KittiDataModule(cfg.get("data", {}))
|
50 |
+
|
51 |
+
metrics = evaluate(
|
52 |
+
experiment,
|
53 |
+
cfg,
|
54 |
+
dataset,
|
55 |
+
split=split,
|
56 |
+
sequential=sequential,
|
57 |
+
viz_kwargs=dict(show_dir_error=True, show_masked_prob=False),
|
58 |
+
**kwargs,
|
59 |
+
)
|
60 |
+
|
61 |
+
keys = ["directional_error", "yaw_max_error"]
|
62 |
+
if sequential:
|
63 |
+
keys += ["directional_seq_error", "yaw_seq_error"]
|
64 |
+
for k in keys:
|
65 |
+
rec = metrics[k].recall(thresholds).double().numpy().round(2).tolist()
|
66 |
+
logger.info("Recall %s: %s at %s m/°", k, rec, thresholds)
|
67 |
+
return metrics
|
68 |
+
|
69 |
+
|
70 |
+
if __name__ == "__main__":
|
71 |
+
parser = argparse.ArgumentParser()
|
72 |
+
parser.add_argument("--experiment", type=str, required=True)
|
73 |
+
parser.add_argument(
|
74 |
+
"--split", type=str, default="test", choices=["test", "val", "train"]
|
75 |
+
)
|
76 |
+
parser.add_argument("--sequential", action="store_true")
|
77 |
+
parser.add_argument("--output_dir", type=Path)
|
78 |
+
parser.add_argument("--num", type=int)
|
79 |
+
parser.add_argument("dotlist", nargs="*")
|
80 |
+
args = parser.parse_args()
|
81 |
+
cfg = OmegaConf.from_cli(args.dotlist)
|
82 |
+
run(
|
83 |
+
args.split,
|
84 |
+
args.experiment,
|
85 |
+
cfg,
|
86 |
+
args.sequential,
|
87 |
+
output_dir=args.output_dir,
|
88 |
+
num=args.num,
|
89 |
+
)
|
evaluation/mapillary.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Optional, Tuple
|
6 |
+
|
7 |
+
from omegaconf import OmegaConf, DictConfig
|
8 |
+
|
9 |
+
from .. import logger
|
10 |
+
from ..conf import data as conf_data_dir
|
11 |
+
from ..data import MapillaryDataModule
|
12 |
+
from .run import evaluate
|
13 |
+
|
14 |
+
|
15 |
+
split_overrides = {
|
16 |
+
"val": {
|
17 |
+
"scenes": [
|
18 |
+
"sanfrancisco_soma",
|
19 |
+
"sanfrancisco_hayes",
|
20 |
+
"amsterdam",
|
21 |
+
"berlin",
|
22 |
+
"lemans",
|
23 |
+
"montrouge",
|
24 |
+
"toulouse",
|
25 |
+
"nantes",
|
26 |
+
"vilnius",
|
27 |
+
"avignon",
|
28 |
+
"helsinki",
|
29 |
+
"milan",
|
30 |
+
"paris",
|
31 |
+
],
|
32 |
+
},
|
33 |
+
}
|
34 |
+
data_cfg_train = OmegaConf.load(Path(conf_data_dir.__file__).parent / "mapillary.yaml")
|
35 |
+
data_cfg = OmegaConf.merge(
|
36 |
+
data_cfg_train,
|
37 |
+
{
|
38 |
+
"return_gps": True,
|
39 |
+
"add_map_mask": True,
|
40 |
+
"max_init_error": 32,
|
41 |
+
"loading": {"val": {"batch_size": 1, "num_workers": 0}},
|
42 |
+
},
|
43 |
+
)
|
44 |
+
default_cfg_single = OmegaConf.create({"data": data_cfg})
|
45 |
+
default_cfg_sequential = OmegaConf.create(
|
46 |
+
{
|
47 |
+
**default_cfg_single,
|
48 |
+
"chunking": {
|
49 |
+
"max_length": 10,
|
50 |
+
},
|
51 |
+
}
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
def run(
|
56 |
+
split: str,
|
57 |
+
experiment: str,
|
58 |
+
cfg: Optional[DictConfig] = None,
|
59 |
+
sequential: bool = False,
|
60 |
+
thresholds: Tuple[int] = (1, 3, 5),
|
61 |
+
**kwargs,
|
62 |
+
):
|
63 |
+
cfg = cfg or {}
|
64 |
+
if isinstance(cfg, dict):
|
65 |
+
cfg = OmegaConf.create(cfg)
|
66 |
+
default = default_cfg_sequential if sequential else default_cfg_single
|
67 |
+
default = OmegaConf.merge(default, split_overrides[split])
|
68 |
+
cfg = OmegaConf.merge(default, cfg)
|
69 |
+
dataset = MapillaryDataModule(cfg.get("data", {}))
|
70 |
+
|
71 |
+
metrics = evaluate(experiment, cfg, dataset, split, sequential=sequential, **kwargs)
|
72 |
+
|
73 |
+
keys = [
|
74 |
+
"xy_max_error",
|
75 |
+
"xy_gps_error",
|
76 |
+
"yaw_max_error",
|
77 |
+
]
|
78 |
+
if sequential:
|
79 |
+
keys += [
|
80 |
+
"xy_seq_error",
|
81 |
+
"xy_gps_seq_error",
|
82 |
+
"yaw_seq_error",
|
83 |
+
"yaw_gps_seq_error",
|
84 |
+
]
|
85 |
+
for k in keys:
|
86 |
+
if k not in metrics:
|
87 |
+
logger.warning("Key %s not in metrics.", k)
|
88 |
+
continue
|
89 |
+
rec = metrics[k].recall(thresholds).double().numpy().round(2).tolist()
|
90 |
+
logger.info("Recall %s: %s at %s m/°", k, rec, thresholds)
|
91 |
+
return metrics
|
92 |
+
|
93 |
+
|
94 |
+
if __name__ == "__main__":
|
95 |
+
parser = argparse.ArgumentParser()
|
96 |
+
parser.add_argument("--experiment", type=str, required=True)
|
97 |
+
parser.add_argument("--split", type=str, default="val", choices=["val"])
|
98 |
+
parser.add_argument("--sequential", action="store_true")
|
99 |
+
parser.add_argument("--output_dir", type=Path)
|
100 |
+
parser.add_argument("--num", type=int)
|
101 |
+
parser.add_argument("dotlist", nargs="*")
|
102 |
+
args = parser.parse_args()
|
103 |
+
cfg = OmegaConf.from_cli(args.dotlist)
|
104 |
+
run(
|
105 |
+
args.split,
|
106 |
+
args.experiment,
|
107 |
+
cfg,
|
108 |
+
args.sequential,
|
109 |
+
output_dir=args.output_dir,
|
110 |
+
num=args.num,
|
111 |
+
)
|
evaluation/run.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import functools
|
4 |
+
from itertools import islice
|
5 |
+
from typing import Callable, Dict, Optional, Tuple
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from omegaconf import DictConfig, OmegaConf
|
11 |
+
from torchmetrics import MetricCollection
|
12 |
+
from pytorch_lightning import seed_everything
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from logger import logger, EXPERIMENTS_PATH
|
16 |
+
from dataset.torch import collate, unbatch_to_device
|
17 |
+
from models.voting import argmax_xyr, fuse_gps
|
18 |
+
from models.metrics import AngleError, LateralLongitudinalError, Location2DError
|
19 |
+
# from models.sequential import GPSAligner, RigidAligner
|
20 |
+
from module import GenericModule
|
21 |
+
from utils.io import download_file, DATA_URL
|
22 |
+
from evaluation.viz import plot_example_single, plot_example_sequential
|
23 |
+
from evaluation.utils import write_dump
|
24 |
+
|
25 |
+
|
26 |
+
pretrained_models = dict(
|
27 |
+
OrienterNet_MGL=("orienternet_mgl.ckpt", dict(num_rotations=256)),
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
def resolve_checkpoint_path(experiment_or_path: str) -> Path:
|
32 |
+
path = Path(experiment_or_path)
|
33 |
+
if not path.exists():
|
34 |
+
# provided name of experiment
|
35 |
+
path = Path(EXPERIMENTS_PATH, *experiment_or_path.split("/"))
|
36 |
+
if not path.exists():
|
37 |
+
if experiment_or_path in set(p for p, _ in pretrained_models.values()):
|
38 |
+
download_file(f"{DATA_URL}/{experiment_or_path}", path)
|
39 |
+
else:
|
40 |
+
raise FileNotFoundError(path)
|
41 |
+
if path.is_file():
|
42 |
+
return path
|
43 |
+
# provided only the experiment name
|
44 |
+
maybe_path = path / "last-step-v1.ckpt"
|
45 |
+
if not maybe_path.exists():
|
46 |
+
maybe_path = path / "last.ckpt"
|
47 |
+
if not maybe_path.exists():
|
48 |
+
raise FileNotFoundError(f"Could not find any checkpoint in {path}.")
|
49 |
+
return maybe_path
|
50 |
+
|
51 |
+
|
52 |
+
@torch.no_grad()
|
53 |
+
def evaluate_single_image(
|
54 |
+
dataloader: torch.utils.data.DataLoader,
|
55 |
+
model: GenericModule,
|
56 |
+
num: Optional[int] = None,
|
57 |
+
callback: Optional[Callable] = None,
|
58 |
+
progress: bool = True,
|
59 |
+
mask_index: Optional[Tuple[int]] = None,
|
60 |
+
has_gps: bool = False,
|
61 |
+
):
|
62 |
+
ppm = model.model.conf.pixel_per_meter
|
63 |
+
metrics = MetricCollection(model.model.metrics())
|
64 |
+
metrics["directional_error"] = LateralLongitudinalError(ppm)
|
65 |
+
if has_gps:
|
66 |
+
metrics["xy_gps_error"] = Location2DError("uv_gps", ppm)
|
67 |
+
metrics["xy_fused_error"] = Location2DError("uv_fused", ppm)
|
68 |
+
metrics["yaw_fused_error"] = AngleError("yaw_fused")
|
69 |
+
metrics = metrics.to(model.device)
|
70 |
+
|
71 |
+
for i, batch_ in enumerate(
|
72 |
+
islice(tqdm(dataloader, total=num, disable=not progress), num)
|
73 |
+
):
|
74 |
+
batch = model.transfer_batch_to_device(batch_, model.device, i)
|
75 |
+
# Ablation: mask semantic classes
|
76 |
+
if mask_index is not None:
|
77 |
+
mask = batch["map"][0, mask_index[0]] == (mask_index[1] + 1)
|
78 |
+
batch["map"][0, mask_index[0]][mask] = 0
|
79 |
+
pred = model(batch)
|
80 |
+
|
81 |
+
if has_gps:
|
82 |
+
(uv_gps,) = pred["uv_gps"] = batch["uv_gps"]
|
83 |
+
pred["log_probs_fused"] = fuse_gps(
|
84 |
+
pred["log_probs"], uv_gps, ppm, sigma=batch["accuracy_gps"]
|
85 |
+
)
|
86 |
+
uvt_fused = argmax_xyr(pred["log_probs_fused"])
|
87 |
+
pred["uv_fused"] = uvt_fused[..., :2]
|
88 |
+
pred["yaw_fused"] = uvt_fused[..., -1]
|
89 |
+
del uv_gps, uvt_fused
|
90 |
+
|
91 |
+
results = metrics(pred, batch)
|
92 |
+
if callback is not None:
|
93 |
+
callback(
|
94 |
+
i, model, unbatch_to_device(pred), unbatch_to_device(batch_), results
|
95 |
+
)
|
96 |
+
del batch_, batch, pred, results
|
97 |
+
|
98 |
+
return metrics.cpu()
|
99 |
+
|
100 |
+
|
101 |
+
@torch.no_grad()
|
102 |
+
def evaluate_sequential(
|
103 |
+
dataset: torch.utils.data.Dataset,
|
104 |
+
chunk2idx: Dict,
|
105 |
+
model: GenericModule,
|
106 |
+
num: Optional[int] = None,
|
107 |
+
shuffle: bool = False,
|
108 |
+
callback: Optional[Callable] = None,
|
109 |
+
progress: bool = True,
|
110 |
+
num_rotations: int = 512,
|
111 |
+
mask_index: Optional[Tuple[int]] = None,
|
112 |
+
has_gps: bool = True,
|
113 |
+
):
|
114 |
+
chunk_keys = list(chunk2idx)
|
115 |
+
if shuffle:
|
116 |
+
chunk_keys = [chunk_keys[i] for i in torch.randperm(len(chunk_keys))]
|
117 |
+
if num is not None:
|
118 |
+
chunk_keys = chunk_keys[:num]
|
119 |
+
lengths = [len(chunk2idx[k]) for k in chunk_keys]
|
120 |
+
logger.info(
|
121 |
+
"Min/max/med lengths: %d/%d/%d, total number of images: %d",
|
122 |
+
min(lengths),
|
123 |
+
np.median(lengths),
|
124 |
+
max(lengths),
|
125 |
+
sum(lengths),
|
126 |
+
)
|
127 |
+
viz = callback is not None
|
128 |
+
|
129 |
+
metrics = MetricCollection(model.model.metrics())
|
130 |
+
ppm = model.model.conf.pixel_per_meter
|
131 |
+
metrics["directional_error"] = LateralLongitudinalError(ppm)
|
132 |
+
metrics["xy_seq_error"] = Location2DError("uv_seq", ppm)
|
133 |
+
metrics["yaw_seq_error"] = AngleError("yaw_seq")
|
134 |
+
metrics["directional_seq_error"] = LateralLongitudinalError(ppm, key="uv_seq")
|
135 |
+
if has_gps:
|
136 |
+
metrics["xy_gps_error"] = Location2DError("uv_gps", ppm)
|
137 |
+
metrics["xy_gps_seq_error"] = Location2DError("uv_gps_seq", ppm)
|
138 |
+
metrics["yaw_gps_seq_error"] = AngleError("yaw_gps_seq")
|
139 |
+
metrics = metrics.to(model.device)
|
140 |
+
|
141 |
+
keys_save = ["uvr_max", "uv_max", "yaw_max", "uv_expectation"]
|
142 |
+
if has_gps:
|
143 |
+
keys_save.append("uv_gps")
|
144 |
+
if viz:
|
145 |
+
keys_save.append("log_probs")
|
146 |
+
|
147 |
+
for chunk_index, key in enumerate(tqdm(chunk_keys, disable=not progress)):
|
148 |
+
indices = chunk2idx[key]
|
149 |
+
aligner = RigidAligner(track_priors=viz, num_rotations=num_rotations)
|
150 |
+
if has_gps:
|
151 |
+
aligner_gps = GPSAligner(track_priors=viz, num_rotations=num_rotations)
|
152 |
+
batches = []
|
153 |
+
preds = []
|
154 |
+
for i in indices:
|
155 |
+
data = dataset[i]
|
156 |
+
data = model.transfer_batch_to_device(data, model.device, 0)
|
157 |
+
pred = model(collate([data]))
|
158 |
+
|
159 |
+
canvas = data["canvas"]
|
160 |
+
data["xy_geo"] = xy = canvas.to_xy(data["uv"].double())
|
161 |
+
data["yaw"] = yaw = data["roll_pitch_yaw"][-1].double()
|
162 |
+
aligner.update(pred["log_probs"][0], canvas, xy, yaw)
|
163 |
+
|
164 |
+
if has_gps:
|
165 |
+
(uv_gps) = pred["uv_gps"] = data["uv_gps"][None]
|
166 |
+
xy_gps = canvas.to_xy(uv_gps.double())
|
167 |
+
aligner_gps.update(xy_gps, data["accuracy_gps"], canvas, xy, yaw)
|
168 |
+
|
169 |
+
if not viz:
|
170 |
+
data.pop("image")
|
171 |
+
data.pop("map")
|
172 |
+
batches.append(data)
|
173 |
+
preds.append({k: pred[k][0] for k in keys_save})
|
174 |
+
del pred
|
175 |
+
|
176 |
+
xy_gt = torch.stack([b["xy_geo"] for b in batches])
|
177 |
+
yaw_gt = torch.stack([b["yaw"] for b in batches])
|
178 |
+
aligner.compute()
|
179 |
+
xy_seq, yaw_seq = aligner.transform(xy_gt, yaw_gt)
|
180 |
+
if has_gps:
|
181 |
+
aligner_gps.compute()
|
182 |
+
xy_gps_seq, yaw_gps_seq = aligner_gps.transform(xy_gt, yaw_gt)
|
183 |
+
results = []
|
184 |
+
for i in range(len(indices)):
|
185 |
+
preds[i]["uv_seq"] = batches[i]["canvas"].to_uv(xy_seq[i]).float()
|
186 |
+
preds[i]["yaw_seq"] = yaw_seq[i].float()
|
187 |
+
if has_gps:
|
188 |
+
preds[i]["uv_gps_seq"] = (
|
189 |
+
batches[i]["canvas"].to_uv(xy_gps_seq[i]).float()
|
190 |
+
)
|
191 |
+
preds[i]["yaw_gps_seq"] = yaw_gps_seq[i].float()
|
192 |
+
results.append(metrics(preds[i], batches[i]))
|
193 |
+
if viz:
|
194 |
+
callback(chunk_index, model, batches, preds, results, aligner)
|
195 |
+
del aligner, preds, batches, results
|
196 |
+
return metrics.cpu()
|
197 |
+
|
198 |
+
|
199 |
+
def evaluate(
|
200 |
+
experiment: str,
|
201 |
+
cfg: DictConfig,
|
202 |
+
dataset,
|
203 |
+
split: str,
|
204 |
+
sequential: bool = False,
|
205 |
+
output_dir: Optional[Path] = None,
|
206 |
+
callback: Optional[Callable] = None,
|
207 |
+
num_workers: int = 1,
|
208 |
+
viz_kwargs=None,
|
209 |
+
**kwargs,
|
210 |
+
):
|
211 |
+
if experiment in pretrained_models:
|
212 |
+
experiment, cfg_override = pretrained_models[experiment]
|
213 |
+
cfg = OmegaConf.merge(OmegaConf.create(dict(model=cfg_override)), cfg)
|
214 |
+
|
215 |
+
logger.info("Evaluating model %s with config %s", experiment, cfg)
|
216 |
+
checkpoint_path = resolve_checkpoint_path(experiment)
|
217 |
+
model = GenericModule.load_from_checkpoint(
|
218 |
+
checkpoint_path, cfg=cfg, find_best=not experiment.endswith(".ckpt")
|
219 |
+
)
|
220 |
+
model = model.eval()
|
221 |
+
if torch.cuda.is_available():
|
222 |
+
model = model.cuda()
|
223 |
+
|
224 |
+
dataset.prepare_data()
|
225 |
+
dataset.setup()
|
226 |
+
|
227 |
+
if output_dir is not None:
|
228 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
229 |
+
if callback is None:
|
230 |
+
if sequential:
|
231 |
+
callback = plot_example_sequential
|
232 |
+
else:
|
233 |
+
callback = plot_example_single
|
234 |
+
callback = functools.partial(
|
235 |
+
callback, out_dir=output_dir, **(viz_kwargs or {})
|
236 |
+
)
|
237 |
+
kwargs = {**kwargs, "callback": callback}
|
238 |
+
|
239 |
+
seed_everything(dataset.cfg.seed)
|
240 |
+
if sequential:
|
241 |
+
dset, chunk2idx = dataset.sequence_dataset(split, **cfg.chunking)
|
242 |
+
metrics = evaluate_sequential(dset, chunk2idx, model, **kwargs)
|
243 |
+
else:
|
244 |
+
loader = dataset.dataloader(split, shuffle=True, num_workers=num_workers)
|
245 |
+
metrics = evaluate_single_image(loader, model, **kwargs)
|
246 |
+
|
247 |
+
results = metrics.compute()
|
248 |
+
logger.info("All results: %s", results)
|
249 |
+
if output_dir is not None:
|
250 |
+
write_dump(output_dir, experiment, cfg, results, metrics)
|
251 |
+
logger.info("Outputs have been written to %s.", output_dir)
|
252 |
+
return metrics
|
evaluation/utils.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
|
6 |
+
from utils.io import write_json
|
7 |
+
|
8 |
+
|
9 |
+
def compute_recall(errors):
|
10 |
+
num_elements = len(errors)
|
11 |
+
sort_idx = np.argsort(errors)
|
12 |
+
errors = np.array(errors.copy())[sort_idx]
|
13 |
+
recall = (np.arange(num_elements) + 1) / num_elements
|
14 |
+
recall = np.r_[0, recall]
|
15 |
+
errors = np.r_[0, errors]
|
16 |
+
return errors, recall
|
17 |
+
|
18 |
+
|
19 |
+
def compute_auc(errors, recall, thresholds):
|
20 |
+
aucs = []
|
21 |
+
for t in thresholds:
|
22 |
+
last_index = np.searchsorted(errors, t, side="right")
|
23 |
+
r = np.r_[recall[:last_index], recall[last_index - 1]]
|
24 |
+
e = np.r_[errors[:last_index], t]
|
25 |
+
auc = np.trapz(r, x=e) / t
|
26 |
+
aucs.append(auc * 100)
|
27 |
+
return aucs
|
28 |
+
|
29 |
+
|
30 |
+
def write_dump(output_dir, experiment, cfg, results, metrics):
|
31 |
+
dump = {
|
32 |
+
"experiment": experiment,
|
33 |
+
"cfg": OmegaConf.to_container(cfg),
|
34 |
+
"results": results,
|
35 |
+
"errors": {},
|
36 |
+
}
|
37 |
+
for k, m in metrics.items():
|
38 |
+
if hasattr(m, "get_errors"):
|
39 |
+
dump["errors"][k] = m.get_errors().numpy()
|
40 |
+
write_json(output_dir / "log.json", dump)
|
evaluation/viz.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
|
7 |
+
from utils.io import write_torch_image
|
8 |
+
from utils.viz_2d import plot_images, features_to_RGB, save_plot
|
9 |
+
from utils.viz_localization import (
|
10 |
+
likelihood_overlay,
|
11 |
+
plot_pose,
|
12 |
+
plot_dense_rotations,
|
13 |
+
add_circle_inset,
|
14 |
+
)
|
15 |
+
from osm.viz import Colormap, plot_nodes
|
16 |
+
|
17 |
+
|
18 |
+
def plot_example_single(
|
19 |
+
idx,
|
20 |
+
model,
|
21 |
+
pred,
|
22 |
+
data,
|
23 |
+
results,
|
24 |
+
plot_bev=True,
|
25 |
+
out_dir=None,
|
26 |
+
fig_for_paper=False,
|
27 |
+
show_gps=False,
|
28 |
+
show_fused=False,
|
29 |
+
show_dir_error=False,
|
30 |
+
show_masked_prob=False,
|
31 |
+
):
|
32 |
+
scene, name, rasters, uv_gt = (data[k] for k in ("scene", "name", "map", "uv"))
|
33 |
+
uv_gps = data.get("uv_gps")
|
34 |
+
yaw_gt = data["roll_pitch_yaw"][-1].numpy()
|
35 |
+
image = data["image"].permute(1, 2, 0)
|
36 |
+
if "valid" in data:
|
37 |
+
image = image.masked_fill(~data["valid"].unsqueeze(-1), 0.3)
|
38 |
+
|
39 |
+
lp_uvt = lp_uv = pred["log_probs"]
|
40 |
+
if show_fused and "log_probs_fused" in pred:
|
41 |
+
lp_uvt = lp_uv = pred["log_probs_fused"]
|
42 |
+
elif not show_masked_prob and "scores_unmasked" in pred:
|
43 |
+
lp_uvt = lp_uv = pred["scores_unmasked"]
|
44 |
+
has_rotation = lp_uvt.ndim == 3
|
45 |
+
if has_rotation:
|
46 |
+
lp_uv = lp_uvt.max(-1).values
|
47 |
+
if lp_uv.min() > -np.inf:
|
48 |
+
lp_uv = lp_uv.clip(min=np.percentile(lp_uv, 1))
|
49 |
+
prob = lp_uv.exp()
|
50 |
+
uv_p, yaw_p = pred["uv_max"], pred.get("yaw_max")
|
51 |
+
if show_fused and "uv_fused" in pred:
|
52 |
+
uv_p, yaw_p = pred["uv_fused"], pred.get("yaw_fused")
|
53 |
+
feats_map = pred["map"]["map_features"][0]
|
54 |
+
(feats_map_rgb,) = features_to_RGB(feats_map.numpy())
|
55 |
+
|
56 |
+
text1 = rf'$\Delta xy$: {results["xy_max_error"]:.1f}m'
|
57 |
+
if has_rotation:
|
58 |
+
text1 += rf', $\Delta\theta$: {results["yaw_max_error"]:.1f}°'
|
59 |
+
if show_fused and "xy_fused_error" in results:
|
60 |
+
text1 += rf', $\Delta xy_{{fused}}$: {results["xy_fused_error"]:.1f}m'
|
61 |
+
text1 += rf', $\Delta\theta_{{fused}}$: {results["yaw_fused_error"]:.1f}°'
|
62 |
+
if show_dir_error and "directional_error" in results:
|
63 |
+
err_lat, err_lon = results["directional_error"]
|
64 |
+
text1 += rf", $\Delta$lateral/longitundinal={err_lat:.1f}m/{err_lon:.1f}m"
|
65 |
+
if "xy_gps_error" in results:
|
66 |
+
text1 += rf', $\Delta xy_{{GPS}}$: {results["xy_gps_error"]:.1f}m'
|
67 |
+
|
68 |
+
map_viz = Colormap.apply(rasters)
|
69 |
+
overlay = likelihood_overlay(prob.numpy(), map_viz.mean(-1, keepdims=True))
|
70 |
+
plot_images(
|
71 |
+
[image, map_viz, overlay, feats_map_rgb],
|
72 |
+
titles=[text1, "map", "likelihood", "neural map"],
|
73 |
+
dpi=75,
|
74 |
+
cmaps="jet",
|
75 |
+
)
|
76 |
+
fig = plt.gcf()
|
77 |
+
axes = fig.axes
|
78 |
+
axes[1].images[0].set_interpolation("none")
|
79 |
+
axes[2].images[0].set_interpolation("none")
|
80 |
+
Colormap.add_colorbar()
|
81 |
+
plot_nodes(1, rasters[2])
|
82 |
+
|
83 |
+
if show_gps and uv_gps is not None:
|
84 |
+
plot_pose([1], uv_gps, c="blue")
|
85 |
+
plot_pose([1], uv_gt, yaw_gt, c="red")
|
86 |
+
plot_pose([1], uv_p, yaw_p, c="k")
|
87 |
+
plot_dense_rotations(2, lp_uvt.exp())
|
88 |
+
inset_center = pred["uv_max"] if results["xy_max_error"] < 5 else uv_gt
|
89 |
+
axins = add_circle_inset(axes[2], inset_center)
|
90 |
+
axins.scatter(*uv_gt, lw=1, c="red", ec="k", s=50, zorder=15)
|
91 |
+
axes[0].text(
|
92 |
+
0.003,
|
93 |
+
0.003,
|
94 |
+
f"{scene}/{name}",
|
95 |
+
transform=axes[0].transAxes,
|
96 |
+
fontsize=3,
|
97 |
+
va="bottom",
|
98 |
+
ha="left",
|
99 |
+
color="w",
|
100 |
+
)
|
101 |
+
plt.show()
|
102 |
+
if out_dir is not None:
|
103 |
+
name_ = name.replace("/", "_")
|
104 |
+
p = str(out_dir / f"{scene}_{name_}_{{}}.pdf")
|
105 |
+
save_plot(p.format("pred"))
|
106 |
+
plt.close()
|
107 |
+
|
108 |
+
if fig_for_paper:
|
109 |
+
# !cp ../datasets/MGL/{scene}/images/{name}.jpg {out_dir}/{scene}_{name}.jpg
|
110 |
+
plot_images([map_viz])
|
111 |
+
plt.gca().images[0].set_interpolation("none")
|
112 |
+
plot_nodes(0, rasters[2])
|
113 |
+
plot_pose([0], uv_gt, yaw_gt, c="red")
|
114 |
+
plot_pose([0], pred["uv_max"], pred["yaw_max"], c="k")
|
115 |
+
save_plot(p.format("map"))
|
116 |
+
plt.close()
|
117 |
+
plot_images([lp_uv], cmaps="jet")
|
118 |
+
plot_dense_rotations(0, lp_uvt.exp())
|
119 |
+
save_plot(p.format("loglikelihood"), dpi=100)
|
120 |
+
plt.close()
|
121 |
+
plot_images([overlay])
|
122 |
+
plt.gca().images[0].set_interpolation("none")
|
123 |
+
axins = add_circle_inset(plt.gca(), inset_center)
|
124 |
+
axins.scatter(*uv_gt, lw=1, c="red", ec="k", s=50)
|
125 |
+
save_plot(p.format("likelihood"))
|
126 |
+
plt.close()
|
127 |
+
write_torch_image(
|
128 |
+
p.format("neuralmap").replace("pdf", "jpg"), feats_map_rgb
|
129 |
+
)
|
130 |
+
write_torch_image(p.format("image").replace("pdf", "jpg"), image.numpy())
|
131 |
+
|
132 |
+
if not plot_bev:
|
133 |
+
return
|
134 |
+
|
135 |
+
feats_q = pred["features_bev"]
|
136 |
+
mask_bev = pred["valid_bev"]
|
137 |
+
prior = None
|
138 |
+
if "log_prior" in pred["map"]:
|
139 |
+
prior = pred["map"]["log_prior"][0].sigmoid()
|
140 |
+
if "bev" in pred and "confidence" in pred["bev"]:
|
141 |
+
conf_q = pred["bev"]["confidence"]
|
142 |
+
else:
|
143 |
+
conf_q = torch.norm(feats_q, dim=0)
|
144 |
+
conf_q = conf_q.masked_fill(~mask_bev, np.nan)
|
145 |
+
(feats_q_rgb,) = features_to_RGB(feats_q.numpy(), masks=[mask_bev.numpy()])
|
146 |
+
# feats_map_rgb, feats_q_rgb, = features_to_RGB(
|
147 |
+
# feats_map.numpy(), feats_q.numpy(), masks=[None, mask_bev])
|
148 |
+
norm_map = torch.norm(feats_map, dim=0)
|
149 |
+
|
150 |
+
plot_images(
|
151 |
+
[conf_q, feats_q_rgb, norm_map] + ([] if prior is None else [prior]),
|
152 |
+
titles=["BEV confidence", "BEV features", "map norm"]
|
153 |
+
+ ([] if prior is None else ["map prior"]),
|
154 |
+
dpi=50,
|
155 |
+
cmaps="jet",
|
156 |
+
)
|
157 |
+
plt.show()
|
158 |
+
|
159 |
+
if out_dir is not None:
|
160 |
+
save_plot(p.format("bev"))
|
161 |
+
plt.close()
|
162 |
+
|
163 |
+
|
164 |
+
def plot_example_sequential(
|
165 |
+
idx,
|
166 |
+
model,
|
167 |
+
pred,
|
168 |
+
data,
|
169 |
+
results,
|
170 |
+
plot_bev=True,
|
171 |
+
out_dir=None,
|
172 |
+
fig_for_paper=False,
|
173 |
+
show_gps=False,
|
174 |
+
show_fused=False,
|
175 |
+
show_dir_error=False,
|
176 |
+
show_masked_prob=False,
|
177 |
+
):
|
178 |
+
return
|
flagged/inp/10d2e4a8712491181c2f48b61f5003b216d2b9f9/tmp48n9eoyh.png
ADDED
![]() |
flagged/inp/e1b18d44d9e381d586209f73a015fed7f688822b/tmp86ith_2q.png
ADDED
![]() |
flagged/log.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
inp,longitude,latitude,Area,output,flag,username,timestamp
|
2 |
+
E:\MapLocNetDemo\Demo\flagged\inp\10d2e4a8712491181c2f48b61f5003b216d2b9f9\tmp48n9eoyh.png,70.1,40,256,E:\MapLocNetDemo\Demo\flagged\output\tmp59657zop.json,,,2023-09-22 10:07:17.488625
|
3 |
+
E:\MapLocNetDemo\Demo\flagged\inp\e1b18d44d9e381d586209f73a015fed7f688822b\tmp86ith_2q.png,70.1,40,256,E:\MapLocNetDemo\Demo\flagged\output\tmpbs17s28d.json,,,2023-09-22 10:07:21.485967
|
flagged/output/tmp59657zop.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"label": "bull mastiff\n", "confidences": [{"label": "bull mastiff\n", "confidence": 0.24759389460086823}, {"label": "pug\n", "confidence": 0.0916372761130333}, {"label": "Great Dane\n", "confidence": 0.08652031421661377}]}
|
flagged/output/tmpbs17s28d.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"label": "bull mastiff\n", "confidences": [{"label": "bull mastiff\n", "confidence": 0.24759389460086823}, {"label": "pug\n", "confidence": 0.0916372761130333}, {"label": "Great Dane\n", "confidence": 0.08652031421661377}]}
|
images/00000.jpg
ADDED
![]() |
images/00011.jpg
ADDED
![]() |
images/00022.jpg
ADDED
![]() |
images/00033.jpg
ADDED
![]() |
images/cat_dog.png
ADDED
![]() |
label.txt
ADDED
@@ -0,0 +1,1000 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tench
|
2 |
+
goldfish
|
3 |
+
great white shark
|
4 |
+
tiger shark
|
5 |
+
hammerhead
|
6 |
+
electric ray
|
7 |
+
stingray
|
8 |
+
cock
|
9 |
+
hen
|
10 |
+
ostrich
|
11 |
+
brambling
|
12 |
+
goldfinch
|
13 |
+
house finch
|
14 |
+
junco
|
15 |
+
indigo bunting
|
16 |
+
robin
|
17 |
+
bulbul
|
18 |
+
jay
|
19 |
+
magpie
|
20 |
+
chickadee
|
21 |
+
water ouzel
|
22 |
+
kite
|
23 |
+
bald eagle
|
24 |
+
vulture
|
25 |
+
great grey owl
|
26 |
+
European fire salamander
|
27 |
+
common newt
|
28 |
+
eft
|
29 |
+
spotted salamander
|
30 |
+
axolotl
|
31 |
+
bullfrog
|
32 |
+
tree frog
|
33 |
+
tailed frog
|
34 |
+
loggerhead
|
35 |
+
leatherback turtle
|
36 |
+
mud turtle
|
37 |
+
terrapin
|
38 |
+
box turtle
|
39 |
+
banded gecko
|
40 |
+
common iguana
|
41 |
+
American chameleon
|
42 |
+
whiptail
|
43 |
+
agama
|
44 |
+
frilled lizard
|
45 |
+
alligator lizard
|
46 |
+
Gila monster
|
47 |
+
green lizard
|
48 |
+
African chameleon
|
49 |
+
Komodo dragon
|
50 |
+
African crocodile
|
51 |
+
American alligator
|
52 |
+
triceratops
|
53 |
+
thunder snake
|
54 |
+
ringneck snake
|
55 |
+
hognose snake
|
56 |
+
green snake
|
57 |
+
king snake
|
58 |
+
garter snake
|
59 |
+
water snake
|
60 |
+
vine snake
|
61 |
+
night snake
|
62 |
+
boa constrictor
|
63 |
+
rock python
|
64 |
+
Indian cobra
|
65 |
+
green mamba
|
66 |
+
sea snake
|
67 |
+
horned viper
|
68 |
+
diamondback
|
69 |
+
sidewinder
|
70 |
+
trilobite
|
71 |
+
harvestman
|
72 |
+
scorpion
|
73 |
+
black and gold garden spider
|
74 |
+
barn spider
|
75 |
+
garden spider
|
76 |
+
black widow
|
77 |
+
tarantula
|
78 |
+
wolf spider
|
79 |
+
tick
|
80 |
+
centipede
|
81 |
+
black grouse
|
82 |
+
ptarmigan
|
83 |
+
ruffed grouse
|
84 |
+
prairie chicken
|
85 |
+
peacock
|
86 |
+
quail
|
87 |
+
partridge
|
88 |
+
African grey
|
89 |
+
macaw
|
90 |
+
sulphur-crested cockatoo
|
91 |
+
lorikeet
|
92 |
+
coucal
|
93 |
+
bee eater
|
94 |
+
hornbill
|
95 |
+
hummingbird
|
96 |
+
jacamar
|
97 |
+
toucan
|
98 |
+
drake
|
99 |
+
red-breasted merganser
|
100 |
+
goose
|
101 |
+
black swan
|
102 |
+
tusker
|
103 |
+
echidna
|
104 |
+
platypus
|
105 |
+
wallaby
|
106 |
+
koala
|
107 |
+
wombat
|
108 |
+
jellyfish
|
109 |
+
sea anemone
|
110 |
+
brain coral
|
111 |
+
flatworm
|
112 |
+
nematode
|
113 |
+
conch
|
114 |
+
snail
|
115 |
+
slug
|
116 |
+
sea slug
|
117 |
+
chiton
|
118 |
+
chambered nautilus
|
119 |
+
Dungeness crab
|
120 |
+
rock crab
|
121 |
+
fiddler crab
|
122 |
+
king crab
|
123 |
+
American lobster
|
124 |
+
spiny lobster
|
125 |
+
crayfish
|
126 |
+
hermit crab
|
127 |
+
isopod
|
128 |
+
white stork
|
129 |
+
black stork
|
130 |
+
spoonbill
|
131 |
+
flamingo
|
132 |
+
little blue heron
|
133 |
+
American egret
|
134 |
+
bittern
|
135 |
+
crane
|
136 |
+
limpkin
|
137 |
+
European gallinule
|
138 |
+
American coot
|
139 |
+
bustard
|
140 |
+
ruddy turnstone
|
141 |
+
red-backed sandpiper
|
142 |
+
redshank
|
143 |
+
dowitcher
|
144 |
+
oystercatcher
|
145 |
+
pelican
|
146 |
+
king penguin
|
147 |
+
albatross
|
148 |
+
grey whale
|
149 |
+
killer whale
|
150 |
+
dugong
|
151 |
+
sea lion
|
152 |
+
Chihuahua
|
153 |
+
Japanese spaniel
|
154 |
+
Maltese dog
|
155 |
+
Pekinese
|
156 |
+
Shih-Tzu
|
157 |
+
Blenheim spaniel
|
158 |
+
papillon
|
159 |
+
toy terrier
|
160 |
+
Rhodesian ridgeback
|
161 |
+
Afghan hound
|
162 |
+
basset
|
163 |
+
beagle
|
164 |
+
bloodhound
|
165 |
+
bluetick
|
166 |
+
black-and-tan coonhound
|
167 |
+
Walker hound
|
168 |
+
English foxhound
|
169 |
+
redbone
|
170 |
+
borzoi
|
171 |
+
Irish wolfhound
|
172 |
+
Italian greyhound
|
173 |
+
whippet
|
174 |
+
Ibizan hound
|
175 |
+
Norwegian elkhound
|
176 |
+
otterhound
|
177 |
+
Saluki
|
178 |
+
Scottish deerhound
|
179 |
+
Weimaraner
|
180 |
+
Staffordshire bullterrier
|
181 |
+
American Staffordshire terrier
|
182 |
+
Bedlington terrier
|
183 |
+
Border terrier
|
184 |
+
Kerry blue terrier
|
185 |
+
Irish terrier
|
186 |
+
Norfolk terrier
|
187 |
+
Norwich terrier
|
188 |
+
Yorkshire terrier
|
189 |
+
wire-haired fox terrier
|
190 |
+
Lakeland terrier
|
191 |
+
Sealyham terrier
|
192 |
+
Airedale
|
193 |
+
cairn
|
194 |
+
Australian terrier
|
195 |
+
Dandie Dinmont
|
196 |
+
Boston bull
|
197 |
+
miniature schnauzer
|
198 |
+
giant schnauzer
|
199 |
+
standard schnauzer
|
200 |
+
Scotch terrier
|
201 |
+
Tibetan terrier
|
202 |
+
silky terrier
|
203 |
+
soft-coated wheaten terrier
|
204 |
+
West Highland white terrier
|
205 |
+
Lhasa
|
206 |
+
flat-coated retriever
|
207 |
+
curly-coated retriever
|
208 |
+
golden retriever
|
209 |
+
Labrador retriever
|
210 |
+
Chesapeake Bay retriever
|
211 |
+
German short-haired pointer
|
212 |
+
vizsla
|
213 |
+
English setter
|
214 |
+
Irish setter
|
215 |
+
Gordon setter
|
216 |
+
Brittany spaniel
|
217 |
+
clumber
|
218 |
+
English springer
|
219 |
+
Welsh springer spaniel
|
220 |
+
cocker spaniel
|
221 |
+
Sussex spaniel
|
222 |
+
Irish water spaniel
|
223 |
+
kuvasz
|
224 |
+
schipperke
|
225 |
+
groenendael
|
226 |
+
malinois
|
227 |
+
briard
|
228 |
+
kelpie
|
229 |
+
komondor
|
230 |
+
Old English sheepdog
|
231 |
+
Shetland sheepdog
|
232 |
+
collie
|
233 |
+
Border collie
|
234 |
+
Bouvier des Flandres
|
235 |
+
Rottweiler
|
236 |
+
German shepherd
|
237 |
+
Doberman
|
238 |
+
miniature pinscher
|
239 |
+
Greater Swiss Mountain dog
|
240 |
+
Bernese mountain dog
|
241 |
+
Appenzeller
|
242 |
+
EntleBucher
|
243 |
+
boxer
|
244 |
+
bull mastiff
|
245 |
+
Tibetan mastiff
|
246 |
+
French bulldog
|
247 |
+
Great Dane
|
248 |
+
Saint Bernard
|
249 |
+
Eskimo dog
|
250 |
+
malamute
|
251 |
+
Siberian husky
|
252 |
+
dalmatian
|
253 |
+
affenpinscher
|
254 |
+
basenji
|
255 |
+
pug
|
256 |
+
Leonberg
|
257 |
+
Newfoundland
|
258 |
+
Great Pyrenees
|
259 |
+
Samoyed
|
260 |
+
Pomeranian
|
261 |
+
chow
|
262 |
+
keeshond
|
263 |
+
Brabancon griffon
|
264 |
+
Pembroke
|
265 |
+
Cardigan
|
266 |
+
toy poodle
|
267 |
+
miniature poodle
|
268 |
+
standard poodle
|
269 |
+
Mexican hairless
|
270 |
+
timber wolf
|
271 |
+
white wolf
|
272 |
+
red wolf
|
273 |
+
coyote
|
274 |
+
dingo
|
275 |
+
dhole
|
276 |
+
African hunting dog
|
277 |
+
hyena
|
278 |
+
red fox
|
279 |
+
kit fox
|
280 |
+
Arctic fox
|
281 |
+
grey fox
|
282 |
+
tabby
|
283 |
+
tiger cat
|
284 |
+
Persian cat
|
285 |
+
Siamese cat
|
286 |
+
Egyptian cat
|
287 |
+
cougar
|
288 |
+
lynx
|
289 |
+
leopard
|
290 |
+
snow leopard
|
291 |
+
jaguar
|
292 |
+
lion
|
293 |
+
tiger
|
294 |
+
cheetah
|
295 |
+
brown bear
|
296 |
+
American black bear
|
297 |
+
ice bear
|
298 |
+
sloth bear
|
299 |
+
mongoose
|
300 |
+
meerkat
|
301 |
+
tiger beetle
|
302 |
+
ladybug
|
303 |
+
ground beetle
|
304 |
+
long-horned beetle
|
305 |
+
leaf beetle
|
306 |
+
dung beetle
|
307 |
+
rhinoceros beetle
|
308 |
+
weevil
|
309 |
+
fly
|
310 |
+
bee
|
311 |
+
ant
|
312 |
+
grasshopper
|
313 |
+
cricket
|
314 |
+
walking stick
|
315 |
+
cockroach
|
316 |
+
mantis
|
317 |
+
cicada
|
318 |
+
leafhopper
|
319 |
+
lacewing
|
320 |
+
dragonfly
|
321 |
+
damselfly
|
322 |
+
admiral
|
323 |
+
ringlet
|
324 |
+
monarch
|
325 |
+
cabbage butterfly
|
326 |
+
sulphur butterfly
|
327 |
+
lycaenid
|
328 |
+
starfish
|
329 |
+
sea urchin
|
330 |
+
sea cucumber
|
331 |
+
wood rabbit
|
332 |
+
hare
|
333 |
+
Angora
|
334 |
+
hamster
|
335 |
+
porcupine
|
336 |
+
fox squirrel
|
337 |
+
marmot
|
338 |
+
beaver
|
339 |
+
guinea pig
|
340 |
+
sorrel
|
341 |
+
zebra
|
342 |
+
hog
|
343 |
+
wild boar
|
344 |
+
warthog
|
345 |
+
hippopotamus
|
346 |
+
ox
|
347 |
+
water buffalo
|
348 |
+
bison
|
349 |
+
ram
|
350 |
+
bighorn
|
351 |
+
ibex
|
352 |
+
hartebeest
|
353 |
+
impala
|
354 |
+
gazelle
|
355 |
+
Arabian camel
|
356 |
+
llama
|
357 |
+
weasel
|
358 |
+
mink
|
359 |
+
polecat
|
360 |
+
black-footed ferret
|
361 |
+
otter
|
362 |
+
skunk
|
363 |
+
badger
|
364 |
+
armadillo
|
365 |
+
three-toed sloth
|
366 |
+
orangutan
|
367 |
+
gorilla
|
368 |
+
chimpanzee
|
369 |
+
gibbon
|
370 |
+
siamang
|
371 |
+
guenon
|
372 |
+
patas
|
373 |
+
baboon
|
374 |
+
macaque
|
375 |
+
langur
|
376 |
+
colobus
|
377 |
+
proboscis monkey
|
378 |
+
marmoset
|
379 |
+
capuchin
|
380 |
+
howler monkey
|
381 |
+
titi
|
382 |
+
spider monkey
|
383 |
+
squirrel monkey
|
384 |
+
Madagascar cat
|
385 |
+
indri
|
386 |
+
Indian elephant
|
387 |
+
African elephant
|
388 |
+
lesser panda
|
389 |
+
giant panda
|
390 |
+
barracouta
|
391 |
+
eel
|
392 |
+
coho
|
393 |
+
rock beauty
|
394 |
+
anemone fish
|
395 |
+
sturgeon
|
396 |
+
gar
|
397 |
+
lionfish
|
398 |
+
puffer
|
399 |
+
abacus
|
400 |
+
abaya
|
401 |
+
academic gown
|
402 |
+
accordion
|
403 |
+
acoustic guitar
|
404 |
+
aircraft carrier
|
405 |
+
airliner
|
406 |
+
airship
|
407 |
+
altar
|
408 |
+
ambulance
|
409 |
+
amphibian
|
410 |
+
analog clock
|
411 |
+
apiary
|
412 |
+
apron
|
413 |
+
ashcan
|
414 |
+
assault rifle
|
415 |
+
backpack
|
416 |
+
bakery
|
417 |
+
balance beam
|
418 |
+
balloon
|
419 |
+
ballpoint
|
420 |
+
Band Aid
|
421 |
+
banjo
|
422 |
+
bannister
|
423 |
+
barbell
|
424 |
+
barber chair
|
425 |
+
barbershop
|
426 |
+
barn
|
427 |
+
barometer
|
428 |
+
barrel
|
429 |
+
barrow
|
430 |
+
baseball
|
431 |
+
basketball
|
432 |
+
bassinet
|
433 |
+
bassoon
|
434 |
+
bathing cap
|
435 |
+
bath towel
|
436 |
+
bathtub
|
437 |
+
beach wagon
|
438 |
+
beacon
|
439 |
+
beaker
|
440 |
+
bearskin
|
441 |
+
beer bottle
|
442 |
+
beer glass
|
443 |
+
bell cote
|
444 |
+
bib
|
445 |
+
bicycle-built-for-two
|
446 |
+
bikini
|
447 |
+
binder
|
448 |
+
binoculars
|
449 |
+
birdhouse
|
450 |
+
boathouse
|
451 |
+
bobsled
|
452 |
+
bolo tie
|
453 |
+
bonnet
|
454 |
+
bookcase
|
455 |
+
bookshop
|
456 |
+
bottlecap
|
457 |
+
bow
|
458 |
+
bow tie
|
459 |
+
brass
|
460 |
+
brassiere
|
461 |
+
breakwater
|
462 |
+
breastplate
|
463 |
+
broom
|
464 |
+
bucket
|
465 |
+
buckle
|
466 |
+
bulletproof vest
|
467 |
+
bullet train
|
468 |
+
butcher shop
|
469 |
+
cab
|
470 |
+
caldron
|
471 |
+
candle
|
472 |
+
cannon
|
473 |
+
canoe
|
474 |
+
can opener
|
475 |
+
cardigan
|
476 |
+
car mirror
|
477 |
+
carousel
|
478 |
+
carpenter's kit
|
479 |
+
carton
|
480 |
+
car wheel
|
481 |
+
cash machine
|
482 |
+
cassette
|
483 |
+
cassette player
|
484 |
+
castle
|
485 |
+
catamaran
|
486 |
+
CD player
|
487 |
+
cello
|
488 |
+
cellular telephone
|
489 |
+
chain
|
490 |
+
chainlink fence
|
491 |
+
chain mail
|
492 |
+
chain saw
|
493 |
+
chest
|
494 |
+
chiffonier
|
495 |
+
chime
|
496 |
+
china cabinet
|
497 |
+
Christmas stocking
|
498 |
+
church
|
499 |
+
cinema
|
500 |
+
cleaver
|
501 |
+
cliff dwelling
|
502 |
+
cloak
|
503 |
+
clog
|
504 |
+
cocktail shaker
|
505 |
+
coffee mug
|
506 |
+
coffeepot
|
507 |
+
coil
|
508 |
+
combination lock
|
509 |
+
computer keyboard
|
510 |
+
confectionery
|
511 |
+
container ship
|
512 |
+
convertible
|
513 |
+
corkscrew
|
514 |
+
cornet
|
515 |
+
cowboy boot
|
516 |
+
cowboy hat
|
517 |
+
cradle
|
518 |
+
crane
|
519 |
+
crash helmet
|
520 |
+
crate
|
521 |
+
crib
|
522 |
+
Crock Pot
|
523 |
+
croquet ball
|
524 |
+
crutch
|
525 |
+
cuirass
|
526 |
+
dam
|
527 |
+
desk
|
528 |
+
desktop computer
|
529 |
+
dial telephone
|
530 |
+
diaper
|
531 |
+
digital clock
|
532 |
+
digital watch
|
533 |
+
dining table
|
534 |
+
dishrag
|
535 |
+
dishwasher
|
536 |
+
disk brake
|
537 |
+
dock
|
538 |
+
dogsled
|
539 |
+
dome
|
540 |
+
doormat
|
541 |
+
drilling platform
|
542 |
+
drum
|
543 |
+
drumstick
|
544 |
+
dumbbell
|
545 |
+
Dutch oven
|
546 |
+
electric fan
|
547 |
+
electric guitar
|
548 |
+
electric locomotive
|
549 |
+
entertainment center
|
550 |
+
envelope
|
551 |
+
espresso maker
|
552 |
+
face powder
|
553 |
+
feather boa
|
554 |
+
file
|
555 |
+
fireboat
|
556 |
+
fire engine
|
557 |
+
fire screen
|
558 |
+
flagpole
|
559 |
+
flute
|
560 |
+
folding chair
|
561 |
+
football helmet
|
562 |
+
forklift
|
563 |
+
fountain
|
564 |
+
fountain pen
|
565 |
+
four-poster
|
566 |
+
freight car
|
567 |
+
French horn
|
568 |
+
frying pan
|
569 |
+
fur coat
|
570 |
+
garbage truck
|
571 |
+
gasmask
|
572 |
+
gas pump
|
573 |
+
goblet
|
574 |
+
go-kart
|
575 |
+
golf ball
|
576 |
+
golfcart
|
577 |
+
gondola
|
578 |
+
gong
|
579 |
+
gown
|
580 |
+
grand piano
|
581 |
+
greenhouse
|
582 |
+
grille
|
583 |
+
grocery store
|
584 |
+
guillotine
|
585 |
+
hair slide
|
586 |
+
hair spray
|
587 |
+
half track
|
588 |
+
hammer
|
589 |
+
hamper
|
590 |
+
hand blower
|
591 |
+
hand-held computer
|
592 |
+
handkerchief
|
593 |
+
hard disc
|
594 |
+
harmonica
|
595 |
+
harp
|
596 |
+
harvester
|
597 |
+
hatchet
|
598 |
+
holster
|
599 |
+
home theater
|
600 |
+
honeycomb
|
601 |
+
hook
|
602 |
+
hoopskirt
|
603 |
+
horizontal bar
|
604 |
+
horse cart
|
605 |
+
hourglass
|
606 |
+
iPod
|
607 |
+
iron
|
608 |
+
jack-o'-lantern
|
609 |
+
jean
|
610 |
+
jeep
|
611 |
+
jersey
|
612 |
+
jigsaw puzzle
|
613 |
+
jinrikisha
|
614 |
+
joystick
|
615 |
+
kimono
|
616 |
+
knee pad
|
617 |
+
knot
|
618 |
+
lab coat
|
619 |
+
ladle
|
620 |
+
lampshade
|
621 |
+
laptop
|
622 |
+
lawn mower
|
623 |
+
lens cap
|
624 |
+
letter opener
|
625 |
+
library
|
626 |
+
lifeboat
|
627 |
+
lighter
|
628 |
+
limousine
|
629 |
+
liner
|
630 |
+
lipstick
|
631 |
+
Loafer
|
632 |
+
lotion
|
633 |
+
loudspeaker
|
634 |
+
loupe
|
635 |
+
lumbermill
|
636 |
+
magnetic compass
|
637 |
+
mailbag
|
638 |
+
mailbox
|
639 |
+
maillot
|
640 |
+
maillot
|
641 |
+
manhole cover
|
642 |
+
maraca
|
643 |
+
marimba
|
644 |
+
mask
|
645 |
+
matchstick
|
646 |
+
maypole
|
647 |
+
maze
|
648 |
+
measuring cup
|
649 |
+
medicine chest
|
650 |
+
megalith
|
651 |
+
microphone
|
652 |
+
microwave
|
653 |
+
military uniform
|
654 |
+
milk can
|
655 |
+
minibus
|
656 |
+
miniskirt
|
657 |
+
minivan
|
658 |
+
missile
|
659 |
+
mitten
|
660 |
+
mixing bowl
|
661 |
+
mobile home
|
662 |
+
Model T
|
663 |
+
modem
|
664 |
+
monastery
|
665 |
+
monitor
|
666 |
+
moped
|
667 |
+
mortar
|
668 |
+
mortarboard
|
669 |
+
mosque
|
670 |
+
mosquito net
|
671 |
+
motor scooter
|
672 |
+
mountain bike
|
673 |
+
mountain tent
|
674 |
+
mouse
|
675 |
+
mousetrap
|
676 |
+
moving van
|
677 |
+
muzzle
|
678 |
+
nail
|
679 |
+
neck brace
|
680 |
+
necklace
|
681 |
+
nipple
|
682 |
+
notebook
|
683 |
+
obelisk
|
684 |
+
oboe
|
685 |
+
ocarina
|
686 |
+
odometer
|
687 |
+
oil filter
|
688 |
+
organ
|
689 |
+
oscilloscope
|
690 |
+
overskirt
|
691 |
+
oxcart
|
692 |
+
oxygen mask
|
693 |
+
packet
|
694 |
+
paddle
|
695 |
+
paddlewheel
|
696 |
+
padlock
|
697 |
+
paintbrush
|
698 |
+
pajama
|
699 |
+
palace
|
700 |
+
panpipe
|
701 |
+
paper towel
|
702 |
+
parachute
|
703 |
+
parallel bars
|
704 |
+
park bench
|
705 |
+
parking meter
|
706 |
+
passenger car
|
707 |
+
patio
|
708 |
+
pay-phone
|
709 |
+
pedestal
|
710 |
+
pencil box
|
711 |
+
pencil sharpener
|
712 |
+
perfume
|
713 |
+
Petri dish
|
714 |
+
photocopier
|
715 |
+
pick
|
716 |
+
pickelhaube
|
717 |
+
picket fence
|
718 |
+
pickup
|
719 |
+
pier
|
720 |
+
piggy bank
|
721 |
+
pill bottle
|
722 |
+
pillow
|
723 |
+
ping-pong ball
|
724 |
+
pinwheel
|
725 |
+
pirate
|
726 |
+
pitcher
|
727 |
+
plane
|
728 |
+
planetarium
|
729 |
+
plastic bag
|
730 |
+
plate rack
|
731 |
+
plow
|
732 |
+
plunger
|
733 |
+
Polaroid camera
|
734 |
+
pole
|
735 |
+
police van
|
736 |
+
poncho
|
737 |
+
pool table
|
738 |
+
pop bottle
|
739 |
+
pot
|
740 |
+
potter's wheel
|
741 |
+
power drill
|
742 |
+
prayer rug
|
743 |
+
printer
|
744 |
+
prison
|
745 |
+
projectile
|
746 |
+
projector
|
747 |
+
puck
|
748 |
+
punching bag
|
749 |
+
purse
|
750 |
+
quill
|
751 |
+
quilt
|
752 |
+
racer
|
753 |
+
racket
|
754 |
+
radiator
|
755 |
+
radio
|
756 |
+
radio telescope
|
757 |
+
rain barrel
|
758 |
+
recreational vehicle
|
759 |
+
reel
|
760 |
+
reflex camera
|
761 |
+
refrigerator
|
762 |
+
remote control
|
763 |
+
restaurant
|
764 |
+
revolver
|
765 |
+
rifle
|
766 |
+
rocking chair
|
767 |
+
rotisserie
|
768 |
+
rubber eraser
|
769 |
+
rugby ball
|
770 |
+
rule
|
771 |
+
running shoe
|
772 |
+
safe
|
773 |
+
safety pin
|
774 |
+
saltshaker
|
775 |
+
sandal
|
776 |
+
sarong
|
777 |
+
sax
|
778 |
+
scabbard
|
779 |
+
scale
|
780 |
+
school bus
|
781 |
+
schooner
|
782 |
+
scoreboard
|
783 |
+
screen
|
784 |
+
screw
|
785 |
+
screwdriver
|
786 |
+
seat belt
|
787 |
+
sewing machine
|
788 |
+
shield
|
789 |
+
shoe shop
|
790 |
+
shoji
|
791 |
+
shopping basket
|
792 |
+
shopping cart
|
793 |
+
shovel
|
794 |
+
shower cap
|
795 |
+
shower curtain
|
796 |
+
ski
|
797 |
+
ski mask
|
798 |
+
sleeping bag
|
799 |
+
slide rule
|
800 |
+
sliding door
|
801 |
+
slot
|
802 |
+
snorkel
|
803 |
+
snowmobile
|
804 |
+
snowplow
|
805 |
+
soap dispenser
|
806 |
+
soccer ball
|
807 |
+
sock
|
808 |
+
solar dish
|
809 |
+
sombrero
|
810 |
+
soup bowl
|
811 |
+
space bar
|
812 |
+
space heater
|
813 |
+
space shuttle
|
814 |
+
spatula
|
815 |
+
speedboat
|
816 |
+
spider web
|
817 |
+
spindle
|
818 |
+
sports car
|
819 |
+
spotlight
|
820 |
+
stage
|
821 |
+
steam locomotive
|
822 |
+
steel arch bridge
|
823 |
+
steel drum
|
824 |
+
stethoscope
|
825 |
+
stole
|
826 |
+
stone wall
|
827 |
+
stopwatch
|
828 |
+
stove
|
829 |
+
strainer
|
830 |
+
streetcar
|
831 |
+
stretcher
|
832 |
+
studio couch
|
833 |
+
stupa
|
834 |
+
submarine
|
835 |
+
suit
|
836 |
+
sundial
|
837 |
+
sunglass
|
838 |
+
sunglasses
|
839 |
+
sunscreen
|
840 |
+
suspension bridge
|
841 |
+
swab
|
842 |
+
sweatshirt
|
843 |
+
swimming trunks
|
844 |
+
swing
|
845 |
+
switch
|
846 |
+
syringe
|
847 |
+
table lamp
|
848 |
+
tank
|
849 |
+
tape player
|
850 |
+
teapot
|
851 |
+
teddy
|
852 |
+
television
|
853 |
+
tennis ball
|
854 |
+
thatch
|
855 |
+
theater curtain
|
856 |
+
thimble
|
857 |
+
thresher
|
858 |
+
throne
|
859 |
+
tile roof
|
860 |
+
toaster
|
861 |
+
tobacco shop
|
862 |
+
toilet seat
|
863 |
+
torch
|
864 |
+
totem pole
|
865 |
+
tow truck
|
866 |
+
toyshop
|
867 |
+
tractor
|
868 |
+
trailer truck
|
869 |
+
tray
|
870 |
+
trench coat
|
871 |
+
tricycle
|
872 |
+
trimaran
|
873 |
+
tripod
|
874 |
+
triumphal arch
|
875 |
+
trolleybus
|
876 |
+
trombone
|
877 |
+
tub
|
878 |
+
turnstile
|
879 |
+
typewriter keyboard
|
880 |
+
umbrella
|
881 |
+
unicycle
|
882 |
+
upright
|
883 |
+
vacuum
|
884 |
+
vase
|
885 |
+
vault
|
886 |
+
velvet
|
887 |
+
vending machine
|
888 |
+
vestment
|
889 |
+
viaduct
|
890 |
+
violin
|
891 |
+
volleyball
|
892 |
+
waffle iron
|
893 |
+
wall clock
|
894 |
+
wallet
|
895 |
+
wardrobe
|
896 |
+
warplane
|
897 |
+
washbasin
|
898 |
+
washer
|
899 |
+
water bottle
|
900 |
+
water jug
|
901 |
+
water tower
|
902 |
+
whiskey jug
|
903 |
+
whistle
|
904 |
+
wig
|
905 |
+
window screen
|
906 |
+
window shade
|
907 |
+
Windsor tie
|
908 |
+
wine bottle
|
909 |
+
wing
|
910 |
+
wok
|
911 |
+
wooden spoon
|
912 |
+
wool
|
913 |
+
worm fence
|
914 |
+
wreck
|
915 |
+
yawl
|
916 |
+
yurt
|
917 |
+
web site
|
918 |
+
comic book
|
919 |
+
crossword puzzle
|
920 |
+
street sign
|
921 |
+
traffic light
|
922 |
+
book jacket
|
923 |
+
menu
|
924 |
+
plate
|
925 |
+
guacamole
|
926 |
+
consomme
|
927 |
+
hot pot
|
928 |
+
trifle
|
929 |
+
ice cream
|
930 |
+
ice lolly
|
931 |
+
French loaf
|
932 |
+
bagel
|
933 |
+
pretzel
|
934 |
+
cheeseburger
|
935 |
+
hotdog
|
936 |
+
mashed potato
|
937 |
+
head cabbage
|
938 |
+
broccoli
|
939 |
+
cauliflower
|
940 |
+
zucchini
|
941 |
+
spaghetti squash
|
942 |
+
acorn squash
|
943 |
+
butternut squash
|
944 |
+
cucumber
|
945 |
+
artichoke
|
946 |
+
bell pepper
|
947 |
+
cardoon
|
948 |
+
mushroom
|
949 |
+
Granny Smith
|
950 |
+
strawberry
|
951 |
+
orange
|
952 |
+
lemon
|
953 |
+
fig
|
954 |
+
pineapple
|
955 |
+
banana
|
956 |
+
jackfruit
|
957 |
+
custard apple
|
958 |
+
pomegranate
|
959 |
+
hay
|
960 |
+
carbonara
|
961 |
+
chocolate sauce
|
962 |
+
dough
|
963 |
+
meat loaf
|
964 |
+
pizza
|
965 |
+
potpie
|
966 |
+
burrito
|
967 |
+
red wine
|
968 |
+
espresso
|
969 |
+
cup
|
970 |
+
eggnog
|
971 |
+
alp
|
972 |
+
bubble
|
973 |
+
cliff
|
974 |
+
coral reef
|
975 |
+
geyser
|
976 |
+
lakeside
|
977 |
+
promontory
|
978 |
+
sandbar
|
979 |
+
seashore
|
980 |
+
valley
|
981 |
+
volcano
|
982 |
+
ballplayer
|
983 |
+
groom
|
984 |
+
scuba diver
|
985 |
+
rapeseed
|
986 |
+
daisy
|
987 |
+
yellow lady's slipper
|
988 |
+
corn
|
989 |
+
acorn
|
990 |
+
hip
|
991 |
+
buckeye
|
992 |
+
coral fungus
|
993 |
+
agaric
|
994 |
+
gyromitra
|
995 |
+
stinkhorn
|
996 |
+
earthstar
|
997 |
+
hen-of-the-woods
|
998 |
+
bolete
|
999 |
+
ear
|
1000 |
+
toilet tissue
|
logger.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
import logging
|
5 |
+
|
6 |
+
import pytorch_lightning # noqa: F401
|
7 |
+
|
8 |
+
|
9 |
+
formatter = logging.Formatter(
|
10 |
+
fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s",
|
11 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
12 |
+
)
|
13 |
+
handler = logging.StreamHandler()
|
14 |
+
handler.setFormatter(formatter)
|
15 |
+
handler.setLevel(logging.INFO)
|
16 |
+
|
17 |
+
logger = logging.getLogger("maploc")
|
18 |
+
logger.setLevel(logging.INFO)
|
19 |
+
logger.addHandler(handler)
|
20 |
+
logger.propagate = False
|
21 |
+
|
22 |
+
pl_logger = logging.getLogger("pytorch_lightning")
|
23 |
+
if len(pl_logger.handlers):
|
24 |
+
pl_logger.handlers[0].setFormatter(formatter)
|
25 |
+
|
26 |
+
repo_dir = Path(__file__).parent
|
27 |
+
EXPERIMENTS_PATH = repo_dir / "experiments/"
|
28 |
+
DATASETS_PATH = repo_dir / "datasets/"
|
main.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import cv2
|
3 |
+
import gradio as gr
|
4 |
+
import torch
|
5 |
+
from torchvision import transforms
|
6 |
+
import requests
|
7 |
+
from PIL import Image
|
8 |
+
from demo import Demo,read_input_image_test,show_result,vis_image_feature
|
9 |
+
from osm.tiling import TileManager
|
10 |
+
from osm.viz import Colormap, plot_nodes
|
11 |
+
from utils.viz_2d import plot_images
|
12 |
+
import numpy as np
|
13 |
+
from utils.viz_2d import features_to_RGB
|
14 |
+
from utils.viz_localization import (
|
15 |
+
likelihood_overlay,
|
16 |
+
plot_dense_rotations,
|
17 |
+
add_circle_inset,
|
18 |
+
)
|
19 |
+
from osm.viz import GeoPlotter
|
20 |
+
import matplotlib.pyplot as plt
|
21 |
+
import random
|
22 |
+
from geopy.distance import geodesic
|
23 |
+
|
24 |
+
experiment_or_path = "weight/last-step-checkpointing.ckpt"
|
25 |
+
# experiment_or_path="experiments/maplocanet_0906_diffhight/last-step-checkpointing.ckpt"
|
26 |
+
image_path = 'images/00000.jpg'
|
27 |
+
|
28 |
+
# prior_latlon = (37.75704325989902, -122.435941445631)
|
29 |
+
# tile_size_meters = 128
|
30 |
+
model = Demo(experiment_or_path=experiment_or_path, num_rotations=128, device='cpu')
|
31 |
+
|
32 |
+
def demo_localize(image,long,lat,tile_size_meters):
|
33 |
+
# inp = Image.fromarray(inp.astype('uint8'), 'RGB')
|
34 |
+
# inp = transforms.ToTensor()(inp).unsqueeze(0)
|
35 |
+
prior_latlon=(lat,long)
|
36 |
+
image, camera, gravity, proj, bbox, true_prior_latlon = read_input_image_test(
|
37 |
+
image,
|
38 |
+
prior_latlon=prior_latlon,
|
39 |
+
tile_size_meters=tile_size_meters, # try 64, 256, etc.
|
40 |
+
)
|
41 |
+
tiler = TileManager.from_bbox(projection=proj, bbox=bbox, ppm=1, tile_size=tile_size_meters)
|
42 |
+
# tiler = TileManager.from_bbox(projection=proj, bbox=bbox + 10,ppm=1,path=root/city/'{}.osm'.format(city), tile_size=1)
|
43 |
+
canvas = tiler.query(bbox)
|
44 |
+
uv, yaw, prob, neural_map, image_rectified, data_, pred = model.localize(
|
45 |
+
image, camera, canvas)
|
46 |
+
prior_latlon_pred = proj.unproject(canvas.to_xy(uv))
|
47 |
+
|
48 |
+
map_viz = Colormap.apply(canvas.raster)
|
49 |
+
map_vis_image_result = map_viz * 255
|
50 |
+
map_vis_image_result =show_result(map_vis_image_result.astype(np.uint8), uv, yaw)
|
51 |
+
# map_vis_image_result = show_result(map_vis_image_result.astype(np.uint8), True_uv,
|
52 |
+
# uv,
|
53 |
+
# 90.0 - yaw_T,
|
54 |
+
# yaw)
|
55 |
+
# return prior_latlon_pred
|
56 |
+
uab_feature_rgb = vis_image_feature(pred['features_image'][0].cpu().numpy())
|
57 |
+
map_viz = cv2.resize(map_viz, (prob.numpy().shape[0], prob.numpy().shape[1]))
|
58 |
+
overlay = likelihood_overlay(prob.numpy().max(-1), map_viz.mean(-1, keepdims=True))
|
59 |
+
(neural_map_rgb,) = features_to_RGB(neural_map.numpy())
|
60 |
+
fig=plot_images([image, map_vis_image_result / 255, overlay, uab_feature_rgb, neural_map_rgb],
|
61 |
+
titles=["UAV image", "map","likelihood","UAV feature","map feature"])
|
62 |
+
# plot_images([overlay, neural_map_rgb], titles=["prediction", "neural map"])
|
63 |
+
# ax = plt.gcf().axes[2]
|
64 |
+
# ax.scatter(*canvas.to_uv(bbox.center), s=5, c="red")
|
65 |
+
# plot_dense_rotations(ax, prob, w=0.005, s=1 / 25)
|
66 |
+
# add_circle_inset(ax, uv)
|
67 |
+
|
68 |
+
# Plot as interactive figure
|
69 |
+
bbox_latlon = proj.unproject(canvas.bbox)
|
70 |
+
plot2 = GeoPlotter(zoom=16.5)
|
71 |
+
plot2.raster(map_viz, bbox_latlon, opacity=0.5)
|
72 |
+
plot2.raster(likelihood_overlay(prob.numpy().max(-1)), proj.unproject(bbox))
|
73 |
+
plot2.points(prior_latlon[:2], "red", name="location prior", size=10)
|
74 |
+
plot2.points(proj.unproject(canvas.to_xy(uv)), "black", name="argmax", size=10)
|
75 |
+
plot2.bbox(bbox_latlon, "blue", name="map tile")
|
76 |
+
# plot2.fig.show()
|
77 |
+
return fig,plot2.fig,str(prior_latlon_pred)
|
78 |
+
# model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
|
79 |
+
#标题
|
80 |
+
title = "MapLocNet"
|
81 |
+
#标题下的描述,支持md格式
|
82 |
+
description = "UAV Vision-based Geo-Localization Using Vectorized Maps"
|
83 |
+
|
84 |
+
# outputs = gr.outputs.Label(num_top_classes=3)
|
85 |
+
outputs = gr.Plot()
|
86 |
+
interface = gr.Interface(fn=demo_localize,
|
87 |
+
inputs=["image",
|
88 |
+
gr.Number(label="Prior location-longitude)"),
|
89 |
+
gr.Number(label="Prior location-longitude)"),
|
90 |
+
gr.Radio([64, 128, 256], label="Search radius (meters)", info="vectorized map size"),
|
91 |
+
# gr.inputs.RadioGroup(label="Search radius (meters)",["English", "French", "Spanish"]),
|
92 |
+
# gr.Slider(64, 512,label='Search radius (meters)')
|
93 |
+
],
|
94 |
+
outputs=["plot","plot","text"],
|
95 |
+
title=title,
|
96 |
+
description=description,
|
97 |
+
examples=[['images/00000.jpg',-122.435941445631,37.75704325989902,128]])
|
98 |
+
interface.launch(share=True)
|
models/__init__.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
|
4 |
+
# https://github.com/cvg/pixloc
|
5 |
+
# Released under the Apache License 2.0
|
6 |
+
|
7 |
+
import inspect
|
8 |
+
|
9 |
+
from .base import BaseModel
|
10 |
+
|
11 |
+
|
12 |
+
def get_class(mod_name, base_path, BaseClass):
|
13 |
+
"""Get the class object which inherits from BaseClass and is defined in
|
14 |
+
the module named mod_name, child of base_path.
|
15 |
+
"""
|
16 |
+
mod_path = "{}.{}".format(base_path, mod_name)
|
17 |
+
mod = __import__(mod_path, fromlist=[""])
|
18 |
+
classes = inspect.getmembers(mod, inspect.isclass)
|
19 |
+
# Filter classes defined in the module
|
20 |
+
classes = [c for c in classes if c[1].__module__ == mod_path]
|
21 |
+
# Filter classes inherited from BaseModel
|
22 |
+
classes = [c for c in classes if issubclass(c[1], BaseClass)]
|
23 |
+
assert len(classes) == 1, classes
|
24 |
+
return classes[0][1]
|
25 |
+
|
26 |
+
|
27 |
+
def get_model(name):
|
28 |
+
if name == "localizer":
|
29 |
+
name = "localizer_basic"
|
30 |
+
elif name == "rotation_localizer":
|
31 |
+
name = "localizer_basic_rotation"
|
32 |
+
elif name == "bev_localizer":
|
33 |
+
name = "localizer_bev_plane"
|
34 |
+
return get_class(name, __name__, BaseModel)
|
models/base.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
|
4 |
+
# https://github.com/cvg/pixloc
|
5 |
+
# Released under the Apache License 2.0
|
6 |
+
|
7 |
+
"""
|
8 |
+
Base class for trainable models.
|
9 |
+
"""
|
10 |
+
|
11 |
+
from abc import ABCMeta, abstractmethod
|
12 |
+
from copy import copy
|
13 |
+
|
14 |
+
import omegaconf
|
15 |
+
from omegaconf import OmegaConf
|
16 |
+
from torch import nn
|
17 |
+
|
18 |
+
|
19 |
+
class BaseModel(nn.Module, metaclass=ABCMeta):
|
20 |
+
"""
|
21 |
+
What the child model is expect to declare:
|
22 |
+
default_conf: dictionary of the default configuration of the model.
|
23 |
+
It recursively updates the default_conf of all parent classes, and
|
24 |
+
it is updated by the user-provided configuration passed to __init__.
|
25 |
+
Configurations can be nested.
|
26 |
+
|
27 |
+
required_data_keys: list of expected keys in the input data dictionary.
|
28 |
+
|
29 |
+
strict_conf (optional): boolean. If false, BaseModel does not raise
|
30 |
+
an error when the user provides an unknown configuration entry.
|
31 |
+
|
32 |
+
_init(self, conf): initialization method, where conf is the final
|
33 |
+
configuration object (also accessible with `self.conf`). Accessing
|
34 |
+
unknown configuration entries will raise an error.
|
35 |
+
|
36 |
+
_forward(self, data): method that returns a dictionary of batched
|
37 |
+
prediction tensors based on a dictionary of batched input data tensors.
|
38 |
+
|
39 |
+
loss(self, pred, data): method that returns a dictionary of losses,
|
40 |
+
computed from model predictions and input data. Each loss is a batch
|
41 |
+
of scalars, i.e. a torch.Tensor of shape (B,).
|
42 |
+
The total loss to be optimized has the key `'total'`.
|
43 |
+
|
44 |
+
metrics(self, pred, data): method that returns a dictionary of metrics,
|
45 |
+
each as a batch of scalars.
|
46 |
+
"""
|
47 |
+
|
48 |
+
base_default_conf = {
|
49 |
+
"name": None,
|
50 |
+
"trainable": True, # if false: do not optimize this model parameters
|
51 |
+
"freeze_batch_normalization": False, # use test-time statistics
|
52 |
+
}
|
53 |
+
default_conf = {}
|
54 |
+
required_data_keys = []
|
55 |
+
strict_conf = True
|
56 |
+
|
57 |
+
def __init__(self, conf):
|
58 |
+
"""Perform some logic and call the _init method of the child model."""
|
59 |
+
super().__init__()
|
60 |
+
default_conf = OmegaConf.merge(
|
61 |
+
self.base_default_conf, OmegaConf.create(self.default_conf)
|
62 |
+
)
|
63 |
+
if self.strict_conf:
|
64 |
+
OmegaConf.set_struct(default_conf, True)
|
65 |
+
|
66 |
+
# fixme: backward compatibility
|
67 |
+
if "pad" in conf and "pad" not in default_conf: # backward compat.
|
68 |
+
with omegaconf.read_write(conf):
|
69 |
+
with omegaconf.open_dict(conf):
|
70 |
+
conf["interpolation"] = {"pad": conf.pop("pad")}
|
71 |
+
|
72 |
+
if isinstance(conf, dict):
|
73 |
+
conf = OmegaConf.create(conf)
|
74 |
+
self.conf = conf = OmegaConf.merge(default_conf, conf)
|
75 |
+
OmegaConf.set_readonly(conf, True)
|
76 |
+
OmegaConf.set_struct(conf, True)
|
77 |
+
self.required_data_keys = copy(self.required_data_keys)
|
78 |
+
self._init(conf)
|
79 |
+
|
80 |
+
if not conf.trainable:
|
81 |
+
for p in self.parameters():
|
82 |
+
p.requires_grad = False
|
83 |
+
|
84 |
+
def train(self, mode=True):
|
85 |
+
super().train(mode)
|
86 |
+
|
87 |
+
def freeze_bn(module):
|
88 |
+
if isinstance(module, nn.modules.batchnorm._BatchNorm):
|
89 |
+
module.eval()
|
90 |
+
|
91 |
+
if self.conf.freeze_batch_normalization:
|
92 |
+
self.apply(freeze_bn)
|
93 |
+
|
94 |
+
return self
|
95 |
+
|
96 |
+
def forward(self, data):
|
97 |
+
"""Check the data and call the _forward method of the child model."""
|
98 |
+
|
99 |
+
def recursive_key_check(expected, given):
|
100 |
+
for key in expected:
|
101 |
+
assert key in given, f"Missing key {key} in data"
|
102 |
+
if isinstance(expected, dict):
|
103 |
+
recursive_key_check(expected[key], given[key])
|
104 |
+
|
105 |
+
recursive_key_check(self.required_data_keys, data)
|
106 |
+
return self._forward(data)
|
107 |
+
|
108 |
+
@abstractmethod
|
109 |
+
def _init(self, conf):
|
110 |
+
"""To be implemented by the child class."""
|
111 |
+
raise NotImplementedError
|
112 |
+
|
113 |
+
@abstractmethod
|
114 |
+
def _forward(self, data):
|
115 |
+
"""To be implemented by the child class."""
|
116 |
+
raise NotImplementedError
|
117 |
+
|
118 |
+
def loss(self, pred, data):
|
119 |
+
"""To be implemented by the child class."""
|
120 |
+
raise NotImplementedError
|
121 |
+
|
122 |
+
def metrics(self):
|
123 |
+
return {} # no metrics
|
models/feature_extractor.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
|
4 |
+
# https://github.com/cvg/pixloc
|
5 |
+
# Released under the Apache License 2.0
|
6 |
+
|
7 |
+
"""
|
8 |
+
Flexible UNet model which takes any Torchvision backbone as encoder.
|
9 |
+
Predicts multi-level feature and makes sure that they are well aligned.
|
10 |
+
"""
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torchvision
|
15 |
+
|
16 |
+
from .base import BaseModel
|
17 |
+
from .utils import checkpointed
|
18 |
+
|
19 |
+
|
20 |
+
class DecoderBlock(nn.Module):
|
21 |
+
def __init__(
|
22 |
+
self, previous, skip, out, num_convs=1, norm=nn.BatchNorm2d, padding="zeros"
|
23 |
+
):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
self.upsample = nn.Upsample(
|
27 |
+
scale_factor=2, mode="bilinear", align_corners=False
|
28 |
+
)
|
29 |
+
|
30 |
+
layers = []
|
31 |
+
for i in range(num_convs):
|
32 |
+
conv = nn.Conv2d(
|
33 |
+
previous + skip if i == 0 else out,
|
34 |
+
out,
|
35 |
+
kernel_size=3,
|
36 |
+
padding=1,
|
37 |
+
bias=norm is None,
|
38 |
+
padding_mode=padding,
|
39 |
+
)
|
40 |
+
layers.append(conv)
|
41 |
+
if norm is not None:
|
42 |
+
layers.append(norm(out))
|
43 |
+
layers.append(nn.ReLU(inplace=True))
|
44 |
+
self.layers = nn.Sequential(*layers)
|
45 |
+
|
46 |
+
def forward(self, previous, skip):
|
47 |
+
upsampled = self.upsample(previous)
|
48 |
+
# If the shape of the input map `skip` is not a multiple of 2,
|
49 |
+
# it will not match the shape of the upsampled map `upsampled`.
|
50 |
+
# If the downsampling uses ceil_mode=False, we nedd to crop `skip`.
|
51 |
+
# If it uses ceil_mode=True (not supported here), we should pad it.
|
52 |
+
_, _, hu, wu = upsampled.shape
|
53 |
+
_, _, hs, ws = skip.shape
|
54 |
+
assert (hu <= hs) and (wu <= ws), "Using ceil_mode=True in pooling?"
|
55 |
+
# assert (hu == hs) and (wu == ws), 'Careful about padding'
|
56 |
+
skip = skip[:, :, :hu, :wu]
|
57 |
+
return self.layers(torch.cat([upsampled, skip], dim=1))
|
58 |
+
|
59 |
+
|
60 |
+
class AdaptationBlock(nn.Sequential):
|
61 |
+
def __init__(self, inp, out):
|
62 |
+
conv = nn.Conv2d(inp, out, kernel_size=1, padding=0, bias=True)
|
63 |
+
super().__init__(conv)
|
64 |
+
|
65 |
+
|
66 |
+
class FeatureExtractor(BaseModel):
|
67 |
+
default_conf = {
|
68 |
+
"pretrained": True,
|
69 |
+
"input_dim": 3,
|
70 |
+
"output_scales": [0, 2, 4], # what scales to adapt and output
|
71 |
+
"output_dim": 128, # # of channels in output feature maps
|
72 |
+
"encoder": "vgg16", # string (torchvision net) or list of channels
|
73 |
+
"num_downsample": 4, # how many downsample block (if VGG-style net)
|
74 |
+
"decoder": [64, 64, 64, 64], # list of channels of decoder
|
75 |
+
"decoder_norm": "nn.BatchNorm2d", # normalization ind decoder blocks
|
76 |
+
"do_average_pooling": False,
|
77 |
+
"checkpointed": False, # whether to use gradient checkpointing
|
78 |
+
"padding": "zeros",
|
79 |
+
}
|
80 |
+
mean = [0.485, 0.456, 0.406]
|
81 |
+
std = [0.229, 0.224, 0.225]
|
82 |
+
|
83 |
+
def build_encoder(self, conf):
|
84 |
+
assert isinstance(conf.encoder, str)
|
85 |
+
if conf.pretrained:
|
86 |
+
assert conf.input_dim == 3
|
87 |
+
Encoder = getattr(torchvision.models, conf.encoder)
|
88 |
+
encoder = Encoder(weights="DEFAULT" if conf.pretrained else None)
|
89 |
+
Block = checkpointed(torch.nn.Sequential, do=conf.checkpointed)
|
90 |
+
assert max(conf.output_scales) <= conf.num_downsample
|
91 |
+
|
92 |
+
if conf.encoder.startswith("vgg"):
|
93 |
+
# Parse the layers and pack them into downsampling blocks
|
94 |
+
# It's easy for VGG-style nets because of their linear structure.
|
95 |
+
# This does not handle strided convs and residual connections
|
96 |
+
skip_dims = []
|
97 |
+
previous_dim = None
|
98 |
+
blocks = [[]]
|
99 |
+
for i, layer in enumerate(encoder.features):
|
100 |
+
if isinstance(layer, torch.nn.Conv2d):
|
101 |
+
# Change the first conv layer if the input dim mismatches
|
102 |
+
if i == 0 and conf.input_dim != layer.in_channels:
|
103 |
+
args = {k: getattr(layer, k) for k in layer.__constants__}
|
104 |
+
args.pop("output_padding")
|
105 |
+
layer = torch.nn.Conv2d(
|
106 |
+
**{**args, "in_channels": conf.input_dim}
|
107 |
+
)
|
108 |
+
previous_dim = layer.out_channels
|
109 |
+
elif isinstance(layer, torch.nn.MaxPool2d):
|
110 |
+
assert previous_dim is not None
|
111 |
+
skip_dims.append(previous_dim)
|
112 |
+
if (conf.num_downsample + 1) == len(blocks):
|
113 |
+
break
|
114 |
+
blocks.append([]) # start a new block
|
115 |
+
if conf.do_average_pooling:
|
116 |
+
assert layer.dilation == 1
|
117 |
+
layer = torch.nn.AvgPool2d(
|
118 |
+
kernel_size=layer.kernel_size,
|
119 |
+
stride=layer.stride,
|
120 |
+
padding=layer.padding,
|
121 |
+
ceil_mode=layer.ceil_mode,
|
122 |
+
count_include_pad=False,
|
123 |
+
)
|
124 |
+
blocks[-1].append(layer)
|
125 |
+
encoder = [Block(*b) for b in blocks]
|
126 |
+
elif conf.encoder.startswith("resnet"):
|
127 |
+
# Manually define the ResNet blocks such that the downsampling comes first
|
128 |
+
assert conf.encoder[len("resnet") :] in ["18", "34", "50", "101"]
|
129 |
+
assert conf.input_dim == 3, "Unsupported for now."
|
130 |
+
block1 = torch.nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu)
|
131 |
+
block2 = torch.nn.Sequential(encoder.maxpool, encoder.layer1)
|
132 |
+
block3 = encoder.layer2
|
133 |
+
block4 = encoder.layer3
|
134 |
+
block5 = encoder.layer4
|
135 |
+
blocks = [block1, block2, block3, block4, block5]
|
136 |
+
# Extract the output dimension of each block
|
137 |
+
skip_dims = [encoder.conv1.out_channels]
|
138 |
+
for i in range(1, 5):
|
139 |
+
modules = getattr(encoder, f"layer{i}")[-1]._modules
|
140 |
+
conv = sorted(k for k in modules if k.startswith("conv"))[-1]
|
141 |
+
skip_dims.append(modules[conv].out_channels)
|
142 |
+
# Add a dummy block such that the first one does not downsample
|
143 |
+
encoder = [torch.nn.Identity()] + [Block(b) for b in blocks]
|
144 |
+
skip_dims = [3] + skip_dims
|
145 |
+
# Trim based on the requested encoder size
|
146 |
+
encoder = encoder[: conf.num_downsample + 1]
|
147 |
+
skip_dims = skip_dims[: conf.num_downsample + 1]
|
148 |
+
else:
|
149 |
+
raise NotImplementedError(conf.encoder)
|
150 |
+
|
151 |
+
assert (conf.num_downsample + 1) == len(encoder)
|
152 |
+
encoder = nn.ModuleList(encoder)
|
153 |
+
|
154 |
+
return encoder, skip_dims
|
155 |
+
|
156 |
+
def _init(self, conf):
|
157 |
+
# Encoder
|
158 |
+
self.encoder, skip_dims = self.build_encoder(conf)
|
159 |
+
self.skip_dims = skip_dims
|
160 |
+
|
161 |
+
def update_padding(module):
|
162 |
+
if isinstance(module, nn.Conv2d):
|
163 |
+
module.padding_mode = conf.padding
|
164 |
+
|
165 |
+
if conf.padding != "zeros":
|
166 |
+
self.encoder.apply(update_padding)
|
167 |
+
|
168 |
+
# Decoder
|
169 |
+
if conf.decoder is not None:
|
170 |
+
assert len(conf.decoder) == (len(skip_dims) - 1)
|
171 |
+
Block = checkpointed(DecoderBlock, do=conf.checkpointed)
|
172 |
+
norm = eval(conf.decoder_norm) if conf.decoder_norm else None # noqa
|
173 |
+
|
174 |
+
previous = skip_dims[-1]
|
175 |
+
decoder = []
|
176 |
+
for out, skip in zip(conf.decoder, skip_dims[:-1][::-1]):
|
177 |
+
decoder.append(
|
178 |
+
Block(previous, skip, out, norm=norm, padding=conf.padding)
|
179 |
+
)
|
180 |
+
previous = out
|
181 |
+
self.decoder = nn.ModuleList(decoder)
|
182 |
+
|
183 |
+
# Adaptation layers
|
184 |
+
adaptation = []
|
185 |
+
for idx, i in enumerate(conf.output_scales):
|
186 |
+
if conf.decoder is None or i == (len(self.encoder) - 1):
|
187 |
+
input_ = skip_dims[i]
|
188 |
+
else:
|
189 |
+
input_ = conf.decoder[-1 - i]
|
190 |
+
|
191 |
+
# out_dim can be an int (same for all scales) or a list (per scale)
|
192 |
+
dim = conf.output_dim
|
193 |
+
if not isinstance(dim, int):
|
194 |
+
dim = dim[idx]
|
195 |
+
|
196 |
+
block = AdaptationBlock(input_, dim)
|
197 |
+
adaptation.append(block)
|
198 |
+
self.adaptation = nn.ModuleList(adaptation)
|
199 |
+
self.scales = [2**s for s in conf.output_scales]
|
200 |
+
|
201 |
+
def _forward(self, data):
|
202 |
+
image = data["image"]
|
203 |
+
if self.conf.pretrained:
|
204 |
+
mean, std = image.new_tensor(self.mean), image.new_tensor(self.std)
|
205 |
+
image = (image - mean[:, None, None]) / std[:, None, None]
|
206 |
+
|
207 |
+
skip_features = []
|
208 |
+
features = image
|
209 |
+
for block in self.encoder:
|
210 |
+
features = block(features)
|
211 |
+
skip_features.append(features)
|
212 |
+
|
213 |
+
if self.conf.decoder:
|
214 |
+
pre_features = [skip_features[-1]]
|
215 |
+
for block, skip in zip(self.decoder, skip_features[:-1][::-1]):
|
216 |
+
pre_features.append(block(pre_features[-1], skip))
|
217 |
+
pre_features = pre_features[::-1] # fine to coarse
|
218 |
+
else:
|
219 |
+
pre_features = skip_features
|
220 |
+
|
221 |
+
out_features = []
|
222 |
+
for adapt, i in zip(self.adaptation, self.conf.output_scales):
|
223 |
+
out_features.append(adapt(pre_features[i]))
|
224 |
+
pred = {"feature_maps": out_features, "skip_features": skip_features}
|
225 |
+
return pred
|
226 |
+
|
227 |
+
def loss(self, pred, data):
|
228 |
+
raise NotImplementedError
|
229 |
+
|
230 |
+
def metrics(self, pred, data):
|
231 |
+
raise NotImplementedError
|
models/feature_extractor_v2.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torchvision
|
7 |
+
from torchvision.models.feature_extraction import create_feature_extractor
|
8 |
+
|
9 |
+
from .base import BaseModel
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
class DecoderBlock(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self, previous, out, ksize=3, num_convs=1, norm=nn.BatchNorm2d, padding="zeros"
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
layers = []
|
20 |
+
for i in range(num_convs):
|
21 |
+
conv = nn.Conv2d(
|
22 |
+
previous if i == 0 else out,
|
23 |
+
out,
|
24 |
+
kernel_size=ksize,
|
25 |
+
padding=ksize // 2,
|
26 |
+
bias=norm is None,
|
27 |
+
padding_mode=padding,
|
28 |
+
)
|
29 |
+
layers.append(conv)
|
30 |
+
if norm is not None:
|
31 |
+
layers.append(norm(out))
|
32 |
+
layers.append(nn.ReLU(inplace=True))
|
33 |
+
self.layers = nn.Sequential(*layers)
|
34 |
+
|
35 |
+
def forward(self, previous, skip):
|
36 |
+
_, _, hp, wp = previous.shape
|
37 |
+
_, _, hs, ws = skip.shape
|
38 |
+
scale = 2 ** np.round(np.log2(np.array([hs / hp, ws / wp])))
|
39 |
+
upsampled = nn.functional.interpolate(
|
40 |
+
previous, scale_factor=scale.tolist(), mode="bilinear", align_corners=False
|
41 |
+
)
|
42 |
+
# If the shape of the input map `skip` is not a multiple of 2,
|
43 |
+
# it will not match the shape of the upsampled map `upsampled`.
|
44 |
+
# If the downsampling uses ceil_mode=False, we nedd to crop `skip`.
|
45 |
+
# If it uses ceil_mode=True (not supported here), we should pad it.
|
46 |
+
_, _, hu, wu = upsampled.shape
|
47 |
+
_, _, hs, ws = skip.shape
|
48 |
+
if (hu <= hs) and (wu <= ws):
|
49 |
+
skip = skip[:, :, :hu, :wu]
|
50 |
+
elif (hu >= hs) and (wu >= ws):
|
51 |
+
skip = nn.functional.pad(skip, [0, wu - ws, 0, hu - hs])
|
52 |
+
else:
|
53 |
+
raise ValueError(
|
54 |
+
f"Inconsistent skip vs upsampled shapes: {(hs, ws)}, {(hu, wu)}"
|
55 |
+
)
|
56 |
+
|
57 |
+
return self.layers(skip) + upsampled
|
58 |
+
|
59 |
+
|
60 |
+
class FPN(nn.Module):
|
61 |
+
def __init__(self, in_channels_list, out_channels, **kw):
|
62 |
+
super().__init__()
|
63 |
+
self.first = nn.Conv2d(
|
64 |
+
in_channels_list[-1], out_channels, 1, padding=0, bias=True
|
65 |
+
)
|
66 |
+
self.blocks = nn.ModuleList(
|
67 |
+
[
|
68 |
+
DecoderBlock(c, out_channels, ksize=1, **kw)
|
69 |
+
for c in in_channels_list[::-1][1:]
|
70 |
+
]
|
71 |
+
)
|
72 |
+
self.out = nn.Sequential(
|
73 |
+
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
|
74 |
+
nn.BatchNorm2d(out_channels),
|
75 |
+
nn.ReLU(inplace=True),
|
76 |
+
)
|
77 |
+
|
78 |
+
def forward(self, layers):
|
79 |
+
feats = None
|
80 |
+
for idx, x in enumerate(reversed(layers.values())):
|
81 |
+
if feats is None:
|
82 |
+
feats = self.first(x)
|
83 |
+
else:
|
84 |
+
feats = self.blocks[idx - 1](feats, x)
|
85 |
+
out = self.out(feats)
|
86 |
+
return out
|
87 |
+
|
88 |
+
|
89 |
+
def remove_conv_stride(conv):
|
90 |
+
conv_new = nn.Conv2d(
|
91 |
+
conv.in_channels,
|
92 |
+
conv.out_channels,
|
93 |
+
conv.kernel_size,
|
94 |
+
bias=conv.bias is not None,
|
95 |
+
stride=1,
|
96 |
+
padding=conv.padding,
|
97 |
+
)
|
98 |
+
conv_new.weight = conv.weight
|
99 |
+
conv_new.bias = conv.bias
|
100 |
+
return conv_new
|
101 |
+
|
102 |
+
|
103 |
+
class FeatureExtractor(BaseModel):
|
104 |
+
default_conf = {
|
105 |
+
"pretrained": True,
|
106 |
+
"input_dim": 3,
|
107 |
+
"output_dim": 128, # # of channels in output feature maps
|
108 |
+
"encoder": "resnet50", # torchvision net as string
|
109 |
+
"remove_stride_from_first_conv": False,
|
110 |
+
"num_downsample": None, # how many downsample block
|
111 |
+
"decoder_norm": "nn.BatchNorm2d", # normalization ind decoder blocks
|
112 |
+
"do_average_pooling": False,
|
113 |
+
"checkpointed": False, # whether to use gradient checkpointing
|
114 |
+
}
|
115 |
+
mean = [0.485, 0.456, 0.406]
|
116 |
+
std = [0.229, 0.224, 0.225]
|
117 |
+
|
118 |
+
def build_encoder(self, conf):
|
119 |
+
assert isinstance(conf.encoder, str)
|
120 |
+
if conf.pretrained:
|
121 |
+
assert conf.input_dim == 3
|
122 |
+
Encoder = getattr(torchvision.models, conf.encoder)
|
123 |
+
|
124 |
+
kw = {}
|
125 |
+
if conf.encoder.startswith("resnet"):
|
126 |
+
layers = ["relu", "layer1", "layer2", "layer3", "layer4"]
|
127 |
+
kw["replace_stride_with_dilation"] = [False, False, False]
|
128 |
+
elif conf.encoder == "vgg13":
|
129 |
+
layers = [
|
130 |
+
"features.3",
|
131 |
+
"features.8",
|
132 |
+
"features.13",
|
133 |
+
"features.18",
|
134 |
+
"features.23",
|
135 |
+
]
|
136 |
+
elif conf.encoder == "vgg16":
|
137 |
+
layers = [
|
138 |
+
"features.3",
|
139 |
+
"features.8",
|
140 |
+
"features.15",
|
141 |
+
"features.22",
|
142 |
+
"features.29",
|
143 |
+
]
|
144 |
+
else:
|
145 |
+
raise NotImplementedError(conf.encoder)
|
146 |
+
|
147 |
+
if conf.num_downsample is not None:
|
148 |
+
layers = layers[: conf.num_downsample]
|
149 |
+
encoder = Encoder(weights="DEFAULT" if conf.pretrained else None, **kw)
|
150 |
+
encoder = create_feature_extractor(encoder, return_nodes=layers)
|
151 |
+
if conf.encoder.startswith("resnet") and conf.remove_stride_from_first_conv:
|
152 |
+
encoder.conv1 = remove_conv_stride(encoder.conv1)
|
153 |
+
|
154 |
+
if conf.do_average_pooling:
|
155 |
+
raise NotImplementedError
|
156 |
+
if conf.checkpointed:
|
157 |
+
raise NotImplementedError
|
158 |
+
|
159 |
+
return encoder, layers
|
160 |
+
|
161 |
+
def _init(self, conf):
|
162 |
+
# Preprocessing
|
163 |
+
self.register_buffer("mean_", torch.tensor(self.mean), persistent=False)
|
164 |
+
self.register_buffer("std_", torch.tensor(self.std), persistent=False)
|
165 |
+
|
166 |
+
# Encoder
|
167 |
+
self.encoder, self.layers = self.build_encoder(conf)
|
168 |
+
s = 128
|
169 |
+
inp = torch.zeros(1, 3, s, s)
|
170 |
+
features = list(self.encoder(inp).values())
|
171 |
+
self.skip_dims = [x.shape[1] for x in features]
|
172 |
+
self.layer_strides = [s / f.shape[-1] for f in features]
|
173 |
+
self.scales = [self.layer_strides[0]]
|
174 |
+
|
175 |
+
# Decoder
|
176 |
+
norm = eval(conf.decoder_norm) if conf.decoder_norm else None # noqa
|
177 |
+
self.decoder = FPN(self.skip_dims, out_channels=conf.output_dim, norm=norm)
|
178 |
+
|
179 |
+
logger.debug(
|
180 |
+
"Built feature extractor with layers {name:dim:stride}:\n"
|
181 |
+
f"{list(zip(self.layers, self.skip_dims, self.layer_strides))}\n"
|
182 |
+
f"and output scales {self.scales}."
|
183 |
+
)
|
184 |
+
|
185 |
+
def _forward(self, data):
|
186 |
+
image = data["image"]
|
187 |
+
image = (image - self.mean_[:, None, None]) / self.std_[:, None, None]
|
188 |
+
|
189 |
+
skip_features = self.encoder(image)
|
190 |
+
output = self.decoder(skip_features)
|
191 |
+
pred = {"feature_maps": [output], "skip_features": skip_features}
|
192 |
+
return pred
|
models/map_encoder.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from .base import BaseModel
|
7 |
+
from .feature_extractor import FeatureExtractor
|
8 |
+
|
9 |
+
|
10 |
+
class MapEncoder(BaseModel):
|
11 |
+
default_conf = {
|
12 |
+
"embedding_dim": "???",
|
13 |
+
"output_dim": None,
|
14 |
+
"num_classes": "???",
|
15 |
+
"backbone": "???",
|
16 |
+
"unary_prior": False,
|
17 |
+
}
|
18 |
+
|
19 |
+
def _init(self, conf):
|
20 |
+
self.embeddings = torch.nn.ModuleDict(
|
21 |
+
{
|
22 |
+
k: torch.nn.Embedding(n + 1, conf.embedding_dim)
|
23 |
+
for k, n in conf.num_classes.items()
|
24 |
+
}
|
25 |
+
)
|
26 |
+
#num_calsses:{'areas': 7, 'ways': 10, 'nodes': 33}
|
27 |
+
input_dim = len(conf.num_classes) * conf.embedding_dim
|
28 |
+
output_dim = conf.output_dim
|
29 |
+
if output_dim is None:
|
30 |
+
output_dim = conf.backbone.output_dim
|
31 |
+
if conf.unary_prior:
|
32 |
+
output_dim += 1
|
33 |
+
if conf.backbone is None:
|
34 |
+
self.encoder = nn.Conv2d(input_dim, output_dim, 1)
|
35 |
+
elif conf.backbone == "simple":
|
36 |
+
self.encoder = nn.Sequential(
|
37 |
+
nn.Conv2d(input_dim, 128, 3, padding=1),
|
38 |
+
nn.ReLU(inplace=True),
|
39 |
+
nn.Conv2d(128, 128, 3, padding=1),
|
40 |
+
nn.ReLU(inplace=True),
|
41 |
+
nn.Conv2d(128, output_dim, 3, padding=1),
|
42 |
+
)
|
43 |
+
else:
|
44 |
+
self.encoder = FeatureExtractor(
|
45 |
+
{
|
46 |
+
**conf.backbone,
|
47 |
+
"input_dim": input_dim,
|
48 |
+
"output_dim": output_dim,
|
49 |
+
}
|
50 |
+
)
|
51 |
+
|
52 |
+
def _forward(self, data):
|
53 |
+
embeddings = [
|
54 |
+
self.embeddings[k](data["map"][:, i])
|
55 |
+
for i, k in enumerate(("areas", "ways", "nodes"))
|
56 |
+
]
|
57 |
+
embeddings = torch.cat(embeddings, dim=-1).permute(0, 3, 1, 2)
|
58 |
+
if isinstance(self.encoder, BaseModel):
|
59 |
+
features = self.encoder({"image": embeddings})["feature_maps"]
|
60 |
+
else:
|
61 |
+
features = [self.encoder(embeddings)]
|
62 |
+
pred = {}
|
63 |
+
if self.conf.unary_prior:
|
64 |
+
pred["log_prior"] = [f[:, -1] for f in features]
|
65 |
+
features = [f[:, :-1] for f in features]
|
66 |
+
pred["map_features"] = features
|
67 |
+
return pred
|
models/maplocnet.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch.nn.functional import normalize
|
6 |
+
|
7 |
+
from . import get_model
|
8 |
+
from models.base import BaseModel
|
9 |
+
# from models.bev_net import BEVNet
|
10 |
+
# from models.bev_projection import CartesianProjection, PolarProjectionDepth
|
11 |
+
from models.voting import (
|
12 |
+
argmax_xyr,
|
13 |
+
conv2d_fft_batchwise,
|
14 |
+
expectation_xyr,
|
15 |
+
log_softmax_spatial,
|
16 |
+
mask_yaw_prior,
|
17 |
+
nll_loss_xyr,
|
18 |
+
nll_loss_xyr_smoothed,
|
19 |
+
TemplateSampler,
|
20 |
+
UAVTemplateSampler,
|
21 |
+
UAVTemplateSamplerFast
|
22 |
+
)
|
23 |
+
from .map_encoder import MapEncoder
|
24 |
+
from .metrics import AngleError, AngleRecall, Location2DError, Location2DRecall
|
25 |
+
|
26 |
+
|
27 |
+
class MapLocNet(BaseModel):
|
28 |
+
default_conf = {
|
29 |
+
"image_size": "???",
|
30 |
+
"val_citys":"???",
|
31 |
+
"image_encoder": "???",
|
32 |
+
"map_encoder": "???",
|
33 |
+
"bev_net": "???",
|
34 |
+
"latent_dim": "???",
|
35 |
+
"matching_dim": "???",
|
36 |
+
"scale_range": [0, 9],
|
37 |
+
"num_scale_bins": "???",
|
38 |
+
"z_min": None,
|
39 |
+
"z_max": "???",
|
40 |
+
"x_max": "???",
|
41 |
+
"pixel_per_meter": "???",
|
42 |
+
"num_rotations": "???",
|
43 |
+
"add_temperature": False,
|
44 |
+
"normalize_features": False,
|
45 |
+
"padding_matching": "replicate",
|
46 |
+
"apply_map_prior": True,
|
47 |
+
"do_label_smoothing": False,
|
48 |
+
"sigma_xy": 1,
|
49 |
+
"sigma_r": 2,
|
50 |
+
# depcreated
|
51 |
+
"depth_parameterization": "scale",
|
52 |
+
"norm_depth_scores": False,
|
53 |
+
"normalize_scores_by_dim": False,
|
54 |
+
"normalize_scores_by_num_valid": True,
|
55 |
+
"prior_renorm": True,
|
56 |
+
"retrieval_dim": None,
|
57 |
+
}
|
58 |
+
|
59 |
+
def _init(self, conf):
|
60 |
+
assert not self.conf.norm_depth_scores
|
61 |
+
assert self.conf.depth_parameterization == "scale"
|
62 |
+
assert not self.conf.normalize_scores_by_dim
|
63 |
+
assert self.conf.normalize_scores_by_num_valid
|
64 |
+
assert self.conf.prior_renorm
|
65 |
+
|
66 |
+
Encoder = get_model(conf.image_encoder.get("name", "feature_extractor_v2"))
|
67 |
+
self.image_encoder = Encoder(conf.image_encoder.backbone)
|
68 |
+
self.map_encoder = MapEncoder(conf.map_encoder)
|
69 |
+
# self.bev_net = None if conf.bev_net is None else BEVNet(conf.bev_net)
|
70 |
+
|
71 |
+
ppm = conf.pixel_per_meter
|
72 |
+
# self.projection_polar = PolarProjectionDepth(
|
73 |
+
# conf.z_max,
|
74 |
+
# ppm,
|
75 |
+
# conf.scale_range,
|
76 |
+
# conf.z_min,
|
77 |
+
# )
|
78 |
+
# self.projection_bev = CartesianProjection(
|
79 |
+
# conf.z_max, conf.x_max, ppm, conf.z_min
|
80 |
+
# )
|
81 |
+
# self.template_sampler = TemplateSampler(
|
82 |
+
# self.projection_bev.grid_xz, ppm, conf.num_rotations
|
83 |
+
# )
|
84 |
+
# self.template_sampler = UAVTemplateSamplerFast(conf.num_rotations,w=conf.image_size//2)
|
85 |
+
self.template_sampler = UAVTemplateSampler(conf.num_rotations)
|
86 |
+
# self.scale_classifier = torch.nn.Linear(conf.latent_dim, conf.num_scale_bins)
|
87 |
+
# if conf.bev_net is None:
|
88 |
+
# self.feature_projection = torch.nn.Linear(
|
89 |
+
# conf.latent_dim, conf.matching_dim
|
90 |
+
# )
|
91 |
+
if conf.add_temperature:
|
92 |
+
temperature = torch.nn.Parameter(torch.tensor(0.0))
|
93 |
+
self.register_parameter("temperature", temperature)
|
94 |
+
|
95 |
+
def exhaustive_voting(self, f_bev, f_map):
|
96 |
+
if self.conf.normalize_features:
|
97 |
+
f_bev = normalize(f_bev, dim=1)
|
98 |
+
f_map = normalize(f_map, dim=1)
|
99 |
+
|
100 |
+
# Build the templates and exhaustively match against the map.
|
101 |
+
# if confidence_bev is not None:
|
102 |
+
# f_bev = f_bev * confidence_bev.unsqueeze(1)
|
103 |
+
# f_bev = f_bev.masked_fill(~valid_bev.unsqueeze(1), 0.0)
|
104 |
+
# torch.save(f_bev, 'f_bev.pt')
|
105 |
+
# torch.save(f_map, 'f_map.pt')
|
106 |
+
|
107 |
+
templates = self.template_sampler(f_bev)#[batch,256,8,129,129]
|
108 |
+
# torch.save(templates, 'templates.pt')
|
109 |
+
with torch.autocast("cuda", enabled=False):
|
110 |
+
scores = conv2d_fft_batchwise(
|
111 |
+
f_map.float(),
|
112 |
+
templates.float(),
|
113 |
+
padding_mode=self.conf.padding_matching,
|
114 |
+
)
|
115 |
+
if self.conf.add_temperature:
|
116 |
+
scores = scores * torch.exp(self.temperature)
|
117 |
+
|
118 |
+
# Reweight the different rotations based on the number of valid pixels
|
119 |
+
# in each template. Axis-aligned rotation have the maximum number of valid pixels.
|
120 |
+
# valid_templates = self.template_sampler(valid_bev.float()[None]) > (1 - 1e-4)
|
121 |
+
# num_valid = valid_templates.float().sum((-3, -2, -1))
|
122 |
+
# scores = scores / num_valid[..., None, None]
|
123 |
+
return scores
|
124 |
+
|
125 |
+
def _forward(self, data):
|
126 |
+
pred = {}
|
127 |
+
pred_map = pred["map"] = self.map_encoder(data)
|
128 |
+
f_map = pred_map["map_features"][0]#[batch,8,256,256]
|
129 |
+
|
130 |
+
# Extract image features.
|
131 |
+
level = 0
|
132 |
+
f_image = self.image_encoder(data)["feature_maps"][level]#[batch,128,128,176]
|
133 |
+
# print("f_map:",f_map.shape)
|
134 |
+
|
135 |
+
scores = self.exhaustive_voting(f_image, f_map)#f_bev:[batch,8,64,129] f_map:[batch,8,256,256] confidence:[1,64,129]
|
136 |
+
scores = scores.moveaxis(1, -1) # B,H,W,N
|
137 |
+
if "log_prior" in pred_map and self.conf.apply_map_prior:
|
138 |
+
scores = scores + pred_map["log_prior"][0].unsqueeze(-1)
|
139 |
+
# pred["scores_unmasked"] = scores.clone()
|
140 |
+
if "map_mask" in data:
|
141 |
+
scores.masked_fill_(~data["map_mask"][..., None], -np.inf)
|
142 |
+
if "yaw_prior" in data:
|
143 |
+
mask_yaw_prior(scores, data["yaw_prior"], self.conf.num_rotations)
|
144 |
+
log_probs = log_softmax_spatial(scores)
|
145 |
+
# torch.save(scores, 'scores.pt')
|
146 |
+
with torch.no_grad():
|
147 |
+
uvr_max = argmax_xyr(scores).to(scores)
|
148 |
+
uvr_avg, _ = expectation_xyr(log_probs.exp())
|
149 |
+
|
150 |
+
return {
|
151 |
+
**pred,
|
152 |
+
"scores": scores,
|
153 |
+
"log_probs": log_probs,
|
154 |
+
"uvr_max": uvr_max,
|
155 |
+
"uv_max": uvr_max[..., :2],
|
156 |
+
"yaw_max": uvr_max[..., 2],
|
157 |
+
"uvr_expectation": uvr_avg,
|
158 |
+
"uv_expectation": uvr_avg[..., :2],
|
159 |
+
"yaw_expectation": uvr_avg[..., 2],
|
160 |
+
"features_image": f_image,
|
161 |
+
}
|
162 |
+
|
163 |
+
def loss(self, pred, data):
|
164 |
+
xy_gt = data["uv"]
|
165 |
+
yaw_gt = data["roll_pitch_yaw"][..., -1]
|
166 |
+
if self.conf.do_label_smoothing:
|
167 |
+
nll = nll_loss_xyr_smoothed(
|
168 |
+
pred["log_probs"],
|
169 |
+
xy_gt,
|
170 |
+
yaw_gt,
|
171 |
+
self.conf.sigma_xy / self.conf.pixel_per_meter,
|
172 |
+
self.conf.sigma_r,
|
173 |
+
mask=data.get("map_mask"),
|
174 |
+
)
|
175 |
+
else:
|
176 |
+
nll = nll_loss_xyr(pred["log_probs"], xy_gt, yaw_gt)
|
177 |
+
loss = {"total": nll, "nll": nll}
|
178 |
+
if self.training and self.conf.add_temperature:
|
179 |
+
loss["temperature"] = self.temperature.expand(len(nll))
|
180 |
+
return loss
|
181 |
+
|
182 |
+
def metrics(self):
|
183 |
+
return {
|
184 |
+
"xy_max_error": Location2DError("uv_max", self.conf.pixel_per_meter),
|
185 |
+
"xy_expectation_error": Location2DError(
|
186 |
+
"uv_expectation", self.conf.pixel_per_meter
|
187 |
+
),
|
188 |
+
"yaw_max_error": AngleError("yaw_max"),
|
189 |
+
"xy_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"),
|
190 |
+
"xy_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"),
|
191 |
+
"xy_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"),
|
192 |
+
|
193 |
+
# "x_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"),
|
194 |
+
# "x_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"),
|
195 |
+
# "x_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"),
|
196 |
+
#
|
197 |
+
# "y_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"),
|
198 |
+
# "y_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"),
|
199 |
+
# "y_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"),
|
200 |
+
|
201 |
+
"yaw_recall_1°": AngleRecall(1.0, "yaw_max"),
|
202 |
+
"yaw_recall_3°": AngleRecall(3.0, "yaw_max"),
|
203 |
+
"yaw_recall_5°": AngleRecall(5.0, "yaw_max"),
|
204 |
+
}
|
models/metrics.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torchmetrics
|
5 |
+
from torchmetrics.utilities.data import dim_zero_cat
|
6 |
+
|
7 |
+
from .utils import deg2rad, rotmat2d
|
8 |
+
|
9 |
+
|
10 |
+
def location_error(uv, uv_gt, ppm=1):
|
11 |
+
return torch.norm(uv - uv_gt.to(uv), dim=-1) / ppm
|
12 |
+
|
13 |
+
def location_error_single(uv, uv_gt, ppm=1):
|
14 |
+
return torch.norm(uv - uv_gt.to(uv), dim=-1) / ppm
|
15 |
+
|
16 |
+
def angle_error(t, t_gt):
|
17 |
+
error = torch.abs(t % 360 - t_gt.to(t) % 360)
|
18 |
+
error = torch.minimum(error, 360 - error)
|
19 |
+
return error
|
20 |
+
|
21 |
+
|
22 |
+
class Location2DRecall(torchmetrics.MeanMetric):
|
23 |
+
def __init__(self, threshold, pixel_per_meter, key="uv_max", *args, **kwargs):
|
24 |
+
self.threshold = threshold
|
25 |
+
self.ppm = pixel_per_meter
|
26 |
+
self.key = key
|
27 |
+
super().__init__(*args, **kwargs)
|
28 |
+
|
29 |
+
def update(self, pred, data):
|
30 |
+
self.cuda()
|
31 |
+
error = location_error(pred[self.key], data["uv"], self.ppm)
|
32 |
+
# print(error,self.threshold)
|
33 |
+
super().update((error <= torch.tensor(self.threshold,device=error.device)).float())
|
34 |
+
|
35 |
+
class Location1DRecall(torchmetrics.MeanMetric):
|
36 |
+
def __init__(self, threshold, pixel_per_meter, key="uv_max", *args, **kwargs):
|
37 |
+
self.threshold = threshold
|
38 |
+
self.ppm = pixel_per_meter
|
39 |
+
self.key = key
|
40 |
+
super().__init__(*args, **kwargs)
|
41 |
+
|
42 |
+
def update(self, pred, data):
|
43 |
+
self.cuda()
|
44 |
+
error = location_error(pred[self.key], data["uv"], self.ppm)
|
45 |
+
# print(error,self.threshold)
|
46 |
+
super().update((error <= torch.tensor(self.threshold,device=error.device)).float())
|
47 |
+
class AngleRecall(torchmetrics.MeanMetric):
|
48 |
+
def __init__(self, threshold, key="yaw_max", *args, **kwargs):
|
49 |
+
self.threshold = threshold
|
50 |
+
self.key = key
|
51 |
+
|
52 |
+
super().__init__(*args, **kwargs)
|
53 |
+
|
54 |
+
def update(self, pred, data):
|
55 |
+
self.cuda()
|
56 |
+
error = angle_error(pred[self.key], data["roll_pitch_yaw"][..., -1])
|
57 |
+
super().update((error <= self.threshold).float())
|
58 |
+
|
59 |
+
|
60 |
+
class MeanMetricWithRecall(torchmetrics.Metric):
|
61 |
+
full_state_update = True
|
62 |
+
|
63 |
+
def __init__(self):
|
64 |
+
super().__init__()
|
65 |
+
self.add_state("value", default=[], dist_reduce_fx="cat")
|
66 |
+
def compute(self):
|
67 |
+
return dim_zero_cat(self.value).mean(0)
|
68 |
+
|
69 |
+
def get_errors(self):
|
70 |
+
return dim_zero_cat(self.value)
|
71 |
+
|
72 |
+
def recall(self, thresholds):
|
73 |
+
self.cuda()
|
74 |
+
error = self.get_errors()
|
75 |
+
thresholds = error.new_tensor(thresholds)
|
76 |
+
return (error.unsqueeze(-1) < thresholds).float().mean(0) * 100
|
77 |
+
|
78 |
+
|
79 |
+
class AngleError(MeanMetricWithRecall):
|
80 |
+
def __init__(self, key):
|
81 |
+
super().__init__()
|
82 |
+
self.key = key
|
83 |
+
|
84 |
+
def update(self, pred, data):
|
85 |
+
self.cuda()
|
86 |
+
value = angle_error(pred[self.key], data["roll_pitch_yaw"][..., -1])
|
87 |
+
if value.numel():
|
88 |
+
self.value.append(value)
|
89 |
+
|
90 |
+
|
91 |
+
class Location2DError(MeanMetricWithRecall):
|
92 |
+
def __init__(self, key, pixel_per_meter):
|
93 |
+
super().__init__()
|
94 |
+
self.key = key
|
95 |
+
self.ppm = pixel_per_meter
|
96 |
+
|
97 |
+
def update(self, pred, data):
|
98 |
+
self.cuda()
|
99 |
+
value = location_error(pred[self.key], data["uv"], self.ppm)
|
100 |
+
if value.numel():
|
101 |
+
self.value.append(value)
|
102 |
+
|
103 |
+
|
104 |
+
class LateralLongitudinalError(MeanMetricWithRecall):
|
105 |
+
def __init__(self, pixel_per_meter, key="uv_max"):
|
106 |
+
super().__init__()
|
107 |
+
self.ppm = pixel_per_meter
|
108 |
+
self.key = key
|
109 |
+
|
110 |
+
def update(self, pred, data):
|
111 |
+
self.cuda()
|
112 |
+
yaw = deg2rad(data["roll_pitch_yaw"][..., -1])
|
113 |
+
shift = (pred[self.key] - data["uv"]) * yaw.new_tensor([-1, 1])
|
114 |
+
shift = (rotmat2d(yaw) @ shift.unsqueeze(-1)).squeeze(-1)
|
115 |
+
error = torch.abs(shift) / self.ppm
|
116 |
+
value = error.view(-1, 2)
|
117 |
+
if value.numel():
|
118 |
+
self.value.append(value)
|
models/utils.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import math
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def checkpointed(cls, do=True):
|
10 |
+
"""Adapted from the DISK implementation of Michał Tyszkiewicz."""
|
11 |
+
assert issubclass(cls, torch.nn.Module)
|
12 |
+
|
13 |
+
class Checkpointed(cls):
|
14 |
+
def forward(self, *args, **kwargs):
|
15 |
+
super_fwd = super(Checkpointed, self).forward
|
16 |
+
if any((torch.is_tensor(a) and a.requires_grad) for a in args):
|
17 |
+
return torch.utils.checkpoint.checkpoint(super_fwd, *args, **kwargs)
|
18 |
+
else:
|
19 |
+
return super_fwd(*args, **kwargs)
|
20 |
+
|
21 |
+
return Checkpointed if do else cls
|
22 |
+
|
23 |
+
|
24 |
+
class GlobalPooling(torch.nn.Module):
|
25 |
+
def __init__(self, kind):
|
26 |
+
super().__init__()
|
27 |
+
if kind == "mean":
|
28 |
+
self.fn = torch.nn.Sequential(
|
29 |
+
torch.nn.Flatten(2), torch.nn.AdaptiveAvgPool1d(1), torch.nn.Flatten()
|
30 |
+
)
|
31 |
+
elif kind == "max":
|
32 |
+
self.fn = torch.nn.Sequential(
|
33 |
+
torch.nn.Flatten(2), torch.nn.AdaptiveMaxPool1d(1), torch.nn.Flatten()
|
34 |
+
)
|
35 |
+
else:
|
36 |
+
raise ValueError(f"Unknown pooling type {kind}.")
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
return self.fn(x)
|
40 |
+
|
41 |
+
|
42 |
+
@torch.jit.script
|
43 |
+
def make_grid(
|
44 |
+
w: float,
|
45 |
+
h: float,
|
46 |
+
step_x: float = 1.0,
|
47 |
+
step_y: float = 1.0,
|
48 |
+
orig_x: float = 0,
|
49 |
+
orig_y: float = 0,
|
50 |
+
y_up: bool = False,
|
51 |
+
device: Optional[torch.device] = None,
|
52 |
+
) -> torch.Tensor:
|
53 |
+
x, y = torch.meshgrid(
|
54 |
+
[
|
55 |
+
torch.arange(orig_x, w + orig_x, step_x, device=device),
|
56 |
+
torch.arange(orig_y, h + orig_y, step_y, device=device),
|
57 |
+
],
|
58 |
+
indexing="xy",
|
59 |
+
)
|
60 |
+
if y_up:
|
61 |
+
y = y.flip(-2)
|
62 |
+
grid = torch.stack((x, y), -1)
|
63 |
+
return grid
|
64 |
+
|
65 |
+
|
66 |
+
@torch.jit.script
|
67 |
+
def rotmat2d(angle: torch.Tensor) -> torch.Tensor:
|
68 |
+
c = torch.cos(angle)
|
69 |
+
s = torch.sin(angle)
|
70 |
+
R = torch.stack([c, -s, s, c], -1).reshape(angle.shape + (2, 2))
|
71 |
+
return R
|
72 |
+
|
73 |
+
|
74 |
+
@torch.jit.script
|
75 |
+
def rotmat2d_grad(angle: torch.Tensor) -> torch.Tensor:
|
76 |
+
c = torch.cos(angle)
|
77 |
+
s = torch.sin(angle)
|
78 |
+
R = torch.stack([-s, -c, c, -s], -1).reshape(angle.shape + (2, 2))
|
79 |
+
return R
|
80 |
+
|
81 |
+
|
82 |
+
def deg2rad(x):
|
83 |
+
return x * math.pi / 180
|
84 |
+
|
85 |
+
|
86 |
+
def rad2deg(x):
|
87 |
+
return x * 180 / math.pi
|
models/voting.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
from typing import Optional, Tuple
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch.fft import irfftn, rfftn
|
8 |
+
from torch.nn.functional import grid_sample, log_softmax, pad
|
9 |
+
|
10 |
+
from .metrics import angle_error
|
11 |
+
from .utils import make_grid, rotmat2d
|
12 |
+
from torchvision.transforms.functional import rotate
|
13 |
+
|
14 |
+
class UAVTemplateSamplerFast(torch.nn.Module):
|
15 |
+
def __init__(self, num_rotations,w=128,optimize=True):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
h, w = w,w
|
19 |
+
grid_xy = make_grid(
|
20 |
+
w=w,
|
21 |
+
h=h,
|
22 |
+
step_x=1,
|
23 |
+
step_y=1,
|
24 |
+
orig_y=-h//2,
|
25 |
+
orig_x=-h//2,
|
26 |
+
y_up=True,
|
27 |
+
).cuda()
|
28 |
+
|
29 |
+
if optimize:
|
30 |
+
assert (num_rotations % 4) == 0
|
31 |
+
angles = torch.arange(
|
32 |
+
0, 90, 90 / (num_rotations // 4)
|
33 |
+
).cuda()
|
34 |
+
else:
|
35 |
+
angles = torch.arange(
|
36 |
+
0, 360, 360 / num_rotations, device=grid_xz_bev.device
|
37 |
+
)
|
38 |
+
rotmats = rotmat2d(angles / 180 * np.pi)
|
39 |
+
grid_xy_rot = torch.einsum("...nij,...hwj->...nhwi", rotmats, grid_xy)
|
40 |
+
|
41 |
+
grid_ij_rot = (grid_xy_rot - grid_xy[..., :1, :1, :]) * grid_xy.new_tensor(
|
42 |
+
[1, -1]
|
43 |
+
)
|
44 |
+
grid_ij_rot = grid_ij_rot
|
45 |
+
grid_norm = (grid_ij_rot + 0.5) / grid_ij_rot.new_tensor([w, h]) * 2 - 1
|
46 |
+
|
47 |
+
self.optimize = optimize
|
48 |
+
self.num_rots = num_rotations
|
49 |
+
self.register_buffer("angles", angles, persistent=False)
|
50 |
+
self.register_buffer("grid_norm", grid_norm, persistent=False)
|
51 |
+
|
52 |
+
def forward(self, image_bev):
|
53 |
+
grid = self.grid_norm
|
54 |
+
b, c = image_bev.shape[:2]
|
55 |
+
n, h, w = grid.shape[:3]
|
56 |
+
grid = grid[None].repeat_interleave(b, 0).reshape(b * n, h, w, 2)
|
57 |
+
image = (
|
58 |
+
image_bev[:, None]
|
59 |
+
.repeat_interleave(n, 1)
|
60 |
+
.reshape(b * n, *image_bev.shape[1:])
|
61 |
+
)
|
62 |
+
# print(image.shape,grid.shape,self.grid_norm.shape)
|
63 |
+
kernels = grid_sample(image, grid.to(image.dtype), align_corners=False).reshape(
|
64 |
+
b, n, c, h, w
|
65 |
+
)
|
66 |
+
|
67 |
+
if self.optimize: # we have computed only the first quadrant
|
68 |
+
kernels_quad234 = [torch.rot90(kernels, -i, (-2, -1)) for i in (1, 2, 3)]
|
69 |
+
kernels = torch.cat([kernels] + kernels_quad234, 1)
|
70 |
+
|
71 |
+
return kernels
|
72 |
+
class UAVTemplateSampler(torch.nn.Module):
|
73 |
+
def __init__(self, num_rotations):
|
74 |
+
super().__init__()
|
75 |
+
|
76 |
+
self.num_rotations = num_rotations
|
77 |
+
|
78 |
+
def Template(self, input_features):
|
79 |
+
# 角度数量
|
80 |
+
num_angles = self.num_rotations
|
81 |
+
# 扩展第二个维度为旋转角度数量
|
82 |
+
input_shape = torch.tensor(input_features.shape)
|
83 |
+
output_shape = torch.cat((input_shape[:1], torch.tensor([num_angles]), input_shape[1:])).tolist()
|
84 |
+
expanded_features = torch.zeros(output_shape,device=input_features.device)
|
85 |
+
|
86 |
+
# 生成旋转角度序列
|
87 |
+
rotation_angles = torch.linspace(360, 0, 64 + 1)[:-1]
|
88 |
+
# rotation_angles=torch.flip(rotation_angles, dims=[0])
|
89 |
+
# 对扩展后的特征应用不同的旋转角度
|
90 |
+
rotated_features = []
|
91 |
+
# print(len(rotation_angles))
|
92 |
+
for i in range(len(rotation_angles)):
|
93 |
+
# print(rotation_angles[i].item())
|
94 |
+
rotated_feature = rotate(input_features, rotation_angles[i].item(), fill=0)
|
95 |
+
expanded_features[:, i, :, :, :] = rotated_feature
|
96 |
+
|
97 |
+
# 将所有旋转后的特征堆叠起来形成最终的输出向量
|
98 |
+
# output_features = torch.stack(rotated_features, dim=1)
|
99 |
+
|
100 |
+
# 输出向量的维度
|
101 |
+
# output_size = [3, num_angles, 8, 128, 128]
|
102 |
+
return expanded_features # 输出调试信息,验证输出向量的维度是否正确
|
103 |
+
def forward(self, image_bev):
|
104 |
+
|
105 |
+
kernels=self.Template(image_bev)
|
106 |
+
|
107 |
+
return kernels
|
108 |
+
class TemplateSampler(torch.nn.Module):
|
109 |
+
def __init__(self, grid_xz_bev, ppm, num_rotations, optimize=True):
|
110 |
+
super().__init__()
|
111 |
+
|
112 |
+
Δ = 1 / ppm
|
113 |
+
h, w = grid_xz_bev.shape[:2]
|
114 |
+
ksize = max(w, h * 2 + 1)
|
115 |
+
radius = ksize * Δ
|
116 |
+
grid_xy = make_grid(
|
117 |
+
radius,
|
118 |
+
radius,
|
119 |
+
step_x=Δ,
|
120 |
+
step_y=Δ,
|
121 |
+
orig_y=(Δ - radius) / 2,
|
122 |
+
orig_x=(Δ - radius) / 2,
|
123 |
+
y_up=True,
|
124 |
+
)
|
125 |
+
|
126 |
+
if optimize:
|
127 |
+
assert (num_rotations % 4) == 0
|
128 |
+
angles = torch.arange(
|
129 |
+
0, 90, 90 / (num_rotations // 4), device=grid_xz_bev.device
|
130 |
+
)
|
131 |
+
else:
|
132 |
+
angles = torch.arange(
|
133 |
+
0, 360, 360 / num_rotations, device=grid_xz_bev.device
|
134 |
+
)
|
135 |
+
rotmats = rotmat2d(angles / 180 * np.pi)
|
136 |
+
grid_xy_rot = torch.einsum("...nij,...hwj->...nhwi", rotmats, grid_xy)
|
137 |
+
|
138 |
+
grid_ij_rot = (grid_xy_rot - grid_xz_bev[..., :1, :1, :]) * grid_xy.new_tensor(
|
139 |
+
[1, -1]
|
140 |
+
)
|
141 |
+
grid_ij_rot = grid_ij_rot / Δ
|
142 |
+
grid_norm = (grid_ij_rot + 0.5) / grid_ij_rot.new_tensor([w, h]) * 2 - 1
|
143 |
+
|
144 |
+
self.optimize = optimize
|
145 |
+
self.num_rots = num_rotations
|
146 |
+
self.register_buffer("angles", angles, persistent=False)
|
147 |
+
self.register_buffer("grid_norm", grid_norm, persistent=False)
|
148 |
+
|
149 |
+
def forward(self, image_bev):
|
150 |
+
grid = self.grid_norm
|
151 |
+
b, c = image_bev.shape[:2]
|
152 |
+
n, h, w = grid.shape[:3]
|
153 |
+
grid = grid[None].repeat_interleave(b, 0).reshape(b * n, h, w, 2)
|
154 |
+
image = (
|
155 |
+
image_bev[:, None]
|
156 |
+
.repeat_interleave(n, 1)
|
157 |
+
.reshape(b * n, *image_bev.shape[1:])
|
158 |
+
)
|
159 |
+
kernels = grid_sample(image, grid.to(image.dtype), align_corners=False).reshape(
|
160 |
+
b, n, c, h, w
|
161 |
+
)
|
162 |
+
|
163 |
+
if self.optimize: # we have computed only the first quadrant
|
164 |
+
kernels_quad234 = [torch.rot90(kernels, -i, (-2, -1)) for i in (1, 2, 3)]
|
165 |
+
kernels = torch.cat([kernels] + kernels_quad234, 1)
|
166 |
+
|
167 |
+
return kernels
|
168 |
+
|
169 |
+
|
170 |
+
def conv2d_fft_batchwise(signal, kernel, padding="same", padding_mode="constant"):
|
171 |
+
if padding == "same":
|
172 |
+
padding = [i // 2 for i in kernel.shape[-2:]]
|
173 |
+
padding_signal = [p for p in padding[::-1] for _ in range(2)]
|
174 |
+
signal = pad(signal, padding_signal, mode=padding_mode)
|
175 |
+
assert signal.size(-1) % 2 == 0
|
176 |
+
|
177 |
+
padding_kernel = [
|
178 |
+
pad for i in [1, 2] for pad in [0, signal.size(-i) - kernel.size(-i)]
|
179 |
+
]
|
180 |
+
kernel_padded = pad(kernel, padding_kernel)
|
181 |
+
|
182 |
+
signal_fr = rfftn(signal, dim=(-1, -2))
|
183 |
+
kernel_fr = rfftn(kernel_padded, dim=(-1, -2))
|
184 |
+
|
185 |
+
kernel_fr.imag *= -1 # flip the kernel
|
186 |
+
output_fr = torch.einsum("bc...,bdc...->bd...", signal_fr, kernel_fr)
|
187 |
+
output = irfftn(output_fr, dim=(-1, -2))
|
188 |
+
|
189 |
+
crop_slices = [slice(0, output.size(0)), slice(0, output.size(1))] + [
|
190 |
+
slice(0, (signal.size(i) - kernel.size(i) + 1)) for i in [-2, -1]
|
191 |
+
]
|
192 |
+
output = output[crop_slices].contiguous()
|
193 |
+
|
194 |
+
return output
|
195 |
+
|
196 |
+
|
197 |
+
class SparseMapSampler(torch.nn.Module):
|
198 |
+
def __init__(self, num_rotations):
|
199 |
+
super().__init__()
|
200 |
+
angles = torch.arange(0, 360, 360 / self.conf.num_rotations)
|
201 |
+
rotmats = rotmat2d(angles / 180 * np.pi)
|
202 |
+
self.num_rotations = num_rotations
|
203 |
+
self.register_buffer("rotmats", rotmats, persistent=False)
|
204 |
+
|
205 |
+
def forward(self, image_map, p2d_bev):
|
206 |
+
h, w = image_map.shape[-2:]
|
207 |
+
locations = make_grid(w, h, device=p2d_bev.device)
|
208 |
+
p2d_candidates = torch.einsum(
|
209 |
+
"kji,...i,->...kj", self.rotmats.to(p2d_bev), p2d_bev
|
210 |
+
)
|
211 |
+
p2d_candidates = p2d_candidates[..., None, None, :, :] + locations.unsqueeze(-1)
|
212 |
+
# ... x N x W x H x K x 2
|
213 |
+
|
214 |
+
p2d_norm = (p2d_candidates / (image_map.new_tensor([w, h]) - 1)) * 2 - 1
|
215 |
+
valid = torch.all((p2d_norm >= -1) & (p2d_norm <= 1), -1)
|
216 |
+
value = grid_sample(
|
217 |
+
image_map, p2d_norm.flatten(-4, -2), align_corners=True, mode="bilinear"
|
218 |
+
)
|
219 |
+
value = value.reshape(image_map.shape[:2] + valid.shape[-4])
|
220 |
+
return valid, value
|
221 |
+
|
222 |
+
|
223 |
+
def sample_xyr(volume, xy_grid, angle_grid, nearest_for_inf=False):
|
224 |
+
# (B, C, H, W, N) to (B, C, H, W, N+1)
|
225 |
+
volume_padded = pad(volume, [0, 1, 0, 0, 0, 0], mode="circular")
|
226 |
+
|
227 |
+
size = xy_grid.new_tensor(volume.shape[-3:-1][::-1])
|
228 |
+
xy_norm = xy_grid / (size - 1) # align_corners=True
|
229 |
+
angle_norm = (angle_grid / 360) % 1
|
230 |
+
grid = torch.concat([angle_norm.unsqueeze(-1), xy_norm], -1)
|
231 |
+
grid_norm = grid * 2 - 1
|
232 |
+
|
233 |
+
valid = torch.all((grid_norm >= -1) & (grid_norm <= 1), -1)
|
234 |
+
value = grid_sample(volume_padded, grid_norm, align_corners=True, mode="bilinear")
|
235 |
+
|
236 |
+
# if one of the values used for linear interpolation is infinite,
|
237 |
+
# we fallback to nearest to avoid propagating inf
|
238 |
+
if nearest_for_inf:
|
239 |
+
value_nearest = grid_sample(
|
240 |
+
volume_padded, grid_norm, align_corners=True, mode="nearest"
|
241 |
+
)
|
242 |
+
value = torch.where(~torch.isfinite(value) & valid, value_nearest, value)
|
243 |
+
|
244 |
+
return value, valid
|
245 |
+
|
246 |
+
|
247 |
+
def nll_loss_xyr(log_probs, xy, angle):
|
248 |
+
log_prob, _ = sample_xyr(
|
249 |
+
log_probs.unsqueeze(1), xy[:, None, None, None], angle[:, None, None, None]
|
250 |
+
)
|
251 |
+
nll = -log_prob.reshape(-1) # remove C,H,W,N
|
252 |
+
return nll
|
253 |
+
|
254 |
+
|
255 |
+
def nll_loss_xyr_smoothed(log_probs, xy, angle, sigma_xy, sigma_r, mask=None):
|
256 |
+
*_, nx, ny, nr = log_probs.shape
|
257 |
+
grid_x = torch.arange(nx, device=log_probs.device, dtype=torch.float)
|
258 |
+
dx = (grid_x - xy[..., None, 0]) / sigma_xy
|
259 |
+
grid_y = torch.arange(ny, device=log_probs.device, dtype=torch.float)
|
260 |
+
dy = (grid_y - xy[..., None, 1]) / sigma_xy
|
261 |
+
dr = (
|
262 |
+
torch.arange(0, 360, 360 / nr, device=log_probs.device, dtype=torch.float)
|
263 |
+
- angle[..., None]
|
264 |
+
) % 360
|
265 |
+
dr = torch.minimum(dr, 360 - dr) / sigma_r
|
266 |
+
diff = (
|
267 |
+
dx[..., None, :, None] ** 2
|
268 |
+
+ dy[..., :, None, None] ** 2
|
269 |
+
+ dr[..., None, None, :] ** 2
|
270 |
+
)
|
271 |
+
pdf = torch.exp(-diff / 2)
|
272 |
+
if mask is not None:
|
273 |
+
pdf.masked_fill_(~mask[..., None], 0)
|
274 |
+
log_probs = log_probs.masked_fill(~mask[..., None], 0)
|
275 |
+
pdf /= pdf.sum((-1, -2, -3), keepdim=True)
|
276 |
+
return -torch.sum(pdf * log_probs.to(torch.float), dim=(-1, -2, -3))
|
277 |
+
|
278 |
+
|
279 |
+
def log_softmax_spatial(x, dims=3):
|
280 |
+
return log_softmax(x.flatten(-dims), dim=-1).reshape(x.shape)
|
281 |
+
|
282 |
+
|
283 |
+
@torch.jit.script
|
284 |
+
def argmax_xy(scores: torch.Tensor) -> torch.Tensor:
|
285 |
+
indices = scores.flatten(-2).max(-1).indices
|
286 |
+
width = scores.shape[-1]
|
287 |
+
x = indices % width
|
288 |
+
y = torch.div(indices, width, rounding_mode="floor")
|
289 |
+
return torch.stack((x, y), -1)
|
290 |
+
|
291 |
+
|
292 |
+
@torch.jit.script
|
293 |
+
def expectation_xy(prob: torch.Tensor) -> torch.Tensor:
|
294 |
+
h, w = prob.shape[-2:]
|
295 |
+
grid = make_grid(float(w), float(h), device=prob.device).to(prob)
|
296 |
+
return torch.einsum("...hw,hwd->...d", prob, grid)
|
297 |
+
|
298 |
+
|
299 |
+
@torch.jit.script
|
300 |
+
def expectation_xyr(
|
301 |
+
prob: torch.Tensor, covariance: bool = False
|
302 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
303 |
+
h, w, num_rotations = prob.shape[-3:]
|
304 |
+
x, y = torch.meshgrid(
|
305 |
+
[
|
306 |
+
torch.arange(w, device=prob.device, dtype=prob.dtype),
|
307 |
+
torch.arange(h, device=prob.device, dtype=prob.dtype),
|
308 |
+
],
|
309 |
+
indexing="xy",
|
310 |
+
)
|
311 |
+
grid_xy = torch.stack((x, y), -1)
|
312 |
+
xy_mean = torch.einsum("...hwn,hwd->...d", prob, grid_xy)
|
313 |
+
|
314 |
+
angles = torch.arange(0, 1, 1 / num_rotations, device=prob.device, dtype=prob.dtype)
|
315 |
+
angles = angles * 2 * np.pi
|
316 |
+
grid_cs = torch.stack([torch.cos(angles), torch.sin(angles)], -1)
|
317 |
+
cs_mean = torch.einsum("...hwn,nd->...d", prob, grid_cs)
|
318 |
+
angle = torch.atan2(cs_mean[..., 1], cs_mean[..., 0])
|
319 |
+
angle = (angle * 180 / np.pi) % 360
|
320 |
+
|
321 |
+
if covariance:
|
322 |
+
xy_cov = torch.einsum("...hwn,...hwd,...hwk->...dk", prob, grid_xy, grid_xy)
|
323 |
+
xy_cov = xy_cov - torch.einsum("...d,...k->...dk", xy_mean, xy_mean)
|
324 |
+
else:
|
325 |
+
xy_cov = None
|
326 |
+
|
327 |
+
xyr_mean = torch.cat((xy_mean, angle.unsqueeze(-1)), -1)
|
328 |
+
return xyr_mean, xy_cov
|
329 |
+
|
330 |
+
|
331 |
+
@torch.jit.script
|
332 |
+
def argmax_xyr(scores: torch.Tensor) -> torch.Tensor:
|
333 |
+
indices = scores.flatten(-3).max(-1).indices
|
334 |
+
width, num_rotations = scores.shape[-2:]
|
335 |
+
wr = width * num_rotations
|
336 |
+
y = torch.div(indices, wr, rounding_mode="floor")
|
337 |
+
x = torch.div(indices % wr, num_rotations, rounding_mode="floor")
|
338 |
+
angle_index = indices % num_rotations
|
339 |
+
angle = angle_index * 360 / num_rotations
|
340 |
+
xyr = torch.stack((x, y, angle), -1)
|
341 |
+
return xyr
|
342 |
+
|
343 |
+
|
344 |
+
@torch.jit.script
|
345 |
+
def mask_yaw_prior(
|
346 |
+
scores: torch.Tensor, yaw_prior: torch.Tensor, num_rotations: int
|
347 |
+
) -> torch.Tensor:
|
348 |
+
step = 360 / num_rotations
|
349 |
+
step_2 = step / 2
|
350 |
+
angles = torch.arange(step_2, 360 + step_2, step, device=scores.device)
|
351 |
+
yaw_init, yaw_range = yaw_prior.chunk(2, dim=-1)
|
352 |
+
rot_mask = angle_error(angles, yaw_init) < yaw_range
|
353 |
+
return scores.masked_fill_(~rot_mask[:, None, None], -np.inf)
|
354 |
+
|
355 |
+
|
356 |
+
def fuse_gps(log_prob, uv_gps, ppm, sigma=10, gaussian=False):
|
357 |
+
grid = make_grid(*log_prob.shape[-3:-1][::-1]).to(log_prob)
|
358 |
+
dist = torch.sum((grid - uv_gps) ** 2, -1)
|
359 |
+
sigma_pixel = sigma * ppm
|
360 |
+
if gaussian:
|
361 |
+
gps_log_prob = -1 / 2 * dist / sigma_pixel**2
|
362 |
+
else:
|
363 |
+
gps_log_prob = torch.where(dist < sigma_pixel**2, 1, -np.inf)
|
364 |
+
log_prob_fused = log_softmax_spatial(log_prob + gps_log_prob.unsqueeze(-1))
|
365 |
+
return log_prob_fused
|
module.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
import torch
|
7 |
+
from omegaconf import DictConfig, OmegaConf, open_dict
|
8 |
+
from torchmetrics import MeanMetric, MetricCollection
|
9 |
+
|
10 |
+
import logger
|
11 |
+
from models import get_model
|
12 |
+
|
13 |
+
|
14 |
+
class AverageKeyMeter(MeanMetric):
|
15 |
+
def __init__(self, key, *args, **kwargs):
|
16 |
+
self.key = key
|
17 |
+
super().__init__(*args, **kwargs)
|
18 |
+
|
19 |
+
def update(self, dict):
|
20 |
+
value = dict[self.key]
|
21 |
+
value = value[torch.isfinite(value)]
|
22 |
+
return super().update(value)
|
23 |
+
|
24 |
+
|
25 |
+
class GenericModule(pl.LightningModule):
|
26 |
+
def __init__(self, cfg):
|
27 |
+
super().__init__()
|
28 |
+
name = cfg.model.get("name")
|
29 |
+
name = "orienternet" if name in ("localizer_bev_depth", None) else name
|
30 |
+
self.model = get_model(name)(cfg.model)
|
31 |
+
self.cfg = cfg
|
32 |
+
self.save_hyperparameters(cfg)
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
self.metrics_val = MetricCollection(self.model.metrics(), prefix="val/")
|
37 |
+
self.losses_val = None # we do not know the loss keys in advance
|
38 |
+
|
39 |
+
# self.citys = self.cfg.data.val_citys
|
40 |
+
# for i in range(len(self.citys)):
|
41 |
+
# city=self.citys[i]
|
42 |
+
# setattr(self, "metric_vals_{}".format(i), MetricCollection(self.model.metrics(), prefix="val_{}/".format(city)))
|
43 |
+
# self.losse_vals = [None for city in self.cfg.data.val_citys]
|
44 |
+
|
45 |
+
|
46 |
+
def forward(self, batch):
|
47 |
+
return self.model(batch)
|
48 |
+
|
49 |
+
def training_step(self, batch):
|
50 |
+
pred = self(batch)
|
51 |
+
losses = self.model.loss(pred, batch)
|
52 |
+
self.log_dict(
|
53 |
+
{f"loss/{k}/train": v.mean() for k, v in losses.items()},
|
54 |
+
prog_bar=True,
|
55 |
+
rank_zero_only=True,
|
56 |
+
)
|
57 |
+
return losses["total"].mean()
|
58 |
+
|
59 |
+
# def validation_step(self, batch, batch_idx,dataloader_idx):
|
60 |
+
# city=self.citys[dataloader_idx]
|
61 |
+
#
|
62 |
+
# pred = self(batch)
|
63 |
+
# losses = self.model.loss(pred, batch)
|
64 |
+
#
|
65 |
+
# if hasattr(self,"losse_val_{}".format(dataloader_idx)) is False:
|
66 |
+
# setattr(self,"losse_val_{}".format(dataloader_idx),MetricCollection(
|
67 |
+
# {k: AverageKeyMeter(k).to(self.device) for k in losses},
|
68 |
+
# prefix="loss_{}/".format(city),
|
69 |
+
# postfix="/val_{}".format(city),
|
70 |
+
# ))
|
71 |
+
#
|
72 |
+
# # print(pred, batch)
|
73 |
+
# getattr(self,"metric_vals_{}".format(dataloader_idx))(pred, batch)
|
74 |
+
# self.log_dict(getattr(self,"metric_vals_{}".format(dataloader_idx))(pred, batch), sync_dist=True)
|
75 |
+
#
|
76 |
+
# getattr(self,"losse_val_{}".format(dataloader_idx)).update(losses)
|
77 |
+
# # print(getattr(self,"losse_val_{}".format(dataloader_idx)))
|
78 |
+
# self.log_dict(getattr(self,"losse_val_{}".format(dataloader_idx)).compute(), sync_dist=True)
|
79 |
+
def validation_step(self, batch, batch_idx):
|
80 |
+
pred = self(batch)
|
81 |
+
losses = self.model.loss(pred, batch)
|
82 |
+
if self.losses_val is None:
|
83 |
+
self.losses_val = MetricCollection(
|
84 |
+
{k: AverageKeyMeter(k).to(self.device) for k in losses},
|
85 |
+
prefix="loss/",
|
86 |
+
postfix="/val",
|
87 |
+
)
|
88 |
+
self.metrics_val(pred, batch)
|
89 |
+
self.log_dict(self.metrics_val, sync_dist=True)
|
90 |
+
self.losses_val.update(losses)
|
91 |
+
self.log_dict(self.losses_val, sync_dist=True)
|
92 |
+
|
93 |
+
def validation_epoch_start(self, batch):
|
94 |
+
self.losses_val = None
|
95 |
+
# self.losse_val = [None for city in self.cfg.data.val_citys]
|
96 |
+
|
97 |
+
def configure_optimizers(self):
|
98 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.cfg.training.lr)
|
99 |
+
ret = {"optimizer": optimizer}
|
100 |
+
cfg_scheduler = self.cfg.training.get("lr_scheduler")
|
101 |
+
if cfg_scheduler is not None:
|
102 |
+
scheduler = getattr(torch.optim.lr_scheduler, cfg_scheduler.name)(
|
103 |
+
optimizer=optimizer, **cfg_scheduler.get("args", {})
|
104 |
+
)
|
105 |
+
ret["lr_scheduler"] = {
|
106 |
+
"scheduler": scheduler,
|
107 |
+
"interval": "epoch",
|
108 |
+
"frequency": 1,
|
109 |
+
"monitor": "loss/total/val",
|
110 |
+
"strict": True,
|
111 |
+
"name": "learning_rate",
|
112 |
+
}
|
113 |
+
return ret
|
114 |
+
|
115 |
+
@classmethod
|
116 |
+
def load_from_checkpoint(
|
117 |
+
cls,
|
118 |
+
checkpoint_path,
|
119 |
+
map_location=None,
|
120 |
+
hparams_file=None,
|
121 |
+
strict=True,
|
122 |
+
cfg=None,
|
123 |
+
find_best=False,
|
124 |
+
):
|
125 |
+
assert hparams_file is None, "hparams are not supported."
|
126 |
+
|
127 |
+
checkpoint = torch.load(
|
128 |
+
checkpoint_path, map_location=map_location or (lambda storage, loc: storage)
|
129 |
+
)
|
130 |
+
if find_best:
|
131 |
+
best_score, best_name = None, None
|
132 |
+
modes = {"min": torch.lt, "max": torch.gt}
|
133 |
+
for key, state in checkpoint["callbacks"].items():
|
134 |
+
if not key.startswith("ModelCheckpoint"):
|
135 |
+
continue
|
136 |
+
mode = eval(key.replace("ModelCheckpoint", ""))["mode"]
|
137 |
+
if best_score is None or modes[mode](
|
138 |
+
state["best_model_score"], best_score
|
139 |
+
):
|
140 |
+
best_score = state["best_model_score"]
|
141 |
+
best_name = Path(state["best_model_path"]).name
|
142 |
+
logger.info("Loading best checkpoint %s", best_name)
|
143 |
+
if best_name != checkpoint_path:
|
144 |
+
return cls.load_from_checkpoint(
|
145 |
+
Path(checkpoint_path).parent / best_name,
|
146 |
+
map_location,
|
147 |
+
hparams_file,
|
148 |
+
strict,
|
149 |
+
cfg,
|
150 |
+
find_best=False,
|
151 |
+
)
|
152 |
+
|
153 |
+
logger.info(
|
154 |
+
"Using checkpoint %s from epoch %d and step %d.",
|
155 |
+
checkpoint_path.name,
|
156 |
+
checkpoint["epoch"],
|
157 |
+
checkpoint["global_step"],
|
158 |
+
)
|
159 |
+
cfg_ckpt = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]
|
160 |
+
if list(cfg_ckpt.keys()) == ["cfg"]: # backward compatibility
|
161 |
+
cfg_ckpt = cfg_ckpt["cfg"]
|
162 |
+
cfg_ckpt = OmegaConf.create(cfg_ckpt)
|
163 |
+
|
164 |
+
if cfg is None:
|
165 |
+
cfg = {}
|
166 |
+
if not isinstance(cfg, DictConfig):
|
167 |
+
cfg = OmegaConf.create(cfg)
|
168 |
+
with open_dict(cfg_ckpt):
|
169 |
+
cfg = OmegaConf.merge(cfg_ckpt, cfg)
|
170 |
+
|
171 |
+
return pl.core.saving._load_state(cls, checkpoint, strict=strict, cfg=cfg)
|
osm/analysis.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
from collections import Counter, defaultdict
|
4 |
+
from typing import Dict
|
5 |
+
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import numpy as np
|
8 |
+
import plotly.graph_objects as go
|
9 |
+
|
10 |
+
from .parser import (
|
11 |
+
filter_area,
|
12 |
+
filter_node,
|
13 |
+
filter_way,
|
14 |
+
match_to_group,
|
15 |
+
parse_area,
|
16 |
+
parse_node,
|
17 |
+
parse_way,
|
18 |
+
Patterns,
|
19 |
+
)
|
20 |
+
from .reader import OSMData
|
21 |
+
|
22 |
+
|
23 |
+
def recover_hierarchy(counter: Counter) -> Dict:
|
24 |
+
"""Recover a two-level hierarchy from the flat group labels."""
|
25 |
+
groups = defaultdict(dict)
|
26 |
+
for k, v in sorted(counter.items(), key=lambda x: -x[1]):
|
27 |
+
if ":" in k:
|
28 |
+
prefix, group = k.split(":")
|
29 |
+
if prefix in groups and isinstance(groups[prefix], int):
|
30 |
+
groups[prefix] = {}
|
31 |
+
groups[prefix][prefix] = groups[prefix]
|
32 |
+
groups[prefix] = {}
|
33 |
+
groups[prefix][group] = v
|
34 |
+
else:
|
35 |
+
groups[k] = v
|
36 |
+
return dict(groups)
|
37 |
+
|
38 |
+
|
39 |
+
def bar_autolabel(rects, fontsize):
|
40 |
+
"""Attach a text label above each bar in *rects*, displaying its height."""
|
41 |
+
for rect in rects:
|
42 |
+
width = rect.get_width()
|
43 |
+
plt.gca().annotate(
|
44 |
+
f"{width}",
|
45 |
+
xy=(width, rect.get_y() + rect.get_height() / 2),
|
46 |
+
xytext=(3, 0), # 3 points vertical offset
|
47 |
+
textcoords="offset points",
|
48 |
+
ha="left",
|
49 |
+
va="center",
|
50 |
+
fontsize=fontsize,
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
def plot_histogram(counts, fontsize, dpi):
|
55 |
+
fig, ax = plt.subplots(dpi=dpi, figsize=(8, 20))
|
56 |
+
|
57 |
+
labels = []
|
58 |
+
for k, v in counts.items():
|
59 |
+
if isinstance(v, dict):
|
60 |
+
labels += list(v.keys())
|
61 |
+
v = list(v.values())
|
62 |
+
else:
|
63 |
+
labels.append(k)
|
64 |
+
v = [v]
|
65 |
+
bars = plt.barh(
|
66 |
+
len(labels) + -len(v) + np.arange(len(v)), v, height=0.9, label=k
|
67 |
+
)
|
68 |
+
bar_autolabel(bars, fontsize)
|
69 |
+
|
70 |
+
ax.set_yticklabels(labels, fontsize=fontsize)
|
71 |
+
ax.axes.xaxis.set_ticklabels([])
|
72 |
+
ax.xaxis.tick_top()
|
73 |
+
ax.invert_yaxis()
|
74 |
+
plt.yticks(np.arange(len(labels)))
|
75 |
+
plt.xscale("log")
|
76 |
+
plt.legend(ncol=len(counts), loc="upper center")
|
77 |
+
|
78 |
+
|
79 |
+
def count_elements(elems: Dict[int, str], filter_fn, parse_fn) -> Dict:
|
80 |
+
"""Count the number of elements in each group."""
|
81 |
+
counts = Counter()
|
82 |
+
for elem in filter(filter_fn, elems.values()):
|
83 |
+
group = parse_fn(elem.tags)
|
84 |
+
if group is None:
|
85 |
+
continue
|
86 |
+
counts[group] += 1
|
87 |
+
counts = recover_hierarchy(counts)
|
88 |
+
return counts
|
89 |
+
|
90 |
+
|
91 |
+
def plot_osm_histograms(osm: OSMData, fontsize=8, dpi=150):
|
92 |
+
counts = count_elements(osm.nodes, filter_node, parse_node)
|
93 |
+
plot_histogram(counts, fontsize, dpi)
|
94 |
+
plt.title("nodes")
|
95 |
+
|
96 |
+
counts = count_elements(osm.ways, filter_way, parse_way)
|
97 |
+
plot_histogram(counts, fontsize, dpi)
|
98 |
+
plt.title("ways")
|
99 |
+
|
100 |
+
counts = count_elements(osm.ways, filter_area, parse_area)
|
101 |
+
plot_histogram(counts, fontsize, dpi)
|
102 |
+
plt.title("areas")
|
103 |
+
|
104 |
+
|
105 |
+
def plot_sankey_hierarchy(osm: OSMData):
|
106 |
+
triplets = []
|
107 |
+
for node in filter(filter_node, osm.nodes.values()):
|
108 |
+
label = parse_node(node.tags)
|
109 |
+
if label is None:
|
110 |
+
continue
|
111 |
+
group = match_to_group(label, Patterns.nodes)
|
112 |
+
if group is None:
|
113 |
+
group = match_to_group(label, Patterns.ways)
|
114 |
+
if group is None:
|
115 |
+
group = "null"
|
116 |
+
if ":" in label:
|
117 |
+
key, tag = label.split(":")
|
118 |
+
if tag == "yes":
|
119 |
+
tag = key
|
120 |
+
else:
|
121 |
+
key = tag = label
|
122 |
+
triplets.append((key, tag, group))
|
123 |
+
keys, tags, groups = list(zip(*triplets))
|
124 |
+
counts_key_tag = Counter(zip(keys, tags))
|
125 |
+
counts_key_tag_group = Counter(triplets)
|
126 |
+
|
127 |
+
key2tags = defaultdict(set)
|
128 |
+
for k, t in zip(keys, tags):
|
129 |
+
key2tags[k].add(t)
|
130 |
+
key2tags = {k: sorted(t) for k, t in key2tags.items()}
|
131 |
+
keytag2group = dict(zip(zip(keys, tags), groups))
|
132 |
+
key_names = sorted(set(keys))
|
133 |
+
tag_names = [(k, t) for k in key_names for t in key2tags[k]]
|
134 |
+
|
135 |
+
group_names = []
|
136 |
+
for k in key_names:
|
137 |
+
for t in key2tags[k]:
|
138 |
+
g = keytag2group[k, t]
|
139 |
+
if g not in group_names and g != "null":
|
140 |
+
group_names.append(g)
|
141 |
+
group_names += ["null"]
|
142 |
+
|
143 |
+
key2idx = dict(zip(key_names, range(len(key_names))))
|
144 |
+
tag2idx = {kt: i + len(key2idx) for i, kt in enumerate(tag_names)}
|
145 |
+
group2idx = {n: i + len(key2idx) + len(tag2idx) for i, n in enumerate(group_names)}
|
146 |
+
|
147 |
+
key_counts = Counter(keys)
|
148 |
+
key_text = [f"{k} {key_counts[k]}" for k in key_names]
|
149 |
+
tag_counts = Counter(list(zip(keys, tags)))
|
150 |
+
tag_text = [f"{t} {tag_counts[k, t]}" for k, t in tag_names]
|
151 |
+
group_counts = Counter(groups)
|
152 |
+
group_text = [f"{k} {group_counts[k]}" for k in group_names]
|
153 |
+
|
154 |
+
fig = go.Figure(
|
155 |
+
data=[
|
156 |
+
go.Sankey(
|
157 |
+
orientation="h",
|
158 |
+
node=dict(
|
159 |
+
pad=15,
|
160 |
+
thickness=20,
|
161 |
+
line=dict(color="black", width=0.5),
|
162 |
+
label=key_text + tag_text + group_text,
|
163 |
+
x=[0] * len(key_names)
|
164 |
+
+ [1] * len(tag_names)
|
165 |
+
+ [2] * len(group_names),
|
166 |
+
color="blue",
|
167 |
+
),
|
168 |
+
arrangement="fixed",
|
169 |
+
link=dict(
|
170 |
+
source=[key2idx[k] for k, _ in counts_key_tag]
|
171 |
+
+ [tag2idx[k, t] for k, t, _ in counts_key_tag_group],
|
172 |
+
target=[tag2idx[k, t] for k, t in counts_key_tag]
|
173 |
+
+ [group2idx[g] for _, _, g in counts_key_tag_group],
|
174 |
+
value=list(counts_key_tag.values())
|
175 |
+
+ list(counts_key_tag_group.values()),
|
176 |
+
),
|
177 |
+
)
|
178 |
+
]
|
179 |
+
)
|
180 |
+
fig.update_layout(autosize=False, width=800, height=2000, font_size=10)
|
181 |
+
fig.show()
|
182 |
+
return fig
|
osm/data.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from dataclasses import dataclass, field
|
5 |
+
from typing import Dict, List, Optional, Set, Tuple
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from .parser import (
|
10 |
+
filter_area,
|
11 |
+
filter_node,
|
12 |
+
filter_way,
|
13 |
+
match_to_group,
|
14 |
+
parse_area,
|
15 |
+
parse_node,
|
16 |
+
parse_way,
|
17 |
+
Patterns,
|
18 |
+
)
|
19 |
+
from .reader import OSMData, OSMNode, OSMRelation, OSMWay
|
20 |
+
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
def glue(ways: List[OSMWay]) -> List[List[OSMNode]]:
|
26 |
+
result: List[List[OSMNode]] = []
|
27 |
+
to_process: Set[Tuple[OSMNode]] = set()
|
28 |
+
|
29 |
+
for way in ways:
|
30 |
+
if way.is_cycle():
|
31 |
+
result.append(way.nodes)
|
32 |
+
else:
|
33 |
+
to_process.add(tuple(way.nodes))
|
34 |
+
|
35 |
+
while to_process:
|
36 |
+
nodes: List[OSMNode] = list(to_process.pop())
|
37 |
+
glued: Optional[List[OSMNode]] = None
|
38 |
+
other_nodes: Optional[Tuple[OSMNode]] = None
|
39 |
+
|
40 |
+
for other_nodes in to_process:
|
41 |
+
glued = try_to_glue(nodes, list(other_nodes))
|
42 |
+
if glued is not None:
|
43 |
+
break
|
44 |
+
|
45 |
+
if glued is not None:
|
46 |
+
to_process.remove(other_nodes)
|
47 |
+
if is_cycle(glued):
|
48 |
+
result.append(glued)
|
49 |
+
else:
|
50 |
+
to_process.add(tuple(glued))
|
51 |
+
else:
|
52 |
+
result.append(nodes)
|
53 |
+
|
54 |
+
return result
|
55 |
+
|
56 |
+
|
57 |
+
def is_cycle(nodes: List[OSMNode]) -> bool:
|
58 |
+
"""Is way a cycle way or an area boundary."""
|
59 |
+
return nodes[0] == nodes[-1]
|
60 |
+
|
61 |
+
|
62 |
+
def try_to_glue(nodes: List[OSMNode], other: List[OSMNode]) -> Optional[List[OSMNode]]:
|
63 |
+
"""Create new combined way if ways share endpoints."""
|
64 |
+
if nodes[0] == other[0]:
|
65 |
+
return list(reversed(other[1:])) + nodes
|
66 |
+
if nodes[0] == other[-1]:
|
67 |
+
return other[:-1] + nodes
|
68 |
+
if nodes[-1] == other[-1]:
|
69 |
+
return nodes + list(reversed(other[:-1]))
|
70 |
+
if nodes[-1] == other[0]:
|
71 |
+
return nodes + other[1:]
|
72 |
+
return None
|
73 |
+
|
74 |
+
|
75 |
+
def multipolygon_from_relation(rel: OSMRelation, osm: OSMData):
|
76 |
+
inner_ways = []
|
77 |
+
outer_ways = []
|
78 |
+
for member in rel.members:
|
79 |
+
if member.type_ == "way":
|
80 |
+
if member.role == "inner":
|
81 |
+
if member.ref in osm.ways:
|
82 |
+
inner_ways.append(osm.ways[member.ref])
|
83 |
+
elif member.role == "outer":
|
84 |
+
if member.ref in osm.ways:
|
85 |
+
outer_ways.append(osm.ways[member.ref])
|
86 |
+
else:
|
87 |
+
logger.warning(f'Unknown member role "{member.role}".')
|
88 |
+
if outer_ways:
|
89 |
+
inners_path = glue(inner_ways)
|
90 |
+
outers_path = glue(outer_ways)
|
91 |
+
return inners_path, outers_path
|
92 |
+
|
93 |
+
|
94 |
+
@dataclass
|
95 |
+
class MapElement:
|
96 |
+
id_: int
|
97 |
+
label: str
|
98 |
+
group: str
|
99 |
+
tags: Optional[Dict[str, str]]
|
100 |
+
|
101 |
+
|
102 |
+
@dataclass
|
103 |
+
class MapNode(MapElement):
|
104 |
+
xy: np.ndarray
|
105 |
+
|
106 |
+
@classmethod
|
107 |
+
def from_osm(cls, node: OSMNode, label: str, group: str):
|
108 |
+
return cls(
|
109 |
+
node.id_,
|
110 |
+
label,
|
111 |
+
group,
|
112 |
+
node.tags,
|
113 |
+
xy=node.xy,
|
114 |
+
)
|
115 |
+
|
116 |
+
|
117 |
+
@dataclass
|
118 |
+
class MapLine(MapElement):
|
119 |
+
xy: np.ndarray
|
120 |
+
|
121 |
+
@classmethod
|
122 |
+
def from_osm(cls, way: OSMWay, label: str, group: str):
|
123 |
+
xy = np.stack([n.xy for n in way.nodes])
|
124 |
+
return cls(
|
125 |
+
way.id_,
|
126 |
+
label,
|
127 |
+
group,
|
128 |
+
way.tags,
|
129 |
+
xy=xy,
|
130 |
+
)
|
131 |
+
|
132 |
+
|
133 |
+
@dataclass
|
134 |
+
class MapArea(MapElement):
|
135 |
+
outers: List[np.ndarray]
|
136 |
+
inners: List[np.ndarray] = field(default_factory=list)
|
137 |
+
|
138 |
+
@classmethod
|
139 |
+
def from_relation(cls, rel: OSMRelation, label: str, group: str, osm: OSMData):
|
140 |
+
outers_inners = multipolygon_from_relation(rel, osm)
|
141 |
+
if outers_inners is None:
|
142 |
+
return None
|
143 |
+
outers, inners = outers_inners
|
144 |
+
outers = [np.stack([n.xy for n in way]) for way in outers]
|
145 |
+
inners = [np.stack([n.xy for n in way]) for way in inners]
|
146 |
+
return cls(
|
147 |
+
rel.id_,
|
148 |
+
label,
|
149 |
+
group,
|
150 |
+
rel.tags,
|
151 |
+
outers=outers,
|
152 |
+
inners=inners,
|
153 |
+
)
|
154 |
+
|
155 |
+
@classmethod
|
156 |
+
def from_way(cls, way: OSMWay, label: str, group: str):
|
157 |
+
xy = np.stack([n.xy for n in way.nodes])
|
158 |
+
return cls(
|
159 |
+
way.id_,
|
160 |
+
label,
|
161 |
+
group,
|
162 |
+
way.tags,
|
163 |
+
outers=[xy],
|
164 |
+
)
|
165 |
+
|
166 |
+
|
167 |
+
class MapData:
|
168 |
+
def __init__(self):
|
169 |
+
self.nodes: Dict[int, MapNode] = {}
|
170 |
+
self.lines: Dict[int, MapLine] = {}
|
171 |
+
self.areas: Dict[int, MapArea] = {}
|
172 |
+
|
173 |
+
@classmethod
|
174 |
+
def from_osm(cls, osm: OSMData):
|
175 |
+
self = cls()
|
176 |
+
|
177 |
+
for node in filter(filter_node, osm.nodes.values()):
|
178 |
+
label = parse_node(node.tags)
|
179 |
+
if label is None:
|
180 |
+
continue
|
181 |
+
group = match_to_group(label, Patterns.nodes)
|
182 |
+
if group is None:
|
183 |
+
group = match_to_group(label, Patterns.ways)
|
184 |
+
if group is None:
|
185 |
+
continue # missing
|
186 |
+
self.nodes[node.id_] = MapNode.from_osm(node, label, group)
|
187 |
+
|
188 |
+
for way in filter(filter_way, osm.ways.values()):
|
189 |
+
label = parse_way(way.tags)
|
190 |
+
if label is None:
|
191 |
+
continue
|
192 |
+
group = match_to_group(label, Patterns.ways)
|
193 |
+
if group is None:
|
194 |
+
group = match_to_group(label, Patterns.nodes)
|
195 |
+
if group is None:
|
196 |
+
continue # missing
|
197 |
+
self.lines[way.id_] = MapLine.from_osm(way, label, group)
|
198 |
+
|
199 |
+
for area in filter(filter_area, osm.ways.values()):
|
200 |
+
label = parse_area(area.tags)
|
201 |
+
if label is None:
|
202 |
+
continue
|
203 |
+
group = match_to_group(label, Patterns.areas)
|
204 |
+
if group is None:
|
205 |
+
group = match_to_group(label, Patterns.ways)
|
206 |
+
if group is None:
|
207 |
+
group = match_to_group(label, Patterns.nodes)
|
208 |
+
if group is None:
|
209 |
+
continue # missing
|
210 |
+
self.areas[area.id_] = MapArea.from_way(area, label, group)
|
211 |
+
|
212 |
+
for rel in osm.relations.values():
|
213 |
+
if rel.tags.get("type") != "multipolygon":
|
214 |
+
continue
|
215 |
+
label = parse_area(rel.tags)
|
216 |
+
if label is None:
|
217 |
+
continue
|
218 |
+
group = match_to_group(label, Patterns.areas)
|
219 |
+
if group is None:
|
220 |
+
group = match_to_group(label, Patterns.ways)
|
221 |
+
if group is None:
|
222 |
+
group = match_to_group(label, Patterns.nodes)
|
223 |
+
if group is None:
|
224 |
+
continue # missing
|
225 |
+
area = MapArea.from_relation(rel, label, group, osm)
|
226 |
+
assert rel.id_ not in self.areas # not sure if there can be collision
|
227 |
+
if area is not None:
|
228 |
+
self.areas[rel.id_] = area
|
229 |
+
|
230 |
+
return self
|
osm/download.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import json
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Dict, Optional
|
6 |
+
|
7 |
+
import urllib3
|
8 |
+
|
9 |
+
|
10 |
+
from utils.geo import BoundaryBox
|
11 |
+
import urllib.request
|
12 |
+
import requests
|
13 |
+
|
14 |
+
def get_osm(
|
15 |
+
boundary_box: BoundaryBox,
|
16 |
+
cache_path: Optional[Path] = None,
|
17 |
+
overwrite: bool = False,
|
18 |
+
) -> str:
|
19 |
+
if not overwrite and cache_path is not None and cache_path.is_file():
|
20 |
+
with cache_path.open() as fp:
|
21 |
+
return json.load(fp)
|
22 |
+
|
23 |
+
(bottom, left), (top, right) = boundary_box.min_, boundary_box.max_
|
24 |
+
content: bytes = get_web_data(
|
25 |
+
# "https://api.openstreetmap.org/api/0.6/map.json",
|
26 |
+
"https://openstreetmap.erniubot.live/api/0.6/map.json",
|
27 |
+
# 'https://overpass-api.de/api/map',
|
28 |
+
# 'http://localhost:29505/api/map',
|
29 |
+
# "https://lz4.overpass-api.de/api/interpreter",
|
30 |
+
{"bbox": f"{left},{bottom},{right},{top}"},
|
31 |
+
)
|
32 |
+
|
33 |
+
content_str = content.decode("utf-8")
|
34 |
+
if content_str.startswith("You requested too many nodes"):
|
35 |
+
raise ValueError(content_str)
|
36 |
+
|
37 |
+
if cache_path is not None:
|
38 |
+
with cache_path.open("bw+") as fp:
|
39 |
+
fp.write(content)
|
40 |
+
a=json.loads(content_str)
|
41 |
+
return json.loads(content_str)
|
42 |
+
|
43 |
+
|
44 |
+
def get_web_data(address: str, parameters: Dict[str, str]) -> bytes:
|
45 |
+
# logger.info("Getting %s...", address)
|
46 |
+
# proxy_address = "http://107.173.122.186:3128"
|
47 |
+
#
|
48 |
+
# # 设置代理服务器地址和端口
|
49 |
+
# proxies = {
|
50 |
+
# 'http': proxy_address,
|
51 |
+
# 'https': proxy_address
|
52 |
+
# }
|
53 |
+
|
54 |
+
# 发送GET请求并返回响应数据
|
55 |
+
# response = requests.get(address, params=parameters, timeout=100, proxies=proxies)
|
56 |
+
print('url:',address)
|
57 |
+
response = requests.get(address, params=parameters, timeout=100)
|
58 |
+
return response.content
|
59 |
+
def get_web_data(address: str, parameters: Dict[str, str]) -> bytes:
|
60 |
+
# logger.info("Getting %s...", address)
|
61 |
+
while True:
|
62 |
+
try:
|
63 |
+
# proxy_address = "http://107.173.122.186:3128"
|
64 |
+
#
|
65 |
+
# # 设置代理服务器地址和端口
|
66 |
+
# proxies = {
|
67 |
+
# 'http': proxy_address,
|
68 |
+
# 'https': proxy_address
|
69 |
+
# }
|
70 |
+
# # 发送GET请求并返回响应数据
|
71 |
+
response = requests.get(address, params=parameters, timeout=100)
|
72 |
+
request = requests.Request('GET', address, params=parameters)
|
73 |
+
prepared_request = request.prepare()
|
74 |
+
# 获取完整URL
|
75 |
+
full_url = prepared_request.url
|
76 |
+
break
|
77 |
+
|
78 |
+
except Exception as e:
|
79 |
+
# 打印错误信息
|
80 |
+
print(f"发生错误: {e}")
|
81 |
+
print("重试...")
|
82 |
+
|
83 |
+
return response.content
|
84 |
+
# def get_web_data_2(address: str, parameters: Dict[str, str]) -> bytes:
|
85 |
+
# # logger.info("Getting %s...", address)
|
86 |
+
# proxy_address="http://107.173.122.186:3128"
|
87 |
+
# http = urllib3.PoolManager(proxy_url=proxy_address)
|
88 |
+
# result = http.request("GET", address, parameters, timeout=100)
|
89 |
+
# return result.data
|
90 |
+
#
|
91 |
+
#
|
92 |
+
# def get_web_data_1(address: str, parameters: Dict[str, str]) -> bytes:
|
93 |
+
#
|
94 |
+
# # 设置代理服务器地址和端口
|
95 |
+
# proxy_address = "http://107.173.122.186:3128"
|
96 |
+
#
|
97 |
+
# # 创建ProxyHandler对象
|
98 |
+
# proxy_handler = urllib.request.ProxyHandler({'http': proxy_address})
|
99 |
+
#
|
100 |
+
# # 构建查询字符串
|
101 |
+
# query_string = urllib.parse.urlencode(parameters)
|
102 |
+
#
|
103 |
+
# # 构建完整的URL
|
104 |
+
# url = address + '?' + query_string
|
105 |
+
# print(url)
|
106 |
+
# # 创建OpenerDirector对象,并将ProxyHandler对象作为参数传递
|
107 |
+
# opener = urllib.request.build_opener(proxy_handler)
|
108 |
+
#
|
109 |
+
# # 使用OpenerDirector对象发送请求
|
110 |
+
# response = opener.open(url)
|
111 |
+
#
|
112 |
+
# # 发送GET请求
|
113 |
+
# # response = urllib.request.urlopen(url, timeout=100)
|
114 |
+
#
|
115 |
+
# # 读取响应内容
|
116 |
+
# data = response.read()
|
117 |
+
# print()
|
118 |
+
# return data
|
osm/parser.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import re
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
from .reader import OSMData, OSMElement, OSMNode, OSMWay
|
8 |
+
|
9 |
+
IGNORE_TAGS = {"source", "phone", "entrance", "inscription", "note", "name"}
|
10 |
+
|
11 |
+
|
12 |
+
def parse_levels(string: str) -> List[float]:
|
13 |
+
"""Parse string representation of level sequence value."""
|
14 |
+
try:
|
15 |
+
cleaned = string.replace(",", ";").replace(" ", "")
|
16 |
+
return list(map(float, cleaned.split(";")))
|
17 |
+
except ValueError:
|
18 |
+
logging.debug("Cannot parse level description from `%s`.", string)
|
19 |
+
return []
|
20 |
+
|
21 |
+
|
22 |
+
def filter_level(elem: OSMElement):
|
23 |
+
level = elem.tags.get("level")
|
24 |
+
if level is not None:
|
25 |
+
levels = parse_levels(level)
|
26 |
+
# In the US, ground floor levels are sometimes marked as level=1
|
27 |
+
# so let's be conservative and include it.
|
28 |
+
if not (0 in levels or 1 in levels):
|
29 |
+
return False
|
30 |
+
layer = elem.tags.get("layer")
|
31 |
+
if layer is not None:
|
32 |
+
layer = parse_levels(layer)
|
33 |
+
if len(layer) > 0 and max(layer) < 0:
|
34 |
+
return False
|
35 |
+
return (
|
36 |
+
elem.tags.get("location") != "underground"
|
37 |
+
and elem.tags.get("parking") != "underground"
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
def filter_node(node: OSMNode):
|
42 |
+
return len(node.tags.keys() - IGNORE_TAGS) > 0 and filter_level(node)
|
43 |
+
|
44 |
+
|
45 |
+
def is_area(way: OSMWay):
|
46 |
+
if way.nodes[0] != way.nodes[-1]:
|
47 |
+
return False
|
48 |
+
if way.tags.get("area") == "no":
|
49 |
+
return False
|
50 |
+
filters = [
|
51 |
+
"area",
|
52 |
+
"building",
|
53 |
+
"amenity",
|
54 |
+
"indoor",
|
55 |
+
"landuse",
|
56 |
+
"landcover",
|
57 |
+
"leisure",
|
58 |
+
"public_transport",
|
59 |
+
"shop",
|
60 |
+
]
|
61 |
+
for f in filters:
|
62 |
+
if f in way.tags and way.tags.get(f) != "no":
|
63 |
+
return True
|
64 |
+
if way.tags.get("natural") in {"wood", "grassland", "water"}:
|
65 |
+
return True
|
66 |
+
return False
|
67 |
+
|
68 |
+
|
69 |
+
def filter_area(way: OSMWay):
|
70 |
+
return len(way.tags.keys() - IGNORE_TAGS) > 0 and is_area(way) and filter_level(way)
|
71 |
+
|
72 |
+
|
73 |
+
def filter_way(way: OSMWay):
|
74 |
+
return not filter_area(way) and way.tags != {} and filter_level(way)
|
75 |
+
|
76 |
+
|
77 |
+
def parse_node(tags):
|
78 |
+
keys = tags.keys()
|
79 |
+
for key in [
|
80 |
+
"amenity",
|
81 |
+
"natural",
|
82 |
+
"highway",
|
83 |
+
"barrier",
|
84 |
+
"shop",
|
85 |
+
"tourism",
|
86 |
+
"public_transport",
|
87 |
+
"emergency",
|
88 |
+
"man_made",
|
89 |
+
]:
|
90 |
+
if key in keys:
|
91 |
+
if "disused" in tags[key]:
|
92 |
+
continue
|
93 |
+
return f"{key}:{tags[key]}"
|
94 |
+
return None
|
95 |
+
|
96 |
+
|
97 |
+
def parse_area(tags):
|
98 |
+
if "building" in tags:
|
99 |
+
group = "building"
|
100 |
+
kind = tags["building"]
|
101 |
+
if kind == "yes":
|
102 |
+
for key in ["amenity", "tourism"]:
|
103 |
+
if key in tags:
|
104 |
+
kind = tags[key]
|
105 |
+
break
|
106 |
+
if kind != "yes":
|
107 |
+
group += f":{kind}"
|
108 |
+
return group
|
109 |
+
if "area:highway" in tags:
|
110 |
+
return f'highway:{tags["area:highway"]}'
|
111 |
+
for key in [
|
112 |
+
"amenity",
|
113 |
+
"landcover",
|
114 |
+
"leisure",
|
115 |
+
"shop",
|
116 |
+
"highway",
|
117 |
+
"tourism",
|
118 |
+
"natural",
|
119 |
+
"waterway",
|
120 |
+
"landuse",
|
121 |
+
]:
|
122 |
+
if key in tags:
|
123 |
+
return f"{key}:{tags[key]}"
|
124 |
+
return None
|
125 |
+
|
126 |
+
|
127 |
+
def parse_way(tags):
|
128 |
+
keys = tags.keys()
|
129 |
+
for key in ["highway", "barrier", "natural"]:
|
130 |
+
if key in keys:
|
131 |
+
return f"{key}:{tags[key]}"
|
132 |
+
return None
|
133 |
+
|
134 |
+
|
135 |
+
def match_to_group(label, patterns):
|
136 |
+
for group, pattern in patterns.items():
|
137 |
+
if re.match(pattern, label):
|
138 |
+
return group
|
139 |
+
return None
|
140 |
+
|
141 |
+
|
142 |
+
class Patterns:
|
143 |
+
areas = dict(
|
144 |
+
building="building($|:.*?)*",
|
145 |
+
parking="amenity:parking",
|
146 |
+
playground="leisure:(playground|pitch)",
|
147 |
+
grass="(landuse:grass|landcover:grass|landuse:meadow|landuse:flowerbed|natural:grassland)",
|
148 |
+
park="leisure:(park|garden|dog_park)",
|
149 |
+
forest="(landuse:forest|natural:wood)",
|
150 |
+
water="(natural:water|waterway:*)",
|
151 |
+
)
|
152 |
+
# + ways: road, path
|
153 |
+
# + node: fountain, bicycle_parking
|
154 |
+
|
155 |
+
ways = dict(
|
156 |
+
fence="barrier:(fence|yes)",
|
157 |
+
wall="barrier:(wall|retaining_wall)",
|
158 |
+
hedge="barrier:hedge",
|
159 |
+
kerb="barrier:kerb",
|
160 |
+
building_outline="building($|:.*?)*",
|
161 |
+
cycleway="highway:cycleway",
|
162 |
+
path="highway:(pedestrian|footway|steps|path|corridor)",
|
163 |
+
road="highway:(motorway|trunk|primary|secondary|tertiary|service|construction|track|unclassified|residential|.*_link)",
|
164 |
+
busway="highway:busway",
|
165 |
+
tree_row="natural:tree_row", # maybe merge with node?
|
166 |
+
)
|
167 |
+
# + nodes: bollard
|
168 |
+
|
169 |
+
nodes = dict(
|
170 |
+
tree="natural:tree",
|
171 |
+
stone="(natural:stone|barrier:block)",
|
172 |
+
crossing="highway:crossing",
|
173 |
+
lamp="highway:street_lamp",
|
174 |
+
traffic_signal="highway:traffic_signals",
|
175 |
+
bus_stop="highway:bus_stop",
|
176 |
+
stop_sign="highway:stop",
|
177 |
+
junction="highway:motorway_junction",
|
178 |
+
bus_stop_position="public_transport:stop_position",
|
179 |
+
gate="barrier:(gate|lift_gate|swing_gate|cycle_barrier)",
|
180 |
+
bollard="barrier:bollard",
|
181 |
+
shop="(shop.*?|amenity:(bank|post_office))",
|
182 |
+
restaurant="amenity:(restaurant|fast_food)",
|
183 |
+
bar="amenity:(cafe|bar|pub|biergarten)",
|
184 |
+
pharmacy="amenity:pharmacy",
|
185 |
+
fuel="amenity:fuel",
|
186 |
+
bicycle_parking="amenity:(bicycle_parking|bicycle_rental)",
|
187 |
+
charging_station="amenity:charging_station",
|
188 |
+
parking_entrance="amenity:parking_entrance",
|
189 |
+
atm="amenity:atm",
|
190 |
+
toilets="amenity:toilets",
|
191 |
+
vending_machine="amenity:vending_machine",
|
192 |
+
fountain="amenity:fountain",
|
193 |
+
waste_basket="amenity:(waste_basket|waste_disposal)",
|
194 |
+
bench="amenity:bench",
|
195 |
+
post_box="amenity:post_box",
|
196 |
+
artwork="tourism:artwork",
|
197 |
+
recycling="amenity:recycling",
|
198 |
+
give_way="highway:give_way",
|
199 |
+
clock="amenity:clock",
|
200 |
+
fire_hydrant="emergency:fire_hydrant",
|
201 |
+
pole="man_made:(flagpole|utility_pole)",
|
202 |
+
street_cabinet="man_made:street_cabinet",
|
203 |
+
)
|
204 |
+
# + ways: kerb
|
205 |
+
|
206 |
+
|
207 |
+
class Groups:
|
208 |
+
areas = list(Patterns.areas)
|
209 |
+
ways = list(Patterns.ways)
|
210 |
+
nodes = list(Patterns.nodes)
|
211 |
+
|
212 |
+
|
213 |
+
def group_elements(osm: OSMData):
|
214 |
+
elem2group = {
|
215 |
+
"area": {},
|
216 |
+
"way": {},
|
217 |
+
"node": {},
|
218 |
+
}
|
219 |
+
|
220 |
+
for node in filter(filter_node, osm.nodes.values()):
|
221 |
+
label = parse_node(node.tags)
|
222 |
+
if label is None:
|
223 |
+
continue
|
224 |
+
group = match_to_group(label, Patterns.nodes)
|
225 |
+
if group is None:
|
226 |
+
group = match_to_group(label, Patterns.ways)
|
227 |
+
if group is None:
|
228 |
+
continue # missing
|
229 |
+
elem2group["node"][node.id_] = group
|
230 |
+
|
231 |
+
for way in filter(filter_way, osm.ways.values()):
|
232 |
+
label = parse_way(way.tags)
|
233 |
+
if label is None:
|
234 |
+
continue
|
235 |
+
group = match_to_group(label, Patterns.ways)
|
236 |
+
if group is None:
|
237 |
+
group = match_to_group(label, Patterns.nodes)
|
238 |
+
if group is None:
|
239 |
+
continue # missing
|
240 |
+
elem2group["way"][way.id_] = group
|
241 |
+
|
242 |
+
for area in filter(filter_area, osm.ways.values()):
|
243 |
+
label = parse_area(area.tags)
|
244 |
+
if label is None:
|
245 |
+
continue
|
246 |
+
group = match_to_group(label, Patterns.areas)
|
247 |
+
if group is None:
|
248 |
+
group = match_to_group(label, Patterns.ways)
|
249 |
+
if group is None:
|
250 |
+
group = match_to_group(label, Patterns.nodes)
|
251 |
+
if group is None:
|
252 |
+
continue # missing
|
253 |
+
elem2group["area"][area.id_] = group
|
254 |
+
|
255 |
+
return elem2group
|
osm/raster.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
from typing import Dict, List
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from utils.geo import BoundaryBox
|
10 |
+
from .data import MapArea, MapLine, MapNode
|
11 |
+
from .parser import Groups
|
12 |
+
|
13 |
+
|
14 |
+
class Canvas:
|
15 |
+
def __init__(self, bbox: BoundaryBox, ppm: float):
|
16 |
+
self.bbox = bbox
|
17 |
+
self.ppm = ppm
|
18 |
+
self.scaling = bbox.size * ppm
|
19 |
+
self.w, self.h = np.ceil(self.scaling).astype(int)
|
20 |
+
self.clear()
|
21 |
+
|
22 |
+
def clear(self):
|
23 |
+
self.raster = np.zeros((self.h, self.w), np.uint8)
|
24 |
+
|
25 |
+
def to_uv(self, xy: np.ndarray):
|
26 |
+
xy = self.bbox.normalize(xy)
|
27 |
+
xy[..., 1] = 1 - xy[..., 1]
|
28 |
+
s = self.scaling
|
29 |
+
if isinstance(xy, torch.Tensor):
|
30 |
+
s = torch.from_numpy(s).to(xy)
|
31 |
+
return xy * s - 0.5
|
32 |
+
|
33 |
+
def to_xy(self, uv: np.ndarray):
|
34 |
+
s = self.scaling
|
35 |
+
if isinstance(uv, torch.Tensor):
|
36 |
+
s = torch.from_numpy(s).to(uv)
|
37 |
+
xy = (uv + 0.5) / s
|
38 |
+
xy[..., 1] = 1 - xy[..., 1]
|
39 |
+
return self.bbox.unnormalize(xy)
|
40 |
+
|
41 |
+
def draw_polygon(self, xy: np.ndarray):
|
42 |
+
uv = self.to_uv(xy)
|
43 |
+
cv2.fillPoly(self.raster, uv[None].astype(np.int32), 255)
|
44 |
+
|
45 |
+
def draw_multipolygon(self, xys: List[np.ndarray]):
|
46 |
+
uvs = [self.to_uv(xy).round().astype(np.int32) for xy in xys]
|
47 |
+
cv2.fillPoly(self.raster, uvs, 255)
|
48 |
+
|
49 |
+
def draw_line(self, xy: np.ndarray, width: float = 1):
|
50 |
+
uv = self.to_uv(xy)
|
51 |
+
cv2.polylines(
|
52 |
+
self.raster, uv[None].round().astype(np.int32), False, 255, thickness=width
|
53 |
+
)
|
54 |
+
|
55 |
+
def draw_cell(self, xy: np.ndarray):
|
56 |
+
if not self.bbox.contains(xy):
|
57 |
+
return
|
58 |
+
uv = self.to_uv(xy)
|
59 |
+
self.raster[tuple(uv.round().astype(int).T[::-1])] = 255
|
60 |
+
|
61 |
+
|
62 |
+
def render_raster_masks(
|
63 |
+
nodes: List[MapNode],
|
64 |
+
lines: List[MapLine],
|
65 |
+
areas: List[MapArea],
|
66 |
+
canvas: Canvas,
|
67 |
+
) -> Dict[str, np.ndarray]:
|
68 |
+
all_groups = Groups.areas + Groups.ways + Groups.nodes
|
69 |
+
masks = {k: np.zeros((canvas.h, canvas.w), np.uint8) for k in all_groups}
|
70 |
+
|
71 |
+
for area in areas:
|
72 |
+
canvas.raster = masks[area.group]
|
73 |
+
outlines = area.outers + area.inners
|
74 |
+
canvas.draw_multipolygon(outlines)
|
75 |
+
if area.group == "building":
|
76 |
+
canvas.raster = masks["building_outline"]
|
77 |
+
for line in outlines:
|
78 |
+
canvas.draw_line(line)
|
79 |
+
|
80 |
+
for line in lines:
|
81 |
+
canvas.raster = masks[line.group]
|
82 |
+
canvas.draw_line(line.xy)
|
83 |
+
|
84 |
+
for node in nodes:
|
85 |
+
canvas.raster = masks[node.group]
|
86 |
+
canvas.draw_cell(node.xy)
|
87 |
+
|
88 |
+
return masks
|
89 |
+
|
90 |
+
|
91 |
+
def mask_to_idx(group2mask: Dict[str, np.ndarray], groups: List[str]) -> np.ndarray:
|
92 |
+
masks = np.stack([group2mask[k] for k in groups]) > 0
|
93 |
+
void = ~np.any(masks, 0)
|
94 |
+
idx = np.argmax(masks, 0)
|
95 |
+
idx = np.where(void, np.zeros_like(idx), idx + 1) # add background
|
96 |
+
return idx
|
97 |
+
|
98 |
+
|
99 |
+
def render_raster_map(masks: Dict[str, np.ndarray]) -> np.ndarray:
|
100 |
+
areas = mask_to_idx(masks, Groups.areas)
|
101 |
+
ways = mask_to_idx(masks, Groups.ways)
|
102 |
+
nodes = mask_to_idx(masks, Groups.nodes)
|
103 |
+
return np.stack([areas, ways, nodes])
|
osm/reader.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import json
|
4 |
+
import re
|
5 |
+
from dataclasses import dataclass, field
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Any, Dict, List, Optional
|
8 |
+
|
9 |
+
from lxml import etree
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
from utils.geo import BoundaryBox, Projection
|
13 |
+
|
14 |
+
METERS_PATTERN: re.Pattern = re.compile("^(?P<value>\\d*\\.?\\d*)\\s*m$")
|
15 |
+
KILOMETERS_PATTERN: re.Pattern = re.compile("^(?P<value>\\d*\\.?\\d*)\\s*km$")
|
16 |
+
MILES_PATTERN: re.Pattern = re.compile("^(?P<value>\\d*\\.?\\d*)\\s*mi$")
|
17 |
+
|
18 |
+
|
19 |
+
def parse_float(string: str) -> Optional[float]:
|
20 |
+
"""Parse string representation of a float or integer value."""
|
21 |
+
try:
|
22 |
+
return float(string)
|
23 |
+
except (TypeError, ValueError):
|
24 |
+
return None
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass(eq=False)
|
28 |
+
class OSMElement:
|
29 |
+
"""
|
30 |
+
Something with tags (string to string mapping).
|
31 |
+
"""
|
32 |
+
|
33 |
+
id_: int
|
34 |
+
tags: Dict[str, str]
|
35 |
+
|
36 |
+
def get_float(self, key: str) -> Optional[float]:
|
37 |
+
"""Parse float from tag value."""
|
38 |
+
if key in self.tags:
|
39 |
+
return parse_float(self.tags[key])
|
40 |
+
return None
|
41 |
+
|
42 |
+
def get_length(self, key: str) -> Optional[float]:
|
43 |
+
"""Get length in meters."""
|
44 |
+
if key not in self.tags:
|
45 |
+
return None
|
46 |
+
|
47 |
+
value: str = self.tags[key]
|
48 |
+
|
49 |
+
float_value: float = parse_float(value)
|
50 |
+
if float_value is not None:
|
51 |
+
return float_value
|
52 |
+
|
53 |
+
for pattern, ratio in [
|
54 |
+
(METERS_PATTERN, 1.0),
|
55 |
+
(KILOMETERS_PATTERN, 1000.0),
|
56 |
+
(MILES_PATTERN, 1609.344),
|
57 |
+
]:
|
58 |
+
matcher: re.Match = pattern.match(value)
|
59 |
+
if matcher:
|
60 |
+
float_value: float = parse_float(matcher.group("value"))
|
61 |
+
if float_value is not None:
|
62 |
+
return float_value * ratio
|
63 |
+
|
64 |
+
return None
|
65 |
+
|
66 |
+
def __hash__(self) -> int:
|
67 |
+
return self.id_
|
68 |
+
|
69 |
+
|
70 |
+
@dataclass(eq=False)
|
71 |
+
class OSMNode(OSMElement):
|
72 |
+
"""
|
73 |
+
OpenStreetMap node.
|
74 |
+
|
75 |
+
See https://wiki.openstreetmap.org/wiki/Node
|
76 |
+
"""
|
77 |
+
|
78 |
+
geo: np.ndarray
|
79 |
+
visible: Optional[str] = None
|
80 |
+
xy: Optional[np.ndarray] = None
|
81 |
+
|
82 |
+
@classmethod
|
83 |
+
def from_dict(cls, structure: Dict[str, Any]) -> "OSMNode":
|
84 |
+
"""
|
85 |
+
Parse node from Overpass-like structure.
|
86 |
+
|
87 |
+
:param structure: input structure
|
88 |
+
"""
|
89 |
+
return cls(
|
90 |
+
structure["id"],
|
91 |
+
structure.get("tags", {}),
|
92 |
+
geo=np.array((structure["lat"], structure["lon"])),
|
93 |
+
visible=structure.get("visible"),
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
+
@dataclass(eq=False)
|
98 |
+
class OSMWay(OSMElement):
|
99 |
+
"""
|
100 |
+
OpenStreetMap way.
|
101 |
+
|
102 |
+
See https://wiki.openstreetmap.org/wiki/Way
|
103 |
+
"""
|
104 |
+
|
105 |
+
nodes: Optional[List[OSMNode]] = field(default_factory=list)
|
106 |
+
visible: Optional[str] = None
|
107 |
+
|
108 |
+
@classmethod
|
109 |
+
def from_dict(
|
110 |
+
cls, structure: Dict[str, Any], nodes: Dict[int, OSMNode]
|
111 |
+
) -> "OSMWay":
|
112 |
+
"""
|
113 |
+
Parse way from Overpass-like structure.
|
114 |
+
|
115 |
+
:param structure: input structure
|
116 |
+
:param nodes: node structure
|
117 |
+
"""
|
118 |
+
return cls(
|
119 |
+
structure["id"],
|
120 |
+
structure.get("tags", {}),
|
121 |
+
[nodes[x] for x in structure["nodes"]],
|
122 |
+
visible=structure.get("visible"),
|
123 |
+
)
|
124 |
+
|
125 |
+
def is_cycle(self) -> bool:
|
126 |
+
"""Is way a cycle way or an area boundary."""
|
127 |
+
return self.nodes[0] == self.nodes[-1]
|
128 |
+
|
129 |
+
def __repr__(self) -> str:
|
130 |
+
return f"Way <{self.id_}> {self.nodes}"
|
131 |
+
|
132 |
+
|
133 |
+
@dataclass
|
134 |
+
class OSMMember:
|
135 |
+
"""
|
136 |
+
Member of OpenStreetMap relation.
|
137 |
+
"""
|
138 |
+
|
139 |
+
type_: str
|
140 |
+
ref: int
|
141 |
+
role: str
|
142 |
+
|
143 |
+
|
144 |
+
@dataclass(eq=False)
|
145 |
+
class OSMRelation(OSMElement):
|
146 |
+
"""
|
147 |
+
OpenStreetMap relation.
|
148 |
+
|
149 |
+
See https://wiki.openstreetmap.org/wiki/Relation
|
150 |
+
"""
|
151 |
+
|
152 |
+
members: Optional[List[OSMMember]]
|
153 |
+
visible: Optional[str] = None
|
154 |
+
|
155 |
+
@classmethod
|
156 |
+
def from_dict(cls, structure: Dict[str, Any]) -> "OSMRelation":
|
157 |
+
"""
|
158 |
+
Parse relation from Overpass-like structure.
|
159 |
+
|
160 |
+
:param structure: input structure
|
161 |
+
"""
|
162 |
+
return cls(
|
163 |
+
structure["id"],
|
164 |
+
structure["tags"],
|
165 |
+
[OSMMember(x["type"], x["ref"], x["role"]) for x in structure["members"]],
|
166 |
+
visible=structure.get("visible"),
|
167 |
+
)
|
168 |
+
|
169 |
+
|
170 |
+
class OSMData:
|
171 |
+
"""
|
172 |
+
The whole OpenStreetMap information about nodes, ways, and relations.
|
173 |
+
"""
|
174 |
+
|
175 |
+
def __init__(self) -> None:
|
176 |
+
self.nodes: Dict[int, OSMNode] = {}
|
177 |
+
self.ways: Dict[int, OSMWay] = {}
|
178 |
+
self.relations: Dict[int, OSMRelation] = {}
|
179 |
+
self.box: BoundaryBox = None
|
180 |
+
|
181 |
+
@classmethod
|
182 |
+
def from_dict(cls, structure: Dict[str, Any]):
|
183 |
+
data = cls()
|
184 |
+
bounds = structure.get("bounds")
|
185 |
+
if bounds is not None:
|
186 |
+
data.box = BoundaryBox(
|
187 |
+
np.array([bounds["minlat"], bounds["minlon"]]),
|
188 |
+
np.array([bounds["maxlat"], bounds["maxlon"]]),
|
189 |
+
)
|
190 |
+
|
191 |
+
for element in structure["elements"]:
|
192 |
+
if element["type"] == "node":
|
193 |
+
node = OSMNode.from_dict(element)
|
194 |
+
data.add_node(node)
|
195 |
+
for element in structure["elements"]:
|
196 |
+
if element["type"] == "way":
|
197 |
+
way = OSMWay.from_dict(element, data.nodes)
|
198 |
+
data.add_way(way)
|
199 |
+
for element in structure["elements"]:
|
200 |
+
if element["type"] == "relation":
|
201 |
+
relation = OSMRelation.from_dict(element)
|
202 |
+
data.add_relation(relation)
|
203 |
+
|
204 |
+
return data
|
205 |
+
|
206 |
+
@classmethod
|
207 |
+
def from_json(cls, path: Path):
|
208 |
+
with path.open(encoding='utf-8') as fid:
|
209 |
+
structure = json.load(fid)
|
210 |
+
return cls.from_dict(structure)
|
211 |
+
|
212 |
+
@classmethod
|
213 |
+
def from_xml(cls, path: Path):
|
214 |
+
root = etree.parse(str(path)).getroot()
|
215 |
+
structure = {"elements": []}
|
216 |
+
from tqdm import tqdm
|
217 |
+
|
218 |
+
for elem in tqdm(root):
|
219 |
+
if elem.tag == "bounds":
|
220 |
+
structure["bounds"] = {
|
221 |
+
k: float(elem.attrib[k])
|
222 |
+
for k in ("minlon", "minlat", "maxlon", "maxlat")
|
223 |
+
}
|
224 |
+
elif elem.tag in {"node", "way", "relation"}:
|
225 |
+
if elem.tag == "node":
|
226 |
+
item = {
|
227 |
+
"id": int(elem.attrib["id"]),
|
228 |
+
"lat": float(elem.attrib["lat"]),
|
229 |
+
"lon": float(elem.attrib["lon"]),
|
230 |
+
"visible": elem.attrib.get("visible"),
|
231 |
+
"tags": {
|
232 |
+
x.attrib["k"]: x.attrib["v"] for x in elem if x.tag == "tag"
|
233 |
+
},
|
234 |
+
}
|
235 |
+
elif elem.tag == "way":
|
236 |
+
item = {
|
237 |
+
"id": int(elem.attrib["id"]),
|
238 |
+
"visible": elem.attrib.get("visible"),
|
239 |
+
"tags": {
|
240 |
+
x.attrib["k"]: x.attrib["v"] for x in elem if x.tag == "tag"
|
241 |
+
},
|
242 |
+
"nodes": [int(x.attrib["ref"]) for x in elem if x.tag == "nd"],
|
243 |
+
}
|
244 |
+
elif elem.tag == "relation":
|
245 |
+
item = {
|
246 |
+
"id": int(elem.attrib["id"]),
|
247 |
+
"visible": elem.attrib.get("visible"),
|
248 |
+
"tags": {
|
249 |
+
x.attrib["k"]: x.attrib["v"] for x in elem if x.tag == "tag"
|
250 |
+
},
|
251 |
+
"members": [
|
252 |
+
{
|
253 |
+
"type": x.attrib["type"],
|
254 |
+
"ref": int(x.attrib["ref"]),
|
255 |
+
"role": x.attrib["role"],
|
256 |
+
}
|
257 |
+
for x in elem
|
258 |
+
if x.tag == "member"
|
259 |
+
],
|
260 |
+
}
|
261 |
+
item["type"] = elem.tag
|
262 |
+
structure["elements"].append(item)
|
263 |
+
elem.clear()
|
264 |
+
del root
|
265 |
+
return cls.from_dict(structure)
|
266 |
+
|
267 |
+
@classmethod
|
268 |
+
def from_file(cls, path: Path):
|
269 |
+
ext = path.suffix
|
270 |
+
if ext == ".json":
|
271 |
+
return cls.from_json(path)
|
272 |
+
elif ext in {".osm", ".xml"}:
|
273 |
+
return cls.from_xml(path)
|
274 |
+
else:
|
275 |
+
raise ValueError(f"Unknown extension for {path}")
|
276 |
+
|
277 |
+
def add_node(self, node: OSMNode):
|
278 |
+
"""Add node and update map parameters."""
|
279 |
+
if node.id_ in self.nodes:
|
280 |
+
raise ValueError(f"Node with duplicate id {node.id_}.")
|
281 |
+
self.nodes[node.id_] = node
|
282 |
+
|
283 |
+
def add_way(self, way: OSMWay):
|
284 |
+
"""Add way and update map parameters."""
|
285 |
+
if way.id_ in self.ways:
|
286 |
+
raise ValueError(f"Way with duplicate id {way.id_}.")
|
287 |
+
self.ways[way.id_] = way
|
288 |
+
|
289 |
+
def add_relation(self, relation: OSMRelation):
|
290 |
+
"""Add relation and update map parameters."""
|
291 |
+
if relation.id_ in self.relations:
|
292 |
+
raise ValueError(f"Relation with duplicate id {relation.id_}.")
|
293 |
+
self.relations[relation.id_] = relation
|
294 |
+
|
295 |
+
def add_xy_to_nodes(self, proj: Projection):
|
296 |
+
nodes = list(self.nodes.values())
|
297 |
+
if len(nodes) == 0:
|
298 |
+
return
|
299 |
+
geos = np.stack([n.geo for n in nodes], 0)
|
300 |
+
if proj.bounds is not None:
|
301 |
+
# For some reasons few nodes are sometimes very far off the initial bbox.
|
302 |
+
valid = proj.bounds.contains(geos)
|
303 |
+
if valid.mean() < 0.9:
|
304 |
+
print("Many nodes are out of the projection bounds.")
|
305 |
+
xys = np.zeros_like(geos)
|
306 |
+
xys[valid] = proj.project(geos[valid])
|
307 |
+
else:
|
308 |
+
xys = proj.project(geos)
|
309 |
+
for xy, node in zip(xys, nodes):
|
310 |
+
node.xy = xy
|
osm/tiling.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import io
|
4 |
+
import pickle
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Dict, List, Optional, Tuple
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image
|
10 |
+
import rtree
|
11 |
+
|
12 |
+
from utils.geo import BoundaryBox, Projection
|
13 |
+
from .data import MapData
|
14 |
+
from .download import get_osm
|
15 |
+
from .parser import Groups
|
16 |
+
from .raster import Canvas, render_raster_map, render_raster_masks
|
17 |
+
from .reader import OSMData, OSMNode, OSMWay
|
18 |
+
|
19 |
+
|
20 |
+
class MapIndex:
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
data: MapData,
|
24 |
+
):
|
25 |
+
self.index_nodes = rtree.index.Index()
|
26 |
+
for i, node in data.nodes.items():
|
27 |
+
self.index_nodes.insert(i, tuple(node.xy) * 2)
|
28 |
+
|
29 |
+
self.index_lines = rtree.index.Index()
|
30 |
+
for i, line in data.lines.items():
|
31 |
+
bbox = tuple(np.r_[line.xy.min(0), line.xy.max(0)])
|
32 |
+
self.index_lines.insert(i, bbox)
|
33 |
+
|
34 |
+
self.index_areas = rtree.index.Index()
|
35 |
+
for i, area in data.areas.items():
|
36 |
+
xy = np.concatenate(area.outers + area.inners)
|
37 |
+
bbox = tuple(np.r_[xy.min(0), xy.max(0)])
|
38 |
+
self.index_areas.insert(i, bbox)
|
39 |
+
|
40 |
+
self.data = data
|
41 |
+
|
42 |
+
def query(self, bbox: BoundaryBox) -> Tuple[List[OSMNode], List[OSMWay]]:
|
43 |
+
query = tuple(np.r_[bbox.min_, bbox.max_])
|
44 |
+
ret = []
|
45 |
+
for x in ["nodes", "lines", "areas"]:
|
46 |
+
ids = getattr(self, "index_" + x).intersection(query)
|
47 |
+
ret.append([getattr(self.data, x)[i] for i in ids])
|
48 |
+
return tuple(ret)
|
49 |
+
|
50 |
+
|
51 |
+
def bbox_to_slice(bbox: BoundaryBox, canvas: Canvas):
|
52 |
+
uv_min = np.ceil(canvas.to_uv(bbox.min_)).astype(int)
|
53 |
+
uv_max = np.ceil(canvas.to_uv(bbox.max_)).astype(int)
|
54 |
+
slice_ = (slice(uv_max[1], uv_min[1]), slice(uv_min[0], uv_max[0]))
|
55 |
+
return slice_
|
56 |
+
|
57 |
+
|
58 |
+
def round_bbox(bbox: BoundaryBox, origin: np.ndarray, ppm: int):
|
59 |
+
bbox = bbox.translate(-origin)
|
60 |
+
bbox = BoundaryBox(np.round(bbox.min_ * ppm) / ppm, np.round(bbox.max_ * ppm) / ppm)
|
61 |
+
return bbox.translate(origin)
|
62 |
+
|
63 |
+
class MapTileManager:
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
osmpath:Path,
|
67 |
+
):
|
68 |
+
|
69 |
+
self.osm = OSMData.from_file(osmpath)
|
70 |
+
|
71 |
+
|
72 |
+
# @classmethod
|
73 |
+
def from_bbox(
|
74 |
+
self,
|
75 |
+
projection: Projection,
|
76 |
+
bbox: BoundaryBox,
|
77 |
+
ppm: int,
|
78 |
+
tile_size: int = 128,
|
79 |
+
):
|
80 |
+
# bbox_osm = projection.unproject(bbox)
|
81 |
+
# if path is not None and path.is_file():
|
82 |
+
# print(OSMData.from_file)
|
83 |
+
# osm = OSMData.from_file(path)
|
84 |
+
# if osm.box is not None:
|
85 |
+
# assert osm.box.contains(bbox_osm)
|
86 |
+
# else:
|
87 |
+
# osm = OSMData.from_dict(get_osm(bbox_osm, path))
|
88 |
+
|
89 |
+
self.osm.add_xy_to_nodes(projection)
|
90 |
+
map_data = MapData.from_osm(self.osm)
|
91 |
+
map_index = MapIndex(map_data)
|
92 |
+
|
93 |
+
bounds_x, bounds_y = [
|
94 |
+
np.r_[np.arange(min_, max_, tile_size), max_]
|
95 |
+
for min_, max_ in zip(bbox.min_, bbox.max_)
|
96 |
+
]
|
97 |
+
bbox_tiles = {}
|
98 |
+
for i, xmin in enumerate(bounds_x[:-1]):
|
99 |
+
for j, ymin in enumerate(bounds_y[:-1]):
|
100 |
+
bbox_tiles[i, j] = BoundaryBox(
|
101 |
+
[xmin, ymin], [bounds_x[i + 1], bounds_y[j + 1]]
|
102 |
+
)
|
103 |
+
|
104 |
+
tiles = {}
|
105 |
+
for ij, bbox_tile in bbox_tiles.items():
|
106 |
+
canvas = Canvas(bbox_tile, ppm)
|
107 |
+
nodes, lines, areas = map_index.query(bbox_tile)
|
108 |
+
masks = render_raster_masks(nodes, lines, areas, canvas)
|
109 |
+
canvas.raster = render_raster_map(masks)
|
110 |
+
tiles[ij] = canvas
|
111 |
+
|
112 |
+
groups = {k: v for k, v in vars(Groups).items() if not k.startswith("__")}
|
113 |
+
|
114 |
+
self.origin = bbox.min_
|
115 |
+
self.bbox = bbox
|
116 |
+
self.tiles = tiles
|
117 |
+
self.tile_size = tile_size
|
118 |
+
self.ppm = ppm
|
119 |
+
self.projection = projection
|
120 |
+
self.groups = groups
|
121 |
+
self.map_data = map_data
|
122 |
+
|
123 |
+
return self.query(bbox)
|
124 |
+
# return cls(tiles, bbox, tile_size, ppm, projection, groups, map_data)
|
125 |
+
|
126 |
+
def query(self, bbox: BoundaryBox) -> Canvas:
|
127 |
+
bbox = round_bbox(bbox, self.bbox.min_, self.ppm)
|
128 |
+
canvas = Canvas(bbox, self.ppm)
|
129 |
+
raster = np.zeros((3, canvas.h, canvas.w), np.uint8)
|
130 |
+
|
131 |
+
bbox_all = bbox & self.bbox
|
132 |
+
ij_min = np.floor((bbox_all.min_ - self.origin) / self.tile_size).astype(int)
|
133 |
+
ij_max = np.ceil((bbox_all.max_ - self.origin) / self.tile_size).astype(int) - 1
|
134 |
+
for i in range(ij_min[0], ij_max[0] + 1):
|
135 |
+
for j in range(ij_min[1], ij_max[1] + 1):
|
136 |
+
tile = self.tiles[i, j]
|
137 |
+
bbox_select = tile.bbox & bbox
|
138 |
+
slice_query = bbox_to_slice(bbox_select, canvas)
|
139 |
+
slice_tile = bbox_to_slice(bbox_select, tile)
|
140 |
+
raster[(slice(None),) + slice_query] = tile.raster[
|
141 |
+
(slice(None),) + slice_tile
|
142 |
+
]
|
143 |
+
canvas.raster = raster
|
144 |
+
return canvas
|
145 |
+
|
146 |
+
def save(self, path: Path):
|
147 |
+
dump = {
|
148 |
+
"bbox": self.bbox.format(),
|
149 |
+
"tile_size": self.tile_size,
|
150 |
+
"ppm": self.ppm,
|
151 |
+
"groups": self.groups,
|
152 |
+
"tiles_bbox": {},
|
153 |
+
"tiles_raster": {},
|
154 |
+
}
|
155 |
+
if self.projection is not None:
|
156 |
+
dump["ref_latlonalt"] = self.projection.latlonalt
|
157 |
+
for ij, canvas in self.tiles.items():
|
158 |
+
dump["tiles_bbox"][ij] = canvas.bbox.format()
|
159 |
+
raster_bytes = io.BytesIO()
|
160 |
+
raster = Image.fromarray(canvas.raster.transpose(1, 2, 0).astype(np.uint8))
|
161 |
+
raster.save(raster_bytes, format="PNG")
|
162 |
+
dump["tiles_raster"][ij] = raster_bytes
|
163 |
+
with open(path, "wb") as fp:
|
164 |
+
pickle.dump(dump, fp)
|
165 |
+
|
166 |
+
@classmethod
|
167 |
+
def load(cls, path: Path):
|
168 |
+
with path.open("rb") as fp:
|
169 |
+
dump = pickle.load(fp)
|
170 |
+
tiles = {}
|
171 |
+
for ij, bbox in dump["tiles_bbox"].items():
|
172 |
+
tiles[ij] = Canvas(BoundaryBox.from_string(bbox), dump["ppm"])
|
173 |
+
raster = np.asarray(Image.open(dump["tiles_raster"][ij]))
|
174 |
+
tiles[ij].raster = raster.transpose(2, 0, 1).copy()
|
175 |
+
projection = Projection(*dump["ref_latlonalt"])
|
176 |
+
return cls(
|
177 |
+
tiles,
|
178 |
+
BoundaryBox.from_string(dump["bbox"]),
|
179 |
+
dump["tile_size"],
|
180 |
+
dump["ppm"],
|
181 |
+
projection,
|
182 |
+
dump["groups"],
|
183 |
+
)
|
184 |
+
|
185 |
+
class TileManager:
|
186 |
+
def __init__(
|
187 |
+
self,
|
188 |
+
tiles: Dict,
|
189 |
+
bbox: BoundaryBox,
|
190 |
+
tile_size: int,
|
191 |
+
ppm: int,
|
192 |
+
projection: Projection,
|
193 |
+
groups: Dict[str, List[str]],
|
194 |
+
map_data: Optional[MapData] = None,
|
195 |
+
):
|
196 |
+
self.origin = bbox.min_
|
197 |
+
self.bbox = bbox
|
198 |
+
self.tiles = tiles
|
199 |
+
self.tile_size = tile_size
|
200 |
+
self.ppm = ppm
|
201 |
+
self.projection = projection
|
202 |
+
self.groups = groups
|
203 |
+
self.map_data = map_data
|
204 |
+
assert np.all(tiles[0, 0].bbox.min_ == self.origin)
|
205 |
+
for tile in tiles.values():
|
206 |
+
assert bbox.contains(tile.bbox)
|
207 |
+
|
208 |
+
@classmethod
|
209 |
+
def from_bbox(
|
210 |
+
cls,
|
211 |
+
projection: Projection,
|
212 |
+
bbox: BoundaryBox,
|
213 |
+
ppm: int,
|
214 |
+
path: Optional[Path] = None,
|
215 |
+
tile_size: int = 128,
|
216 |
+
):
|
217 |
+
bbox_osm = projection.unproject(bbox)
|
218 |
+
if path is not None and path.is_file():
|
219 |
+
print(OSMData.from_file)
|
220 |
+
osm = OSMData.from_file(path)
|
221 |
+
if osm.box is not None:
|
222 |
+
assert osm.box.contains(bbox_osm)
|
223 |
+
else:
|
224 |
+
osm = OSMData.from_dict(get_osm(bbox_osm, path))
|
225 |
+
|
226 |
+
osm.add_xy_to_nodes(projection)
|
227 |
+
map_data = MapData.from_osm(osm)
|
228 |
+
map_index = MapIndex(map_data)
|
229 |
+
|
230 |
+
bounds_x, bounds_y = [
|
231 |
+
np.r_[np.arange(min_, max_, tile_size), max_]
|
232 |
+
for min_, max_ in zip(bbox.min_, bbox.max_)
|
233 |
+
]
|
234 |
+
bbox_tiles = {}
|
235 |
+
for i, xmin in enumerate(bounds_x[:-1]):
|
236 |
+
for j, ymin in enumerate(bounds_y[:-1]):
|
237 |
+
bbox_tiles[i, j] = BoundaryBox(
|
238 |
+
[xmin, ymin], [bounds_x[i + 1], bounds_y[j + 1]]
|
239 |
+
)
|
240 |
+
|
241 |
+
tiles = {}
|
242 |
+
for ij, bbox_tile in bbox_tiles.items():
|
243 |
+
canvas = Canvas(bbox_tile, ppm)
|
244 |
+
nodes, lines, areas = map_index.query(bbox_tile)
|
245 |
+
masks = render_raster_masks(nodes, lines, areas, canvas)
|
246 |
+
canvas.raster = render_raster_map(masks)
|
247 |
+
tiles[ij] = canvas
|
248 |
+
|
249 |
+
groups = {k: v for k, v in vars(Groups).items() if not k.startswith("__")}
|
250 |
+
|
251 |
+
return cls(tiles, bbox, tile_size, ppm, projection, groups, map_data)
|
252 |
+
|
253 |
+
def query(self, bbox: BoundaryBox) -> Canvas:
|
254 |
+
bbox = round_bbox(bbox, self.bbox.min_, self.ppm)
|
255 |
+
canvas = Canvas(bbox, self.ppm)
|
256 |
+
raster = np.zeros((3, canvas.h, canvas.w), np.uint8)
|
257 |
+
|
258 |
+
bbox_all = bbox & self.bbox
|
259 |
+
ij_min = np.floor((bbox_all.min_ - self.origin) / self.tile_size).astype(int)
|
260 |
+
ij_max = np.ceil((bbox_all.max_ - self.origin) / self.tile_size).astype(int) - 1
|
261 |
+
for i in range(ij_min[0], ij_max[0] + 1):
|
262 |
+
for j in range(ij_min[1], ij_max[1] + 1):
|
263 |
+
tile = self.tiles[i, j]
|
264 |
+
bbox_select = tile.bbox & bbox
|
265 |
+
slice_query = bbox_to_slice(bbox_select, canvas)
|
266 |
+
slice_tile = bbox_to_slice(bbox_select, tile)
|
267 |
+
raster[(slice(None),) + slice_query] = tile.raster[
|
268 |
+
(slice(None),) + slice_tile
|
269 |
+
]
|
270 |
+
canvas.raster = raster
|
271 |
+
return canvas
|
272 |
+
|
273 |
+
def save(self, path: Path):
|
274 |
+
dump = {
|
275 |
+
"bbox": self.bbox.format(),
|
276 |
+
"tile_size": self.tile_size,
|
277 |
+
"ppm": self.ppm,
|
278 |
+
"groups": self.groups,
|
279 |
+
"tiles_bbox": {},
|
280 |
+
"tiles_raster": {},
|
281 |
+
}
|
282 |
+
if self.projection is not None:
|
283 |
+
dump["ref_latlonalt"] = self.projection.latlonalt
|
284 |
+
for ij, canvas in self.tiles.items():
|
285 |
+
dump["tiles_bbox"][ij] = canvas.bbox.format()
|
286 |
+
raster_bytes = io.BytesIO()
|
287 |
+
raster = Image.fromarray(canvas.raster.transpose(1, 2, 0).astype(np.uint8))
|
288 |
+
raster.save(raster_bytes, format="PNG")
|
289 |
+
dump["tiles_raster"][ij] = raster_bytes
|
290 |
+
with open(path, "wb") as fp:
|
291 |
+
pickle.dump(dump, fp)
|
292 |
+
|
293 |
+
@classmethod
|
294 |
+
def load(cls, path: Path):
|
295 |
+
with path.open("rb") as fp:
|
296 |
+
dump = pickle.load(fp)
|
297 |
+
tiles = {}
|
298 |
+
for ij, bbox in dump["tiles_bbox"].items():
|
299 |
+
tiles[ij] = Canvas(BoundaryBox.from_string(bbox), dump["ppm"])
|
300 |
+
raster = np.asarray(Image.open(dump["tiles_raster"][ij]))
|
301 |
+
tiles[ij].raster = raster.transpose(2, 0, 1).copy()
|
302 |
+
projection = Projection(*dump["ref_latlonalt"])
|
303 |
+
return cls(
|
304 |
+
tiles,
|
305 |
+
BoundaryBox.from_string(dump["bbox"]),
|
306 |
+
dump["tile_size"],
|
307 |
+
dump["ppm"],
|
308 |
+
projection,
|
309 |
+
dump["groups"],
|
310 |
+
)
|
osm/viz.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
import matplotlib as mpl
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
import plotly.graph_objects as go
|
7 |
+
import PIL.Image
|
8 |
+
|
9 |
+
from utils.viz_2d import add_text
|
10 |
+
from .parser import Groups
|
11 |
+
|
12 |
+
|
13 |
+
class GeoPlotter:
|
14 |
+
def __init__(self, zoom=12, **kwargs):
|
15 |
+
self.fig = go.Figure()
|
16 |
+
self.fig.update_layout(
|
17 |
+
mapbox_style="open-street-map",
|
18 |
+
autosize=True,
|
19 |
+
mapbox_zoom=zoom,
|
20 |
+
margin={"r": 0, "t": 0, "l": 0, "b": 0},
|
21 |
+
showlegend=True,
|
22 |
+
**kwargs,
|
23 |
+
)
|
24 |
+
|
25 |
+
def points(self, latlons, color, text=None, name=None, size=5, **kwargs):
|
26 |
+
latlons = np.asarray(latlons)
|
27 |
+
self.fig.add_trace(
|
28 |
+
go.Scattermapbox(
|
29 |
+
lat=latlons[..., 0],
|
30 |
+
lon=latlons[..., 1],
|
31 |
+
mode="markers",
|
32 |
+
text=text,
|
33 |
+
marker_color=color,
|
34 |
+
marker_size=size,
|
35 |
+
name=name,
|
36 |
+
**kwargs,
|
37 |
+
)
|
38 |
+
)
|
39 |
+
center = latlons.reshape(-1, 2).mean(0)
|
40 |
+
self.fig.update_layout(
|
41 |
+
mapbox_center=dict(zip(("lat", "lon"), center)),
|
42 |
+
)
|
43 |
+
|
44 |
+
def bbox(self, bbox, color, name=None, **kwargs):
|
45 |
+
corners = np.stack(
|
46 |
+
[bbox.min_, bbox.left_top, bbox.max_, bbox.right_bottom, bbox.min_]
|
47 |
+
)
|
48 |
+
self.fig.add_trace(
|
49 |
+
go.Scattermapbox(
|
50 |
+
lat=corners[:, 0],
|
51 |
+
lon=corners[:, 1],
|
52 |
+
mode="lines",
|
53 |
+
marker_color=color,
|
54 |
+
name=name,
|
55 |
+
**kwargs,
|
56 |
+
)
|
57 |
+
)
|
58 |
+
self.fig.update_layout(
|
59 |
+
mapbox_center=dict(zip(("lat", "lon"), bbox.center)),
|
60 |
+
)
|
61 |
+
|
62 |
+
def raster(self, raster, bbox, below="traces", **kwargs):
|
63 |
+
if not np.issubdtype(raster.dtype, np.integer):
|
64 |
+
raster = (raster * 255).astype(np.uint8)
|
65 |
+
raster = PIL.Image.fromarray(raster)
|
66 |
+
corners = np.stack(
|
67 |
+
[
|
68 |
+
bbox.min_,
|
69 |
+
bbox.left_top,
|
70 |
+
bbox.max_,
|
71 |
+
bbox.right_bottom,
|
72 |
+
]
|
73 |
+
)[::-1, ::-1]
|
74 |
+
layers = [*self.fig.layout.mapbox.layers]
|
75 |
+
layers.append(
|
76 |
+
dict(
|
77 |
+
sourcetype="image",
|
78 |
+
source=raster,
|
79 |
+
coordinates=corners,
|
80 |
+
below=below,
|
81 |
+
**kwargs,
|
82 |
+
)
|
83 |
+
)
|
84 |
+
self.fig.layout.mapbox.layers = layers
|
85 |
+
|
86 |
+
|
87 |
+
map_colors = {
|
88 |
+
"building": (84, 155, 255),
|
89 |
+
"parking": (255, 229, 145),
|
90 |
+
"playground": (150, 133, 125),
|
91 |
+
"grass": (188, 255, 143),
|
92 |
+
"park": (0, 158, 16),
|
93 |
+
"forest": (0, 92, 9),
|
94 |
+
"water": (184, 213, 255),
|
95 |
+
"fence": (238, 0, 255),
|
96 |
+
"wall": (0, 0, 0),
|
97 |
+
"hedge": (107, 68, 48),
|
98 |
+
"kerb": (255, 234, 0),
|
99 |
+
"building_outline": (0, 0, 255),
|
100 |
+
"cycleway": (0, 251, 255),
|
101 |
+
"path": (8, 237, 0),
|
102 |
+
"road": (255, 0, 0),
|
103 |
+
"tree_row": (0, 92, 9),
|
104 |
+
"busway": (255, 128, 0),
|
105 |
+
"void": [int(255 * 0.9)] * 3,
|
106 |
+
}
|
107 |
+
|
108 |
+
|
109 |
+
class Colormap:
|
110 |
+
colors_areas = np.stack([map_colors[k] for k in ["void"] + Groups.areas])
|
111 |
+
colors_ways = np.stack([map_colors[k] for k in ["void"] + Groups.ways])
|
112 |
+
|
113 |
+
@classmethod
|
114 |
+
def apply(cls, rasters):
|
115 |
+
return (
|
116 |
+
np.where(
|
117 |
+
rasters[1, ..., None] > 0,
|
118 |
+
cls.colors_ways[rasters[1]],
|
119 |
+
cls.colors_areas[rasters[0]],
|
120 |
+
)
|
121 |
+
/ 255.0
|
122 |
+
)
|
123 |
+
|
124 |
+
@classmethod
|
125 |
+
def add_colorbar(cls):
|
126 |
+
ax2 = plt.gcf().add_axes([1, 0.1, 0.02, 0.8])
|
127 |
+
color_list = np.r_[cls.colors_areas[1:], cls.colors_ways[1:]] / 255.0
|
128 |
+
cmap = mpl.colors.ListedColormap(color_list[::-1])
|
129 |
+
ticks = np.linspace(0, 1, len(color_list), endpoint=False)
|
130 |
+
ticks += 1 / len(color_list) / 2
|
131 |
+
cb = mpl.colorbar.ColorbarBase(
|
132 |
+
ax2,
|
133 |
+
cmap=cmap,
|
134 |
+
orientation="vertical",
|
135 |
+
ticks=ticks,
|
136 |
+
)
|
137 |
+
cb.set_ticklabels((Groups.areas + Groups.ways)[::-1])
|
138 |
+
ax2.tick_params(labelsize=15)
|
139 |
+
|
140 |
+
|
141 |
+
def plot_nodes(idx, raster, fontsize=8, size=15):
|
142 |
+
ax = plt.gcf().axes[idx]
|
143 |
+
ax.autoscale(enable=False)
|
144 |
+
nodes_xy = np.stack(np.where(raster > 0)[::-1], -1)
|
145 |
+
nodes_val = raster[tuple(nodes_xy.T[::-1])] - 1
|
146 |
+
ax.scatter(*nodes_xy.T, c="k", s=size)
|
147 |
+
for xy, val in zip(nodes_xy, nodes_val):
|
148 |
+
group = Groups.nodes[val]
|
149 |
+
add_text(
|
150 |
+
idx,
|
151 |
+
group,
|
152 |
+
xy + 2,
|
153 |
+
lcolor=None,
|
154 |
+
fs=fontsize,
|
155 |
+
color="k",
|
156 |
+
normalized=False,
|
157 |
+
ha="center",
|
158 |
+
)
|
159 |
+
plt.show()
|
requirements.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
numpy
|
4 |
+
opencv-python
|
5 |
+
Pillow
|
6 |
+
tqdm>=4.36.0
|
7 |
+
matplotlib
|
8 |
+
plotly
|
9 |
+
scipy
|
10 |
+
omegaconf
|
11 |
+
pytorch-lightning
|
12 |
+
torchmetrics
|
13 |
+
jupyter
|
14 |
+
lxml
|
15 |
+
rtree
|
16 |
+
scikit-learn
|
17 |
+
geopy
|
18 |
+
exifread
|
19 |
+
gradio_client
|
20 |
+
urllib3>=2
|
train.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
import warnings
|
3 |
+
warnings.filterwarnings('ignore')
|
4 |
+
from typing import Optional
|
5 |
+
from pathlib import Path
|
6 |
+
from models.maplocnet import MapLocNet
|
7 |
+
import hydra
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
import torch
|
10 |
+
from omegaconf import DictConfig, OmegaConf
|
11 |
+
from pytorch_lightning.utilities import rank_zero_only
|
12 |
+
from module import GenericModule
|
13 |
+
from logger import logger, pl_logger, EXPERIMENTS_PATH
|
14 |
+
from module import GenericModule
|
15 |
+
from dataset import UavMapDatasetModule
|
16 |
+
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
17 |
+
# print(osp.join(osp.dirname(__file__), "conf"))
|
18 |
+
|
19 |
+
|
20 |
+
class CleanProgressBar(pl.callbacks.TQDMProgressBar):
|
21 |
+
def get_metrics(self, trainer, model):
|
22 |
+
items = super().get_metrics(trainer, model)
|
23 |
+
items.pop("v_num", None) # don't show the version number
|
24 |
+
items.pop("loss", None)
|
25 |
+
return items
|
26 |
+
|
27 |
+
|
28 |
+
class SeedingCallback(pl.callbacks.Callback):
|
29 |
+
def on_epoch_start_(self, trainer, module):
|
30 |
+
seed = module.cfg.experiment.seed
|
31 |
+
is_overfit = module.cfg.training.trainer.get("overfit_batches", 0) > 0
|
32 |
+
if trainer.training and not is_overfit:
|
33 |
+
seed = seed + trainer.current_epoch
|
34 |
+
|
35 |
+
# Temporarily disable the logging (does not seem to work?)
|
36 |
+
pl_logger.disabled = True
|
37 |
+
try:
|
38 |
+
pl.seed_everything(seed, workers=True)
|
39 |
+
finally:
|
40 |
+
pl_logger.disabled = False
|
41 |
+
|
42 |
+
def on_train_epoch_start(self, *args, **kwargs):
|
43 |
+
self.on_epoch_start_(*args, **kwargs)
|
44 |
+
|
45 |
+
def on_validation_epoch_start(self, *args, **kwargs):
|
46 |
+
self.on_epoch_start_(*args, **kwargs)
|
47 |
+
|
48 |
+
def on_test_epoch_start(self, *args, **kwargs):
|
49 |
+
self.on_epoch_start_(*args, **kwargs)
|
50 |
+
|
51 |
+
|
52 |
+
class ConsoleLogger(pl.callbacks.Callback):
|
53 |
+
@rank_zero_only
|
54 |
+
def on_train_epoch_start(self, trainer, module):
|
55 |
+
logger.info(
|
56 |
+
"New training epoch %d for experiment '%s'.",
|
57 |
+
module.current_epoch,
|
58 |
+
module.cfg.experiment.name,
|
59 |
+
)
|
60 |
+
|
61 |
+
# @rank_zero_only
|
62 |
+
# def on_validation_epoch_end(self, trainer, module):
|
63 |
+
# results = {
|
64 |
+
# **dict(module.metrics_val.items()),
|
65 |
+
# **dict(module.losses_val.items()),
|
66 |
+
# }
|
67 |
+
# results = [f"{k} {v.compute():.3E}" for k, v in results.items()]
|
68 |
+
# logger.info(f'[Validation] {{{", ".join(results)}}}')
|
69 |
+
|
70 |
+
|
71 |
+
def find_last_checkpoint_path(experiment_dir):
|
72 |
+
cls = pl.callbacks.ModelCheckpoint
|
73 |
+
path = osp.join(experiment_dir, cls.CHECKPOINT_NAME_LAST + cls.FILE_EXTENSION)
|
74 |
+
if osp.exists(path):
|
75 |
+
return path
|
76 |
+
else:
|
77 |
+
return None
|
78 |
+
|
79 |
+
|
80 |
+
def prepare_experiment_dir(experiment_dir, cfg, rank):
|
81 |
+
config_path = osp.join(experiment_dir, "config.yaml")
|
82 |
+
last_checkpoint_path = find_last_checkpoint_path(experiment_dir)
|
83 |
+
if last_checkpoint_path is not None:
|
84 |
+
if rank == 0:
|
85 |
+
logger.info(
|
86 |
+
"Resuming the training from checkpoint %s", last_checkpoint_path
|
87 |
+
)
|
88 |
+
if osp.exists(config_path):
|
89 |
+
with open(config_path, "r") as fp:
|
90 |
+
cfg_prev = OmegaConf.create(fp.read())
|
91 |
+
compare_keys = ["experiment", "data", "model", "training"]
|
92 |
+
if OmegaConf.masked_copy(cfg, compare_keys) != OmegaConf.masked_copy(
|
93 |
+
cfg_prev, compare_keys
|
94 |
+
):
|
95 |
+
raise ValueError(
|
96 |
+
"Attempting to resume training with a different config: "
|
97 |
+
f"{OmegaConf.masked_copy(cfg, compare_keys)} vs "
|
98 |
+
f"{OmegaConf.masked_copy(cfg_prev, compare_keys)}"
|
99 |
+
)
|
100 |
+
if rank == 0:
|
101 |
+
Path(experiment_dir).mkdir(exist_ok=True, parents=True)
|
102 |
+
with open(config_path, "w") as fp:
|
103 |
+
OmegaConf.save(cfg, fp)
|
104 |
+
return last_checkpoint_path
|
105 |
+
|
106 |
+
|
107 |
+
def train(cfg: DictConfig) -> None:
|
108 |
+
torch.set_float32_matmul_precision("medium")
|
109 |
+
OmegaConf.resolve(cfg)
|
110 |
+
rank = rank_zero_only.rank
|
111 |
+
|
112 |
+
if rank == 0:
|
113 |
+
logger.info("Starting training with config:\n%s", OmegaConf.to_yaml(cfg))
|
114 |
+
if cfg.experiment.gpus in (None, 0):
|
115 |
+
logger.warning("Will train on CPU...")
|
116 |
+
cfg.experiment.gpus = 0
|
117 |
+
elif not torch.cuda.is_available():
|
118 |
+
raise ValueError("Requested GPU but no NVIDIA drivers found.")
|
119 |
+
pl.seed_everything(cfg.experiment.seed, workers=True)
|
120 |
+
|
121 |
+
init_checkpoint_path = cfg.training.get("finetune_from_checkpoint")
|
122 |
+
if init_checkpoint_path is not None:
|
123 |
+
logger.info("Initializing the model from checkpoint %s.", init_checkpoint_path)
|
124 |
+
model = GenericModule.load_from_checkpoint(
|
125 |
+
init_checkpoint_path, strict=True, find_best=False, cfg=cfg
|
126 |
+
)
|
127 |
+
else:
|
128 |
+
model = GenericModule(cfg)
|
129 |
+
if rank == 0:
|
130 |
+
logger.info("Network:\n%s", model.model)
|
131 |
+
|
132 |
+
experiment_dir = osp.join(EXPERIMENTS_PATH, cfg.experiment.name)
|
133 |
+
last_checkpoint_path = prepare_experiment_dir(experiment_dir, cfg, rank)
|
134 |
+
checkpointing_epoch = pl.callbacks.ModelCheckpoint(
|
135 |
+
dirpath=experiment_dir,
|
136 |
+
filename="checkpoint-epoch-{epoch:02d}-loss-{loss/total/val:02f}",
|
137 |
+
auto_insert_metric_name=False,
|
138 |
+
save_last=True,
|
139 |
+
every_n_epochs=1,
|
140 |
+
save_on_train_epoch_end=True,
|
141 |
+
verbose=True,
|
142 |
+
**cfg.training.checkpointing,
|
143 |
+
)
|
144 |
+
checkpointing_step = pl.callbacks.ModelCheckpoint(
|
145 |
+
dirpath=experiment_dir,
|
146 |
+
filename="checkpoint-step-{step}-{loss/total/val:02f}",
|
147 |
+
auto_insert_metric_name=False,
|
148 |
+
save_last=True,
|
149 |
+
every_n_train_steps=1000,
|
150 |
+
verbose=True,
|
151 |
+
**cfg.training.checkpointing,
|
152 |
+
)
|
153 |
+
checkpointing_step.CHECKPOINT_NAME_LAST = "last-step-checkpointing"
|
154 |
+
|
155 |
+
# 创建 EarlyStopping 回调
|
156 |
+
early_stopping_callback = EarlyStopping(monitor=cfg.training.checkpointing.monitor, patience=5)
|
157 |
+
|
158 |
+
strategy = None
|
159 |
+
if cfg.experiment.gpus > 1:
|
160 |
+
strategy = pl.strategies.DDPStrategy(find_unused_parameters=False)
|
161 |
+
for split in ["train", "val"]:
|
162 |
+
cfg.data[split].batch_size = (
|
163 |
+
cfg.data[split].batch_size // cfg.experiment.gpus
|
164 |
+
)
|
165 |
+
cfg.data[split].num_workers = int(
|
166 |
+
(cfg.data[split].num_workers + cfg.experiment.gpus - 1)
|
167 |
+
/ cfg.experiment.gpus
|
168 |
+
)
|
169 |
+
|
170 |
+
# data = data_modules[cfg.data.get("name", "mapillary")](cfg.data)
|
171 |
+
|
172 |
+
datamodule =UavMapDatasetModule(cfg.data)
|
173 |
+
|
174 |
+
tb_args = {"name": cfg.experiment.name, "version": ""}
|
175 |
+
tb = pl.loggers.TensorBoardLogger(EXPERIMENTS_PATH, **tb_args)
|
176 |
+
|
177 |
+
callbacks = [
|
178 |
+
checkpointing_epoch,
|
179 |
+
checkpointing_step,
|
180 |
+
# early_stopping_callback,
|
181 |
+
pl.callbacks.LearningRateMonitor(),
|
182 |
+
SeedingCallback(),
|
183 |
+
CleanProgressBar(),
|
184 |
+
ConsoleLogger(),
|
185 |
+
]
|
186 |
+
if cfg.experiment.gpus > 0:
|
187 |
+
callbacks.append(pl.callbacks.DeviceStatsMonitor())
|
188 |
+
|
189 |
+
trainer = pl.Trainer(
|
190 |
+
default_root_dir=experiment_dir,
|
191 |
+
detect_anomaly=False,
|
192 |
+
# strategy=ddp_find_unused_parameters_true,
|
193 |
+
enable_model_summary=True,
|
194 |
+
sync_batchnorm=True,
|
195 |
+
enable_checkpointing=True,
|
196 |
+
logger=tb,
|
197 |
+
callbacks=callbacks,
|
198 |
+
strategy=strategy,
|
199 |
+
check_val_every_n_epoch=1,
|
200 |
+
accelerator="gpu",
|
201 |
+
num_nodes=1,
|
202 |
+
**cfg.training.trainer,
|
203 |
+
)
|
204 |
+
trainer.fit(model=model, datamodule=datamodule, ckpt_path=last_checkpoint_path)
|
205 |
+
|
206 |
+
|
207 |
+
@hydra.main(
|
208 |
+
config_path=osp.join(osp.dirname(__file__), "conf"), config_name="maplocnet.yaml"
|
209 |
+
)
|
210 |
+
def main(cfg: DictConfig) -> None:
|
211 |
+
OmegaConf.save(config=cfg, f='maplocnet.yaml')
|
212 |
+
train(cfg)
|
213 |
+
|
214 |
+
|
215 |
+
if __name__ == "__main__":
|
216 |
+
main()
|
217 |
+
|
train.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
nohup python train.py > logs/train0907.log 2>&1 &
|
utils/exif.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Copied from opensfm.exif to minimize hard dependencies."""
|
2 |
+
from pathlib import Path
|
3 |
+
import json
|
4 |
+
import datetime
|
5 |
+
import logging
|
6 |
+
from codecs import encode, decode
|
7 |
+
from typing import Any, Dict, Optional, Tuple
|
8 |
+
|
9 |
+
import exifread
|
10 |
+
|
11 |
+
logger: logging.Logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
inch_in_mm = 25.4
|
14 |
+
cm_in_mm = 10
|
15 |
+
um_in_mm = 0.001
|
16 |
+
default_projection = "perspective"
|
17 |
+
maximum_altitude = 1e4
|
18 |
+
|
19 |
+
|
20 |
+
def sensor_data():
|
21 |
+
with (Path(__file__).parent / "sensor_data.json").open() as fid:
|
22 |
+
data = json.load(fid)
|
23 |
+
return {k.lower(): v for k, v in data.items()}
|
24 |
+
|
25 |
+
|
26 |
+
def eval_frac(value) -> Optional[float]:
|
27 |
+
try:
|
28 |
+
return float(value.num) / float(value.den)
|
29 |
+
except ZeroDivisionError:
|
30 |
+
return None
|
31 |
+
|
32 |
+
|
33 |
+
def gps_to_decimal(values, reference) -> Optional[float]:
|
34 |
+
sign = 1 if reference in "NE" else -1
|
35 |
+
degrees = eval_frac(values[0])
|
36 |
+
minutes = eval_frac(values[1])
|
37 |
+
seconds = eval_frac(values[2])
|
38 |
+
if degrees is not None and minutes is not None and seconds is not None:
|
39 |
+
return sign * (degrees + minutes / 60 + seconds / 3600)
|
40 |
+
return None
|
41 |
+
|
42 |
+
|
43 |
+
def get_tag_as_float(tags, key, index: int = 0) -> Optional[float]:
|
44 |
+
if key in tags:
|
45 |
+
val = tags[key].values[index]
|
46 |
+
if isinstance(val, exifread.utils.Ratio):
|
47 |
+
ret_val = eval_frac(val)
|
48 |
+
if ret_val is None:
|
49 |
+
logger.error(
|
50 |
+
'The rational "{2}" of tag "{0:s}" at index {1:d} c'
|
51 |
+
"aused a division by zero error".format(key, index, val)
|
52 |
+
)
|
53 |
+
return ret_val
|
54 |
+
else:
|
55 |
+
return float(val)
|
56 |
+
else:
|
57 |
+
return None
|
58 |
+
|
59 |
+
|
60 |
+
def compute_focal(
|
61 |
+
focal_35: Optional[float], focal: Optional[float], sensor_width, sensor_string
|
62 |
+
) -> Tuple[float, float]:
|
63 |
+
if focal_35 is not None and focal_35 > 0:
|
64 |
+
focal_ratio = focal_35 / 36.0 # 35mm film produces 36x24mm pictures.
|
65 |
+
else:
|
66 |
+
if not sensor_width:
|
67 |
+
sensor_width = sensor_data().get(sensor_string, None)
|
68 |
+
if sensor_width and focal:
|
69 |
+
focal_ratio = focal / sensor_width
|
70 |
+
focal_35 = 36.0 * focal_ratio
|
71 |
+
else:
|
72 |
+
focal_35 = 0.0
|
73 |
+
focal_ratio = 0.0
|
74 |
+
return focal_35, focal_ratio
|
75 |
+
|
76 |
+
|
77 |
+
def sensor_string(make: str, model: str) -> str:
|
78 |
+
if make != "unknown":
|
79 |
+
# remove duplicate 'make' information in 'model'
|
80 |
+
model = model.replace(make, "")
|
81 |
+
return (make.strip() + " " + model.strip()).strip().lower()
|
82 |
+
|
83 |
+
|
84 |
+
def unescape_string(s) -> str:
|
85 |
+
return decode(encode(s, "latin-1", "backslashreplace"), "unicode-escape")
|
86 |
+
|
87 |
+
|
88 |
+
class EXIF:
|
89 |
+
def __init__(
|
90 |
+
self, fileobj, image_size_loader, use_exif_size=True, name=None
|
91 |
+
) -> None:
|
92 |
+
self.image_size_loader = image_size_loader
|
93 |
+
self.use_exif_size = use_exif_size
|
94 |
+
self.fileobj = fileobj
|
95 |
+
self.tags = exifread.process_file(fileobj, details=False)
|
96 |
+
fileobj.seek(0)
|
97 |
+
self.fileobj_name = self.fileobj.name if name is None else name
|
98 |
+
|
99 |
+
def extract_image_size(self) -> Tuple[int, int]:
|
100 |
+
if (
|
101 |
+
self.use_exif_size
|
102 |
+
and "EXIF ExifImageWidth" in self.tags
|
103 |
+
and "EXIF ExifImageLength" in self.tags
|
104 |
+
):
|
105 |
+
width, height = (
|
106 |
+
int(self.tags["EXIF ExifImageWidth"].values[0]),
|
107 |
+
int(self.tags["EXIF ExifImageLength"].values[0]),
|
108 |
+
)
|
109 |
+
elif (
|
110 |
+
self.use_exif_size
|
111 |
+
and "Image ImageWidth" in self.tags
|
112 |
+
and "Image ImageLength" in self.tags
|
113 |
+
):
|
114 |
+
width, height = (
|
115 |
+
int(self.tags["Image ImageWidth"].values[0]),
|
116 |
+
int(self.tags["Image ImageLength"].values[0]),
|
117 |
+
)
|
118 |
+
else:
|
119 |
+
height, width = self.image_size_loader()
|
120 |
+
return width, height
|
121 |
+
|
122 |
+
def _decode_make_model(self, value) -> str:
|
123 |
+
"""Python 2/3 compatible decoding of make/model field."""
|
124 |
+
if hasattr(value, "decode"):
|
125 |
+
try:
|
126 |
+
return value.decode("utf-8")
|
127 |
+
except UnicodeDecodeError:
|
128 |
+
return "unknown"
|
129 |
+
else:
|
130 |
+
return value
|
131 |
+
|
132 |
+
def extract_make(self) -> str:
|
133 |
+
# Camera make and model
|
134 |
+
if "EXIF LensMake" in self.tags:
|
135 |
+
make = self.tags["EXIF LensMake"].values
|
136 |
+
elif "Image Make" in self.tags:
|
137 |
+
make = self.tags["Image Make"].values
|
138 |
+
else:
|
139 |
+
make = "unknown"
|
140 |
+
return self._decode_make_model(make)
|
141 |
+
|
142 |
+
def extract_model(self) -> str:
|
143 |
+
if "EXIF LensModel" in self.tags:
|
144 |
+
model = self.tags["EXIF LensModel"].values
|
145 |
+
elif "Image Model" in self.tags:
|
146 |
+
model = self.tags["Image Model"].values
|
147 |
+
else:
|
148 |
+
model = "unknown"
|
149 |
+
return self._decode_make_model(model)
|
150 |
+
|
151 |
+
def extract_focal(self) -> Tuple[float, float]:
|
152 |
+
make, model = self.extract_make(), self.extract_model()
|
153 |
+
focal_35, focal_ratio = compute_focal(
|
154 |
+
get_tag_as_float(self.tags, "EXIF FocalLengthIn35mmFilm"),
|
155 |
+
get_tag_as_float(self.tags, "EXIF FocalLength"),
|
156 |
+
self.extract_sensor_width(),
|
157 |
+
sensor_string(make, model),
|
158 |
+
)
|
159 |
+
return focal_35, focal_ratio
|
160 |
+
|
161 |
+
def extract_sensor_width(self) -> Optional[float]:
|
162 |
+
"""Compute sensor with from width and resolution."""
|
163 |
+
if (
|
164 |
+
"EXIF FocalPlaneResolutionUnit" not in self.tags
|
165 |
+
or "EXIF FocalPlaneXResolution" not in self.tags
|
166 |
+
):
|
167 |
+
return None
|
168 |
+
resolution_unit = self.tags["EXIF FocalPlaneResolutionUnit"].values[0]
|
169 |
+
mm_per_unit = self.get_mm_per_unit(resolution_unit)
|
170 |
+
if not mm_per_unit:
|
171 |
+
return None
|
172 |
+
pixels_per_unit = get_tag_as_float(self.tags, "EXIF FocalPlaneXResolution")
|
173 |
+
if pixels_per_unit is None:
|
174 |
+
return None
|
175 |
+
if pixels_per_unit <= 0.0:
|
176 |
+
pixels_per_unit = get_tag_as_float(self.tags, "EXIF FocalPlaneYResolution")
|
177 |
+
if pixels_per_unit is None or pixels_per_unit <= 0.0:
|
178 |
+
return None
|
179 |
+
units_per_pixel = 1 / pixels_per_unit
|
180 |
+
width_in_pixels = self.extract_image_size()[0]
|
181 |
+
return width_in_pixels * units_per_pixel * mm_per_unit
|
182 |
+
|
183 |
+
def get_mm_per_unit(self, resolution_unit) -> Optional[float]:
|
184 |
+
"""Length of a resolution unit in millimeters.
|
185 |
+
|
186 |
+
Uses the values from the EXIF specs in
|
187 |
+
https://www.sno.phy.queensu.ca/~phil/exiftool/TagNames/EXIF.html
|
188 |
+
|
189 |
+
Args:
|
190 |
+
resolution_unit: the resolution unit value given in the EXIF
|
191 |
+
"""
|
192 |
+
if resolution_unit == 2: # inch
|
193 |
+
return inch_in_mm
|
194 |
+
elif resolution_unit == 3: # cm
|
195 |
+
return cm_in_mm
|
196 |
+
elif resolution_unit == 4: # mm
|
197 |
+
return 1
|
198 |
+
elif resolution_unit == 5: # um
|
199 |
+
return um_in_mm
|
200 |
+
else:
|
201 |
+
logger.warning(
|
202 |
+
"Unknown EXIF resolution unit value: {}".format(resolution_unit)
|
203 |
+
)
|
204 |
+
return None
|
205 |
+
|
206 |
+
def extract_orientation(self) -> int:
|
207 |
+
orientation = 1
|
208 |
+
if "Image Orientation" in self.tags:
|
209 |
+
value = self.tags.get("Image Orientation").values[0]
|
210 |
+
if type(value) == int and value != 0:
|
211 |
+
orientation = value
|
212 |
+
return orientation
|
213 |
+
|
214 |
+
def extract_ref_lon_lat(self) -> Tuple[str, str]:
|
215 |
+
if "GPS GPSLatitudeRef" in self.tags:
|
216 |
+
reflat = self.tags["GPS GPSLatitudeRef"].values
|
217 |
+
else:
|
218 |
+
reflat = "N"
|
219 |
+
if "GPS GPSLongitudeRef" in self.tags:
|
220 |
+
reflon = self.tags["GPS GPSLongitudeRef"].values
|
221 |
+
else:
|
222 |
+
reflon = "E"
|
223 |
+
return reflon, reflat
|
224 |
+
|
225 |
+
def extract_lon_lat(self) -> Tuple[Optional[float], Optional[float]]:
|
226 |
+
if "GPS GPSLatitude" in self.tags:
|
227 |
+
reflon, reflat = self.extract_ref_lon_lat()
|
228 |
+
lat = gps_to_decimal(self.tags["GPS GPSLatitude"].values, reflat)
|
229 |
+
lon = gps_to_decimal(self.tags["GPS GPSLongitude"].values, reflon)
|
230 |
+
else:
|
231 |
+
lon, lat = None, None
|
232 |
+
return lon, lat
|
233 |
+
|
234 |
+
def extract_altitude(self) -> Optional[float]:
|
235 |
+
if "GPS GPSAltitude" in self.tags:
|
236 |
+
alt_value = self.tags["GPS GPSAltitude"].values[0]
|
237 |
+
if isinstance(alt_value, exifread.utils.Ratio):
|
238 |
+
altitude = eval_frac(alt_value)
|
239 |
+
elif isinstance(alt_value, int):
|
240 |
+
altitude = float(alt_value)
|
241 |
+
else:
|
242 |
+
altitude = None
|
243 |
+
|
244 |
+
# Check if GPSAltitudeRef is equal to 1, which means GPSAltitude should be negative, reference: http://www.exif.org/Exif2-2.PDF#page=53
|
245 |
+
if (
|
246 |
+
"GPS GPSAltitudeRef" in self.tags
|
247 |
+
and self.tags["GPS GPSAltitudeRef"].values[0] == 1
|
248 |
+
and altitude is not None
|
249 |
+
):
|
250 |
+
altitude = -altitude
|
251 |
+
else:
|
252 |
+
altitude = None
|
253 |
+
return altitude
|
254 |
+
|
255 |
+
def extract_dop(self) -> Optional[float]:
|
256 |
+
if "GPS GPSDOP" in self.tags:
|
257 |
+
return eval_frac(self.tags["GPS GPSDOP"].values[0])
|
258 |
+
return None
|
259 |
+
|
260 |
+
def extract_geo(self) -> Dict[str, Any]:
|
261 |
+
altitude = self.extract_altitude()
|
262 |
+
dop = self.extract_dop()
|
263 |
+
lon, lat = self.extract_lon_lat()
|
264 |
+
d = {}
|
265 |
+
|
266 |
+
if lon is not None and lat is not None:
|
267 |
+
d["latitude"] = lat
|
268 |
+
d["longitude"] = lon
|
269 |
+
if altitude is not None:
|
270 |
+
d["altitude"] = min([maximum_altitude, altitude])
|
271 |
+
if dop is not None:
|
272 |
+
d["dop"] = dop
|
273 |
+
return d
|
274 |
+
|
275 |
+
def extract_capture_time(self) -> float:
|
276 |
+
if (
|
277 |
+
"GPS GPSDate" in self.tags
|
278 |
+
and "GPS GPSTimeStamp" in self.tags # Actually GPSDateStamp
|
279 |
+
):
|
280 |
+
try:
|
281 |
+
hours_f = get_tag_as_float(self.tags, "GPS GPSTimeStamp", 0)
|
282 |
+
minutes_f = get_tag_as_float(self.tags, "GPS GPSTimeStamp", 1)
|
283 |
+
if hours_f is None or minutes_f is None:
|
284 |
+
raise TypeError
|
285 |
+
hours = int(hours_f)
|
286 |
+
minutes = int(minutes_f)
|
287 |
+
seconds = get_tag_as_float(self.tags, "GPS GPSTimeStamp", 2)
|
288 |
+
gps_timestamp_string = "{0:s} {1:02d}:{2:02d}:{3:02f}".format(
|
289 |
+
self.tags["GPS GPSDate"].values, hours, minutes, seconds
|
290 |
+
)
|
291 |
+
return (
|
292 |
+
datetime.datetime.strptime(
|
293 |
+
gps_timestamp_string, "%Y:%m:%d %H:%M:%S.%f"
|
294 |
+
)
|
295 |
+
- datetime.datetime(1970, 1, 1)
|
296 |
+
).total_seconds()
|
297 |
+
except (TypeError, ValueError):
|
298 |
+
logger.info(
|
299 |
+
'The GPS time stamp in image file "{0:s}" is invalid. '
|
300 |
+
"Falling back to DateTime*".format(self.fileobj_name)
|
301 |
+
)
|
302 |
+
|
303 |
+
time_strings = [
|
304 |
+
("EXIF DateTimeOriginal", "EXIF SubSecTimeOriginal", "EXIF Tag 0x9011"),
|
305 |
+
("EXIF DateTimeDigitized", "EXIF SubSecTimeDigitized", "EXIF Tag 0x9012"),
|
306 |
+
("Image DateTime", "Image SubSecTime", "Image Tag 0x9010"),
|
307 |
+
]
|
308 |
+
for datetime_tag, subsec_tag, offset_tag in time_strings:
|
309 |
+
if datetime_tag in self.tags:
|
310 |
+
date_time = self.tags[datetime_tag].values
|
311 |
+
if subsec_tag in self.tags:
|
312 |
+
subsec_time = self.tags[subsec_tag].values
|
313 |
+
else:
|
314 |
+
subsec_time = "0"
|
315 |
+
try:
|
316 |
+
s = "{0:s}.{1:s}".format(date_time, subsec_time)
|
317 |
+
d = datetime.datetime.strptime(s, "%Y:%m:%d %H:%M:%S.%f")
|
318 |
+
except ValueError:
|
319 |
+
logger.debug(
|
320 |
+
'The "{1:s}" time stamp or "{2:s}" tag is invalid in '
|
321 |
+
'image file "{0:s}"'.format(
|
322 |
+
self.fileobj_name, datetime_tag, subsec_tag
|
323 |
+
)
|
324 |
+
)
|
325 |
+
continue
|
326 |
+
# Test for OffsetTimeOriginal | OffsetTimeDigitized | OffsetTime
|
327 |
+
if offset_tag in self.tags:
|
328 |
+
offset_time = self.tags[offset_tag].values
|
329 |
+
try:
|
330 |
+
d += datetime.timedelta(
|
331 |
+
hours=-int(offset_time[0:3]), minutes=int(offset_time[4:6])
|
332 |
+
)
|
333 |
+
except (TypeError, ValueError):
|
334 |
+
logger.debug(
|
335 |
+
'The "{0:s}" time zone offset in image file "{1:s}"'
|
336 |
+
" is invalid".format(offset_tag, self.fileobj_name)
|
337 |
+
)
|
338 |
+
logger.debug(
|
339 |
+
'Naively assuming UTC on "{0:s}" in image file '
|
340 |
+
'"{1:s}"'.format(datetime_tag, self.fileobj_name)
|
341 |
+
)
|
342 |
+
else:
|
343 |
+
logger.debug(
|
344 |
+
"No GPS time stamp and no time zone offset in image "
|
345 |
+
'file "{0:s}"'.format(self.fileobj_name)
|
346 |
+
)
|
347 |
+
logger.debug(
|
348 |
+
'Naively assuming UTC on "{0:s}" in image file "{1:s}"'.format(
|
349 |
+
datetime_tag, self.fileobj_name
|
350 |
+
)
|
351 |
+
)
|
352 |
+
return (d - datetime.datetime(1970, 1, 1)).total_seconds()
|
353 |
+
logger.info(
|
354 |
+
'Image file "{0:s}" has no valid time stamp'.format(self.fileobj_name)
|
355 |
+
)
|
356 |
+
return 0.0
|
utils/geo.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from .geo_opensfm import TopocentricConverter
|
9 |
+
|
10 |
+
|
11 |
+
class BoundaryBox:
|
12 |
+
def __init__(self, min_: np.ndarray, max_: np.ndarray):
|
13 |
+
self.min_ = np.asarray(min_)
|
14 |
+
self.max_ = np.asarray(max_)
|
15 |
+
assert np.all(self.min_ <= self.max_)
|
16 |
+
|
17 |
+
@classmethod
|
18 |
+
def from_string(cls, string: str):
|
19 |
+
return cls(*np.split(np.array(string.split(","), float), 2))
|
20 |
+
|
21 |
+
@property
|
22 |
+
def left_top(self):
|
23 |
+
return np.stack([self.min_[..., 0], self.max_[..., 1]], -1)
|
24 |
+
|
25 |
+
@property
|
26 |
+
def right_bottom(self) -> (np.ndarray, np.ndarray):
|
27 |
+
return np.stack([self.max_[..., 0], self.min_[..., 1]], -1)
|
28 |
+
|
29 |
+
@property
|
30 |
+
def center(self) -> np.ndarray:
|
31 |
+
return (self.min_ + self.max_) / 2
|
32 |
+
|
33 |
+
@property
|
34 |
+
def size(self) -> np.ndarray:
|
35 |
+
return self.max_ - self.min_
|
36 |
+
|
37 |
+
def translate(self, t: float):
|
38 |
+
return self.__class__(self.min_ + t, self.max_ + t)
|
39 |
+
|
40 |
+
def contains(self, xy: Union[np.ndarray, "BoundaryBox"]):
|
41 |
+
if isinstance(xy, self.__class__):
|
42 |
+
return self.contains(xy.min_) and self.contains(xy.max_)
|
43 |
+
return np.all((xy >= self.min_) & (xy <= self.max_), -1)
|
44 |
+
|
45 |
+
def normalize(self, xy):
|
46 |
+
min_, max_ = self.min_, self.max_
|
47 |
+
if isinstance(xy, torch.Tensor):
|
48 |
+
min_ = torch.from_numpy(min_).to(xy)
|
49 |
+
max_ = torch.from_numpy(max_).to(xy)
|
50 |
+
return (xy - min_) / (max_ - min_)
|
51 |
+
|
52 |
+
def unnormalize(self, xy):
|
53 |
+
min_, max_ = self.min_, self.max_
|
54 |
+
if isinstance(xy, torch.Tensor):
|
55 |
+
min_ = torch.from_numpy(min_).to(xy)
|
56 |
+
max_ = torch.from_numpy(max_).to(xy)
|
57 |
+
return xy * (max_ - min_) + min_
|
58 |
+
|
59 |
+
def format(self) -> str:
|
60 |
+
return ",".join(np.r_[self.min_, self.max_].astype(str))
|
61 |
+
|
62 |
+
def __add__(self, x):
|
63 |
+
if isinstance(x, (int, float)):
|
64 |
+
return self.__class__(self.min_ - x, self.max_ + x)
|
65 |
+
else:
|
66 |
+
raise TypeError(f"Cannot add {self.__class__.__name__} to {type(x)}.")
|
67 |
+
|
68 |
+
def __and__(self, other):
|
69 |
+
return self.__class__(
|
70 |
+
np.maximum(self.min_, other.min_), np.minimum(self.max_, other.max_)
|
71 |
+
)
|
72 |
+
|
73 |
+
def __repr__(self):
|
74 |
+
return self.format()
|
75 |
+
|
76 |
+
|
77 |
+
class Projection:
|
78 |
+
def __init__(self, lat, lon, alt=0, max_extent=25e3):
|
79 |
+
# The approximation error is |L - radius * tan(L / radius)|
|
80 |
+
# and is around 13cm for L=25km.
|
81 |
+
self.latlonalt = (lat, lon, alt)
|
82 |
+
self.converter = TopocentricConverter(lat, lon, alt)
|
83 |
+
min_ = self.converter.to_lla(*(-max_extent,) * 2, 0)[:2]
|
84 |
+
max_ = self.converter.to_lla(*(max_extent,) * 2, 0)[:2]
|
85 |
+
self.bounds = BoundaryBox(min_, max_)
|
86 |
+
|
87 |
+
@classmethod
|
88 |
+
def from_points(cls, all_latlon):
|
89 |
+
assert all_latlon.shape[-1] == 2
|
90 |
+
all_latlon = all_latlon.reshape(-1, 2)
|
91 |
+
latlon_mid = (all_latlon.min(0) + all_latlon.max(0)) / 2
|
92 |
+
return cls(*latlon_mid)
|
93 |
+
|
94 |
+
def check_bbox(self, bbox: BoundaryBox):
|
95 |
+
if self.bounds is not None and not self.bounds.contains(bbox):
|
96 |
+
raise ValueError(
|
97 |
+
f"Bbox {bbox.format()} is not contained in "
|
98 |
+
f"projection with bounds {self.bounds.format()}."
|
99 |
+
)
|
100 |
+
|
101 |
+
def project(self, geo, return_z=False):
|
102 |
+
if isinstance(geo, BoundaryBox):
|
103 |
+
return BoundaryBox(*self.project(np.stack([geo.min_, geo.max_])))
|
104 |
+
geo = np.asarray(geo)
|
105 |
+
assert geo.shape[-1] in (2, 3)
|
106 |
+
if self.bounds is not None:
|
107 |
+
if not np.all(self.bounds.contains(geo[..., :2])):
|
108 |
+
raise ValueError(
|
109 |
+
f"Points {geo} are out of the valid bounds "
|
110 |
+
f"{self.bounds.format()}."
|
111 |
+
)
|
112 |
+
lat, lon = geo[..., 0], geo[..., 1]
|
113 |
+
if geo.shape[-1] == 3:
|
114 |
+
alt = geo[..., -1]
|
115 |
+
else:
|
116 |
+
alt = np.zeros_like(lat)
|
117 |
+
x, y, z = self.converter.to_topocentric(lat, lon, alt)
|
118 |
+
return np.stack([x, y] + ([z] if return_z else []), -1)
|
119 |
+
|
120 |
+
def unproject(self, xy, return_z=False):
|
121 |
+
if isinstance(xy, BoundaryBox):
|
122 |
+
return BoundaryBox(*self.unproject(np.stack([xy.min_, xy.max_])))
|
123 |
+
xy = np.asarray(xy)
|
124 |
+
x, y = xy[..., 0], xy[..., 1]
|
125 |
+
if xy.shape[-1] == 3:
|
126 |
+
z = xy[..., -1]
|
127 |
+
else:
|
128 |
+
z = np.zeros_like(x)
|
129 |
+
lat, lon, alt = self.converter.to_lla(x, y, z)
|
130 |
+
return np.stack([lat, lon] + ([alt] if return_z else []), -1)
|