wangerniu commited on
Commit
629144d
·
1 Parent(s): 5de8ec7
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .idea
2
+ temp
3
+ temp.py
4
+ weight
conf/maplocnet.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ root: '/root/DATASET/UAV2MAP/UAV/'
3
+ train_citys:
4
+ - Paris
5
+ - Berlin
6
+ - London
7
+ - Tokyo
8
+ - NewYork
9
+ val_citys:
10
+ - Toronto
11
+ image_size: 256
12
+ train:
13
+ batch_size: 12
14
+ num_workers: 4
15
+ val:
16
+ batch_size: ${..train.batch_size}
17
+ num_workers: ${.batch_size}
18
+ num_classes:
19
+ areas: 7
20
+ ways: 10
21
+ nodes: 33
22
+ pixel_per_meter: 1
23
+ crop_size_meters: 64
24
+ max_init_error: 48
25
+ add_map_mask: true
26
+ resize_image: 512
27
+ pad_to_square: true
28
+ rectify_pitch: true
29
+ augmentation:
30
+ rot90: true
31
+ flip: true
32
+ image:
33
+ apply: true
34
+ brightness: 0.5
35
+ contrast: 0.4
36
+ saturation: 0.4
37
+ hue": 0.5/3.14
38
+ model:
39
+ image_size: ${data.image_size}
40
+ latent_dim: 128
41
+ val_citys: ${data.val_citys}
42
+ image_encoder:
43
+ name: feature_extractor_v2
44
+ backbone:
45
+ encoder: resnet50
46
+ pretrained: true
47
+ output_dim: 8
48
+ num_downsample: null
49
+ remove_stride_from_first_conv: false
50
+ name: orienternet
51
+ matching_dim: 8
52
+ z_max: 32
53
+ x_max: 32
54
+ pixel_per_meter: 1
55
+ num_scale_bins: 33
56
+ num_rotations: 64
57
+ map_encoder:
58
+ embedding_dim: 16
59
+ output_dim: 8
60
+ num_classes:
61
+ areas: 7
62
+ ways: 10
63
+ nodes: 33
64
+ backbone:
65
+ encoder: vgg19
66
+ pretrained: false
67
+ output_scales:
68
+ - 0
69
+ num_downsample: 3
70
+ decoder:
71
+ - 128
72
+ - 64
73
+ - 64
74
+ padding: replicate
75
+ unary_prior: false
76
+ bev_net:
77
+ num_blocks: 4
78
+ latent_dim: 128
79
+ output_dim: 8
80
+ confidence: true
81
+ experiment:
82
+ name: maplocanet_0906_diffhight
83
+ gpus: 6
84
+ seed: 0
85
+ training:
86
+ lr: 0.0001
87
+ lr_scheduler: null
88
+ finetune_from_checkpoint: null
89
+ trainer:
90
+ val_check_interval: 1000
91
+ log_every_n_steps: 100
92
+ # limit_val_batches: 1000
93
+ max_steps: 200000
94
+ devices: ${experiment.gpus}
95
+ checkpointing:
96
+ monitor: "loss/total/val"
97
+ save_top_k: 10
98
+ mode: min
99
+
100
+ # filename: '{epoch}-{step}-{loss_SanFrancisco:.2f}'
dataset/UAV/dataset.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import os
4
+ import cv2
5
+ # @Time : 2023-02-13 22:56
6
+ # @Author : Wang Zhen
7
+ # @Email : frozenzhencola@163.com
8
+ # @File : SatelliteTool.py
9
+ # @Project : TGRS_seqmatch_2023_1
10
+ import numpy as np
11
+ import random
12
+ from utils.geo import BoundaryBox, Projection
13
+ from osm.tiling import TileManager,MapTileManager
14
+ from pathlib import Path
15
+ from torchvision import transforms
16
+ from torch.utils.data import DataLoader
17
+
18
+ class UavMapPair(Dataset):
19
+ def __init__(
20
+ self,
21
+ root: Path,
22
+ city:str,
23
+ training:bool,
24
+ transform
25
+ ):
26
+ super().__init__()
27
+
28
+ # self.root = root
29
+
30
+ # city = 'Manhattan'
31
+ # root = '/root/DATASET/CrossModel/'
32
+ # root=Path(root)
33
+ self.uav_image_path = root/city/'uav'
34
+ self.map_path = root/city/'map'
35
+ self.map_vis = root / city / 'map_vis'
36
+ info_path = root / city / 'info.csv'
37
+
38
+ self.info = np.loadtxt(str(info_path), dtype=str, delimiter=",", skiprows=1)
39
+
40
+ self.transform=transform
41
+ self.training=training
42
+
43
+ def random_center_crop(self,image):
44
+ height, width = image.shape[:2]
45
+
46
+ # 随机生成剪裁尺寸
47
+ crop_size = random.randint(min(height, width) // 2, min(height, width))
48
+
49
+ # 计算剪裁的起始坐标
50
+ start_x = (width - crop_size) // 2
51
+ start_y = (height - crop_size) // 2
52
+
53
+ # 进行剪裁
54
+ cropped_image = image[start_y:start_y + crop_size, start_x:start_x + crop_size]
55
+
56
+ return cropped_image
57
+ def __getitem__(self, index: int):
58
+ id, uav_name, map_name, \
59
+ uav_long, uav_lat, \
60
+ map_long, map_lat, \
61
+ tile_size_meters, pixel_per_meter, \
62
+ u, v, yaw,dis=self.info[index]
63
+
64
+
65
+ uav_image=cv2.imread(str(self.uav_image_path/uav_name))
66
+ if self.training:
67
+ uav_image =self.random_center_crop(uav_image)
68
+ uav_image=cv2.cvtColor(uav_image,cv2.COLOR_BGR2RGB)
69
+ if self.transform:
70
+ uav_image=self.transform(uav_image)
71
+ map=np.load(str(self.map_path/map_name))
72
+
73
+ return {
74
+ 'map':torch.from_numpy(np.ascontiguousarray(map)).long(),
75
+ 'image':torch.tensor(uav_image),
76
+ 'roll_pitch_yaw':torch.tensor((0, 0, float(yaw))).float(),
77
+ 'pixels_per_meter':torch.tensor(float(pixel_per_meter)).float(),
78
+ "uv":torch.tensor([float(u), float(v)]).float(),
79
+ }
80
+ def __len__(self):
81
+ return len(self.info)
82
+ if __name__ == '__main__':
83
+
84
+ root=Path('/root/DATASET/OrienterNet/UavMap/')
85
+ city='NewYork'
86
+
87
+ transform = transforms.Compose([
88
+ transforms.ToTensor(),
89
+ transforms.Resize(256),
90
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
91
+ ])
92
+
93
+ dataset=UavMapPair(
94
+ root=root,
95
+ city=city,
96
+ transform=transform
97
+ )
98
+ datasetloder = DataLoader(dataset, batch_size=3)
99
+ for batch, i in enumerate(datasetloder):
100
+ pass
101
+ # 将PyTorch张量转换为PIL图像
102
+ # pil_image = Image.fromarray(i['uav_image'][0].permute(1, 2, 0).byte().numpy())
103
+
104
+ # 显示图像
105
+ # 将PyTorch张量转换为NumPy数组
106
+ # numpy_array = i['uav_image'][0].numpy()
107
+ #
108
+ # # 显示图像
109
+ # plt.imshow(numpy_array.transpose(1, 2, 0))
110
+ # plt.axis('off')
111
+ # plt.show()
112
+ #
113
+ # map_viz, label = Colormap.apply(i['map'][0])
114
+ # map_viz = map_viz * 255
115
+ # map_viz = map_viz.astype(np.uint8)
116
+ # plot_images([map_viz], titles=["OpenStreetMap raster"])
dataset/UAV/prepara_dataset.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import os
4
+ import cv2
5
+ # @Time : 2023-02-13 22:56
6
+ # @Author : Wang Zhen
7
+ # @Email : frozenzhencola@163.com
8
+ # @File : SatelliteTool.py
9
+ # @Project : TGRS_seqmatch_2023_1
10
+ import numpy as np
11
+ import random
12
+ from utils.geo import BoundaryBox, Projection
13
+ from osm.tiling import TileManager,MapTileManager
14
+ from pathlib import Path
15
+ from torchvision import transforms
16
+ from tqdm import tqdm
17
+ import time
18
+ import math
19
+ import random
20
+ from geopy import Point, distance
21
+ from osm.viz import Colormap, plot_nodes
22
+
23
+ def generate_random_coordinate(latitude, longitude, dis):
24
+ # 生成一个随机方向角
25
+ random_angle = random.uniform(0, 360)
26
+ # print("random_angle",random_angle)
27
+ # 计算目标点的经纬度
28
+ start_point = Point(latitude, longitude)
29
+ destination = distance.distance(kilometers=dis/1000).destination(start_point, random_angle)
30
+
31
+ return destination.latitude, destination.longitude
32
+
33
+ def rotate_corp(src,angle):
34
+ # 原图的高、宽 以及通道数
35
+ rows, cols, channel = src.shape
36
+
37
+ # 绕图像的中心旋转
38
+ # 参数:旋转中心 旋转度数 scale
39
+ M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1)
40
+ # rows, cols=700,700
41
+ # 自适应图片边框大小
42
+ cos = np.abs(M[0, 0])
43
+ sin = np.abs(M[0, 1])
44
+ new_w = rows * sin + cols * cos
45
+ new_h = rows * cos + cols * sin
46
+ M[0, 2] += (new_w - cols) * 0.5
47
+ M[1, 2] += (new_h - rows) * 0.5
48
+ w = int(np.round(new_w))
49
+ h = int(np.round(new_h))
50
+ rotated = cv2.warpAffine(src, M, (w, h))
51
+
52
+ # rotated = cv2.warpAffine(src, M, (cols, rows))
53
+
54
+ c=int(w / 2)
55
+ w=int(rows*math.sqrt(2)/4)
56
+ rotated2=rotated[c-w:c+w,c-w:c+w,:]
57
+ return rotated2
58
+
59
+ class SatelliteGeoTools:
60
+ """
61
+ 用于读取卫星图tfw文件,执行 像素坐标-Mercator-GPS坐标 的转化
62
+ """
63
+ def __init__(self, tfw_path):
64
+ self.SatelliteParameter=self.Parsetfw(tfw_path)
65
+ def Parsetfw(self, tfw_path):
66
+ info = []
67
+ f = open(tfw_path)
68
+ for _ in range(6):
69
+ line = f.readline()
70
+ line = line.strip('\n')
71
+ info.append(float(line))
72
+ f.close()
73
+ return info
74
+ def Pix2Geo(self, x, y):
75
+ A, D, B, E, C, F = self.SatelliteParameter
76
+ x1 = A * x + B * y + C
77
+ y1 = D * x + E * y + F
78
+ # print(x1,y1)
79
+ s_long, s_lat = self.MercatorTolonlat(x1, y1)
80
+ return s_long, s_lat
81
+
82
+ def Geo2Pix(self, lon, lat):
83
+ """
84
+ https://baike.baidu.com/item/TFW%E6%A0%BC%E5%BC%8F/6273151?fr=aladdin
85
+ x'=Ax+By+C
86
+ y'=Dx+Ey+F
87
+ :return:
88
+ """
89
+ x1, y1 = self.LonlatToMercator(lon, lat)
90
+ A, D, B, E, C, F = self.SatelliteParameter
91
+ M = np.array([[A, B, C],
92
+ [D, E, F],
93
+ [0, 0, 1]])
94
+ M_INV = np.linalg.inv(M)
95
+ XY = np.matmul(M_INV, np.array([x1, y1, 1]).T)
96
+ return int(XY[0]), int(XY[1])
97
+ def MercatorTolonlat(self,mx,my):
98
+ x = mx/20037508.3427892*180
99
+ y = my/20037508.3427892*180
100
+ # y= 180/math.pi*(2*math.atan(math.exp(y*math.pi/180))-math.pi/2)
101
+ y = 180.0 / np.pi * (2.0 * np.arctan(np.exp(y * np.pi / 180.0)) - np.pi / 2.0)
102
+ return x,y
103
+ def LonlatToMercator(self,lon, lat):
104
+ x = lon * 20037508.342789 / 180
105
+ y = np.log(np.tan((90 + lat) * np.pi / 360)) / (np.pi / 180)
106
+ y = y * 20037508.34789 / 180
107
+ return x, y
108
+
109
+ def geodistance(lng1, lat1, lng2, lat2):
110
+ lng1, lat1, lng2, lat2 = map(np.radians, [lng1, lat1, lng2, lat2])
111
+ dlon = lng2 - lng1
112
+ dlat = lat2 - lat1
113
+ a = np.sin(dlat / 2) ** 2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2) ** 2
114
+ distance = 2 * np.arcsin(np.sqrt(a)) * 6371 * 1000 # 地球平均半径,6371km
115
+ return distance
116
+
117
+ class PreparaDataset:
118
+ def __init__(
119
+ self,
120
+ root: Path,
121
+ city:str,
122
+ patch_size:int,
123
+ tile_size_meters:float
124
+ ):
125
+ super().__init__()
126
+
127
+ # self.root = root
128
+
129
+ # city = 'Manhattan'
130
+ # root = '/root/DATASET/CrossModel/'
131
+ imagepath = root/city/ '{}.tif'.format(city)
132
+ tfwpath = root/city/'{}.tfw'.format(city)
133
+
134
+ self.osmpath = root/city/'{}.osm'.format(city)
135
+
136
+ self.TileManager=MapTileManager(self.osmpath)
137
+ image = cv2.imread(str(imagepath))
138
+ self.image=cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
139
+
140
+ self.ST = SatelliteGeoTools(str(tfwpath))
141
+
142
+ self.patch_size=patch_size
143
+ self.tile_size_meters=tile_size_meters
144
+
145
+
146
+
147
+ def get_osm(self,prior_latlon,uav_latlon):
148
+ latlon = np.array(prior_latlon)
149
+ proj = Projection(*latlon)
150
+ center = proj.project(latlon)
151
+
152
+ uav_latlon=np.array(uav_latlon)
153
+
154
+ XY=proj.project(uav_latlon)
155
+ # tile_size_meters = 128
156
+ bbox = BoundaryBox(center, center) + self.tile_size_meters
157
+ # bbox= BoundaryBox(center, center)
158
+ # Query OpenStreetMap for this area
159
+ self.pixel_per_meter = 1
160
+ start_time = time.time()
161
+ canvas = self.TileManager.from_bbox(proj, bbox, self.pixel_per_meter)
162
+ end_time = time.time()
163
+ execution_time = end_time - start_time
164
+ # print("方法执行时间:", execution_time, "秒")
165
+ # canvas = tiler.query(bbox)
166
+ XY=[XY[0]+self.tile_size_meters,-XY[1]+self.tile_size_meters]
167
+ return canvas,XY
168
+ def random_corp(self):
169
+
170
+ # 根据随机裁剪尺寸计算出裁剪区域的左上角坐标
171
+ x = random.randint(1000, self.image.shape[1] - self.patch_size-1000)
172
+ y = random.randint(1000, self.image.shape[0] - self.patch_size-1000)
173
+ x1 = x + self.patch_size
174
+ y1 = y + self.patch_size
175
+ return x,x1,y,y1
176
+
177
+ def generate(self):
178
+ x,x1,y,y1 = self.random_corp()
179
+ uav_center_x,uav_center_y=int((x+x1)//2),int((y+y1)//2)
180
+ uav_center_long,uav_center_lat=self.ST.Pix2Geo(uav_center_x,uav_center_y)
181
+ # print(uav_center_long,uav_center_lat)
182
+ self.image_patch = self.image[y:y1, x:x1]
183
+
184
+ map_center_lat, map_center_long = generate_random_coordinate(uav_center_lat, uav_center_long, self.tile_size_meters)
185
+ map,XY=self.get_osm([map_center_lat,map_center_long],[uav_center_lat, uav_center_long])
186
+
187
+
188
+ yaw=np.random.random()*360
189
+ self.image_patch=rotate_corp(self.image_patch,yaw)
190
+ # return self.image_patch,self.osm_patch
191
+ # XY=[X+self.tile_size_meters
192
+ return {
193
+ 'uav_image':self.image_patch,
194
+ 'uav_long_lat':[uav_center_long,uav_center_lat],
195
+ 'map_long_lat': [map_center_long,map_center_lat],
196
+ 'tile_size_meters': map.raster.shape[1],
197
+ 'pixel_per_meter':self.pixel_per_meter,
198
+ 'yaw':yaw,
199
+ 'map':map.raster,
200
+ "uv":XY
201
+ }
202
+ if __name__ == '__main__':
203
+
204
+ import argparse
205
+
206
+ parser = argparse.ArgumentParser(description='manual to this script')
207
+ parser.add_argument('--city', type=str, default=None,required=True)
208
+ parser.add_argument('--num', type=int, default=10000)
209
+ args = parser.parse_args()
210
+
211
+
212
+ root=Path('/root/DATASET/OrienterNet/UavMap/')
213
+ city=args.city
214
+ dataset = PreparaDataset(
215
+ root=root,
216
+ city=city,
217
+ patch_size=512,
218
+ tile_size_meters=128,
219
+ )
220
+
221
+ uav_path=root/city/'uav'
222
+ if not uav_path.exists():
223
+ uav_path.mkdir(parents=True)
224
+
225
+ map_path = root / city / 'map'
226
+ if not map_path.exists():
227
+ map_path.mkdir(parents=True)
228
+
229
+ map_vis_path = root / city / 'map_vis'
230
+ if not map_vis_path.exists():
231
+ map_vis_path.mkdir(parents=True)
232
+
233
+ info_path = root / city / 'info.csv'
234
+
235
+ # num=1000
236
+ num = args.num
237
+ info=[['id','uav_name','map_name','uav_long','uav_lat','map_long','map_lat','tile_size_meters','pixel_per_meter','u','v','yaw']]
238
+ # info =[]
239
+ for i in tqdm(range(num)):
240
+ data=dataset.generate()
241
+ # print(str(uav_path/"{:05d}.jpg".format(i)))
242
+
243
+ cv2.imwrite(str(uav_path/"{:05d}.jpg".format(i)),cv2.cvtColor(data['uav_image'],cv2.COLOR_RGB2BGR))
244
+
245
+ np.save(str(map_path/"{:05d}.npy".format(i)),data['map'])
246
+
247
+ map_viz, label = Colormap.apply(data['map'])
248
+ map_viz = map_viz * 255
249
+ map_viz = map_viz.astype(np.uint8)
250
+ cv2.imwrite(str(map_vis_path / "{:05d}.jpg".format(i)), cv2.cvtColor(map_viz, cv2.COLOR_RGB2BGR))
251
+
252
+
253
+ uav_center_long, uav_center_lat=data['uav_long_lat']
254
+ map_center_long, map_center_lat = data['map_long_lat']
255
+ info.append([
256
+ i,
257
+ "{:05d}.jpg".format(i),
258
+ "{:05d}.npy".format(i),
259
+ uav_center_long,
260
+ uav_center_lat,
261
+ map_center_long,
262
+ map_center_lat,
263
+ data["tile_size_meters"],
264
+ data["pixel_per_meter"],
265
+ data['uv'][0],
266
+ data['uv'][1],
267
+ data['yaw']
268
+ ])
269
+ # print(info)
270
+ np.savetxt(info_path,info,delimiter=',',fmt="%s")
dataset/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # from .UAV.dataset import UavMapPair
2
+ from .dataset import UavMapDatasetModule
3
+
4
+ # modules = {"UAV": UavMapPair}
dataset/dataset.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from copy import deepcopy
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List
6
+ # from logger import logger
7
+ import numpy as np
8
+ # import torch
9
+ # import torch.utils.data as torchdata
10
+ # import torchvision.transforms as tvf
11
+ from omegaconf import DictConfig, OmegaConf
12
+ import pytorch_lightning as pl
13
+ from dataset.UAV.dataset import UavMapPair
14
+ # from torch.utils.data import Dataset, DataLoader
15
+ # from torchvision import transforms
16
+ from torch.utils.data import Dataset, ConcatDataset
17
+ from torch.utils.data import Dataset, DataLoader, random_split
18
+ import torchvision.transforms as tvf
19
+
20
+ # 自定义数据模块类,继承自pl.LightningDataModule
21
+ class UavMapDatasetModule(pl.LightningDataModule):
22
+
23
+
24
+ def __init__(self, cfg: Dict[str, Any]):
25
+ super().__init__()
26
+
27
+ # default_cfg = OmegaConf.create(self.default_cfg)
28
+ # OmegaConf.set_struct(default_cfg, True) # cannot add new keys
29
+ # self.cfg = OmegaConf.merge(default_cfg, cfg)
30
+ self.cfg=cfg
31
+ # self.transform = tvf.Compose([
32
+ # tvf.ToTensor(),
33
+ # tvf.Resize(self.cfg.image_size),
34
+ # tvf.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
35
+ # ])
36
+
37
+ tfs = []
38
+ tfs.append(tvf.ToTensor())
39
+ tfs.append(tvf.Resize(self.cfg.image_size))
40
+ self.val_tfs = tvf.Compose(tfs)
41
+
42
+ # transforms.Resize(self.cfg.image_size),
43
+ if cfg.augmentation.image.apply:
44
+ args = OmegaConf.masked_copy(
45
+ cfg.augmentation.image, ["brightness", "contrast", "saturation", "hue"]
46
+ )
47
+ tfs.append(tvf.ColorJitter(**args))
48
+ self.train_tfs = tvf.Compose(tfs)
49
+
50
+ # self.train_tfs=self.transform
51
+ # self.val_tfs = self.transform
52
+ self.init()
53
+ def init(self):
54
+ self.train_dataset = ConcatDataset([
55
+ UavMapPair(root=Path(self.cfg.root),city=city,training=True,transform=self.train_tfs)
56
+ for city in self.cfg.train_citys
57
+ ])
58
+
59
+ self.val_dataset = ConcatDataset([
60
+ UavMapPair(root=Path(self.cfg.root),city=city,training=False,transform=self.val_tfs)
61
+ for city in self.cfg.val_citys
62
+ ])
63
+
64
+ # self.val_datasets = {
65
+ # city:UavMapPair(root=Path(self.cfg.root),city=city,transform=self.val_tfs)
66
+ # for city in self.cfg.val_citys
67
+ # }
68
+ # logger.info("train data len:{},val data len:{}".format(len(self.train_dataset),len(self.val_dataset)))
69
+ # # 定义分割比例
70
+ # train_ratio = 0.8 # 训练集比例
71
+ # # 计算分割的样本数量
72
+ # train_size = int(len(self.dataset) * train_ratio)
73
+ # val_size = len(self.dataset) - train_size
74
+ # self.train_dataset, self.val_dataset = random_split(self.dataset, [train_size, val_size])
75
+ def train_dataloader(self):
76
+ train_loader = DataLoader(self.train_dataset,
77
+ batch_size=self.cfg.train.batch_size,
78
+ num_workers=self.cfg.train.num_workers,
79
+ shuffle=True,pin_memory = True)
80
+ return train_loader
81
+
82
+ def val_dataloader(self):
83
+ val_loader = DataLoader(self.val_dataset,
84
+ batch_size=self.cfg.val.batch_size,
85
+ num_workers=self.cfg.val.num_workers,
86
+ shuffle=True,pin_memory = True)
87
+ #
88
+ # my_dict = {k: v for k, v in self.val_datasets}
89
+ # val_loaders={city: DataLoader(dataset,
90
+ # batch_size=self.cfg.val.batch_size,
91
+ # num_workers=self.cfg.val.num_workers,
92
+ # shuffle=False,pin_memory = True) for city, dataset in self.val_datasets.items()}
93
+ return val_loader
dataset/image.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from typing import Callable, Optional, Union, Sequence
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchvision.transforms.functional as tvf
8
+ import collections
9
+ from scipy.spatial.transform import Rotation
10
+
11
+ from utils.geometry import from_homogeneous, to_homogeneous
12
+ from utils.wrappers import Camera
13
+
14
+
15
+ def rectify_image(
16
+ image: torch.Tensor,
17
+ cam: Camera,
18
+ roll: float,
19
+ pitch: Optional[float] = None,
20
+ valid: Optional[torch.Tensor] = None,
21
+ ):
22
+ *_, h, w = image.shape
23
+ grid = torch.meshgrid(
24
+ [torch.arange(w, device=image.device), torch.arange(h, device=image.device)],
25
+ indexing="xy",
26
+ )
27
+ grid = torch.stack(grid, -1).to(image.dtype)
28
+
29
+ if pitch is not None:
30
+ args = ("ZX", (roll, pitch))
31
+ else:
32
+ args = ("Z", roll)
33
+ R = Rotation.from_euler(*args, degrees=True).as_matrix()
34
+ R = torch.from_numpy(R).to(image)
35
+
36
+ grid_rect = to_homogeneous(cam.normalize(grid)) @ R.T
37
+ grid_rect = cam.denormalize(from_homogeneous(grid_rect))
38
+ grid_norm = (grid_rect + 0.5) / grid.new_tensor([w, h]) * 2 - 1
39
+ rectified = torch.nn.functional.grid_sample(
40
+ image[None],
41
+ grid_norm[None],
42
+ align_corners=False,
43
+ mode="bilinear",
44
+ ).squeeze(0)
45
+ if valid is None:
46
+ valid = torch.all((grid_norm >= -1) & (grid_norm <= 1), -1)
47
+ else:
48
+ valid = (
49
+ torch.nn.functional.grid_sample(
50
+ valid[None, None].float(),
51
+ grid_norm[None],
52
+ align_corners=False,
53
+ mode="nearest",
54
+ )[0, 0]
55
+ > 0
56
+ )
57
+ return rectified, valid
58
+
59
+
60
+ def resize_image(
61
+ image: torch.Tensor,
62
+ size: Union[int, Sequence, np.ndarray],
63
+ fn: Optional[Callable] = None,
64
+ camera: Optional[Camera] = None,
65
+ valid: np.ndarray = None,
66
+ ):
67
+ """Resize an image to a fixed size, or according to max or min edge."""
68
+ *_, h, w = image.shape
69
+ if fn is not None:
70
+ assert isinstance(size, int)
71
+ scale = size / fn(h, w)
72
+ h_new, w_new = int(round(h * scale)), int(round(w * scale))
73
+ scale = (scale, scale)
74
+ else:
75
+ if isinstance(size, (collections.abc.Sequence, np.ndarray)):
76
+ w_new, h_new = size
77
+ elif isinstance(size, int):
78
+ w_new = h_new = size
79
+ else:
80
+ raise ValueError(f"Incorrect new size: {size}")
81
+ scale = (w_new / w, h_new / h)
82
+ if (w, h) != (w_new, h_new):
83
+ mode = tvf.InterpolationMode.BILINEAR
84
+ image = tvf.resize(image, (h_new, w_new), interpolation=mode, antialias=True)
85
+ image.clip_(0, 1)
86
+ if camera is not None:
87
+ camera = camera.scale(scale)
88
+ if valid is not None:
89
+ valid = tvf.resize(
90
+ valid.unsqueeze(0),
91
+ (h_new, w_new),
92
+ interpolation=tvf.InterpolationMode.NEAREST,
93
+ ).squeeze(0)
94
+ ret = [image, scale]
95
+ if camera is not None:
96
+ ret.append(camera)
97
+ if valid is not None:
98
+ ret.append(valid)
99
+ return ret
100
+
101
+
102
+ def pad_image(
103
+ image: torch.Tensor,
104
+ size: Union[int, Sequence, np.ndarray],
105
+ camera: Optional[Camera] = None,
106
+ valid: torch.Tensor = None,
107
+ crop_and_center: bool = False,
108
+ ):
109
+ if isinstance(size, int):
110
+ w_new = h_new = size
111
+ elif isinstance(size, (collections.abc.Sequence, np.ndarray)):
112
+ w_new, h_new = size
113
+ else:
114
+ raise ValueError(f"Incorrect new size: {size}")
115
+ *c, h, w = image.shape
116
+ if crop_and_center:
117
+ diff = np.array([w - w_new, h - h_new])
118
+ left, top = left_top = np.round(diff / 2).astype(int)
119
+ right, bottom = diff - left_top
120
+ else:
121
+ assert h <= h_new
122
+ assert w <= w_new
123
+ top = bottom = left = right = 0
124
+ slice_out = np.s_[..., : min(h, h_new), : min(w, w_new)]
125
+ slice_in = np.s_[
126
+ ..., max(top, 0) : h - max(bottom, 0), max(left, 0) : w - max(right, 0)
127
+ ]
128
+ if (w, h) == (w_new, h_new):
129
+ out = image
130
+ else:
131
+ out = torch.zeros((*c, h_new, w_new), dtype=image.dtype)
132
+ out[slice_out] = image[slice_in]
133
+ if camera is not None:
134
+ camera = camera.crop((max(left, 0), max(top, 0)), (w_new, h_new))
135
+ out_valid = torch.zeros((h_new, w_new), dtype=torch.bool)
136
+ out_valid[slice_out] = True if valid is None else valid[slice_in]
137
+ if camera is not None:
138
+ return out, out_valid, camera
139
+ else:
140
+ return out, out_valid
dataset/torch.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import collections
4
+ import os
5
+
6
+ import torch
7
+ from torch.utils.data import get_worker_info
8
+ from torch.utils.data._utils.collate import (
9
+ default_collate_err_msg_format,
10
+ np_str_obj_array_pattern,
11
+ )
12
+ from lightning_fabric.utilities.seed import pl_worker_init_function
13
+ from lightning_utilities.core.apply_func import apply_to_collection
14
+ from lightning_fabric.utilities.apply_func import move_data_to_device
15
+
16
+
17
+ def collate(batch):
18
+ """Difference with PyTorch default_collate: it can stack other tensor-like objects.
19
+ Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
20
+ https://github.com/cvg/pixloc
21
+ Released under the Apache License 2.0
22
+ """
23
+ if not isinstance(batch, list): # no batching
24
+ return batch
25
+ elem = batch[0]
26
+ elem_type = type(elem)
27
+ if isinstance(elem, torch.Tensor):
28
+ out = None
29
+ if torch.utils.data.get_worker_info() is not None:
30
+ # If we're in a background process, concatenate directly into a
31
+ # shared memory tensor to avoid an extra copy
32
+ numel = sum(x.numel() for x in batch)
33
+ storage = elem.storage()._new_shared(numel, device=elem.device)
34
+ out = elem.new(storage).resize_(len(batch), *list(elem.size()))
35
+ return torch.stack(batch, 0, out=out)
36
+ elif (
37
+ elem_type.__module__ == "numpy"
38
+ and elem_type.__name__ != "str_"
39
+ and elem_type.__name__ != "string_"
40
+ ):
41
+ if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap":
42
+ # array of string classes and object
43
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
44
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
45
+
46
+ return collate([torch.as_tensor(b) for b in batch])
47
+ elif elem.shape == (): # scalars
48
+ return torch.as_tensor(batch)
49
+ elif isinstance(elem, float):
50
+ return torch.tensor(batch, dtype=torch.float64)
51
+ elif isinstance(elem, int):
52
+ return torch.tensor(batch)
53
+ elif isinstance(elem, (str, bytes)):
54
+ return batch
55
+ elif isinstance(elem, collections.abc.Mapping):
56
+ return {key: collate([d[key] for d in batch]) for key in elem}
57
+ elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
58
+ return elem_type(*(collate(samples) for samples in zip(*batch)))
59
+ elif isinstance(elem, collections.abc.Sequence):
60
+ # check to make sure that the elements in batch have consistent size
61
+ it = iter(batch)
62
+ elem_size = len(next(it))
63
+ if not all(len(elem) == elem_size for elem in it):
64
+ raise RuntimeError("each element in list of batch should be of equal size")
65
+ transposed = zip(*batch)
66
+ return [collate(samples) for samples in transposed]
67
+ else:
68
+ # try to stack anyway in case the object implements stacking.
69
+ try:
70
+ return torch.stack(batch, 0)
71
+ except TypeError as e:
72
+ if "expected Tensor as element" in str(e):
73
+ return batch
74
+ else:
75
+ raise e
76
+
77
+
78
+ def set_num_threads(nt):
79
+ """Force numpy and other libraries to use a limited number of threads."""
80
+ try:
81
+ import mkl
82
+ except ImportError:
83
+ pass
84
+ else:
85
+ mkl.set_num_threads(nt)
86
+ torch.set_num_threads(1)
87
+ os.environ["IPC_ENABLE"] = "1"
88
+ for o in [
89
+ "OPENBLAS_NUM_THREADS",
90
+ "NUMEXPR_NUM_THREADS",
91
+ "OMP_NUM_THREADS",
92
+ "MKL_NUM_THREADS",
93
+ ]:
94
+ os.environ[o] = str(nt)
95
+
96
+
97
+ def worker_init_fn(i):
98
+ info = get_worker_info()
99
+ pl_worker_init_function(info.id)
100
+ num_threads = info.dataset.cfg.get("num_threads")
101
+ if num_threads is not None:
102
+ set_num_threads(num_threads)
103
+
104
+
105
+ def unbatch_to_device(data, device="cpu"):
106
+ data = move_data_to_device(data, device)
107
+ data = apply_to_collection(data, torch.Tensor, lambda x: x.squeeze(0))
108
+ data = apply_to_collection(
109
+ data, list, lambda x: x[0] if len(x) == 1 and isinstance(x[0], str) else x
110
+ )
111
+ return data
demo.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import matplotlib.pyplot as plt
3
+ # from demo import Demo, read_input_image,read_input_image_test
4
+ from evaluation.viz import plot_example_single
5
+ from dataset.torch import unbatch_to_device
6
+ import matplotlib.pyplot as plt
7
+ from typing import Optional, Tuple
8
+ import cv2
9
+ import torch
10
+ import numpy as np
11
+ import time
12
+ from logger import logger
13
+ from evaluation.run import resolve_checkpoint_path, pretrained_models
14
+ from models.maplocnet import MapLocNet
15
+ from models.voting import fuse_gps, argmax_xyr
16
+ # from data.image import resize_image, pad_image, rectify_image
17
+ from osm.raster import Canvas
18
+ from utils.wrappers import Camera
19
+ from utils.io import read_image
20
+ from utils.geo import BoundaryBox, Projection
21
+ from utils.exif import EXIF
22
+ import requests
23
+ from pathlib import Path
24
+ from utils.exif import EXIF
25
+ from dataset.image import resize_image, pad_image, rectify_image
26
+ # from maploc.demo import Demo, read_input_image
27
+ from dataset import UavMapDatasetModule
28
+ import torchvision.transforms as tvf
29
+ import matplotlib.pyplot as plt
30
+ import numpy as np
31
+ from sklearn.decomposition import PCA
32
+ from PIL import Image
33
+ # import pyproj
34
+ # Query OpenStreetMap for this area
35
+ from osm.tiling import TileManager
36
+ from utils.viz_localization import (
37
+ likelihood_overlay,
38
+ plot_dense_rotations,
39
+ add_circle_inset,
40
+ )
41
+ # Show the inputs to the model: image and raster map
42
+ from osm.viz import Colormap, plot_nodes
43
+ from utils.viz_2d import plot_images
44
+
45
+ from utils.viz_2d import features_to_RGB
46
+ import random
47
+ from geopy.distance import geodesic
48
+
49
+
50
+ def vis_image_feature(F):
51
+ def normalize(x):
52
+ return x / np.linalg.norm(x, axis=-1, keepdims=True)
53
+
54
+ # F=neural_map.numpy()
55
+ F = F[:, 0:180, 0:180]
56
+ flatten = []
57
+ c, h, w = F.shape
58
+ print(F.shape)
59
+ F = np.rollaxis(F, 0, 3)
60
+ F_flat = F.reshape(-1, c)
61
+ flatten.append(F_flat)
62
+ flatten = normalize(flatten)[0]
63
+
64
+ flatten = np.nan_to_num(flatten, nan=0)
65
+ pca = PCA(n_components=3)
66
+
67
+ print(flatten.shape)
68
+ flatten = pca.fit_transform(flatten)
69
+ flatten = (normalize(flatten) + 1) / 2
70
+
71
+ # h, w = F.shape[-2:]
72
+ F_rgb, flatten = np.split(flatten, [h * w], axis=0)
73
+ F_rgb = F_rgb.reshape((h, w, 3))
74
+ return F_rgb
75
+ def distance(lat1, lon1, lat2, lon2):
76
+ point1 = (lat1, lon1)
77
+ point2 = (lat2, lon2)
78
+ distance_km = geodesic(point1, point2).meters
79
+ return distance_km
80
+
81
+ # # 示例
82
+ # lat1, lon1 = 39.9, 116.4 # 北京的经纬度
83
+ # lat2, lon2 = 31.2, 121.5 # 上海的经纬度
84
+
85
+ # distance_km = distance(lat1, lon1, lat2, lon2)
86
+ # print(distance_km)
87
+ def show_result(map_vis_image, pre_uv, pre_yaw):
88
+ # 创建一个和原始图片大小相同的灰色蒙版图像
89
+ gray_mask = np.zeros_like(map_vis_image)
90
+ gray_mask.fill(128) # 填充灰色
91
+
92
+ # 将灰色蒙版图像与原始图像进行融合
93
+ image = cv2.addWeighted(map_vis_image, 1, gray_mask, 0, 0)
94
+ # 绘制真实值
95
+
96
+ # 绘制预测值
97
+ u, v = pre_uv
98
+ x1, y1 = int(u), int(v) # 替换为实际的起点坐标
99
+ angle = pre_yaw - 90 # 替换为实际的箭头角度
100
+ # 计算箭头的终点坐标
101
+ length = 20
102
+ x2 = int(x1 + length * np.cos(np.radians(angle)))
103
+ y2 = int(y1 + length * np.sin(np.radians(angle)))
104
+ # 在图像上画出箭头
105
+ cv2.arrowedLine(image, (x1, y1), (x2, y2), (0, 0, 0), 2, 5, 0, 0.3)
106
+ # cv2.circle(image, (x1, y1), radius=2, color=(255, 0, 255), thickness=-1)
107
+ return image
108
+
109
+
110
+ def xyz_to_latlon(x, y, z):
111
+ # 定义WGS84投影
112
+ wgs84 = pyproj.CRS('EPSG:4326')
113
+
114
+ # 定义XYZ投影
115
+ xyz = pyproj.CRS(f'+proj=geocent +datum=WGS84 +units=m +no_defs')
116
+
117
+ # 创建坐标转换器
118
+ transformer = pyproj.Transformer.from_crs(xyz, wgs84)
119
+
120
+ # 转换坐标
121
+ lon, lat, _ = transformer.transform(x, y, z)
122
+
123
+ return lat, lon
124
+
125
+
126
+ class Demo:
127
+ def __init__(
128
+ self,
129
+ experiment_or_path: Optional[str] = "OrienterNet_MGL",
130
+ device=None,
131
+ **kwargs
132
+ ):
133
+ if experiment_or_path in pretrained_models:
134
+ experiment_or_path, _ = pretrained_models[experiment_or_path]
135
+ path = resolve_checkpoint_path(experiment_or_path)
136
+ ckpt = torch.load(path, map_location=(lambda storage, loc: storage))
137
+ config = ckpt["hyper_parameters"]
138
+ config.model.update(kwargs)
139
+ config.model.image_encoder.backbone.pretrained = False
140
+
141
+ model = MapLocNet(config.model).eval()
142
+ state = {k[len("model."):]: v for k, v in ckpt["state_dict"].items()}
143
+ model.load_state_dict(state, strict=True)
144
+ if device is None:
145
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
146
+ model = model.to(device)
147
+
148
+ self.model = model
149
+ self.config = config
150
+ self.device = device
151
+
152
+ def prepare_data(
153
+ self,
154
+ image: np.ndarray,
155
+ camera: Camera,
156
+ canvas: Canvas,
157
+ roll_pitch: Optional[Tuple[float]] = None,
158
+ ):
159
+
160
+ image = torch.from_numpy(image).permute(2, 0, 1).float().div_(255)
161
+
162
+ return {
163
+ 'map': torch.from_numpy(canvas.raster).long(),
164
+ 'image': image,
165
+ # 'roll_pitch_yaw':torch.tensor((0, 0, float(yaw))).float().unsqueeze(0),
166
+ # 'pixels_per_meter':torch.tensor(float(pixel_per_meter)).float().unsqueeze(0),
167
+ # "uv":torch.tensor([float(u), float(v)]).float().unsqueeze(0),
168
+ }
169
+ # return dict(
170
+ # image=image,
171
+ # map=torch.from_numpy(canvas.raster).long(),
172
+ # camera=camera.float(),
173
+ # valid=valid,
174
+ # )
175
+
176
+ def localize(self, image: np.ndarray, camera: Camera, canvas: Canvas, **kwargs):
177
+
178
+ data = self.prepare_data(image, camera, canvas, **kwargs)
179
+ data_ = {k: v.to(self.device)[None] for k, v in data.items()}
180
+ # data_np = {k: v.cpu().numpy()[None] for k, v in data.items()}
181
+ # logger.info(data_)
182
+ # np.save(data_np, 'data_.npy')
183
+ start = time.time()
184
+ with torch.no_grad():
185
+ pred = self.model(data_)
186
+
187
+ end = time.time()
188
+ xy_gps = canvas.bbox.center
189
+ uv_gps = torch.from_numpy(canvas.to_uv(xy_gps))
190
+
191
+ lp_xyr = pred["log_probs"].squeeze(0)
192
+ # tile_size = canvas.bbox.size.min() / 2
193
+ # sigma = tile_size - 20 # 20 meters margin
194
+ # lp_xyr = fuse_gps(
195
+ # lp_xyr,
196
+ # uv_gps.to(lp_xyr),
197
+ # self.config.model.pixel_per_meter,
198
+ # sigma=sigma,
199
+ # )
200
+ xyr = argmax_xyr(lp_xyr).cpu()
201
+
202
+ prob = lp_xyr.exp().cpu()
203
+ neural_map = pred["map"]["map_features"][0].squeeze(0).cpu()
204
+ print('total time:', start - end)
205
+ return xyr[:2], xyr[2], prob, neural_map, data["image"], data_, pred
206
+
207
+
208
+ def load_test_data(
209
+ root: Path,
210
+ city: str,
211
+ index: int,
212
+ ):
213
+ uav_image_path = root / city / 'uav'
214
+ map_path = root / city / 'map'
215
+ map_vis = root / city / 'map_vis'
216
+ info_path = root / city / 'info.csv'
217
+ osm_path = root / city / '{}.osm'.format(city)
218
+
219
+ info = np.loadtxt(str(info_path), dtype=str, delimiter=",", skiprows=1)
220
+
221
+ id, uav_name, map_name, \
222
+ uav_long, uav_lat, \
223
+ map_long, map_lat, \
224
+ tile_size_meters, pixel_per_meter, \
225
+ u, v, yaw, dis = info[index]
226
+ print(info[index])
227
+ uav_image_rgb = cv2.imread(str(uav_image_path / uav_name))
228
+ uav_image_rgb = cv2.cvtColor(uav_image_rgb, cv2.COLOR_BGR2RGB)
229
+
230
+ # w,h,c=uav_image_rgb.shape
231
+ # # 指定裁剪区域的坐标
232
+ # x = w//2 # 起始横坐标
233
+ # y = h//2 # 起始纵坐标
234
+ # w = 150 # 宽度
235
+ # h = 150 # 高度
236
+
237
+ # # 裁剪图像
238
+ # uav_image_rgb = uav_image_rgb[y-h:y+h, x-w:x+w]
239
+
240
+ map_vis_image = cv2.imread(str(map_vis / uav_name))
241
+ map_vis_image = cv2.cvtColor(map_vis_image, cv2.COLOR_BGR2RGB)
242
+
243
+ map = np.load(str(map_path / map_name))
244
+
245
+ tfs = []
246
+ tfs.append(tvf.ToTensor())
247
+ tfs.append(tvf.Resize(256))
248
+ val_tfs = tvf.Compose(tfs)
249
+
250
+ uav_image = val_tfs(uav_image_rgb)
251
+ # print(id, uav_name, map_name, \
252
+ # uav_long, uav_lat, \
253
+ # map_long, map_lat, \
254
+ # tile_size_meters, pixel_per_meter, \
255
+ # u, v, yaw,dis)
256
+ uav_path = str(uav_image_path / uav_name)
257
+ return {
258
+ 'map': torch.from_numpy(np.ascontiguousarray(map)).long().unsqueeze(0),
259
+ 'image': torch.tensor(uav_image).unsqueeze(0),
260
+ 'roll_pitch_yaw': torch.tensor((0, 0, float(yaw))).float().unsqueeze(0),
261
+ 'pixels_per_meter': torch.tensor(float(pixel_per_meter)).float().unsqueeze(0),
262
+ "uv": torch.tensor([float(u), float(v)]).float().unsqueeze(0),
263
+ }, uav_image_rgb, map_vis_image, uav_path, [float(map_lat), float(map_long)]
264
+
265
+
266
+ def crop_image(image, width, height):
267
+ # 计算剪裁区域的起始点坐标
268
+ x = int((image.shape[1] - width) / 2)
269
+ y = int((image.shape[0] - height) / 2)
270
+
271
+ # 剪裁图像
272
+ cropped_image = image[y:y + height, x:x + width]
273
+ return cropped_image
274
+
275
+
276
+ def crop_square(image):
277
+ # 获取图像的宽度和高度
278
+ height, width = image.shape[:2]
279
+
280
+ # 确定最小边的长度
281
+ min_length = min(height, width)
282
+
283
+ # 计算剪裁区域的坐标
284
+ top = (height - min_length) // 2
285
+ bottom = top + min_length
286
+ left = (width - min_length) // 2
287
+ right = left + min_length
288
+
289
+ # 剪裁图像为正方形
290
+ cropped_image = image[top:bottom, left:right]
291
+
292
+ return cropped_image
293
+ def read_input_image_test(
294
+ image,
295
+ prior_latlon,
296
+ tile_size_meters,
297
+ ):
298
+ # image = read_image(image_path)
299
+ # # 剪裁图像
300
+ # # 指定剪裁的宽度和高度
301
+ # width = 1080*2
302
+ # height =1080*2
303
+ # image = crop_square(image)
304
+ # # print("input image:",image.shape)
305
+ # image = crop_image(image, width, height)
306
+ # # print("crop_image:",image.shape)
307
+ image = cv2.resize(image,(256,256))
308
+ roll_pitch = None
309
+
310
+
311
+ latlon = None
312
+ if prior_latlon is not None:
313
+ latlon = prior_latlon
314
+ logger.info("Using prior latlon %s.", prior_latlon)
315
+
316
+ if latlon is None:
317
+ with open(image_path, "rb") as fid:
318
+ exif = EXIF(fid, lambda: image.shape[:2])
319
+ geo = exif.extract_geo()
320
+ if geo:
321
+ alt = geo.get("altitude", 0) # read if available
322
+ latlon = (geo["latitude"], geo["longitude"], alt)
323
+ logger.info("Using prior location from EXIF.")
324
+ # print(latlon)
325
+ else:
326
+ logger.info("Could not find any prior location in the image EXIF metadata.")
327
+
328
+ latlon = np.array(latlon)
329
+
330
+ proj = Projection(*latlon)
331
+ center = proj.project(latlon)
332
+ bbox = BoundaryBox(center, center) + float(tile_size_meters)
333
+ camera=None
334
+ image=cv2.resize(image,(256,256))
335
+ return image, camera, roll_pitch, proj, bbox, latlon
336
+ if __name__ == '__main__':
337
+ experiment_or_path = "weight/last-step-checkpointing.ckpt"
338
+ # experiment_or_path="experiments/maplocanet_0906_diffhight/last-step-checkpointing.ckpt"
339
+ image_path='images/00000.jpg'
340
+ prior_latlon=(37.75704325989902,-122.435941445631)
341
+ tile_size_meters=128
342
+ demo = Demo(experiment_or_path=experiment_or_path, num_rotations=128, device='cpu')
343
+ image, camera, gravity, proj, bbox, true_prior_latlon = read_input_image_test(
344
+ image_path,
345
+ prior_latlon=prior_latlon,
346
+ tile_size_meters=tile_size_meters, # try 64, 256, etc.
347
+ )
348
+ tiler = TileManager.from_bbox(projection=proj, bbox=bbox + 10,ppm=1, tile_size=tile_size_meters)
349
+ # tiler = TileManager.from_bbox(projection=proj, bbox=bbox + 10,ppm=1,path=root/city/'{}.osm'.format(city), tile_size=1)
350
+ canvas = tiler.query(bbox)
351
+ uv, yaw, prob, neural_map, image_rectified, data_, pred = demo.localize(
352
+ image, camera, canvas)
353
+ prior_latlon_pred = proj.unproject(canvas.to_xy(uv))
354
+ pass
evaluation/kitti.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+ from typing import Optional, Tuple
6
+
7
+ from omegaconf import OmegaConf, DictConfig
8
+
9
+ from .. import logger
10
+ from ..data import KittiDataModule
11
+ from .run import evaluate
12
+
13
+
14
+ default_cfg_single = OmegaConf.create({})
15
+ # For the sequential evaluation, we need to center the map around the GT location,
16
+ # since random offsets would accumulate and leave only the GT location with a valid mask.
17
+ # This should not have much impact on the results.
18
+ default_cfg_sequential = OmegaConf.create(
19
+ {
20
+ "data": {
21
+ "mask_radius": KittiDataModule.default_cfg["max_init_error"],
22
+ "prior_range_rotation": KittiDataModule.default_cfg[
23
+ "max_init_error_rotation"
24
+ ]
25
+ + 1,
26
+ "max_init_error": 0,
27
+ "max_init_error_rotation": 0,
28
+ },
29
+ "chunking": {
30
+ "max_length": 100, # about 10s?
31
+ },
32
+ }
33
+ )
34
+
35
+
36
+ def run(
37
+ split: str,
38
+ experiment: str,
39
+ cfg: Optional[DictConfig] = None,
40
+ sequential: bool = False,
41
+ thresholds: Tuple[int] = (1, 3, 5),
42
+ **kwargs,
43
+ ):
44
+ cfg = cfg or {}
45
+ if isinstance(cfg, dict):
46
+ cfg = OmegaConf.create(cfg)
47
+ default = default_cfg_sequential if sequential else default_cfg_single
48
+ cfg = OmegaConf.merge(default, cfg)
49
+ dataset = KittiDataModule(cfg.get("data", {}))
50
+
51
+ metrics = evaluate(
52
+ experiment,
53
+ cfg,
54
+ dataset,
55
+ split=split,
56
+ sequential=sequential,
57
+ viz_kwargs=dict(show_dir_error=True, show_masked_prob=False),
58
+ **kwargs,
59
+ )
60
+
61
+ keys = ["directional_error", "yaw_max_error"]
62
+ if sequential:
63
+ keys += ["directional_seq_error", "yaw_seq_error"]
64
+ for k in keys:
65
+ rec = metrics[k].recall(thresholds).double().numpy().round(2).tolist()
66
+ logger.info("Recall %s: %s at %s m/°", k, rec, thresholds)
67
+ return metrics
68
+
69
+
70
+ if __name__ == "__main__":
71
+ parser = argparse.ArgumentParser()
72
+ parser.add_argument("--experiment", type=str, required=True)
73
+ parser.add_argument(
74
+ "--split", type=str, default="test", choices=["test", "val", "train"]
75
+ )
76
+ parser.add_argument("--sequential", action="store_true")
77
+ parser.add_argument("--output_dir", type=Path)
78
+ parser.add_argument("--num", type=int)
79
+ parser.add_argument("dotlist", nargs="*")
80
+ args = parser.parse_args()
81
+ cfg = OmegaConf.from_cli(args.dotlist)
82
+ run(
83
+ args.split,
84
+ args.experiment,
85
+ cfg,
86
+ args.sequential,
87
+ output_dir=args.output_dir,
88
+ num=args.num,
89
+ )
evaluation/mapillary.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+ from typing import Optional, Tuple
6
+
7
+ from omegaconf import OmegaConf, DictConfig
8
+
9
+ from .. import logger
10
+ from ..conf import data as conf_data_dir
11
+ from ..data import MapillaryDataModule
12
+ from .run import evaluate
13
+
14
+
15
+ split_overrides = {
16
+ "val": {
17
+ "scenes": [
18
+ "sanfrancisco_soma",
19
+ "sanfrancisco_hayes",
20
+ "amsterdam",
21
+ "berlin",
22
+ "lemans",
23
+ "montrouge",
24
+ "toulouse",
25
+ "nantes",
26
+ "vilnius",
27
+ "avignon",
28
+ "helsinki",
29
+ "milan",
30
+ "paris",
31
+ ],
32
+ },
33
+ }
34
+ data_cfg_train = OmegaConf.load(Path(conf_data_dir.__file__).parent / "mapillary.yaml")
35
+ data_cfg = OmegaConf.merge(
36
+ data_cfg_train,
37
+ {
38
+ "return_gps": True,
39
+ "add_map_mask": True,
40
+ "max_init_error": 32,
41
+ "loading": {"val": {"batch_size": 1, "num_workers": 0}},
42
+ },
43
+ )
44
+ default_cfg_single = OmegaConf.create({"data": data_cfg})
45
+ default_cfg_sequential = OmegaConf.create(
46
+ {
47
+ **default_cfg_single,
48
+ "chunking": {
49
+ "max_length": 10,
50
+ },
51
+ }
52
+ )
53
+
54
+
55
+ def run(
56
+ split: str,
57
+ experiment: str,
58
+ cfg: Optional[DictConfig] = None,
59
+ sequential: bool = False,
60
+ thresholds: Tuple[int] = (1, 3, 5),
61
+ **kwargs,
62
+ ):
63
+ cfg = cfg or {}
64
+ if isinstance(cfg, dict):
65
+ cfg = OmegaConf.create(cfg)
66
+ default = default_cfg_sequential if sequential else default_cfg_single
67
+ default = OmegaConf.merge(default, split_overrides[split])
68
+ cfg = OmegaConf.merge(default, cfg)
69
+ dataset = MapillaryDataModule(cfg.get("data", {}))
70
+
71
+ metrics = evaluate(experiment, cfg, dataset, split, sequential=sequential, **kwargs)
72
+
73
+ keys = [
74
+ "xy_max_error",
75
+ "xy_gps_error",
76
+ "yaw_max_error",
77
+ ]
78
+ if sequential:
79
+ keys += [
80
+ "xy_seq_error",
81
+ "xy_gps_seq_error",
82
+ "yaw_seq_error",
83
+ "yaw_gps_seq_error",
84
+ ]
85
+ for k in keys:
86
+ if k not in metrics:
87
+ logger.warning("Key %s not in metrics.", k)
88
+ continue
89
+ rec = metrics[k].recall(thresholds).double().numpy().round(2).tolist()
90
+ logger.info("Recall %s: %s at %s m/°", k, rec, thresholds)
91
+ return metrics
92
+
93
+
94
+ if __name__ == "__main__":
95
+ parser = argparse.ArgumentParser()
96
+ parser.add_argument("--experiment", type=str, required=True)
97
+ parser.add_argument("--split", type=str, default="val", choices=["val"])
98
+ parser.add_argument("--sequential", action="store_true")
99
+ parser.add_argument("--output_dir", type=Path)
100
+ parser.add_argument("--num", type=int)
101
+ parser.add_argument("dotlist", nargs="*")
102
+ args = parser.parse_args()
103
+ cfg = OmegaConf.from_cli(args.dotlist)
104
+ run(
105
+ args.split,
106
+ args.experiment,
107
+ cfg,
108
+ args.sequential,
109
+ output_dir=args.output_dir,
110
+ num=args.num,
111
+ )
evaluation/run.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import functools
4
+ from itertools import islice
5
+ from typing import Callable, Dict, Optional, Tuple
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import torch
10
+ from omegaconf import DictConfig, OmegaConf
11
+ from torchmetrics import MetricCollection
12
+ from pytorch_lightning import seed_everything
13
+ from tqdm import tqdm
14
+
15
+ from logger import logger, EXPERIMENTS_PATH
16
+ from dataset.torch import collate, unbatch_to_device
17
+ from models.voting import argmax_xyr, fuse_gps
18
+ from models.metrics import AngleError, LateralLongitudinalError, Location2DError
19
+ # from models.sequential import GPSAligner, RigidAligner
20
+ from module import GenericModule
21
+ from utils.io import download_file, DATA_URL
22
+ from evaluation.viz import plot_example_single, plot_example_sequential
23
+ from evaluation.utils import write_dump
24
+
25
+
26
+ pretrained_models = dict(
27
+ OrienterNet_MGL=("orienternet_mgl.ckpt", dict(num_rotations=256)),
28
+ )
29
+
30
+
31
+ def resolve_checkpoint_path(experiment_or_path: str) -> Path:
32
+ path = Path(experiment_or_path)
33
+ if not path.exists():
34
+ # provided name of experiment
35
+ path = Path(EXPERIMENTS_PATH, *experiment_or_path.split("/"))
36
+ if not path.exists():
37
+ if experiment_or_path in set(p for p, _ in pretrained_models.values()):
38
+ download_file(f"{DATA_URL}/{experiment_or_path}", path)
39
+ else:
40
+ raise FileNotFoundError(path)
41
+ if path.is_file():
42
+ return path
43
+ # provided only the experiment name
44
+ maybe_path = path / "last-step-v1.ckpt"
45
+ if not maybe_path.exists():
46
+ maybe_path = path / "last.ckpt"
47
+ if not maybe_path.exists():
48
+ raise FileNotFoundError(f"Could not find any checkpoint in {path}.")
49
+ return maybe_path
50
+
51
+
52
+ @torch.no_grad()
53
+ def evaluate_single_image(
54
+ dataloader: torch.utils.data.DataLoader,
55
+ model: GenericModule,
56
+ num: Optional[int] = None,
57
+ callback: Optional[Callable] = None,
58
+ progress: bool = True,
59
+ mask_index: Optional[Tuple[int]] = None,
60
+ has_gps: bool = False,
61
+ ):
62
+ ppm = model.model.conf.pixel_per_meter
63
+ metrics = MetricCollection(model.model.metrics())
64
+ metrics["directional_error"] = LateralLongitudinalError(ppm)
65
+ if has_gps:
66
+ metrics["xy_gps_error"] = Location2DError("uv_gps", ppm)
67
+ metrics["xy_fused_error"] = Location2DError("uv_fused", ppm)
68
+ metrics["yaw_fused_error"] = AngleError("yaw_fused")
69
+ metrics = metrics.to(model.device)
70
+
71
+ for i, batch_ in enumerate(
72
+ islice(tqdm(dataloader, total=num, disable=not progress), num)
73
+ ):
74
+ batch = model.transfer_batch_to_device(batch_, model.device, i)
75
+ # Ablation: mask semantic classes
76
+ if mask_index is not None:
77
+ mask = batch["map"][0, mask_index[0]] == (mask_index[1] + 1)
78
+ batch["map"][0, mask_index[0]][mask] = 0
79
+ pred = model(batch)
80
+
81
+ if has_gps:
82
+ (uv_gps,) = pred["uv_gps"] = batch["uv_gps"]
83
+ pred["log_probs_fused"] = fuse_gps(
84
+ pred["log_probs"], uv_gps, ppm, sigma=batch["accuracy_gps"]
85
+ )
86
+ uvt_fused = argmax_xyr(pred["log_probs_fused"])
87
+ pred["uv_fused"] = uvt_fused[..., :2]
88
+ pred["yaw_fused"] = uvt_fused[..., -1]
89
+ del uv_gps, uvt_fused
90
+
91
+ results = metrics(pred, batch)
92
+ if callback is not None:
93
+ callback(
94
+ i, model, unbatch_to_device(pred), unbatch_to_device(batch_), results
95
+ )
96
+ del batch_, batch, pred, results
97
+
98
+ return metrics.cpu()
99
+
100
+
101
+ @torch.no_grad()
102
+ def evaluate_sequential(
103
+ dataset: torch.utils.data.Dataset,
104
+ chunk2idx: Dict,
105
+ model: GenericModule,
106
+ num: Optional[int] = None,
107
+ shuffle: bool = False,
108
+ callback: Optional[Callable] = None,
109
+ progress: bool = True,
110
+ num_rotations: int = 512,
111
+ mask_index: Optional[Tuple[int]] = None,
112
+ has_gps: bool = True,
113
+ ):
114
+ chunk_keys = list(chunk2idx)
115
+ if shuffle:
116
+ chunk_keys = [chunk_keys[i] for i in torch.randperm(len(chunk_keys))]
117
+ if num is not None:
118
+ chunk_keys = chunk_keys[:num]
119
+ lengths = [len(chunk2idx[k]) for k in chunk_keys]
120
+ logger.info(
121
+ "Min/max/med lengths: %d/%d/%d, total number of images: %d",
122
+ min(lengths),
123
+ np.median(lengths),
124
+ max(lengths),
125
+ sum(lengths),
126
+ )
127
+ viz = callback is not None
128
+
129
+ metrics = MetricCollection(model.model.metrics())
130
+ ppm = model.model.conf.pixel_per_meter
131
+ metrics["directional_error"] = LateralLongitudinalError(ppm)
132
+ metrics["xy_seq_error"] = Location2DError("uv_seq", ppm)
133
+ metrics["yaw_seq_error"] = AngleError("yaw_seq")
134
+ metrics["directional_seq_error"] = LateralLongitudinalError(ppm, key="uv_seq")
135
+ if has_gps:
136
+ metrics["xy_gps_error"] = Location2DError("uv_gps", ppm)
137
+ metrics["xy_gps_seq_error"] = Location2DError("uv_gps_seq", ppm)
138
+ metrics["yaw_gps_seq_error"] = AngleError("yaw_gps_seq")
139
+ metrics = metrics.to(model.device)
140
+
141
+ keys_save = ["uvr_max", "uv_max", "yaw_max", "uv_expectation"]
142
+ if has_gps:
143
+ keys_save.append("uv_gps")
144
+ if viz:
145
+ keys_save.append("log_probs")
146
+
147
+ for chunk_index, key in enumerate(tqdm(chunk_keys, disable=not progress)):
148
+ indices = chunk2idx[key]
149
+ aligner = RigidAligner(track_priors=viz, num_rotations=num_rotations)
150
+ if has_gps:
151
+ aligner_gps = GPSAligner(track_priors=viz, num_rotations=num_rotations)
152
+ batches = []
153
+ preds = []
154
+ for i in indices:
155
+ data = dataset[i]
156
+ data = model.transfer_batch_to_device(data, model.device, 0)
157
+ pred = model(collate([data]))
158
+
159
+ canvas = data["canvas"]
160
+ data["xy_geo"] = xy = canvas.to_xy(data["uv"].double())
161
+ data["yaw"] = yaw = data["roll_pitch_yaw"][-1].double()
162
+ aligner.update(pred["log_probs"][0], canvas, xy, yaw)
163
+
164
+ if has_gps:
165
+ (uv_gps) = pred["uv_gps"] = data["uv_gps"][None]
166
+ xy_gps = canvas.to_xy(uv_gps.double())
167
+ aligner_gps.update(xy_gps, data["accuracy_gps"], canvas, xy, yaw)
168
+
169
+ if not viz:
170
+ data.pop("image")
171
+ data.pop("map")
172
+ batches.append(data)
173
+ preds.append({k: pred[k][0] for k in keys_save})
174
+ del pred
175
+
176
+ xy_gt = torch.stack([b["xy_geo"] for b in batches])
177
+ yaw_gt = torch.stack([b["yaw"] for b in batches])
178
+ aligner.compute()
179
+ xy_seq, yaw_seq = aligner.transform(xy_gt, yaw_gt)
180
+ if has_gps:
181
+ aligner_gps.compute()
182
+ xy_gps_seq, yaw_gps_seq = aligner_gps.transform(xy_gt, yaw_gt)
183
+ results = []
184
+ for i in range(len(indices)):
185
+ preds[i]["uv_seq"] = batches[i]["canvas"].to_uv(xy_seq[i]).float()
186
+ preds[i]["yaw_seq"] = yaw_seq[i].float()
187
+ if has_gps:
188
+ preds[i]["uv_gps_seq"] = (
189
+ batches[i]["canvas"].to_uv(xy_gps_seq[i]).float()
190
+ )
191
+ preds[i]["yaw_gps_seq"] = yaw_gps_seq[i].float()
192
+ results.append(metrics(preds[i], batches[i]))
193
+ if viz:
194
+ callback(chunk_index, model, batches, preds, results, aligner)
195
+ del aligner, preds, batches, results
196
+ return metrics.cpu()
197
+
198
+
199
+ def evaluate(
200
+ experiment: str,
201
+ cfg: DictConfig,
202
+ dataset,
203
+ split: str,
204
+ sequential: bool = False,
205
+ output_dir: Optional[Path] = None,
206
+ callback: Optional[Callable] = None,
207
+ num_workers: int = 1,
208
+ viz_kwargs=None,
209
+ **kwargs,
210
+ ):
211
+ if experiment in pretrained_models:
212
+ experiment, cfg_override = pretrained_models[experiment]
213
+ cfg = OmegaConf.merge(OmegaConf.create(dict(model=cfg_override)), cfg)
214
+
215
+ logger.info("Evaluating model %s with config %s", experiment, cfg)
216
+ checkpoint_path = resolve_checkpoint_path(experiment)
217
+ model = GenericModule.load_from_checkpoint(
218
+ checkpoint_path, cfg=cfg, find_best=not experiment.endswith(".ckpt")
219
+ )
220
+ model = model.eval()
221
+ if torch.cuda.is_available():
222
+ model = model.cuda()
223
+
224
+ dataset.prepare_data()
225
+ dataset.setup()
226
+
227
+ if output_dir is not None:
228
+ output_dir.mkdir(exist_ok=True, parents=True)
229
+ if callback is None:
230
+ if sequential:
231
+ callback = plot_example_sequential
232
+ else:
233
+ callback = plot_example_single
234
+ callback = functools.partial(
235
+ callback, out_dir=output_dir, **(viz_kwargs or {})
236
+ )
237
+ kwargs = {**kwargs, "callback": callback}
238
+
239
+ seed_everything(dataset.cfg.seed)
240
+ if sequential:
241
+ dset, chunk2idx = dataset.sequence_dataset(split, **cfg.chunking)
242
+ metrics = evaluate_sequential(dset, chunk2idx, model, **kwargs)
243
+ else:
244
+ loader = dataset.dataloader(split, shuffle=True, num_workers=num_workers)
245
+ metrics = evaluate_single_image(loader, model, **kwargs)
246
+
247
+ results = metrics.compute()
248
+ logger.info("All results: %s", results)
249
+ if output_dir is not None:
250
+ write_dump(output_dir, experiment, cfg, results, metrics)
251
+ logger.info("Outputs have been written to %s.", output_dir)
252
+ return metrics
evaluation/utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import numpy as np
4
+ from omegaconf import OmegaConf
5
+
6
+ from utils.io import write_json
7
+
8
+
9
+ def compute_recall(errors):
10
+ num_elements = len(errors)
11
+ sort_idx = np.argsort(errors)
12
+ errors = np.array(errors.copy())[sort_idx]
13
+ recall = (np.arange(num_elements) + 1) / num_elements
14
+ recall = np.r_[0, recall]
15
+ errors = np.r_[0, errors]
16
+ return errors, recall
17
+
18
+
19
+ def compute_auc(errors, recall, thresholds):
20
+ aucs = []
21
+ for t in thresholds:
22
+ last_index = np.searchsorted(errors, t, side="right")
23
+ r = np.r_[recall[:last_index], recall[last_index - 1]]
24
+ e = np.r_[errors[:last_index], t]
25
+ auc = np.trapz(r, x=e) / t
26
+ aucs.append(auc * 100)
27
+ return aucs
28
+
29
+
30
+ def write_dump(output_dir, experiment, cfg, results, metrics):
31
+ dump = {
32
+ "experiment": experiment,
33
+ "cfg": OmegaConf.to_container(cfg),
34
+ "results": results,
35
+ "errors": {},
36
+ }
37
+ for k, m in metrics.items():
38
+ if hasattr(m, "get_errors"):
39
+ dump["errors"][k] = m.get_errors().numpy()
40
+ write_json(output_dir / "log.json", dump)
evaluation/viz.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import numpy as np
4
+ import torch
5
+ import matplotlib.pyplot as plt
6
+
7
+ from utils.io import write_torch_image
8
+ from utils.viz_2d import plot_images, features_to_RGB, save_plot
9
+ from utils.viz_localization import (
10
+ likelihood_overlay,
11
+ plot_pose,
12
+ plot_dense_rotations,
13
+ add_circle_inset,
14
+ )
15
+ from osm.viz import Colormap, plot_nodes
16
+
17
+
18
+ def plot_example_single(
19
+ idx,
20
+ model,
21
+ pred,
22
+ data,
23
+ results,
24
+ plot_bev=True,
25
+ out_dir=None,
26
+ fig_for_paper=False,
27
+ show_gps=False,
28
+ show_fused=False,
29
+ show_dir_error=False,
30
+ show_masked_prob=False,
31
+ ):
32
+ scene, name, rasters, uv_gt = (data[k] for k in ("scene", "name", "map", "uv"))
33
+ uv_gps = data.get("uv_gps")
34
+ yaw_gt = data["roll_pitch_yaw"][-1].numpy()
35
+ image = data["image"].permute(1, 2, 0)
36
+ if "valid" in data:
37
+ image = image.masked_fill(~data["valid"].unsqueeze(-1), 0.3)
38
+
39
+ lp_uvt = lp_uv = pred["log_probs"]
40
+ if show_fused and "log_probs_fused" in pred:
41
+ lp_uvt = lp_uv = pred["log_probs_fused"]
42
+ elif not show_masked_prob and "scores_unmasked" in pred:
43
+ lp_uvt = lp_uv = pred["scores_unmasked"]
44
+ has_rotation = lp_uvt.ndim == 3
45
+ if has_rotation:
46
+ lp_uv = lp_uvt.max(-1).values
47
+ if lp_uv.min() > -np.inf:
48
+ lp_uv = lp_uv.clip(min=np.percentile(lp_uv, 1))
49
+ prob = lp_uv.exp()
50
+ uv_p, yaw_p = pred["uv_max"], pred.get("yaw_max")
51
+ if show_fused and "uv_fused" in pred:
52
+ uv_p, yaw_p = pred["uv_fused"], pred.get("yaw_fused")
53
+ feats_map = pred["map"]["map_features"][0]
54
+ (feats_map_rgb,) = features_to_RGB(feats_map.numpy())
55
+
56
+ text1 = rf'$\Delta xy$: {results["xy_max_error"]:.1f}m'
57
+ if has_rotation:
58
+ text1 += rf', $\Delta\theta$: {results["yaw_max_error"]:.1f}°'
59
+ if show_fused and "xy_fused_error" in results:
60
+ text1 += rf', $\Delta xy_{{fused}}$: {results["xy_fused_error"]:.1f}m'
61
+ text1 += rf', $\Delta\theta_{{fused}}$: {results["yaw_fused_error"]:.1f}°'
62
+ if show_dir_error and "directional_error" in results:
63
+ err_lat, err_lon = results["directional_error"]
64
+ text1 += rf", $\Delta$lateral/longitundinal={err_lat:.1f}m/{err_lon:.1f}m"
65
+ if "xy_gps_error" in results:
66
+ text1 += rf', $\Delta xy_{{GPS}}$: {results["xy_gps_error"]:.1f}m'
67
+
68
+ map_viz = Colormap.apply(rasters)
69
+ overlay = likelihood_overlay(prob.numpy(), map_viz.mean(-1, keepdims=True))
70
+ plot_images(
71
+ [image, map_viz, overlay, feats_map_rgb],
72
+ titles=[text1, "map", "likelihood", "neural map"],
73
+ dpi=75,
74
+ cmaps="jet",
75
+ )
76
+ fig = plt.gcf()
77
+ axes = fig.axes
78
+ axes[1].images[0].set_interpolation("none")
79
+ axes[2].images[0].set_interpolation("none")
80
+ Colormap.add_colorbar()
81
+ plot_nodes(1, rasters[2])
82
+
83
+ if show_gps and uv_gps is not None:
84
+ plot_pose([1], uv_gps, c="blue")
85
+ plot_pose([1], uv_gt, yaw_gt, c="red")
86
+ plot_pose([1], uv_p, yaw_p, c="k")
87
+ plot_dense_rotations(2, lp_uvt.exp())
88
+ inset_center = pred["uv_max"] if results["xy_max_error"] < 5 else uv_gt
89
+ axins = add_circle_inset(axes[2], inset_center)
90
+ axins.scatter(*uv_gt, lw=1, c="red", ec="k", s=50, zorder=15)
91
+ axes[0].text(
92
+ 0.003,
93
+ 0.003,
94
+ f"{scene}/{name}",
95
+ transform=axes[0].transAxes,
96
+ fontsize=3,
97
+ va="bottom",
98
+ ha="left",
99
+ color="w",
100
+ )
101
+ plt.show()
102
+ if out_dir is not None:
103
+ name_ = name.replace("/", "_")
104
+ p = str(out_dir / f"{scene}_{name_}_{{}}.pdf")
105
+ save_plot(p.format("pred"))
106
+ plt.close()
107
+
108
+ if fig_for_paper:
109
+ # !cp ../datasets/MGL/{scene}/images/{name}.jpg {out_dir}/{scene}_{name}.jpg
110
+ plot_images([map_viz])
111
+ plt.gca().images[0].set_interpolation("none")
112
+ plot_nodes(0, rasters[2])
113
+ plot_pose([0], uv_gt, yaw_gt, c="red")
114
+ plot_pose([0], pred["uv_max"], pred["yaw_max"], c="k")
115
+ save_plot(p.format("map"))
116
+ plt.close()
117
+ plot_images([lp_uv], cmaps="jet")
118
+ plot_dense_rotations(0, lp_uvt.exp())
119
+ save_plot(p.format("loglikelihood"), dpi=100)
120
+ plt.close()
121
+ plot_images([overlay])
122
+ plt.gca().images[0].set_interpolation("none")
123
+ axins = add_circle_inset(plt.gca(), inset_center)
124
+ axins.scatter(*uv_gt, lw=1, c="red", ec="k", s=50)
125
+ save_plot(p.format("likelihood"))
126
+ plt.close()
127
+ write_torch_image(
128
+ p.format("neuralmap").replace("pdf", "jpg"), feats_map_rgb
129
+ )
130
+ write_torch_image(p.format("image").replace("pdf", "jpg"), image.numpy())
131
+
132
+ if not plot_bev:
133
+ return
134
+
135
+ feats_q = pred["features_bev"]
136
+ mask_bev = pred["valid_bev"]
137
+ prior = None
138
+ if "log_prior" in pred["map"]:
139
+ prior = pred["map"]["log_prior"][0].sigmoid()
140
+ if "bev" in pred and "confidence" in pred["bev"]:
141
+ conf_q = pred["bev"]["confidence"]
142
+ else:
143
+ conf_q = torch.norm(feats_q, dim=0)
144
+ conf_q = conf_q.masked_fill(~mask_bev, np.nan)
145
+ (feats_q_rgb,) = features_to_RGB(feats_q.numpy(), masks=[mask_bev.numpy()])
146
+ # feats_map_rgb, feats_q_rgb, = features_to_RGB(
147
+ # feats_map.numpy(), feats_q.numpy(), masks=[None, mask_bev])
148
+ norm_map = torch.norm(feats_map, dim=0)
149
+
150
+ plot_images(
151
+ [conf_q, feats_q_rgb, norm_map] + ([] if prior is None else [prior]),
152
+ titles=["BEV confidence", "BEV features", "map norm"]
153
+ + ([] if prior is None else ["map prior"]),
154
+ dpi=50,
155
+ cmaps="jet",
156
+ )
157
+ plt.show()
158
+
159
+ if out_dir is not None:
160
+ save_plot(p.format("bev"))
161
+ plt.close()
162
+
163
+
164
+ def plot_example_sequential(
165
+ idx,
166
+ model,
167
+ pred,
168
+ data,
169
+ results,
170
+ plot_bev=True,
171
+ out_dir=None,
172
+ fig_for_paper=False,
173
+ show_gps=False,
174
+ show_fused=False,
175
+ show_dir_error=False,
176
+ show_masked_prob=False,
177
+ ):
178
+ return
flagged/inp/10d2e4a8712491181c2f48b61f5003b216d2b9f9/tmp48n9eoyh.png ADDED
flagged/inp/e1b18d44d9e381d586209f73a015fed7f688822b/tmp86ith_2q.png ADDED
flagged/log.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ inp,longitude,latitude,Area,output,flag,username,timestamp
2
+ E:\MapLocNetDemo\Demo\flagged\inp\10d2e4a8712491181c2f48b61f5003b216d2b9f9\tmp48n9eoyh.png,70.1,40,256,E:\MapLocNetDemo\Demo\flagged\output\tmp59657zop.json,,,2023-09-22 10:07:17.488625
3
+ E:\MapLocNetDemo\Demo\flagged\inp\e1b18d44d9e381d586209f73a015fed7f688822b\tmp86ith_2q.png,70.1,40,256,E:\MapLocNetDemo\Demo\flagged\output\tmpbs17s28d.json,,,2023-09-22 10:07:21.485967
flagged/output/tmp59657zop.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"label": "bull mastiff\n", "confidences": [{"label": "bull mastiff\n", "confidence": 0.24759389460086823}, {"label": "pug\n", "confidence": 0.0916372761130333}, {"label": "Great Dane\n", "confidence": 0.08652031421661377}]}
flagged/output/tmpbs17s28d.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"label": "bull mastiff\n", "confidences": [{"label": "bull mastiff\n", "confidence": 0.24759389460086823}, {"label": "pug\n", "confidence": 0.0916372761130333}, {"label": "Great Dane\n", "confidence": 0.08652031421661377}]}
images/00000.jpg ADDED
images/00011.jpg ADDED
images/00022.jpg ADDED
images/00033.jpg ADDED
images/cat_dog.png ADDED
label.txt ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tench
2
+ goldfish
3
+ great white shark
4
+ tiger shark
5
+ hammerhead
6
+ electric ray
7
+ stingray
8
+ cock
9
+ hen
10
+ ostrich
11
+ brambling
12
+ goldfinch
13
+ house finch
14
+ junco
15
+ indigo bunting
16
+ robin
17
+ bulbul
18
+ jay
19
+ magpie
20
+ chickadee
21
+ water ouzel
22
+ kite
23
+ bald eagle
24
+ vulture
25
+ great grey owl
26
+ European fire salamander
27
+ common newt
28
+ eft
29
+ spotted salamander
30
+ axolotl
31
+ bullfrog
32
+ tree frog
33
+ tailed frog
34
+ loggerhead
35
+ leatherback turtle
36
+ mud turtle
37
+ terrapin
38
+ box turtle
39
+ banded gecko
40
+ common iguana
41
+ American chameleon
42
+ whiptail
43
+ agama
44
+ frilled lizard
45
+ alligator lizard
46
+ Gila monster
47
+ green lizard
48
+ African chameleon
49
+ Komodo dragon
50
+ African crocodile
51
+ American alligator
52
+ triceratops
53
+ thunder snake
54
+ ringneck snake
55
+ hognose snake
56
+ green snake
57
+ king snake
58
+ garter snake
59
+ water snake
60
+ vine snake
61
+ night snake
62
+ boa constrictor
63
+ rock python
64
+ Indian cobra
65
+ green mamba
66
+ sea snake
67
+ horned viper
68
+ diamondback
69
+ sidewinder
70
+ trilobite
71
+ harvestman
72
+ scorpion
73
+ black and gold garden spider
74
+ barn spider
75
+ garden spider
76
+ black widow
77
+ tarantula
78
+ wolf spider
79
+ tick
80
+ centipede
81
+ black grouse
82
+ ptarmigan
83
+ ruffed grouse
84
+ prairie chicken
85
+ peacock
86
+ quail
87
+ partridge
88
+ African grey
89
+ macaw
90
+ sulphur-crested cockatoo
91
+ lorikeet
92
+ coucal
93
+ bee eater
94
+ hornbill
95
+ hummingbird
96
+ jacamar
97
+ toucan
98
+ drake
99
+ red-breasted merganser
100
+ goose
101
+ black swan
102
+ tusker
103
+ echidna
104
+ platypus
105
+ wallaby
106
+ koala
107
+ wombat
108
+ jellyfish
109
+ sea anemone
110
+ brain coral
111
+ flatworm
112
+ nematode
113
+ conch
114
+ snail
115
+ slug
116
+ sea slug
117
+ chiton
118
+ chambered nautilus
119
+ Dungeness crab
120
+ rock crab
121
+ fiddler crab
122
+ king crab
123
+ American lobster
124
+ spiny lobster
125
+ crayfish
126
+ hermit crab
127
+ isopod
128
+ white stork
129
+ black stork
130
+ spoonbill
131
+ flamingo
132
+ little blue heron
133
+ American egret
134
+ bittern
135
+ crane
136
+ limpkin
137
+ European gallinule
138
+ American coot
139
+ bustard
140
+ ruddy turnstone
141
+ red-backed sandpiper
142
+ redshank
143
+ dowitcher
144
+ oystercatcher
145
+ pelican
146
+ king penguin
147
+ albatross
148
+ grey whale
149
+ killer whale
150
+ dugong
151
+ sea lion
152
+ Chihuahua
153
+ Japanese spaniel
154
+ Maltese dog
155
+ Pekinese
156
+ Shih-Tzu
157
+ Blenheim spaniel
158
+ papillon
159
+ toy terrier
160
+ Rhodesian ridgeback
161
+ Afghan hound
162
+ basset
163
+ beagle
164
+ bloodhound
165
+ bluetick
166
+ black-and-tan coonhound
167
+ Walker hound
168
+ English foxhound
169
+ redbone
170
+ borzoi
171
+ Irish wolfhound
172
+ Italian greyhound
173
+ whippet
174
+ Ibizan hound
175
+ Norwegian elkhound
176
+ otterhound
177
+ Saluki
178
+ Scottish deerhound
179
+ Weimaraner
180
+ Staffordshire bullterrier
181
+ American Staffordshire terrier
182
+ Bedlington terrier
183
+ Border terrier
184
+ Kerry blue terrier
185
+ Irish terrier
186
+ Norfolk terrier
187
+ Norwich terrier
188
+ Yorkshire terrier
189
+ wire-haired fox terrier
190
+ Lakeland terrier
191
+ Sealyham terrier
192
+ Airedale
193
+ cairn
194
+ Australian terrier
195
+ Dandie Dinmont
196
+ Boston bull
197
+ miniature schnauzer
198
+ giant schnauzer
199
+ standard schnauzer
200
+ Scotch terrier
201
+ Tibetan terrier
202
+ silky terrier
203
+ soft-coated wheaten terrier
204
+ West Highland white terrier
205
+ Lhasa
206
+ flat-coated retriever
207
+ curly-coated retriever
208
+ golden retriever
209
+ Labrador retriever
210
+ Chesapeake Bay retriever
211
+ German short-haired pointer
212
+ vizsla
213
+ English setter
214
+ Irish setter
215
+ Gordon setter
216
+ Brittany spaniel
217
+ clumber
218
+ English springer
219
+ Welsh springer spaniel
220
+ cocker spaniel
221
+ Sussex spaniel
222
+ Irish water spaniel
223
+ kuvasz
224
+ schipperke
225
+ groenendael
226
+ malinois
227
+ briard
228
+ kelpie
229
+ komondor
230
+ Old English sheepdog
231
+ Shetland sheepdog
232
+ collie
233
+ Border collie
234
+ Bouvier des Flandres
235
+ Rottweiler
236
+ German shepherd
237
+ Doberman
238
+ miniature pinscher
239
+ Greater Swiss Mountain dog
240
+ Bernese mountain dog
241
+ Appenzeller
242
+ EntleBucher
243
+ boxer
244
+ bull mastiff
245
+ Tibetan mastiff
246
+ French bulldog
247
+ Great Dane
248
+ Saint Bernard
249
+ Eskimo dog
250
+ malamute
251
+ Siberian husky
252
+ dalmatian
253
+ affenpinscher
254
+ basenji
255
+ pug
256
+ Leonberg
257
+ Newfoundland
258
+ Great Pyrenees
259
+ Samoyed
260
+ Pomeranian
261
+ chow
262
+ keeshond
263
+ Brabancon griffon
264
+ Pembroke
265
+ Cardigan
266
+ toy poodle
267
+ miniature poodle
268
+ standard poodle
269
+ Mexican hairless
270
+ timber wolf
271
+ white wolf
272
+ red wolf
273
+ coyote
274
+ dingo
275
+ dhole
276
+ African hunting dog
277
+ hyena
278
+ red fox
279
+ kit fox
280
+ Arctic fox
281
+ grey fox
282
+ tabby
283
+ tiger cat
284
+ Persian cat
285
+ Siamese cat
286
+ Egyptian cat
287
+ cougar
288
+ lynx
289
+ leopard
290
+ snow leopard
291
+ jaguar
292
+ lion
293
+ tiger
294
+ cheetah
295
+ brown bear
296
+ American black bear
297
+ ice bear
298
+ sloth bear
299
+ mongoose
300
+ meerkat
301
+ tiger beetle
302
+ ladybug
303
+ ground beetle
304
+ long-horned beetle
305
+ leaf beetle
306
+ dung beetle
307
+ rhinoceros beetle
308
+ weevil
309
+ fly
310
+ bee
311
+ ant
312
+ grasshopper
313
+ cricket
314
+ walking stick
315
+ cockroach
316
+ mantis
317
+ cicada
318
+ leafhopper
319
+ lacewing
320
+ dragonfly
321
+ damselfly
322
+ admiral
323
+ ringlet
324
+ monarch
325
+ cabbage butterfly
326
+ sulphur butterfly
327
+ lycaenid
328
+ starfish
329
+ sea urchin
330
+ sea cucumber
331
+ wood rabbit
332
+ hare
333
+ Angora
334
+ hamster
335
+ porcupine
336
+ fox squirrel
337
+ marmot
338
+ beaver
339
+ guinea pig
340
+ sorrel
341
+ zebra
342
+ hog
343
+ wild boar
344
+ warthog
345
+ hippopotamus
346
+ ox
347
+ water buffalo
348
+ bison
349
+ ram
350
+ bighorn
351
+ ibex
352
+ hartebeest
353
+ impala
354
+ gazelle
355
+ Arabian camel
356
+ llama
357
+ weasel
358
+ mink
359
+ polecat
360
+ black-footed ferret
361
+ otter
362
+ skunk
363
+ badger
364
+ armadillo
365
+ three-toed sloth
366
+ orangutan
367
+ gorilla
368
+ chimpanzee
369
+ gibbon
370
+ siamang
371
+ guenon
372
+ patas
373
+ baboon
374
+ macaque
375
+ langur
376
+ colobus
377
+ proboscis monkey
378
+ marmoset
379
+ capuchin
380
+ howler monkey
381
+ titi
382
+ spider monkey
383
+ squirrel monkey
384
+ Madagascar cat
385
+ indri
386
+ Indian elephant
387
+ African elephant
388
+ lesser panda
389
+ giant panda
390
+ barracouta
391
+ eel
392
+ coho
393
+ rock beauty
394
+ anemone fish
395
+ sturgeon
396
+ gar
397
+ lionfish
398
+ puffer
399
+ abacus
400
+ abaya
401
+ academic gown
402
+ accordion
403
+ acoustic guitar
404
+ aircraft carrier
405
+ airliner
406
+ airship
407
+ altar
408
+ ambulance
409
+ amphibian
410
+ analog clock
411
+ apiary
412
+ apron
413
+ ashcan
414
+ assault rifle
415
+ backpack
416
+ bakery
417
+ balance beam
418
+ balloon
419
+ ballpoint
420
+ Band Aid
421
+ banjo
422
+ bannister
423
+ barbell
424
+ barber chair
425
+ barbershop
426
+ barn
427
+ barometer
428
+ barrel
429
+ barrow
430
+ baseball
431
+ basketball
432
+ bassinet
433
+ bassoon
434
+ bathing cap
435
+ bath towel
436
+ bathtub
437
+ beach wagon
438
+ beacon
439
+ beaker
440
+ bearskin
441
+ beer bottle
442
+ beer glass
443
+ bell cote
444
+ bib
445
+ bicycle-built-for-two
446
+ bikini
447
+ binder
448
+ binoculars
449
+ birdhouse
450
+ boathouse
451
+ bobsled
452
+ bolo tie
453
+ bonnet
454
+ bookcase
455
+ bookshop
456
+ bottlecap
457
+ bow
458
+ bow tie
459
+ brass
460
+ brassiere
461
+ breakwater
462
+ breastplate
463
+ broom
464
+ bucket
465
+ buckle
466
+ bulletproof vest
467
+ bullet train
468
+ butcher shop
469
+ cab
470
+ caldron
471
+ candle
472
+ cannon
473
+ canoe
474
+ can opener
475
+ cardigan
476
+ car mirror
477
+ carousel
478
+ carpenter's kit
479
+ carton
480
+ car wheel
481
+ cash machine
482
+ cassette
483
+ cassette player
484
+ castle
485
+ catamaran
486
+ CD player
487
+ cello
488
+ cellular telephone
489
+ chain
490
+ chainlink fence
491
+ chain mail
492
+ chain saw
493
+ chest
494
+ chiffonier
495
+ chime
496
+ china cabinet
497
+ Christmas stocking
498
+ church
499
+ cinema
500
+ cleaver
501
+ cliff dwelling
502
+ cloak
503
+ clog
504
+ cocktail shaker
505
+ coffee mug
506
+ coffeepot
507
+ coil
508
+ combination lock
509
+ computer keyboard
510
+ confectionery
511
+ container ship
512
+ convertible
513
+ corkscrew
514
+ cornet
515
+ cowboy boot
516
+ cowboy hat
517
+ cradle
518
+ crane
519
+ crash helmet
520
+ crate
521
+ crib
522
+ Crock Pot
523
+ croquet ball
524
+ crutch
525
+ cuirass
526
+ dam
527
+ desk
528
+ desktop computer
529
+ dial telephone
530
+ diaper
531
+ digital clock
532
+ digital watch
533
+ dining table
534
+ dishrag
535
+ dishwasher
536
+ disk brake
537
+ dock
538
+ dogsled
539
+ dome
540
+ doormat
541
+ drilling platform
542
+ drum
543
+ drumstick
544
+ dumbbell
545
+ Dutch oven
546
+ electric fan
547
+ electric guitar
548
+ electric locomotive
549
+ entertainment center
550
+ envelope
551
+ espresso maker
552
+ face powder
553
+ feather boa
554
+ file
555
+ fireboat
556
+ fire engine
557
+ fire screen
558
+ flagpole
559
+ flute
560
+ folding chair
561
+ football helmet
562
+ forklift
563
+ fountain
564
+ fountain pen
565
+ four-poster
566
+ freight car
567
+ French horn
568
+ frying pan
569
+ fur coat
570
+ garbage truck
571
+ gasmask
572
+ gas pump
573
+ goblet
574
+ go-kart
575
+ golf ball
576
+ golfcart
577
+ gondola
578
+ gong
579
+ gown
580
+ grand piano
581
+ greenhouse
582
+ grille
583
+ grocery store
584
+ guillotine
585
+ hair slide
586
+ hair spray
587
+ half track
588
+ hammer
589
+ hamper
590
+ hand blower
591
+ hand-held computer
592
+ handkerchief
593
+ hard disc
594
+ harmonica
595
+ harp
596
+ harvester
597
+ hatchet
598
+ holster
599
+ home theater
600
+ honeycomb
601
+ hook
602
+ hoopskirt
603
+ horizontal bar
604
+ horse cart
605
+ hourglass
606
+ iPod
607
+ iron
608
+ jack-o'-lantern
609
+ jean
610
+ jeep
611
+ jersey
612
+ jigsaw puzzle
613
+ jinrikisha
614
+ joystick
615
+ kimono
616
+ knee pad
617
+ knot
618
+ lab coat
619
+ ladle
620
+ lampshade
621
+ laptop
622
+ lawn mower
623
+ lens cap
624
+ letter opener
625
+ library
626
+ lifeboat
627
+ lighter
628
+ limousine
629
+ liner
630
+ lipstick
631
+ Loafer
632
+ lotion
633
+ loudspeaker
634
+ loupe
635
+ lumbermill
636
+ magnetic compass
637
+ mailbag
638
+ mailbox
639
+ maillot
640
+ maillot
641
+ manhole cover
642
+ maraca
643
+ marimba
644
+ mask
645
+ matchstick
646
+ maypole
647
+ maze
648
+ measuring cup
649
+ medicine chest
650
+ megalith
651
+ microphone
652
+ microwave
653
+ military uniform
654
+ milk can
655
+ minibus
656
+ miniskirt
657
+ minivan
658
+ missile
659
+ mitten
660
+ mixing bowl
661
+ mobile home
662
+ Model T
663
+ modem
664
+ monastery
665
+ monitor
666
+ moped
667
+ mortar
668
+ mortarboard
669
+ mosque
670
+ mosquito net
671
+ motor scooter
672
+ mountain bike
673
+ mountain tent
674
+ mouse
675
+ mousetrap
676
+ moving van
677
+ muzzle
678
+ nail
679
+ neck brace
680
+ necklace
681
+ nipple
682
+ notebook
683
+ obelisk
684
+ oboe
685
+ ocarina
686
+ odometer
687
+ oil filter
688
+ organ
689
+ oscilloscope
690
+ overskirt
691
+ oxcart
692
+ oxygen mask
693
+ packet
694
+ paddle
695
+ paddlewheel
696
+ padlock
697
+ paintbrush
698
+ pajama
699
+ palace
700
+ panpipe
701
+ paper towel
702
+ parachute
703
+ parallel bars
704
+ park bench
705
+ parking meter
706
+ passenger car
707
+ patio
708
+ pay-phone
709
+ pedestal
710
+ pencil box
711
+ pencil sharpener
712
+ perfume
713
+ Petri dish
714
+ photocopier
715
+ pick
716
+ pickelhaube
717
+ picket fence
718
+ pickup
719
+ pier
720
+ piggy bank
721
+ pill bottle
722
+ pillow
723
+ ping-pong ball
724
+ pinwheel
725
+ pirate
726
+ pitcher
727
+ plane
728
+ planetarium
729
+ plastic bag
730
+ plate rack
731
+ plow
732
+ plunger
733
+ Polaroid camera
734
+ pole
735
+ police van
736
+ poncho
737
+ pool table
738
+ pop bottle
739
+ pot
740
+ potter's wheel
741
+ power drill
742
+ prayer rug
743
+ printer
744
+ prison
745
+ projectile
746
+ projector
747
+ puck
748
+ punching bag
749
+ purse
750
+ quill
751
+ quilt
752
+ racer
753
+ racket
754
+ radiator
755
+ radio
756
+ radio telescope
757
+ rain barrel
758
+ recreational vehicle
759
+ reel
760
+ reflex camera
761
+ refrigerator
762
+ remote control
763
+ restaurant
764
+ revolver
765
+ rifle
766
+ rocking chair
767
+ rotisserie
768
+ rubber eraser
769
+ rugby ball
770
+ rule
771
+ running shoe
772
+ safe
773
+ safety pin
774
+ saltshaker
775
+ sandal
776
+ sarong
777
+ sax
778
+ scabbard
779
+ scale
780
+ school bus
781
+ schooner
782
+ scoreboard
783
+ screen
784
+ screw
785
+ screwdriver
786
+ seat belt
787
+ sewing machine
788
+ shield
789
+ shoe shop
790
+ shoji
791
+ shopping basket
792
+ shopping cart
793
+ shovel
794
+ shower cap
795
+ shower curtain
796
+ ski
797
+ ski mask
798
+ sleeping bag
799
+ slide rule
800
+ sliding door
801
+ slot
802
+ snorkel
803
+ snowmobile
804
+ snowplow
805
+ soap dispenser
806
+ soccer ball
807
+ sock
808
+ solar dish
809
+ sombrero
810
+ soup bowl
811
+ space bar
812
+ space heater
813
+ space shuttle
814
+ spatula
815
+ speedboat
816
+ spider web
817
+ spindle
818
+ sports car
819
+ spotlight
820
+ stage
821
+ steam locomotive
822
+ steel arch bridge
823
+ steel drum
824
+ stethoscope
825
+ stole
826
+ stone wall
827
+ stopwatch
828
+ stove
829
+ strainer
830
+ streetcar
831
+ stretcher
832
+ studio couch
833
+ stupa
834
+ submarine
835
+ suit
836
+ sundial
837
+ sunglass
838
+ sunglasses
839
+ sunscreen
840
+ suspension bridge
841
+ swab
842
+ sweatshirt
843
+ swimming trunks
844
+ swing
845
+ switch
846
+ syringe
847
+ table lamp
848
+ tank
849
+ tape player
850
+ teapot
851
+ teddy
852
+ television
853
+ tennis ball
854
+ thatch
855
+ theater curtain
856
+ thimble
857
+ thresher
858
+ throne
859
+ tile roof
860
+ toaster
861
+ tobacco shop
862
+ toilet seat
863
+ torch
864
+ totem pole
865
+ tow truck
866
+ toyshop
867
+ tractor
868
+ trailer truck
869
+ tray
870
+ trench coat
871
+ tricycle
872
+ trimaran
873
+ tripod
874
+ triumphal arch
875
+ trolleybus
876
+ trombone
877
+ tub
878
+ turnstile
879
+ typewriter keyboard
880
+ umbrella
881
+ unicycle
882
+ upright
883
+ vacuum
884
+ vase
885
+ vault
886
+ velvet
887
+ vending machine
888
+ vestment
889
+ viaduct
890
+ violin
891
+ volleyball
892
+ waffle iron
893
+ wall clock
894
+ wallet
895
+ wardrobe
896
+ warplane
897
+ washbasin
898
+ washer
899
+ water bottle
900
+ water jug
901
+ water tower
902
+ whiskey jug
903
+ whistle
904
+ wig
905
+ window screen
906
+ window shade
907
+ Windsor tie
908
+ wine bottle
909
+ wing
910
+ wok
911
+ wooden spoon
912
+ wool
913
+ worm fence
914
+ wreck
915
+ yawl
916
+ yurt
917
+ web site
918
+ comic book
919
+ crossword puzzle
920
+ street sign
921
+ traffic light
922
+ book jacket
923
+ menu
924
+ plate
925
+ guacamole
926
+ consomme
927
+ hot pot
928
+ trifle
929
+ ice cream
930
+ ice lolly
931
+ French loaf
932
+ bagel
933
+ pretzel
934
+ cheeseburger
935
+ hotdog
936
+ mashed potato
937
+ head cabbage
938
+ broccoli
939
+ cauliflower
940
+ zucchini
941
+ spaghetti squash
942
+ acorn squash
943
+ butternut squash
944
+ cucumber
945
+ artichoke
946
+ bell pepper
947
+ cardoon
948
+ mushroom
949
+ Granny Smith
950
+ strawberry
951
+ orange
952
+ lemon
953
+ fig
954
+ pineapple
955
+ banana
956
+ jackfruit
957
+ custard apple
958
+ pomegranate
959
+ hay
960
+ carbonara
961
+ chocolate sauce
962
+ dough
963
+ meat loaf
964
+ pizza
965
+ potpie
966
+ burrito
967
+ red wine
968
+ espresso
969
+ cup
970
+ eggnog
971
+ alp
972
+ bubble
973
+ cliff
974
+ coral reef
975
+ geyser
976
+ lakeside
977
+ promontory
978
+ sandbar
979
+ seashore
980
+ valley
981
+ volcano
982
+ ballplayer
983
+ groom
984
+ scuba diver
985
+ rapeseed
986
+ daisy
987
+ yellow lady's slipper
988
+ corn
989
+ acorn
990
+ hip
991
+ buckeye
992
+ coral fungus
993
+ agaric
994
+ gyromitra
995
+ stinkhorn
996
+ earthstar
997
+ hen-of-the-woods
998
+ bolete
999
+ ear
1000
+ toilet tissue
logger.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from pathlib import Path
4
+ import logging
5
+
6
+ import pytorch_lightning # noqa: F401
7
+
8
+
9
+ formatter = logging.Formatter(
10
+ fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s",
11
+ datefmt="%Y-%m-%d %H:%M:%S",
12
+ )
13
+ handler = logging.StreamHandler()
14
+ handler.setFormatter(formatter)
15
+ handler.setLevel(logging.INFO)
16
+
17
+ logger = logging.getLogger("maploc")
18
+ logger.setLevel(logging.INFO)
19
+ logger.addHandler(handler)
20
+ logger.propagate = False
21
+
22
+ pl_logger = logging.getLogger("pytorch_lightning")
23
+ if len(pl_logger.handlers):
24
+ pl_logger.handlers[0].setFormatter(formatter)
25
+
26
+ repo_dir = Path(__file__).parent
27
+ EXPERIMENTS_PATH = repo_dir / "experiments/"
28
+ DATASETS_PATH = repo_dir / "datasets/"
main.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import gradio as gr
4
+ import torch
5
+ from torchvision import transforms
6
+ import requests
7
+ from PIL import Image
8
+ from demo import Demo,read_input_image_test,show_result,vis_image_feature
9
+ from osm.tiling import TileManager
10
+ from osm.viz import Colormap, plot_nodes
11
+ from utils.viz_2d import plot_images
12
+ import numpy as np
13
+ from utils.viz_2d import features_to_RGB
14
+ from utils.viz_localization import (
15
+ likelihood_overlay,
16
+ plot_dense_rotations,
17
+ add_circle_inset,
18
+ )
19
+ from osm.viz import GeoPlotter
20
+ import matplotlib.pyplot as plt
21
+ import random
22
+ from geopy.distance import geodesic
23
+
24
+ experiment_or_path = "weight/last-step-checkpointing.ckpt"
25
+ # experiment_or_path="experiments/maplocanet_0906_diffhight/last-step-checkpointing.ckpt"
26
+ image_path = 'images/00000.jpg'
27
+
28
+ # prior_latlon = (37.75704325989902, -122.435941445631)
29
+ # tile_size_meters = 128
30
+ model = Demo(experiment_or_path=experiment_or_path, num_rotations=128, device='cpu')
31
+
32
+ def demo_localize(image,long,lat,tile_size_meters):
33
+ # inp = Image.fromarray(inp.astype('uint8'), 'RGB')
34
+ # inp = transforms.ToTensor()(inp).unsqueeze(0)
35
+ prior_latlon=(lat,long)
36
+ image, camera, gravity, proj, bbox, true_prior_latlon = read_input_image_test(
37
+ image,
38
+ prior_latlon=prior_latlon,
39
+ tile_size_meters=tile_size_meters, # try 64, 256, etc.
40
+ )
41
+ tiler = TileManager.from_bbox(projection=proj, bbox=bbox, ppm=1, tile_size=tile_size_meters)
42
+ # tiler = TileManager.from_bbox(projection=proj, bbox=bbox + 10,ppm=1,path=root/city/'{}.osm'.format(city), tile_size=1)
43
+ canvas = tiler.query(bbox)
44
+ uv, yaw, prob, neural_map, image_rectified, data_, pred = model.localize(
45
+ image, camera, canvas)
46
+ prior_latlon_pred = proj.unproject(canvas.to_xy(uv))
47
+
48
+ map_viz = Colormap.apply(canvas.raster)
49
+ map_vis_image_result = map_viz * 255
50
+ map_vis_image_result =show_result(map_vis_image_result.astype(np.uint8), uv, yaw)
51
+ # map_vis_image_result = show_result(map_vis_image_result.astype(np.uint8), True_uv,
52
+ # uv,
53
+ # 90.0 - yaw_T,
54
+ # yaw)
55
+ # return prior_latlon_pred
56
+ uab_feature_rgb = vis_image_feature(pred['features_image'][0].cpu().numpy())
57
+ map_viz = cv2.resize(map_viz, (prob.numpy().shape[0], prob.numpy().shape[1]))
58
+ overlay = likelihood_overlay(prob.numpy().max(-1), map_viz.mean(-1, keepdims=True))
59
+ (neural_map_rgb,) = features_to_RGB(neural_map.numpy())
60
+ fig=plot_images([image, map_vis_image_result / 255, overlay, uab_feature_rgb, neural_map_rgb],
61
+ titles=["UAV image", "map","likelihood","UAV feature","map feature"])
62
+ # plot_images([overlay, neural_map_rgb], titles=["prediction", "neural map"])
63
+ # ax = plt.gcf().axes[2]
64
+ # ax.scatter(*canvas.to_uv(bbox.center), s=5, c="red")
65
+ # plot_dense_rotations(ax, prob, w=0.005, s=1 / 25)
66
+ # add_circle_inset(ax, uv)
67
+
68
+ # Plot as interactive figure
69
+ bbox_latlon = proj.unproject(canvas.bbox)
70
+ plot2 = GeoPlotter(zoom=16.5)
71
+ plot2.raster(map_viz, bbox_latlon, opacity=0.5)
72
+ plot2.raster(likelihood_overlay(prob.numpy().max(-1)), proj.unproject(bbox))
73
+ plot2.points(prior_latlon[:2], "red", name="location prior", size=10)
74
+ plot2.points(proj.unproject(canvas.to_xy(uv)), "black", name="argmax", size=10)
75
+ plot2.bbox(bbox_latlon, "blue", name="map tile")
76
+ # plot2.fig.show()
77
+ return fig,plot2.fig,str(prior_latlon_pred)
78
+ # model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
79
+ #标题
80
+ title = "MapLocNet"
81
+ #标题下的描述,支持md格式
82
+ description = "UAV Vision-based Geo-Localization Using Vectorized Maps"
83
+
84
+ # outputs = gr.outputs.Label(num_top_classes=3)
85
+ outputs = gr.Plot()
86
+ interface = gr.Interface(fn=demo_localize,
87
+ inputs=["image",
88
+ gr.Number(label="Prior location-longitude)"),
89
+ gr.Number(label="Prior location-longitude)"),
90
+ gr.Radio([64, 128, 256], label="Search radius (meters)", info="vectorized map size"),
91
+ # gr.inputs.RadioGroup(label="Search radius (meters)",["English", "French", "Spanish"]),
92
+ # gr.Slider(64, 512,label='Search radius (meters)')
93
+ ],
94
+ outputs=["plot","plot","text"],
95
+ title=title,
96
+ description=description,
97
+ examples=[['images/00000.jpg',-122.435941445631,37.75704325989902,128]])
98
+ interface.launch(share=True)
models/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
4
+ # https://github.com/cvg/pixloc
5
+ # Released under the Apache License 2.0
6
+
7
+ import inspect
8
+
9
+ from .base import BaseModel
10
+
11
+
12
+ def get_class(mod_name, base_path, BaseClass):
13
+ """Get the class object which inherits from BaseClass and is defined in
14
+ the module named mod_name, child of base_path.
15
+ """
16
+ mod_path = "{}.{}".format(base_path, mod_name)
17
+ mod = __import__(mod_path, fromlist=[""])
18
+ classes = inspect.getmembers(mod, inspect.isclass)
19
+ # Filter classes defined in the module
20
+ classes = [c for c in classes if c[1].__module__ == mod_path]
21
+ # Filter classes inherited from BaseModel
22
+ classes = [c for c in classes if issubclass(c[1], BaseClass)]
23
+ assert len(classes) == 1, classes
24
+ return classes[0][1]
25
+
26
+
27
+ def get_model(name):
28
+ if name == "localizer":
29
+ name = "localizer_basic"
30
+ elif name == "rotation_localizer":
31
+ name = "localizer_basic_rotation"
32
+ elif name == "bev_localizer":
33
+ name = "localizer_bev_plane"
34
+ return get_class(name, __name__, BaseModel)
models/base.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
4
+ # https://github.com/cvg/pixloc
5
+ # Released under the Apache License 2.0
6
+
7
+ """
8
+ Base class for trainable models.
9
+ """
10
+
11
+ from abc import ABCMeta, abstractmethod
12
+ from copy import copy
13
+
14
+ import omegaconf
15
+ from omegaconf import OmegaConf
16
+ from torch import nn
17
+
18
+
19
+ class BaseModel(nn.Module, metaclass=ABCMeta):
20
+ """
21
+ What the child model is expect to declare:
22
+ default_conf: dictionary of the default configuration of the model.
23
+ It recursively updates the default_conf of all parent classes, and
24
+ it is updated by the user-provided configuration passed to __init__.
25
+ Configurations can be nested.
26
+
27
+ required_data_keys: list of expected keys in the input data dictionary.
28
+
29
+ strict_conf (optional): boolean. If false, BaseModel does not raise
30
+ an error when the user provides an unknown configuration entry.
31
+
32
+ _init(self, conf): initialization method, where conf is the final
33
+ configuration object (also accessible with `self.conf`). Accessing
34
+ unknown configuration entries will raise an error.
35
+
36
+ _forward(self, data): method that returns a dictionary of batched
37
+ prediction tensors based on a dictionary of batched input data tensors.
38
+
39
+ loss(self, pred, data): method that returns a dictionary of losses,
40
+ computed from model predictions and input data. Each loss is a batch
41
+ of scalars, i.e. a torch.Tensor of shape (B,).
42
+ The total loss to be optimized has the key `'total'`.
43
+
44
+ metrics(self, pred, data): method that returns a dictionary of metrics,
45
+ each as a batch of scalars.
46
+ """
47
+
48
+ base_default_conf = {
49
+ "name": None,
50
+ "trainable": True, # if false: do not optimize this model parameters
51
+ "freeze_batch_normalization": False, # use test-time statistics
52
+ }
53
+ default_conf = {}
54
+ required_data_keys = []
55
+ strict_conf = True
56
+
57
+ def __init__(self, conf):
58
+ """Perform some logic and call the _init method of the child model."""
59
+ super().__init__()
60
+ default_conf = OmegaConf.merge(
61
+ self.base_default_conf, OmegaConf.create(self.default_conf)
62
+ )
63
+ if self.strict_conf:
64
+ OmegaConf.set_struct(default_conf, True)
65
+
66
+ # fixme: backward compatibility
67
+ if "pad" in conf and "pad" not in default_conf: # backward compat.
68
+ with omegaconf.read_write(conf):
69
+ with omegaconf.open_dict(conf):
70
+ conf["interpolation"] = {"pad": conf.pop("pad")}
71
+
72
+ if isinstance(conf, dict):
73
+ conf = OmegaConf.create(conf)
74
+ self.conf = conf = OmegaConf.merge(default_conf, conf)
75
+ OmegaConf.set_readonly(conf, True)
76
+ OmegaConf.set_struct(conf, True)
77
+ self.required_data_keys = copy(self.required_data_keys)
78
+ self._init(conf)
79
+
80
+ if not conf.trainable:
81
+ for p in self.parameters():
82
+ p.requires_grad = False
83
+
84
+ def train(self, mode=True):
85
+ super().train(mode)
86
+
87
+ def freeze_bn(module):
88
+ if isinstance(module, nn.modules.batchnorm._BatchNorm):
89
+ module.eval()
90
+
91
+ if self.conf.freeze_batch_normalization:
92
+ self.apply(freeze_bn)
93
+
94
+ return self
95
+
96
+ def forward(self, data):
97
+ """Check the data and call the _forward method of the child model."""
98
+
99
+ def recursive_key_check(expected, given):
100
+ for key in expected:
101
+ assert key in given, f"Missing key {key} in data"
102
+ if isinstance(expected, dict):
103
+ recursive_key_check(expected[key], given[key])
104
+
105
+ recursive_key_check(self.required_data_keys, data)
106
+ return self._forward(data)
107
+
108
+ @abstractmethod
109
+ def _init(self, conf):
110
+ """To be implemented by the child class."""
111
+ raise NotImplementedError
112
+
113
+ @abstractmethod
114
+ def _forward(self, data):
115
+ """To be implemented by the child class."""
116
+ raise NotImplementedError
117
+
118
+ def loss(self, pred, data):
119
+ """To be implemented by the child class."""
120
+ raise NotImplementedError
121
+
122
+ def metrics(self):
123
+ return {} # no metrics
models/feature_extractor.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
4
+ # https://github.com/cvg/pixloc
5
+ # Released under the Apache License 2.0
6
+
7
+ """
8
+ Flexible UNet model which takes any Torchvision backbone as encoder.
9
+ Predicts multi-level feature and makes sure that they are well aligned.
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torchvision
15
+
16
+ from .base import BaseModel
17
+ from .utils import checkpointed
18
+
19
+
20
+ class DecoderBlock(nn.Module):
21
+ def __init__(
22
+ self, previous, skip, out, num_convs=1, norm=nn.BatchNorm2d, padding="zeros"
23
+ ):
24
+ super().__init__()
25
+
26
+ self.upsample = nn.Upsample(
27
+ scale_factor=2, mode="bilinear", align_corners=False
28
+ )
29
+
30
+ layers = []
31
+ for i in range(num_convs):
32
+ conv = nn.Conv2d(
33
+ previous + skip if i == 0 else out,
34
+ out,
35
+ kernel_size=3,
36
+ padding=1,
37
+ bias=norm is None,
38
+ padding_mode=padding,
39
+ )
40
+ layers.append(conv)
41
+ if norm is not None:
42
+ layers.append(norm(out))
43
+ layers.append(nn.ReLU(inplace=True))
44
+ self.layers = nn.Sequential(*layers)
45
+
46
+ def forward(self, previous, skip):
47
+ upsampled = self.upsample(previous)
48
+ # If the shape of the input map `skip` is not a multiple of 2,
49
+ # it will not match the shape of the upsampled map `upsampled`.
50
+ # If the downsampling uses ceil_mode=False, we nedd to crop `skip`.
51
+ # If it uses ceil_mode=True (not supported here), we should pad it.
52
+ _, _, hu, wu = upsampled.shape
53
+ _, _, hs, ws = skip.shape
54
+ assert (hu <= hs) and (wu <= ws), "Using ceil_mode=True in pooling?"
55
+ # assert (hu == hs) and (wu == ws), 'Careful about padding'
56
+ skip = skip[:, :, :hu, :wu]
57
+ return self.layers(torch.cat([upsampled, skip], dim=1))
58
+
59
+
60
+ class AdaptationBlock(nn.Sequential):
61
+ def __init__(self, inp, out):
62
+ conv = nn.Conv2d(inp, out, kernel_size=1, padding=0, bias=True)
63
+ super().__init__(conv)
64
+
65
+
66
+ class FeatureExtractor(BaseModel):
67
+ default_conf = {
68
+ "pretrained": True,
69
+ "input_dim": 3,
70
+ "output_scales": [0, 2, 4], # what scales to adapt and output
71
+ "output_dim": 128, # # of channels in output feature maps
72
+ "encoder": "vgg16", # string (torchvision net) or list of channels
73
+ "num_downsample": 4, # how many downsample block (if VGG-style net)
74
+ "decoder": [64, 64, 64, 64], # list of channels of decoder
75
+ "decoder_norm": "nn.BatchNorm2d", # normalization ind decoder blocks
76
+ "do_average_pooling": False,
77
+ "checkpointed": False, # whether to use gradient checkpointing
78
+ "padding": "zeros",
79
+ }
80
+ mean = [0.485, 0.456, 0.406]
81
+ std = [0.229, 0.224, 0.225]
82
+
83
+ def build_encoder(self, conf):
84
+ assert isinstance(conf.encoder, str)
85
+ if conf.pretrained:
86
+ assert conf.input_dim == 3
87
+ Encoder = getattr(torchvision.models, conf.encoder)
88
+ encoder = Encoder(weights="DEFAULT" if conf.pretrained else None)
89
+ Block = checkpointed(torch.nn.Sequential, do=conf.checkpointed)
90
+ assert max(conf.output_scales) <= conf.num_downsample
91
+
92
+ if conf.encoder.startswith("vgg"):
93
+ # Parse the layers and pack them into downsampling blocks
94
+ # It's easy for VGG-style nets because of their linear structure.
95
+ # This does not handle strided convs and residual connections
96
+ skip_dims = []
97
+ previous_dim = None
98
+ blocks = [[]]
99
+ for i, layer in enumerate(encoder.features):
100
+ if isinstance(layer, torch.nn.Conv2d):
101
+ # Change the first conv layer if the input dim mismatches
102
+ if i == 0 and conf.input_dim != layer.in_channels:
103
+ args = {k: getattr(layer, k) for k in layer.__constants__}
104
+ args.pop("output_padding")
105
+ layer = torch.nn.Conv2d(
106
+ **{**args, "in_channels": conf.input_dim}
107
+ )
108
+ previous_dim = layer.out_channels
109
+ elif isinstance(layer, torch.nn.MaxPool2d):
110
+ assert previous_dim is not None
111
+ skip_dims.append(previous_dim)
112
+ if (conf.num_downsample + 1) == len(blocks):
113
+ break
114
+ blocks.append([]) # start a new block
115
+ if conf.do_average_pooling:
116
+ assert layer.dilation == 1
117
+ layer = torch.nn.AvgPool2d(
118
+ kernel_size=layer.kernel_size,
119
+ stride=layer.stride,
120
+ padding=layer.padding,
121
+ ceil_mode=layer.ceil_mode,
122
+ count_include_pad=False,
123
+ )
124
+ blocks[-1].append(layer)
125
+ encoder = [Block(*b) for b in blocks]
126
+ elif conf.encoder.startswith("resnet"):
127
+ # Manually define the ResNet blocks such that the downsampling comes first
128
+ assert conf.encoder[len("resnet") :] in ["18", "34", "50", "101"]
129
+ assert conf.input_dim == 3, "Unsupported for now."
130
+ block1 = torch.nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu)
131
+ block2 = torch.nn.Sequential(encoder.maxpool, encoder.layer1)
132
+ block3 = encoder.layer2
133
+ block4 = encoder.layer3
134
+ block5 = encoder.layer4
135
+ blocks = [block1, block2, block3, block4, block5]
136
+ # Extract the output dimension of each block
137
+ skip_dims = [encoder.conv1.out_channels]
138
+ for i in range(1, 5):
139
+ modules = getattr(encoder, f"layer{i}")[-1]._modules
140
+ conv = sorted(k for k in modules if k.startswith("conv"))[-1]
141
+ skip_dims.append(modules[conv].out_channels)
142
+ # Add a dummy block such that the first one does not downsample
143
+ encoder = [torch.nn.Identity()] + [Block(b) for b in blocks]
144
+ skip_dims = [3] + skip_dims
145
+ # Trim based on the requested encoder size
146
+ encoder = encoder[: conf.num_downsample + 1]
147
+ skip_dims = skip_dims[: conf.num_downsample + 1]
148
+ else:
149
+ raise NotImplementedError(conf.encoder)
150
+
151
+ assert (conf.num_downsample + 1) == len(encoder)
152
+ encoder = nn.ModuleList(encoder)
153
+
154
+ return encoder, skip_dims
155
+
156
+ def _init(self, conf):
157
+ # Encoder
158
+ self.encoder, skip_dims = self.build_encoder(conf)
159
+ self.skip_dims = skip_dims
160
+
161
+ def update_padding(module):
162
+ if isinstance(module, nn.Conv2d):
163
+ module.padding_mode = conf.padding
164
+
165
+ if conf.padding != "zeros":
166
+ self.encoder.apply(update_padding)
167
+
168
+ # Decoder
169
+ if conf.decoder is not None:
170
+ assert len(conf.decoder) == (len(skip_dims) - 1)
171
+ Block = checkpointed(DecoderBlock, do=conf.checkpointed)
172
+ norm = eval(conf.decoder_norm) if conf.decoder_norm else None # noqa
173
+
174
+ previous = skip_dims[-1]
175
+ decoder = []
176
+ for out, skip in zip(conf.decoder, skip_dims[:-1][::-1]):
177
+ decoder.append(
178
+ Block(previous, skip, out, norm=norm, padding=conf.padding)
179
+ )
180
+ previous = out
181
+ self.decoder = nn.ModuleList(decoder)
182
+
183
+ # Adaptation layers
184
+ adaptation = []
185
+ for idx, i in enumerate(conf.output_scales):
186
+ if conf.decoder is None or i == (len(self.encoder) - 1):
187
+ input_ = skip_dims[i]
188
+ else:
189
+ input_ = conf.decoder[-1 - i]
190
+
191
+ # out_dim can be an int (same for all scales) or a list (per scale)
192
+ dim = conf.output_dim
193
+ if not isinstance(dim, int):
194
+ dim = dim[idx]
195
+
196
+ block = AdaptationBlock(input_, dim)
197
+ adaptation.append(block)
198
+ self.adaptation = nn.ModuleList(adaptation)
199
+ self.scales = [2**s for s in conf.output_scales]
200
+
201
+ def _forward(self, data):
202
+ image = data["image"]
203
+ if self.conf.pretrained:
204
+ mean, std = image.new_tensor(self.mean), image.new_tensor(self.std)
205
+ image = (image - mean[:, None, None]) / std[:, None, None]
206
+
207
+ skip_features = []
208
+ features = image
209
+ for block in self.encoder:
210
+ features = block(features)
211
+ skip_features.append(features)
212
+
213
+ if self.conf.decoder:
214
+ pre_features = [skip_features[-1]]
215
+ for block, skip in zip(self.decoder, skip_features[:-1][::-1]):
216
+ pre_features.append(block(pre_features[-1], skip))
217
+ pre_features = pre_features[::-1] # fine to coarse
218
+ else:
219
+ pre_features = skip_features
220
+
221
+ out_features = []
222
+ for adapt, i in zip(self.adaptation, self.conf.output_scales):
223
+ out_features.append(adapt(pre_features[i]))
224
+ pred = {"feature_maps": out_features, "skip_features": skip_features}
225
+ return pred
226
+
227
+ def loss(self, pred, data):
228
+ raise NotImplementedError
229
+
230
+ def metrics(self, pred, data):
231
+ raise NotImplementedError
models/feature_extractor_v2.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision
7
+ from torchvision.models.feature_extraction import create_feature_extractor
8
+
9
+ from .base import BaseModel
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class DecoderBlock(nn.Module):
15
+ def __init__(
16
+ self, previous, out, ksize=3, num_convs=1, norm=nn.BatchNorm2d, padding="zeros"
17
+ ):
18
+ super().__init__()
19
+ layers = []
20
+ for i in range(num_convs):
21
+ conv = nn.Conv2d(
22
+ previous if i == 0 else out,
23
+ out,
24
+ kernel_size=ksize,
25
+ padding=ksize // 2,
26
+ bias=norm is None,
27
+ padding_mode=padding,
28
+ )
29
+ layers.append(conv)
30
+ if norm is not None:
31
+ layers.append(norm(out))
32
+ layers.append(nn.ReLU(inplace=True))
33
+ self.layers = nn.Sequential(*layers)
34
+
35
+ def forward(self, previous, skip):
36
+ _, _, hp, wp = previous.shape
37
+ _, _, hs, ws = skip.shape
38
+ scale = 2 ** np.round(np.log2(np.array([hs / hp, ws / wp])))
39
+ upsampled = nn.functional.interpolate(
40
+ previous, scale_factor=scale.tolist(), mode="bilinear", align_corners=False
41
+ )
42
+ # If the shape of the input map `skip` is not a multiple of 2,
43
+ # it will not match the shape of the upsampled map `upsampled`.
44
+ # If the downsampling uses ceil_mode=False, we nedd to crop `skip`.
45
+ # If it uses ceil_mode=True (not supported here), we should pad it.
46
+ _, _, hu, wu = upsampled.shape
47
+ _, _, hs, ws = skip.shape
48
+ if (hu <= hs) and (wu <= ws):
49
+ skip = skip[:, :, :hu, :wu]
50
+ elif (hu >= hs) and (wu >= ws):
51
+ skip = nn.functional.pad(skip, [0, wu - ws, 0, hu - hs])
52
+ else:
53
+ raise ValueError(
54
+ f"Inconsistent skip vs upsampled shapes: {(hs, ws)}, {(hu, wu)}"
55
+ )
56
+
57
+ return self.layers(skip) + upsampled
58
+
59
+
60
+ class FPN(nn.Module):
61
+ def __init__(self, in_channels_list, out_channels, **kw):
62
+ super().__init__()
63
+ self.first = nn.Conv2d(
64
+ in_channels_list[-1], out_channels, 1, padding=0, bias=True
65
+ )
66
+ self.blocks = nn.ModuleList(
67
+ [
68
+ DecoderBlock(c, out_channels, ksize=1, **kw)
69
+ for c in in_channels_list[::-1][1:]
70
+ ]
71
+ )
72
+ self.out = nn.Sequential(
73
+ nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
74
+ nn.BatchNorm2d(out_channels),
75
+ nn.ReLU(inplace=True),
76
+ )
77
+
78
+ def forward(self, layers):
79
+ feats = None
80
+ for idx, x in enumerate(reversed(layers.values())):
81
+ if feats is None:
82
+ feats = self.first(x)
83
+ else:
84
+ feats = self.blocks[idx - 1](feats, x)
85
+ out = self.out(feats)
86
+ return out
87
+
88
+
89
+ def remove_conv_stride(conv):
90
+ conv_new = nn.Conv2d(
91
+ conv.in_channels,
92
+ conv.out_channels,
93
+ conv.kernel_size,
94
+ bias=conv.bias is not None,
95
+ stride=1,
96
+ padding=conv.padding,
97
+ )
98
+ conv_new.weight = conv.weight
99
+ conv_new.bias = conv.bias
100
+ return conv_new
101
+
102
+
103
+ class FeatureExtractor(BaseModel):
104
+ default_conf = {
105
+ "pretrained": True,
106
+ "input_dim": 3,
107
+ "output_dim": 128, # # of channels in output feature maps
108
+ "encoder": "resnet50", # torchvision net as string
109
+ "remove_stride_from_first_conv": False,
110
+ "num_downsample": None, # how many downsample block
111
+ "decoder_norm": "nn.BatchNorm2d", # normalization ind decoder blocks
112
+ "do_average_pooling": False,
113
+ "checkpointed": False, # whether to use gradient checkpointing
114
+ }
115
+ mean = [0.485, 0.456, 0.406]
116
+ std = [0.229, 0.224, 0.225]
117
+
118
+ def build_encoder(self, conf):
119
+ assert isinstance(conf.encoder, str)
120
+ if conf.pretrained:
121
+ assert conf.input_dim == 3
122
+ Encoder = getattr(torchvision.models, conf.encoder)
123
+
124
+ kw = {}
125
+ if conf.encoder.startswith("resnet"):
126
+ layers = ["relu", "layer1", "layer2", "layer3", "layer4"]
127
+ kw["replace_stride_with_dilation"] = [False, False, False]
128
+ elif conf.encoder == "vgg13":
129
+ layers = [
130
+ "features.3",
131
+ "features.8",
132
+ "features.13",
133
+ "features.18",
134
+ "features.23",
135
+ ]
136
+ elif conf.encoder == "vgg16":
137
+ layers = [
138
+ "features.3",
139
+ "features.8",
140
+ "features.15",
141
+ "features.22",
142
+ "features.29",
143
+ ]
144
+ else:
145
+ raise NotImplementedError(conf.encoder)
146
+
147
+ if conf.num_downsample is not None:
148
+ layers = layers[: conf.num_downsample]
149
+ encoder = Encoder(weights="DEFAULT" if conf.pretrained else None, **kw)
150
+ encoder = create_feature_extractor(encoder, return_nodes=layers)
151
+ if conf.encoder.startswith("resnet") and conf.remove_stride_from_first_conv:
152
+ encoder.conv1 = remove_conv_stride(encoder.conv1)
153
+
154
+ if conf.do_average_pooling:
155
+ raise NotImplementedError
156
+ if conf.checkpointed:
157
+ raise NotImplementedError
158
+
159
+ return encoder, layers
160
+
161
+ def _init(self, conf):
162
+ # Preprocessing
163
+ self.register_buffer("mean_", torch.tensor(self.mean), persistent=False)
164
+ self.register_buffer("std_", torch.tensor(self.std), persistent=False)
165
+
166
+ # Encoder
167
+ self.encoder, self.layers = self.build_encoder(conf)
168
+ s = 128
169
+ inp = torch.zeros(1, 3, s, s)
170
+ features = list(self.encoder(inp).values())
171
+ self.skip_dims = [x.shape[1] for x in features]
172
+ self.layer_strides = [s / f.shape[-1] for f in features]
173
+ self.scales = [self.layer_strides[0]]
174
+
175
+ # Decoder
176
+ norm = eval(conf.decoder_norm) if conf.decoder_norm else None # noqa
177
+ self.decoder = FPN(self.skip_dims, out_channels=conf.output_dim, norm=norm)
178
+
179
+ logger.debug(
180
+ "Built feature extractor with layers {name:dim:stride}:\n"
181
+ f"{list(zip(self.layers, self.skip_dims, self.layer_strides))}\n"
182
+ f"and output scales {self.scales}."
183
+ )
184
+
185
+ def _forward(self, data):
186
+ image = data["image"]
187
+ image = (image - self.mean_[:, None, None]) / self.std_[:, None, None]
188
+
189
+ skip_features = self.encoder(image)
190
+ output = self.decoder(skip_features)
191
+ pred = {"feature_maps": [output], "skip_features": skip_features}
192
+ return pred
models/map_encoder.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .base import BaseModel
7
+ from .feature_extractor import FeatureExtractor
8
+
9
+
10
+ class MapEncoder(BaseModel):
11
+ default_conf = {
12
+ "embedding_dim": "???",
13
+ "output_dim": None,
14
+ "num_classes": "???",
15
+ "backbone": "???",
16
+ "unary_prior": False,
17
+ }
18
+
19
+ def _init(self, conf):
20
+ self.embeddings = torch.nn.ModuleDict(
21
+ {
22
+ k: torch.nn.Embedding(n + 1, conf.embedding_dim)
23
+ for k, n in conf.num_classes.items()
24
+ }
25
+ )
26
+ #num_calsses:{'areas': 7, 'ways': 10, 'nodes': 33}
27
+ input_dim = len(conf.num_classes) * conf.embedding_dim
28
+ output_dim = conf.output_dim
29
+ if output_dim is None:
30
+ output_dim = conf.backbone.output_dim
31
+ if conf.unary_prior:
32
+ output_dim += 1
33
+ if conf.backbone is None:
34
+ self.encoder = nn.Conv2d(input_dim, output_dim, 1)
35
+ elif conf.backbone == "simple":
36
+ self.encoder = nn.Sequential(
37
+ nn.Conv2d(input_dim, 128, 3, padding=1),
38
+ nn.ReLU(inplace=True),
39
+ nn.Conv2d(128, 128, 3, padding=1),
40
+ nn.ReLU(inplace=True),
41
+ nn.Conv2d(128, output_dim, 3, padding=1),
42
+ )
43
+ else:
44
+ self.encoder = FeatureExtractor(
45
+ {
46
+ **conf.backbone,
47
+ "input_dim": input_dim,
48
+ "output_dim": output_dim,
49
+ }
50
+ )
51
+
52
+ def _forward(self, data):
53
+ embeddings = [
54
+ self.embeddings[k](data["map"][:, i])
55
+ for i, k in enumerate(("areas", "ways", "nodes"))
56
+ ]
57
+ embeddings = torch.cat(embeddings, dim=-1).permute(0, 3, 1, 2)
58
+ if isinstance(self.encoder, BaseModel):
59
+ features = self.encoder({"image": embeddings})["feature_maps"]
60
+ else:
61
+ features = [self.encoder(embeddings)]
62
+ pred = {}
63
+ if self.conf.unary_prior:
64
+ pred["log_prior"] = [f[:, -1] for f in features]
65
+ features = [f[:, :-1] for f in features]
66
+ pred["map_features"] = features
67
+ return pred
models/maplocnet.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch.nn.functional import normalize
6
+
7
+ from . import get_model
8
+ from models.base import BaseModel
9
+ # from models.bev_net import BEVNet
10
+ # from models.bev_projection import CartesianProjection, PolarProjectionDepth
11
+ from models.voting import (
12
+ argmax_xyr,
13
+ conv2d_fft_batchwise,
14
+ expectation_xyr,
15
+ log_softmax_spatial,
16
+ mask_yaw_prior,
17
+ nll_loss_xyr,
18
+ nll_loss_xyr_smoothed,
19
+ TemplateSampler,
20
+ UAVTemplateSampler,
21
+ UAVTemplateSamplerFast
22
+ )
23
+ from .map_encoder import MapEncoder
24
+ from .metrics import AngleError, AngleRecall, Location2DError, Location2DRecall
25
+
26
+
27
+ class MapLocNet(BaseModel):
28
+ default_conf = {
29
+ "image_size": "???",
30
+ "val_citys":"???",
31
+ "image_encoder": "???",
32
+ "map_encoder": "???",
33
+ "bev_net": "???",
34
+ "latent_dim": "???",
35
+ "matching_dim": "???",
36
+ "scale_range": [0, 9],
37
+ "num_scale_bins": "???",
38
+ "z_min": None,
39
+ "z_max": "???",
40
+ "x_max": "???",
41
+ "pixel_per_meter": "???",
42
+ "num_rotations": "???",
43
+ "add_temperature": False,
44
+ "normalize_features": False,
45
+ "padding_matching": "replicate",
46
+ "apply_map_prior": True,
47
+ "do_label_smoothing": False,
48
+ "sigma_xy": 1,
49
+ "sigma_r": 2,
50
+ # depcreated
51
+ "depth_parameterization": "scale",
52
+ "norm_depth_scores": False,
53
+ "normalize_scores_by_dim": False,
54
+ "normalize_scores_by_num_valid": True,
55
+ "prior_renorm": True,
56
+ "retrieval_dim": None,
57
+ }
58
+
59
+ def _init(self, conf):
60
+ assert not self.conf.norm_depth_scores
61
+ assert self.conf.depth_parameterization == "scale"
62
+ assert not self.conf.normalize_scores_by_dim
63
+ assert self.conf.normalize_scores_by_num_valid
64
+ assert self.conf.prior_renorm
65
+
66
+ Encoder = get_model(conf.image_encoder.get("name", "feature_extractor_v2"))
67
+ self.image_encoder = Encoder(conf.image_encoder.backbone)
68
+ self.map_encoder = MapEncoder(conf.map_encoder)
69
+ # self.bev_net = None if conf.bev_net is None else BEVNet(conf.bev_net)
70
+
71
+ ppm = conf.pixel_per_meter
72
+ # self.projection_polar = PolarProjectionDepth(
73
+ # conf.z_max,
74
+ # ppm,
75
+ # conf.scale_range,
76
+ # conf.z_min,
77
+ # )
78
+ # self.projection_bev = CartesianProjection(
79
+ # conf.z_max, conf.x_max, ppm, conf.z_min
80
+ # )
81
+ # self.template_sampler = TemplateSampler(
82
+ # self.projection_bev.grid_xz, ppm, conf.num_rotations
83
+ # )
84
+ # self.template_sampler = UAVTemplateSamplerFast(conf.num_rotations,w=conf.image_size//2)
85
+ self.template_sampler = UAVTemplateSampler(conf.num_rotations)
86
+ # self.scale_classifier = torch.nn.Linear(conf.latent_dim, conf.num_scale_bins)
87
+ # if conf.bev_net is None:
88
+ # self.feature_projection = torch.nn.Linear(
89
+ # conf.latent_dim, conf.matching_dim
90
+ # )
91
+ if conf.add_temperature:
92
+ temperature = torch.nn.Parameter(torch.tensor(0.0))
93
+ self.register_parameter("temperature", temperature)
94
+
95
+ def exhaustive_voting(self, f_bev, f_map):
96
+ if self.conf.normalize_features:
97
+ f_bev = normalize(f_bev, dim=1)
98
+ f_map = normalize(f_map, dim=1)
99
+
100
+ # Build the templates and exhaustively match against the map.
101
+ # if confidence_bev is not None:
102
+ # f_bev = f_bev * confidence_bev.unsqueeze(1)
103
+ # f_bev = f_bev.masked_fill(~valid_bev.unsqueeze(1), 0.0)
104
+ # torch.save(f_bev, 'f_bev.pt')
105
+ # torch.save(f_map, 'f_map.pt')
106
+
107
+ templates = self.template_sampler(f_bev)#[batch,256,8,129,129]
108
+ # torch.save(templates, 'templates.pt')
109
+ with torch.autocast("cuda", enabled=False):
110
+ scores = conv2d_fft_batchwise(
111
+ f_map.float(),
112
+ templates.float(),
113
+ padding_mode=self.conf.padding_matching,
114
+ )
115
+ if self.conf.add_temperature:
116
+ scores = scores * torch.exp(self.temperature)
117
+
118
+ # Reweight the different rotations based on the number of valid pixels
119
+ # in each template. Axis-aligned rotation have the maximum number of valid pixels.
120
+ # valid_templates = self.template_sampler(valid_bev.float()[None]) > (1 - 1e-4)
121
+ # num_valid = valid_templates.float().sum((-3, -2, -1))
122
+ # scores = scores / num_valid[..., None, None]
123
+ return scores
124
+
125
+ def _forward(self, data):
126
+ pred = {}
127
+ pred_map = pred["map"] = self.map_encoder(data)
128
+ f_map = pred_map["map_features"][0]#[batch,8,256,256]
129
+
130
+ # Extract image features.
131
+ level = 0
132
+ f_image = self.image_encoder(data)["feature_maps"][level]#[batch,128,128,176]
133
+ # print("f_map:",f_map.shape)
134
+
135
+ scores = self.exhaustive_voting(f_image, f_map)#f_bev:[batch,8,64,129] f_map:[batch,8,256,256] confidence:[1,64,129]
136
+ scores = scores.moveaxis(1, -1) # B,H,W,N
137
+ if "log_prior" in pred_map and self.conf.apply_map_prior:
138
+ scores = scores + pred_map["log_prior"][0].unsqueeze(-1)
139
+ # pred["scores_unmasked"] = scores.clone()
140
+ if "map_mask" in data:
141
+ scores.masked_fill_(~data["map_mask"][..., None], -np.inf)
142
+ if "yaw_prior" in data:
143
+ mask_yaw_prior(scores, data["yaw_prior"], self.conf.num_rotations)
144
+ log_probs = log_softmax_spatial(scores)
145
+ # torch.save(scores, 'scores.pt')
146
+ with torch.no_grad():
147
+ uvr_max = argmax_xyr(scores).to(scores)
148
+ uvr_avg, _ = expectation_xyr(log_probs.exp())
149
+
150
+ return {
151
+ **pred,
152
+ "scores": scores,
153
+ "log_probs": log_probs,
154
+ "uvr_max": uvr_max,
155
+ "uv_max": uvr_max[..., :2],
156
+ "yaw_max": uvr_max[..., 2],
157
+ "uvr_expectation": uvr_avg,
158
+ "uv_expectation": uvr_avg[..., :2],
159
+ "yaw_expectation": uvr_avg[..., 2],
160
+ "features_image": f_image,
161
+ }
162
+
163
+ def loss(self, pred, data):
164
+ xy_gt = data["uv"]
165
+ yaw_gt = data["roll_pitch_yaw"][..., -1]
166
+ if self.conf.do_label_smoothing:
167
+ nll = nll_loss_xyr_smoothed(
168
+ pred["log_probs"],
169
+ xy_gt,
170
+ yaw_gt,
171
+ self.conf.sigma_xy / self.conf.pixel_per_meter,
172
+ self.conf.sigma_r,
173
+ mask=data.get("map_mask"),
174
+ )
175
+ else:
176
+ nll = nll_loss_xyr(pred["log_probs"], xy_gt, yaw_gt)
177
+ loss = {"total": nll, "nll": nll}
178
+ if self.training and self.conf.add_temperature:
179
+ loss["temperature"] = self.temperature.expand(len(nll))
180
+ return loss
181
+
182
+ def metrics(self):
183
+ return {
184
+ "xy_max_error": Location2DError("uv_max", self.conf.pixel_per_meter),
185
+ "xy_expectation_error": Location2DError(
186
+ "uv_expectation", self.conf.pixel_per_meter
187
+ ),
188
+ "yaw_max_error": AngleError("yaw_max"),
189
+ "xy_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"),
190
+ "xy_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"),
191
+ "xy_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"),
192
+
193
+ # "x_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"),
194
+ # "x_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"),
195
+ # "x_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"),
196
+ #
197
+ # "y_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"),
198
+ # "y_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"),
199
+ # "y_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"),
200
+
201
+ "yaw_recall_1°": AngleRecall(1.0, "yaw_max"),
202
+ "yaw_recall_3°": AngleRecall(3.0, "yaw_max"),
203
+ "yaw_recall_5°": AngleRecall(5.0, "yaw_max"),
204
+ }
models/metrics.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import torch
4
+ import torchmetrics
5
+ from torchmetrics.utilities.data import dim_zero_cat
6
+
7
+ from .utils import deg2rad, rotmat2d
8
+
9
+
10
+ def location_error(uv, uv_gt, ppm=1):
11
+ return torch.norm(uv - uv_gt.to(uv), dim=-1) / ppm
12
+
13
+ def location_error_single(uv, uv_gt, ppm=1):
14
+ return torch.norm(uv - uv_gt.to(uv), dim=-1) / ppm
15
+
16
+ def angle_error(t, t_gt):
17
+ error = torch.abs(t % 360 - t_gt.to(t) % 360)
18
+ error = torch.minimum(error, 360 - error)
19
+ return error
20
+
21
+
22
+ class Location2DRecall(torchmetrics.MeanMetric):
23
+ def __init__(self, threshold, pixel_per_meter, key="uv_max", *args, **kwargs):
24
+ self.threshold = threshold
25
+ self.ppm = pixel_per_meter
26
+ self.key = key
27
+ super().__init__(*args, **kwargs)
28
+
29
+ def update(self, pred, data):
30
+ self.cuda()
31
+ error = location_error(pred[self.key], data["uv"], self.ppm)
32
+ # print(error,self.threshold)
33
+ super().update((error <= torch.tensor(self.threshold,device=error.device)).float())
34
+
35
+ class Location1DRecall(torchmetrics.MeanMetric):
36
+ def __init__(self, threshold, pixel_per_meter, key="uv_max", *args, **kwargs):
37
+ self.threshold = threshold
38
+ self.ppm = pixel_per_meter
39
+ self.key = key
40
+ super().__init__(*args, **kwargs)
41
+
42
+ def update(self, pred, data):
43
+ self.cuda()
44
+ error = location_error(pred[self.key], data["uv"], self.ppm)
45
+ # print(error,self.threshold)
46
+ super().update((error <= torch.tensor(self.threshold,device=error.device)).float())
47
+ class AngleRecall(torchmetrics.MeanMetric):
48
+ def __init__(self, threshold, key="yaw_max", *args, **kwargs):
49
+ self.threshold = threshold
50
+ self.key = key
51
+
52
+ super().__init__(*args, **kwargs)
53
+
54
+ def update(self, pred, data):
55
+ self.cuda()
56
+ error = angle_error(pred[self.key], data["roll_pitch_yaw"][..., -1])
57
+ super().update((error <= self.threshold).float())
58
+
59
+
60
+ class MeanMetricWithRecall(torchmetrics.Metric):
61
+ full_state_update = True
62
+
63
+ def __init__(self):
64
+ super().__init__()
65
+ self.add_state("value", default=[], dist_reduce_fx="cat")
66
+ def compute(self):
67
+ return dim_zero_cat(self.value).mean(0)
68
+
69
+ def get_errors(self):
70
+ return dim_zero_cat(self.value)
71
+
72
+ def recall(self, thresholds):
73
+ self.cuda()
74
+ error = self.get_errors()
75
+ thresholds = error.new_tensor(thresholds)
76
+ return (error.unsqueeze(-1) < thresholds).float().mean(0) * 100
77
+
78
+
79
+ class AngleError(MeanMetricWithRecall):
80
+ def __init__(self, key):
81
+ super().__init__()
82
+ self.key = key
83
+
84
+ def update(self, pred, data):
85
+ self.cuda()
86
+ value = angle_error(pred[self.key], data["roll_pitch_yaw"][..., -1])
87
+ if value.numel():
88
+ self.value.append(value)
89
+
90
+
91
+ class Location2DError(MeanMetricWithRecall):
92
+ def __init__(self, key, pixel_per_meter):
93
+ super().__init__()
94
+ self.key = key
95
+ self.ppm = pixel_per_meter
96
+
97
+ def update(self, pred, data):
98
+ self.cuda()
99
+ value = location_error(pred[self.key], data["uv"], self.ppm)
100
+ if value.numel():
101
+ self.value.append(value)
102
+
103
+
104
+ class LateralLongitudinalError(MeanMetricWithRecall):
105
+ def __init__(self, pixel_per_meter, key="uv_max"):
106
+ super().__init__()
107
+ self.ppm = pixel_per_meter
108
+ self.key = key
109
+
110
+ def update(self, pred, data):
111
+ self.cuda()
112
+ yaw = deg2rad(data["roll_pitch_yaw"][..., -1])
113
+ shift = (pred[self.key] - data["uv"]) * yaw.new_tensor([-1, 1])
114
+ shift = (rotmat2d(yaw) @ shift.unsqueeze(-1)).squeeze(-1)
115
+ error = torch.abs(shift) / self.ppm
116
+ value = error.view(-1, 2)
117
+ if value.numel():
118
+ self.value.append(value)
models/utils.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import math
4
+ from typing import Optional
5
+
6
+ import torch
7
+
8
+
9
+ def checkpointed(cls, do=True):
10
+ """Adapted from the DISK implementation of Michał Tyszkiewicz."""
11
+ assert issubclass(cls, torch.nn.Module)
12
+
13
+ class Checkpointed(cls):
14
+ def forward(self, *args, **kwargs):
15
+ super_fwd = super(Checkpointed, self).forward
16
+ if any((torch.is_tensor(a) and a.requires_grad) for a in args):
17
+ return torch.utils.checkpoint.checkpoint(super_fwd, *args, **kwargs)
18
+ else:
19
+ return super_fwd(*args, **kwargs)
20
+
21
+ return Checkpointed if do else cls
22
+
23
+
24
+ class GlobalPooling(torch.nn.Module):
25
+ def __init__(self, kind):
26
+ super().__init__()
27
+ if kind == "mean":
28
+ self.fn = torch.nn.Sequential(
29
+ torch.nn.Flatten(2), torch.nn.AdaptiveAvgPool1d(1), torch.nn.Flatten()
30
+ )
31
+ elif kind == "max":
32
+ self.fn = torch.nn.Sequential(
33
+ torch.nn.Flatten(2), torch.nn.AdaptiveMaxPool1d(1), torch.nn.Flatten()
34
+ )
35
+ else:
36
+ raise ValueError(f"Unknown pooling type {kind}.")
37
+
38
+ def forward(self, x):
39
+ return self.fn(x)
40
+
41
+
42
+ @torch.jit.script
43
+ def make_grid(
44
+ w: float,
45
+ h: float,
46
+ step_x: float = 1.0,
47
+ step_y: float = 1.0,
48
+ orig_x: float = 0,
49
+ orig_y: float = 0,
50
+ y_up: bool = False,
51
+ device: Optional[torch.device] = None,
52
+ ) -> torch.Tensor:
53
+ x, y = torch.meshgrid(
54
+ [
55
+ torch.arange(orig_x, w + orig_x, step_x, device=device),
56
+ torch.arange(orig_y, h + orig_y, step_y, device=device),
57
+ ],
58
+ indexing="xy",
59
+ )
60
+ if y_up:
61
+ y = y.flip(-2)
62
+ grid = torch.stack((x, y), -1)
63
+ return grid
64
+
65
+
66
+ @torch.jit.script
67
+ def rotmat2d(angle: torch.Tensor) -> torch.Tensor:
68
+ c = torch.cos(angle)
69
+ s = torch.sin(angle)
70
+ R = torch.stack([c, -s, s, c], -1).reshape(angle.shape + (2, 2))
71
+ return R
72
+
73
+
74
+ @torch.jit.script
75
+ def rotmat2d_grad(angle: torch.Tensor) -> torch.Tensor:
76
+ c = torch.cos(angle)
77
+ s = torch.sin(angle)
78
+ R = torch.stack([-s, -c, c, -s], -1).reshape(angle.shape + (2, 2))
79
+ return R
80
+
81
+
82
+ def deg2rad(x):
83
+ return x * math.pi / 180
84
+
85
+
86
+ def rad2deg(x):
87
+ return x * 180 / math.pi
models/voting.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from typing import Optional, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch.fft import irfftn, rfftn
8
+ from torch.nn.functional import grid_sample, log_softmax, pad
9
+
10
+ from .metrics import angle_error
11
+ from .utils import make_grid, rotmat2d
12
+ from torchvision.transforms.functional import rotate
13
+
14
+ class UAVTemplateSamplerFast(torch.nn.Module):
15
+ def __init__(self, num_rotations,w=128,optimize=True):
16
+ super().__init__()
17
+
18
+ h, w = w,w
19
+ grid_xy = make_grid(
20
+ w=w,
21
+ h=h,
22
+ step_x=1,
23
+ step_y=1,
24
+ orig_y=-h//2,
25
+ orig_x=-h//2,
26
+ y_up=True,
27
+ ).cuda()
28
+
29
+ if optimize:
30
+ assert (num_rotations % 4) == 0
31
+ angles = torch.arange(
32
+ 0, 90, 90 / (num_rotations // 4)
33
+ ).cuda()
34
+ else:
35
+ angles = torch.arange(
36
+ 0, 360, 360 / num_rotations, device=grid_xz_bev.device
37
+ )
38
+ rotmats = rotmat2d(angles / 180 * np.pi)
39
+ grid_xy_rot = torch.einsum("...nij,...hwj->...nhwi", rotmats, grid_xy)
40
+
41
+ grid_ij_rot = (grid_xy_rot - grid_xy[..., :1, :1, :]) * grid_xy.new_tensor(
42
+ [1, -1]
43
+ )
44
+ grid_ij_rot = grid_ij_rot
45
+ grid_norm = (grid_ij_rot + 0.5) / grid_ij_rot.new_tensor([w, h]) * 2 - 1
46
+
47
+ self.optimize = optimize
48
+ self.num_rots = num_rotations
49
+ self.register_buffer("angles", angles, persistent=False)
50
+ self.register_buffer("grid_norm", grid_norm, persistent=False)
51
+
52
+ def forward(self, image_bev):
53
+ grid = self.grid_norm
54
+ b, c = image_bev.shape[:2]
55
+ n, h, w = grid.shape[:3]
56
+ grid = grid[None].repeat_interleave(b, 0).reshape(b * n, h, w, 2)
57
+ image = (
58
+ image_bev[:, None]
59
+ .repeat_interleave(n, 1)
60
+ .reshape(b * n, *image_bev.shape[1:])
61
+ )
62
+ # print(image.shape,grid.shape,self.grid_norm.shape)
63
+ kernels = grid_sample(image, grid.to(image.dtype), align_corners=False).reshape(
64
+ b, n, c, h, w
65
+ )
66
+
67
+ if self.optimize: # we have computed only the first quadrant
68
+ kernels_quad234 = [torch.rot90(kernels, -i, (-2, -1)) for i in (1, 2, 3)]
69
+ kernels = torch.cat([kernels] + kernels_quad234, 1)
70
+
71
+ return kernels
72
+ class UAVTemplateSampler(torch.nn.Module):
73
+ def __init__(self, num_rotations):
74
+ super().__init__()
75
+
76
+ self.num_rotations = num_rotations
77
+
78
+ def Template(self, input_features):
79
+ # 角度数量
80
+ num_angles = self.num_rotations
81
+ # 扩展第二个维度为旋转角度数量
82
+ input_shape = torch.tensor(input_features.shape)
83
+ output_shape = torch.cat((input_shape[:1], torch.tensor([num_angles]), input_shape[1:])).tolist()
84
+ expanded_features = torch.zeros(output_shape,device=input_features.device)
85
+
86
+ # 生成旋转角度序列
87
+ rotation_angles = torch.linspace(360, 0, 64 + 1)[:-1]
88
+ # rotation_angles=torch.flip(rotation_angles, dims=[0])
89
+ # 对扩展后的特征应用不同的旋转角度
90
+ rotated_features = []
91
+ # print(len(rotation_angles))
92
+ for i in range(len(rotation_angles)):
93
+ # print(rotation_angles[i].item())
94
+ rotated_feature = rotate(input_features, rotation_angles[i].item(), fill=0)
95
+ expanded_features[:, i, :, :, :] = rotated_feature
96
+
97
+ # 将所有旋转后的特征堆叠起来形成最终的输出向量
98
+ # output_features = torch.stack(rotated_features, dim=1)
99
+
100
+ # 输出向量的维度
101
+ # output_size = [3, num_angles, 8, 128, 128]
102
+ return expanded_features # 输出调试信息,验证输出向量的维度是否正确
103
+ def forward(self, image_bev):
104
+
105
+ kernels=self.Template(image_bev)
106
+
107
+ return kernels
108
+ class TemplateSampler(torch.nn.Module):
109
+ def __init__(self, grid_xz_bev, ppm, num_rotations, optimize=True):
110
+ super().__init__()
111
+
112
+ Δ = 1 / ppm
113
+ h, w = grid_xz_bev.shape[:2]
114
+ ksize = max(w, h * 2 + 1)
115
+ radius = ksize * Δ
116
+ grid_xy = make_grid(
117
+ radius,
118
+ radius,
119
+ step_x=Δ,
120
+ step_y=Δ,
121
+ orig_y=(Δ - radius) / 2,
122
+ orig_x=(Δ - radius) / 2,
123
+ y_up=True,
124
+ )
125
+
126
+ if optimize:
127
+ assert (num_rotations % 4) == 0
128
+ angles = torch.arange(
129
+ 0, 90, 90 / (num_rotations // 4), device=grid_xz_bev.device
130
+ )
131
+ else:
132
+ angles = torch.arange(
133
+ 0, 360, 360 / num_rotations, device=grid_xz_bev.device
134
+ )
135
+ rotmats = rotmat2d(angles / 180 * np.pi)
136
+ grid_xy_rot = torch.einsum("...nij,...hwj->...nhwi", rotmats, grid_xy)
137
+
138
+ grid_ij_rot = (grid_xy_rot - grid_xz_bev[..., :1, :1, :]) * grid_xy.new_tensor(
139
+ [1, -1]
140
+ )
141
+ grid_ij_rot = grid_ij_rot / Δ
142
+ grid_norm = (grid_ij_rot + 0.5) / grid_ij_rot.new_tensor([w, h]) * 2 - 1
143
+
144
+ self.optimize = optimize
145
+ self.num_rots = num_rotations
146
+ self.register_buffer("angles", angles, persistent=False)
147
+ self.register_buffer("grid_norm", grid_norm, persistent=False)
148
+
149
+ def forward(self, image_bev):
150
+ grid = self.grid_norm
151
+ b, c = image_bev.shape[:2]
152
+ n, h, w = grid.shape[:3]
153
+ grid = grid[None].repeat_interleave(b, 0).reshape(b * n, h, w, 2)
154
+ image = (
155
+ image_bev[:, None]
156
+ .repeat_interleave(n, 1)
157
+ .reshape(b * n, *image_bev.shape[1:])
158
+ )
159
+ kernels = grid_sample(image, grid.to(image.dtype), align_corners=False).reshape(
160
+ b, n, c, h, w
161
+ )
162
+
163
+ if self.optimize: # we have computed only the first quadrant
164
+ kernels_quad234 = [torch.rot90(kernels, -i, (-2, -1)) for i in (1, 2, 3)]
165
+ kernels = torch.cat([kernels] + kernels_quad234, 1)
166
+
167
+ return kernels
168
+
169
+
170
+ def conv2d_fft_batchwise(signal, kernel, padding="same", padding_mode="constant"):
171
+ if padding == "same":
172
+ padding = [i // 2 for i in kernel.shape[-2:]]
173
+ padding_signal = [p for p in padding[::-1] for _ in range(2)]
174
+ signal = pad(signal, padding_signal, mode=padding_mode)
175
+ assert signal.size(-1) % 2 == 0
176
+
177
+ padding_kernel = [
178
+ pad for i in [1, 2] for pad in [0, signal.size(-i) - kernel.size(-i)]
179
+ ]
180
+ kernel_padded = pad(kernel, padding_kernel)
181
+
182
+ signal_fr = rfftn(signal, dim=(-1, -2))
183
+ kernel_fr = rfftn(kernel_padded, dim=(-1, -2))
184
+
185
+ kernel_fr.imag *= -1 # flip the kernel
186
+ output_fr = torch.einsum("bc...,bdc...->bd...", signal_fr, kernel_fr)
187
+ output = irfftn(output_fr, dim=(-1, -2))
188
+
189
+ crop_slices = [slice(0, output.size(0)), slice(0, output.size(1))] + [
190
+ slice(0, (signal.size(i) - kernel.size(i) + 1)) for i in [-2, -1]
191
+ ]
192
+ output = output[crop_slices].contiguous()
193
+
194
+ return output
195
+
196
+
197
+ class SparseMapSampler(torch.nn.Module):
198
+ def __init__(self, num_rotations):
199
+ super().__init__()
200
+ angles = torch.arange(0, 360, 360 / self.conf.num_rotations)
201
+ rotmats = rotmat2d(angles / 180 * np.pi)
202
+ self.num_rotations = num_rotations
203
+ self.register_buffer("rotmats", rotmats, persistent=False)
204
+
205
+ def forward(self, image_map, p2d_bev):
206
+ h, w = image_map.shape[-2:]
207
+ locations = make_grid(w, h, device=p2d_bev.device)
208
+ p2d_candidates = torch.einsum(
209
+ "kji,...i,->...kj", self.rotmats.to(p2d_bev), p2d_bev
210
+ )
211
+ p2d_candidates = p2d_candidates[..., None, None, :, :] + locations.unsqueeze(-1)
212
+ # ... x N x W x H x K x 2
213
+
214
+ p2d_norm = (p2d_candidates / (image_map.new_tensor([w, h]) - 1)) * 2 - 1
215
+ valid = torch.all((p2d_norm >= -1) & (p2d_norm <= 1), -1)
216
+ value = grid_sample(
217
+ image_map, p2d_norm.flatten(-4, -2), align_corners=True, mode="bilinear"
218
+ )
219
+ value = value.reshape(image_map.shape[:2] + valid.shape[-4])
220
+ return valid, value
221
+
222
+
223
+ def sample_xyr(volume, xy_grid, angle_grid, nearest_for_inf=False):
224
+ # (B, C, H, W, N) to (B, C, H, W, N+1)
225
+ volume_padded = pad(volume, [0, 1, 0, 0, 0, 0], mode="circular")
226
+
227
+ size = xy_grid.new_tensor(volume.shape[-3:-1][::-1])
228
+ xy_norm = xy_grid / (size - 1) # align_corners=True
229
+ angle_norm = (angle_grid / 360) % 1
230
+ grid = torch.concat([angle_norm.unsqueeze(-1), xy_norm], -1)
231
+ grid_norm = grid * 2 - 1
232
+
233
+ valid = torch.all((grid_norm >= -1) & (grid_norm <= 1), -1)
234
+ value = grid_sample(volume_padded, grid_norm, align_corners=True, mode="bilinear")
235
+
236
+ # if one of the values used for linear interpolation is infinite,
237
+ # we fallback to nearest to avoid propagating inf
238
+ if nearest_for_inf:
239
+ value_nearest = grid_sample(
240
+ volume_padded, grid_norm, align_corners=True, mode="nearest"
241
+ )
242
+ value = torch.where(~torch.isfinite(value) & valid, value_nearest, value)
243
+
244
+ return value, valid
245
+
246
+
247
+ def nll_loss_xyr(log_probs, xy, angle):
248
+ log_prob, _ = sample_xyr(
249
+ log_probs.unsqueeze(1), xy[:, None, None, None], angle[:, None, None, None]
250
+ )
251
+ nll = -log_prob.reshape(-1) # remove C,H,W,N
252
+ return nll
253
+
254
+
255
+ def nll_loss_xyr_smoothed(log_probs, xy, angle, sigma_xy, sigma_r, mask=None):
256
+ *_, nx, ny, nr = log_probs.shape
257
+ grid_x = torch.arange(nx, device=log_probs.device, dtype=torch.float)
258
+ dx = (grid_x - xy[..., None, 0]) / sigma_xy
259
+ grid_y = torch.arange(ny, device=log_probs.device, dtype=torch.float)
260
+ dy = (grid_y - xy[..., None, 1]) / sigma_xy
261
+ dr = (
262
+ torch.arange(0, 360, 360 / nr, device=log_probs.device, dtype=torch.float)
263
+ - angle[..., None]
264
+ ) % 360
265
+ dr = torch.minimum(dr, 360 - dr) / sigma_r
266
+ diff = (
267
+ dx[..., None, :, None] ** 2
268
+ + dy[..., :, None, None] ** 2
269
+ + dr[..., None, None, :] ** 2
270
+ )
271
+ pdf = torch.exp(-diff / 2)
272
+ if mask is not None:
273
+ pdf.masked_fill_(~mask[..., None], 0)
274
+ log_probs = log_probs.masked_fill(~mask[..., None], 0)
275
+ pdf /= pdf.sum((-1, -2, -3), keepdim=True)
276
+ return -torch.sum(pdf * log_probs.to(torch.float), dim=(-1, -2, -3))
277
+
278
+
279
+ def log_softmax_spatial(x, dims=3):
280
+ return log_softmax(x.flatten(-dims), dim=-1).reshape(x.shape)
281
+
282
+
283
+ @torch.jit.script
284
+ def argmax_xy(scores: torch.Tensor) -> torch.Tensor:
285
+ indices = scores.flatten(-2).max(-1).indices
286
+ width = scores.shape[-1]
287
+ x = indices % width
288
+ y = torch.div(indices, width, rounding_mode="floor")
289
+ return torch.stack((x, y), -1)
290
+
291
+
292
+ @torch.jit.script
293
+ def expectation_xy(prob: torch.Tensor) -> torch.Tensor:
294
+ h, w = prob.shape[-2:]
295
+ grid = make_grid(float(w), float(h), device=prob.device).to(prob)
296
+ return torch.einsum("...hw,hwd->...d", prob, grid)
297
+
298
+
299
+ @torch.jit.script
300
+ def expectation_xyr(
301
+ prob: torch.Tensor, covariance: bool = False
302
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
303
+ h, w, num_rotations = prob.shape[-3:]
304
+ x, y = torch.meshgrid(
305
+ [
306
+ torch.arange(w, device=prob.device, dtype=prob.dtype),
307
+ torch.arange(h, device=prob.device, dtype=prob.dtype),
308
+ ],
309
+ indexing="xy",
310
+ )
311
+ grid_xy = torch.stack((x, y), -1)
312
+ xy_mean = torch.einsum("...hwn,hwd->...d", prob, grid_xy)
313
+
314
+ angles = torch.arange(0, 1, 1 / num_rotations, device=prob.device, dtype=prob.dtype)
315
+ angles = angles * 2 * np.pi
316
+ grid_cs = torch.stack([torch.cos(angles), torch.sin(angles)], -1)
317
+ cs_mean = torch.einsum("...hwn,nd->...d", prob, grid_cs)
318
+ angle = torch.atan2(cs_mean[..., 1], cs_mean[..., 0])
319
+ angle = (angle * 180 / np.pi) % 360
320
+
321
+ if covariance:
322
+ xy_cov = torch.einsum("...hwn,...hwd,...hwk->...dk", prob, grid_xy, grid_xy)
323
+ xy_cov = xy_cov - torch.einsum("...d,...k->...dk", xy_mean, xy_mean)
324
+ else:
325
+ xy_cov = None
326
+
327
+ xyr_mean = torch.cat((xy_mean, angle.unsqueeze(-1)), -1)
328
+ return xyr_mean, xy_cov
329
+
330
+
331
+ @torch.jit.script
332
+ def argmax_xyr(scores: torch.Tensor) -> torch.Tensor:
333
+ indices = scores.flatten(-3).max(-1).indices
334
+ width, num_rotations = scores.shape[-2:]
335
+ wr = width * num_rotations
336
+ y = torch.div(indices, wr, rounding_mode="floor")
337
+ x = torch.div(indices % wr, num_rotations, rounding_mode="floor")
338
+ angle_index = indices % num_rotations
339
+ angle = angle_index * 360 / num_rotations
340
+ xyr = torch.stack((x, y, angle), -1)
341
+ return xyr
342
+
343
+
344
+ @torch.jit.script
345
+ def mask_yaw_prior(
346
+ scores: torch.Tensor, yaw_prior: torch.Tensor, num_rotations: int
347
+ ) -> torch.Tensor:
348
+ step = 360 / num_rotations
349
+ step_2 = step / 2
350
+ angles = torch.arange(step_2, 360 + step_2, step, device=scores.device)
351
+ yaw_init, yaw_range = yaw_prior.chunk(2, dim=-1)
352
+ rot_mask = angle_error(angles, yaw_init) < yaw_range
353
+ return scores.masked_fill_(~rot_mask[:, None, None], -np.inf)
354
+
355
+
356
+ def fuse_gps(log_prob, uv_gps, ppm, sigma=10, gaussian=False):
357
+ grid = make_grid(*log_prob.shape[-3:-1][::-1]).to(log_prob)
358
+ dist = torch.sum((grid - uv_gps) ** 2, -1)
359
+ sigma_pixel = sigma * ppm
360
+ if gaussian:
361
+ gps_log_prob = -1 / 2 * dist / sigma_pixel**2
362
+ else:
363
+ gps_log_prob = torch.where(dist < sigma_pixel**2, 1, -np.inf)
364
+ log_prob_fused = log_softmax_spatial(log_prob + gps_log_prob.unsqueeze(-1))
365
+ return log_prob_fused
module.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from pathlib import Path
4
+
5
+ import pytorch_lightning as pl
6
+ import torch
7
+ from omegaconf import DictConfig, OmegaConf, open_dict
8
+ from torchmetrics import MeanMetric, MetricCollection
9
+
10
+ import logger
11
+ from models import get_model
12
+
13
+
14
+ class AverageKeyMeter(MeanMetric):
15
+ def __init__(self, key, *args, **kwargs):
16
+ self.key = key
17
+ super().__init__(*args, **kwargs)
18
+
19
+ def update(self, dict):
20
+ value = dict[self.key]
21
+ value = value[torch.isfinite(value)]
22
+ return super().update(value)
23
+
24
+
25
+ class GenericModule(pl.LightningModule):
26
+ def __init__(self, cfg):
27
+ super().__init__()
28
+ name = cfg.model.get("name")
29
+ name = "orienternet" if name in ("localizer_bev_depth", None) else name
30
+ self.model = get_model(name)(cfg.model)
31
+ self.cfg = cfg
32
+ self.save_hyperparameters(cfg)
33
+
34
+
35
+
36
+ self.metrics_val = MetricCollection(self.model.metrics(), prefix="val/")
37
+ self.losses_val = None # we do not know the loss keys in advance
38
+
39
+ # self.citys = self.cfg.data.val_citys
40
+ # for i in range(len(self.citys)):
41
+ # city=self.citys[i]
42
+ # setattr(self, "metric_vals_{}".format(i), MetricCollection(self.model.metrics(), prefix="val_{}/".format(city)))
43
+ # self.losse_vals = [None for city in self.cfg.data.val_citys]
44
+
45
+
46
+ def forward(self, batch):
47
+ return self.model(batch)
48
+
49
+ def training_step(self, batch):
50
+ pred = self(batch)
51
+ losses = self.model.loss(pred, batch)
52
+ self.log_dict(
53
+ {f"loss/{k}/train": v.mean() for k, v in losses.items()},
54
+ prog_bar=True,
55
+ rank_zero_only=True,
56
+ )
57
+ return losses["total"].mean()
58
+
59
+ # def validation_step(self, batch, batch_idx,dataloader_idx):
60
+ # city=self.citys[dataloader_idx]
61
+ #
62
+ # pred = self(batch)
63
+ # losses = self.model.loss(pred, batch)
64
+ #
65
+ # if hasattr(self,"losse_val_{}".format(dataloader_idx)) is False:
66
+ # setattr(self,"losse_val_{}".format(dataloader_idx),MetricCollection(
67
+ # {k: AverageKeyMeter(k).to(self.device) for k in losses},
68
+ # prefix="loss_{}/".format(city),
69
+ # postfix="/val_{}".format(city),
70
+ # ))
71
+ #
72
+ # # print(pred, batch)
73
+ # getattr(self,"metric_vals_{}".format(dataloader_idx))(pred, batch)
74
+ # self.log_dict(getattr(self,"metric_vals_{}".format(dataloader_idx))(pred, batch), sync_dist=True)
75
+ #
76
+ # getattr(self,"losse_val_{}".format(dataloader_idx)).update(losses)
77
+ # # print(getattr(self,"losse_val_{}".format(dataloader_idx)))
78
+ # self.log_dict(getattr(self,"losse_val_{}".format(dataloader_idx)).compute(), sync_dist=True)
79
+ def validation_step(self, batch, batch_idx):
80
+ pred = self(batch)
81
+ losses = self.model.loss(pred, batch)
82
+ if self.losses_val is None:
83
+ self.losses_val = MetricCollection(
84
+ {k: AverageKeyMeter(k).to(self.device) for k in losses},
85
+ prefix="loss/",
86
+ postfix="/val",
87
+ )
88
+ self.metrics_val(pred, batch)
89
+ self.log_dict(self.metrics_val, sync_dist=True)
90
+ self.losses_val.update(losses)
91
+ self.log_dict(self.losses_val, sync_dist=True)
92
+
93
+ def validation_epoch_start(self, batch):
94
+ self.losses_val = None
95
+ # self.losse_val = [None for city in self.cfg.data.val_citys]
96
+
97
+ def configure_optimizers(self):
98
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.cfg.training.lr)
99
+ ret = {"optimizer": optimizer}
100
+ cfg_scheduler = self.cfg.training.get("lr_scheduler")
101
+ if cfg_scheduler is not None:
102
+ scheduler = getattr(torch.optim.lr_scheduler, cfg_scheduler.name)(
103
+ optimizer=optimizer, **cfg_scheduler.get("args", {})
104
+ )
105
+ ret["lr_scheduler"] = {
106
+ "scheduler": scheduler,
107
+ "interval": "epoch",
108
+ "frequency": 1,
109
+ "monitor": "loss/total/val",
110
+ "strict": True,
111
+ "name": "learning_rate",
112
+ }
113
+ return ret
114
+
115
+ @classmethod
116
+ def load_from_checkpoint(
117
+ cls,
118
+ checkpoint_path,
119
+ map_location=None,
120
+ hparams_file=None,
121
+ strict=True,
122
+ cfg=None,
123
+ find_best=False,
124
+ ):
125
+ assert hparams_file is None, "hparams are not supported."
126
+
127
+ checkpoint = torch.load(
128
+ checkpoint_path, map_location=map_location or (lambda storage, loc: storage)
129
+ )
130
+ if find_best:
131
+ best_score, best_name = None, None
132
+ modes = {"min": torch.lt, "max": torch.gt}
133
+ for key, state in checkpoint["callbacks"].items():
134
+ if not key.startswith("ModelCheckpoint"):
135
+ continue
136
+ mode = eval(key.replace("ModelCheckpoint", ""))["mode"]
137
+ if best_score is None or modes[mode](
138
+ state["best_model_score"], best_score
139
+ ):
140
+ best_score = state["best_model_score"]
141
+ best_name = Path(state["best_model_path"]).name
142
+ logger.info("Loading best checkpoint %s", best_name)
143
+ if best_name != checkpoint_path:
144
+ return cls.load_from_checkpoint(
145
+ Path(checkpoint_path).parent / best_name,
146
+ map_location,
147
+ hparams_file,
148
+ strict,
149
+ cfg,
150
+ find_best=False,
151
+ )
152
+
153
+ logger.info(
154
+ "Using checkpoint %s from epoch %d and step %d.",
155
+ checkpoint_path.name,
156
+ checkpoint["epoch"],
157
+ checkpoint["global_step"],
158
+ )
159
+ cfg_ckpt = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]
160
+ if list(cfg_ckpt.keys()) == ["cfg"]: # backward compatibility
161
+ cfg_ckpt = cfg_ckpt["cfg"]
162
+ cfg_ckpt = OmegaConf.create(cfg_ckpt)
163
+
164
+ if cfg is None:
165
+ cfg = {}
166
+ if not isinstance(cfg, DictConfig):
167
+ cfg = OmegaConf.create(cfg)
168
+ with open_dict(cfg_ckpt):
169
+ cfg = OmegaConf.merge(cfg_ckpt, cfg)
170
+
171
+ return pl.core.saving._load_state(cls, checkpoint, strict=strict, cfg=cfg)
osm/analysis.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from collections import Counter, defaultdict
4
+ from typing import Dict
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import plotly.graph_objects as go
9
+
10
+ from .parser import (
11
+ filter_area,
12
+ filter_node,
13
+ filter_way,
14
+ match_to_group,
15
+ parse_area,
16
+ parse_node,
17
+ parse_way,
18
+ Patterns,
19
+ )
20
+ from .reader import OSMData
21
+
22
+
23
+ def recover_hierarchy(counter: Counter) -> Dict:
24
+ """Recover a two-level hierarchy from the flat group labels."""
25
+ groups = defaultdict(dict)
26
+ for k, v in sorted(counter.items(), key=lambda x: -x[1]):
27
+ if ":" in k:
28
+ prefix, group = k.split(":")
29
+ if prefix in groups and isinstance(groups[prefix], int):
30
+ groups[prefix] = {}
31
+ groups[prefix][prefix] = groups[prefix]
32
+ groups[prefix] = {}
33
+ groups[prefix][group] = v
34
+ else:
35
+ groups[k] = v
36
+ return dict(groups)
37
+
38
+
39
+ def bar_autolabel(rects, fontsize):
40
+ """Attach a text label above each bar in *rects*, displaying its height."""
41
+ for rect in rects:
42
+ width = rect.get_width()
43
+ plt.gca().annotate(
44
+ f"{width}",
45
+ xy=(width, rect.get_y() + rect.get_height() / 2),
46
+ xytext=(3, 0), # 3 points vertical offset
47
+ textcoords="offset points",
48
+ ha="left",
49
+ va="center",
50
+ fontsize=fontsize,
51
+ )
52
+
53
+
54
+ def plot_histogram(counts, fontsize, dpi):
55
+ fig, ax = plt.subplots(dpi=dpi, figsize=(8, 20))
56
+
57
+ labels = []
58
+ for k, v in counts.items():
59
+ if isinstance(v, dict):
60
+ labels += list(v.keys())
61
+ v = list(v.values())
62
+ else:
63
+ labels.append(k)
64
+ v = [v]
65
+ bars = plt.barh(
66
+ len(labels) + -len(v) + np.arange(len(v)), v, height=0.9, label=k
67
+ )
68
+ bar_autolabel(bars, fontsize)
69
+
70
+ ax.set_yticklabels(labels, fontsize=fontsize)
71
+ ax.axes.xaxis.set_ticklabels([])
72
+ ax.xaxis.tick_top()
73
+ ax.invert_yaxis()
74
+ plt.yticks(np.arange(len(labels)))
75
+ plt.xscale("log")
76
+ plt.legend(ncol=len(counts), loc="upper center")
77
+
78
+
79
+ def count_elements(elems: Dict[int, str], filter_fn, parse_fn) -> Dict:
80
+ """Count the number of elements in each group."""
81
+ counts = Counter()
82
+ for elem in filter(filter_fn, elems.values()):
83
+ group = parse_fn(elem.tags)
84
+ if group is None:
85
+ continue
86
+ counts[group] += 1
87
+ counts = recover_hierarchy(counts)
88
+ return counts
89
+
90
+
91
+ def plot_osm_histograms(osm: OSMData, fontsize=8, dpi=150):
92
+ counts = count_elements(osm.nodes, filter_node, parse_node)
93
+ plot_histogram(counts, fontsize, dpi)
94
+ plt.title("nodes")
95
+
96
+ counts = count_elements(osm.ways, filter_way, parse_way)
97
+ plot_histogram(counts, fontsize, dpi)
98
+ plt.title("ways")
99
+
100
+ counts = count_elements(osm.ways, filter_area, parse_area)
101
+ plot_histogram(counts, fontsize, dpi)
102
+ plt.title("areas")
103
+
104
+
105
+ def plot_sankey_hierarchy(osm: OSMData):
106
+ triplets = []
107
+ for node in filter(filter_node, osm.nodes.values()):
108
+ label = parse_node(node.tags)
109
+ if label is None:
110
+ continue
111
+ group = match_to_group(label, Patterns.nodes)
112
+ if group is None:
113
+ group = match_to_group(label, Patterns.ways)
114
+ if group is None:
115
+ group = "null"
116
+ if ":" in label:
117
+ key, tag = label.split(":")
118
+ if tag == "yes":
119
+ tag = key
120
+ else:
121
+ key = tag = label
122
+ triplets.append((key, tag, group))
123
+ keys, tags, groups = list(zip(*triplets))
124
+ counts_key_tag = Counter(zip(keys, tags))
125
+ counts_key_tag_group = Counter(triplets)
126
+
127
+ key2tags = defaultdict(set)
128
+ for k, t in zip(keys, tags):
129
+ key2tags[k].add(t)
130
+ key2tags = {k: sorted(t) for k, t in key2tags.items()}
131
+ keytag2group = dict(zip(zip(keys, tags), groups))
132
+ key_names = sorted(set(keys))
133
+ tag_names = [(k, t) for k in key_names for t in key2tags[k]]
134
+
135
+ group_names = []
136
+ for k in key_names:
137
+ for t in key2tags[k]:
138
+ g = keytag2group[k, t]
139
+ if g not in group_names and g != "null":
140
+ group_names.append(g)
141
+ group_names += ["null"]
142
+
143
+ key2idx = dict(zip(key_names, range(len(key_names))))
144
+ tag2idx = {kt: i + len(key2idx) for i, kt in enumerate(tag_names)}
145
+ group2idx = {n: i + len(key2idx) + len(tag2idx) for i, n in enumerate(group_names)}
146
+
147
+ key_counts = Counter(keys)
148
+ key_text = [f"{k} {key_counts[k]}" for k in key_names]
149
+ tag_counts = Counter(list(zip(keys, tags)))
150
+ tag_text = [f"{t} {tag_counts[k, t]}" for k, t in tag_names]
151
+ group_counts = Counter(groups)
152
+ group_text = [f"{k} {group_counts[k]}" for k in group_names]
153
+
154
+ fig = go.Figure(
155
+ data=[
156
+ go.Sankey(
157
+ orientation="h",
158
+ node=dict(
159
+ pad=15,
160
+ thickness=20,
161
+ line=dict(color="black", width=0.5),
162
+ label=key_text + tag_text + group_text,
163
+ x=[0] * len(key_names)
164
+ + [1] * len(tag_names)
165
+ + [2] * len(group_names),
166
+ color="blue",
167
+ ),
168
+ arrangement="fixed",
169
+ link=dict(
170
+ source=[key2idx[k] for k, _ in counts_key_tag]
171
+ + [tag2idx[k, t] for k, t, _ in counts_key_tag_group],
172
+ target=[tag2idx[k, t] for k, t in counts_key_tag]
173
+ + [group2idx[g] for _, _, g in counts_key_tag_group],
174
+ value=list(counts_key_tag.values())
175
+ + list(counts_key_tag_group.values()),
176
+ ),
177
+ )
178
+ ]
179
+ )
180
+ fig.update_layout(autosize=False, width=800, height=2000, font_size=10)
181
+ fig.show()
182
+ return fig
osm/data.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import logging
4
+ from dataclasses import dataclass, field
5
+ from typing import Dict, List, Optional, Set, Tuple
6
+
7
+ import numpy as np
8
+
9
+ from .parser import (
10
+ filter_area,
11
+ filter_node,
12
+ filter_way,
13
+ match_to_group,
14
+ parse_area,
15
+ parse_node,
16
+ parse_way,
17
+ Patterns,
18
+ )
19
+ from .reader import OSMData, OSMNode, OSMRelation, OSMWay
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def glue(ways: List[OSMWay]) -> List[List[OSMNode]]:
26
+ result: List[List[OSMNode]] = []
27
+ to_process: Set[Tuple[OSMNode]] = set()
28
+
29
+ for way in ways:
30
+ if way.is_cycle():
31
+ result.append(way.nodes)
32
+ else:
33
+ to_process.add(tuple(way.nodes))
34
+
35
+ while to_process:
36
+ nodes: List[OSMNode] = list(to_process.pop())
37
+ glued: Optional[List[OSMNode]] = None
38
+ other_nodes: Optional[Tuple[OSMNode]] = None
39
+
40
+ for other_nodes in to_process:
41
+ glued = try_to_glue(nodes, list(other_nodes))
42
+ if glued is not None:
43
+ break
44
+
45
+ if glued is not None:
46
+ to_process.remove(other_nodes)
47
+ if is_cycle(glued):
48
+ result.append(glued)
49
+ else:
50
+ to_process.add(tuple(glued))
51
+ else:
52
+ result.append(nodes)
53
+
54
+ return result
55
+
56
+
57
+ def is_cycle(nodes: List[OSMNode]) -> bool:
58
+ """Is way a cycle way or an area boundary."""
59
+ return nodes[0] == nodes[-1]
60
+
61
+
62
+ def try_to_glue(nodes: List[OSMNode], other: List[OSMNode]) -> Optional[List[OSMNode]]:
63
+ """Create new combined way if ways share endpoints."""
64
+ if nodes[0] == other[0]:
65
+ return list(reversed(other[1:])) + nodes
66
+ if nodes[0] == other[-1]:
67
+ return other[:-1] + nodes
68
+ if nodes[-1] == other[-1]:
69
+ return nodes + list(reversed(other[:-1]))
70
+ if nodes[-1] == other[0]:
71
+ return nodes + other[1:]
72
+ return None
73
+
74
+
75
+ def multipolygon_from_relation(rel: OSMRelation, osm: OSMData):
76
+ inner_ways = []
77
+ outer_ways = []
78
+ for member in rel.members:
79
+ if member.type_ == "way":
80
+ if member.role == "inner":
81
+ if member.ref in osm.ways:
82
+ inner_ways.append(osm.ways[member.ref])
83
+ elif member.role == "outer":
84
+ if member.ref in osm.ways:
85
+ outer_ways.append(osm.ways[member.ref])
86
+ else:
87
+ logger.warning(f'Unknown member role "{member.role}".')
88
+ if outer_ways:
89
+ inners_path = glue(inner_ways)
90
+ outers_path = glue(outer_ways)
91
+ return inners_path, outers_path
92
+
93
+
94
+ @dataclass
95
+ class MapElement:
96
+ id_: int
97
+ label: str
98
+ group: str
99
+ tags: Optional[Dict[str, str]]
100
+
101
+
102
+ @dataclass
103
+ class MapNode(MapElement):
104
+ xy: np.ndarray
105
+
106
+ @classmethod
107
+ def from_osm(cls, node: OSMNode, label: str, group: str):
108
+ return cls(
109
+ node.id_,
110
+ label,
111
+ group,
112
+ node.tags,
113
+ xy=node.xy,
114
+ )
115
+
116
+
117
+ @dataclass
118
+ class MapLine(MapElement):
119
+ xy: np.ndarray
120
+
121
+ @classmethod
122
+ def from_osm(cls, way: OSMWay, label: str, group: str):
123
+ xy = np.stack([n.xy for n in way.nodes])
124
+ return cls(
125
+ way.id_,
126
+ label,
127
+ group,
128
+ way.tags,
129
+ xy=xy,
130
+ )
131
+
132
+
133
+ @dataclass
134
+ class MapArea(MapElement):
135
+ outers: List[np.ndarray]
136
+ inners: List[np.ndarray] = field(default_factory=list)
137
+
138
+ @classmethod
139
+ def from_relation(cls, rel: OSMRelation, label: str, group: str, osm: OSMData):
140
+ outers_inners = multipolygon_from_relation(rel, osm)
141
+ if outers_inners is None:
142
+ return None
143
+ outers, inners = outers_inners
144
+ outers = [np.stack([n.xy for n in way]) for way in outers]
145
+ inners = [np.stack([n.xy for n in way]) for way in inners]
146
+ return cls(
147
+ rel.id_,
148
+ label,
149
+ group,
150
+ rel.tags,
151
+ outers=outers,
152
+ inners=inners,
153
+ )
154
+
155
+ @classmethod
156
+ def from_way(cls, way: OSMWay, label: str, group: str):
157
+ xy = np.stack([n.xy for n in way.nodes])
158
+ return cls(
159
+ way.id_,
160
+ label,
161
+ group,
162
+ way.tags,
163
+ outers=[xy],
164
+ )
165
+
166
+
167
+ class MapData:
168
+ def __init__(self):
169
+ self.nodes: Dict[int, MapNode] = {}
170
+ self.lines: Dict[int, MapLine] = {}
171
+ self.areas: Dict[int, MapArea] = {}
172
+
173
+ @classmethod
174
+ def from_osm(cls, osm: OSMData):
175
+ self = cls()
176
+
177
+ for node in filter(filter_node, osm.nodes.values()):
178
+ label = parse_node(node.tags)
179
+ if label is None:
180
+ continue
181
+ group = match_to_group(label, Patterns.nodes)
182
+ if group is None:
183
+ group = match_to_group(label, Patterns.ways)
184
+ if group is None:
185
+ continue # missing
186
+ self.nodes[node.id_] = MapNode.from_osm(node, label, group)
187
+
188
+ for way in filter(filter_way, osm.ways.values()):
189
+ label = parse_way(way.tags)
190
+ if label is None:
191
+ continue
192
+ group = match_to_group(label, Patterns.ways)
193
+ if group is None:
194
+ group = match_to_group(label, Patterns.nodes)
195
+ if group is None:
196
+ continue # missing
197
+ self.lines[way.id_] = MapLine.from_osm(way, label, group)
198
+
199
+ for area in filter(filter_area, osm.ways.values()):
200
+ label = parse_area(area.tags)
201
+ if label is None:
202
+ continue
203
+ group = match_to_group(label, Patterns.areas)
204
+ if group is None:
205
+ group = match_to_group(label, Patterns.ways)
206
+ if group is None:
207
+ group = match_to_group(label, Patterns.nodes)
208
+ if group is None:
209
+ continue # missing
210
+ self.areas[area.id_] = MapArea.from_way(area, label, group)
211
+
212
+ for rel in osm.relations.values():
213
+ if rel.tags.get("type") != "multipolygon":
214
+ continue
215
+ label = parse_area(rel.tags)
216
+ if label is None:
217
+ continue
218
+ group = match_to_group(label, Patterns.areas)
219
+ if group is None:
220
+ group = match_to_group(label, Patterns.ways)
221
+ if group is None:
222
+ group = match_to_group(label, Patterns.nodes)
223
+ if group is None:
224
+ continue # missing
225
+ area = MapArea.from_relation(rel, label, group, osm)
226
+ assert rel.id_ not in self.areas # not sure if there can be collision
227
+ if area is not None:
228
+ self.areas[rel.id_] = area
229
+
230
+ return self
osm/download.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Dict, Optional
6
+
7
+ import urllib3
8
+
9
+
10
+ from utils.geo import BoundaryBox
11
+ import urllib.request
12
+ import requests
13
+
14
+ def get_osm(
15
+ boundary_box: BoundaryBox,
16
+ cache_path: Optional[Path] = None,
17
+ overwrite: bool = False,
18
+ ) -> str:
19
+ if not overwrite and cache_path is not None and cache_path.is_file():
20
+ with cache_path.open() as fp:
21
+ return json.load(fp)
22
+
23
+ (bottom, left), (top, right) = boundary_box.min_, boundary_box.max_
24
+ content: bytes = get_web_data(
25
+ # "https://api.openstreetmap.org/api/0.6/map.json",
26
+ "https://openstreetmap.erniubot.live/api/0.6/map.json",
27
+ # 'https://overpass-api.de/api/map',
28
+ # 'http://localhost:29505/api/map',
29
+ # "https://lz4.overpass-api.de/api/interpreter",
30
+ {"bbox": f"{left},{bottom},{right},{top}"},
31
+ )
32
+
33
+ content_str = content.decode("utf-8")
34
+ if content_str.startswith("You requested too many nodes"):
35
+ raise ValueError(content_str)
36
+
37
+ if cache_path is not None:
38
+ with cache_path.open("bw+") as fp:
39
+ fp.write(content)
40
+ a=json.loads(content_str)
41
+ return json.loads(content_str)
42
+
43
+
44
+ def get_web_data(address: str, parameters: Dict[str, str]) -> bytes:
45
+ # logger.info("Getting %s...", address)
46
+ # proxy_address = "http://107.173.122.186:3128"
47
+ #
48
+ # # 设置代理服务器地址和端口
49
+ # proxies = {
50
+ # 'http': proxy_address,
51
+ # 'https': proxy_address
52
+ # }
53
+
54
+ # 发送GET请求并返回响应数据
55
+ # response = requests.get(address, params=parameters, timeout=100, proxies=proxies)
56
+ print('url:',address)
57
+ response = requests.get(address, params=parameters, timeout=100)
58
+ return response.content
59
+ def get_web_data(address: str, parameters: Dict[str, str]) -> bytes:
60
+ # logger.info("Getting %s...", address)
61
+ while True:
62
+ try:
63
+ # proxy_address = "http://107.173.122.186:3128"
64
+ #
65
+ # # 设置代理服务器地址和端口
66
+ # proxies = {
67
+ # 'http': proxy_address,
68
+ # 'https': proxy_address
69
+ # }
70
+ # # 发送GET请求并返回响应数据
71
+ response = requests.get(address, params=parameters, timeout=100)
72
+ request = requests.Request('GET', address, params=parameters)
73
+ prepared_request = request.prepare()
74
+ # 获取完整URL
75
+ full_url = prepared_request.url
76
+ break
77
+
78
+ except Exception as e:
79
+ # 打印错误信息
80
+ print(f"发生错误: {e}")
81
+ print("重试...")
82
+
83
+ return response.content
84
+ # def get_web_data_2(address: str, parameters: Dict[str, str]) -> bytes:
85
+ # # logger.info("Getting %s...", address)
86
+ # proxy_address="http://107.173.122.186:3128"
87
+ # http = urllib3.PoolManager(proxy_url=proxy_address)
88
+ # result = http.request("GET", address, parameters, timeout=100)
89
+ # return result.data
90
+ #
91
+ #
92
+ # def get_web_data_1(address: str, parameters: Dict[str, str]) -> bytes:
93
+ #
94
+ # # 设置代理服务器地址和端口
95
+ # proxy_address = "http://107.173.122.186:3128"
96
+ #
97
+ # # 创建ProxyHandler对象
98
+ # proxy_handler = urllib.request.ProxyHandler({'http': proxy_address})
99
+ #
100
+ # # 构建查询字符串
101
+ # query_string = urllib.parse.urlencode(parameters)
102
+ #
103
+ # # 构建完整的URL
104
+ # url = address + '?' + query_string
105
+ # print(url)
106
+ # # 创建OpenerDirector对象,并将ProxyHandler对象作为参数传递
107
+ # opener = urllib.request.build_opener(proxy_handler)
108
+ #
109
+ # # 使用OpenerDirector对象发送请求
110
+ # response = opener.open(url)
111
+ #
112
+ # # 发送GET请求
113
+ # # response = urllib.request.urlopen(url, timeout=100)
114
+ #
115
+ # # 读取响应内容
116
+ # data = response.read()
117
+ # print()
118
+ # return data
osm/parser.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import logging
4
+ import re
5
+ from typing import List
6
+
7
+ from .reader import OSMData, OSMElement, OSMNode, OSMWay
8
+
9
+ IGNORE_TAGS = {"source", "phone", "entrance", "inscription", "note", "name"}
10
+
11
+
12
+ def parse_levels(string: str) -> List[float]:
13
+ """Parse string representation of level sequence value."""
14
+ try:
15
+ cleaned = string.replace(",", ";").replace(" ", "")
16
+ return list(map(float, cleaned.split(";")))
17
+ except ValueError:
18
+ logging.debug("Cannot parse level description from `%s`.", string)
19
+ return []
20
+
21
+
22
+ def filter_level(elem: OSMElement):
23
+ level = elem.tags.get("level")
24
+ if level is not None:
25
+ levels = parse_levels(level)
26
+ # In the US, ground floor levels are sometimes marked as level=1
27
+ # so let's be conservative and include it.
28
+ if not (0 in levels or 1 in levels):
29
+ return False
30
+ layer = elem.tags.get("layer")
31
+ if layer is not None:
32
+ layer = parse_levels(layer)
33
+ if len(layer) > 0 and max(layer) < 0:
34
+ return False
35
+ return (
36
+ elem.tags.get("location") != "underground"
37
+ and elem.tags.get("parking") != "underground"
38
+ )
39
+
40
+
41
+ def filter_node(node: OSMNode):
42
+ return len(node.tags.keys() - IGNORE_TAGS) > 0 and filter_level(node)
43
+
44
+
45
+ def is_area(way: OSMWay):
46
+ if way.nodes[0] != way.nodes[-1]:
47
+ return False
48
+ if way.tags.get("area") == "no":
49
+ return False
50
+ filters = [
51
+ "area",
52
+ "building",
53
+ "amenity",
54
+ "indoor",
55
+ "landuse",
56
+ "landcover",
57
+ "leisure",
58
+ "public_transport",
59
+ "shop",
60
+ ]
61
+ for f in filters:
62
+ if f in way.tags and way.tags.get(f) != "no":
63
+ return True
64
+ if way.tags.get("natural") in {"wood", "grassland", "water"}:
65
+ return True
66
+ return False
67
+
68
+
69
+ def filter_area(way: OSMWay):
70
+ return len(way.tags.keys() - IGNORE_TAGS) > 0 and is_area(way) and filter_level(way)
71
+
72
+
73
+ def filter_way(way: OSMWay):
74
+ return not filter_area(way) and way.tags != {} and filter_level(way)
75
+
76
+
77
+ def parse_node(tags):
78
+ keys = tags.keys()
79
+ for key in [
80
+ "amenity",
81
+ "natural",
82
+ "highway",
83
+ "barrier",
84
+ "shop",
85
+ "tourism",
86
+ "public_transport",
87
+ "emergency",
88
+ "man_made",
89
+ ]:
90
+ if key in keys:
91
+ if "disused" in tags[key]:
92
+ continue
93
+ return f"{key}:{tags[key]}"
94
+ return None
95
+
96
+
97
+ def parse_area(tags):
98
+ if "building" in tags:
99
+ group = "building"
100
+ kind = tags["building"]
101
+ if kind == "yes":
102
+ for key in ["amenity", "tourism"]:
103
+ if key in tags:
104
+ kind = tags[key]
105
+ break
106
+ if kind != "yes":
107
+ group += f":{kind}"
108
+ return group
109
+ if "area:highway" in tags:
110
+ return f'highway:{tags["area:highway"]}'
111
+ for key in [
112
+ "amenity",
113
+ "landcover",
114
+ "leisure",
115
+ "shop",
116
+ "highway",
117
+ "tourism",
118
+ "natural",
119
+ "waterway",
120
+ "landuse",
121
+ ]:
122
+ if key in tags:
123
+ return f"{key}:{tags[key]}"
124
+ return None
125
+
126
+
127
+ def parse_way(tags):
128
+ keys = tags.keys()
129
+ for key in ["highway", "barrier", "natural"]:
130
+ if key in keys:
131
+ return f"{key}:{tags[key]}"
132
+ return None
133
+
134
+
135
+ def match_to_group(label, patterns):
136
+ for group, pattern in patterns.items():
137
+ if re.match(pattern, label):
138
+ return group
139
+ return None
140
+
141
+
142
+ class Patterns:
143
+ areas = dict(
144
+ building="building($|:.*?)*",
145
+ parking="amenity:parking",
146
+ playground="leisure:(playground|pitch)",
147
+ grass="(landuse:grass|landcover:grass|landuse:meadow|landuse:flowerbed|natural:grassland)",
148
+ park="leisure:(park|garden|dog_park)",
149
+ forest="(landuse:forest|natural:wood)",
150
+ water="(natural:water|waterway:*)",
151
+ )
152
+ # + ways: road, path
153
+ # + node: fountain, bicycle_parking
154
+
155
+ ways = dict(
156
+ fence="barrier:(fence|yes)",
157
+ wall="barrier:(wall|retaining_wall)",
158
+ hedge="barrier:hedge",
159
+ kerb="barrier:kerb",
160
+ building_outline="building($|:.*?)*",
161
+ cycleway="highway:cycleway",
162
+ path="highway:(pedestrian|footway|steps|path|corridor)",
163
+ road="highway:(motorway|trunk|primary|secondary|tertiary|service|construction|track|unclassified|residential|.*_link)",
164
+ busway="highway:busway",
165
+ tree_row="natural:tree_row", # maybe merge with node?
166
+ )
167
+ # + nodes: bollard
168
+
169
+ nodes = dict(
170
+ tree="natural:tree",
171
+ stone="(natural:stone|barrier:block)",
172
+ crossing="highway:crossing",
173
+ lamp="highway:street_lamp",
174
+ traffic_signal="highway:traffic_signals",
175
+ bus_stop="highway:bus_stop",
176
+ stop_sign="highway:stop",
177
+ junction="highway:motorway_junction",
178
+ bus_stop_position="public_transport:stop_position",
179
+ gate="barrier:(gate|lift_gate|swing_gate|cycle_barrier)",
180
+ bollard="barrier:bollard",
181
+ shop="(shop.*?|amenity:(bank|post_office))",
182
+ restaurant="amenity:(restaurant|fast_food)",
183
+ bar="amenity:(cafe|bar|pub|biergarten)",
184
+ pharmacy="amenity:pharmacy",
185
+ fuel="amenity:fuel",
186
+ bicycle_parking="amenity:(bicycle_parking|bicycle_rental)",
187
+ charging_station="amenity:charging_station",
188
+ parking_entrance="amenity:parking_entrance",
189
+ atm="amenity:atm",
190
+ toilets="amenity:toilets",
191
+ vending_machine="amenity:vending_machine",
192
+ fountain="amenity:fountain",
193
+ waste_basket="amenity:(waste_basket|waste_disposal)",
194
+ bench="amenity:bench",
195
+ post_box="amenity:post_box",
196
+ artwork="tourism:artwork",
197
+ recycling="amenity:recycling",
198
+ give_way="highway:give_way",
199
+ clock="amenity:clock",
200
+ fire_hydrant="emergency:fire_hydrant",
201
+ pole="man_made:(flagpole|utility_pole)",
202
+ street_cabinet="man_made:street_cabinet",
203
+ )
204
+ # + ways: kerb
205
+
206
+
207
+ class Groups:
208
+ areas = list(Patterns.areas)
209
+ ways = list(Patterns.ways)
210
+ nodes = list(Patterns.nodes)
211
+
212
+
213
+ def group_elements(osm: OSMData):
214
+ elem2group = {
215
+ "area": {},
216
+ "way": {},
217
+ "node": {},
218
+ }
219
+
220
+ for node in filter(filter_node, osm.nodes.values()):
221
+ label = parse_node(node.tags)
222
+ if label is None:
223
+ continue
224
+ group = match_to_group(label, Patterns.nodes)
225
+ if group is None:
226
+ group = match_to_group(label, Patterns.ways)
227
+ if group is None:
228
+ continue # missing
229
+ elem2group["node"][node.id_] = group
230
+
231
+ for way in filter(filter_way, osm.ways.values()):
232
+ label = parse_way(way.tags)
233
+ if label is None:
234
+ continue
235
+ group = match_to_group(label, Patterns.ways)
236
+ if group is None:
237
+ group = match_to_group(label, Patterns.nodes)
238
+ if group is None:
239
+ continue # missing
240
+ elem2group["way"][way.id_] = group
241
+
242
+ for area in filter(filter_area, osm.ways.values()):
243
+ label = parse_area(area.tags)
244
+ if label is None:
245
+ continue
246
+ group = match_to_group(label, Patterns.areas)
247
+ if group is None:
248
+ group = match_to_group(label, Patterns.ways)
249
+ if group is None:
250
+ group = match_to_group(label, Patterns.nodes)
251
+ if group is None:
252
+ continue # missing
253
+ elem2group["area"][area.id_] = group
254
+
255
+ return elem2group
osm/raster.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from typing import Dict, List
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+
9
+ from utils.geo import BoundaryBox
10
+ from .data import MapArea, MapLine, MapNode
11
+ from .parser import Groups
12
+
13
+
14
+ class Canvas:
15
+ def __init__(self, bbox: BoundaryBox, ppm: float):
16
+ self.bbox = bbox
17
+ self.ppm = ppm
18
+ self.scaling = bbox.size * ppm
19
+ self.w, self.h = np.ceil(self.scaling).astype(int)
20
+ self.clear()
21
+
22
+ def clear(self):
23
+ self.raster = np.zeros((self.h, self.w), np.uint8)
24
+
25
+ def to_uv(self, xy: np.ndarray):
26
+ xy = self.bbox.normalize(xy)
27
+ xy[..., 1] = 1 - xy[..., 1]
28
+ s = self.scaling
29
+ if isinstance(xy, torch.Tensor):
30
+ s = torch.from_numpy(s).to(xy)
31
+ return xy * s - 0.5
32
+
33
+ def to_xy(self, uv: np.ndarray):
34
+ s = self.scaling
35
+ if isinstance(uv, torch.Tensor):
36
+ s = torch.from_numpy(s).to(uv)
37
+ xy = (uv + 0.5) / s
38
+ xy[..., 1] = 1 - xy[..., 1]
39
+ return self.bbox.unnormalize(xy)
40
+
41
+ def draw_polygon(self, xy: np.ndarray):
42
+ uv = self.to_uv(xy)
43
+ cv2.fillPoly(self.raster, uv[None].astype(np.int32), 255)
44
+
45
+ def draw_multipolygon(self, xys: List[np.ndarray]):
46
+ uvs = [self.to_uv(xy).round().astype(np.int32) for xy in xys]
47
+ cv2.fillPoly(self.raster, uvs, 255)
48
+
49
+ def draw_line(self, xy: np.ndarray, width: float = 1):
50
+ uv = self.to_uv(xy)
51
+ cv2.polylines(
52
+ self.raster, uv[None].round().astype(np.int32), False, 255, thickness=width
53
+ )
54
+
55
+ def draw_cell(self, xy: np.ndarray):
56
+ if not self.bbox.contains(xy):
57
+ return
58
+ uv = self.to_uv(xy)
59
+ self.raster[tuple(uv.round().astype(int).T[::-1])] = 255
60
+
61
+
62
+ def render_raster_masks(
63
+ nodes: List[MapNode],
64
+ lines: List[MapLine],
65
+ areas: List[MapArea],
66
+ canvas: Canvas,
67
+ ) -> Dict[str, np.ndarray]:
68
+ all_groups = Groups.areas + Groups.ways + Groups.nodes
69
+ masks = {k: np.zeros((canvas.h, canvas.w), np.uint8) for k in all_groups}
70
+
71
+ for area in areas:
72
+ canvas.raster = masks[area.group]
73
+ outlines = area.outers + area.inners
74
+ canvas.draw_multipolygon(outlines)
75
+ if area.group == "building":
76
+ canvas.raster = masks["building_outline"]
77
+ for line in outlines:
78
+ canvas.draw_line(line)
79
+
80
+ for line in lines:
81
+ canvas.raster = masks[line.group]
82
+ canvas.draw_line(line.xy)
83
+
84
+ for node in nodes:
85
+ canvas.raster = masks[node.group]
86
+ canvas.draw_cell(node.xy)
87
+
88
+ return masks
89
+
90
+
91
+ def mask_to_idx(group2mask: Dict[str, np.ndarray], groups: List[str]) -> np.ndarray:
92
+ masks = np.stack([group2mask[k] for k in groups]) > 0
93
+ void = ~np.any(masks, 0)
94
+ idx = np.argmax(masks, 0)
95
+ idx = np.where(void, np.zeros_like(idx), idx + 1) # add background
96
+ return idx
97
+
98
+
99
+ def render_raster_map(masks: Dict[str, np.ndarray]) -> np.ndarray:
100
+ areas = mask_to_idx(masks, Groups.areas)
101
+ ways = mask_to_idx(masks, Groups.ways)
102
+ nodes = mask_to_idx(masks, Groups.nodes)
103
+ return np.stack([areas, ways, nodes])
osm/reader.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import json
4
+ import re
5
+ from dataclasses import dataclass, field
6
+ from pathlib import Path
7
+ from typing import Any, Dict, List, Optional
8
+
9
+ from lxml import etree
10
+ import numpy as np
11
+
12
+ from utils.geo import BoundaryBox, Projection
13
+
14
+ METERS_PATTERN: re.Pattern = re.compile("^(?P<value>\\d*\\.?\\d*)\\s*m$")
15
+ KILOMETERS_PATTERN: re.Pattern = re.compile("^(?P<value>\\d*\\.?\\d*)\\s*km$")
16
+ MILES_PATTERN: re.Pattern = re.compile("^(?P<value>\\d*\\.?\\d*)\\s*mi$")
17
+
18
+
19
+ def parse_float(string: str) -> Optional[float]:
20
+ """Parse string representation of a float or integer value."""
21
+ try:
22
+ return float(string)
23
+ except (TypeError, ValueError):
24
+ return None
25
+
26
+
27
+ @dataclass(eq=False)
28
+ class OSMElement:
29
+ """
30
+ Something with tags (string to string mapping).
31
+ """
32
+
33
+ id_: int
34
+ tags: Dict[str, str]
35
+
36
+ def get_float(self, key: str) -> Optional[float]:
37
+ """Parse float from tag value."""
38
+ if key in self.tags:
39
+ return parse_float(self.tags[key])
40
+ return None
41
+
42
+ def get_length(self, key: str) -> Optional[float]:
43
+ """Get length in meters."""
44
+ if key not in self.tags:
45
+ return None
46
+
47
+ value: str = self.tags[key]
48
+
49
+ float_value: float = parse_float(value)
50
+ if float_value is not None:
51
+ return float_value
52
+
53
+ for pattern, ratio in [
54
+ (METERS_PATTERN, 1.0),
55
+ (KILOMETERS_PATTERN, 1000.0),
56
+ (MILES_PATTERN, 1609.344),
57
+ ]:
58
+ matcher: re.Match = pattern.match(value)
59
+ if matcher:
60
+ float_value: float = parse_float(matcher.group("value"))
61
+ if float_value is not None:
62
+ return float_value * ratio
63
+
64
+ return None
65
+
66
+ def __hash__(self) -> int:
67
+ return self.id_
68
+
69
+
70
+ @dataclass(eq=False)
71
+ class OSMNode(OSMElement):
72
+ """
73
+ OpenStreetMap node.
74
+
75
+ See https://wiki.openstreetmap.org/wiki/Node
76
+ """
77
+
78
+ geo: np.ndarray
79
+ visible: Optional[str] = None
80
+ xy: Optional[np.ndarray] = None
81
+
82
+ @classmethod
83
+ def from_dict(cls, structure: Dict[str, Any]) -> "OSMNode":
84
+ """
85
+ Parse node from Overpass-like structure.
86
+
87
+ :param structure: input structure
88
+ """
89
+ return cls(
90
+ structure["id"],
91
+ structure.get("tags", {}),
92
+ geo=np.array((structure["lat"], structure["lon"])),
93
+ visible=structure.get("visible"),
94
+ )
95
+
96
+
97
+ @dataclass(eq=False)
98
+ class OSMWay(OSMElement):
99
+ """
100
+ OpenStreetMap way.
101
+
102
+ See https://wiki.openstreetmap.org/wiki/Way
103
+ """
104
+
105
+ nodes: Optional[List[OSMNode]] = field(default_factory=list)
106
+ visible: Optional[str] = None
107
+
108
+ @classmethod
109
+ def from_dict(
110
+ cls, structure: Dict[str, Any], nodes: Dict[int, OSMNode]
111
+ ) -> "OSMWay":
112
+ """
113
+ Parse way from Overpass-like structure.
114
+
115
+ :param structure: input structure
116
+ :param nodes: node structure
117
+ """
118
+ return cls(
119
+ structure["id"],
120
+ structure.get("tags", {}),
121
+ [nodes[x] for x in structure["nodes"]],
122
+ visible=structure.get("visible"),
123
+ )
124
+
125
+ def is_cycle(self) -> bool:
126
+ """Is way a cycle way or an area boundary."""
127
+ return self.nodes[0] == self.nodes[-1]
128
+
129
+ def __repr__(self) -> str:
130
+ return f"Way <{self.id_}> {self.nodes}"
131
+
132
+
133
+ @dataclass
134
+ class OSMMember:
135
+ """
136
+ Member of OpenStreetMap relation.
137
+ """
138
+
139
+ type_: str
140
+ ref: int
141
+ role: str
142
+
143
+
144
+ @dataclass(eq=False)
145
+ class OSMRelation(OSMElement):
146
+ """
147
+ OpenStreetMap relation.
148
+
149
+ See https://wiki.openstreetmap.org/wiki/Relation
150
+ """
151
+
152
+ members: Optional[List[OSMMember]]
153
+ visible: Optional[str] = None
154
+
155
+ @classmethod
156
+ def from_dict(cls, structure: Dict[str, Any]) -> "OSMRelation":
157
+ """
158
+ Parse relation from Overpass-like structure.
159
+
160
+ :param structure: input structure
161
+ """
162
+ return cls(
163
+ structure["id"],
164
+ structure["tags"],
165
+ [OSMMember(x["type"], x["ref"], x["role"]) for x in structure["members"]],
166
+ visible=structure.get("visible"),
167
+ )
168
+
169
+
170
+ class OSMData:
171
+ """
172
+ The whole OpenStreetMap information about nodes, ways, and relations.
173
+ """
174
+
175
+ def __init__(self) -> None:
176
+ self.nodes: Dict[int, OSMNode] = {}
177
+ self.ways: Dict[int, OSMWay] = {}
178
+ self.relations: Dict[int, OSMRelation] = {}
179
+ self.box: BoundaryBox = None
180
+
181
+ @classmethod
182
+ def from_dict(cls, structure: Dict[str, Any]):
183
+ data = cls()
184
+ bounds = structure.get("bounds")
185
+ if bounds is not None:
186
+ data.box = BoundaryBox(
187
+ np.array([bounds["minlat"], bounds["minlon"]]),
188
+ np.array([bounds["maxlat"], bounds["maxlon"]]),
189
+ )
190
+
191
+ for element in structure["elements"]:
192
+ if element["type"] == "node":
193
+ node = OSMNode.from_dict(element)
194
+ data.add_node(node)
195
+ for element in structure["elements"]:
196
+ if element["type"] == "way":
197
+ way = OSMWay.from_dict(element, data.nodes)
198
+ data.add_way(way)
199
+ for element in structure["elements"]:
200
+ if element["type"] == "relation":
201
+ relation = OSMRelation.from_dict(element)
202
+ data.add_relation(relation)
203
+
204
+ return data
205
+
206
+ @classmethod
207
+ def from_json(cls, path: Path):
208
+ with path.open(encoding='utf-8') as fid:
209
+ structure = json.load(fid)
210
+ return cls.from_dict(structure)
211
+
212
+ @classmethod
213
+ def from_xml(cls, path: Path):
214
+ root = etree.parse(str(path)).getroot()
215
+ structure = {"elements": []}
216
+ from tqdm import tqdm
217
+
218
+ for elem in tqdm(root):
219
+ if elem.tag == "bounds":
220
+ structure["bounds"] = {
221
+ k: float(elem.attrib[k])
222
+ for k in ("minlon", "minlat", "maxlon", "maxlat")
223
+ }
224
+ elif elem.tag in {"node", "way", "relation"}:
225
+ if elem.tag == "node":
226
+ item = {
227
+ "id": int(elem.attrib["id"]),
228
+ "lat": float(elem.attrib["lat"]),
229
+ "lon": float(elem.attrib["lon"]),
230
+ "visible": elem.attrib.get("visible"),
231
+ "tags": {
232
+ x.attrib["k"]: x.attrib["v"] for x in elem if x.tag == "tag"
233
+ },
234
+ }
235
+ elif elem.tag == "way":
236
+ item = {
237
+ "id": int(elem.attrib["id"]),
238
+ "visible": elem.attrib.get("visible"),
239
+ "tags": {
240
+ x.attrib["k"]: x.attrib["v"] for x in elem if x.tag == "tag"
241
+ },
242
+ "nodes": [int(x.attrib["ref"]) for x in elem if x.tag == "nd"],
243
+ }
244
+ elif elem.tag == "relation":
245
+ item = {
246
+ "id": int(elem.attrib["id"]),
247
+ "visible": elem.attrib.get("visible"),
248
+ "tags": {
249
+ x.attrib["k"]: x.attrib["v"] for x in elem if x.tag == "tag"
250
+ },
251
+ "members": [
252
+ {
253
+ "type": x.attrib["type"],
254
+ "ref": int(x.attrib["ref"]),
255
+ "role": x.attrib["role"],
256
+ }
257
+ for x in elem
258
+ if x.tag == "member"
259
+ ],
260
+ }
261
+ item["type"] = elem.tag
262
+ structure["elements"].append(item)
263
+ elem.clear()
264
+ del root
265
+ return cls.from_dict(structure)
266
+
267
+ @classmethod
268
+ def from_file(cls, path: Path):
269
+ ext = path.suffix
270
+ if ext == ".json":
271
+ return cls.from_json(path)
272
+ elif ext in {".osm", ".xml"}:
273
+ return cls.from_xml(path)
274
+ else:
275
+ raise ValueError(f"Unknown extension for {path}")
276
+
277
+ def add_node(self, node: OSMNode):
278
+ """Add node and update map parameters."""
279
+ if node.id_ in self.nodes:
280
+ raise ValueError(f"Node with duplicate id {node.id_}.")
281
+ self.nodes[node.id_] = node
282
+
283
+ def add_way(self, way: OSMWay):
284
+ """Add way and update map parameters."""
285
+ if way.id_ in self.ways:
286
+ raise ValueError(f"Way with duplicate id {way.id_}.")
287
+ self.ways[way.id_] = way
288
+
289
+ def add_relation(self, relation: OSMRelation):
290
+ """Add relation and update map parameters."""
291
+ if relation.id_ in self.relations:
292
+ raise ValueError(f"Relation with duplicate id {relation.id_}.")
293
+ self.relations[relation.id_] = relation
294
+
295
+ def add_xy_to_nodes(self, proj: Projection):
296
+ nodes = list(self.nodes.values())
297
+ if len(nodes) == 0:
298
+ return
299
+ geos = np.stack([n.geo for n in nodes], 0)
300
+ if proj.bounds is not None:
301
+ # For some reasons few nodes are sometimes very far off the initial bbox.
302
+ valid = proj.bounds.contains(geos)
303
+ if valid.mean() < 0.9:
304
+ print("Many nodes are out of the projection bounds.")
305
+ xys = np.zeros_like(geos)
306
+ xys[valid] = proj.project(geos[valid])
307
+ else:
308
+ xys = proj.project(geos)
309
+ for xy, node in zip(xys, nodes):
310
+ node.xy = xy
osm/tiling.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import io
4
+ import pickle
5
+ from pathlib import Path
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
+ import numpy as np
9
+ from PIL import Image
10
+ import rtree
11
+
12
+ from utils.geo import BoundaryBox, Projection
13
+ from .data import MapData
14
+ from .download import get_osm
15
+ from .parser import Groups
16
+ from .raster import Canvas, render_raster_map, render_raster_masks
17
+ from .reader import OSMData, OSMNode, OSMWay
18
+
19
+
20
+ class MapIndex:
21
+ def __init__(
22
+ self,
23
+ data: MapData,
24
+ ):
25
+ self.index_nodes = rtree.index.Index()
26
+ for i, node in data.nodes.items():
27
+ self.index_nodes.insert(i, tuple(node.xy) * 2)
28
+
29
+ self.index_lines = rtree.index.Index()
30
+ for i, line in data.lines.items():
31
+ bbox = tuple(np.r_[line.xy.min(0), line.xy.max(0)])
32
+ self.index_lines.insert(i, bbox)
33
+
34
+ self.index_areas = rtree.index.Index()
35
+ for i, area in data.areas.items():
36
+ xy = np.concatenate(area.outers + area.inners)
37
+ bbox = tuple(np.r_[xy.min(0), xy.max(0)])
38
+ self.index_areas.insert(i, bbox)
39
+
40
+ self.data = data
41
+
42
+ def query(self, bbox: BoundaryBox) -> Tuple[List[OSMNode], List[OSMWay]]:
43
+ query = tuple(np.r_[bbox.min_, bbox.max_])
44
+ ret = []
45
+ for x in ["nodes", "lines", "areas"]:
46
+ ids = getattr(self, "index_" + x).intersection(query)
47
+ ret.append([getattr(self.data, x)[i] for i in ids])
48
+ return tuple(ret)
49
+
50
+
51
+ def bbox_to_slice(bbox: BoundaryBox, canvas: Canvas):
52
+ uv_min = np.ceil(canvas.to_uv(bbox.min_)).astype(int)
53
+ uv_max = np.ceil(canvas.to_uv(bbox.max_)).astype(int)
54
+ slice_ = (slice(uv_max[1], uv_min[1]), slice(uv_min[0], uv_max[0]))
55
+ return slice_
56
+
57
+
58
+ def round_bbox(bbox: BoundaryBox, origin: np.ndarray, ppm: int):
59
+ bbox = bbox.translate(-origin)
60
+ bbox = BoundaryBox(np.round(bbox.min_ * ppm) / ppm, np.round(bbox.max_ * ppm) / ppm)
61
+ return bbox.translate(origin)
62
+
63
+ class MapTileManager:
64
+ def __init__(
65
+ self,
66
+ osmpath:Path,
67
+ ):
68
+
69
+ self.osm = OSMData.from_file(osmpath)
70
+
71
+
72
+ # @classmethod
73
+ def from_bbox(
74
+ self,
75
+ projection: Projection,
76
+ bbox: BoundaryBox,
77
+ ppm: int,
78
+ tile_size: int = 128,
79
+ ):
80
+ # bbox_osm = projection.unproject(bbox)
81
+ # if path is not None and path.is_file():
82
+ # print(OSMData.from_file)
83
+ # osm = OSMData.from_file(path)
84
+ # if osm.box is not None:
85
+ # assert osm.box.contains(bbox_osm)
86
+ # else:
87
+ # osm = OSMData.from_dict(get_osm(bbox_osm, path))
88
+
89
+ self.osm.add_xy_to_nodes(projection)
90
+ map_data = MapData.from_osm(self.osm)
91
+ map_index = MapIndex(map_data)
92
+
93
+ bounds_x, bounds_y = [
94
+ np.r_[np.arange(min_, max_, tile_size), max_]
95
+ for min_, max_ in zip(bbox.min_, bbox.max_)
96
+ ]
97
+ bbox_tiles = {}
98
+ for i, xmin in enumerate(bounds_x[:-1]):
99
+ for j, ymin in enumerate(bounds_y[:-1]):
100
+ bbox_tiles[i, j] = BoundaryBox(
101
+ [xmin, ymin], [bounds_x[i + 1], bounds_y[j + 1]]
102
+ )
103
+
104
+ tiles = {}
105
+ for ij, bbox_tile in bbox_tiles.items():
106
+ canvas = Canvas(bbox_tile, ppm)
107
+ nodes, lines, areas = map_index.query(bbox_tile)
108
+ masks = render_raster_masks(nodes, lines, areas, canvas)
109
+ canvas.raster = render_raster_map(masks)
110
+ tiles[ij] = canvas
111
+
112
+ groups = {k: v for k, v in vars(Groups).items() if not k.startswith("__")}
113
+
114
+ self.origin = bbox.min_
115
+ self.bbox = bbox
116
+ self.tiles = tiles
117
+ self.tile_size = tile_size
118
+ self.ppm = ppm
119
+ self.projection = projection
120
+ self.groups = groups
121
+ self.map_data = map_data
122
+
123
+ return self.query(bbox)
124
+ # return cls(tiles, bbox, tile_size, ppm, projection, groups, map_data)
125
+
126
+ def query(self, bbox: BoundaryBox) -> Canvas:
127
+ bbox = round_bbox(bbox, self.bbox.min_, self.ppm)
128
+ canvas = Canvas(bbox, self.ppm)
129
+ raster = np.zeros((3, canvas.h, canvas.w), np.uint8)
130
+
131
+ bbox_all = bbox & self.bbox
132
+ ij_min = np.floor((bbox_all.min_ - self.origin) / self.tile_size).astype(int)
133
+ ij_max = np.ceil((bbox_all.max_ - self.origin) / self.tile_size).astype(int) - 1
134
+ for i in range(ij_min[0], ij_max[0] + 1):
135
+ for j in range(ij_min[1], ij_max[1] + 1):
136
+ tile = self.tiles[i, j]
137
+ bbox_select = tile.bbox & bbox
138
+ slice_query = bbox_to_slice(bbox_select, canvas)
139
+ slice_tile = bbox_to_slice(bbox_select, tile)
140
+ raster[(slice(None),) + slice_query] = tile.raster[
141
+ (slice(None),) + slice_tile
142
+ ]
143
+ canvas.raster = raster
144
+ return canvas
145
+
146
+ def save(self, path: Path):
147
+ dump = {
148
+ "bbox": self.bbox.format(),
149
+ "tile_size": self.tile_size,
150
+ "ppm": self.ppm,
151
+ "groups": self.groups,
152
+ "tiles_bbox": {},
153
+ "tiles_raster": {},
154
+ }
155
+ if self.projection is not None:
156
+ dump["ref_latlonalt"] = self.projection.latlonalt
157
+ for ij, canvas in self.tiles.items():
158
+ dump["tiles_bbox"][ij] = canvas.bbox.format()
159
+ raster_bytes = io.BytesIO()
160
+ raster = Image.fromarray(canvas.raster.transpose(1, 2, 0).astype(np.uint8))
161
+ raster.save(raster_bytes, format="PNG")
162
+ dump["tiles_raster"][ij] = raster_bytes
163
+ with open(path, "wb") as fp:
164
+ pickle.dump(dump, fp)
165
+
166
+ @classmethod
167
+ def load(cls, path: Path):
168
+ with path.open("rb") as fp:
169
+ dump = pickle.load(fp)
170
+ tiles = {}
171
+ for ij, bbox in dump["tiles_bbox"].items():
172
+ tiles[ij] = Canvas(BoundaryBox.from_string(bbox), dump["ppm"])
173
+ raster = np.asarray(Image.open(dump["tiles_raster"][ij]))
174
+ tiles[ij].raster = raster.transpose(2, 0, 1).copy()
175
+ projection = Projection(*dump["ref_latlonalt"])
176
+ return cls(
177
+ tiles,
178
+ BoundaryBox.from_string(dump["bbox"]),
179
+ dump["tile_size"],
180
+ dump["ppm"],
181
+ projection,
182
+ dump["groups"],
183
+ )
184
+
185
+ class TileManager:
186
+ def __init__(
187
+ self,
188
+ tiles: Dict,
189
+ bbox: BoundaryBox,
190
+ tile_size: int,
191
+ ppm: int,
192
+ projection: Projection,
193
+ groups: Dict[str, List[str]],
194
+ map_data: Optional[MapData] = None,
195
+ ):
196
+ self.origin = bbox.min_
197
+ self.bbox = bbox
198
+ self.tiles = tiles
199
+ self.tile_size = tile_size
200
+ self.ppm = ppm
201
+ self.projection = projection
202
+ self.groups = groups
203
+ self.map_data = map_data
204
+ assert np.all(tiles[0, 0].bbox.min_ == self.origin)
205
+ for tile in tiles.values():
206
+ assert bbox.contains(tile.bbox)
207
+
208
+ @classmethod
209
+ def from_bbox(
210
+ cls,
211
+ projection: Projection,
212
+ bbox: BoundaryBox,
213
+ ppm: int,
214
+ path: Optional[Path] = None,
215
+ tile_size: int = 128,
216
+ ):
217
+ bbox_osm = projection.unproject(bbox)
218
+ if path is not None and path.is_file():
219
+ print(OSMData.from_file)
220
+ osm = OSMData.from_file(path)
221
+ if osm.box is not None:
222
+ assert osm.box.contains(bbox_osm)
223
+ else:
224
+ osm = OSMData.from_dict(get_osm(bbox_osm, path))
225
+
226
+ osm.add_xy_to_nodes(projection)
227
+ map_data = MapData.from_osm(osm)
228
+ map_index = MapIndex(map_data)
229
+
230
+ bounds_x, bounds_y = [
231
+ np.r_[np.arange(min_, max_, tile_size), max_]
232
+ for min_, max_ in zip(bbox.min_, bbox.max_)
233
+ ]
234
+ bbox_tiles = {}
235
+ for i, xmin in enumerate(bounds_x[:-1]):
236
+ for j, ymin in enumerate(bounds_y[:-1]):
237
+ bbox_tiles[i, j] = BoundaryBox(
238
+ [xmin, ymin], [bounds_x[i + 1], bounds_y[j + 1]]
239
+ )
240
+
241
+ tiles = {}
242
+ for ij, bbox_tile in bbox_tiles.items():
243
+ canvas = Canvas(bbox_tile, ppm)
244
+ nodes, lines, areas = map_index.query(bbox_tile)
245
+ masks = render_raster_masks(nodes, lines, areas, canvas)
246
+ canvas.raster = render_raster_map(masks)
247
+ tiles[ij] = canvas
248
+
249
+ groups = {k: v for k, v in vars(Groups).items() if not k.startswith("__")}
250
+
251
+ return cls(tiles, bbox, tile_size, ppm, projection, groups, map_data)
252
+
253
+ def query(self, bbox: BoundaryBox) -> Canvas:
254
+ bbox = round_bbox(bbox, self.bbox.min_, self.ppm)
255
+ canvas = Canvas(bbox, self.ppm)
256
+ raster = np.zeros((3, canvas.h, canvas.w), np.uint8)
257
+
258
+ bbox_all = bbox & self.bbox
259
+ ij_min = np.floor((bbox_all.min_ - self.origin) / self.tile_size).astype(int)
260
+ ij_max = np.ceil((bbox_all.max_ - self.origin) / self.tile_size).astype(int) - 1
261
+ for i in range(ij_min[0], ij_max[0] + 1):
262
+ for j in range(ij_min[1], ij_max[1] + 1):
263
+ tile = self.tiles[i, j]
264
+ bbox_select = tile.bbox & bbox
265
+ slice_query = bbox_to_slice(bbox_select, canvas)
266
+ slice_tile = bbox_to_slice(bbox_select, tile)
267
+ raster[(slice(None),) + slice_query] = tile.raster[
268
+ (slice(None),) + slice_tile
269
+ ]
270
+ canvas.raster = raster
271
+ return canvas
272
+
273
+ def save(self, path: Path):
274
+ dump = {
275
+ "bbox": self.bbox.format(),
276
+ "tile_size": self.tile_size,
277
+ "ppm": self.ppm,
278
+ "groups": self.groups,
279
+ "tiles_bbox": {},
280
+ "tiles_raster": {},
281
+ }
282
+ if self.projection is not None:
283
+ dump["ref_latlonalt"] = self.projection.latlonalt
284
+ for ij, canvas in self.tiles.items():
285
+ dump["tiles_bbox"][ij] = canvas.bbox.format()
286
+ raster_bytes = io.BytesIO()
287
+ raster = Image.fromarray(canvas.raster.transpose(1, 2, 0).astype(np.uint8))
288
+ raster.save(raster_bytes, format="PNG")
289
+ dump["tiles_raster"][ij] = raster_bytes
290
+ with open(path, "wb") as fp:
291
+ pickle.dump(dump, fp)
292
+
293
+ @classmethod
294
+ def load(cls, path: Path):
295
+ with path.open("rb") as fp:
296
+ dump = pickle.load(fp)
297
+ tiles = {}
298
+ for ij, bbox in dump["tiles_bbox"].items():
299
+ tiles[ij] = Canvas(BoundaryBox.from_string(bbox), dump["ppm"])
300
+ raster = np.asarray(Image.open(dump["tiles_raster"][ij]))
301
+ tiles[ij].raster = raster.transpose(2, 0, 1).copy()
302
+ projection = Projection(*dump["ref_latlonalt"])
303
+ return cls(
304
+ tiles,
305
+ BoundaryBox.from_string(dump["bbox"]),
306
+ dump["tile_size"],
307
+ dump["ppm"],
308
+ projection,
309
+ dump["groups"],
310
+ )
osm/viz.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ import matplotlib as mpl
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import plotly.graph_objects as go
7
+ import PIL.Image
8
+
9
+ from utils.viz_2d import add_text
10
+ from .parser import Groups
11
+
12
+
13
+ class GeoPlotter:
14
+ def __init__(self, zoom=12, **kwargs):
15
+ self.fig = go.Figure()
16
+ self.fig.update_layout(
17
+ mapbox_style="open-street-map",
18
+ autosize=True,
19
+ mapbox_zoom=zoom,
20
+ margin={"r": 0, "t": 0, "l": 0, "b": 0},
21
+ showlegend=True,
22
+ **kwargs,
23
+ )
24
+
25
+ def points(self, latlons, color, text=None, name=None, size=5, **kwargs):
26
+ latlons = np.asarray(latlons)
27
+ self.fig.add_trace(
28
+ go.Scattermapbox(
29
+ lat=latlons[..., 0],
30
+ lon=latlons[..., 1],
31
+ mode="markers",
32
+ text=text,
33
+ marker_color=color,
34
+ marker_size=size,
35
+ name=name,
36
+ **kwargs,
37
+ )
38
+ )
39
+ center = latlons.reshape(-1, 2).mean(0)
40
+ self.fig.update_layout(
41
+ mapbox_center=dict(zip(("lat", "lon"), center)),
42
+ )
43
+
44
+ def bbox(self, bbox, color, name=None, **kwargs):
45
+ corners = np.stack(
46
+ [bbox.min_, bbox.left_top, bbox.max_, bbox.right_bottom, bbox.min_]
47
+ )
48
+ self.fig.add_trace(
49
+ go.Scattermapbox(
50
+ lat=corners[:, 0],
51
+ lon=corners[:, 1],
52
+ mode="lines",
53
+ marker_color=color,
54
+ name=name,
55
+ **kwargs,
56
+ )
57
+ )
58
+ self.fig.update_layout(
59
+ mapbox_center=dict(zip(("lat", "lon"), bbox.center)),
60
+ )
61
+
62
+ def raster(self, raster, bbox, below="traces", **kwargs):
63
+ if not np.issubdtype(raster.dtype, np.integer):
64
+ raster = (raster * 255).astype(np.uint8)
65
+ raster = PIL.Image.fromarray(raster)
66
+ corners = np.stack(
67
+ [
68
+ bbox.min_,
69
+ bbox.left_top,
70
+ bbox.max_,
71
+ bbox.right_bottom,
72
+ ]
73
+ )[::-1, ::-1]
74
+ layers = [*self.fig.layout.mapbox.layers]
75
+ layers.append(
76
+ dict(
77
+ sourcetype="image",
78
+ source=raster,
79
+ coordinates=corners,
80
+ below=below,
81
+ **kwargs,
82
+ )
83
+ )
84
+ self.fig.layout.mapbox.layers = layers
85
+
86
+
87
+ map_colors = {
88
+ "building": (84, 155, 255),
89
+ "parking": (255, 229, 145),
90
+ "playground": (150, 133, 125),
91
+ "grass": (188, 255, 143),
92
+ "park": (0, 158, 16),
93
+ "forest": (0, 92, 9),
94
+ "water": (184, 213, 255),
95
+ "fence": (238, 0, 255),
96
+ "wall": (0, 0, 0),
97
+ "hedge": (107, 68, 48),
98
+ "kerb": (255, 234, 0),
99
+ "building_outline": (0, 0, 255),
100
+ "cycleway": (0, 251, 255),
101
+ "path": (8, 237, 0),
102
+ "road": (255, 0, 0),
103
+ "tree_row": (0, 92, 9),
104
+ "busway": (255, 128, 0),
105
+ "void": [int(255 * 0.9)] * 3,
106
+ }
107
+
108
+
109
+ class Colormap:
110
+ colors_areas = np.stack([map_colors[k] for k in ["void"] + Groups.areas])
111
+ colors_ways = np.stack([map_colors[k] for k in ["void"] + Groups.ways])
112
+
113
+ @classmethod
114
+ def apply(cls, rasters):
115
+ return (
116
+ np.where(
117
+ rasters[1, ..., None] > 0,
118
+ cls.colors_ways[rasters[1]],
119
+ cls.colors_areas[rasters[0]],
120
+ )
121
+ / 255.0
122
+ )
123
+
124
+ @classmethod
125
+ def add_colorbar(cls):
126
+ ax2 = plt.gcf().add_axes([1, 0.1, 0.02, 0.8])
127
+ color_list = np.r_[cls.colors_areas[1:], cls.colors_ways[1:]] / 255.0
128
+ cmap = mpl.colors.ListedColormap(color_list[::-1])
129
+ ticks = np.linspace(0, 1, len(color_list), endpoint=False)
130
+ ticks += 1 / len(color_list) / 2
131
+ cb = mpl.colorbar.ColorbarBase(
132
+ ax2,
133
+ cmap=cmap,
134
+ orientation="vertical",
135
+ ticks=ticks,
136
+ )
137
+ cb.set_ticklabels((Groups.areas + Groups.ways)[::-1])
138
+ ax2.tick_params(labelsize=15)
139
+
140
+
141
+ def plot_nodes(idx, raster, fontsize=8, size=15):
142
+ ax = plt.gcf().axes[idx]
143
+ ax.autoscale(enable=False)
144
+ nodes_xy = np.stack(np.where(raster > 0)[::-1], -1)
145
+ nodes_val = raster[tuple(nodes_xy.T[::-1])] - 1
146
+ ax.scatter(*nodes_xy.T, c="k", s=size)
147
+ for xy, val in zip(nodes_xy, nodes_val):
148
+ group = Groups.nodes[val]
149
+ add_text(
150
+ idx,
151
+ group,
152
+ xy + 2,
153
+ lcolor=None,
154
+ fs=fontsize,
155
+ color="k",
156
+ normalized=False,
157
+ ha="center",
158
+ )
159
+ plt.show()
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ opencv-python
5
+ Pillow
6
+ tqdm>=4.36.0
7
+ matplotlib
8
+ plotly
9
+ scipy
10
+ omegaconf
11
+ pytorch-lightning
12
+ torchmetrics
13
+ jupyter
14
+ lxml
15
+ rtree
16
+ scikit-learn
17
+ geopy
18
+ exifread
19
+ gradio_client
20
+ urllib3>=2
train.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import warnings
3
+ warnings.filterwarnings('ignore')
4
+ from typing import Optional
5
+ from pathlib import Path
6
+ from models.maplocnet import MapLocNet
7
+ import hydra
8
+ import pytorch_lightning as pl
9
+ import torch
10
+ from omegaconf import DictConfig, OmegaConf
11
+ from pytorch_lightning.utilities import rank_zero_only
12
+ from module import GenericModule
13
+ from logger import logger, pl_logger, EXPERIMENTS_PATH
14
+ from module import GenericModule
15
+ from dataset import UavMapDatasetModule
16
+ from pytorch_lightning.callbacks.early_stopping import EarlyStopping
17
+ # print(osp.join(osp.dirname(__file__), "conf"))
18
+
19
+
20
+ class CleanProgressBar(pl.callbacks.TQDMProgressBar):
21
+ def get_metrics(self, trainer, model):
22
+ items = super().get_metrics(trainer, model)
23
+ items.pop("v_num", None) # don't show the version number
24
+ items.pop("loss", None)
25
+ return items
26
+
27
+
28
+ class SeedingCallback(pl.callbacks.Callback):
29
+ def on_epoch_start_(self, trainer, module):
30
+ seed = module.cfg.experiment.seed
31
+ is_overfit = module.cfg.training.trainer.get("overfit_batches", 0) > 0
32
+ if trainer.training and not is_overfit:
33
+ seed = seed + trainer.current_epoch
34
+
35
+ # Temporarily disable the logging (does not seem to work?)
36
+ pl_logger.disabled = True
37
+ try:
38
+ pl.seed_everything(seed, workers=True)
39
+ finally:
40
+ pl_logger.disabled = False
41
+
42
+ def on_train_epoch_start(self, *args, **kwargs):
43
+ self.on_epoch_start_(*args, **kwargs)
44
+
45
+ def on_validation_epoch_start(self, *args, **kwargs):
46
+ self.on_epoch_start_(*args, **kwargs)
47
+
48
+ def on_test_epoch_start(self, *args, **kwargs):
49
+ self.on_epoch_start_(*args, **kwargs)
50
+
51
+
52
+ class ConsoleLogger(pl.callbacks.Callback):
53
+ @rank_zero_only
54
+ def on_train_epoch_start(self, trainer, module):
55
+ logger.info(
56
+ "New training epoch %d for experiment '%s'.",
57
+ module.current_epoch,
58
+ module.cfg.experiment.name,
59
+ )
60
+
61
+ # @rank_zero_only
62
+ # def on_validation_epoch_end(self, trainer, module):
63
+ # results = {
64
+ # **dict(module.metrics_val.items()),
65
+ # **dict(module.losses_val.items()),
66
+ # }
67
+ # results = [f"{k} {v.compute():.3E}" for k, v in results.items()]
68
+ # logger.info(f'[Validation] {{{", ".join(results)}}}')
69
+
70
+
71
+ def find_last_checkpoint_path(experiment_dir):
72
+ cls = pl.callbacks.ModelCheckpoint
73
+ path = osp.join(experiment_dir, cls.CHECKPOINT_NAME_LAST + cls.FILE_EXTENSION)
74
+ if osp.exists(path):
75
+ return path
76
+ else:
77
+ return None
78
+
79
+
80
+ def prepare_experiment_dir(experiment_dir, cfg, rank):
81
+ config_path = osp.join(experiment_dir, "config.yaml")
82
+ last_checkpoint_path = find_last_checkpoint_path(experiment_dir)
83
+ if last_checkpoint_path is not None:
84
+ if rank == 0:
85
+ logger.info(
86
+ "Resuming the training from checkpoint %s", last_checkpoint_path
87
+ )
88
+ if osp.exists(config_path):
89
+ with open(config_path, "r") as fp:
90
+ cfg_prev = OmegaConf.create(fp.read())
91
+ compare_keys = ["experiment", "data", "model", "training"]
92
+ if OmegaConf.masked_copy(cfg, compare_keys) != OmegaConf.masked_copy(
93
+ cfg_prev, compare_keys
94
+ ):
95
+ raise ValueError(
96
+ "Attempting to resume training with a different config: "
97
+ f"{OmegaConf.masked_copy(cfg, compare_keys)} vs "
98
+ f"{OmegaConf.masked_copy(cfg_prev, compare_keys)}"
99
+ )
100
+ if rank == 0:
101
+ Path(experiment_dir).mkdir(exist_ok=True, parents=True)
102
+ with open(config_path, "w") as fp:
103
+ OmegaConf.save(cfg, fp)
104
+ return last_checkpoint_path
105
+
106
+
107
+ def train(cfg: DictConfig) -> None:
108
+ torch.set_float32_matmul_precision("medium")
109
+ OmegaConf.resolve(cfg)
110
+ rank = rank_zero_only.rank
111
+
112
+ if rank == 0:
113
+ logger.info("Starting training with config:\n%s", OmegaConf.to_yaml(cfg))
114
+ if cfg.experiment.gpus in (None, 0):
115
+ logger.warning("Will train on CPU...")
116
+ cfg.experiment.gpus = 0
117
+ elif not torch.cuda.is_available():
118
+ raise ValueError("Requested GPU but no NVIDIA drivers found.")
119
+ pl.seed_everything(cfg.experiment.seed, workers=True)
120
+
121
+ init_checkpoint_path = cfg.training.get("finetune_from_checkpoint")
122
+ if init_checkpoint_path is not None:
123
+ logger.info("Initializing the model from checkpoint %s.", init_checkpoint_path)
124
+ model = GenericModule.load_from_checkpoint(
125
+ init_checkpoint_path, strict=True, find_best=False, cfg=cfg
126
+ )
127
+ else:
128
+ model = GenericModule(cfg)
129
+ if rank == 0:
130
+ logger.info("Network:\n%s", model.model)
131
+
132
+ experiment_dir = osp.join(EXPERIMENTS_PATH, cfg.experiment.name)
133
+ last_checkpoint_path = prepare_experiment_dir(experiment_dir, cfg, rank)
134
+ checkpointing_epoch = pl.callbacks.ModelCheckpoint(
135
+ dirpath=experiment_dir,
136
+ filename="checkpoint-epoch-{epoch:02d}-loss-{loss/total/val:02f}",
137
+ auto_insert_metric_name=False,
138
+ save_last=True,
139
+ every_n_epochs=1,
140
+ save_on_train_epoch_end=True,
141
+ verbose=True,
142
+ **cfg.training.checkpointing,
143
+ )
144
+ checkpointing_step = pl.callbacks.ModelCheckpoint(
145
+ dirpath=experiment_dir,
146
+ filename="checkpoint-step-{step}-{loss/total/val:02f}",
147
+ auto_insert_metric_name=False,
148
+ save_last=True,
149
+ every_n_train_steps=1000,
150
+ verbose=True,
151
+ **cfg.training.checkpointing,
152
+ )
153
+ checkpointing_step.CHECKPOINT_NAME_LAST = "last-step-checkpointing"
154
+
155
+ # 创建 EarlyStopping 回调
156
+ early_stopping_callback = EarlyStopping(monitor=cfg.training.checkpointing.monitor, patience=5)
157
+
158
+ strategy = None
159
+ if cfg.experiment.gpus > 1:
160
+ strategy = pl.strategies.DDPStrategy(find_unused_parameters=False)
161
+ for split in ["train", "val"]:
162
+ cfg.data[split].batch_size = (
163
+ cfg.data[split].batch_size // cfg.experiment.gpus
164
+ )
165
+ cfg.data[split].num_workers = int(
166
+ (cfg.data[split].num_workers + cfg.experiment.gpus - 1)
167
+ / cfg.experiment.gpus
168
+ )
169
+
170
+ # data = data_modules[cfg.data.get("name", "mapillary")](cfg.data)
171
+
172
+ datamodule =UavMapDatasetModule(cfg.data)
173
+
174
+ tb_args = {"name": cfg.experiment.name, "version": ""}
175
+ tb = pl.loggers.TensorBoardLogger(EXPERIMENTS_PATH, **tb_args)
176
+
177
+ callbacks = [
178
+ checkpointing_epoch,
179
+ checkpointing_step,
180
+ # early_stopping_callback,
181
+ pl.callbacks.LearningRateMonitor(),
182
+ SeedingCallback(),
183
+ CleanProgressBar(),
184
+ ConsoleLogger(),
185
+ ]
186
+ if cfg.experiment.gpus > 0:
187
+ callbacks.append(pl.callbacks.DeviceStatsMonitor())
188
+
189
+ trainer = pl.Trainer(
190
+ default_root_dir=experiment_dir,
191
+ detect_anomaly=False,
192
+ # strategy=ddp_find_unused_parameters_true,
193
+ enable_model_summary=True,
194
+ sync_batchnorm=True,
195
+ enable_checkpointing=True,
196
+ logger=tb,
197
+ callbacks=callbacks,
198
+ strategy=strategy,
199
+ check_val_every_n_epoch=1,
200
+ accelerator="gpu",
201
+ num_nodes=1,
202
+ **cfg.training.trainer,
203
+ )
204
+ trainer.fit(model=model, datamodule=datamodule, ckpt_path=last_checkpoint_path)
205
+
206
+
207
+ @hydra.main(
208
+ config_path=osp.join(osp.dirname(__file__), "conf"), config_name="maplocnet.yaml"
209
+ )
210
+ def main(cfg: DictConfig) -> None:
211
+ OmegaConf.save(config=cfg, f='maplocnet.yaml')
212
+ train(cfg)
213
+
214
+
215
+ if __name__ == "__main__":
216
+ main()
217
+
train.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ nohup python train.py > logs/train0907.log 2>&1 &
utils/exif.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Copied from opensfm.exif to minimize hard dependencies."""
2
+ from pathlib import Path
3
+ import json
4
+ import datetime
5
+ import logging
6
+ from codecs import encode, decode
7
+ from typing import Any, Dict, Optional, Tuple
8
+
9
+ import exifread
10
+
11
+ logger: logging.Logger = logging.getLogger(__name__)
12
+
13
+ inch_in_mm = 25.4
14
+ cm_in_mm = 10
15
+ um_in_mm = 0.001
16
+ default_projection = "perspective"
17
+ maximum_altitude = 1e4
18
+
19
+
20
+ def sensor_data():
21
+ with (Path(__file__).parent / "sensor_data.json").open() as fid:
22
+ data = json.load(fid)
23
+ return {k.lower(): v for k, v in data.items()}
24
+
25
+
26
+ def eval_frac(value) -> Optional[float]:
27
+ try:
28
+ return float(value.num) / float(value.den)
29
+ except ZeroDivisionError:
30
+ return None
31
+
32
+
33
+ def gps_to_decimal(values, reference) -> Optional[float]:
34
+ sign = 1 if reference in "NE" else -1
35
+ degrees = eval_frac(values[0])
36
+ minutes = eval_frac(values[1])
37
+ seconds = eval_frac(values[2])
38
+ if degrees is not None and minutes is not None and seconds is not None:
39
+ return sign * (degrees + minutes / 60 + seconds / 3600)
40
+ return None
41
+
42
+
43
+ def get_tag_as_float(tags, key, index: int = 0) -> Optional[float]:
44
+ if key in tags:
45
+ val = tags[key].values[index]
46
+ if isinstance(val, exifread.utils.Ratio):
47
+ ret_val = eval_frac(val)
48
+ if ret_val is None:
49
+ logger.error(
50
+ 'The rational "{2}" of tag "{0:s}" at index {1:d} c'
51
+ "aused a division by zero error".format(key, index, val)
52
+ )
53
+ return ret_val
54
+ else:
55
+ return float(val)
56
+ else:
57
+ return None
58
+
59
+
60
+ def compute_focal(
61
+ focal_35: Optional[float], focal: Optional[float], sensor_width, sensor_string
62
+ ) -> Tuple[float, float]:
63
+ if focal_35 is not None and focal_35 > 0:
64
+ focal_ratio = focal_35 / 36.0 # 35mm film produces 36x24mm pictures.
65
+ else:
66
+ if not sensor_width:
67
+ sensor_width = sensor_data().get(sensor_string, None)
68
+ if sensor_width and focal:
69
+ focal_ratio = focal / sensor_width
70
+ focal_35 = 36.0 * focal_ratio
71
+ else:
72
+ focal_35 = 0.0
73
+ focal_ratio = 0.0
74
+ return focal_35, focal_ratio
75
+
76
+
77
+ def sensor_string(make: str, model: str) -> str:
78
+ if make != "unknown":
79
+ # remove duplicate 'make' information in 'model'
80
+ model = model.replace(make, "")
81
+ return (make.strip() + " " + model.strip()).strip().lower()
82
+
83
+
84
+ def unescape_string(s) -> str:
85
+ return decode(encode(s, "latin-1", "backslashreplace"), "unicode-escape")
86
+
87
+
88
+ class EXIF:
89
+ def __init__(
90
+ self, fileobj, image_size_loader, use_exif_size=True, name=None
91
+ ) -> None:
92
+ self.image_size_loader = image_size_loader
93
+ self.use_exif_size = use_exif_size
94
+ self.fileobj = fileobj
95
+ self.tags = exifread.process_file(fileobj, details=False)
96
+ fileobj.seek(0)
97
+ self.fileobj_name = self.fileobj.name if name is None else name
98
+
99
+ def extract_image_size(self) -> Tuple[int, int]:
100
+ if (
101
+ self.use_exif_size
102
+ and "EXIF ExifImageWidth" in self.tags
103
+ and "EXIF ExifImageLength" in self.tags
104
+ ):
105
+ width, height = (
106
+ int(self.tags["EXIF ExifImageWidth"].values[0]),
107
+ int(self.tags["EXIF ExifImageLength"].values[0]),
108
+ )
109
+ elif (
110
+ self.use_exif_size
111
+ and "Image ImageWidth" in self.tags
112
+ and "Image ImageLength" in self.tags
113
+ ):
114
+ width, height = (
115
+ int(self.tags["Image ImageWidth"].values[0]),
116
+ int(self.tags["Image ImageLength"].values[0]),
117
+ )
118
+ else:
119
+ height, width = self.image_size_loader()
120
+ return width, height
121
+
122
+ def _decode_make_model(self, value) -> str:
123
+ """Python 2/3 compatible decoding of make/model field."""
124
+ if hasattr(value, "decode"):
125
+ try:
126
+ return value.decode("utf-8")
127
+ except UnicodeDecodeError:
128
+ return "unknown"
129
+ else:
130
+ return value
131
+
132
+ def extract_make(self) -> str:
133
+ # Camera make and model
134
+ if "EXIF LensMake" in self.tags:
135
+ make = self.tags["EXIF LensMake"].values
136
+ elif "Image Make" in self.tags:
137
+ make = self.tags["Image Make"].values
138
+ else:
139
+ make = "unknown"
140
+ return self._decode_make_model(make)
141
+
142
+ def extract_model(self) -> str:
143
+ if "EXIF LensModel" in self.tags:
144
+ model = self.tags["EXIF LensModel"].values
145
+ elif "Image Model" in self.tags:
146
+ model = self.tags["Image Model"].values
147
+ else:
148
+ model = "unknown"
149
+ return self._decode_make_model(model)
150
+
151
+ def extract_focal(self) -> Tuple[float, float]:
152
+ make, model = self.extract_make(), self.extract_model()
153
+ focal_35, focal_ratio = compute_focal(
154
+ get_tag_as_float(self.tags, "EXIF FocalLengthIn35mmFilm"),
155
+ get_tag_as_float(self.tags, "EXIF FocalLength"),
156
+ self.extract_sensor_width(),
157
+ sensor_string(make, model),
158
+ )
159
+ return focal_35, focal_ratio
160
+
161
+ def extract_sensor_width(self) -> Optional[float]:
162
+ """Compute sensor with from width and resolution."""
163
+ if (
164
+ "EXIF FocalPlaneResolutionUnit" not in self.tags
165
+ or "EXIF FocalPlaneXResolution" not in self.tags
166
+ ):
167
+ return None
168
+ resolution_unit = self.tags["EXIF FocalPlaneResolutionUnit"].values[0]
169
+ mm_per_unit = self.get_mm_per_unit(resolution_unit)
170
+ if not mm_per_unit:
171
+ return None
172
+ pixels_per_unit = get_tag_as_float(self.tags, "EXIF FocalPlaneXResolution")
173
+ if pixels_per_unit is None:
174
+ return None
175
+ if pixels_per_unit <= 0.0:
176
+ pixels_per_unit = get_tag_as_float(self.tags, "EXIF FocalPlaneYResolution")
177
+ if pixels_per_unit is None or pixels_per_unit <= 0.0:
178
+ return None
179
+ units_per_pixel = 1 / pixels_per_unit
180
+ width_in_pixels = self.extract_image_size()[0]
181
+ return width_in_pixels * units_per_pixel * mm_per_unit
182
+
183
+ def get_mm_per_unit(self, resolution_unit) -> Optional[float]:
184
+ """Length of a resolution unit in millimeters.
185
+
186
+ Uses the values from the EXIF specs in
187
+ https://www.sno.phy.queensu.ca/~phil/exiftool/TagNames/EXIF.html
188
+
189
+ Args:
190
+ resolution_unit: the resolution unit value given in the EXIF
191
+ """
192
+ if resolution_unit == 2: # inch
193
+ return inch_in_mm
194
+ elif resolution_unit == 3: # cm
195
+ return cm_in_mm
196
+ elif resolution_unit == 4: # mm
197
+ return 1
198
+ elif resolution_unit == 5: # um
199
+ return um_in_mm
200
+ else:
201
+ logger.warning(
202
+ "Unknown EXIF resolution unit value: {}".format(resolution_unit)
203
+ )
204
+ return None
205
+
206
+ def extract_orientation(self) -> int:
207
+ orientation = 1
208
+ if "Image Orientation" in self.tags:
209
+ value = self.tags.get("Image Orientation").values[0]
210
+ if type(value) == int and value != 0:
211
+ orientation = value
212
+ return orientation
213
+
214
+ def extract_ref_lon_lat(self) -> Tuple[str, str]:
215
+ if "GPS GPSLatitudeRef" in self.tags:
216
+ reflat = self.tags["GPS GPSLatitudeRef"].values
217
+ else:
218
+ reflat = "N"
219
+ if "GPS GPSLongitudeRef" in self.tags:
220
+ reflon = self.tags["GPS GPSLongitudeRef"].values
221
+ else:
222
+ reflon = "E"
223
+ return reflon, reflat
224
+
225
+ def extract_lon_lat(self) -> Tuple[Optional[float], Optional[float]]:
226
+ if "GPS GPSLatitude" in self.tags:
227
+ reflon, reflat = self.extract_ref_lon_lat()
228
+ lat = gps_to_decimal(self.tags["GPS GPSLatitude"].values, reflat)
229
+ lon = gps_to_decimal(self.tags["GPS GPSLongitude"].values, reflon)
230
+ else:
231
+ lon, lat = None, None
232
+ return lon, lat
233
+
234
+ def extract_altitude(self) -> Optional[float]:
235
+ if "GPS GPSAltitude" in self.tags:
236
+ alt_value = self.tags["GPS GPSAltitude"].values[0]
237
+ if isinstance(alt_value, exifread.utils.Ratio):
238
+ altitude = eval_frac(alt_value)
239
+ elif isinstance(alt_value, int):
240
+ altitude = float(alt_value)
241
+ else:
242
+ altitude = None
243
+
244
+ # Check if GPSAltitudeRef is equal to 1, which means GPSAltitude should be negative, reference: http://www.exif.org/Exif2-2.PDF#page=53
245
+ if (
246
+ "GPS GPSAltitudeRef" in self.tags
247
+ and self.tags["GPS GPSAltitudeRef"].values[0] == 1
248
+ and altitude is not None
249
+ ):
250
+ altitude = -altitude
251
+ else:
252
+ altitude = None
253
+ return altitude
254
+
255
+ def extract_dop(self) -> Optional[float]:
256
+ if "GPS GPSDOP" in self.tags:
257
+ return eval_frac(self.tags["GPS GPSDOP"].values[0])
258
+ return None
259
+
260
+ def extract_geo(self) -> Dict[str, Any]:
261
+ altitude = self.extract_altitude()
262
+ dop = self.extract_dop()
263
+ lon, lat = self.extract_lon_lat()
264
+ d = {}
265
+
266
+ if lon is not None and lat is not None:
267
+ d["latitude"] = lat
268
+ d["longitude"] = lon
269
+ if altitude is not None:
270
+ d["altitude"] = min([maximum_altitude, altitude])
271
+ if dop is not None:
272
+ d["dop"] = dop
273
+ return d
274
+
275
+ def extract_capture_time(self) -> float:
276
+ if (
277
+ "GPS GPSDate" in self.tags
278
+ and "GPS GPSTimeStamp" in self.tags # Actually GPSDateStamp
279
+ ):
280
+ try:
281
+ hours_f = get_tag_as_float(self.tags, "GPS GPSTimeStamp", 0)
282
+ minutes_f = get_tag_as_float(self.tags, "GPS GPSTimeStamp", 1)
283
+ if hours_f is None or minutes_f is None:
284
+ raise TypeError
285
+ hours = int(hours_f)
286
+ minutes = int(minutes_f)
287
+ seconds = get_tag_as_float(self.tags, "GPS GPSTimeStamp", 2)
288
+ gps_timestamp_string = "{0:s} {1:02d}:{2:02d}:{3:02f}".format(
289
+ self.tags["GPS GPSDate"].values, hours, minutes, seconds
290
+ )
291
+ return (
292
+ datetime.datetime.strptime(
293
+ gps_timestamp_string, "%Y:%m:%d %H:%M:%S.%f"
294
+ )
295
+ - datetime.datetime(1970, 1, 1)
296
+ ).total_seconds()
297
+ except (TypeError, ValueError):
298
+ logger.info(
299
+ 'The GPS time stamp in image file "{0:s}" is invalid. '
300
+ "Falling back to DateTime*".format(self.fileobj_name)
301
+ )
302
+
303
+ time_strings = [
304
+ ("EXIF DateTimeOriginal", "EXIF SubSecTimeOriginal", "EXIF Tag 0x9011"),
305
+ ("EXIF DateTimeDigitized", "EXIF SubSecTimeDigitized", "EXIF Tag 0x9012"),
306
+ ("Image DateTime", "Image SubSecTime", "Image Tag 0x9010"),
307
+ ]
308
+ for datetime_tag, subsec_tag, offset_tag in time_strings:
309
+ if datetime_tag in self.tags:
310
+ date_time = self.tags[datetime_tag].values
311
+ if subsec_tag in self.tags:
312
+ subsec_time = self.tags[subsec_tag].values
313
+ else:
314
+ subsec_time = "0"
315
+ try:
316
+ s = "{0:s}.{1:s}".format(date_time, subsec_time)
317
+ d = datetime.datetime.strptime(s, "%Y:%m:%d %H:%M:%S.%f")
318
+ except ValueError:
319
+ logger.debug(
320
+ 'The "{1:s}" time stamp or "{2:s}" tag is invalid in '
321
+ 'image file "{0:s}"'.format(
322
+ self.fileobj_name, datetime_tag, subsec_tag
323
+ )
324
+ )
325
+ continue
326
+ # Test for OffsetTimeOriginal | OffsetTimeDigitized | OffsetTime
327
+ if offset_tag in self.tags:
328
+ offset_time = self.tags[offset_tag].values
329
+ try:
330
+ d += datetime.timedelta(
331
+ hours=-int(offset_time[0:3]), minutes=int(offset_time[4:6])
332
+ )
333
+ except (TypeError, ValueError):
334
+ logger.debug(
335
+ 'The "{0:s}" time zone offset in image file "{1:s}"'
336
+ " is invalid".format(offset_tag, self.fileobj_name)
337
+ )
338
+ logger.debug(
339
+ 'Naively assuming UTC on "{0:s}" in image file '
340
+ '"{1:s}"'.format(datetime_tag, self.fileobj_name)
341
+ )
342
+ else:
343
+ logger.debug(
344
+ "No GPS time stamp and no time zone offset in image "
345
+ 'file "{0:s}"'.format(self.fileobj_name)
346
+ )
347
+ logger.debug(
348
+ 'Naively assuming UTC on "{0:s}" in image file "{1:s}"'.format(
349
+ datetime_tag, self.fileobj_name
350
+ )
351
+ )
352
+ return (d - datetime.datetime(1970, 1, 1)).total_seconds()
353
+ logger.info(
354
+ 'Image file "{0:s}" has no valid time stamp'.format(self.fileobj_name)
355
+ )
356
+ return 0.0
utils/geo.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ from typing import Union
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from .geo_opensfm import TopocentricConverter
9
+
10
+
11
+ class BoundaryBox:
12
+ def __init__(self, min_: np.ndarray, max_: np.ndarray):
13
+ self.min_ = np.asarray(min_)
14
+ self.max_ = np.asarray(max_)
15
+ assert np.all(self.min_ <= self.max_)
16
+
17
+ @classmethod
18
+ def from_string(cls, string: str):
19
+ return cls(*np.split(np.array(string.split(","), float), 2))
20
+
21
+ @property
22
+ def left_top(self):
23
+ return np.stack([self.min_[..., 0], self.max_[..., 1]], -1)
24
+
25
+ @property
26
+ def right_bottom(self) -> (np.ndarray, np.ndarray):
27
+ return np.stack([self.max_[..., 0], self.min_[..., 1]], -1)
28
+
29
+ @property
30
+ def center(self) -> np.ndarray:
31
+ return (self.min_ + self.max_) / 2
32
+
33
+ @property
34
+ def size(self) -> np.ndarray:
35
+ return self.max_ - self.min_
36
+
37
+ def translate(self, t: float):
38
+ return self.__class__(self.min_ + t, self.max_ + t)
39
+
40
+ def contains(self, xy: Union[np.ndarray, "BoundaryBox"]):
41
+ if isinstance(xy, self.__class__):
42
+ return self.contains(xy.min_) and self.contains(xy.max_)
43
+ return np.all((xy >= self.min_) & (xy <= self.max_), -1)
44
+
45
+ def normalize(self, xy):
46
+ min_, max_ = self.min_, self.max_
47
+ if isinstance(xy, torch.Tensor):
48
+ min_ = torch.from_numpy(min_).to(xy)
49
+ max_ = torch.from_numpy(max_).to(xy)
50
+ return (xy - min_) / (max_ - min_)
51
+
52
+ def unnormalize(self, xy):
53
+ min_, max_ = self.min_, self.max_
54
+ if isinstance(xy, torch.Tensor):
55
+ min_ = torch.from_numpy(min_).to(xy)
56
+ max_ = torch.from_numpy(max_).to(xy)
57
+ return xy * (max_ - min_) + min_
58
+
59
+ def format(self) -> str:
60
+ return ",".join(np.r_[self.min_, self.max_].astype(str))
61
+
62
+ def __add__(self, x):
63
+ if isinstance(x, (int, float)):
64
+ return self.__class__(self.min_ - x, self.max_ + x)
65
+ else:
66
+ raise TypeError(f"Cannot add {self.__class__.__name__} to {type(x)}.")
67
+
68
+ def __and__(self, other):
69
+ return self.__class__(
70
+ np.maximum(self.min_, other.min_), np.minimum(self.max_, other.max_)
71
+ )
72
+
73
+ def __repr__(self):
74
+ return self.format()
75
+
76
+
77
+ class Projection:
78
+ def __init__(self, lat, lon, alt=0, max_extent=25e3):
79
+ # The approximation error is |L - radius * tan(L / radius)|
80
+ # and is around 13cm for L=25km.
81
+ self.latlonalt = (lat, lon, alt)
82
+ self.converter = TopocentricConverter(lat, lon, alt)
83
+ min_ = self.converter.to_lla(*(-max_extent,) * 2, 0)[:2]
84
+ max_ = self.converter.to_lla(*(max_extent,) * 2, 0)[:2]
85
+ self.bounds = BoundaryBox(min_, max_)
86
+
87
+ @classmethod
88
+ def from_points(cls, all_latlon):
89
+ assert all_latlon.shape[-1] == 2
90
+ all_latlon = all_latlon.reshape(-1, 2)
91
+ latlon_mid = (all_latlon.min(0) + all_latlon.max(0)) / 2
92
+ return cls(*latlon_mid)
93
+
94
+ def check_bbox(self, bbox: BoundaryBox):
95
+ if self.bounds is not None and not self.bounds.contains(bbox):
96
+ raise ValueError(
97
+ f"Bbox {bbox.format()} is not contained in "
98
+ f"projection with bounds {self.bounds.format()}."
99
+ )
100
+
101
+ def project(self, geo, return_z=False):
102
+ if isinstance(geo, BoundaryBox):
103
+ return BoundaryBox(*self.project(np.stack([geo.min_, geo.max_])))
104
+ geo = np.asarray(geo)
105
+ assert geo.shape[-1] in (2, 3)
106
+ if self.bounds is not None:
107
+ if not np.all(self.bounds.contains(geo[..., :2])):
108
+ raise ValueError(
109
+ f"Points {geo} are out of the valid bounds "
110
+ f"{self.bounds.format()}."
111
+ )
112
+ lat, lon = geo[..., 0], geo[..., 1]
113
+ if geo.shape[-1] == 3:
114
+ alt = geo[..., -1]
115
+ else:
116
+ alt = np.zeros_like(lat)
117
+ x, y, z = self.converter.to_topocentric(lat, lon, alt)
118
+ return np.stack([x, y] + ([z] if return_z else []), -1)
119
+
120
+ def unproject(self, xy, return_z=False):
121
+ if isinstance(xy, BoundaryBox):
122
+ return BoundaryBox(*self.unproject(np.stack([xy.min_, xy.max_])))
123
+ xy = np.asarray(xy)
124
+ x, y = xy[..., 0], xy[..., 1]
125
+ if xy.shape[-1] == 3:
126
+ z = xy[..., -1]
127
+ else:
128
+ z = np.zeros_like(x)
129
+ lat, lon, alt = self.converter.to_lla(x, y, z)
130
+ return np.stack([lat, lon] + ([alt] if return_z else []), -1)