File size: 15,075 Bytes
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
import numpy as np
import copy
import cv2
import h5py
import math
from tqdm import tqdm
import torch
from torch.nn.functional import pixel_shuffle, softmax
from torch.utils.data import DataLoader
from kornia.geometry import warp_perspective

from .dataset.dataset_util import get_dataset
from .model.model_util import get_model
from .misc.train_utils import get_latest_checkpoint
from .train import convert_junc_predictions
from .dataset.transforms.homographic_transforms import sample_homography


def restore_weights(model, state_dict):
    """ Restore weights in compatible mode. """
    # Try to directly load state dict
    try:
        model.load_state_dict(state_dict)
    except:
        err = model.load_state_dict(state_dict, strict=False)
        # missing keys are those in model but not in state_dict
        missing_keys = err.missing_keys
        # Unexpected keys are those in state_dict but not in model
        unexpected_keys = err.unexpected_keys

        # Load mismatched keys manually
        model_dict = model.state_dict()
        for idx, key in enumerate(missing_keys):
            dict_keys = [_ for _ in unexpected_keys if not "tracked" in _]
            model_dict[key] = state_dict[dict_keys[idx]]
        model.load_state_dict(model_dict)
    return model


def get_padded_filename(num_pad, idx):
    """ Get the filename padded with 0. """
    file_len = len("%d" % (idx))
    filename = "0" * (num_pad - file_len) + "%d" % (idx)
    return filename


def export_predictions(args, dataset_cfg, model_cfg, output_path,
                       export_dataset_mode):
    """ Export predictions. """
    # Get the test configuration
    test_cfg = model_cfg["test"]

    # Create the dataset and dataloader based on the export_dataset_mode
    print("\t Initializing dataset and dataloader")
    batch_size = 4
    export_dataset, collate_fn = get_dataset(export_dataset_mode, dataset_cfg)
    export_loader = DataLoader(export_dataset, batch_size=batch_size,
                               num_workers=test_cfg.get("num_workers", 4),
                               shuffle=False, pin_memory=False,
                               collate_fn=collate_fn)
    print("\t Successfully intialized dataset and dataloader.")

    # Initialize model and load the checkpoint
    model = get_model(model_cfg, mode="test")
    checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name)
    model = restore_weights(model, checkpoint["model_state_dict"])
    model = model.cuda()
    model.eval()
    print("\t Successfully initialized model")

    # Start the export process
    print("[Info] Start exporting predictions")
    output_dataset_path = output_path + ".h5"
    filename_idx = 0
    with h5py.File(output_dataset_path, "w", libver="latest", swmr=True) as f:
        # Iterate through all the data in dataloader
        for data in tqdm(export_loader, ascii=True):
            # Fetch the data
            junc_map = data["junction_map"]
            heatmap = data["heatmap"]
            valid_mask = data["valid_mask"]
            input_images = data["image"].cuda()

            # Run the forward pass
            with torch.no_grad():
                outputs = model(input_images)

            # Convert predictions
            junc_np = convert_junc_predictions(
                outputs["junctions"], model_cfg["grid_size"],
                model_cfg["detection_thresh"], 300)
            junc_map_np = junc_map.numpy().transpose(0, 2, 3, 1)
            heatmap_np = softmax(outputs["heatmap"].detach(),
                                 dim=1).cpu().numpy().transpose(0, 2, 3, 1)
            heatmap_gt_np = heatmap.numpy().transpose(0, 2, 3, 1)
            valid_mask_np = valid_mask.numpy().transpose(0, 2, 3, 1)

            # Data entries to save
            current_batch_size = input_images.shape[0]
            for batch_idx in range(current_batch_size):
                output_data = {
                    "image": input_images.cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
                    "junc_gt": junc_map_np[batch_idx],
                    "junc_pred": junc_np["junc_pred"][batch_idx],
                    "junc_pred_nms": junc_np["junc_pred_nms"][batch_idx].astype(np.float32),
                    "heatmap_gt": heatmap_gt_np[batch_idx],
                    "heatmap_pred": heatmap_np[batch_idx],
                    "valid_mask": valid_mask_np[batch_idx],
                    "junc_points": data["junctions"][batch_idx].numpy()[0].round().astype(np.int32),
                    "line_map": data["line_map"][batch_idx].numpy()[0].astype(np.int32)
                }

                # Save data to h5 dataset
                num_pad = math.ceil(math.log10(len(export_loader))) + 1
                output_key = get_padded_filename(num_pad, filename_idx)
                f_group = f.create_group(output_key)

                # Store data
                for key, output_data in output_data.items():
                    f_group.create_dataset(key, data=output_data,
                                           compression="gzip")
                filename_idx += 1


