umuthopeyildirim
commited on
Commit
•
bd86ed9
1
Parent(s):
db5b5dc
here we go
Browse files- .DS_Store +0 -0
- app.py +109 -0
- checkpoints/kittieigen_L.pth +3 -0
- checkpoints/nyu_L.pth +3 -0
- iebins/dataloaders/__init__.py +0 -0
- iebins/dataloaders/__pycache__/__init__.cpython-38.pyc +0 -0
- iebins/dataloaders/__pycache__/dataloader.cpython-38.pyc +0 -0
- iebins/dataloaders/__pycache__/dataloader_sun.cpython-38.pyc +0 -0
- iebins/dataloaders/dataloader.py +343 -0
- iebins/dataloaders/dataloader_sun.py +326 -0
- iebins/eval.py +177 -0
- iebins/eval_sun.py +179 -0
- iebins/inference_single_image.py +117 -0
- iebins/networks/NewCRFDepth.py +318 -0
- iebins/networks/__init__.py +0 -0
- iebins/networks/depth_update.py +39 -0
- iebins/networks/newcrf_layers.py +433 -0
- iebins/networks/newcrf_utils.py +264 -0
- iebins/networks/resize.py +51 -0
- iebins/networks/swin_transformer.py +620 -0
- iebins/networks/uper_crf_head.py +364 -0
- iebins/sum_depth.py +22 -0
- iebins/test.py +209 -0
- iebins/train.py +499 -0
- iebins/utils.py +356 -0
- iebins/utils/transfrom.py +250 -0
- requirements.txt +12 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
from PIL import Image
|
6 |
+
import spaces
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torchvision.transforms import Compose
|
10 |
+
import tempfile
|
11 |
+
from gradio_imageslider import ImageSlider
|
12 |
+
|
13 |
+
from iebins.networks.NewCRFDepth import NewCRFDepth
|
14 |
+
from iebins.utils.transfrom import Resize, NormalizeImage, PrepareForNet
|
15 |
+
|
16 |
+
css = """
|
17 |
+
#img-display-container {
|
18 |
+
max-height: 100vh;
|
19 |
+
}
|
20 |
+
#img-display-input {
|
21 |
+
max-height: 80vh;
|
22 |
+
}
|
23 |
+
#img-display-output {
|
24 |
+
max-height: 80vh;
|
25 |
+
}
|
26 |
+
"""
|
27 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
28 |
+
model = NewCRFDepth(version="large07", inv_depth=False,
|
29 |
+
max_depth=10, pretrained=None).to(DEVICE).eval()
|
30 |
+
model.load_state_dict(torch.load('checkpoints/nyu_L.pth'))
|
31 |
+
|
32 |
+
title = "# IEBins: Iterative Elastic Bins for Monocular Depth Estimation"
|
33 |
+
description = """Demo for **IEBins: Iterative Elastic Bins for Monocular Depth Estimation**.
|
34 |
+
Please refer to the [paper](https://arxiv.org/abs/2309.14137), [github](https://github.com/ShuweiShao/IEBins), or [poster](https://nips.cc/media/PosterPDFs/NeurIPS%202023/70695.png?t=1701662442.5228624) for more details."""
|
35 |
+
|
36 |
+
transform = Compose([
|
37 |
+
Resize(
|
38 |
+
width=518,
|
39 |
+
height=518,
|
40 |
+
resize_target=False,
|
41 |
+
keep_aspect_ratio=True,
|
42 |
+
ensure_multiple_of=14,
|
43 |
+
resize_method='lower_bound',
|
44 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
45 |
+
),
|
46 |
+
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
47 |
+
PrepareForNet(),
|
48 |
+
])
|
49 |
+
|
50 |
+
|
51 |
+
@spaces.GPU
|
52 |
+
@torch.no_grad()
|
53 |
+
def predict_depth(model, image):
|
54 |
+
return model(image)
|
55 |
+
|
56 |
+
|
57 |
+
with gr.Blocks(css=css) as demo:
|
58 |
+
gr.Markdown(title)
|
59 |
+
gr.Markdown(description)
|
60 |
+
gr.Markdown("### Depth Prediction demo")
|
61 |
+
gr.Markdown(
|
62 |
+
"You can slide the output to compare the depth prediction with input image")
|
63 |
+
|
64 |
+
with gr.Row():
|
65 |
+
input_image = gr.Image(label="Input Image",
|
66 |
+
type='numpy', elem_id='img-display-input')
|
67 |
+
depth_image_slider = ImageSlider(
|
68 |
+
label="Depth Map with Slider View", elem_id='img-display-output', position=0.5,)
|
69 |
+
raw_file = gr.File(
|
70 |
+
label="16-bit raw depth (can be considered as disparity)")
|
71 |
+
submit = gr.Button("Submit")
|
72 |
+
|
73 |
+
def on_submit(image):
|
74 |
+
original_image = image.copy()
|
75 |
+
|
76 |
+
h, w = image.shape[:2]
|
77 |
+
|
78 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
|
79 |
+
image = transform({'image': image})['image']
|
80 |
+
image = torch.from_numpy(image).unsqueeze(0).to(DEVICE)
|
81 |
+
|
82 |
+
depth = predict_depth(model, image)
|
83 |
+
depth = F.interpolate(depth[None], (h, w),
|
84 |
+
mode='bilinear', align_corners=False)[0, 0]
|
85 |
+
|
86 |
+
raw_depth = Image.fromarray(depth.cpu().numpy().astype('uint16'))
|
87 |
+
tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
88 |
+
raw_depth.save(tmp.name)
|
89 |
+
|
90 |
+
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
91 |
+
depth = depth.cpu().numpy().astype(np.uint8)
|
92 |
+
colored_depth = cv2.applyColorMap(
|
93 |
+
depth, cv2.COLORMAP_INFERNO)[:, :, ::-1]
|
94 |
+
|
95 |
+
return [(original_image, colored_depth), tmp.name]
|
96 |
+
|
97 |
+
submit.click(on_submit, inputs=[input_image], outputs=[
|
98 |
+
depth_image_slider, raw_file])
|
99 |
+
|
100 |
+
example_files = os.listdir('examples')
|
101 |
+
example_files.sort()
|
102 |
+
example_files = [os.path.join('examples', filename)
|
103 |
+
for filename in example_files]
|
104 |
+
examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[
|
105 |
+
depth_image_slider, raw_file], fn=on_submit, cache_examples=False)
|
106 |
+
|
107 |
+
|
108 |
+
if __name__ == '__main__':
|
109 |
+
demo.queue().launch()
|
checkpoints/kittieigen_L.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bf10549a615b19b96ffdddc82e639662c421fe0cd30008cc3cf3e7d4bffa5f55
|
3 |
+
size 3276188594
|
checkpoints/nyu_L.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:81d95d5f26f5d01b7e8b060467eef77ea6efea4ddf100d60f5fad87e6c0daae7
|
3 |
+
size 3276188594
|
iebins/dataloaders/__init__.py
ADDED
File without changes
|
iebins/dataloaders/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (173 Bytes). View file
|
|
iebins/dataloaders/__pycache__/dataloader.cpython-38.pyc
ADDED
Binary file (9.15 kB). View file
|
|
iebins/dataloaders/__pycache__/dataloader_sun.cpython-38.pyc
ADDED
Binary file (8.93 kB). View file
|
|
iebins/dataloaders/dataloader.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset, DataLoader
|
3 |
+
import torch.utils.data.distributed
|
4 |
+
from torchvision import transforms
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
import os
|
9 |
+
import random
|
10 |
+
import copy
|
11 |
+
|
12 |
+
from utils import DistributedSamplerNoEvenlyDivisible
|
13 |
+
|
14 |
+
|
15 |
+
def _is_pil_image(img):
|
16 |
+
return isinstance(img, Image.Image)
|
17 |
+
|
18 |
+
|
19 |
+
def _is_numpy_image(img):
|
20 |
+
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
|
21 |
+
|
22 |
+
|
23 |
+
def preprocessing_transforms(mode):
|
24 |
+
return transforms.Compose([
|
25 |
+
ToTensor(mode=mode)
|
26 |
+
])
|
27 |
+
|
28 |
+
|
29 |
+
class NewDataLoader(object):
|
30 |
+
def __init__(self, args, mode):
|
31 |
+
if mode == 'train':
|
32 |
+
self.training_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
|
33 |
+
if args.distributed:
|
34 |
+
self.train_sampler = torch.utils.data.distributed.DistributedSampler(self.training_samples)
|
35 |
+
else:
|
36 |
+
self.train_sampler = None
|
37 |
+
|
38 |
+
self.data = DataLoader(self.training_samples, args.batch_size,
|
39 |
+
shuffle=(self.train_sampler is None),
|
40 |
+
num_workers=args.num_threads,
|
41 |
+
pin_memory=True,
|
42 |
+
sampler=self.train_sampler)
|
43 |
+
|
44 |
+
elif mode == 'online_eval':
|
45 |
+
self.testing_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
|
46 |
+
if args.distributed:
|
47 |
+
# self.eval_sampler = torch.utils.data.distributed.DistributedSampler(self.testing_samples, shuffle=False)
|
48 |
+
self.eval_sampler = DistributedSamplerNoEvenlyDivisible(self.testing_samples, shuffle=False)
|
49 |
+
else:
|
50 |
+
self.eval_sampler = None
|
51 |
+
self.data = DataLoader(self.testing_samples, 1,
|
52 |
+
shuffle=False,
|
53 |
+
num_workers=1,
|
54 |
+
pin_memory=True,
|
55 |
+
sampler=self.eval_sampler)
|
56 |
+
|
57 |
+
elif mode == 'test':
|
58 |
+
self.testing_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
|
59 |
+
self.data = DataLoader(self.testing_samples, 1, shuffle=False, num_workers=1)
|
60 |
+
|
61 |
+
else:
|
62 |
+
print('mode should be one of \'train, test, online_eval\'. Got {}'.format(mode))
|
63 |
+
|
64 |
+
|
65 |
+
class DataLoadPreprocess(Dataset):
|
66 |
+
def __init__(self, args, mode, transform=None, is_for_online_eval=False):
|
67 |
+
self.args = args
|
68 |
+
if mode == 'online_eval':
|
69 |
+
with open(args.filenames_file_eval, 'r') as f:
|
70 |
+
self.filenames = f.readlines()
|
71 |
+
else:
|
72 |
+
with open(args.filenames_file, 'r') as f:
|
73 |
+
self.filenames = f.readlines()
|
74 |
+
|
75 |
+
self.mode = mode
|
76 |
+
self.transform = transform
|
77 |
+
self.to_tensor = ToTensor
|
78 |
+
self.is_for_online_eval = is_for_online_eval
|
79 |
+
|
80 |
+
def __getitem__(self, idx):
|
81 |
+
sample_path = self.filenames[idx]
|
82 |
+
# focal = float(sample_path.split()[2])
|
83 |
+
focal = 518.8579
|
84 |
+
|
85 |
+
if self.mode == 'train':
|
86 |
+
if self.args.dataset == 'kitti':
|
87 |
+
rgb_file = sample_path.split()[0]
|
88 |
+
depth_file = os.path.join(sample_path.split()[0].split('/')[0], sample_path.split()[1])
|
89 |
+
if self.args.use_right is True and random.random() > 0.5:
|
90 |
+
rgb_file = rgb_file.replace('image_02', 'image_03')
|
91 |
+
depth_file = depth_file.replace('image_02', 'image_03')
|
92 |
+
else:
|
93 |
+
rgb_file = sample_path.split()[0]
|
94 |
+
depth_file = sample_path.split()[1]
|
95 |
+
|
96 |
+
image_path = os.path.join(self.args.data_path, rgb_file)
|
97 |
+
depth_path = os.path.join(self.args.gt_path, depth_file)
|
98 |
+
|
99 |
+
image = Image.open(image_path)
|
100 |
+
depth_gt = Image.open(depth_path)
|
101 |
+
|
102 |
+
if self.args.do_kb_crop is True:
|
103 |
+
height = image.height
|
104 |
+
width = image.width
|
105 |
+
top_margin = int(height - 352)
|
106 |
+
left_margin = int((width - 1216) / 2)
|
107 |
+
depth_gt = depth_gt.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
|
108 |
+
image = image.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
|
109 |
+
|
110 |
+
# To avoid blank boundaries due to pixel registration
|
111 |
+
if self.args.dataset == 'nyu':
|
112 |
+
if self.args.input_height == 480:
|
113 |
+
depth_gt = np.array(depth_gt)
|
114 |
+
valid_mask = np.zeros_like(depth_gt)
|
115 |
+
valid_mask[45:472, 43:608] = 1
|
116 |
+
depth_gt[valid_mask==0] = 0
|
117 |
+
depth_gt = Image.fromarray(depth_gt)
|
118 |
+
else:
|
119 |
+
depth_gt = depth_gt.crop((43, 45, 608, 472))
|
120 |
+
image = image.crop((43, 45, 608, 472))
|
121 |
+
|
122 |
+
if self.args.do_random_rotate is True:
|
123 |
+
random_angle = (random.random() - 0.5) * 2 * self.args.degree
|
124 |
+
image = self.rotate_image(image, random_angle)
|
125 |
+
depth_gt = self.rotate_image(depth_gt, random_angle, flag=Image.NEAREST)
|
126 |
+
|
127 |
+
image = np.asarray(image, dtype=np.float32) / 255.0
|
128 |
+
depth_gt = np.asarray(depth_gt, dtype=np.float32)
|
129 |
+
depth_gt = np.expand_dims(depth_gt, axis=2)
|
130 |
+
|
131 |
+
if self.args.dataset == 'nyu':
|
132 |
+
depth_gt = depth_gt / 1000.0
|
133 |
+
else:
|
134 |
+
depth_gt = depth_gt / 256.0
|
135 |
+
|
136 |
+
if image.shape[0] != self.args.input_height or image.shape[1] != self.args.input_width:
|
137 |
+
image, depth_gt = self.random_crop(image, depth_gt, self.args.input_height, self.args.input_width)
|
138 |
+
image, depth_gt = self.train_preprocess(image, depth_gt)
|
139 |
+
# https://github.com/ShuweiShao/URCDC-Depth
|
140 |
+
image, depth_gt = self.Cut_Flip(image, depth_gt)
|
141 |
+
sample = {'image': image, 'depth': depth_gt, 'focal': focal}
|
142 |
+
|
143 |
+
else:
|
144 |
+
if self.mode == 'online_eval':
|
145 |
+
data_path = self.args.data_path_eval
|
146 |
+
else:
|
147 |
+
data_path = self.args.data_path
|
148 |
+
|
149 |
+
image_path = os.path.join(data_path, "./" + sample_path.split()[0])
|
150 |
+
image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
|
151 |
+
|
152 |
+
if self.mode == 'online_eval':
|
153 |
+
gt_path = self.args.gt_path_eval
|
154 |
+
depth_path = os.path.join(gt_path, "./" + sample_path.split()[1])
|
155 |
+
if self.args.dataset == 'kitti':
|
156 |
+
depth_path = os.path.join(gt_path, sample_path.split()[0].split('/')[0], sample_path.split()[1])
|
157 |
+
has_valid_depth = False
|
158 |
+
try:
|
159 |
+
depth_gt = Image.open(depth_path)
|
160 |
+
has_valid_depth = True
|
161 |
+
except IOError:
|
162 |
+
depth_gt = False
|
163 |
+
# print('Missing gt for {}'.format(image_path))
|
164 |
+
|
165 |
+
if has_valid_depth:
|
166 |
+
depth_gt = np.asarray(depth_gt, dtype=np.float32)
|
167 |
+
depth_gt = np.expand_dims(depth_gt, axis=2)
|
168 |
+
if self.args.dataset == 'nyu':
|
169 |
+
depth_gt = depth_gt / 1000.0
|
170 |
+
else:
|
171 |
+
depth_gt = depth_gt / 256.0
|
172 |
+
|
173 |
+
if self.args.do_kb_crop is True:
|
174 |
+
height = image.shape[0]
|
175 |
+
width = image.shape[1]
|
176 |
+
top_margin = int(height - 352)
|
177 |
+
left_margin = int((width - 1216) / 2)
|
178 |
+
image = image[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
|
179 |
+
if self.mode == 'online_eval' and has_valid_depth:
|
180 |
+
depth_gt = depth_gt[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
|
181 |
+
|
182 |
+
if self.mode == 'online_eval':
|
183 |
+
sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth}
|
184 |
+
else:
|
185 |
+
sample = {'image': image, 'focal': focal}
|
186 |
+
|
187 |
+
if self.transform:
|
188 |
+
sample = self.transform([sample, self.args.dataset])
|
189 |
+
|
190 |
+
return sample
|
191 |
+
|
192 |
+
def rotate_image(self, image, angle, flag=Image.BILINEAR):
|
193 |
+
result = image.rotate(angle, resample=flag)
|
194 |
+
return result
|
195 |
+
|
196 |
+
def random_crop(self, img, depth, height, width):
|
197 |
+
assert img.shape[0] >= height
|
198 |
+
assert img.shape[1] >= width
|
199 |
+
assert img.shape[0] == depth.shape[0]
|
200 |
+
assert img.shape[1] == depth.shape[1]
|
201 |
+
x = random.randint(0, img.shape[1] - width)
|
202 |
+
y = random.randint(0, img.shape[0] - height)
|
203 |
+
img = img[y:y + height, x:x + width, :]
|
204 |
+
depth = depth[y:y + height, x:x + width, :]
|
205 |
+
return img, depth
|
206 |
+
|
207 |
+
def train_preprocess(self, image, depth_gt):
|
208 |
+
# Random flipping
|
209 |
+
do_flip = random.random()
|
210 |
+
if do_flip > 0.5:
|
211 |
+
image = (image[:, ::-1, :]).copy()
|
212 |
+
depth_gt = (depth_gt[:, ::-1, :]).copy()
|
213 |
+
|
214 |
+
# Random gamma, brightness, color augmentation
|
215 |
+
do_augment = random.random()
|
216 |
+
if do_augment > 0.5:
|
217 |
+
image = self.augment_image(image)
|
218 |
+
|
219 |
+
return image, depth_gt
|
220 |
+
|
221 |
+
def augment_image(self, image):
|
222 |
+
# gamma augmentation
|
223 |
+
gamma = random.uniform(0.9, 1.1)
|
224 |
+
image_aug = image ** gamma
|
225 |
+
|
226 |
+
# brightness augmentation
|
227 |
+
if self.args.dataset == 'nyu':
|
228 |
+
brightness = random.uniform(0.75, 1.25)
|
229 |
+
else:
|
230 |
+
brightness = random.uniform(0.9, 1.1)
|
231 |
+
image_aug = image_aug * brightness
|
232 |
+
|
233 |
+
# color augmentation
|
234 |
+
colors = np.random.uniform(0.9, 1.1, size=3)
|
235 |
+
white = np.ones((image.shape[0], image.shape[1]))
|
236 |
+
color_image = np.stack([white * colors[i] for i in range(3)], axis=2)
|
237 |
+
image_aug *= color_image
|
238 |
+
image_aug = np.clip(image_aug, 0, 1)
|
239 |
+
|
240 |
+
return image_aug
|
241 |
+
|
242 |
+
def Cut_Flip(self, image, depth):
|
243 |
+
|
244 |
+
p = random.random()
|
245 |
+
if p < 0.5:
|
246 |
+
return image, depth
|
247 |
+
image_copy = copy.deepcopy(image)
|
248 |
+
depth_copy = copy.deepcopy(depth)
|
249 |
+
h, w, c = image.shape
|
250 |
+
|
251 |
+
N = 2
|
252 |
+
h_list = []
|
253 |
+
h_interval_list = [] # hight interval
|
254 |
+
for i in range(N-1):
|
255 |
+
h_list.append(random.randint(int(0.2*h), int(0.8*h)))
|
256 |
+
h_list.append(h)
|
257 |
+
h_list.append(0)
|
258 |
+
h_list.sort()
|
259 |
+
h_list_inv = np.array([h]*(N+1))-np.array(h_list)
|
260 |
+
for i in range(len(h_list)-1):
|
261 |
+
h_interval_list.append(h_list[i+1]-h_list[i])
|
262 |
+
for i in range(N):
|
263 |
+
image[h_list[i]:h_list[i+1], :, :] = image_copy[h_list_inv[i]-h_interval_list[i]:h_list_inv[i], :, :]
|
264 |
+
depth[h_list[i]:h_list[i+1], :, :] = depth_copy[h_list_inv[i]-h_interval_list[i]:h_list_inv[i], :, :]
|
265 |
+
|
266 |
+
return image, depth
|
267 |
+
|
268 |
+
|
269 |
+
def __len__(self):
|
270 |
+
return len(self.filenames)
|
271 |
+
|
272 |
+
|
273 |
+
class ToTensor(object):
|
274 |
+
def __init__(self, mode):
|
275 |
+
self.mode = mode
|
276 |
+
self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
277 |
+
|
278 |
+
def __call__(self, sample_dataset):
|
279 |
+
|
280 |
+
sample = sample_dataset[0]
|
281 |
+
dataset = sample_dataset[1]
|
282 |
+
|
283 |
+
image, focal = sample['image'], sample['focal']
|
284 |
+
image = self.to_tensor(image)
|
285 |
+
image = self.normalize(image)
|
286 |
+
|
287 |
+
if dataset == 'kitti':
|
288 |
+
K_p = np.array([[716.88, 0, 596.5593, 0],
|
289 |
+
[0, 716.88, 149.854, 0],
|
290 |
+
[0, 0, 1, 0],
|
291 |
+
[0, 0, 0, 1]], dtype=np.float32)
|
292 |
+
inv_K_p = np.linalg.pinv(K_p)
|
293 |
+
inv_K_p = torch.from_numpy(inv_K_p)
|
294 |
+
|
295 |
+
elif dataset == 'nyu':
|
296 |
+
K_p = np.array([[518.8579, 0, 325.5824, 0],
|
297 |
+
[0, 518.8579, 253.7362, 0],
|
298 |
+
[0, 0, 1, 0],
|
299 |
+
[0, 0, 0, 1]], dtype=np.float32)
|
300 |
+
inv_K_p = np.linalg.pinv(K_p)
|
301 |
+
inv_K_p = torch.from_numpy(inv_K_p)
|
302 |
+
|
303 |
+
if self.mode == 'test':
|
304 |
+
return {'image': image, 'inv_K_p': inv_K_p, 'focal': focal}
|
305 |
+
|
306 |
+
depth = sample['depth']
|
307 |
+
if self.mode == 'train':
|
308 |
+
depth = self.to_tensor(depth)
|
309 |
+
return {'image': image, 'depth': depth, 'focal': focal}
|
310 |
+
else:
|
311 |
+
has_valid_depth = sample['has_valid_depth']
|
312 |
+
return {'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth}
|
313 |
+
|
314 |
+
def to_tensor(self, pic):
|
315 |
+
if not (_is_pil_image(pic) or _is_numpy_image(pic)):
|
316 |
+
raise TypeError(
|
317 |
+
'pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
|
318 |
+
|
319 |
+
if isinstance(pic, np.ndarray):
|
320 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
321 |
+
return img
|
322 |
+
|
323 |
+
# handle PIL Image
|
324 |
+
if pic.mode == 'I':
|
325 |
+
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
326 |
+
elif pic.mode == 'I;16':
|
327 |
+
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
328 |
+
else:
|
329 |
+
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
|
330 |
+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
331 |
+
if pic.mode == 'YCbCr':
|
332 |
+
nchannel = 3
|
333 |
+
elif pic.mode == 'I;16':
|
334 |
+
nchannel = 1
|
335 |
+
else:
|
336 |
+
nchannel = len(pic.mode)
|
337 |
+
img = img.view(pic.size[1], pic.size[0], nchannel)
|
338 |
+
|
339 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
340 |
+
if isinstance(img, torch.ByteTensor):
|
341 |
+
return img.float()
|
342 |
+
else:
|
343 |
+
return img
|
iebins/dataloaders/dataloader_sun.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset, DataLoader
|
3 |
+
import torch.utils.data.distributed
|
4 |
+
from torchvision import transforms
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
import os
|
9 |
+
import random
|
10 |
+
import copy
|
11 |
+
import cv2
|
12 |
+
|
13 |
+
from utils import DistributedSamplerNoEvenlyDivisible
|
14 |
+
|
15 |
+
|
16 |
+
def _is_pil_image(img):
|
17 |
+
return isinstance(img, Image.Image)
|
18 |
+
|
19 |
+
|
20 |
+
def _is_numpy_image(img):
|
21 |
+
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
|
22 |
+
|
23 |
+
|
24 |
+
def preprocessing_transforms(mode):
|
25 |
+
return transforms.Compose([
|
26 |
+
ToTensor(mode=mode)
|
27 |
+
])
|
28 |
+
|
29 |
+
|
30 |
+
class NewDataLoader(object):
|
31 |
+
def __init__(self, args, mode):
|
32 |
+
if mode == 'train':
|
33 |
+
self.training_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
|
34 |
+
if args.distributed:
|
35 |
+
self.train_sampler = torch.utils.data.distributed.DistributedSampler(self.training_samples)
|
36 |
+
else:
|
37 |
+
self.train_sampler = None
|
38 |
+
|
39 |
+
self.data = DataLoader(self.training_samples, args.batch_size,
|
40 |
+
shuffle=(self.train_sampler is None),
|
41 |
+
num_workers=args.num_threads,
|
42 |
+
pin_memory=True,
|
43 |
+
sampler=self.train_sampler)
|
44 |
+
|
45 |
+
elif mode == 'online_eval':
|
46 |
+
self.testing_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
|
47 |
+
if args.distributed:
|
48 |
+
# self.eval_sampler = torch.utils.data.distributed.DistributedSampler(self.testing_samples, shuffle=False)
|
49 |
+
self.eval_sampler = DistributedSamplerNoEvenlyDivisible(self.testing_samples, shuffle=False)
|
50 |
+
else:
|
51 |
+
self.eval_sampler = None
|
52 |
+
self.data = DataLoader(self.testing_samples, 1,
|
53 |
+
shuffle=False,
|
54 |
+
num_workers=1,
|
55 |
+
pin_memory=True,
|
56 |
+
sampler=self.eval_sampler)
|
57 |
+
|
58 |
+
elif mode == 'test':
|
59 |
+
self.testing_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
|
60 |
+
self.data = DataLoader(self.testing_samples, 1, shuffle=False, num_workers=1)
|
61 |
+
|
62 |
+
else:
|
63 |
+
print('mode should be one of \'train, test, online_eval\'. Got {}'.format(mode))
|
64 |
+
|
65 |
+
|
66 |
+
class DataLoadPreprocess(Dataset):
|
67 |
+
def __init__(self, args, mode, transform=None, is_for_online_eval=False):
|
68 |
+
self.args = args
|
69 |
+
if mode == 'online_eval':
|
70 |
+
with open(args.filenames_file_eval, 'r') as f:
|
71 |
+
self.filenames = f.readlines()
|
72 |
+
else:
|
73 |
+
with open(args.filenames_file, 'r') as f:
|
74 |
+
self.filenames = f.readlines()
|
75 |
+
|
76 |
+
self.mode = mode
|
77 |
+
self.transform = transform
|
78 |
+
self.to_tensor = ToTensor
|
79 |
+
self.is_for_online_eval = is_for_online_eval
|
80 |
+
|
81 |
+
def __getitem__(self, idx):
|
82 |
+
sample_path = self.filenames[idx]
|
83 |
+
# focal = float(sample_path.split()[2])
|
84 |
+
focal = 518.8579
|
85 |
+
|
86 |
+
if self.mode == 'train':
|
87 |
+
if self.args.dataset == 'kitti':
|
88 |
+
rgb_file = sample_path.split()[0]
|
89 |
+
depth_file = os.path.join(sample_path.split()[0].split('/')[0], sample_path.split()[1])
|
90 |
+
if self.args.use_right is True and random.random() > 0.5:
|
91 |
+
rgb_file = rgb_file.replace('image_02', 'image_03')
|
92 |
+
depth_file = depth_file.replace('image_02', 'image_03')
|
93 |
+
else:
|
94 |
+
rgb_file = sample_path.split()[0]
|
95 |
+
depth_file = sample_path.split()[1]
|
96 |
+
|
97 |
+
image_path = os.path.join(self.args.data_path, rgb_file)
|
98 |
+
depth_path = os.path.join(self.args.gt_path, depth_file)
|
99 |
+
|
100 |
+
image = Image.open(image_path)
|
101 |
+
depth_gt = Image.open(depth_path)
|
102 |
+
|
103 |
+
if self.args.do_kb_crop is True:
|
104 |
+
height = image.height
|
105 |
+
width = image.width
|
106 |
+
top_margin = int(height - 352)
|
107 |
+
left_margin = int((width - 1216) / 2)
|
108 |
+
depth_gt = depth_gt.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
|
109 |
+
image = image.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
|
110 |
+
|
111 |
+
# To avoid blank boundaries due to pixel registration
|
112 |
+
if self.args.dataset == 'nyu':
|
113 |
+
if self.args.input_height == 480:
|
114 |
+
depth_gt = np.array(depth_gt)
|
115 |
+
valid_mask = np.zeros_like(depth_gt)
|
116 |
+
valid_mask[45:472, 43:608] = 1
|
117 |
+
depth_gt[valid_mask==0] = 0
|
118 |
+
depth_gt = Image.fromarray(depth_gt)
|
119 |
+
else:
|
120 |
+
depth_gt = depth_gt.crop((43, 45, 608, 472))
|
121 |
+
image = image.crop((43, 45, 608, 472))
|
122 |
+
|
123 |
+
if self.args.do_random_rotate is True:
|
124 |
+
random_angle = (random.random() - 0.5) * 2 * self.args.degree
|
125 |
+
image = self.rotate_image(image, random_angle)
|
126 |
+
depth_gt = self.rotate_image(depth_gt, random_angle, flag=Image.NEAREST)
|
127 |
+
|
128 |
+
image = np.asarray(image, dtype=np.float32) / 255.0
|
129 |
+
depth_gt = np.asarray(depth_gt, dtype=np.float32)
|
130 |
+
depth_gt = np.expand_dims(depth_gt, axis=2)
|
131 |
+
|
132 |
+
if self.args.dataset == 'nyu':
|
133 |
+
depth_gt = depth_gt / 1000.0
|
134 |
+
else:
|
135 |
+
depth_gt = depth_gt / 256.0
|
136 |
+
|
137 |
+
if image.shape[0] != self.args.input_height or image.shape[1] != self.args.input_width:
|
138 |
+
image, depth_gt = self.random_crop(image, depth_gt, self.args.input_height, self.args.input_width)
|
139 |
+
image, depth_gt = self.train_preprocess(image, depth_gt)
|
140 |
+
image, depth_gt = self.Cut_Flip(image, depth_gt)
|
141 |
+
sample = {'image': image, 'depth': depth_gt, 'focal': focal}
|
142 |
+
|
143 |
+
else:
|
144 |
+
if self.mode == 'online_eval':
|
145 |
+
data_path = self.args.data_path_eval
|
146 |
+
else:
|
147 |
+
data_path = self.args.data_path
|
148 |
+
|
149 |
+
image_path = os.path.join(data_path, "./" + sample_path.split()[0])
|
150 |
+
image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
|
151 |
+
image = cv2.resize(image, (640, 480))
|
152 |
+
|
153 |
+
if self.mode == 'online_eval':
|
154 |
+
gt_path = self.args.gt_path_eval
|
155 |
+
depth_path = os.path.join(gt_path, "./" + sample_path.split()[1])
|
156 |
+
if self.args.dataset == 'kitti':
|
157 |
+
depth_path = os.path.join(gt_path, sample_path.split()[0].split('/')[0], sample_path.split()[1])
|
158 |
+
has_valid_depth = False
|
159 |
+
try:
|
160 |
+
depth_gt = Image.open(depth_path)
|
161 |
+
has_valid_depth = True
|
162 |
+
except IOError:
|
163 |
+
depth_gt = False
|
164 |
+
# print('Missing gt for {}'.format(image_path))
|
165 |
+
|
166 |
+
if has_valid_depth:
|
167 |
+
depth_gt = np.asarray(depth_gt, dtype=np.uint16) # 2
|
168 |
+
depth_gt = np.bitwise_or(np.right_shift(depth_gt, 3), np.left_shift(depth_gt, 16 - 3)) # 3
|
169 |
+
depth_gt = np.expand_dims(depth_gt, axis=2)
|
170 |
+
if self.args.dataset == 'nyu':
|
171 |
+
depth_gt = depth_gt.astype(np.single) / 1000 # 4
|
172 |
+
depth_gt = depth_gt.astype(np.float32) # 5
|
173 |
+
else:
|
174 |
+
depth_gt = depth_gt / 256.0
|
175 |
+
|
176 |
+
if self.args.do_kb_crop is True:
|
177 |
+
height = image.shape[0]
|
178 |
+
width = image.shape[1]
|
179 |
+
top_margin = int(height - 352)
|
180 |
+
left_margin = int((width - 1216) / 2)
|
181 |
+
image = image[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
|
182 |
+
if self.mode == 'online_eval' and has_valid_depth:
|
183 |
+
depth_gt = depth_gt[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
|
184 |
+
|
185 |
+
if self.mode == 'online_eval':
|
186 |
+
sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth}
|
187 |
+
else:
|
188 |
+
sample = {'image': image, 'focal': focal}
|
189 |
+
|
190 |
+
if self.transform:
|
191 |
+
sample = self.transform(sample)
|
192 |
+
|
193 |
+
return sample
|
194 |
+
|
195 |
+
def rotate_image(self, image, angle, flag=Image.BILINEAR):
|
196 |
+
result = image.rotate(angle, resample=flag)
|
197 |
+
return result
|
198 |
+
|
199 |
+
def random_crop(self, img, depth, height, width):
|
200 |
+
assert img.shape[0] >= height
|
201 |
+
assert img.shape[1] >= width
|
202 |
+
assert img.shape[0] == depth.shape[0]
|
203 |
+
assert img.shape[1] == depth.shape[1]
|
204 |
+
x = random.randint(0, img.shape[1] - width)
|
205 |
+
y = random.randint(0, img.shape[0] - height)
|
206 |
+
img = img[y:y + height, x:x + width, :]
|
207 |
+
depth = depth[y:y + height, x:x + width, :]
|
208 |
+
return img, depth
|
209 |
+
|
210 |
+
def train_preprocess(self, image, depth_gt):
|
211 |
+
# Random flipping
|
212 |
+
do_flip = random.random()
|
213 |
+
if do_flip > 0.5:
|
214 |
+
image = (image[:, ::-1, :]).copy()
|
215 |
+
depth_gt = (depth_gt[:, ::-1, :]).copy()
|
216 |
+
|
217 |
+
# Random gamma, brightness, color augmentation
|
218 |
+
do_augment = random.random()
|
219 |
+
if do_augment > 0.5:
|
220 |
+
image = self.augment_image(image)
|
221 |
+
|
222 |
+
return image, depth_gt
|
223 |
+
|
224 |
+
def augment_image(self, image):
|
225 |
+
# gamma augmentation
|
226 |
+
gamma = random.uniform(0.9, 1.1)
|
227 |
+
image_aug = image ** gamma
|
228 |
+
|
229 |
+
# brightness augmentation
|
230 |
+
if self.args.dataset == 'nyu':
|
231 |
+
brightness = random.uniform(0.75, 1.25)
|
232 |
+
else:
|
233 |
+
brightness = random.uniform(0.9, 1.1)
|
234 |
+
image_aug = image_aug * brightness
|
235 |
+
|
236 |
+
# color augmentation
|
237 |
+
colors = np.random.uniform(0.9, 1.1, size=3)
|
238 |
+
white = np.ones((image.shape[0], image.shape[1]))
|
239 |
+
color_image = np.stack([white * colors[i] for i in range(3)], axis=2)
|
240 |
+
image_aug *= color_image
|
241 |
+
image_aug = np.clip(image_aug, 0, 1)
|
242 |
+
|
243 |
+
return image_aug
|
244 |
+
|
245 |
+
def Cut_Flip(self, image, depth):
|
246 |
+
|
247 |
+
p = random.random()
|
248 |
+
if p < 0.5:
|
249 |
+
return image, depth
|
250 |
+
image_copy = copy.deepcopy(image)
|
251 |
+
depth_copy = copy.deepcopy(depth)
|
252 |
+
h, w, c = image.shape
|
253 |
+
|
254 |
+
N = 2
|
255 |
+
h_list = []
|
256 |
+
h_interval_list = [] # hight interval
|
257 |
+
for i in range(N-1):
|
258 |
+
h_list.append(random.randint(int(0.2*h), int(0.8*h)))
|
259 |
+
h_list.append(h)
|
260 |
+
h_list.append(0)
|
261 |
+
h_list.sort()
|
262 |
+
h_list_inv = np.array([h]*(N+1))-np.array(h_list)
|
263 |
+
for i in range(len(h_list)-1):
|
264 |
+
h_interval_list.append(h_list[i+1]-h_list[i])
|
265 |
+
for i in range(N):
|
266 |
+
image[h_list[i]:h_list[i+1], :, :] = image_copy[h_list_inv[i]-h_interval_list[i]:h_list_inv[i], :, :]
|
267 |
+
depth[h_list[i]:h_list[i+1], :, :] = depth_copy[h_list_inv[i]-h_interval_list[i]:h_list_inv[i], :, :]
|
268 |
+
|
269 |
+
return image, depth
|
270 |
+
|
271 |
+
|
272 |
+
def __len__(self):
|
273 |
+
return len(self.filenames)
|
274 |
+
|
275 |
+
|
276 |
+
class ToTensor(object):
|
277 |
+
def __init__(self, mode):
|
278 |
+
self.mode = mode
|
279 |
+
self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
280 |
+
|
281 |
+
def __call__(self, sample):
|
282 |
+
image, focal = sample['image'], sample['focal']
|
283 |
+
image = self.to_tensor(image)
|
284 |
+
image = self.normalize(image)
|
285 |
+
|
286 |
+
if self.mode == 'test':
|
287 |
+
return {'image': image, 'focal': focal}
|
288 |
+
|
289 |
+
depth = sample['depth']
|
290 |
+
if self.mode == 'train':
|
291 |
+
depth = self.to_tensor(depth)
|
292 |
+
return {'image': image, 'depth': depth, 'focal': focal}
|
293 |
+
else:
|
294 |
+
has_valid_depth = sample['has_valid_depth']
|
295 |
+
return {'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth}
|
296 |
+
|
297 |
+
def to_tensor(self, pic):
|
298 |
+
if not (_is_pil_image(pic) or _is_numpy_image(pic)):
|
299 |
+
raise TypeError(
|
300 |
+
'pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
|
301 |
+
|
302 |
+
if isinstance(pic, np.ndarray):
|
303 |
+
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
304 |
+
return img
|
305 |
+
|
306 |
+
# handle PIL Image
|
307 |
+
if pic.mode == 'I':
|
308 |
+
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
309 |
+
elif pic.mode == 'I;16':
|
310 |
+
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
311 |
+
else:
|
312 |
+
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
|
313 |
+
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
314 |
+
if pic.mode == 'YCbCr':
|
315 |
+
nchannel = 3
|
316 |
+
elif pic.mode == 'I;16':
|
317 |
+
nchannel = 1
|
318 |
+
else:
|
319 |
+
nchannel = len(pic.mode)
|
320 |
+
img = img.view(pic.size[1], pic.size[0], nchannel)
|
321 |
+
|
322 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
323 |
+
if isinstance(img, torch.ByteTensor):
|
324 |
+
return img.float()
|
325 |
+
else:
|
326 |
+
return img
|
iebins/eval.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.backends.cudnn as cudnn
|
3 |
+
|
4 |
+
import os, sys
|
5 |
+
import argparse
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from utils import post_process_depth, flip_lr, compute_errors
|
10 |
+
from networks.NewCRFDepth import NewCRFDepth
|
11 |
+
|
12 |
+
|
13 |
+
def convert_arg_line_to_args(arg_line):
|
14 |
+
for arg in arg_line.split():
|
15 |
+
if not arg.strip():
|
16 |
+
continue
|
17 |
+
yield arg
|
18 |
+
|
19 |
+
|
20 |
+
parser = argparse.ArgumentParser(description='IEBins PyTorch implementation.', fromfile_prefix_chars='@')
|
21 |
+
parser.convert_arg_line_to_args = convert_arg_line_to_args
|
22 |
+
|
23 |
+
parser.add_argument('--model_name', type=str, help='model name', default='iebins')
|
24 |
+
parser.add_argument('--encoder', type=str, help='type of encoder, base07, large07, tiny07', default='large07')
|
25 |
+
parser.add_argument('--checkpoint_path', type=str, help='path to a checkpoint to load', default='')
|
26 |
+
|
27 |
+
# Dataset
|
28 |
+
parser.add_argument('--dataset', type=str, help='dataset to train on, kitti or nyu', default='nyu')
|
29 |
+
parser.add_argument('--input_height', type=int, help='input height', default=480)
|
30 |
+
parser.add_argument('--input_width', type=int, help='input width', default=640)
|
31 |
+
parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10)
|
32 |
+
|
33 |
+
# Preprocessing
|
34 |
+
parser.add_argument('--do_random_rotate', help='if set, will perform random rotation for augmentation', action='store_true')
|
35 |
+
parser.add_argument('--degree', type=float, help='random rotation maximum degree', default=2.5)
|
36 |
+
parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true')
|
37 |
+
parser.add_argument('--use_right', help='if set, will randomly use right images when train on KITTI', action='store_true')
|
38 |
+
|
39 |
+
# Eval
|
40 |
+
parser.add_argument('--data_path_eval', type=str, help='path to the data for evaluation', required=False)
|
41 |
+
parser.add_argument('--gt_path_eval', type=str, help='path to the groundtruth data for evaluation', required=False)
|
42 |
+
parser.add_argument('--filenames_file_eval', type=str, help='path to the filenames text file for evaluation', required=False)
|
43 |
+
parser.add_argument('--min_depth_eval', type=float, help='minimum depth for evaluation', default=1e-3)
|
44 |
+
parser.add_argument('--max_depth_eval', type=float, help='maximum depth for evaluation', default=80)
|
45 |
+
parser.add_argument('--eigen_crop', help='if set, crops according to Eigen NIPS14', action='store_true')
|
46 |
+
parser.add_argument('--garg_crop', help='if set, crops according to Garg ECCV16', action='store_true')
|
47 |
+
|
48 |
+
|
49 |
+
if sys.argv.__len__() == 2:
|
50 |
+
arg_filename_with_prefix = '@' + sys.argv[1]
|
51 |
+
args = parser.parse_args([arg_filename_with_prefix])
|
52 |
+
else:
|
53 |
+
args = parser.parse_args()
|
54 |
+
|
55 |
+
if args.dataset == 'kitti' or args.dataset == 'nyu':
|
56 |
+
from dataloaders.dataloader import NewDataLoader
|
57 |
+
|
58 |
+
|
59 |
+
def eval(model, dataloader_eval, post_process=False):
|
60 |
+
eval_measures = torch.zeros(10).cuda()
|
61 |
+
|
62 |
+
for _, eval_sample_batched in enumerate(tqdm(dataloader_eval.data)):
|
63 |
+
with torch.no_grad():
|
64 |
+
image = torch.autograd.Variable(eval_sample_batched['image'].cuda())
|
65 |
+
gt_depth = eval_sample_batched['depth']
|
66 |
+
has_valid_depth = eval_sample_batched['has_valid_depth']
|
67 |
+
if not has_valid_depth:
|
68 |
+
# print('Invalid depth. continue.')
|
69 |
+
continue
|
70 |
+
|
71 |
+
pred_depths_r_list, _, _ = model(image)
|
72 |
+
if post_process:
|
73 |
+
image_flipped = flip_lr(image)
|
74 |
+
pred_depths_r_list_flipped, _, _ = model(image_flipped)
|
75 |
+
pred_depth = post_process_depth(pred_depths_r_list[-1], pred_depths_r_list_flipped[-1])
|
76 |
+
|
77 |
+
pred_depth = pred_depth.cpu().numpy().squeeze()
|
78 |
+
gt_depth = gt_depth.cpu().numpy().squeeze()
|
79 |
+
|
80 |
+
if args.do_kb_crop:
|
81 |
+
height, width = gt_depth.shape
|
82 |
+
top_margin = int(height - 352)
|
83 |
+
left_margin = int((width - 1216) / 2)
|
84 |
+
pred_depth_uncropped = np.zeros((height, width), dtype=np.float32)
|
85 |
+
pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth
|
86 |
+
pred_depth = pred_depth_uncropped
|
87 |
+
|
88 |
+
pred_depth[pred_depth < args.min_depth_eval] = args.min_depth_eval
|
89 |
+
pred_depth[pred_depth > args.max_depth_eval] = args.max_depth_eval
|
90 |
+
pred_depth[np.isinf(pred_depth)] = args.max_depth_eval
|
91 |
+
pred_depth[np.isnan(pred_depth)] = args.min_depth_eval
|
92 |
+
|
93 |
+
valid_mask = np.logical_and(gt_depth > args.min_depth_eval, gt_depth < args.max_depth_eval)
|
94 |
+
|
95 |
+
if args.garg_crop or args.eigen_crop:
|
96 |
+
gt_height, gt_width = gt_depth.shape
|
97 |
+
eval_mask = np.zeros(valid_mask.shape)
|
98 |
+
|
99 |
+
if args.garg_crop:
|
100 |
+
eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
|
101 |
+
|
102 |
+
elif args.eigen_crop:
|
103 |
+
if args.dataset == 'kitti':
|
104 |
+
eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
|
105 |
+
elif args.dataset == 'nyu':
|
106 |
+
eval_mask[45:471, 41:601] = 1
|
107 |
+
|
108 |
+
valid_mask = np.logical_and(valid_mask, eval_mask)
|
109 |
+
|
110 |
+
measures = compute_errors(gt_depth[valid_mask], pred_depth[valid_mask])
|
111 |
+
|
112 |
+
eval_measures[:9] += torch.tensor(measures).cuda()
|
113 |
+
eval_measures[9] += 1
|
114 |
+
|
115 |
+
eval_measures_cpu = eval_measures.cpu()
|
116 |
+
cnt = eval_measures_cpu[9].item()
|
117 |
+
eval_measures_cpu /= cnt
|
118 |
+
print('Computing errors for {} eval samples'.format(int(cnt)), ', post_process: ', post_process)
|
119 |
+
print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('silog', 'abs_rel', 'log10', 'rms',
|
120 |
+
'sq_rel', 'log_rms', 'd1', 'd2',
|
121 |
+
'd3'))
|
122 |
+
for i in range(8):
|
123 |
+
print('{:7.4f}, '.format(eval_measures_cpu[i]), end='')
|
124 |
+
print('{:7.4f}'.format(eval_measures_cpu[8]))
|
125 |
+
return eval_measures_cpu
|
126 |
+
|
127 |
+
|
128 |
+
def main_worker(args):
|
129 |
+
|
130 |
+
# CRF model
|
131 |
+
model = NewCRFDepth(version=args.encoder, inv_depth=False, max_depth=args.max_depth, pretrained=None)
|
132 |
+
model.train()
|
133 |
+
|
134 |
+
num_params = sum([np.prod(p.size()) for p in model.parameters()])
|
135 |
+
print("== Total number of parameters: {}".format(num_params))
|
136 |
+
|
137 |
+
num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad])
|
138 |
+
print("== Total number of learning parameters: {}".format(num_params_update))
|
139 |
+
|
140 |
+
model = torch.nn.DataParallel(model)
|
141 |
+
model.cuda()
|
142 |
+
|
143 |
+
print("== Model Initialized")
|
144 |
+
|
145 |
+
if args.checkpoint_path != '':
|
146 |
+
if os.path.isfile(args.checkpoint_path):
|
147 |
+
print("== Loading checkpoint '{}'".format(args.checkpoint_path))
|
148 |
+
checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
|
149 |
+
model.load_state_dict(checkpoint['model'])
|
150 |
+
print("== Loaded checkpoint '{}'".format(args.checkpoint_path))
|
151 |
+
del checkpoint
|
152 |
+
else:
|
153 |
+
print("== No checkpoint found at '{}'".format(args.checkpoint_path))
|
154 |
+
|
155 |
+
cudnn.benchmark = True
|
156 |
+
|
157 |
+
dataloader_eval = NewDataLoader(args, 'online_eval')
|
158 |
+
|
159 |
+
# ===== Evaluation ======
|
160 |
+
model.eval()
|
161 |
+
with torch.no_grad():
|
162 |
+
eval_measures = eval(model, dataloader_eval, post_process=True)
|
163 |
+
|
164 |
+
|
165 |
+
def main():
|
166 |
+
torch.cuda.empty_cache()
|
167 |
+
args.distributed = False
|
168 |
+
ngpus_per_node = torch.cuda.device_count()
|
169 |
+
if ngpus_per_node > 1:
|
170 |
+
print("This machine has more than 1 gpu. Please set \'CUDA_VISIBLE_DEVICES=0\'")
|
171 |
+
return -1
|
172 |
+
|
173 |
+
main_worker(args)
|
174 |
+
|
175 |
+
|
176 |
+
if __name__ == '__main__':
|
177 |
+
main()
|
iebins/eval_sun.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.backends.cudnn as cudnn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import os, sys
|
5 |
+
import argparse
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from utils import post_process_depth, flip_lr, compute_errors
|
10 |
+
from networks.NewCRFDepth import NewCRFDepth
|
11 |
+
|
12 |
+
|
13 |
+
def convert_arg_line_to_args(arg_line):
|
14 |
+
for arg in arg_line.split():
|
15 |
+
if not arg.strip():
|
16 |
+
continue
|
17 |
+
yield arg
|
18 |
+
|
19 |
+
|
20 |
+
parser = argparse.ArgumentParser(description='IEbins PyTorch implementation.', fromfile_prefix_chars='@')
|
21 |
+
parser.convert_arg_line_to_args = convert_arg_line_to_args
|
22 |
+
|
23 |
+
parser.add_argument('--model_name', type=str, help='model name', default='iebins')
|
24 |
+
parser.add_argument('--encoder', type=str, help='type of encoder, base07, large07, tiny07', default='large07')
|
25 |
+
parser.add_argument('--checkpoint_path', type=str, help='path to a checkpoint to load', default='')
|
26 |
+
|
27 |
+
# Dataset
|
28 |
+
parser.add_argument('--dataset', type=str, help='dataset to train on, kitti or nyu', default='nyu')
|
29 |
+
parser.add_argument('--input_height', type=int, help='input height', default=480)
|
30 |
+
parser.add_argument('--input_width', type=int, help='input width', default=640)
|
31 |
+
parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10)
|
32 |
+
|
33 |
+
# Preprocessing
|
34 |
+
parser.add_argument('--do_random_rotate', help='if set, will perform random rotation for augmentation', action='store_true')
|
35 |
+
parser.add_argument('--degree', type=float, help='random rotation maximum degree', default=2.5)
|
36 |
+
parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true')
|
37 |
+
parser.add_argument('--use_right', help='if set, will randomly use right images when train on KITTI', action='store_true')
|
38 |
+
|
39 |
+
# Eval
|
40 |
+
parser.add_argument('--data_path_eval', type=str, help='path to the data for evaluation', required=False)
|
41 |
+
parser.add_argument('--gt_path_eval', type=str, help='path to the groundtruth data for evaluation', required=False)
|
42 |
+
parser.add_argument('--filenames_file_eval', type=str, help='path to the filenames text file for evaluation', required=False)
|
43 |
+
parser.add_argument('--min_depth_eval', type=float, help='minimum depth for evaluation', default=1e-3)
|
44 |
+
parser.add_argument('--max_depth_eval', type=float, help='maximum depth for evaluation', default=80)
|
45 |
+
parser.add_argument('--eigen_crop', help='if set, crops according to Eigen NIPS14', action='store_true')
|
46 |
+
parser.add_argument('--garg_crop', help='if set, crops according to Garg ECCV16', action='store_true')
|
47 |
+
|
48 |
+
|
49 |
+
if sys.argv.__len__() == 2:
|
50 |
+
arg_filename_with_prefix = '@' + sys.argv[1]
|
51 |
+
args = parser.parse_args([arg_filename_with_prefix])
|
52 |
+
else:
|
53 |
+
args = parser.parse_args()
|
54 |
+
|
55 |
+
if args.dataset == 'nyu':
|
56 |
+
from dataloaders.dataloader_sun import NewDataLoader
|
57 |
+
|
58 |
+
|
59 |
+
def eval(model, dataloader_eval, post_process=False):
|
60 |
+
eval_measures = torch.zeros(10).cuda()
|
61 |
+
for _, eval_sample_batched in enumerate(tqdm(dataloader_eval.data)):
|
62 |
+
with torch.no_grad():
|
63 |
+
image = torch.autograd.Variable(eval_sample_batched['image'].cuda())
|
64 |
+
gt_depth = eval_sample_batched['depth']
|
65 |
+
has_valid_depth = eval_sample_batched['has_valid_depth']
|
66 |
+
if not has_valid_depth:
|
67 |
+
# print('Invalid depth. continue.')
|
68 |
+
continue
|
69 |
+
_, hh, ww, _ = gt_depth.shape
|
70 |
+
pred_depths_r_list, _, _ = model(image)
|
71 |
+
if post_process:
|
72 |
+
image_flipped = flip_lr(image)
|
73 |
+
pred_depths_r_list_flipped, _, _ = model(image_flipped)
|
74 |
+
pred_depth = post_process_depth(pred_depths_r_list[-1], pred_depths_r_list_flipped[-1])
|
75 |
+
pred_depth = F.interpolate(pred_depth, [hh, ww], mode="bilinear", align_corners=False)
|
76 |
+
|
77 |
+
pred_depth = pred_depth.cpu().numpy().squeeze()
|
78 |
+
gt_depth = gt_depth.cpu().numpy().squeeze()
|
79 |
+
|
80 |
+
if args.do_kb_crop:
|
81 |
+
height, width = gt_depth.shape
|
82 |
+
top_margin = int(height - 352)
|
83 |
+
left_margin = int((width - 1216) / 2)
|
84 |
+
pred_depth_uncropped = np.zeros((height, width), dtype=np.float32)
|
85 |
+
pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth
|
86 |
+
pred_depth = pred_depth_uncropped
|
87 |
+
|
88 |
+
pred_depth[pred_depth < args.min_depth_eval] = args.min_depth_eval
|
89 |
+
pred_depth[pred_depth > args.max_depth_eval] = args.max_depth_eval
|
90 |
+
pred_depth[np.isinf(pred_depth)] = args.max_depth_eval
|
91 |
+
pred_depth[np.isnan(pred_depth)] = args.min_depth_eval
|
92 |
+
pred_depth[pred_depth > 8] = 8
|
93 |
+
gt_depth[gt_depth > 8] = 8
|
94 |
+
|
95 |
+
valid_mask = np.logical_and(gt_depth > args.min_depth_eval, gt_depth < args.max_depth_eval)
|
96 |
+
|
97 |
+
if args.garg_crop or args.eigen_crop:
|
98 |
+
gt_height, gt_width = gt_depth.shape
|
99 |
+
eval_mask = np.zeros(valid_mask.shape)
|
100 |
+
|
101 |
+
if args.garg_crop:
|
102 |
+
eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
|
103 |
+
|
104 |
+
elif args.eigen_crop:
|
105 |
+
if args.dataset == 'kitti':
|
106 |
+
eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
|
107 |
+
elif args.dataset == 'nyu':
|
108 |
+
eval_mask[45:471, 41:601] = 1
|
109 |
+
|
110 |
+
valid_mask = np.logical_and(valid_mask, eval_mask)
|
111 |
+
|
112 |
+
measures = compute_errors(gt_depth[valid_mask], pred_depth[valid_mask])
|
113 |
+
|
114 |
+
eval_measures[:9] += torch.tensor(measures).cuda()
|
115 |
+
eval_measures[9] += 1
|
116 |
+
|
117 |
+
eval_measures_cpu = eval_measures.cpu()
|
118 |
+
cnt = eval_measures_cpu[9].item()
|
119 |
+
eval_measures_cpu /= cnt
|
120 |
+
print('Computing errors for {} eval samples'.format(int(cnt)), ', post_process: ', post_process)
|
121 |
+
print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('silog', 'abs_rel', 'log10', 'rms',
|
122 |
+
'sq_rel', 'log_rms', 'd1', 'd2',
|
123 |
+
'd3'))
|
124 |
+
for i in range(8):
|
125 |
+
print('{:7.4f}, '.format(eval_measures_cpu[i]), end='')
|
126 |
+
print('{:7.4f}'.format(eval_measures_cpu[8]))
|
127 |
+
return eval_measures_cpu
|
128 |
+
|
129 |
+
|
130 |
+
def main_worker(args):
|
131 |
+
|
132 |
+
# CRF model
|
133 |
+
model = NewCRFDepth(version=args.encoder, inv_depth=False, max_depth=args.max_depth, pretrained=None)
|
134 |
+
model.train()
|
135 |
+
|
136 |
+
num_params = sum([np.prod(p.size()) for p in model.parameters()])
|
137 |
+
print("== Total number of parameters: {}".format(num_params))
|
138 |
+
|
139 |
+
num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad])
|
140 |
+
print("== Total number of learning parameters: {}".format(num_params_update))
|
141 |
+
|
142 |
+
model = torch.nn.DataParallel(model)
|
143 |
+
model.cuda()
|
144 |
+
|
145 |
+
print("== Model Initialized")
|
146 |
+
|
147 |
+
if args.checkpoint_path != '':
|
148 |
+
if os.path.isfile(args.checkpoint_path):
|
149 |
+
print("== Loading checkpoint '{}'".format(args.checkpoint_path))
|
150 |
+
checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
|
151 |
+
model.load_state_dict(checkpoint['model'])
|
152 |
+
print("== Loaded checkpoint '{}'".format(args.checkpoint_path))
|
153 |
+
del checkpoint
|
154 |
+
else:
|
155 |
+
print("== No checkpoint found at '{}'".format(args.checkpoint_path))
|
156 |
+
|
157 |
+
cudnn.benchmark = True
|
158 |
+
|
159 |
+
dataloader_eval = NewDataLoader(args, 'online_eval')
|
160 |
+
|
161 |
+
# ===== Evaluation ======
|
162 |
+
model.eval()
|
163 |
+
with torch.no_grad():
|
164 |
+
eval_measures = eval(model, dataloader_eval, post_process=True)
|
165 |
+
|
166 |
+
|
167 |
+
def main():
|
168 |
+
torch.cuda.empty_cache()
|
169 |
+
args.distributed = False
|
170 |
+
ngpus_per_node = torch.cuda.device_count()
|
171 |
+
if ngpus_per_node > 1:
|
172 |
+
print("This machine has more than 1 gpu. Please set \'CUDA_VISIBLE_DEVICES=0\'")
|
173 |
+
return -1
|
174 |
+
|
175 |
+
main_worker(args)
|
176 |
+
|
177 |
+
|
178 |
+
if __name__ == '__main__':
|
179 |
+
main()
|
iebins/inference_single_image.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.backends.cudnn as cudnn
|
3 |
+
|
4 |
+
import os, sys
|
5 |
+
import argparse
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from utils import post_process_depth, flip_lr, compute_errors
|
10 |
+
from networks.NewCRFDepth import NewCRFDepth
|
11 |
+
from PIL import Image
|
12 |
+
from torchvision import transforms
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
|
15 |
+
|
16 |
+
def convert_arg_line_to_args(arg_line):
|
17 |
+
for arg in arg_line.split():
|
18 |
+
if not arg.strip():
|
19 |
+
continue
|
20 |
+
yield arg
|
21 |
+
|
22 |
+
|
23 |
+
parser = argparse.ArgumentParser(description='IEBins PyTorch implementation.', fromfile_prefix_chars='@')
|
24 |
+
parser.convert_arg_line_to_args = convert_arg_line_to_args
|
25 |
+
|
26 |
+
parser.add_argument('--model_name', type=str, help='model name', default='iebins')
|
27 |
+
parser.add_argument('--encoder', type=str, help='type of encoder, base07, large07', default='large07')
|
28 |
+
parser.add_argument('--checkpoint_path', type=str, help='path to a checkpoint to load', default='')
|
29 |
+
parser.add_argument('--dataset', type=str, help='dataset to train on, kitti or nyu', default='nyu')
|
30 |
+
parser.add_argument('--image_path', type=str, help='path to the image for inference', required=False)
|
31 |
+
parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10)
|
32 |
+
|
33 |
+
|
34 |
+
if sys.argv.__len__() == 2:
|
35 |
+
arg_filename_with_prefix = '@' + sys.argv[1]
|
36 |
+
args = parser.parse_args([arg_filename_with_prefix])
|
37 |
+
else:
|
38 |
+
args = parser.parse_args()
|
39 |
+
|
40 |
+
|
41 |
+
def inference(model, post_process=False):
|
42 |
+
|
43 |
+
image = np.asarray(Image.open(args.image_path), dtype=np.float32) / 255.0
|
44 |
+
|
45 |
+
if args.dataset == 'kitti':
|
46 |
+
height = image.shape[0]
|
47 |
+
width = image.shape[1]
|
48 |
+
top_margin = int(height - 352)
|
49 |
+
left_margin = int((width - 1216) / 2)
|
50 |
+
image = image[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
|
51 |
+
|
52 |
+
image = torch.from_numpy(image.transpose((2, 0, 1)))
|
53 |
+
image = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image)
|
54 |
+
|
55 |
+
with torch.no_grad():
|
56 |
+
image = torch.autograd.Variable(image.unsqueeze(0).cuda())
|
57 |
+
|
58 |
+
pred_depths_r_list, _, _ = model(image)
|
59 |
+
if post_process:
|
60 |
+
image_flipped = flip_lr(image)
|
61 |
+
pred_depths_r_list_flipped, _, _ = model(image_flipped)
|
62 |
+
pred_depth = post_process_depth(pred_depths_r_list[-1], pred_depths_r_list_flipped[-1])
|
63 |
+
|
64 |
+
pred_depth = pred_depth.cpu().numpy().squeeze()
|
65 |
+
|
66 |
+
if args.dataset == 'kitti':
|
67 |
+
plt.imsave('depth.png', np.log10(pred_depth), cmap='magma')
|
68 |
+
else:
|
69 |
+
plt.imsave('depth.png', pred_depth, cmap='jet')
|
70 |
+
|
71 |
+
|
72 |
+
def main_worker(args):
|
73 |
+
|
74 |
+
model = NewCRFDepth(version=args.encoder, inv_depth=False, max_depth=args.max_depth, pretrained=None)
|
75 |
+
model.train()
|
76 |
+
|
77 |
+
num_params = sum([np.prod(p.size()) for p in model.parameters()])
|
78 |
+
print("== Total number of parameters: {}".format(num_params))
|
79 |
+
|
80 |
+
num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad])
|
81 |
+
print("== Total number of learning parameters: {}".format(num_params_update))
|
82 |
+
|
83 |
+
model = torch.nn.DataParallel(model)
|
84 |
+
model.cuda()
|
85 |
+
|
86 |
+
print("== Model Initialized")
|
87 |
+
|
88 |
+
if args.checkpoint_path != '':
|
89 |
+
if os.path.isfile(args.checkpoint_path):
|
90 |
+
checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
|
91 |
+
model.load_state_dict(checkpoint['model'])
|
92 |
+
print("== Loaded checkpoint '{}'".format(args.checkpoint_path))
|
93 |
+
del checkpoint
|
94 |
+
else:
|
95 |
+
print("== No checkpoint found at '{}'".format(args.checkpoint_path))
|
96 |
+
|
97 |
+
cudnn.benchmark = True
|
98 |
+
|
99 |
+
# ===== Inference ======
|
100 |
+
model.eval()
|
101 |
+
with torch.no_grad():
|
102 |
+
inference(model, post_process=True)
|
103 |
+
|
104 |
+
|
105 |
+
def main():
|
106 |
+
torch.cuda.empty_cache()
|
107 |
+
args.distributed = False
|
108 |
+
ngpus_per_node = torch.cuda.device_count()
|
109 |
+
if ngpus_per_node > 1:
|
110 |
+
print("This machine has more than 1 gpu. Please set \'CUDA_VISIBLE_DEVICES=0\'")
|
111 |
+
return -1
|
112 |
+
|
113 |
+
main_worker(args)
|
114 |
+
|
115 |
+
|
116 |
+
if __name__ == '__main__':
|
117 |
+
main()
|
iebins/networks/NewCRFDepth.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .swin_transformer import SwinTransformer
|
6 |
+
from .newcrf_layers import NewCRF
|
7 |
+
from .uper_crf_head import PSP
|
8 |
+
from .depth_update import *
|
9 |
+
########################################################################################################################
|
10 |
+
|
11 |
+
|
12 |
+
class NewCRFDepth(nn.Module):
|
13 |
+
"""
|
14 |
+
Depth network based on neural window FC-CRFs architecture.
|
15 |
+
"""
|
16 |
+
def __init__(self, version=None, inv_depth=False, pretrained=None,
|
17 |
+
frozen_stages=-1, min_depth=0.1, max_depth=100.0, **kwargs):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
self.inv_depth = inv_depth
|
21 |
+
self.with_auxiliary_head = False
|
22 |
+
self.with_neck = False
|
23 |
+
|
24 |
+
norm_cfg = dict(type='BN', requires_grad=True)
|
25 |
+
|
26 |
+
window_size = int(version[-2:])
|
27 |
+
|
28 |
+
if version[:-2] == 'base':
|
29 |
+
embed_dim = 128
|
30 |
+
depths = [2, 2, 18, 2]
|
31 |
+
num_heads = [4, 8, 16, 32]
|
32 |
+
in_channels = [128, 256, 512, 1024]
|
33 |
+
self.update = BasicUpdateBlockDepth(hidden_dim=128, context_dim=128)
|
34 |
+
elif version[:-2] == 'large':
|
35 |
+
embed_dim = 192
|
36 |
+
depths = [2, 2, 18, 2]
|
37 |
+
num_heads = [6, 12, 24, 48]
|
38 |
+
in_channels = [192, 384, 768, 1536]
|
39 |
+
self.update = BasicUpdateBlockDepth(hidden_dim=128, context_dim=192)
|
40 |
+
elif version[:-2] == 'tiny':
|
41 |
+
embed_dim = 96
|
42 |
+
depths = [2, 2, 6, 2]
|
43 |
+
num_heads = [3, 6, 12, 24]
|
44 |
+
in_channels = [96, 192, 384, 768]
|
45 |
+
self.update = BasicUpdateBlockDepth(hidden_dim=128, context_dim=96)
|
46 |
+
|
47 |
+
backbone_cfg = dict(
|
48 |
+
embed_dim=embed_dim,
|
49 |
+
depths=depths,
|
50 |
+
num_heads=num_heads,
|
51 |
+
window_size=window_size,
|
52 |
+
ape=False,
|
53 |
+
drop_path_rate=0.3,
|
54 |
+
patch_norm=True,
|
55 |
+
use_checkpoint=False,
|
56 |
+
frozen_stages=frozen_stages
|
57 |
+
)
|
58 |
+
|
59 |
+
embed_dim = 512
|
60 |
+
decoder_cfg = dict(
|
61 |
+
in_channels=in_channels,
|
62 |
+
in_index=[0, 1, 2, 3],
|
63 |
+
pool_scales=(1, 2, 3, 6),
|
64 |
+
channels=embed_dim,
|
65 |
+
dropout_ratio=0.0,
|
66 |
+
num_classes=32,
|
67 |
+
norm_cfg=norm_cfg,
|
68 |
+
align_corners=False
|
69 |
+
)
|
70 |
+
|
71 |
+
self.backbone = SwinTransformer(**backbone_cfg)
|
72 |
+
v_dim = decoder_cfg['num_classes']*4
|
73 |
+
win = 7
|
74 |
+
crf_dims = [128, 256, 512, 1024]
|
75 |
+
v_dims = [64, 128, 256, embed_dim]
|
76 |
+
self.crf3 = NewCRF(input_dim=in_channels[3], embed_dim=crf_dims[3], window_size=win, v_dim=v_dims[3], num_heads=32)
|
77 |
+
self.crf2 = NewCRF(input_dim=in_channels[2], embed_dim=crf_dims[2], window_size=win, v_dim=v_dims[2], num_heads=16)
|
78 |
+
self.crf1 = NewCRF(input_dim=in_channels[1], embed_dim=crf_dims[1], window_size=win, v_dim=v_dims[1], num_heads=8)
|
79 |
+
|
80 |
+
self.decoder = PSP(**decoder_cfg)
|
81 |
+
self.disp_head1 = DispHead(input_dim=crf_dims[0])
|
82 |
+
|
83 |
+
self.up_mode = 'bilinear'
|
84 |
+
if self.up_mode == 'mask':
|
85 |
+
self.mask_head = nn.Sequential(
|
86 |
+
nn.Conv2d(v_dims[0], 64, 3, padding=1),
|
87 |
+
nn.ReLU(inplace=True),
|
88 |
+
nn.Conv2d(64, 16*9, 1, padding=0))
|
89 |
+
|
90 |
+
self.min_depth = min_depth
|
91 |
+
self.max_depth = max_depth
|
92 |
+
self.depth_num = 16
|
93 |
+
self.hidden_dim = 128
|
94 |
+
self.project = Projection(v_dims[0], self.hidden_dim)
|
95 |
+
|
96 |
+
self.init_weights(pretrained=pretrained)
|
97 |
+
|
98 |
+
def init_weights(self, pretrained=None):
|
99 |
+
"""Initialize the weights in backbone and heads.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
pretrained (str, optional): Path to pre-trained weights.
|
103 |
+
Defaults to None.
|
104 |
+
"""
|
105 |
+
print(f'== Load encoder backbone from: {pretrained}')
|
106 |
+
self.backbone.init_weights(pretrained=pretrained)
|
107 |
+
self.decoder.init_weights()
|
108 |
+
if self.with_auxiliary_head:
|
109 |
+
if isinstance(self.auxiliary_head, nn.ModuleList):
|
110 |
+
for aux_head in self.auxiliary_head:
|
111 |
+
aux_head.init_weights()
|
112 |
+
else:
|
113 |
+
self.auxiliary_head.init_weights()
|
114 |
+
|
115 |
+
def upsample_mask(self, disp, mask):
|
116 |
+
""" Upsample disp [H/4, W/4, 1] -> [H, W, 1] using convex combination """
|
117 |
+
N, C, H, W = disp.shape
|
118 |
+
mask = mask.view(N, 1, 9, 4, 4, H, W)
|
119 |
+
mask = torch.softmax(mask, dim=2)
|
120 |
+
|
121 |
+
up_disp = F.unfold(disp, kernel_size=3, padding=1)
|
122 |
+
up_disp = up_disp.view(N, C, 9, 1, 1, H, W)
|
123 |
+
|
124 |
+
up_disp = torch.sum(mask * up_disp, dim=2)
|
125 |
+
up_disp = up_disp.permute(0, 1, 4, 2, 5, 3)
|
126 |
+
return up_disp.reshape(N, C, 4*H, 4*W)
|
127 |
+
|
128 |
+
def forward(self, imgs, epoch=1, step=100):
|
129 |
+
|
130 |
+
feats = self.backbone(imgs)
|
131 |
+
ppm_out = self.decoder(feats)
|
132 |
+
|
133 |
+
e3 = self.crf3(feats[3], ppm_out)
|
134 |
+
e3 = nn.PixelShuffle(2)(e3)
|
135 |
+
e2 = self.crf2(feats[2], e3)
|
136 |
+
e2 = nn.PixelShuffle(2)(e2)
|
137 |
+
e1 = self.crf1(feats[1], e2)
|
138 |
+
e1 = nn.PixelShuffle(2)(e1)
|
139 |
+
|
140 |
+
# iterative bins
|
141 |
+
if epoch == 0 and step < 80:
|
142 |
+
max_tree_depth = 3
|
143 |
+
else:
|
144 |
+
max_tree_depth = 6
|
145 |
+
|
146 |
+
if self.up_mode == 'mask':
|
147 |
+
mask = self.mask_head(e1)
|
148 |
+
|
149 |
+
b, c, h, w = e1.shape
|
150 |
+
device = e1.device
|
151 |
+
|
152 |
+
depth = torch.zeros([b, 1, h, w]).to(device)
|
153 |
+
context = feats[0]
|
154 |
+
gru_hidden = torch.tanh(self.project(e1))
|
155 |
+
pred_depths_r_list, pred_depths_c_list, uncertainty_maps_list = self.update(depth, context, gru_hidden, max_tree_depth, self.depth_num, self.min_depth, self.max_depth)
|
156 |
+
|
157 |
+
if self.up_mode == 'mask':
|
158 |
+
for i in range(len(pred_depths_r_list)):
|
159 |
+
pred_depths_r_list[i] = self.upsample_mask(pred_depths_r_list[i], mask)
|
160 |
+
for i in range(len(pred_depths_c_list)):
|
161 |
+
pred_depths_c_list[i] = self.upsample_mask(pred_depths_c_list[i], mask.detach())
|
162 |
+
for i in range(len(uncertainty_maps_list)):
|
163 |
+
uncertainty_maps_list[i] = self.upsample_mask(uncertainty_maps_list[i], mask.detach())
|
164 |
+
else:
|
165 |
+
for i in range(len(pred_depths_r_list)):
|
166 |
+
pred_depths_r_list[i] = upsample(pred_depths_r_list[i], scale_factor=4)
|
167 |
+
for i in range(len(pred_depths_c_list)):
|
168 |
+
pred_depths_c_list[i] = upsample(pred_depths_c_list[i], scale_factor=4)
|
169 |
+
for i in range(len(uncertainty_maps_list)):
|
170 |
+
uncertainty_maps_list[i] = upsample(uncertainty_maps_list[i], scale_factor=4)
|
171 |
+
|
172 |
+
return pred_depths_r_list, pred_depths_c_list, uncertainty_maps_list
|
173 |
+
|
174 |
+
class DispHead(nn.Module):
|
175 |
+
def __init__(self, input_dim=100):
|
176 |
+
super(DispHead, self).__init__()
|
177 |
+
# self.norm1 = nn.BatchNorm2d(input_dim)
|
178 |
+
self.conv1 = nn.Conv2d(input_dim, 1, 3, padding=1)
|
179 |
+
# self.relu = nn.ReLU(inplace=True)
|
180 |
+
self.sigmoid = nn.Sigmoid()
|
181 |
+
|
182 |
+
def forward(self, x, scale):
|
183 |
+
# x = self.relu(self.norm1(x))
|
184 |
+
x = self.sigmoid(self.conv1(x))
|
185 |
+
if scale > 1:
|
186 |
+
x = upsample(x, scale_factor=scale)
|
187 |
+
return x
|
188 |
+
|
189 |
+
class BasicUpdateBlockDepth(nn.Module):
|
190 |
+
def __init__(self, hidden_dim=128, context_dim=192):
|
191 |
+
super(BasicUpdateBlockDepth, self).__init__()
|
192 |
+
|
193 |
+
self.encoder = ProjectionInputDepth(hidden_dim=hidden_dim, out_chs=hidden_dim * 2)
|
194 |
+
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=self.encoder.out_chs+context_dim)
|
195 |
+
self.p_head = PHead(hidden_dim, hidden_dim)
|
196 |
+
|
197 |
+
def forward(self, depth, context, gru_hidden, seq_len, depth_num, min_depth, max_depth):
|
198 |
+
|
199 |
+
pred_depths_r_list = []
|
200 |
+
pred_depths_c_list = []
|
201 |
+
uncertainty_maps_list = []
|
202 |
+
|
203 |
+
b, _, h, w = depth.shape
|
204 |
+
depth_range = max_depth - min_depth
|
205 |
+
interval = depth_range / depth_num
|
206 |
+
interval = interval * torch.ones_like(depth)
|
207 |
+
interval = interval.repeat(1, depth_num, 1, 1)
|
208 |
+
interval = torch.cat([torch.ones_like(depth) * min_depth, interval], 1)
|
209 |
+
|
210 |
+
bin_edges = torch.cumsum(interval, 1)
|
211 |
+
current_depths = 0.5 * (bin_edges[:, :-1] + bin_edges[:, 1:])
|
212 |
+
index_iter = 0
|
213 |
+
|
214 |
+
for i in range(seq_len):
|
215 |
+
input_features = self.encoder(current_depths.detach())
|
216 |
+
input_c = torch.cat([input_features, context], dim=1)
|
217 |
+
|
218 |
+
gru_hidden = self.gru(gru_hidden, input_c)
|
219 |
+
pred_prob = self.p_head(gru_hidden)
|
220 |
+
|
221 |
+
depth_r = (pred_prob * current_depths.detach()).sum(1, keepdim=True)
|
222 |
+
pred_depths_r_list.append(depth_r)
|
223 |
+
|
224 |
+
uncertainty_map = torch.sqrt((pred_prob * ((current_depths.detach() - depth_r.repeat(1, depth_num, 1, 1))**2)).sum(1, keepdim=True))
|
225 |
+
uncertainty_maps_list.append(uncertainty_map)
|
226 |
+
|
227 |
+
index_iter = index_iter + 1
|
228 |
+
|
229 |
+
pred_label = get_label(torch.squeeze(depth_r, 1), bin_edges, depth_num).unsqueeze(1)
|
230 |
+
depth_c = torch.gather(current_depths.detach(), 1, pred_label.detach())
|
231 |
+
pred_depths_c_list.append(depth_c)
|
232 |
+
|
233 |
+
label_target_bin_left = pred_label
|
234 |
+
target_bin_left = torch.gather(bin_edges, 1, label_target_bin_left)
|
235 |
+
label_target_bin_right = (pred_label.float() + 1).long()
|
236 |
+
target_bin_right = torch.gather(bin_edges, 1, label_target_bin_right)
|
237 |
+
|
238 |
+
bin_edges, current_depths = update_sample(bin_edges, target_bin_left, target_bin_right, depth_r.detach(), pred_label.detach(), depth_num, min_depth, max_depth, uncertainty_map)
|
239 |
+
|
240 |
+
return pred_depths_r_list, pred_depths_c_list, uncertainty_maps_list
|
241 |
+
|
242 |
+
class PHead(nn.Module):
|
243 |
+
def __init__(self, input_dim=128, hidden_dim=128):
|
244 |
+
super(PHead, self).__init__()
|
245 |
+
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
246 |
+
self.conv2 = nn.Conv2d(hidden_dim, 16, 3, padding=1)
|
247 |
+
|
248 |
+
def forward(self, x):
|
249 |
+
out = torch.softmax(self.conv2(F.relu(self.conv1(x))), 1)
|
250 |
+
return out
|
251 |
+
|
252 |
+
class SepConvGRU(nn.Module):
|
253 |
+
def __init__(self, hidden_dim=128, input_dim=128+192):
|
254 |
+
super(SepConvGRU, self).__init__()
|
255 |
+
|
256 |
+
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
257 |
+
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
258 |
+
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
259 |
+
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
260 |
+
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
261 |
+
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
262 |
+
|
263 |
+
def forward(self, h, x):
|
264 |
+
# horizontal
|
265 |
+
hx = torch.cat([h, x], dim=1)
|
266 |
+
z = torch.sigmoid(self.convz1(hx))
|
267 |
+
r = torch.sigmoid(self.convr1(hx))
|
268 |
+
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
|
269 |
+
|
270 |
+
h = (1-z) * h + z * q
|
271 |
+
|
272 |
+
# vertical
|
273 |
+
hx = torch.cat([h, x], dim=1)
|
274 |
+
z = torch.sigmoid(self.convz2(hx))
|
275 |
+
r = torch.sigmoid(self.convr2(hx))
|
276 |
+
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
|
277 |
+
h = (1-z) * h + z * q
|
278 |
+
|
279 |
+
return h
|
280 |
+
|
281 |
+
class ProjectionInputDepth(nn.Module):
|
282 |
+
def __init__(self, hidden_dim, out_chs):
|
283 |
+
super().__init__()
|
284 |
+
self.out_chs = out_chs
|
285 |
+
self.convd1 = nn.Conv2d(16, hidden_dim, 7, padding=3)
|
286 |
+
self.convd2 = nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1)
|
287 |
+
self.convd3 = nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1)
|
288 |
+
self.convd4 = nn.Conv2d(hidden_dim, out_chs, 3, padding=1)
|
289 |
+
|
290 |
+
def forward(self, depth):
|
291 |
+
d = F.relu(self.convd1(depth))
|
292 |
+
d = F.relu(self.convd2(d))
|
293 |
+
d = F.relu(self.convd3(d))
|
294 |
+
d = F.relu(self.convd4(d))
|
295 |
+
|
296 |
+
return d
|
297 |
+
|
298 |
+
class Projection(nn.Module):
|
299 |
+
def __init__(self, in_chs, out_chs):
|
300 |
+
super().__init__()
|
301 |
+
self.conv = nn.Conv2d(in_chs, out_chs, 3, padding=1)
|
302 |
+
|
303 |
+
def forward(self, x):
|
304 |
+
out = self.conv(x)
|
305 |
+
|
306 |
+
return out
|
307 |
+
|
308 |
+
def upsample(x, scale_factor=2, mode="bilinear", align_corners=False):
|
309 |
+
"""Upsample input tensor by a factor of 2
|
310 |
+
"""
|
311 |
+
return F.interpolate(x, scale_factor=scale_factor, mode=mode, align_corners=align_corners)
|
312 |
+
|
313 |
+
def upsample1(x, scale_factor=2, mode="bilinear"):
|
314 |
+
"""Upsample input tensor by a factor of 2
|
315 |
+
"""
|
316 |
+
return F.interpolate(x, scale_factor=scale_factor, mode=mode)
|
317 |
+
|
318 |
+
|
iebins/networks/__init__.py
ADDED
File without changes
|
iebins/networks/depth_update.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import copy
|
4 |
+
|
5 |
+
def update_sample(bin_edges, target_bin_left, target_bin_right, depth_r, pred_label, depth_num, min_depth, max_depth, uncertainty_range):
|
6 |
+
|
7 |
+
with torch.no_grad():
|
8 |
+
b, _, h, w = bin_edges.shape
|
9 |
+
|
10 |
+
mode = 'direct'
|
11 |
+
if mode == 'direct':
|
12 |
+
depth_range = uncertainty_range
|
13 |
+
depth_start_update = torch.clamp_min(depth_r - 0.5 * depth_range, min_depth)
|
14 |
+
else:
|
15 |
+
depth_range = uncertainty_range + (target_bin_right - target_bin_left).abs()
|
16 |
+
depth_start_update = torch.clamp_min(target_bin_left - 0.5 * uncertainty_range, min_depth)
|
17 |
+
|
18 |
+
interval = depth_range / depth_num
|
19 |
+
interval = interval.repeat(1, depth_num, 1, 1)
|
20 |
+
interval = torch.cat([torch.ones([b, 1, h, w], device=bin_edges.device) * depth_start_update, interval], 1)
|
21 |
+
|
22 |
+
bin_edges = torch.cumsum(interval, 1).clamp(min_depth, max_depth)
|
23 |
+
curr_depth = 0.5 * (bin_edges[:, :-1] + bin_edges[:, 1:])
|
24 |
+
|
25 |
+
return bin_edges.detach(), curr_depth.detach()
|
26 |
+
|
27 |
+
def get_label(gt_depth_img, bin_edges, depth_num):
|
28 |
+
|
29 |
+
with torch.no_grad():
|
30 |
+
gt_label = torch.zeros(gt_depth_img.size(), dtype=torch.int64, device=gt_depth_img.device)
|
31 |
+
for i in range(depth_num):
|
32 |
+
bin_mask = torch.ge(gt_depth_img, bin_edges[:, i])
|
33 |
+
bin_mask = torch.logical_and(bin_mask,
|
34 |
+
torch.lt(gt_depth_img, bin_edges[:, i + 1]))
|
35 |
+
gt_label[bin_mask] = i
|
36 |
+
|
37 |
+
return gt_label
|
38 |
+
|
39 |
+
|
iebins/networks/newcrf_layers.py
ADDED
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.utils.checkpoint as checkpoint
|
5 |
+
import numpy as np
|
6 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
7 |
+
|
8 |
+
|
9 |
+
class Mlp(nn.Module):
|
10 |
+
""" Multilayer perceptron."""
|
11 |
+
|
12 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
13 |
+
super().__init__()
|
14 |
+
out_features = out_features or in_features
|
15 |
+
hidden_features = hidden_features or in_features
|
16 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
17 |
+
self.act = act_layer()
|
18 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
19 |
+
self.drop = nn.Dropout(drop)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = self.fc1(x)
|
23 |
+
x = self.act(x)
|
24 |
+
x = self.drop(x)
|
25 |
+
x = self.fc2(x)
|
26 |
+
x = self.drop(x)
|
27 |
+
return x
|
28 |
+
|
29 |
+
|
30 |
+
def window_partition(x, window_size):
|
31 |
+
"""
|
32 |
+
Args:
|
33 |
+
x: (B, H, W, C)
|
34 |
+
window_size (int): window size
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
windows: (num_windows*B, window_size, window_size, C)
|
38 |
+
"""
|
39 |
+
B, H, W, C = x.shape
|
40 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
41 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
42 |
+
return windows
|
43 |
+
|
44 |
+
|
45 |
+
def window_reverse(windows, window_size, H, W):
|
46 |
+
"""
|
47 |
+
Args:
|
48 |
+
windows: (num_windows*B, window_size, window_size, C)
|
49 |
+
window_size (int): Window size
|
50 |
+
H (int): Height of image
|
51 |
+
W (int): Width of image
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
x: (B, H, W, C)
|
55 |
+
"""
|
56 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
57 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
58 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
59 |
+
return x
|
60 |
+
|
61 |
+
|
62 |
+
class WindowAttention(nn.Module):
|
63 |
+
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
64 |
+
It supports both of shifted and non-shifted window.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
dim (int): Number of input channels.
|
68 |
+
window_size (tuple[int]): The height and width of the window.
|
69 |
+
num_heads (int): Number of attention heads.
|
70 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
71 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
72 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
73 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(self, dim, window_size, num_heads, v_dim, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
77 |
+
|
78 |
+
super().__init__()
|
79 |
+
self.dim = dim
|
80 |
+
self.window_size = window_size # Wh, Ww
|
81 |
+
self.num_heads = num_heads
|
82 |
+
head_dim = dim // num_heads
|
83 |
+
self.scale = qk_scale or head_dim ** -0.5
|
84 |
+
|
85 |
+
# define a parameter table of relative position bias
|
86 |
+
self.relative_position_bias_table = nn.Parameter(
|
87 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
88 |
+
|
89 |
+
# get pair-wise relative position index for each token inside the window
|
90 |
+
coords_h = torch.arange(self.window_size[0])
|
91 |
+
coords_w = torch.arange(self.window_size[1])
|
92 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
93 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
94 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
95 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
96 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
97 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
98 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
99 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
100 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
101 |
+
|
102 |
+
self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
103 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
104 |
+
self.proj = nn.Linear(v_dim, v_dim)
|
105 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
106 |
+
|
107 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
108 |
+
self.softmax = nn.Softmax(dim=-1)
|
109 |
+
|
110 |
+
def forward(self, x, v, mask=None):
|
111 |
+
""" Forward function.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
x: input features with shape of (num_windows*B, N, C)
|
115 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
116 |
+
"""
|
117 |
+
B_, N, C = x.shape
|
118 |
+
qk = self.qk(x).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
119 |
+
q, k = qk[0], qk[1] # make torchscript happy (cannot use tensor as tuple)
|
120 |
+
|
121 |
+
q = q * self.scale
|
122 |
+
attn = (q @ k.transpose(-2, -1))
|
123 |
+
|
124 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
125 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
126 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
127 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
128 |
+
|
129 |
+
if mask is not None:
|
130 |
+
nW = mask.shape[0]
|
131 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
132 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
133 |
+
attn = self.softmax(attn)
|
134 |
+
else:
|
135 |
+
attn = self.softmax(attn)
|
136 |
+
|
137 |
+
attn = self.attn_drop(attn)
|
138 |
+
|
139 |
+
# assert self.dim % v.shape[-1] == 0, "self.dim % v.shape[-1] != 0"
|
140 |
+
# repeat_num = self.dim // v.shape[-1]
|
141 |
+
# v = v.view(B_, N, self.num_heads // repeat_num, -1).transpose(1, 2).repeat(1, repeat_num, 1, 1)
|
142 |
+
|
143 |
+
assert self.dim == v.shape[-1], "self.dim != v.shape[-1]"
|
144 |
+
v = v.view(B_, N, self.num_heads, -1).transpose(1, 2)
|
145 |
+
|
146 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
147 |
+
x = self.proj(x)
|
148 |
+
x = self.proj_drop(x)
|
149 |
+
return x
|
150 |
+
|
151 |
+
|
152 |
+
class CRFBlock(nn.Module):
|
153 |
+
""" CRF Block.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
dim (int): Number of input channels.
|
157 |
+
num_heads (int): Number of attention heads.
|
158 |
+
window_size (int): Window size.
|
159 |
+
shift_size (int): Shift size for SW-MSA.
|
160 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
161 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
162 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
163 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
164 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
165 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
166 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
167 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
168 |
+
"""
|
169 |
+
|
170 |
+
def __init__(self, dim, num_heads, v_dim, window_size=7, shift_size=0,
|
171 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
172 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
173 |
+
super().__init__()
|
174 |
+
self.dim = dim
|
175 |
+
self.num_heads = num_heads
|
176 |
+
self.v_dim = v_dim
|
177 |
+
self.window_size = window_size
|
178 |
+
self.shift_size = shift_size
|
179 |
+
self.mlp_ratio = mlp_ratio
|
180 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
181 |
+
|
182 |
+
self.norm1 = norm_layer(dim)
|
183 |
+
self.attn = WindowAttention(
|
184 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, v_dim=v_dim,
|
185 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
186 |
+
|
187 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
188 |
+
self.norm2 = norm_layer(v_dim)
|
189 |
+
mlp_hidden_dim = int(v_dim * mlp_ratio)
|
190 |
+
self.mlp = Mlp(in_features=v_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
191 |
+
|
192 |
+
self.H = None
|
193 |
+
self.W = None
|
194 |
+
|
195 |
+
def forward(self, x, v, mask_matrix):
|
196 |
+
""" Forward function.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
x: Input feature, tensor size (B, H*W, C).
|
200 |
+
H, W: Spatial resolution of the input feature.
|
201 |
+
mask_matrix: Attention mask for cyclic shift.
|
202 |
+
"""
|
203 |
+
B, L, C = x.shape
|
204 |
+
H, W = self.H, self.W
|
205 |
+
assert L == H * W, "input feature has wrong size"
|
206 |
+
|
207 |
+
shortcut = x
|
208 |
+
x = self.norm1(x)
|
209 |
+
x = x.view(B, H, W, C)
|
210 |
+
|
211 |
+
# pad feature maps to multiples of window size
|
212 |
+
pad_l = pad_t = 0
|
213 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
214 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
215 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
216 |
+
v = F.pad(v, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
217 |
+
_, Hp, Wp, _ = x.shape
|
218 |
+
|
219 |
+
# cyclic shift
|
220 |
+
if self.shift_size > 0:
|
221 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
222 |
+
shifted_v = torch.roll(v, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
223 |
+
attn_mask = mask_matrix
|
224 |
+
else:
|
225 |
+
shifted_x = x
|
226 |
+
shifted_v = v
|
227 |
+
attn_mask = None
|
228 |
+
|
229 |
+
# partition windows
|
230 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
231 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
232 |
+
v_windows = window_partition(shifted_v, self.window_size) # nW*B, window_size, window_size, C
|
233 |
+
v_windows = v_windows.view(-1, self.window_size * self.window_size, v_windows.shape[-1]) # nW*B, window_size*window_size, C
|
234 |
+
|
235 |
+
# W-MSA/SW-MSA
|
236 |
+
attn_windows = self.attn(x_windows, v_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
237 |
+
|
238 |
+
# merge windows
|
239 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.v_dim)
|
240 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
241 |
+
|
242 |
+
# reverse cyclic shift
|
243 |
+
if self.shift_size > 0:
|
244 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
245 |
+
else:
|
246 |
+
x = shifted_x
|
247 |
+
|
248 |
+
if pad_r > 0 or pad_b > 0:
|
249 |
+
x = x[:, :H, :W, :].contiguous()
|
250 |
+
|
251 |
+
x = x.view(B, H * W, self.v_dim)
|
252 |
+
|
253 |
+
# FFN
|
254 |
+
x = shortcut + self.drop_path(x)
|
255 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
256 |
+
|
257 |
+
return x
|
258 |
+
|
259 |
+
|
260 |
+
class BasicCRFLayer(nn.Module):
|
261 |
+
""" A basic NeWCRFs layer for one stage.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
dim (int): Number of feature channels
|
265 |
+
depth (int): Depths of this stage.
|
266 |
+
num_heads (int): Number of attention head.
|
267 |
+
window_size (int): Local window size. Default: 7.
|
268 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
269 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
270 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
271 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
272 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
273 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
274 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
275 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
276 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
277 |
+
"""
|
278 |
+
|
279 |
+
def __init__(self,
|
280 |
+
dim,
|
281 |
+
depth,
|
282 |
+
num_heads,
|
283 |
+
v_dim,
|
284 |
+
window_size=7,
|
285 |
+
mlp_ratio=4.,
|
286 |
+
qkv_bias=True,
|
287 |
+
qk_scale=None,
|
288 |
+
drop=0.,
|
289 |
+
attn_drop=0.,
|
290 |
+
drop_path=0.,
|
291 |
+
norm_layer=nn.LayerNorm,
|
292 |
+
downsample=None,
|
293 |
+
use_checkpoint=False):
|
294 |
+
super().__init__()
|
295 |
+
self.window_size = window_size
|
296 |
+
self.shift_size = window_size // 2
|
297 |
+
self.depth = depth
|
298 |
+
self.use_checkpoint = use_checkpoint
|
299 |
+
|
300 |
+
# build blocks
|
301 |
+
self.blocks = nn.ModuleList([
|
302 |
+
CRFBlock(
|
303 |
+
dim=dim,
|
304 |
+
num_heads=num_heads,
|
305 |
+
v_dim=v_dim,
|
306 |
+
window_size=window_size,
|
307 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
308 |
+
mlp_ratio=mlp_ratio,
|
309 |
+
qkv_bias=qkv_bias,
|
310 |
+
qk_scale=qk_scale,
|
311 |
+
drop=drop,
|
312 |
+
attn_drop=attn_drop,
|
313 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
314 |
+
norm_layer=norm_layer)
|
315 |
+
for i in range(depth)])
|
316 |
+
|
317 |
+
# patch merging layer
|
318 |
+
if downsample is not None:
|
319 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
320 |
+
else:
|
321 |
+
self.downsample = None
|
322 |
+
|
323 |
+
def forward(self, x, v, H, W):
|
324 |
+
""" Forward function.
|
325 |
+
|
326 |
+
Args:
|
327 |
+
x: Input feature, tensor size (B, H*W, C).
|
328 |
+
H, W: Spatial resolution of the input feature.
|
329 |
+
"""
|
330 |
+
|
331 |
+
# calculate attention mask for SW-MSA
|
332 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
333 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
334 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
335 |
+
h_slices = (slice(0, -self.window_size),
|
336 |
+
slice(-self.window_size, -self.shift_size),
|
337 |
+
slice(-self.shift_size, None))
|
338 |
+
w_slices = (slice(0, -self.window_size),
|
339 |
+
slice(-self.window_size, -self.shift_size),
|
340 |
+
slice(-self.shift_size, None))
|
341 |
+
cnt = 0
|
342 |
+
for h in h_slices:
|
343 |
+
for w in w_slices:
|
344 |
+
img_mask[:, h, w, :] = cnt
|
345 |
+
cnt += 1
|
346 |
+
|
347 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
348 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
349 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
350 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
351 |
+
|
352 |
+
for blk in self.blocks:
|
353 |
+
blk.H, blk.W = H, W
|
354 |
+
if self.use_checkpoint:
|
355 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
356 |
+
else:
|
357 |
+
x = blk(x, v, attn_mask)
|
358 |
+
if self.downsample is not None:
|
359 |
+
x_down = self.downsample(x, H, W)
|
360 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
361 |
+
return x, H, W, x_down, Wh, Ww
|
362 |
+
else:
|
363 |
+
return x, H, W, x, H, W
|
364 |
+
|
365 |
+
|
366 |
+
class NewCRF(nn.Module):
|
367 |
+
def __init__(self,
|
368 |
+
input_dim=96,
|
369 |
+
embed_dim=96,
|
370 |
+
v_dim=64,
|
371 |
+
window_size=7,
|
372 |
+
num_heads=4,
|
373 |
+
depth=2,
|
374 |
+
patch_size=4,
|
375 |
+
in_chans=3,
|
376 |
+
norm_layer=nn.LayerNorm,
|
377 |
+
patch_norm=True):
|
378 |
+
super().__init__()
|
379 |
+
|
380 |
+
self.embed_dim = embed_dim
|
381 |
+
self.patch_norm = patch_norm
|
382 |
+
|
383 |
+
if input_dim != embed_dim:
|
384 |
+
self.proj_x = nn.Conv2d(input_dim, embed_dim, 3, padding=1)
|
385 |
+
else:
|
386 |
+
self.proj_x = None
|
387 |
+
|
388 |
+
if v_dim != embed_dim:
|
389 |
+
self.proj_v = nn.Conv2d(v_dim, embed_dim, 3, padding=1)
|
390 |
+
elif embed_dim % v_dim == 0:
|
391 |
+
self.proj_v = None
|
392 |
+
|
393 |
+
# For now, v_dim need to be equal to embed_dim, because the output of window-attn is the input of shift-window-attn
|
394 |
+
v_dim = embed_dim
|
395 |
+
assert v_dim == embed_dim
|
396 |
+
|
397 |
+
self.crf_layer = BasicCRFLayer(
|
398 |
+
dim=embed_dim,
|
399 |
+
depth=depth,
|
400 |
+
num_heads=num_heads,
|
401 |
+
v_dim=v_dim,
|
402 |
+
window_size=window_size,
|
403 |
+
mlp_ratio=4.,
|
404 |
+
qkv_bias=True,
|
405 |
+
qk_scale=None,
|
406 |
+
drop=0.,
|
407 |
+
attn_drop=0.,
|
408 |
+
drop_path=0.,
|
409 |
+
norm_layer=norm_layer,
|
410 |
+
downsample=None,
|
411 |
+
use_checkpoint=False)
|
412 |
+
|
413 |
+
layer = norm_layer(embed_dim)
|
414 |
+
layer_name = 'norm_crf'
|
415 |
+
self.add_module(layer_name, layer)
|
416 |
+
|
417 |
+
|
418 |
+
def forward(self, x, v):
|
419 |
+
if self.proj_x is not None:
|
420 |
+
x = self.proj_x(x)
|
421 |
+
if self.proj_v is not None:
|
422 |
+
v = self.proj_v(v)
|
423 |
+
|
424 |
+
Wh, Ww = x.size(2), x.size(3)
|
425 |
+
x = x.flatten(2).transpose(1, 2)
|
426 |
+
v = v.transpose(1, 2).transpose(2, 3)
|
427 |
+
|
428 |
+
x_out, H, W, x, Wh, Ww = self.crf_layer(x, v, Wh, Ww)
|
429 |
+
norm_layer = getattr(self, f'norm_crf')
|
430 |
+
x_out = norm_layer(x_out)
|
431 |
+
out = x_out.view(-1, H, W, self.embed_dim).permute(0, 3, 1, 2).contiguous()
|
432 |
+
|
433 |
+
return out
|
iebins/networks/newcrf_utils.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
import os
|
3 |
+
import os.path as osp
|
4 |
+
import pkgutil
|
5 |
+
import warnings
|
6 |
+
from collections import OrderedDict
|
7 |
+
from importlib import import_module
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torchvision
|
11 |
+
import torch.nn as nn
|
12 |
+
from torch.utils import model_zoo
|
13 |
+
from torch.nn import functional as F
|
14 |
+
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
15 |
+
from torch import distributed as dist
|
16 |
+
|
17 |
+
TORCH_VERSION = torch.__version__
|
18 |
+
|
19 |
+
|
20 |
+
def resize(input,
|
21 |
+
size=None,
|
22 |
+
scale_factor=None,
|
23 |
+
mode='nearest',
|
24 |
+
align_corners=None,
|
25 |
+
warning=True):
|
26 |
+
if warning:
|
27 |
+
if size is not None and align_corners:
|
28 |
+
input_h, input_w = tuple(int(x) for x in input.shape[2:])
|
29 |
+
output_h, output_w = tuple(int(x) for x in size)
|
30 |
+
if output_h > input_h or output_w > output_h:
|
31 |
+
if ((output_h > 1 and output_w > 1 and input_h > 1
|
32 |
+
and input_w > 1) and (output_h - 1) % (input_h - 1)
|
33 |
+
and (output_w - 1) % (input_w - 1)):
|
34 |
+
warnings.warn(
|
35 |
+
f'When align_corners={align_corners}, '
|
36 |
+
'the output would more aligned if '
|
37 |
+
f'input size {(input_h, input_w)} is `x+1` and '
|
38 |
+
f'out size {(output_h, output_w)} is `nx+1`')
|
39 |
+
if isinstance(size, torch.Size):
|
40 |
+
size = tuple(int(x) for x in size)
|
41 |
+
return F.interpolate(input, size, scale_factor, mode, align_corners)
|
42 |
+
|
43 |
+
|
44 |
+
def normal_init(module, mean=0, std=1, bias=0):
|
45 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
46 |
+
nn.init.normal_(module.weight, mean, std)
|
47 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
48 |
+
nn.init.constant_(module.bias, bias)
|
49 |
+
|
50 |
+
|
51 |
+
def is_module_wrapper(module):
|
52 |
+
module_wrappers = (DataParallel, DistributedDataParallel)
|
53 |
+
return isinstance(module, module_wrappers)
|
54 |
+
|
55 |
+
|
56 |
+
def get_dist_info():
|
57 |
+
if TORCH_VERSION < '1.0':
|
58 |
+
initialized = dist._initialized
|
59 |
+
else:
|
60 |
+
if dist.is_available():
|
61 |
+
initialized = dist.is_initialized()
|
62 |
+
else:
|
63 |
+
initialized = False
|
64 |
+
if initialized:
|
65 |
+
rank = dist.get_rank()
|
66 |
+
world_size = dist.get_world_size()
|
67 |
+
else:
|
68 |
+
rank = 0
|
69 |
+
world_size = 1
|
70 |
+
return rank, world_size
|
71 |
+
|
72 |
+
|
73 |
+
def load_state_dict(module, state_dict, strict=False, logger=None):
|
74 |
+
"""Load state_dict to a module.
|
75 |
+
|
76 |
+
This method is modified from :meth:`torch.nn.Module.load_state_dict`.
|
77 |
+
Default value for ``strict`` is set to ``False`` and the message for
|
78 |
+
param mismatch will be shown even if strict is False.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
module (Module): Module that receives the state_dict.
|
82 |
+
state_dict (OrderedDict): Weights.
|
83 |
+
strict (bool): whether to strictly enforce that the keys
|
84 |
+
in :attr:`state_dict` match the keys returned by this module's
|
85 |
+
:meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
|
86 |
+
logger (:obj:`logging.Logger`, optional): Logger to log the error
|
87 |
+
message. If not specified, print function will be used.
|
88 |
+
"""
|
89 |
+
unexpected_keys = []
|
90 |
+
all_missing_keys = []
|
91 |
+
err_msg = []
|
92 |
+
|
93 |
+
metadata = getattr(state_dict, '_metadata', None)
|
94 |
+
state_dict = state_dict.copy()
|
95 |
+
if metadata is not None:
|
96 |
+
state_dict._metadata = metadata
|
97 |
+
|
98 |
+
# use _load_from_state_dict to enable checkpoint version control
|
99 |
+
def load(module, prefix=''):
|
100 |
+
# recursively check parallel module in case that the model has a
|
101 |
+
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
102 |
+
if is_module_wrapper(module):
|
103 |
+
module = module.module
|
104 |
+
local_metadata = {} if metadata is None else metadata.get(
|
105 |
+
prefix[:-1], {})
|
106 |
+
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
|
107 |
+
all_missing_keys, unexpected_keys,
|
108 |
+
err_msg)
|
109 |
+
for name, child in module._modules.items():
|
110 |
+
if child is not None:
|
111 |
+
load(child, prefix + name + '.')
|
112 |
+
|
113 |
+
load(module)
|
114 |
+
load = None # break load->load reference cycle
|
115 |
+
|
116 |
+
# ignore "num_batches_tracked" of BN layers
|
117 |
+
missing_keys = [
|
118 |
+
key for key in all_missing_keys if 'num_batches_tracked' not in key
|
119 |
+
]
|
120 |
+
|
121 |
+
if unexpected_keys:
|
122 |
+
err_msg.append('unexpected key in source '
|
123 |
+
f'state_dict: {", ".join(unexpected_keys)}\n')
|
124 |
+
if missing_keys:
|
125 |
+
err_msg.append(
|
126 |
+
f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
|
127 |
+
|
128 |
+
rank, _ = get_dist_info()
|
129 |
+
if len(err_msg) > 0 and rank == 0:
|
130 |
+
err_msg.insert(
|
131 |
+
0, 'The model and loaded state dict do not match exactly\n')
|
132 |
+
err_msg = '\n'.join(err_msg)
|
133 |
+
if strict:
|
134 |
+
raise RuntimeError(err_msg)
|
135 |
+
elif logger is not None:
|
136 |
+
logger.warning(err_msg)
|
137 |
+
else:
|
138 |
+
print(err_msg)
|
139 |
+
|
140 |
+
|
141 |
+
def load_url_dist(url, model_dir=None):
|
142 |
+
"""In distributed setting, this function only download checkpoint at local
|
143 |
+
rank 0."""
|
144 |
+
rank, world_size = get_dist_info()
|
145 |
+
rank = int(os.environ.get('LOCAL_RANK', rank))
|
146 |
+
if rank == 0:
|
147 |
+
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
|
148 |
+
if world_size > 1:
|
149 |
+
torch.distributed.barrier()
|
150 |
+
if rank > 0:
|
151 |
+
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
|
152 |
+
return checkpoint
|
153 |
+
|
154 |
+
|
155 |
+
def get_torchvision_models():
|
156 |
+
model_urls = dict()
|
157 |
+
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
|
158 |
+
if ispkg:
|
159 |
+
continue
|
160 |
+
_zoo = import_module(f'torchvision.models.{name}')
|
161 |
+
if hasattr(_zoo, 'model_urls'):
|
162 |
+
_urls = getattr(_zoo, 'model_urls')
|
163 |
+
model_urls.update(_urls)
|
164 |
+
return model_urls
|
165 |
+
|
166 |
+
|
167 |
+
def _load_checkpoint(filename, map_location=None):
|
168 |
+
"""Load checkpoint from somewhere (modelzoo, file, url).
|
169 |
+
|
170 |
+
Args:
|
171 |
+
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
172 |
+
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
173 |
+
details.
|
174 |
+
map_location (str | None): Same as :func:`torch.load`. Default: None.
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
dict | OrderedDict: The loaded checkpoint. It can be either an
|
178 |
+
OrderedDict storing model weights or a dict containing other
|
179 |
+
information, which depends on the checkpoint.
|
180 |
+
"""
|
181 |
+
if filename.startswith('modelzoo://'):
|
182 |
+
warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
|
183 |
+
'use "torchvision://" instead')
|
184 |
+
model_urls = get_torchvision_models()
|
185 |
+
model_name = filename[11:]
|
186 |
+
checkpoint = load_url_dist(model_urls[model_name])
|
187 |
+
else:
|
188 |
+
if not osp.isfile(filename):
|
189 |
+
raise IOError(f'{filename} is not a checkpoint file')
|
190 |
+
checkpoint = torch.load(filename, map_location=map_location)
|
191 |
+
return checkpoint
|
192 |
+
|
193 |
+
|
194 |
+
def load_checkpoint(model,
|
195 |
+
filename,
|
196 |
+
map_location='cpu',
|
197 |
+
strict=False,
|
198 |
+
logger=None):
|
199 |
+
"""Load checkpoint from a file or URI.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
model (Module): Module to load checkpoint.
|
203 |
+
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
204 |
+
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
205 |
+
details.
|
206 |
+
map_location (str): Same as :func:`torch.load`.
|
207 |
+
strict (bool): Whether to allow different params for the model and
|
208 |
+
checkpoint.
|
209 |
+
logger (:mod:`logging.Logger` or None): The logger for error message.
|
210 |
+
|
211 |
+
Returns:
|
212 |
+
dict or OrderedDict: The loaded checkpoint.
|
213 |
+
"""
|
214 |
+
checkpoint = _load_checkpoint(filename, map_location)
|
215 |
+
# OrderedDict is a subclass of dict
|
216 |
+
if not isinstance(checkpoint, dict):
|
217 |
+
raise RuntimeError(
|
218 |
+
f'No state_dict found in checkpoint file {filename}')
|
219 |
+
# get state_dict from checkpoint
|
220 |
+
if 'state_dict' in checkpoint:
|
221 |
+
state_dict = checkpoint['state_dict']
|
222 |
+
elif 'model' in checkpoint:
|
223 |
+
state_dict = checkpoint['model']
|
224 |
+
else:
|
225 |
+
state_dict = checkpoint
|
226 |
+
# strip prefix of state_dict
|
227 |
+
if list(state_dict.keys())[0].startswith('module.'):
|
228 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
229 |
+
|
230 |
+
# for MoBY, load model of online branch
|
231 |
+
if sorted(list(state_dict.keys()))[0].startswith('encoder'):
|
232 |
+
state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
|
233 |
+
|
234 |
+
# reshape absolute position embedding
|
235 |
+
if state_dict.get('absolute_pos_embed') is not None:
|
236 |
+
absolute_pos_embed = state_dict['absolute_pos_embed']
|
237 |
+
N1, L, C1 = absolute_pos_embed.size()
|
238 |
+
N2, C2, H, W = model.absolute_pos_embed.size()
|
239 |
+
if N1 != N2 or C1 != C2 or L != H*W:
|
240 |
+
logger.warning("Error in loading absolute_pos_embed, pass")
|
241 |
+
else:
|
242 |
+
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
|
243 |
+
|
244 |
+
# interpolate position bias table if needed
|
245 |
+
relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
|
246 |
+
for table_key in relative_position_bias_table_keys:
|
247 |
+
table_pretrained = state_dict[table_key]
|
248 |
+
table_current = model.state_dict()[table_key]
|
249 |
+
L1, nH1 = table_pretrained.size()
|
250 |
+
L2, nH2 = table_current.size()
|
251 |
+
if nH1 != nH2:
|
252 |
+
logger.warning(f"Error in loading {table_key}, pass")
|
253 |
+
else:
|
254 |
+
if L1 != L2:
|
255 |
+
S1 = int(L1 ** 0.5)
|
256 |
+
S2 = int(L2 ** 0.5)
|
257 |
+
table_pretrained_resized = F.interpolate(
|
258 |
+
table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
|
259 |
+
size=(S2, S2), mode='bicubic')
|
260 |
+
state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
|
261 |
+
|
262 |
+
# load state_dict
|
263 |
+
load_state_dict(model, state_dict, strict, logger)
|
264 |
+
return checkpoint
|
iebins/networks/resize.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
def resize(input,
|
9 |
+
size=None,
|
10 |
+
scale_factor=None,
|
11 |
+
mode='nearest',
|
12 |
+
align_corners=None,
|
13 |
+
warning=False):
|
14 |
+
if warning:
|
15 |
+
if size is not None and align_corners:
|
16 |
+
input_h, input_w = tuple(int(x) for x in input.shape[2:])
|
17 |
+
output_h, output_w = tuple(int(x) for x in size)
|
18 |
+
if output_h > input_h or output_w > output_h:
|
19 |
+
if ((output_h > 1 and output_w > 1 and input_h > 1
|
20 |
+
and input_w > 1) and (output_h - 1) % (input_h - 1)
|
21 |
+
and (output_w - 1) % (input_w - 1)):
|
22 |
+
warnings.warn(
|
23 |
+
f'When align_corners={align_corners}, '
|
24 |
+
'the output would more aligned if '
|
25 |
+
f'input size {(input_h, input_w)} is `x+1` and '
|
26 |
+
f'out size {(output_h, output_w)} is `nx+1`')
|
27 |
+
return F.interpolate(input, size, scale_factor, mode, align_corners)
|
28 |
+
|
29 |
+
|
30 |
+
class Upsample(nn.Module):
|
31 |
+
|
32 |
+
def __init__(self,
|
33 |
+
size=None,
|
34 |
+
scale_factor=None,
|
35 |
+
mode='nearest',
|
36 |
+
align_corners=None):
|
37 |
+
super(Upsample, self).__init__()
|
38 |
+
self.size = size
|
39 |
+
if isinstance(scale_factor, tuple):
|
40 |
+
self.scale_factor = tuple(float(factor) for factor in scale_factor)
|
41 |
+
else:
|
42 |
+
self.scale_factor = float(scale_factor) if scale_factor else None
|
43 |
+
self.mode = mode
|
44 |
+
self.align_corners = align_corners
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
if not self.size:
|
48 |
+
size = [int(t * self.scale_factor) for t in x.shape[-2:]]
|
49 |
+
else:
|
50 |
+
size = self.size
|
51 |
+
return resize(x, size, None, self.mode, self.align_corners)
|
iebins/networks/swin_transformer.py
ADDED
@@ -0,0 +1,620 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.utils.checkpoint as checkpoint
|
5 |
+
import numpy as np
|
6 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
7 |
+
|
8 |
+
from .newcrf_utils import load_checkpoint
|
9 |
+
|
10 |
+
|
11 |
+
class Mlp(nn.Module):
|
12 |
+
""" Multilayer perceptron."""
|
13 |
+
|
14 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
15 |
+
super().__init__()
|
16 |
+
out_features = out_features or in_features
|
17 |
+
hidden_features = hidden_features or in_features
|
18 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
19 |
+
self.act = act_layer()
|
20 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
21 |
+
self.drop = nn.Dropout(drop)
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
x = self.fc1(x)
|
25 |
+
x = self.act(x)
|
26 |
+
x = self.drop(x)
|
27 |
+
x = self.fc2(x)
|
28 |
+
x = self.drop(x)
|
29 |
+
return x
|
30 |
+
|
31 |
+
|
32 |
+
def window_partition(x, window_size):
|
33 |
+
"""
|
34 |
+
Args:
|
35 |
+
x: (B, H, W, C)
|
36 |
+
window_size (int): window size
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
windows: (num_windows*B, window_size, window_size, C)
|
40 |
+
"""
|
41 |
+
B, H, W, C = x.shape
|
42 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
43 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
44 |
+
return windows
|
45 |
+
|
46 |
+
|
47 |
+
def window_reverse(windows, window_size, H, W):
|
48 |
+
"""
|
49 |
+
Args:
|
50 |
+
windows: (num_windows*B, window_size, window_size, C)
|
51 |
+
window_size (int): Window size
|
52 |
+
H (int): Height of image
|
53 |
+
W (int): Width of image
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
x: (B, H, W, C)
|
57 |
+
"""
|
58 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
59 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
60 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
61 |
+
return x
|
62 |
+
|
63 |
+
|
64 |
+
class WindowAttention(nn.Module):
|
65 |
+
""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
66 |
+
It supports both of shifted and non-shifted window.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
dim (int): Number of input channels.
|
70 |
+
window_size (tuple[int]): The height and width of the window.
|
71 |
+
num_heads (int): Number of attention heads.
|
72 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
73 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
74 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
75 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
79 |
+
|
80 |
+
super().__init__()
|
81 |
+
self.dim = dim
|
82 |
+
self.window_size = window_size # Wh, Ww
|
83 |
+
self.num_heads = num_heads
|
84 |
+
head_dim = dim // num_heads
|
85 |
+
self.scale = qk_scale or head_dim ** -0.5
|
86 |
+
|
87 |
+
# define a parameter table of relative position bias
|
88 |
+
self.relative_position_bias_table = nn.Parameter(
|
89 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
90 |
+
|
91 |
+
# get pair-wise relative position index for each token inside the window
|
92 |
+
coords_h = torch.arange(self.window_size[0])
|
93 |
+
coords_w = torch.arange(self.window_size[1])
|
94 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
95 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
96 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
97 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
98 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
99 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
100 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
101 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
102 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
103 |
+
|
104 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
105 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
106 |
+
self.proj = nn.Linear(dim, dim)
|
107 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
108 |
+
|
109 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
110 |
+
self.softmax = nn.Softmax(dim=-1)
|
111 |
+
|
112 |
+
def forward(self, x, mask=None):
|
113 |
+
""" Forward function.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
x: input features with shape of (num_windows*B, N, C)
|
117 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
118 |
+
"""
|
119 |
+
B_, N, C = x.shape
|
120 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
121 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
122 |
+
|
123 |
+
q = q * self.scale
|
124 |
+
attn = (q @ k.transpose(-2, -1))
|
125 |
+
|
126 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
127 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
128 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
129 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
130 |
+
|
131 |
+
if mask is not None:
|
132 |
+
nW = mask.shape[0]
|
133 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
134 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
135 |
+
attn = self.softmax(attn)
|
136 |
+
else:
|
137 |
+
attn = self.softmax(attn)
|
138 |
+
|
139 |
+
attn = self.attn_drop(attn)
|
140 |
+
|
141 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
142 |
+
x = self.proj(x)
|
143 |
+
x = self.proj_drop(x)
|
144 |
+
return x
|
145 |
+
|
146 |
+
|
147 |
+
class SwinTransformerBlock(nn.Module):
|
148 |
+
""" Swin Transformer Block.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
dim (int): Number of input channels.
|
152 |
+
num_heads (int): Number of attention heads.
|
153 |
+
window_size (int): Window size.
|
154 |
+
shift_size (int): Shift size for SW-MSA.
|
155 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
156 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
157 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
158 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
159 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
160 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
161 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
162 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
163 |
+
"""
|
164 |
+
|
165 |
+
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
|
166 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
167 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
168 |
+
super().__init__()
|
169 |
+
self.dim = dim
|
170 |
+
self.num_heads = num_heads
|
171 |
+
self.window_size = window_size
|
172 |
+
self.shift_size = shift_size
|
173 |
+
self.mlp_ratio = mlp_ratio
|
174 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
175 |
+
|
176 |
+
self.norm1 = norm_layer(dim)
|
177 |
+
self.attn = WindowAttention(
|
178 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
179 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
180 |
+
|
181 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
182 |
+
self.norm2 = norm_layer(dim)
|
183 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
184 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
185 |
+
|
186 |
+
self.H = None
|
187 |
+
self.W = None
|
188 |
+
|
189 |
+
def forward(self, x, mask_matrix):
|
190 |
+
""" Forward function.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
x: Input feature, tensor size (B, H*W, C).
|
194 |
+
H, W: Spatial resolution of the input feature.
|
195 |
+
mask_matrix: Attention mask for cyclic shift.
|
196 |
+
"""
|
197 |
+
B, L, C = x.shape
|
198 |
+
H, W = self.H, self.W
|
199 |
+
assert L == H * W, "input feature has wrong size"
|
200 |
+
|
201 |
+
shortcut = x
|
202 |
+
x = self.norm1(x)
|
203 |
+
x = x.view(B, H, W, C)
|
204 |
+
|
205 |
+
# pad feature maps to multiples of window size
|
206 |
+
pad_l = pad_t = 0
|
207 |
+
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
208 |
+
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
209 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
210 |
+
_, Hp, Wp, _ = x.shape
|
211 |
+
|
212 |
+
# cyclic shift
|
213 |
+
if self.shift_size > 0:
|
214 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
215 |
+
attn_mask = mask_matrix
|
216 |
+
else:
|
217 |
+
shifted_x = x
|
218 |
+
attn_mask = None
|
219 |
+
|
220 |
+
# partition windows
|
221 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
222 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
223 |
+
|
224 |
+
# W-MSA/SW-MSA
|
225 |
+
attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
|
226 |
+
|
227 |
+
# merge windows
|
228 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
229 |
+
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
230 |
+
|
231 |
+
# reverse cyclic shift
|
232 |
+
if self.shift_size > 0:
|
233 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
234 |
+
else:
|
235 |
+
x = shifted_x
|
236 |
+
|
237 |
+
if pad_r > 0 or pad_b > 0:
|
238 |
+
x = x[:, :H, :W, :].contiguous()
|
239 |
+
|
240 |
+
x = x.view(B, H * W, C)
|
241 |
+
|
242 |
+
# FFN
|
243 |
+
x = shortcut + self.drop_path(x)
|
244 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
245 |
+
|
246 |
+
return x
|
247 |
+
|
248 |
+
|
249 |
+
class PatchMerging(nn.Module):
|
250 |
+
""" Patch Merging Layer
|
251 |
+
|
252 |
+
Args:
|
253 |
+
dim (int): Number of input channels.
|
254 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
255 |
+
"""
|
256 |
+
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
257 |
+
super().__init__()
|
258 |
+
self.dim = dim
|
259 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
260 |
+
self.norm = norm_layer(4 * dim)
|
261 |
+
|
262 |
+
def forward(self, x, H, W):
|
263 |
+
""" Forward function.
|
264 |
+
|
265 |
+
Args:
|
266 |
+
x: Input feature, tensor size (B, H*W, C).
|
267 |
+
H, W: Spatial resolution of the input feature.
|
268 |
+
"""
|
269 |
+
B, L, C = x.shape
|
270 |
+
assert L == H * W, "input feature has wrong size"
|
271 |
+
|
272 |
+
x = x.view(B, H, W, C)
|
273 |
+
|
274 |
+
# padding
|
275 |
+
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
276 |
+
if pad_input:
|
277 |
+
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
278 |
+
|
279 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
280 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
281 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
282 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
283 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
284 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
285 |
+
|
286 |
+
x = self.norm(x)
|
287 |
+
x = self.reduction(x)
|
288 |
+
|
289 |
+
return x
|
290 |
+
|
291 |
+
|
292 |
+
class BasicLayer(nn.Module):
|
293 |
+
""" A basic Swin Transformer layer for one stage.
|
294 |
+
|
295 |
+
Args:
|
296 |
+
dim (int): Number of feature channels
|
297 |
+
depth (int): Depths of this stage.
|
298 |
+
num_heads (int): Number of attention head.
|
299 |
+
window_size (int): Local window size. Default: 7.
|
300 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
301 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
302 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
303 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
304 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
305 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
306 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
307 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
308 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
309 |
+
"""
|
310 |
+
|
311 |
+
def __init__(self,
|
312 |
+
dim,
|
313 |
+
depth,
|
314 |
+
num_heads,
|
315 |
+
window_size=7,
|
316 |
+
mlp_ratio=4.,
|
317 |
+
qkv_bias=True,
|
318 |
+
qk_scale=None,
|
319 |
+
drop=0.,
|
320 |
+
attn_drop=0.,
|
321 |
+
drop_path=0.,
|
322 |
+
norm_layer=nn.LayerNorm,
|
323 |
+
downsample=None,
|
324 |
+
use_checkpoint=False):
|
325 |
+
super().__init__()
|
326 |
+
self.window_size = window_size
|
327 |
+
self.shift_size = window_size // 2
|
328 |
+
self.depth = depth
|
329 |
+
self.use_checkpoint = use_checkpoint
|
330 |
+
|
331 |
+
# build blocks
|
332 |
+
self.blocks = nn.ModuleList([
|
333 |
+
SwinTransformerBlock(
|
334 |
+
dim=dim,
|
335 |
+
num_heads=num_heads,
|
336 |
+
window_size=window_size,
|
337 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
338 |
+
mlp_ratio=mlp_ratio,
|
339 |
+
qkv_bias=qkv_bias,
|
340 |
+
qk_scale=qk_scale,
|
341 |
+
drop=drop,
|
342 |
+
attn_drop=attn_drop,
|
343 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
344 |
+
norm_layer=norm_layer)
|
345 |
+
for i in range(depth)])
|
346 |
+
|
347 |
+
# patch merging layer
|
348 |
+
if downsample is not None:
|
349 |
+
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
350 |
+
else:
|
351 |
+
self.downsample = None
|
352 |
+
|
353 |
+
def forward(self, x, H, W):
|
354 |
+
""" Forward function.
|
355 |
+
|
356 |
+
Args:
|
357 |
+
x: Input feature, tensor size (B, H*W, C).
|
358 |
+
H, W: Spatial resolution of the input feature.
|
359 |
+
"""
|
360 |
+
|
361 |
+
# calculate attention mask for SW-MSA
|
362 |
+
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
363 |
+
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
364 |
+
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
365 |
+
h_slices = (slice(0, -self.window_size),
|
366 |
+
slice(-self.window_size, -self.shift_size),
|
367 |
+
slice(-self.shift_size, None))
|
368 |
+
w_slices = (slice(0, -self.window_size),
|
369 |
+
slice(-self.window_size, -self.shift_size),
|
370 |
+
slice(-self.shift_size, None))
|
371 |
+
cnt = 0
|
372 |
+
for h in h_slices:
|
373 |
+
for w in w_slices:
|
374 |
+
img_mask[:, h, w, :] = cnt
|
375 |
+
cnt += 1
|
376 |
+
|
377 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
378 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
379 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
380 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
381 |
+
|
382 |
+
for blk in self.blocks:
|
383 |
+
blk.H, blk.W = H, W
|
384 |
+
if self.use_checkpoint:
|
385 |
+
x = checkpoint.checkpoint(blk, x, attn_mask)
|
386 |
+
else:
|
387 |
+
x = blk(x, attn_mask)
|
388 |
+
if self.downsample is not None:
|
389 |
+
x_down = self.downsample(x, H, W)
|
390 |
+
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
391 |
+
return x, H, W, x_down, Wh, Ww
|
392 |
+
else:
|
393 |
+
return x, H, W, x, H, W
|
394 |
+
|
395 |
+
|
396 |
+
class PatchEmbed(nn.Module):
|
397 |
+
""" Image to Patch Embedding
|
398 |
+
|
399 |
+
Args:
|
400 |
+
patch_size (int): Patch token size. Default: 4.
|
401 |
+
in_chans (int): Number of input image channels. Default: 3.
|
402 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
403 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
404 |
+
"""
|
405 |
+
|
406 |
+
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
407 |
+
super().__init__()
|
408 |
+
patch_size = to_2tuple(patch_size)
|
409 |
+
self.patch_size = patch_size
|
410 |
+
|
411 |
+
self.in_chans = in_chans
|
412 |
+
self.embed_dim = embed_dim
|
413 |
+
|
414 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
415 |
+
if norm_layer is not None:
|
416 |
+
self.norm = norm_layer(embed_dim)
|
417 |
+
else:
|
418 |
+
self.norm = None
|
419 |
+
|
420 |
+
def forward(self, x):
|
421 |
+
"""Forward function."""
|
422 |
+
# padding
|
423 |
+
_, _, H, W = x.size()
|
424 |
+
if W % self.patch_size[1] != 0:
|
425 |
+
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
426 |
+
if H % self.patch_size[0] != 0:
|
427 |
+
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
428 |
+
|
429 |
+
x = self.proj(x) # B C Wh Ww
|
430 |
+
if self.norm is not None:
|
431 |
+
Wh, Ww = x.size(2), x.size(3)
|
432 |
+
x = x.flatten(2).transpose(1, 2)
|
433 |
+
x = self.norm(x)
|
434 |
+
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
435 |
+
|
436 |
+
return x
|
437 |
+
|
438 |
+
|
439 |
+
class SwinTransformer(nn.Module):
|
440 |
+
""" Swin Transformer backbone.
|
441 |
+
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
442 |
+
https://arxiv.org/pdf/2103.14030
|
443 |
+
|
444 |
+
Args:
|
445 |
+
pretrain_img_size (int): Input image size for training the pretrained model,
|
446 |
+
used in absolute postion embedding. Default 224.
|
447 |
+
patch_size (int | tuple(int)): Patch size. Default: 4.
|
448 |
+
in_chans (int): Number of input image channels. Default: 3.
|
449 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
450 |
+
depths (tuple[int]): Depths of each Swin Transformer stage.
|
451 |
+
num_heads (tuple[int]): Number of attention head of each stage.
|
452 |
+
window_size (int): Window size. Default: 7.
|
453 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
|
454 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
455 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
|
456 |
+
drop_rate (float): Dropout rate.
|
457 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0.
|
458 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.2.
|
459 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
460 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
|
461 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True.
|
462 |
+
out_indices (Sequence[int]): Output from which stages.
|
463 |
+
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
|
464 |
+
-1 means not freezing any parameters.
|
465 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
466 |
+
"""
|
467 |
+
|
468 |
+
def __init__(self,
|
469 |
+
pretrain_img_size=224,
|
470 |
+
patch_size=4,
|
471 |
+
in_chans=3,
|
472 |
+
embed_dim=96,
|
473 |
+
depths=[2, 2, 6, 2],
|
474 |
+
num_heads=[3, 6, 12, 24],
|
475 |
+
window_size=7,
|
476 |
+
mlp_ratio=4.,
|
477 |
+
qkv_bias=True,
|
478 |
+
qk_scale=None,
|
479 |
+
drop_rate=0.,
|
480 |
+
attn_drop_rate=0.,
|
481 |
+
drop_path_rate=0.2,
|
482 |
+
norm_layer=nn.LayerNorm,
|
483 |
+
ape=False,
|
484 |
+
patch_norm=True,
|
485 |
+
out_indices=(0, 1, 2, 3),
|
486 |
+
frozen_stages=-1,
|
487 |
+
use_checkpoint=False):
|
488 |
+
super().__init__()
|
489 |
+
|
490 |
+
self.pretrain_img_size = pretrain_img_size
|
491 |
+
self.num_layers = len(depths)
|
492 |
+
self.embed_dim = embed_dim
|
493 |
+
self.ape = ape
|
494 |
+
self.patch_norm = patch_norm
|
495 |
+
self.out_indices = out_indices
|
496 |
+
self.frozen_stages = frozen_stages
|
497 |
+
|
498 |
+
# split image into non-overlapping patches
|
499 |
+
self.patch_embed = PatchEmbed(
|
500 |
+
patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
501 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
502 |
+
|
503 |
+
# absolute position embedding
|
504 |
+
if self.ape:
|
505 |
+
pretrain_img_size = to_2tuple(pretrain_img_size)
|
506 |
+
patch_size = to_2tuple(patch_size)
|
507 |
+
patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
|
508 |
+
|
509 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
|
510 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
511 |
+
|
512 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
513 |
+
|
514 |
+
# stochastic depth
|
515 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
516 |
+
|
517 |
+
# build layers
|
518 |
+
self.layers = nn.ModuleList()
|
519 |
+
for i_layer in range(self.num_layers):
|
520 |
+
layer = BasicLayer(
|
521 |
+
dim=int(embed_dim * 2 ** i_layer),
|
522 |
+
depth=depths[i_layer],
|
523 |
+
num_heads=num_heads[i_layer],
|
524 |
+
window_size=window_size,
|
525 |
+
mlp_ratio=mlp_ratio,
|
526 |
+
qkv_bias=qkv_bias,
|
527 |
+
qk_scale=qk_scale,
|
528 |
+
drop=drop_rate,
|
529 |
+
attn_drop=attn_drop_rate,
|
530 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
531 |
+
norm_layer=norm_layer,
|
532 |
+
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
533 |
+
use_checkpoint=use_checkpoint)
|
534 |
+
self.layers.append(layer)
|
535 |
+
|
536 |
+
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
537 |
+
self.num_features = num_features
|
538 |
+
|
539 |
+
# add a norm layer for each output
|
540 |
+
for i_layer in out_indices:
|
541 |
+
layer = norm_layer(num_features[i_layer])
|
542 |
+
layer_name = f'norm{i_layer}'
|
543 |
+
self.add_module(layer_name, layer)
|
544 |
+
|
545 |
+
self._freeze_stages()
|
546 |
+
|
547 |
+
def _freeze_stages(self):
|
548 |
+
if self.frozen_stages >= 0:
|
549 |
+
self.patch_embed.eval()
|
550 |
+
for param in self.patch_embed.parameters():
|
551 |
+
param.requires_grad = False
|
552 |
+
|
553 |
+
if self.frozen_stages >= 1 and self.ape:
|
554 |
+
self.absolute_pos_embed.requires_grad = False
|
555 |
+
|
556 |
+
if self.frozen_stages >= 2:
|
557 |
+
self.pos_drop.eval()
|
558 |
+
for i in range(0, self.frozen_stages - 1):
|
559 |
+
m = self.layers[i]
|
560 |
+
m.eval()
|
561 |
+
for param in m.parameters():
|
562 |
+
param.requires_grad = False
|
563 |
+
|
564 |
+
def init_weights(self, pretrained=None):
|
565 |
+
"""Initialize the weights in backbone.
|
566 |
+
|
567 |
+
Args:
|
568 |
+
pretrained (str, optional): Path to pre-trained weights.
|
569 |
+
Defaults to None.
|
570 |
+
"""
|
571 |
+
|
572 |
+
def _init_weights(m):
|
573 |
+
if isinstance(m, nn.Linear):
|
574 |
+
trunc_normal_(m.weight, std=.02)
|
575 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
576 |
+
nn.init.constant_(m.bias, 0)
|
577 |
+
elif isinstance(m, nn.LayerNorm):
|
578 |
+
nn.init.constant_(m.bias, 0)
|
579 |
+
nn.init.constant_(m.weight, 1.0)
|
580 |
+
|
581 |
+
if isinstance(pretrained, str):
|
582 |
+
self.apply(_init_weights)
|
583 |
+
# logger = get_root_logger()
|
584 |
+
load_checkpoint(self, pretrained, strict=False)
|
585 |
+
elif pretrained is None:
|
586 |
+
self.apply(_init_weights)
|
587 |
+
else:
|
588 |
+
raise TypeError('pretrained must be a str or None')
|
589 |
+
|
590 |
+
def forward(self, x):
|
591 |
+
"""Forward function."""
|
592 |
+
x = self.patch_embed(x)
|
593 |
+
|
594 |
+
Wh, Ww = x.size(2), x.size(3)
|
595 |
+
if self.ape:
|
596 |
+
# interpolate the position embedding to the corresponding size
|
597 |
+
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
|
598 |
+
x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
|
599 |
+
else:
|
600 |
+
x = x.flatten(2).transpose(1, 2)
|
601 |
+
x = self.pos_drop(x)
|
602 |
+
|
603 |
+
outs = []
|
604 |
+
for i in range(self.num_layers):
|
605 |
+
layer = self.layers[i]
|
606 |
+
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
607 |
+
|
608 |
+
if i in self.out_indices:
|
609 |
+
norm_layer = getattr(self, f'norm{i}')
|
610 |
+
x_out = norm_layer(x_out)
|
611 |
+
|
612 |
+
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
613 |
+
outs.append(out)
|
614 |
+
|
615 |
+
return tuple(outs)
|
616 |
+
|
617 |
+
def train(self, mode=True):
|
618 |
+
"""Convert the model into training mode while keep layers freezed."""
|
619 |
+
super(SwinTransformer, self).train(mode)
|
620 |
+
self._freeze_stages()
|
iebins/networks/uper_crf_head.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from mmcv.cnn import ConvModule
|
6 |
+
from .newcrf_utils import resize, normal_init
|
7 |
+
|
8 |
+
|
9 |
+
class PPM(nn.ModuleList):
|
10 |
+
"""Pooling Pyramid Module used in PSPNet.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
14 |
+
Module.
|
15 |
+
in_channels (int): Input channels.
|
16 |
+
channels (int): Channels after modules, before conv_seg.
|
17 |
+
conv_cfg (dict|None): Config of conv layers.
|
18 |
+
norm_cfg (dict|None): Config of norm layers.
|
19 |
+
act_cfg (dict): Config of activation layers.
|
20 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
|
24 |
+
act_cfg, align_corners):
|
25 |
+
super(PPM, self).__init__()
|
26 |
+
self.pool_scales = pool_scales
|
27 |
+
self.align_corners = align_corners
|
28 |
+
self.in_channels = in_channels
|
29 |
+
self.channels = channels
|
30 |
+
self.conv_cfg = conv_cfg
|
31 |
+
self.norm_cfg = norm_cfg
|
32 |
+
self.act_cfg = act_cfg
|
33 |
+
for pool_scale in pool_scales:
|
34 |
+
# == if batch size = 1, BN is not supported, change to GN
|
35 |
+
if pool_scale == 1: norm_cfg = dict(type='GN', requires_grad=True, num_groups=256)
|
36 |
+
self.append(
|
37 |
+
nn.Sequential(
|
38 |
+
nn.AdaptiveAvgPool2d(pool_scale),
|
39 |
+
ConvModule(
|
40 |
+
self.in_channels,
|
41 |
+
self.channels,
|
42 |
+
1,
|
43 |
+
conv_cfg=self.conv_cfg,
|
44 |
+
norm_cfg=norm_cfg,
|
45 |
+
act_cfg=self.act_cfg)))
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
"""Forward function."""
|
49 |
+
ppm_outs = []
|
50 |
+
for ppm in self:
|
51 |
+
ppm_out = ppm(x)
|
52 |
+
upsampled_ppm_out = resize(
|
53 |
+
ppm_out,
|
54 |
+
size=x.size()[2:],
|
55 |
+
mode='bilinear',
|
56 |
+
align_corners=self.align_corners)
|
57 |
+
ppm_outs.append(upsampled_ppm_out)
|
58 |
+
return ppm_outs
|
59 |
+
|
60 |
+
|
61 |
+
class BaseDecodeHead(nn.Module):
|
62 |
+
"""Base class for BaseDecodeHead.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
in_channels (int|Sequence[int]): Input channels.
|
66 |
+
channels (int): Channels after modules, before conv_seg.
|
67 |
+
num_classes (int): Number of classes.
|
68 |
+
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
69 |
+
conv_cfg (dict|None): Config of conv layers. Default: None.
|
70 |
+
norm_cfg (dict|None): Config of norm layers. Default: None.
|
71 |
+
act_cfg (dict): Config of activation layers.
|
72 |
+
Default: dict(type='ReLU')
|
73 |
+
in_index (int|Sequence[int]): Input feature index. Default: -1
|
74 |
+
input_transform (str|None): Transformation type of input features.
|
75 |
+
Options: 'resize_concat', 'multiple_select', None.
|
76 |
+
'resize_concat': Multiple feature maps will be resize to the
|
77 |
+
same size as first one and than concat together.
|
78 |
+
Usually used in FCN head of HRNet.
|
79 |
+
'multiple_select': Multiple feature maps will be bundle into
|
80 |
+
a list and passed into decode head.
|
81 |
+
None: Only one select feature map is allowed.
|
82 |
+
Default: None.
|
83 |
+
loss_decode (dict): Config of decode loss.
|
84 |
+
Default: dict(type='CrossEntropyLoss').
|
85 |
+
ignore_index (int | None): The label index to be ignored. When using
|
86 |
+
masked BCE loss, ignore_index should be set to None. Default: 255
|
87 |
+
sampler (dict|None): The config of segmentation map sampler.
|
88 |
+
Default: None.
|
89 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
90 |
+
Default: False.
|
91 |
+
"""
|
92 |
+
|
93 |
+
def __init__(self,
|
94 |
+
in_channels,
|
95 |
+
channels,
|
96 |
+
*,
|
97 |
+
num_classes,
|
98 |
+
dropout_ratio=0.1,
|
99 |
+
conv_cfg=None,
|
100 |
+
norm_cfg=None,
|
101 |
+
act_cfg=dict(type='ReLU'),
|
102 |
+
in_index=-1,
|
103 |
+
input_transform=None,
|
104 |
+
loss_decode=dict(
|
105 |
+
type='CrossEntropyLoss',
|
106 |
+
use_sigmoid=False,
|
107 |
+
loss_weight=1.0),
|
108 |
+
ignore_index=255,
|
109 |
+
sampler=None,
|
110 |
+
align_corners=False):
|
111 |
+
super(BaseDecodeHead, self).__init__()
|
112 |
+
self._init_inputs(in_channels, in_index, input_transform)
|
113 |
+
self.channels = channels
|
114 |
+
self.num_classes = num_classes
|
115 |
+
self.dropout_ratio = dropout_ratio
|
116 |
+
self.conv_cfg = conv_cfg
|
117 |
+
self.norm_cfg = norm_cfg
|
118 |
+
self.act_cfg = act_cfg
|
119 |
+
self.in_index = in_index
|
120 |
+
# self.loss_decode = build_loss(loss_decode)
|
121 |
+
self.ignore_index = ignore_index
|
122 |
+
self.align_corners = align_corners
|
123 |
+
# if sampler is not None:
|
124 |
+
# self.sampler = build_pixel_sampler(sampler, context=self)
|
125 |
+
# else:
|
126 |
+
# self.sampler = None
|
127 |
+
|
128 |
+
# self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
129 |
+
# self.conv1 = nn.Conv2d(channels, num_classes, 3, padding=1)
|
130 |
+
if dropout_ratio > 0:
|
131 |
+
self.dropout = nn.Dropout2d(dropout_ratio)
|
132 |
+
else:
|
133 |
+
self.dropout = None
|
134 |
+
self.fp16_enabled = False
|
135 |
+
|
136 |
+
def extra_repr(self):
|
137 |
+
"""Extra repr."""
|
138 |
+
s = f'input_transform={self.input_transform}, ' \
|
139 |
+
f'ignore_index={self.ignore_index}, ' \
|
140 |
+
f'align_corners={self.align_corners}'
|
141 |
+
return s
|
142 |
+
|
143 |
+
def _init_inputs(self, in_channels, in_index, input_transform):
|
144 |
+
"""Check and initialize input transforms.
|
145 |
+
|
146 |
+
The in_channels, in_index and input_transform must match.
|
147 |
+
Specifically, when input_transform is None, only single feature map
|
148 |
+
will be selected. So in_channels and in_index must be of type int.
|
149 |
+
When input_transform
|
150 |
+
|
151 |
+
Args:
|
152 |
+
in_channels (int|Sequence[int]): Input channels.
|
153 |
+
in_index (int|Sequence[int]): Input feature index.
|
154 |
+
input_transform (str|None): Transformation type of input features.
|
155 |
+
Options: 'resize_concat', 'multiple_select', None.
|
156 |
+
'resize_concat': Multiple feature maps will be resize to the
|
157 |
+
same size as first one and than concat together.
|
158 |
+
Usually used in FCN head of HRNet.
|
159 |
+
'multiple_select': Multiple feature maps will be bundle into
|
160 |
+
a list and passed into decode head.
|
161 |
+
None: Only one select feature map is allowed.
|
162 |
+
"""
|
163 |
+
|
164 |
+
if input_transform is not None:
|
165 |
+
assert input_transform in ['resize_concat', 'multiple_select']
|
166 |
+
self.input_transform = input_transform
|
167 |
+
self.in_index = in_index
|
168 |
+
if input_transform is not None:
|
169 |
+
assert isinstance(in_channels, (list, tuple))
|
170 |
+
assert isinstance(in_index, (list, tuple))
|
171 |
+
assert len(in_channels) == len(in_index)
|
172 |
+
if input_transform == 'resize_concat':
|
173 |
+
self.in_channels = sum(in_channels)
|
174 |
+
else:
|
175 |
+
self.in_channels = in_channels
|
176 |
+
else:
|
177 |
+
assert isinstance(in_channels, int)
|
178 |
+
assert isinstance(in_index, int)
|
179 |
+
self.in_channels = in_channels
|
180 |
+
|
181 |
+
def init_weights(self):
|
182 |
+
"""Initialize weights of classification layer."""
|
183 |
+
# normal_init(self.conv_seg, mean=0, std=0.01)
|
184 |
+
# normal_init(self.conv1, mean=0, std=0.01)
|
185 |
+
|
186 |
+
def _transform_inputs(self, inputs):
|
187 |
+
"""Transform inputs for decoder.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
inputs (list[Tensor]): List of multi-level img features.
|
191 |
+
|
192 |
+
Returns:
|
193 |
+
Tensor: The transformed inputs
|
194 |
+
"""
|
195 |
+
|
196 |
+
if self.input_transform == 'resize_concat':
|
197 |
+
inputs = [inputs[i] for i in self.in_index]
|
198 |
+
upsampled_inputs = [
|
199 |
+
resize(
|
200 |
+
input=x,
|
201 |
+
size=inputs[0].shape[2:],
|
202 |
+
mode='bilinear',
|
203 |
+
align_corners=self.align_corners) for x in inputs
|
204 |
+
]
|
205 |
+
inputs = torch.cat(upsampled_inputs, dim=1)
|
206 |
+
elif self.input_transform == 'multiple_select':
|
207 |
+
inputs = [inputs[i] for i in self.in_index]
|
208 |
+
else:
|
209 |
+
inputs = inputs[self.in_index]
|
210 |
+
|
211 |
+
return inputs
|
212 |
+
|
213 |
+
def forward(self, inputs):
|
214 |
+
"""Placeholder of forward function."""
|
215 |
+
pass
|
216 |
+
|
217 |
+
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
|
218 |
+
"""Forward function for training.
|
219 |
+
Args:
|
220 |
+
inputs (list[Tensor]): List of multi-level img features.
|
221 |
+
img_metas (list[dict]): List of image info dict where each dict
|
222 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
223 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
224 |
+
For details on the values of these keys see
|
225 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
226 |
+
gt_semantic_seg (Tensor): Semantic segmentation masks
|
227 |
+
used if the architecture supports semantic segmentation task.
|
228 |
+
train_cfg (dict): The training config.
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
dict[str, Tensor]: a dictionary of loss components
|
232 |
+
"""
|
233 |
+
seg_logits = self.forward(inputs)
|
234 |
+
losses = self.losses(seg_logits, gt_semantic_seg)
|
235 |
+
return losses
|
236 |
+
|
237 |
+
def forward_test(self, inputs, img_metas, test_cfg):
|
238 |
+
"""Forward function for testing.
|
239 |
+
|
240 |
+
Args:
|
241 |
+
inputs (list[Tensor]): List of multi-level img features.
|
242 |
+
img_metas (list[dict]): List of image info dict where each dict
|
243 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
244 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
245 |
+
For details on the values of these keys see
|
246 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
247 |
+
test_cfg (dict): The testing config.
|
248 |
+
|
249 |
+
Returns:
|
250 |
+
Tensor: Output segmentation map.
|
251 |
+
"""
|
252 |
+
return self.forward(inputs)
|
253 |
+
|
254 |
+
|
255 |
+
class UPerHead(BaseDecodeHead):
|
256 |
+
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
|
257 |
+
super(UPerHead, self).__init__(
|
258 |
+
input_transform='multiple_select', **kwargs)
|
259 |
+
# FPN Module
|
260 |
+
self.lateral_convs = nn.ModuleList()
|
261 |
+
self.fpn_convs = nn.ModuleList()
|
262 |
+
for in_channels in self.in_channels: # skip the top layer
|
263 |
+
l_conv = ConvModule(
|
264 |
+
in_channels,
|
265 |
+
self.channels,
|
266 |
+
1,
|
267 |
+
conv_cfg=self.conv_cfg,
|
268 |
+
norm_cfg=self.norm_cfg,
|
269 |
+
act_cfg=self.act_cfg,
|
270 |
+
inplace=True)
|
271 |
+
fpn_conv = ConvModule(
|
272 |
+
self.channels,
|
273 |
+
self.channels,
|
274 |
+
3,
|
275 |
+
padding=1,
|
276 |
+
conv_cfg=self.conv_cfg,
|
277 |
+
norm_cfg=self.norm_cfg,
|
278 |
+
act_cfg=self.act_cfg,
|
279 |
+
inplace=True)
|
280 |
+
self.lateral_convs.append(l_conv)
|
281 |
+
self.fpn_convs.append(fpn_conv)
|
282 |
+
|
283 |
+
def forward(self, inputs):
|
284 |
+
"""Forward function."""
|
285 |
+
|
286 |
+
inputs = self._transform_inputs(inputs)
|
287 |
+
|
288 |
+
# build laterals
|
289 |
+
laterals = [
|
290 |
+
lateral_conv(inputs[i])
|
291 |
+
for i, lateral_conv in enumerate(self.lateral_convs)
|
292 |
+
]
|
293 |
+
|
294 |
+
# laterals.append(self.psp_forward(inputs))
|
295 |
+
|
296 |
+
# build top-down path
|
297 |
+
used_backbone_levels = len(laterals)
|
298 |
+
for i in range(used_backbone_levels - 1, 0, -1):
|
299 |
+
prev_shape = laterals[i - 1].shape[2:]
|
300 |
+
laterals[i - 1] += resize(
|
301 |
+
laterals[i],
|
302 |
+
size=prev_shape,
|
303 |
+
mode='bilinear',
|
304 |
+
align_corners=self.align_corners)
|
305 |
+
|
306 |
+
# build outputs
|
307 |
+
fpn_outs = [
|
308 |
+
self.fpn_convs[i](laterals[i])
|
309 |
+
for i in range(used_backbone_levels - 1)
|
310 |
+
]
|
311 |
+
# append psp feature
|
312 |
+
fpn_outs.append(laterals[-1])
|
313 |
+
|
314 |
+
return fpn_outs[0]
|
315 |
+
|
316 |
+
|
317 |
+
|
318 |
+
class PSP(BaseDecodeHead):
|
319 |
+
"""Unified Perceptual Parsing for Scene Understanding.
|
320 |
+
|
321 |
+
This head is the implementation of `UPerNet
|
322 |
+
<https://arxiv.org/abs/1807.10221>`_.
|
323 |
+
|
324 |
+
Args:
|
325 |
+
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
326 |
+
Module applied on the last feature. Default: (1, 2, 3, 6).
|
327 |
+
"""
|
328 |
+
|
329 |
+
def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
|
330 |
+
super(PSP, self).__init__(
|
331 |
+
input_transform='multiple_select', **kwargs)
|
332 |
+
# PSP Module
|
333 |
+
self.psp_modules = PPM(
|
334 |
+
pool_scales,
|
335 |
+
self.in_channels[-1],
|
336 |
+
self.channels,
|
337 |
+
conv_cfg=self.conv_cfg,
|
338 |
+
norm_cfg=self.norm_cfg,
|
339 |
+
act_cfg=self.act_cfg,
|
340 |
+
align_corners=self.align_corners)
|
341 |
+
self.bottleneck = ConvModule(
|
342 |
+
self.in_channels[-1] + len(pool_scales) * self.channels,
|
343 |
+
self.channels,
|
344 |
+
3,
|
345 |
+
padding=1,
|
346 |
+
conv_cfg=self.conv_cfg,
|
347 |
+
norm_cfg=self.norm_cfg,
|
348 |
+
act_cfg=self.act_cfg)
|
349 |
+
|
350 |
+
def psp_forward(self, inputs):
|
351 |
+
"""Forward function of PSP module."""
|
352 |
+
x = inputs[-1]
|
353 |
+
psp_outs = [x]
|
354 |
+
psp_outs.extend(self.psp_modules(x))
|
355 |
+
psp_outs = torch.cat(psp_outs, dim=1)
|
356 |
+
output = self.bottleneck(psp_outs)
|
357 |
+
|
358 |
+
return output
|
359 |
+
|
360 |
+
def forward(self, inputs):
|
361 |
+
"""Forward function."""
|
362 |
+
inputs = self._transform_inputs(inputs)
|
363 |
+
|
364 |
+
return self.psp_forward(inputs)
|
iebins/sum_depth.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
class Sum_depth(nn.Module):
|
6 |
+
def __init__(self):
|
7 |
+
super(Sum_depth, self).__init__()
|
8 |
+
self.sum_conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False)
|
9 |
+
sum_k = np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
|
10 |
+
|
11 |
+
sum_k = torch.from_numpy(sum_k).float().view(1, 1, 3, 3)
|
12 |
+
self.sum_conv.weight = nn.Parameter(sum_k)
|
13 |
+
|
14 |
+
for param in self.parameters():
|
15 |
+
param.requires_grad = False
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
out = self.sum_conv(x)
|
19 |
+
out = out.contiguous().view(-1, 1, x.size(2), x.size(3))
|
20 |
+
|
21 |
+
return out
|
22 |
+
|
iebins/test.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, division, print_function
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.autograd import Variable
|
6 |
+
|
7 |
+
import os, sys, errno
|
8 |
+
import argparse
|
9 |
+
import time
|
10 |
+
import numpy as np
|
11 |
+
import cv2
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
from tqdm import tqdm
|
14 |
+
import open3d as o3d
|
15 |
+
|
16 |
+
from utils import post_process_depth, D_to_cloud, flip_lr, inv_normalize
|
17 |
+
|
18 |
+
from networks.NewCRFDepth import NewCRFDepth
|
19 |
+
|
20 |
+
|
21 |
+
def convert_arg_line_to_args(arg_line):
|
22 |
+
for arg in arg_line.split():
|
23 |
+
if not arg.strip():
|
24 |
+
continue
|
25 |
+
yield arg
|
26 |
+
|
27 |
+
|
28 |
+
parser = argparse.ArgumentParser(description='IEBins PyTorch implementation.', fromfile_prefix_chars='@')
|
29 |
+
parser.convert_arg_line_to_args = convert_arg_line_to_args
|
30 |
+
|
31 |
+
parser.add_argument('--model_name', type=str, help='model name', default='iebins')
|
32 |
+
parser.add_argument('--encoder', type=str, help='type of encoder, base07, large07, tiny07', default='large07')
|
33 |
+
parser.add_argument('--data_path', type=str, help='path to the data', required=True)
|
34 |
+
parser.add_argument('--filenames_file', type=str, help='path to the filenames text file', required=True)
|
35 |
+
parser.add_argument('--input_height', type=int, help='input height', default=480)
|
36 |
+
parser.add_argument('--input_width', type=int, help='input width', default=640)
|
37 |
+
parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10)
|
38 |
+
parser.add_argument('--checkpoint_path', type=str, help='path to a specific checkpoint to load', default='')
|
39 |
+
parser.add_argument('--dataset', type=str, help='dataset to train on', default='nyu')
|
40 |
+
parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true')
|
41 |
+
parser.add_argument('--pred_clouds', help='if set, pred cloud points', action='store_true')
|
42 |
+
parser.add_argument('--save_viz', help='if set, save visulization of the outputs', action='store_true')
|
43 |
+
|
44 |
+
if sys.argv.__len__() == 2:
|
45 |
+
arg_filename_with_prefix = '@' + sys.argv[1]
|
46 |
+
args = parser.parse_args([arg_filename_with_prefix])
|
47 |
+
else:
|
48 |
+
args = parser.parse_args()
|
49 |
+
|
50 |
+
if args.dataset == 'kitti' or args.dataset == 'nyu':
|
51 |
+
from dataloaders.dataloader import NewDataLoader
|
52 |
+
|
53 |
+
model_dir = os.path.dirname(args.checkpoint_path)
|
54 |
+
sys.path.append(model_dir)
|
55 |
+
|
56 |
+
|
57 |
+
def get_num_lines(file_path):
|
58 |
+
f = open(file_path, 'r')
|
59 |
+
lines = f.readlines()
|
60 |
+
f.close()
|
61 |
+
return len(lines)
|
62 |
+
|
63 |
+
|
64 |
+
def test(params):
|
65 |
+
"""Test function."""
|
66 |
+
args.mode = 'test'
|
67 |
+
dataloader = NewDataLoader(args, 'test')
|
68 |
+
|
69 |
+
model = NewCRFDepth(version='large07', inv_depth=False, max_depth=args.max_depth)
|
70 |
+
model = torch.nn.DataParallel(model)
|
71 |
+
|
72 |
+
checkpoint = torch.load(args.checkpoint_path)
|
73 |
+
model.load_state_dict(checkpoint['model'])
|
74 |
+
model.eval()
|
75 |
+
model.cuda()
|
76 |
+
|
77 |
+
num_params = sum([np.prod(p.size()) for p in model.parameters()])
|
78 |
+
print("Total number of parameters: {}".format(num_params))
|
79 |
+
|
80 |
+
num_test_samples = get_num_lines(args.filenames_file)
|
81 |
+
|
82 |
+
with open(args.filenames_file) as f:
|
83 |
+
lines = f.readlines()
|
84 |
+
|
85 |
+
print('now testing {} files with {}'.format(num_test_samples, args.checkpoint_path))
|
86 |
+
|
87 |
+
pred_depths = []
|
88 |
+
pred_clouds = []
|
89 |
+
start_time = time.time()
|
90 |
+
with torch.no_grad():
|
91 |
+
for _, sample in enumerate(tqdm(dataloader.data)):
|
92 |
+
image = Variable(sample['image'].cuda())
|
93 |
+
inv_K_p = Variable(sample['inv_K_p'].cuda())
|
94 |
+
b, _, h, w = image.shape
|
95 |
+
depth_to_cloud = D_to_cloud(b, h, w).cuda()
|
96 |
+
|
97 |
+
# Predict
|
98 |
+
pred_depths_r_list, _, _ = model(image)
|
99 |
+
post_process = True
|
100 |
+
if post_process:
|
101 |
+
image_flipped = flip_lr(image)
|
102 |
+
pred_depths_r_list_flipped, _, _ = model(image_flipped)
|
103 |
+
pred_depth = post_process_depth(pred_depths_r_list[-1], pred_depths_r_list_flipped[-1])
|
104 |
+
|
105 |
+
if args.pred_clouds:
|
106 |
+
if args.dataset == 'nyu':
|
107 |
+
color = inv_normalize(image[0, :, :, :]).permute(1, 2, 0)[45:472, 43:608, :].reshape(-1, 3).cpu().numpy()
|
108 |
+
points = depth_to_cloud(pred_depth, inv_K_p).reshape(1, h, w, 3)[:, 45:472, 43:608, :].reshape(1, -1, 3)
|
109 |
+
points = points.cpu().numpy().squeeze()
|
110 |
+
else:
|
111 |
+
color = inv_normalize(image[0, :, :, :]).permute(1, 2, 0).reshape(-1, 3).cpu().numpy()
|
112 |
+
points = depth_to_cloud(pred_depth, inv_K_p)
|
113 |
+
points = points.cpu().numpy().squeeze()
|
114 |
+
pc = o3d.geometry.PointCloud()
|
115 |
+
pc.points = o3d.utility.Vector3dVector(points)
|
116 |
+
pc.colors = o3d.utility.Vector3dVector(color)
|
117 |
+
|
118 |
+
pred_clouds.append(pc)
|
119 |
+
|
120 |
+
pred_depth = pred_depth.cpu().numpy().squeeze()
|
121 |
+
|
122 |
+
if args.do_kb_crop:
|
123 |
+
height, width = 352, 1216
|
124 |
+
top_margin = int(height - 352)
|
125 |
+
left_margin = int((width - 1216) / 2)
|
126 |
+
pred_depth_uncropped = np.zeros((height, width), dtype=np.float32)
|
127 |
+
pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth
|
128 |
+
pred_depth = pred_depth_uncropped
|
129 |
+
|
130 |
+
pred_depths.append(pred_depth)
|
131 |
+
|
132 |
+
elapsed_time = time.time() - start_time
|
133 |
+
print('Elapesed time: %s' % str(elapsed_time))
|
134 |
+
print('Done.')
|
135 |
+
|
136 |
+
save_name = 'models/result_' + args.model_name
|
137 |
+
|
138 |
+
print('Saving result pngs..')
|
139 |
+
if not os.path.exists(save_name):
|
140 |
+
try:
|
141 |
+
os.mkdir(save_name)
|
142 |
+
os.mkdir(save_name + '/raw')
|
143 |
+
os.mkdir(save_name + '/cmap')
|
144 |
+
os.mkdir(save_name + '/rgb')
|
145 |
+
os.mkdir(save_name + '/gt')
|
146 |
+
os.mkdir(save_name + '/cloud')
|
147 |
+
except OSError as e:
|
148 |
+
if e.errno != errno.EEXIST:
|
149 |
+
raise
|
150 |
+
|
151 |
+
for s in tqdm(range(num_test_samples)):
|
152 |
+
if args.dataset == 'kitti':
|
153 |
+
date_drive = lines[s].split('/')[1]
|
154 |
+
filename_pred_png = save_name + '/raw/' + date_drive + '_' + lines[s].split()[0].split('/')[-1].replace(
|
155 |
+
'.jpg', '.png')
|
156 |
+
filename_pred_ply = save_name + '/cloud/' + date_drive + '_' + lines[s].split()[0].split('/')[-1][:-4] + '_' + 'iebins' + '.ply'
|
157 |
+
filename_cmap_png = save_name + '/cmap/' + date_drive + '_' + lines[s].split()[0].split('/')[
|
158 |
+
-1].replace('.jpg', '.png')
|
159 |
+
filename_image_png = save_name + '/rgb/' + date_drive + '_' + lines[s].split()[0].split('/')[-1]
|
160 |
+
elif args.dataset == 'kittipred':
|
161 |
+
filename_pred_png = save_name + '/raw/' + lines[s].split()[0].split('/')[-1].replace('.jpg', '.png')
|
162 |
+
filename_cmap_png = save_name + '/cmap/' + lines[s].split()[0].split('/')[-1].replace('.jpg', '.png')
|
163 |
+
filename_image_png = save_name + '/rgb/' + lines[s].split()[0].split('/')[-1]
|
164 |
+
else:
|
165 |
+
scene_name = lines[s].split()[0].split('/')[0]
|
166 |
+
filename_pred_png = save_name + '/raw/' + scene_name + '_' + lines[s].split()[0].split('/')[1].replace(
|
167 |
+
'.jpg', '.png')
|
168 |
+
filename_pred_ply = save_name + '/cloud/' + scene_name + '_' + lines[s].split()[0].split('/')[1][:-4] + '_' + 'iebins' + '.ply'
|
169 |
+
filename_cmap_png = save_name + '/cmap/' + scene_name + '_' + lines[s].split()[0].split('/rgb_')[1].replace(
|
170 |
+
'.jpg', '.png')
|
171 |
+
filename_gt_png = save_name + '/gt/' + scene_name + '_' + lines[s].split()[0].split('/rgb_')[1].replace(
|
172 |
+
'.jpg', '_gt.png')
|
173 |
+
filename_image_png = save_name + '/rgb/' + scene_name + '_' + lines[s].split()[0].split('/rgb_')[1]
|
174 |
+
|
175 |
+
rgb_path = os.path.join(args.data_path, './' + lines[s].split()[0])
|
176 |
+
image = cv2.imread(rgb_path)
|
177 |
+
if args.dataset == 'nyu':
|
178 |
+
gt_path = os.path.join(args.data_path, './' + lines[s].split()[1])
|
179 |
+
gt = cv2.imread(gt_path, -1).astype(np.float32) / 1000.0 # Visualization purpose only
|
180 |
+
gt[gt == 0] = np.amax(gt)
|
181 |
+
|
182 |
+
pred_depth = pred_depths[s]
|
183 |
+
|
184 |
+
if args.dataset == 'kitti' or args.dataset == 'kittipred':
|
185 |
+
pred_depth_scaled = pred_depth * 256.0
|
186 |
+
else:
|
187 |
+
pred_depth_scaled = pred_depth * 1000.0
|
188 |
+
|
189 |
+
pred_depth_scaled = pred_depth_scaled.astype(np.uint16)
|
190 |
+
cv2.imwrite(filename_pred_png, pred_depth_scaled, [cv2.IMWRITE_PNG_COMPRESSION, 0])
|
191 |
+
|
192 |
+
if args.save_viz:
|
193 |
+
cv2.imwrite(filename_image_png, image[10:-1 - 9, 10:-1 - 9, :])
|
194 |
+
if args.dataset == 'nyu':
|
195 |
+
plt.imsave(filename_gt_png, (10 - gt) / 10, cmap='jet')
|
196 |
+
pred_depth_cropped = pred_depth[10:-1 - 9, 10:-1 - 9]
|
197 |
+
plt.imsave(filename_cmap_png, (10 - pred_depth) / 10, cmap='jet')
|
198 |
+
else:
|
199 |
+
plt.imsave(filename_cmap_png, np.log10(pred_depth), cmap='magma')
|
200 |
+
|
201 |
+
if args.pred_clouds:
|
202 |
+
pred_cloud = pred_clouds[s]
|
203 |
+
o3d.io.write_point_cloud(filename_pred_ply, pred_cloud)
|
204 |
+
|
205 |
+
return
|
206 |
+
|
207 |
+
|
208 |
+
if __name__ == '__main__':
|
209 |
+
test(args)
|
iebins/train.py
ADDED
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.utils as utils
|
4 |
+
import torch.backends.cudnn as cudnn
|
5 |
+
import torch.distributed as dist
|
6 |
+
import torch.multiprocessing as mp
|
7 |
+
|
8 |
+
import os, sys, time
|
9 |
+
from telnetlib import IP
|
10 |
+
import argparse
|
11 |
+
import numpy as np
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from tensorboardX import SummaryWriter
|
15 |
+
|
16 |
+
from utils import post_process_depth, flip_lr, silog_loss, compute_errors, eval_metrics, entropy_loss, colormap, \
|
17 |
+
block_print, enable_print, normalize_result, inv_normalize, convert_arg_line_to_args, colormap_magma
|
18 |
+
from networks.NewCRFDepth import NewCRFDepth
|
19 |
+
from networks.depth_update import *
|
20 |
+
from datetime import datetime
|
21 |
+
from sum_depth import Sum_depth
|
22 |
+
|
23 |
+
|
24 |
+
parser = argparse.ArgumentParser(description='IEBins PyTorch implementation.', fromfile_prefix_chars='@')
|
25 |
+
parser.convert_arg_line_to_args = convert_arg_line_to_args
|
26 |
+
|
27 |
+
parser.add_argument('--mode', type=str, help='train or test', default='train')
|
28 |
+
parser.add_argument('--model_name', type=str, help='model name', default='iebins')
|
29 |
+
parser.add_argument('--encoder', type=str, help='type of encoder, base07, large07, tiny07', default='large07')
|
30 |
+
parser.add_argument('--pretrain', type=str, help='path of pretrained encoder', default=None)
|
31 |
+
|
32 |
+
# Dataset
|
33 |
+
parser.add_argument('--dataset', type=str, help='dataset to train on, kitti or nyu', default='nyu')
|
34 |
+
parser.add_argument('--data_path', type=str, help='path to the data', required=True)
|
35 |
+
parser.add_argument('--gt_path', type=str, help='path to the groundtruth data', required=True)
|
36 |
+
parser.add_argument('--filenames_file', type=str, help='path to the filenames text file', required=True)
|
37 |
+
parser.add_argument('--input_height', type=int, help='input height', default=480)
|
38 |
+
parser.add_argument('--input_width', type=int, help='input width', default=640)
|
39 |
+
parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10)
|
40 |
+
parser.add_argument('--min_depth', type=float, help='minimum depth in estimation', default=0.1)
|
41 |
+
|
42 |
+
# Log and save
|
43 |
+
parser.add_argument('--log_directory', type=str, help='directory to save checkpoints and summaries', default='')
|
44 |
+
parser.add_argument('--checkpoint_path', type=str, help='path to a checkpoint to load', default='')
|
45 |
+
parser.add_argument('--log_freq', type=int, help='Logging frequency in global steps', default=100)
|
46 |
+
parser.add_argument('--save_freq', type=int, help='Checkpoint saving frequency in global steps', default=5000)
|
47 |
+
|
48 |
+
# Training
|
49 |
+
parser.add_argument('--weight_decay', type=float, help='weight decay factor for optimization', default=1e-2)
|
50 |
+
parser.add_argument('--retrain', help='if used with checkpoint_path, will restart training from step zero', action='store_true')
|
51 |
+
parser.add_argument('--adam_eps', type=float, help='epsilon in Adam optimizer', default=1e-6)
|
52 |
+
parser.add_argument('--batch_size', type=int, help='batch size', default=4)
|
53 |
+
parser.add_argument('--num_epochs', type=int, help='number of epochs', default=50)
|
54 |
+
parser.add_argument('--learning_rate', type=float, help='initial learning rate', default=1e-4)
|
55 |
+
parser.add_argument('--end_learning_rate', type=float, help='end learning rate', default=-1)
|
56 |
+
parser.add_argument('--variance_focus', type=float, help='lambda in paper: [0, 1], higher value more focus on minimizing variance of error', default=0.85)
|
57 |
+
|
58 |
+
# Preprocessing
|
59 |
+
parser.add_argument('--do_random_rotate', help='if set, will perform random rotation for augmentation', action='store_true')
|
60 |
+
parser.add_argument('--degree', type=float, help='random rotation maximum degree', default=2.5)
|
61 |
+
parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true')
|
62 |
+
parser.add_argument('--use_right', help='if set, will randomly use right images when train on KITTI', action='store_true')
|
63 |
+
|
64 |
+
# Multi-gpu training
|
65 |
+
parser.add_argument('--num_threads', type=int, help='number of threads to use for data loading', default=1)
|
66 |
+
parser.add_argument('--world_size', type=int, help='number of nodes for distributed training', default=1)
|
67 |
+
parser.add_argument('--rank', type=int, help='node rank for distributed training', default=0)
|
68 |
+
parser.add_argument('--dist_url', type=str, help='url used to set up distributed training', default='tcp://127.0.0.1:1234')
|
69 |
+
parser.add_argument('--dist_backend', type=str, help='distributed backend', default='nccl')
|
70 |
+
parser.add_argument('--gpu', type=int, help='GPU id to use.', default=None)
|
71 |
+
parser.add_argument('--multiprocessing_distributed', help='Use multi-processing distributed training to launch '
|
72 |
+
'N processes per node, which has N GPUs. This is the '
|
73 |
+
'fastest way to use PyTorch for either single node or '
|
74 |
+
'multi node data parallel training', action='store_true',)
|
75 |
+
# Online eval
|
76 |
+
parser.add_argument('--do_online_eval', help='if set, perform online eval in every eval_freq steps', action='store_true')
|
77 |
+
parser.add_argument('--data_path_eval', type=str, help='path to the data for online evaluation', required=False)
|
78 |
+
parser.add_argument('--gt_path_eval', type=str, help='path to the groundtruth data for online evaluation', required=False)
|
79 |
+
parser.add_argument('--filenames_file_eval', type=str, help='path to the filenames text file for online evaluation', required=False)
|
80 |
+
parser.add_argument('--min_depth_eval', type=float, help='minimum depth for evaluation', default=1e-3)
|
81 |
+
parser.add_argument('--max_depth_eval', type=float, help='maximum depth for evaluation', default=80)
|
82 |
+
parser.add_argument('--eigen_crop', help='if set, crops according to Eigen NIPS14', action='store_true')
|
83 |
+
parser.add_argument('--garg_crop', help='if set, crops according to Garg ECCV16', action='store_true')
|
84 |
+
parser.add_argument('--eval_freq', type=int, help='Online evaluation frequency in global steps', default=500)
|
85 |
+
parser.add_argument('--eval_summary_directory', type=str, help='output directory for eval summary,'
|
86 |
+
'if empty outputs to checkpoint folder', default='')
|
87 |
+
|
88 |
+
if sys.argv.__len__() == 2:
|
89 |
+
arg_filename_with_prefix = '@' + sys.argv[1]
|
90 |
+
args = parser.parse_args([arg_filename_with_prefix])
|
91 |
+
else:
|
92 |
+
args = parser.parse_args()
|
93 |
+
|
94 |
+
if args.dataset == 'kitti' or args.dataset == 'nyu':
|
95 |
+
from dataloaders.dataloader import NewDataLoader
|
96 |
+
|
97 |
+
|
98 |
+
def online_eval(model, dataloader_eval, gpu, epoch, ngpus, group, post_process=False):
|
99 |
+
eval_measures = torch.zeros(10).cuda(device=gpu)
|
100 |
+
for _, eval_sample_batched in enumerate(tqdm(dataloader_eval.data)):
|
101 |
+
with torch.no_grad():
|
102 |
+
image = torch.autograd.Variable(eval_sample_batched['image'].cuda(gpu, non_blocking=True))
|
103 |
+
gt_depth = eval_sample_batched['depth']
|
104 |
+
has_valid_depth = eval_sample_batched['has_valid_depth']
|
105 |
+
if not has_valid_depth:
|
106 |
+
# print('Invalid depth. continue.')
|
107 |
+
continue
|
108 |
+
|
109 |
+
pred_depths_r_list, _, _ = model(image)
|
110 |
+
if post_process:
|
111 |
+
image_flipped = flip_lr(image)
|
112 |
+
pred_depths_r_list_flipped, _, _ = model(image_flipped)
|
113 |
+
pred_depth = post_process_depth(pred_depths_r_list[-1], pred_depths_r_list_flipped[-1])
|
114 |
+
|
115 |
+
pred_depth = pred_depth.cpu().numpy().squeeze()
|
116 |
+
gt_depth = gt_depth.cpu().numpy().squeeze()
|
117 |
+
|
118 |
+
if args.do_kb_crop:
|
119 |
+
height, width = gt_depth.shape
|
120 |
+
top_margin = int(height - 352)
|
121 |
+
left_margin = int((width - 1216) / 2)
|
122 |
+
pred_depth_uncropped = np.zeros((height, width), dtype=np.float32)
|
123 |
+
pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth
|
124 |
+
pred_depth = pred_depth_uncropped
|
125 |
+
|
126 |
+
pred_depth[pred_depth < args.min_depth_eval] = args.min_depth_eval
|
127 |
+
pred_depth[pred_depth > args.max_depth_eval] = args.max_depth_eval
|
128 |
+
pred_depth[np.isinf(pred_depth)] = args.max_depth_eval
|
129 |
+
pred_depth[np.isnan(pred_depth)] = args.min_depth_eval
|
130 |
+
|
131 |
+
valid_mask = np.logical_and(gt_depth > args.min_depth_eval, gt_depth < args.max_depth_eval)
|
132 |
+
|
133 |
+
if args.garg_crop or args.eigen_crop:
|
134 |
+
gt_height, gt_width = gt_depth.shape
|
135 |
+
eval_mask = np.zeros(valid_mask.shape)
|
136 |
+
|
137 |
+
if args.garg_crop:
|
138 |
+
eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
|
139 |
+
|
140 |
+
elif args.eigen_crop:
|
141 |
+
if args.dataset == 'kitti':
|
142 |
+
eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
|
143 |
+
elif args.dataset == 'nyu':
|
144 |
+
eval_mask[45:471, 41:601] = 1
|
145 |
+
|
146 |
+
valid_mask = np.logical_and(valid_mask, eval_mask)
|
147 |
+
|
148 |
+
measures = compute_errors(gt_depth[valid_mask], pred_depth[valid_mask])
|
149 |
+
|
150 |
+
eval_measures[:9] += torch.tensor(measures).cuda(device=gpu)
|
151 |
+
eval_measures[9] += 1
|
152 |
+
|
153 |
+
if args.multiprocessing_distributed:
|
154 |
+
# group = dist.new_group([i for i in range(ngpus)])
|
155 |
+
dist.all_reduce(tensor=eval_measures, op=dist.ReduceOp.SUM, group=group)
|
156 |
+
|
157 |
+
if not args.multiprocessing_distributed or gpu == 0:
|
158 |
+
eval_measures_cpu = eval_measures.cpu()
|
159 |
+
cnt = eval_measures_cpu[9].item()
|
160 |
+
eval_measures_cpu /= cnt
|
161 |
+
print('Computing errors for {} eval samples'.format(int(cnt)), ', post_process: ', post_process)
|
162 |
+
print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('silog', 'abs_rel', 'log10', 'rms',
|
163 |
+
'sq_rel', 'log_rms', 'd1', 'd2',
|
164 |
+
'd3'))
|
165 |
+
for i in range(8):
|
166 |
+
print('{:7.4f}, '.format(eval_measures_cpu[i]), end='')
|
167 |
+
print('{:7.4f}'.format(eval_measures_cpu[8]))
|
168 |
+
return eval_measures_cpu
|
169 |
+
|
170 |
+
return None
|
171 |
+
|
172 |
+
|
173 |
+
def main_worker(gpu, ngpus_per_node, args):
|
174 |
+
args.gpu = gpu
|
175 |
+
|
176 |
+
if args.gpu is not None:
|
177 |
+
print("== Use GPU: {} for training".format(args.gpu))
|
178 |
+
|
179 |
+
if args.distributed:
|
180 |
+
if args.dist_url == "env://" and args.rank == -1:
|
181 |
+
args.rank = int(os.environ["RANK"])
|
182 |
+
if args.multiprocessing_distributed:
|
183 |
+
args.rank = args.rank * ngpus_per_node + gpu
|
184 |
+
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank)
|
185 |
+
|
186 |
+
# model
|
187 |
+
model = NewCRFDepth(version=args.encoder, inv_depth=False, max_depth=args.max_depth, pretrained=args.pretrain)
|
188 |
+
model.train()
|
189 |
+
|
190 |
+
num_params = sum([np.prod(p.size()) for p in model.parameters()])
|
191 |
+
print("== Total number of parameters: {}".format(num_params))
|
192 |
+
|
193 |
+
num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad])
|
194 |
+
print("== Total number of learning parameters: {}".format(num_params_update))
|
195 |
+
|
196 |
+
if args.distributed:
|
197 |
+
if args.gpu is not None:
|
198 |
+
torch.cuda.set_device(args.gpu)
|
199 |
+
model.cuda(args.gpu)
|
200 |
+
args.batch_size = int(args.batch_size / ngpus_per_node)
|
201 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
|
202 |
+
else:
|
203 |
+
model.cuda()
|
204 |
+
model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
|
205 |
+
else:
|
206 |
+
model = torch.nn.DataParallel(model)
|
207 |
+
model.cuda()
|
208 |
+
|
209 |
+
if args.distributed:
|
210 |
+
print("== Model Initialized on GPU: {}".format(args.gpu))
|
211 |
+
else:
|
212 |
+
print("== Model Initialized")
|
213 |
+
|
214 |
+
global_step = 0
|
215 |
+
best_eval_measures_lower_better = torch.zeros(6).cpu() + 1e3
|
216 |
+
best_eval_measures_higher_better = torch.zeros(3).cpu()
|
217 |
+
best_eval_steps = np.zeros(9, dtype=np.int32)
|
218 |
+
|
219 |
+
# Training parameters
|
220 |
+
optimizer = torch.optim.Adam([{'params': model.module.parameters()}],
|
221 |
+
lr=args.learning_rate)
|
222 |
+
|
223 |
+
model_just_loaded = False
|
224 |
+
if args.checkpoint_path != '':
|
225 |
+
if os.path.isfile(args.checkpoint_path):
|
226 |
+
print("== Loading checkpoint '{}'".format(args.checkpoint_path))
|
227 |
+
if args.gpu is None:
|
228 |
+
checkpoint = torch.load(args.checkpoint_path)
|
229 |
+
else:
|
230 |
+
loc = 'cuda:{}'.format(args.gpu)
|
231 |
+
checkpoint = torch.load(args.checkpoint_path, map_location=loc)
|
232 |
+
model.load_state_dict(checkpoint['model'])
|
233 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
234 |
+
if not args.retrain:
|
235 |
+
try:
|
236 |
+
global_step = checkpoint['global_step']
|
237 |
+
best_eval_measures_higher_better = checkpoint['best_eval_measures_higher_better'].cpu()
|
238 |
+
best_eval_measures_lower_better = checkpoint['best_eval_measures_lower_better'].cpu()
|
239 |
+
best_eval_steps = checkpoint['best_eval_steps']
|
240 |
+
except KeyError:
|
241 |
+
print("Could not load values for online evaluation")
|
242 |
+
|
243 |
+
print("== Loaded checkpoint '{}' (global_step {})".format(args.checkpoint_path, checkpoint['global_step']))
|
244 |
+
else:
|
245 |
+
print("== No checkpoint found at '{}'".format(args.checkpoint_path))
|
246 |
+
model_just_loaded = True
|
247 |
+
del checkpoint
|
248 |
+
|
249 |
+
cudnn.benchmark = True
|
250 |
+
|
251 |
+
dataloader = NewDataLoader(args, 'train')
|
252 |
+
dataloader_eval = NewDataLoader(args, 'online_eval')
|
253 |
+
|
254 |
+
# ===== Evaluation before training ======
|
255 |
+
# model.eval()
|
256 |
+
# with torch.no_grad():
|
257 |
+
# eval_measures = online_eval(model, dataloader_eval, gpu, ngpus_per_node, post_process=True)
|
258 |
+
|
259 |
+
# Logging
|
260 |
+
if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
|
261 |
+
writer = SummaryWriter(args.log_directory + '/' + args.model_name + '/summaries', flush_secs=30)
|
262 |
+
if args.do_online_eval:
|
263 |
+
if args.eval_summary_directory != '':
|
264 |
+
eval_summary_path = os.path.join(args.eval_summary_directory, args.model_name)
|
265 |
+
else:
|
266 |
+
eval_summary_path = os.path.join(args.log_directory, args.model_name, 'eval')
|
267 |
+
eval_summary_writer = SummaryWriter(eval_summary_path, flush_secs=30)
|
268 |
+
|
269 |
+
silog_criterion = silog_loss(variance_focus=args.variance_focus)
|
270 |
+
sum_localdepth = Sum_depth().cuda(args.gpu)
|
271 |
+
|
272 |
+
start_time = time.time()
|
273 |
+
duration = 0
|
274 |
+
|
275 |
+
num_log_images = args.batch_size
|
276 |
+
end_learning_rate = args.end_learning_rate if args.end_learning_rate != -1 else 0.1 * args.learning_rate
|
277 |
+
|
278 |
+
var_sum = [var.sum().item() for var in model.parameters() if var.requires_grad]
|
279 |
+
var_cnt = len(var_sum)
|
280 |
+
var_sum = np.sum(var_sum)
|
281 |
+
|
282 |
+
print("== Initial variables' sum: {:.3f}, avg: {:.3f}".format(var_sum, var_sum/var_cnt))
|
283 |
+
|
284 |
+
steps_per_epoch = len(dataloader.data)
|
285 |
+
num_total_steps = args.num_epochs * steps_per_epoch
|
286 |
+
epoch = global_step // steps_per_epoch
|
287 |
+
|
288 |
+
group = dist.new_group([i for i in range(ngpus_per_node)])
|
289 |
+
while epoch < args.num_epochs:
|
290 |
+
if args.distributed:
|
291 |
+
dataloader.train_sampler.set_epoch(epoch)
|
292 |
+
|
293 |
+
for step, sample_batched in enumerate(dataloader.data):
|
294 |
+
optimizer.zero_grad()
|
295 |
+
before_op_time = time.time()
|
296 |
+
si_loss = 0
|
297 |
+
|
298 |
+
image = torch.autograd.Variable(sample_batched['image'].cuda(args.gpu, non_blocking=True))
|
299 |
+
depth_gt = torch.autograd.Variable(sample_batched['depth'].cuda(args.gpu, non_blocking=True))
|
300 |
+
|
301 |
+
pred_depths_r_list, pred_depths_c_list, uncertainty_maps_list = model(image, epoch, step)
|
302 |
+
|
303 |
+
if args.dataset == 'nyu':
|
304 |
+
mask = depth_gt > 0.1
|
305 |
+
else:
|
306 |
+
mask = depth_gt > 1.0
|
307 |
+
|
308 |
+
max_tree_depth = len(pred_depths_r_list)
|
309 |
+
for curr_tree_depth in range(max_tree_depth):
|
310 |
+
|
311 |
+
si_loss += silog_criterion.forward(pred_depths_r_list[curr_tree_depth], depth_gt, mask.to(torch.bool))
|
312 |
+
|
313 |
+
loss = si_loss
|
314 |
+
|
315 |
+
loss.backward()
|
316 |
+
for param_group in optimizer.param_groups:
|
317 |
+
current_lr = (args.learning_rate - end_learning_rate) * (1 - global_step / num_total_steps) ** 0.9 + end_learning_rate
|
318 |
+
param_group['lr'] = current_lr
|
319 |
+
|
320 |
+
optimizer.step()
|
321 |
+
|
322 |
+
if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
|
323 |
+
print('[epoch][s/s_per_e/gs]: [{}][{}/{}/{}], lr: {:.12f}, loss: {:.12f}'.format(epoch, step, steps_per_epoch, global_step, current_lr, loss))
|
324 |
+
# if np.isnan(loss.cpu().item()):
|
325 |
+
# print('NaN in loss occurred. Aborting training.')
|
326 |
+
# return -1
|
327 |
+
|
328 |
+
duration += time.time() - before_op_time
|
329 |
+
if global_step and global_step % args.log_freq == 0 and not model_just_loaded:
|
330 |
+
var_sum = [var.sum().item() for var in model.parameters() if var.requires_grad]
|
331 |
+
var_cnt = len(var_sum)
|
332 |
+
var_sum = np.sum(var_sum)
|
333 |
+
examples_per_sec = args.batch_size / duration * args.log_freq
|
334 |
+
duration = 0
|
335 |
+
time_sofar = (time.time() - start_time) / 3600
|
336 |
+
training_time_left = (num_total_steps / global_step - 1.0) * time_sofar
|
337 |
+
if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
|
338 |
+
print("{}".format(args.model_name))
|
339 |
+
print_string = 'GPU: {} | examples/s: {:4.2f} | loss: {:.5f} | var sum: {:.3f} avg: {:.3f} | time elapsed: {:.2f}h | time left: {:.2f}h'
|
340 |
+
print(print_string.format(args.gpu, examples_per_sec, loss, var_sum.item(), var_sum.item()/var_cnt, time_sofar, training_time_left))
|
341 |
+
|
342 |
+
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
|
343 |
+
and args.rank % ngpus_per_node == 0):
|
344 |
+
writer.add_scalar('silog_loss', si_loss, global_step)
|
345 |
+
# writer.add_scalar('var_loss', var_loss, global_step)
|
346 |
+
writer.add_scalar('learning_rate', current_lr, global_step)
|
347 |
+
writer.add_scalar('var average', var_sum.item()/var_cnt, global_step)
|
348 |
+
depth_gt = torch.where(depth_gt < 1e-3, depth_gt * 0 + 1e-3, depth_gt)
|
349 |
+
for i in range(num_log_images):
|
350 |
+
if args.dataset == 'nyu':
|
351 |
+
writer.add_image('depth_gt/image/{}'.format(i), colormap(depth_gt[i, :, :, :].data), global_step)
|
352 |
+
writer.add_image('image/image/{}'.format(i), inv_normalize(image[i, :, :, :]).data, global_step)
|
353 |
+
writer.add_image('depth_r_est0/image/{}'.format(i), colormap(pred_depths_r_list[0][i, :, :, :].data), global_step)
|
354 |
+
writer.add_image('depth_r_est1/image/{}'.format(i), colormap(pred_depths_r_list[1][i, :, :, :].data), global_step)
|
355 |
+
writer.add_image('depth_r_est2/image/{}'.format(i), colormap(pred_depths_r_list[2][i, :, :, :].data), global_step)
|
356 |
+
writer.add_image('depth_r_est3/image/{}'.format(i), colormap(pred_depths_r_list[3][i, :, :, :].data), global_step)
|
357 |
+
writer.add_image('depth_r_est4/image/{}'.format(i), colormap(pred_depths_r_list[4][i, :, :, :].data), global_step)
|
358 |
+
writer.add_image('depth_r_est5/image/{}'.format(i), colormap(pred_depths_r_list[5][i, :, :, :].data), global_step)
|
359 |
+
writer.add_image('depth_c_est0/image/{}'.format(i), colormap(pred_depths_c_list[0][i, :, :, :].data), global_step)
|
360 |
+
writer.add_image('depth_c_est1/image/{}'.format(i), colormap(pred_depths_c_list[1][i, :, :, :].data), global_step)
|
361 |
+
writer.add_image('depth_c_est2/image/{}'.format(i), colormap(pred_depths_c_list[2][i, :, :, :].data), global_step)
|
362 |
+
writer.add_image('depth_c_est3/image/{}'.format(i), colormap(pred_depths_c_list[3][i, :, :, :].data), global_step)
|
363 |
+
writer.add_image('depth_c_est4/image/{}'.format(i), colormap(pred_depths_c_list[4][i, :, :, :].data), global_step)
|
364 |
+
writer.add_image('depth_c_est5/image/{}'.format(i), colormap(pred_depths_c_list[5][i, :, :, :].data), global_step)
|
365 |
+
else:
|
366 |
+
writer.add_image('depth_gt/image/{}'.format(i), colormap_magma(torch.log10(depth_gt[i, :, :, :].data)), global_step)
|
367 |
+
writer.add_image('image/image/{}'.format(i), inv_normalize(image[i, :, :, :]).data, global_step)
|
368 |
+
writer.add_image('depth_r_est0/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_r_list[0][i, :, :, :].data)), global_step)
|
369 |
+
writer.add_image('depth_r_est1/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_r_list[1][i, :, :, :].data)), global_step)
|
370 |
+
writer.add_image('depth_r_est2/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_r_list[2][i, :, :, :].data)), global_step)
|
371 |
+
writer.add_image('depth_r_est3/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_r_list[3][i, :, :, :].data)), global_step)
|
372 |
+
writer.add_image('depth_r_est4/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_r_list[4][i, :, :, :].data)), global_step)
|
373 |
+
writer.add_image('depth_r_est5/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_r_list[5][i, :, :, :].data)), global_step)
|
374 |
+
writer.add_image('depth_c_est0/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_c_list[0][i, :, :, :].data)), global_step)
|
375 |
+
writer.add_image('depth_c_est1/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_c_list[1][i, :, :, :].data)), global_step)
|
376 |
+
writer.add_image('depth_c_est2/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_c_list[2][i, :, :, :].data)), global_step)
|
377 |
+
writer.add_image('depth_c_est3/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_c_list[3][i, :, :, :].data)), global_step)
|
378 |
+
writer.add_image('depth_c_est4/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_c_list[4][i, :, :, :].data)), global_step)
|
379 |
+
writer.add_image('depth_c_est5/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_c_list[5][i, :, :, :].data)), global_step)
|
380 |
+
|
381 |
+
writer.add_image('uncer_est0/image/{}'.format(i), colormap(uncertainty_maps_list[0][i, :, :, :].data), global_step)
|
382 |
+
writer.add_image('uncer_est1/image/{}'.format(i), colormap(uncertainty_maps_list[1][i, :, :, :].data), global_step)
|
383 |
+
writer.add_image('uncer_est2/image/{}'.format(i), colormap(uncertainty_maps_list[2][i, :, :, :].data), global_step)
|
384 |
+
writer.add_image('uncer_est3/image/{}'.format(i), colormap(uncertainty_maps_list[3][i, :, :, :].data), global_step)
|
385 |
+
writer.add_image('uncer_est4/image/{}'.format(i), colormap(uncertainty_maps_list[4][i, :, :, :].data), global_step)
|
386 |
+
writer.add_image('uncer_est5/image/{}'.format(i), colormap(uncertainty_maps_list[5][i, :, :, :].data), global_step)
|
387 |
+
|
388 |
+
if args.do_online_eval and global_step and global_step % args.eval_freq == 0 and not model_just_loaded:
|
389 |
+
time.sleep(0.1)
|
390 |
+
model.eval()
|
391 |
+
with torch.no_grad():
|
392 |
+
eval_measures = online_eval(model, dataloader_eval, gpu, epoch, ngpus_per_node, group, post_process=True)
|
393 |
+
if eval_measures is not None:
|
394 |
+
exp_name = '%s'%(datetime.now().strftime('%m%d'))
|
395 |
+
log_txt = os.path.join(args.log_directory + '/' + args.model_name, exp_name+'_logs.txt')
|
396 |
+
with open(log_txt, 'a') as txtfile:
|
397 |
+
txtfile.write(">>>>>>>>>>>>>>>>>>>>>>>>>Step:%d>>>>>>>>>>>>>>>>>>>>>>>>>\n"%(int(global_step)))
|
398 |
+
txtfile.write("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}\n".format('silog',
|
399 |
+
'abs_rel', 'log10', 'rms', 'sq_rel', 'log_rms', 'd1', 'd2','d3'))
|
400 |
+
txtfile.write("depth estimation\n")
|
401 |
+
line = ''
|
402 |
+
for i in range(9):
|
403 |
+
line +='{:7.4f}, '.format(eval_measures[i])
|
404 |
+
txtfile.write(line+'\n')
|
405 |
+
|
406 |
+
for i in range(9):
|
407 |
+
eval_summary_writer.add_scalar(eval_metrics[i], eval_measures[i].cpu(), int(global_step))
|
408 |
+
measure = eval_measures[i]
|
409 |
+
is_best = False
|
410 |
+
if i < 6 and measure < best_eval_measures_lower_better[i]:
|
411 |
+
old_best = best_eval_measures_lower_better[i].item()
|
412 |
+
best_eval_measures_lower_better[i] = measure.item()
|
413 |
+
is_best = True
|
414 |
+
elif i >= 6 and measure > best_eval_measures_higher_better[i-6]:
|
415 |
+
old_best = best_eval_measures_higher_better[i-6].item()
|
416 |
+
best_eval_measures_higher_better[i-6] = measure.item()
|
417 |
+
is_best = True
|
418 |
+
if is_best:
|
419 |
+
old_best_step = best_eval_steps[i]
|
420 |
+
old_best_name = '/model-{}-best_{}_{:.5f}'.format(old_best_step, eval_metrics[i], old_best)
|
421 |
+
model_path = args.log_directory + '/' + args.model_name + old_best_name
|
422 |
+
if os.path.exists(model_path):
|
423 |
+
command = 'rm {}'.format(model_path)
|
424 |
+
os.system(command)
|
425 |
+
best_eval_steps[i] = global_step
|
426 |
+
model_save_name = '/model-{}-best_{}_{:.5f}'.format(global_step, eval_metrics[i], measure)
|
427 |
+
print('New best for {}. Saving model: {}'.format(eval_metrics[i], model_save_name))
|
428 |
+
checkpoint = {'global_step': global_step,
|
429 |
+
'model': model.state_dict(),
|
430 |
+
'optimizer': optimizer.state_dict(),
|
431 |
+
'best_eval_measures_higher_better': best_eval_measures_higher_better,
|
432 |
+
'best_eval_measures_lower_better': best_eval_measures_lower_better,
|
433 |
+
'best_eval_steps': best_eval_steps
|
434 |
+
}
|
435 |
+
torch.save(checkpoint, args.log_directory + '/' + args.model_name + model_save_name)
|
436 |
+
eval_summary_writer.flush()
|
437 |
+
model.train()
|
438 |
+
block_print()
|
439 |
+
enable_print()
|
440 |
+
|
441 |
+
model_just_loaded = False
|
442 |
+
global_step += 1
|
443 |
+
|
444 |
+
epoch += 1
|
445 |
+
|
446 |
+
if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
|
447 |
+
writer.close()
|
448 |
+
if args.do_online_eval:
|
449 |
+
eval_summary_writer.close()
|
450 |
+
|
451 |
+
|
452 |
+
def main():
|
453 |
+
if args.mode != 'train':
|
454 |
+
print('train.py is only for training.')
|
455 |
+
return -1
|
456 |
+
|
457 |
+
exp_name = '%s'%(datetime.now().strftime('%m%d'))
|
458 |
+
args.log_directory = os.path.join(args.log_directory,exp_name)
|
459 |
+
command = 'mkdir ' + os.path.join(args.log_directory, args.model_name)
|
460 |
+
os.system(command)
|
461 |
+
|
462 |
+
args_out_path = os.path.join(args.log_directory, args.model_name)
|
463 |
+
command = 'cp ' + sys.argv[1] + ' ' + args_out_path
|
464 |
+
os.system(command)
|
465 |
+
|
466 |
+
save_files = True
|
467 |
+
if save_files:
|
468 |
+
aux_out_path = os.path.join(args.log_directory, args.model_name)
|
469 |
+
networks_savepath = os.path.join(aux_out_path, 'networks')
|
470 |
+
dataloaders_savepath = os.path.join(aux_out_path, 'dataloaders')
|
471 |
+
command = 'cp iebins/train.py ' + aux_out_path
|
472 |
+
os.system(command)
|
473 |
+
command = 'mkdir -p ' + networks_savepath + ' && cp iebins/networks/*.py ' + networks_savepath
|
474 |
+
os.system(command)
|
475 |
+
command = 'mkdir -p ' + dataloaders_savepath + ' && cp iebins/dataloaders/*.py ' + dataloaders_savepath
|
476 |
+
os.system(command)
|
477 |
+
|
478 |
+
torch.cuda.empty_cache()
|
479 |
+
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
|
480 |
+
|
481 |
+
ngpus_per_node = torch.cuda.device_count()
|
482 |
+
if ngpus_per_node > 1 and not args.multiprocessing_distributed:
|
483 |
+
print("This machine has more than 1 gpu. Please specify --multiprocessing_distributed, or set \'CUDA_VISIBLE_DEVICES=0\'")
|
484 |
+
return -1
|
485 |
+
|
486 |
+
if args.do_online_eval:
|
487 |
+
print("You have specified --do_online_eval.")
|
488 |
+
print("This will evaluate the model every eval_freq {} steps and save best models for individual eval metrics."
|
489 |
+
.format(args.eval_freq))
|
490 |
+
|
491 |
+
if args.multiprocessing_distributed:
|
492 |
+
args.world_size = ngpus_per_node * args.world_size
|
493 |
+
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
494 |
+
else:
|
495 |
+
main_worker(args.gpu, ngpus_per_node, args)
|
496 |
+
|
497 |
+
|
498 |
+
if __name__ == '__main__':
|
499 |
+
main()
|
iebins/utils.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch.distributed as dist
|
5 |
+
from torch.utils.data import Sampler
|
6 |
+
from torchvision import transforms
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import os, sys
|
9 |
+
import numpy as np
|
10 |
+
import math
|
11 |
+
import torch
|
12 |
+
|
13 |
+
|
14 |
+
def convert_arg_line_to_args(arg_line):
|
15 |
+
for arg in arg_line.split():
|
16 |
+
if not arg.strip():
|
17 |
+
continue
|
18 |
+
yield arg
|
19 |
+
|
20 |
+
|
21 |
+
def block_print():
|
22 |
+
sys.stdout = open(os.devnull, 'w')
|
23 |
+
|
24 |
+
|
25 |
+
def enable_print():
|
26 |
+
sys.stdout = sys.__stdout__
|
27 |
+
|
28 |
+
|
29 |
+
def get_num_lines(file_path):
|
30 |
+
f = open(file_path, 'r')
|
31 |
+
lines = f.readlines()
|
32 |
+
f.close()
|
33 |
+
return len(lines)
|
34 |
+
|
35 |
+
|
36 |
+
def colorize(value, vmin=None, vmax=None, cmap='Greys'):
|
37 |
+
value = value.cpu().numpy()[:, :, :]
|
38 |
+
value = np.log10(value)
|
39 |
+
|
40 |
+
vmin = value.min() if vmin is None else vmin
|
41 |
+
vmax = value.max() if vmax is None else vmax
|
42 |
+
|
43 |
+
if vmin != vmax:
|
44 |
+
value = (value - vmin) / (vmax - vmin)
|
45 |
+
else:
|
46 |
+
value = value*0.
|
47 |
+
|
48 |
+
cmapper = matplotlib.cm.get_cmap(cmap)
|
49 |
+
value = cmapper(value, bytes=True)
|
50 |
+
|
51 |
+
img = value[:, :, :3]
|
52 |
+
|
53 |
+
return img.transpose((2, 0, 1))
|
54 |
+
|
55 |
+
|
56 |
+
def normalize_result(value, vmin=None, vmax=None):
|
57 |
+
value = value.cpu().numpy()[0, :, :]
|
58 |
+
|
59 |
+
vmin = value.min() if vmin is None else vmin
|
60 |
+
vmax = value.max() if vmax is None else vmax
|
61 |
+
|
62 |
+
if vmin != vmax:
|
63 |
+
value = (value - vmin) / (vmax - vmin)
|
64 |
+
else:
|
65 |
+
value = value * 0.
|
66 |
+
|
67 |
+
return np.expand_dims(value, 0)
|
68 |
+
|
69 |
+
|
70 |
+
inv_normalize = transforms.Normalize(
|
71 |
+
mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
|
72 |
+
std=[1/0.229, 1/0.224, 1/0.225]
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
eval_metrics = ['silog', 'abs_rel', 'log10', 'rms', 'sq_rel', 'log_rms', 'd1', 'd2', 'd3']
|
77 |
+
|
78 |
+
|
79 |
+
def compute_errors(gt, pred):
|
80 |
+
thresh = np.maximum((gt / pred), (pred / gt))
|
81 |
+
d1 = (thresh < 1.25).mean()
|
82 |
+
d2 = (thresh < 1.25 ** 2).mean()
|
83 |
+
d3 = (thresh < 1.25 ** 3).mean()
|
84 |
+
|
85 |
+
rms = (gt - pred) ** 2
|
86 |
+
rms = np.sqrt(rms.mean())
|
87 |
+
|
88 |
+
log_rms = (np.log(gt) - np.log(pred)) ** 2
|
89 |
+
log_rms = np.sqrt(log_rms.mean())
|
90 |
+
|
91 |
+
abs_rel = np.mean(np.abs(gt - pred) / gt)
|
92 |
+
sq_rel = np.mean(((gt - pred) ** 2) / gt)
|
93 |
+
|
94 |
+
err = np.log(pred) - np.log(gt)
|
95 |
+
silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100
|
96 |
+
|
97 |
+
err = np.abs(np.log10(pred) - np.log10(gt))
|
98 |
+
log10 = np.mean(err)
|
99 |
+
|
100 |
+
return [silog, abs_rel, log10, rms, sq_rel, log_rms, d1, d2, d3]
|
101 |
+
|
102 |
+
|
103 |
+
class silog_loss(nn.Module):
|
104 |
+
def __init__(self, variance_focus):
|
105 |
+
super(silog_loss, self).__init__()
|
106 |
+
self.variance_focus = variance_focus
|
107 |
+
|
108 |
+
def forward(self, depth_est, depth_gt, mask):
|
109 |
+
d = torch.log(depth_est[mask]) - torch.log(depth_gt[mask])
|
110 |
+
return torch.sqrt((d ** 2).mean() - self.variance_focus * (d.mean() ** 2)) * 10.0
|
111 |
+
|
112 |
+
|
113 |
+
def entropy_loss(preds, gt_label, mask):
|
114 |
+
# preds: B, C, H, W
|
115 |
+
# gt_label: B, H, W
|
116 |
+
# mask: B, H, W
|
117 |
+
mask = mask > 0.0 # B, H, W
|
118 |
+
preds = preds.permute(0, 2, 3, 1) # B, H, W, C
|
119 |
+
preds_mask = preds[mask] # N, C
|
120 |
+
gt_label_mask = gt_label[mask] # N
|
121 |
+
loss = F.cross_entropy(preds_mask, gt_label_mask, reduction='mean')
|
122 |
+
return loss
|
123 |
+
|
124 |
+
|
125 |
+
def colormap(inputs, normalize=True, torch_transpose=True):
|
126 |
+
if isinstance(inputs, torch.Tensor):
|
127 |
+
inputs = inputs.detach().cpu().numpy()
|
128 |
+
_DEPTH_COLORMAP = plt.get_cmap('jet', 256) # for plotting
|
129 |
+
vis = inputs
|
130 |
+
if normalize:
|
131 |
+
ma = float(vis.max())
|
132 |
+
mi = float(vis.min())
|
133 |
+
d = ma - mi if ma != mi else 1e5
|
134 |
+
vis = (vis - mi) / d
|
135 |
+
|
136 |
+
if vis.ndim == 4:
|
137 |
+
vis = vis.transpose([0, 2, 3, 1])
|
138 |
+
vis = _DEPTH_COLORMAP(vis)
|
139 |
+
vis = vis[:, :, :, 0, :3]
|
140 |
+
if torch_transpose:
|
141 |
+
vis = vis.transpose(0, 3, 1, 2)
|
142 |
+
elif vis.ndim == 3:
|
143 |
+
vis = _DEPTH_COLORMAP(vis)
|
144 |
+
vis = vis[:, :, :, :3]
|
145 |
+
if torch_transpose:
|
146 |
+
vis = vis.transpose(0, 3, 1, 2)
|
147 |
+
elif vis.ndim == 2:
|
148 |
+
vis = _DEPTH_COLORMAP(vis)
|
149 |
+
vis = vis[..., :3]
|
150 |
+
if torch_transpose:
|
151 |
+
vis = vis.transpose(2, 0, 1)
|
152 |
+
|
153 |
+
return vis[0,:,:,:]
|
154 |
+
|
155 |
+
|
156 |
+
def colormap_magma(inputs, normalize=True, torch_transpose=True):
|
157 |
+
if isinstance(inputs, torch.Tensor):
|
158 |
+
inputs = inputs.detach().cpu().numpy()
|
159 |
+
_DEPTH_COLORMAP = plt.get_cmap('magma', 256) # for plotting
|
160 |
+
vis = inputs
|
161 |
+
if normalize:
|
162 |
+
ma = float(vis.max())
|
163 |
+
mi = float(vis.min())
|
164 |
+
d = ma - mi if ma != mi else 1e5
|
165 |
+
vis = (vis - mi) / d
|
166 |
+
|
167 |
+
if vis.ndim == 4:
|
168 |
+
vis = vis.transpose([0, 2, 3, 1])
|
169 |
+
vis = _DEPTH_COLORMAP(vis)
|
170 |
+
vis = vis[:, :, :, 0, :3]
|
171 |
+
if torch_transpose:
|
172 |
+
vis = vis.transpose(0, 3, 1, 2)
|
173 |
+
elif vis.ndim == 3:
|
174 |
+
vis = _DEPTH_COLORMAP(vis)
|
175 |
+
vis = vis[:, :, :, :3]
|
176 |
+
if torch_transpose:
|
177 |
+
vis = vis.transpose(0, 3, 1, 2)
|
178 |
+
elif vis.ndim == 2:
|
179 |
+
vis = _DEPTH_COLORMAP(vis)
|
180 |
+
vis = vis[..., :3]
|
181 |
+
if torch_transpose:
|
182 |
+
vis = vis.transpose(2, 0, 1)
|
183 |
+
|
184 |
+
return vis[0,:,:,:]
|
185 |
+
|
186 |
+
|
187 |
+
def flip_lr(image):
|
188 |
+
"""
|
189 |
+
Flip image horizontally
|
190 |
+
|
191 |
+
Parameters
|
192 |
+
----------
|
193 |
+
image : torch.Tensor [B,3,H,W]
|
194 |
+
Image to be flipped
|
195 |
+
|
196 |
+
Returns
|
197 |
+
-------
|
198 |
+
image_flipped : torch.Tensor [B,3,H,W]
|
199 |
+
Flipped image
|
200 |
+
"""
|
201 |
+
assert image.dim() == 4, 'You need to provide a [B,C,H,W] image to flip'
|
202 |
+
return torch.flip(image, [3])
|
203 |
+
|
204 |
+
|
205 |
+
def fuse_inv_depth(inv_depth, inv_depth_hat, method='mean'):
|
206 |
+
"""
|
207 |
+
Fuse inverse depth and flipped inverse depth maps
|
208 |
+
|
209 |
+
Parameters
|
210 |
+
----------
|
211 |
+
inv_depth : torch.Tensor [B,1,H,W]
|
212 |
+
Inverse depth map
|
213 |
+
inv_depth_hat : torch.Tensor [B,1,H,W]
|
214 |
+
Flipped inverse depth map produced from a flipped image
|
215 |
+
method : str
|
216 |
+
Method that will be used to fuse the inverse depth maps
|
217 |
+
|
218 |
+
Returns
|
219 |
+
-------
|
220 |
+
fused_inv_depth : torch.Tensor [B,1,H,W]
|
221 |
+
Fused inverse depth map
|
222 |
+
"""
|
223 |
+
if method == 'mean':
|
224 |
+
return 0.5 * (inv_depth + inv_depth_hat)
|
225 |
+
elif method == 'max':
|
226 |
+
return torch.max(inv_depth, inv_depth_hat)
|
227 |
+
elif method == 'min':
|
228 |
+
return torch.min(inv_depth, inv_depth_hat)
|
229 |
+
else:
|
230 |
+
raise ValueError('Unknown post-process method {}'.format(method))
|
231 |
+
|
232 |
+
|
233 |
+
def post_process_depth(depth, depth_flipped, method='mean'):
|
234 |
+
"""
|
235 |
+
Post-process an inverse and flipped inverse depth map
|
236 |
+
|
237 |
+
Parameters
|
238 |
+
----------
|
239 |
+
inv_depth : torch.Tensor [B,1,H,W]
|
240 |
+
Inverse depth map
|
241 |
+
inv_depth_flipped : torch.Tensor [B,1,H,W]
|
242 |
+
Inverse depth map produced from a flipped image
|
243 |
+
method : str
|
244 |
+
Method that will be used to fuse the inverse depth maps
|
245 |
+
|
246 |
+
Returns
|
247 |
+
-------
|
248 |
+
inv_depth_pp : torch.Tensor [B,1,H,W]
|
249 |
+
Post-processed inverse depth map
|
250 |
+
"""
|
251 |
+
B, C, H, W = depth.shape
|
252 |
+
inv_depth_hat = flip_lr(depth_flipped)
|
253 |
+
inv_depth_fused = fuse_inv_depth(depth, inv_depth_hat, method=method)
|
254 |
+
xs = torch.linspace(0., 1., W, device=depth.device,
|
255 |
+
dtype=depth.dtype).repeat(B, C, H, 1)
|
256 |
+
mask = 1.0 - torch.clamp(20. * (xs - 0.05), 0., 1.)
|
257 |
+
mask_hat = flip_lr(mask)
|
258 |
+
return mask_hat * depth + mask * inv_depth_hat + \
|
259 |
+
(1.0 - mask - mask_hat) * inv_depth_fused
|
260 |
+
|
261 |
+
|
262 |
+
class DistributedSamplerNoEvenlyDivisible(Sampler):
|
263 |
+
"""Sampler that restricts data loading to a subset of the dataset.
|
264 |
+
|
265 |
+
It is especially useful in conjunction with
|
266 |
+
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
|
267 |
+
process can pass a DistributedSampler instance as a DataLoader sampler,
|
268 |
+
and load a subset of the original dataset that is exclusive to it.
|
269 |
+
|
270 |
+
.. note::
|
271 |
+
Dataset is assumed to be of constant size.
|
272 |
+
|
273 |
+
Arguments:
|
274 |
+
dataset: Dataset used for sampling.
|
275 |
+
num_replicas (optional): Number of processes participating in
|
276 |
+
distributed training.
|
277 |
+
rank (optional): Rank of the current process within num_replicas.
|
278 |
+
shuffle (optional): If true (default), sampler will shuffle the indices
|
279 |
+
"""
|
280 |
+
|
281 |
+
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
|
282 |
+
if num_replicas is None:
|
283 |
+
if not dist.is_available():
|
284 |
+
raise RuntimeError("Requires distributed package to be available")
|
285 |
+
num_replicas = dist.get_world_size()
|
286 |
+
if rank is None:
|
287 |
+
if not dist.is_available():
|
288 |
+
raise RuntimeError("Requires distributed package to be available")
|
289 |
+
rank = dist.get_rank()
|
290 |
+
self.dataset = dataset
|
291 |
+
self.num_replicas = num_replicas
|
292 |
+
self.rank = rank
|
293 |
+
self.epoch = 0
|
294 |
+
num_samples = int(math.floor(len(self.dataset) * 1.0 / self.num_replicas))
|
295 |
+
rest = len(self.dataset) - num_samples * self.num_replicas
|
296 |
+
if self.rank < rest:
|
297 |
+
num_samples += 1
|
298 |
+
self.num_samples = num_samples
|
299 |
+
self.total_size = len(dataset)
|
300 |
+
# self.total_size = self.num_samples * self.num_replicas
|
301 |
+
self.shuffle = shuffle
|
302 |
+
|
303 |
+
def __iter__(self):
|
304 |
+
# deterministically shuffle based on epoch
|
305 |
+
g = torch.Generator()
|
306 |
+
g.manual_seed(self.epoch)
|
307 |
+
if self.shuffle:
|
308 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
309 |
+
else:
|
310 |
+
indices = list(range(len(self.dataset)))
|
311 |
+
|
312 |
+
# add extra samples to make it evenly divisible
|
313 |
+
# indices += indices[:(self.total_size - len(indices))]
|
314 |
+
# assert len(indices) == self.total_size
|
315 |
+
|
316 |
+
# subsample
|
317 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
318 |
+
self.num_samples = len(indices)
|
319 |
+
# assert len(indices) == self.num_samples
|
320 |
+
|
321 |
+
return iter(indices)
|
322 |
+
|
323 |
+
def __len__(self):
|
324 |
+
return self.num_samples
|
325 |
+
|
326 |
+
def set_epoch(self, epoch):
|
327 |
+
self.epoch = epoch
|
328 |
+
|
329 |
+
|
330 |
+
class D_to_cloud(nn.Module):
|
331 |
+
"""Layer to transform depth into point cloud
|
332 |
+
"""
|
333 |
+
def __init__(self, batch_size, height, width):
|
334 |
+
super(D_to_cloud, self).__init__()
|
335 |
+
|
336 |
+
self.batch_size = batch_size
|
337 |
+
self.height = height
|
338 |
+
self.width = width
|
339 |
+
|
340 |
+
meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
|
341 |
+
self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32) # 2, H, W
|
342 |
+
self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords), requires_grad=False) # 2, H, W
|
343 |
+
|
344 |
+
self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width),
|
345 |
+
requires_grad=False) # B, 1, H, W
|
346 |
+
|
347 |
+
self.pix_coords = torch.unsqueeze(torch.stack(
|
348 |
+
[self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0) # 1, 2, L
|
349 |
+
self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1) # B, 2, L
|
350 |
+
self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1), requires_grad=False) # B, 3, L
|
351 |
+
|
352 |
+
def forward(self, depth, inv_K):
|
353 |
+
cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords)
|
354 |
+
cam_points = depth.view(self.batch_size, 1, -1) * cam_points
|
355 |
+
|
356 |
+
return cam_points.permute(0, 2, 1)
|
iebins/utils/transfrom.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from PIL import Image, ImageOps, ImageFilter
|
3 |
+
import torch
|
4 |
+
from torchvision import transforms
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import cv2
|
9 |
+
import math
|
10 |
+
|
11 |
+
|
12 |
+
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
13 |
+
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
14 |
+
Args:
|
15 |
+
sample (dict): sample
|
16 |
+
size (tuple): image size
|
17 |
+
Returns:
|
18 |
+
tuple: new size
|
19 |
+
"""
|
20 |
+
shape = list(sample["disparity"].shape)
|
21 |
+
|
22 |
+
if shape[0] >= size[0] and shape[1] >= size[1]:
|
23 |
+
return sample
|
24 |
+
|
25 |
+
scale = [0, 0]
|
26 |
+
scale[0] = size[0] / shape[0]
|
27 |
+
scale[1] = size[1] / shape[1]
|
28 |
+
|
29 |
+
scale = max(scale)
|
30 |
+
|
31 |
+
shape[0] = math.ceil(scale * shape[0])
|
32 |
+
shape[1] = math.ceil(scale * shape[1])
|
33 |
+
|
34 |
+
# resize
|
35 |
+
sample["image"] = cv2.resize(
|
36 |
+
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
|
37 |
+
)
|
38 |
+
|
39 |
+
sample["disparity"] = cv2.resize(
|
40 |
+
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
|
41 |
+
)
|
42 |
+
sample["mask"] = cv2.resize(
|
43 |
+
sample["mask"].astype(np.float32),
|
44 |
+
tuple(shape[::-1]),
|
45 |
+
interpolation=cv2.INTER_NEAREST,
|
46 |
+
)
|
47 |
+
sample["mask"] = sample["mask"].astype(bool)
|
48 |
+
|
49 |
+
return tuple(shape)
|
50 |
+
|
51 |
+
|
52 |
+
class Resize(object):
|
53 |
+
"""Resize sample to given size (width, height).
|
54 |
+
"""
|
55 |
+
|
56 |
+
def __init__(
|
57 |
+
self,
|
58 |
+
width,
|
59 |
+
height,
|
60 |
+
resize_target=True,
|
61 |
+
keep_aspect_ratio=False,
|
62 |
+
ensure_multiple_of=1,
|
63 |
+
resize_method="lower_bound",
|
64 |
+
image_interpolation_method=cv2.INTER_AREA,
|
65 |
+
):
|
66 |
+
"""Init.
|
67 |
+
Args:
|
68 |
+
width (int): desired output width
|
69 |
+
height (int): desired output height
|
70 |
+
resize_target (bool, optional):
|
71 |
+
True: Resize the full sample (image, mask, target).
|
72 |
+
False: Resize image only.
|
73 |
+
Defaults to True.
|
74 |
+
keep_aspect_ratio (bool, optional):
|
75 |
+
True: Keep the aspect ratio of the input sample.
|
76 |
+
Output sample might not have the given width and height, and
|
77 |
+
resize behaviour depends on the parameter 'resize_method'.
|
78 |
+
Defaults to False.
|
79 |
+
ensure_multiple_of (int, optional):
|
80 |
+
Output width and height is constrained to be multiple of this parameter.
|
81 |
+
Defaults to 1.
|
82 |
+
resize_method (str, optional):
|
83 |
+
"lower_bound": Output will be at least as large as the given size.
|
84 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
85 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
86 |
+
Defaults to "lower_bound".
|
87 |
+
"""
|
88 |
+
self.__width = width
|
89 |
+
self.__height = height
|
90 |
+
|
91 |
+
self.__resize_target = resize_target
|
92 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
93 |
+
self.__multiple_of = ensure_multiple_of
|
94 |
+
self.__resize_method = resize_method
|
95 |
+
self.__image_interpolation_method = image_interpolation_method
|
96 |
+
|
97 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
98 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
99 |
+
|
100 |
+
if max_val is not None and y > max_val:
|
101 |
+
y = (np.floor(x / self.__multiple_of)
|
102 |
+
* self.__multiple_of).astype(int)
|
103 |
+
|
104 |
+
if y < min_val:
|
105 |
+
y = (np.ceil(x / self.__multiple_of)
|
106 |
+
* self.__multiple_of).astype(int)
|
107 |
+
|
108 |
+
return y
|
109 |
+
|
110 |
+
def get_size(self, width, height):
|
111 |
+
# determine new height and width
|
112 |
+
scale_height = self.__height / height
|
113 |
+
scale_width = self.__width / width
|
114 |
+
|
115 |
+
if self.__keep_aspect_ratio:
|
116 |
+
if self.__resize_method == "lower_bound":
|
117 |
+
# scale such that output size is lower bound
|
118 |
+
if scale_width > scale_height:
|
119 |
+
# fit width
|
120 |
+
scale_height = scale_width
|
121 |
+
else:
|
122 |
+
# fit height
|
123 |
+
scale_width = scale_height
|
124 |
+
elif self.__resize_method == "upper_bound":
|
125 |
+
# scale such that output size is upper bound
|
126 |
+
if scale_width < scale_height:
|
127 |
+
# fit width
|
128 |
+
scale_height = scale_width
|
129 |
+
else:
|
130 |
+
# fit height
|
131 |
+
scale_width = scale_height
|
132 |
+
elif self.__resize_method == "minimal":
|
133 |
+
# scale as least as possbile
|
134 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
135 |
+
# fit width
|
136 |
+
scale_height = scale_width
|
137 |
+
else:
|
138 |
+
# fit height
|
139 |
+
scale_width = scale_height
|
140 |
+
else:
|
141 |
+
raise ValueError(
|
142 |
+
f"resize_method {self.__resize_method} not implemented"
|
143 |
+
)
|
144 |
+
|
145 |
+
if self.__resize_method == "lower_bound":
|
146 |
+
new_height = self.constrain_to_multiple_of(
|
147 |
+
scale_height * height, min_val=self.__height
|
148 |
+
)
|
149 |
+
new_width = self.constrain_to_multiple_of(
|
150 |
+
scale_width * width, min_val=self.__width
|
151 |
+
)
|
152 |
+
elif self.__resize_method == "upper_bound":
|
153 |
+
new_height = self.constrain_to_multiple_of(
|
154 |
+
scale_height * height, max_val=self.__height
|
155 |
+
)
|
156 |
+
new_width = self.constrain_to_multiple_of(
|
157 |
+
scale_width * width, max_val=self.__width
|
158 |
+
)
|
159 |
+
elif self.__resize_method == "minimal":
|
160 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
161 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
162 |
+
else:
|
163 |
+
raise ValueError(f"resize_method {
|
164 |
+
self.__resize_method} not implemented")
|
165 |
+
|
166 |
+
return (new_width, new_height)
|
167 |
+
|
168 |
+
def __call__(self, sample):
|
169 |
+
width, height = self.get_size(
|
170 |
+
sample["image"].shape[1], sample["image"].shape[0]
|
171 |
+
)
|
172 |
+
|
173 |
+
# resize sample
|
174 |
+
sample["image"] = cv2.resize(
|
175 |
+
sample["image"],
|
176 |
+
(width, height),
|
177 |
+
interpolation=self.__image_interpolation_method,
|
178 |
+
)
|
179 |
+
|
180 |
+
if self.__resize_target:
|
181 |
+
if "disparity" in sample:
|
182 |
+
sample["disparity"] = cv2.resize(
|
183 |
+
sample["disparity"],
|
184 |
+
(width, height),
|
185 |
+
interpolation=cv2.INTER_NEAREST,
|
186 |
+
)
|
187 |
+
|
188 |
+
if "depth" in sample:
|
189 |
+
sample["depth"] = cv2.resize(
|
190 |
+
sample["depth"], (width,
|
191 |
+
height), interpolation=cv2.INTER_NEAREST
|
192 |
+
)
|
193 |
+
|
194 |
+
if "semseg_mask" in sample:
|
195 |
+
# sample["semseg_mask"] = cv2.resize(
|
196 |
+
# sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST
|
197 |
+
# )
|
198 |
+
sample["semseg_mask"] = F.interpolate(torch.from_numpy(sample["semseg_mask"]).float()[
|
199 |
+
None, None, ...], (height, width), mode='nearest').numpy()[0, 0]
|
200 |
+
|
201 |
+
if "mask" in sample:
|
202 |
+
sample["mask"] = cv2.resize(
|
203 |
+
sample["mask"].astype(np.float32),
|
204 |
+
(width, height),
|
205 |
+
interpolation=cv2.INTER_NEAREST,
|
206 |
+
)
|
207 |
+
# sample["mask"] = sample["mask"].astype(bool)
|
208 |
+
|
209 |
+
# print(sample['image'].shape, sample['depth'].shape)
|
210 |
+
return sample
|
211 |
+
|
212 |
+
|
213 |
+
class NormalizeImage(object):
|
214 |
+
"""Normlize image by given mean and std.
|
215 |
+
"""
|
216 |
+
|
217 |
+
def __init__(self, mean, std):
|
218 |
+
self.__mean = mean
|
219 |
+
self.__std = std
|
220 |
+
|
221 |
+
def __call__(self, sample):
|
222 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
223 |
+
|
224 |
+
return sample
|
225 |
+
|
226 |
+
|
227 |
+
class PrepareForNet(object):
|
228 |
+
"""Prepare sample for usage as network input.
|
229 |
+
"""
|
230 |
+
|
231 |
+
def __init__(self):
|
232 |
+
pass
|
233 |
+
|
234 |
+
def __call__(self, sample):
|
235 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
236 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
237 |
+
|
238 |
+
if "mask" in sample:
|
239 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
240 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
241 |
+
|
242 |
+
if "depth" in sample:
|
243 |
+
depth = sample["depth"].astype(np.float32)
|
244 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
245 |
+
|
246 |
+
if "semseg_mask" in sample:
|
247 |
+
sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32)
|
248 |
+
sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"])
|
249 |
+
|
250 |
+
return sample
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pytorch=1.10.0
|
2 |
+
torchvision
|
3 |
+
cudatoolkit=11.1
|
4 |
+
matplotlib
|
5 |
+
tqdm
|
6 |
+
tensorboardX
|
7 |
+
timm
|
8 |
+
mmcv
|
9 |
+
open3d
|
10 |
+
gradio_imageslider
|
11 |
+
torch
|
12 |
+
opencv-python
|