def export_homograpy_adaptation(args, dataset_cfg, model_cfg, output_path,
                                export_dataset_mode, device):
    """ Export homography adaptation results. """
    # Check if the export_dataset_mode is supported
    supported_modes = ["train", "test"]
    if not export_dataset_mode in supported_modes:
        raise ValueError(
            "[Error] The specified export_dataset_mode is not supported.")

    # Get the test configuration
    test_cfg = model_cfg["test"]

    # Get the homography adaptation configurations
    homography_cfg = dataset_cfg.get("homography_adaptation", None)
    if homography_cfg is None:
        raise ValueError(
            "[Error] Empty homography_adaptation entry in config.")

    # Create the dataset and dataloader based on the export_dataset_mode
    print("\t Initializing dataset and dataloader")
    batch_size = args.export_batch_size

    export_dataset, collate_fn = get_dataset(export_dataset_mode, dataset_cfg)
    export_loader = DataLoader(export_dataset, batch_size=batch_size,
                               num_workers=test_cfg.get("num_workers", 4),
                               shuffle=False, pin_memory=False,
                               collate_fn=collate_fn)
    print("\t Successfully intialized dataset and dataloader.")

    # Initialize model and load the checkpoint
    model = get_model(model_cfg, mode="test")
    checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name,
                                       device)
    model = restore_weights(model, checkpoint["model_state_dict"])
    model = model.to(device).eval()
    print("\t Successfully initialized model")

    # Start the export process
    print("[Info] Start exporting predictions")    
    output_dataset_path = output_path + ".h5"
    with h5py.File(output_dataset_path, "w", libver="latest") as f:
        f.swmr_mode=True
        for _, data in enumerate(tqdm(export_loader, ascii=True)):
            input_images = data["image"].to(device)
            file_keys = data["file_key"]
            batch_size = input_images.shape[0]
            
            # Run the homograpy adaptation
            outputs = homography_adaptation(input_images, model,
                                            model_cfg["grid_size"],
                                            homography_cfg)

            # Save the entries
            for batch_idx in range(batch_size):
                # Get the save key
                save_key = file_keys[batch_idx]
                output_data = {
                    "image": input_images.cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
                    "junc_prob_mean": outputs["junc_probs_mean"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
                    "junc_prob_max": outputs["junc_probs_max"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
                    "junc_count": outputs["junc_counts"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
                    "heatmap_prob_mean": outputs["heatmap_probs_mean"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
                    "heatmap_prob_max": outputs["heatmap_probs_max"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx],
                    "heatmap_cout": outputs["heatmap_counts"].cpu().numpy().transpose(0, 2, 3, 1)[batch_idx]
                }

                # Create group and write data
                f_group = f.create_group(save_key)
                for key, output_data in output_data.items():
                    f_group.create_dataset(key, data=output_data,
                                           compression="gzip")


def homography_adaptation(input_images, model, grid_size, homography_cfg):
    """ The homography adaptation process.
    Arguments:
        input_images: The images to be evaluated.
        model: The pytorch model in evaluation mode.
        grid_size: Grid size of the junction decoder.
        homography_cfg: Homography adaptation configurations.
    """
    # Get the device of the current model
    device = next(model.parameters()).device

    # Define some constants and placeholder
    batch_size, _, H, W = input_images.shape
    num_iter = homography_cfg["num_iter"]
    junc_probs = torch.zeros([batch_size, num_iter, H, W], device=device)
    junc_counts = torch.zeros([batch_size, 1, H, W], device=device)
    heatmap_probs = torch.zeros([batch_size, num_iter, H, W], device=device)
    heatmap_counts = torch.zeros([batch_size, 1, H, W], device=device)
    margin = homography_cfg["valid_border_margin"]

    # Keep a config with no artifacts
    homography_cfg_no_artifacts = copy.copy(homography_cfg["homographies"])
    homography_cfg_no_artifacts["allow_artifacts"] = False

    for idx in range(num_iter):
        if idx <= num_iter // 5:
            # Ensure that 20% of the homographies have no artifact
            H_mat_lst = [sample_homography(
                [H,W], **homography_cfg_no_artifacts)[0][None]
                         for _ in range(batch_size)]
        else:
            H_mat_lst = [sample_homography(
                [H,W], **homography_cfg["homographies"])[0][None]
                         for _ in range(batch_size)]

        H_mats = np.concatenate(H_mat_lst, axis=0)
        H_tensor = torch.tensor(H_mats, dtype=torch.float, device=device)
        H_inv_tensor = torch.inverse(H_tensor)

        # Perform the homography warp
        images_warped = warp_perspective(input_images, H_tensor, (H, W),
                                         flags="bilinear")
        
        # Warp the mask
        masks_junc_warped = warp_perspective(
            torch.ones([batch_size, 1, H, W], device=device),
            H_tensor, (H, W), flags="nearest")
        masks_heatmap_warped = warp_perspective(
            torch.ones([batch_size, 1, H, W], device=device),
            H_tensor, (H, W), flags="nearest")

        # Run the network forward pass
        with torch.no_grad():
            outputs = model(images_warped)
        
        # Unwarp and mask the junction prediction
        junc_prob_warped = pixel_shuffle(softmax(
            outputs["junctions"], dim=1)[:, :-1, :, :], grid_size)
        junc_prob = warp_perspective(junc_prob_warped, H_inv_tensor,
                                     (H, W), flags="bilinear")

        # Create the out of boundary mask
        out_boundary_mask = warp_perspective(
            torch.ones([batch_size, 1, H, W], device=device),
            H_inv_tensor, (H, W), flags="nearest")
        out_boundary_mask = adjust_border(out_boundary_mask, device, margin)

        junc_prob = junc_prob * out_boundary_mask
        junc_count = warp_perspective(masks_junc_warped * out_boundary_mask,
                                      H_inv_tensor, (H, W), flags="nearest")

        # Unwarp the mask and heatmap prediction
        # Always fetch only one channel
        if outputs["heatmap"].shape[1] == 2:
            # Convert to single channel directly from here
            heatmap_prob_warped = softmax(outputs["heatmap"],
                                          dim=1)[:, 1:, :, :]
        else:
            heatmap_prob_warped = torch.sigmoid(outputs["heatmap"])
        
        heatmap_prob_warped = heatmap_prob_warped * masks_heatmap_warped
        heatmap_prob = warp_perspective(heatmap_prob_warped, H_inv_tensor,
                                        (H, W), flags="bilinear")
        heatmap_count = warp_perspective(masks_heatmap_warped, H_inv_tensor,
                                         (H, W), flags="nearest")

        # Record the results
        junc_probs[:, idx:idx+1, :, :] = junc_prob
        heatmap_probs[:, idx:idx+1, :, :] = heatmap_prob
        junc_counts += junc_count
        heatmap_counts += heatmap_count

    # Perform the accumulation operation
    if homography_cfg["min_counts"] > 0:
        min_counts = homography_cfg["min_counts"]
        junc_count_mask = (junc_counts < min_counts)
        heatmap_count_mask = (heatmap_counts < min_counts)
        junc_counts[junc_count_mask] = 0
        heatmap_counts[heatmap_count_mask] = 0
    else:
        junc_count_mask = np.zeros_like(junc_counts, dtype=bool)
        heatmap_count_mask = np.zeros_like(heatmap_counts, dtype=bool)
    
    # Compute the mean accumulation
    junc_probs_mean = torch.sum(junc_probs, dim=1, keepdim=True) / junc_counts
    junc_probs_mean[junc_count_mask] = 0.
    heatmap_probs_mean = (torch.sum(heatmap_probs, dim=1, keepdim=True)
                          / heatmap_counts)
    heatmap_probs_mean[heatmap_count_mask] = 0.

    # Compute the max accumulation
    junc_probs_max = torch.max(junc_probs, dim=1, keepdim=True)[0]
    junc_probs_max[junc_count_mask] = 0.
    heatmap_probs_max = torch.max(heatmap_probs, dim=1, keepdim=True)[0]
    heatmap_probs_max[heatmap_count_mask] = 0.

    return {"junc_probs_mean": junc_probs_mean,
            "junc_probs_max": junc_probs_max,
            "junc_counts": junc_counts,
            "heatmap_probs_mean": heatmap_probs_mean,
            "heatmap_probs_max": heatmap_probs_max,
            "heatmap_counts": heatmap_counts}


def adjust_border(input_masks, device, margin=3):
    """ Adjust the border of the counts and valid_mask. """
    # Convert the mask to numpy array
    dtype = input_masks.dtype
    input_masks = np.squeeze(input_masks.cpu().numpy(), axis=1)

    erosion_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,
                                               (margin*2, margin*2))
    batch_size = input_masks.shape[0]
    
    output_mask_lst = []
    # Erode all the masks
    for i in range(batch_size):
        output_mask = cv2.erode(input_masks[i, ...], erosion_kernel)

        output_mask_lst.append(
            torch.tensor(output_mask, dtype=dtype, device=device)[None])
    
    # Concat back along the batch dimension.
    output_masks = torch.cat(output_mask_lst, dim=0)
    return output_masks.unsqueeze(dim=1)