|
import numpy as np |
|
import copy |
|
import cv2 |
|
import h5py |
|
import math |
|
from tqdm import tqdm |
|
import torch |
|
from torch.nn.functional import pixel_shuffle, softmax |
|
from torch.utils.data import DataLoader |
|
from kornia.geometry import warp_perspective |
|
|
|
from .dataset.dataset_util import get_dataset |
|
from .model.model_util import get_model |
|
from .misc.train_utils import get_latest_checkpoint |
|
from .train import convert_junc_predictions |
|
from .dataset.transforms.homographic_transforms import sample_homography |
|
|
|
|
|
def restore_weights(model, state_dict): |
|
"""Restore weights in compatible mode.""" |
|
|
|
try: |
|
model.load_state_dict(state_dict) |
|
except: |
|
err = model.load_state_dict(state_dict, strict=False) |
|
|
|
missing_keys = err.missing_keys |
|
|
|
unexpected_keys = err.unexpected_keys |
|
|
|
|
|
model_dict = model.state_dict() |
|
for idx, key in enumerate(missing_keys): |
|
dict_keys = [_ for _ in unexpected_keys if not "tracked" in _] |
|
model_dict[key] = state_dict[dict_keys[idx]] |
|
model.load_state_dict(model_dict) |
|
return model |
|
|
|
|
|
def get_padded_filename(num_pad, idx): |
|
"""Get the filename padded with 0.""" |
|
file_len = len("%d" % (idx)) |
|
filename = "0" * (num_pad - file_len) + "%d" % (idx) |
|
return filename |
|
|
|
|
|
def export_predictions(args, dataset_cfg, model_cfg, output_path, export_dataset_mode): |
|
"""Export predictions.""" |
|
|
|
test_cfg = model_cfg["test"] |
|
|
|
|
|
print("\t Initializing dataset and dataloader") |
|
batch_size = 4 |
|
export_dataset, collate_fn = get_dataset(export_dataset_mode, dataset_cfg) |
|
export_loader = DataLoader( |
|
export_dataset, |
|
batch_size=batch_size, |
|
num_workers=test_cfg.get("num_workers", 4), |
|
shuffle=False, |
|
pin_memory=False, |
|
collate_fn=collate_fn, |
|
) |
|
print("\t Successfully intialized dataset and dataloader.") |
|
|
|
|
|
model = get_model(model_cfg, mode="test") |
|
checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name) |
|
model = restore_weights(model, checkpoint["model_state_dict"]) |
|
model = model.cuda() |
|
model.eval() |
|
print("\t Successfully initialized model") |
|
|
|
|
|
print("[Info] Start exporting predictions") |
|
output_dataset_path = output_path + ".h5" |
|
filename_idx = 0 |
|
with h5py.File(output_dataset_path, "w", libver="latest", swmr=True) as f: |
|
|
|
for data in tqdm(export_loader, ascii=True): |
|
|
|
junc_map = data["junction_map"] |
|
heatmap = data["heatmap"] |
|
valid_mask = data["valid_mask"] |
|
input_images = data["image"].cuda() |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(input_images) |
|
|
|
|
|
junc_np = convert_junc_predictions( |
|
outputs["junctions"], |
|
model_cfg["grid_size"], |
|
model_cfg["detection_thresh"], |
|
300, |
|
) |
|
junc_map_np = junc_map.numpy().transpose(0, 2, 3, 1) |
|
heatmap_np = ( |
|
softmax(outputs["heatmap"].detach(), dim=1) |
|
.cpu() |
|
.numpy() |
|
.transpose(0, 2, 3, 1) |
|
) |
|
heatmap_gt_np = heatmap.numpy().transpose(0, 2, 3, 1) |
|
valid_mask_np = valid_mask.numpy().transpose(0, 2, 3, 1) |
|
|
|
|
|
current_batch_size = input_images.shape[0] |
|
for batch_idx in range(current_batch_size): |
|
output_data = { |
|
"image": input_images.cpu() |
|
.numpy() |
|
.transpose(0, 2, 3, 1)[batch_idx], |
|
"junc_gt": junc_map_np[batch_idx], |
|
"junc_pred": junc_np["junc_pred"][batch_idx], |
|
"junc_pred_nms": junc_np["junc_pred_nms"][batch_idx].astype( |
|
np.float32 |
|
), |
|
"heatmap_gt": heatmap_gt_np[batch_idx], |
|
"heatmap_pred": heatmap_np[batch_idx], |
|
"valid_mask": valid_mask_np[batch_idx], |
|
"junc_points": data["junctions"][batch_idx] |
|
.numpy()[0] |
|
.round() |
|
.astype(np.int32), |
|
"line_map": data["line_map"][batch_idx].numpy()[0].astype(np.int32), |
|
} |
|
|
|
|
|
num_pad = math.ceil(math.log10(len(export_loader))) + 1 |
|
output_key = get_padded_filename(num_pad, filename_idx) |
|
f_group = f.create_group(output_key) |
|
|
|
|
|
for key, output_data in output_data.items(): |
|
f_group.create_dataset(key, data=output_data, compression="gzip") |
|
filename_idx += 1 |
|
|
|
|
|
def export_homograpy_adaptation( |
|
args, dataset_cfg, model_cfg, output_path, export_dataset_mode, device |
|
): |
|
"""Export homography adaptation results.""" |
|
|
|
supported_modes = ["train", "test"] |
|
if not export_dataset_mode in supported_modes: |
|
raise ValueError("[Error] The specified export_dataset_mode is not supported.") |
|
|
|
|
|
test_cfg = model_cfg["test"] |
|
|
|
|
|
homography_cfg = dataset_cfg.get("homography_adaptation", None) |
|
if homography_cfg is None: |
|
raise ValueError("[Error] Empty homography_adaptation entry in config.") |
|
|
|
|
|
print("\t Initializing dataset and dataloader") |
|
batch_size = args.export_batch_size |
|
|
|
export_dataset, collate_fn = get_dataset(export_dataset_mode, dataset_cfg) |
|
export_loader = DataLoader( |
|
export_dataset, |
|
batch_size=batch_size, |
|
num_workers=test_cfg.get("num_workers", 4), |
|
shuffle=False, |
|
pin_memory=False, |
|
collate_fn=collate_fn, |
|
) |
|
print("\t Successfully intialized dataset and dataloader.") |
|
|
|
|
|
model = get_model(model_cfg, mode="test") |
|
checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name, device) |
|
model = restore_weights(model, checkpoint["model_state_dict"]) |
|
model = model.to(device).eval() |
|
print("\t Successfully initialized model") |
|
|
|
|
|
print("[Info] Start exporting predictions") |
|
output_dataset_path = output_path + ".h5" |
|
with h5py.File(output_dataset_path, "w", libver="latest") as f: |
|
f.swmr_mode = True |
|
for _, data in enumerate(tqdm(export_loader, ascii=True)): |
|
input_images = data["image"].to(device) |
|
file_keys = data["file_key"] |
|
batch_size = input_images.shape[0] |
|
|
|
|
|
outputs = homography_adaptation( |
|
input_images, model, model_cfg["grid_size"], homography_cfg |
|
) |
|
|
|
|
|
for batch_idx in range(batch_size): |
|
|
|
save_key = file_keys[batch_idx] |
|
output_data = { |
|
"image": input_images.cpu() |
|
.numpy() |
|
.transpose(0, 2, 3, 1)[batch_idx], |
|
"junc_prob_mean": outputs["junc_probs_mean"] |
|
.cpu() |
|
.numpy() |
|
.transpose(0, 2, 3, 1)[batch_idx], |
|
"junc_prob_max": outputs["junc_probs_max"] |
|
.cpu() |
|
.numpy() |
|
.transpose(0, 2, 3, 1)[batch_idx], |
|
"junc_count": outputs["junc_counts"] |
|
.cpu() |
|
.numpy() |
|
.transpose(0, 2, 3, 1)[batch_idx], |
|
"heatmap_prob_mean": outputs["heatmap_probs_mean"] |
|
.cpu() |
|
.numpy() |
|
.transpose(0, 2, 3, 1)[batch_idx], |
|
"heatmap_prob_max": outputs["heatmap_probs_max"] |
|
.cpu() |
|
.numpy() |
|
.transpose(0, 2, 3, 1)[batch_idx], |
|
"heatmap_cout": outputs["heatmap_counts"] |
|
.cpu() |
|
.numpy() |
|
.transpose(0, 2, 3, 1)[batch_idx], |
|
} |
|
|
|
|
|
f_group = f.create_group(save_key) |
|
for key, output_data in output_data.items(): |
|
f_group.create_dataset(key, data=output_data, compression="gzip") |
|
|
|
|
|
def homography_adaptation(input_images, model, grid_size, homography_cfg): |
|
"""The homography adaptation process. |
|
Arguments: |
|
input_images: The images to be evaluated. |
|
model: The pytorch model in evaluation mode. |
|
grid_size: Grid size of the junction decoder. |
|
homography_cfg: Homography adaptation configurations. |
|
""" |
|
|
|
device = next(model.parameters()).device |
|
|
|
|
|
batch_size, _, H, W = input_images.shape |
|
num_iter = homography_cfg["num_iter"] |
|
junc_probs = torch.zeros([batch_size, num_iter, H, W], device=device) |
|
junc_counts = torch.zeros([batch_size, 1, H, W], device=device) |
|
heatmap_probs = torch.zeros([batch_size, num_iter, H, W], device=device) |
|
heatmap_counts = torch.zeros([batch_size, 1, H, W], device=device) |
|
margin = homography_cfg["valid_border_margin"] |
|
|
|
|
|
homography_cfg_no_artifacts = copy.copy(homography_cfg["homographies"]) |
|
homography_cfg_no_artifacts["allow_artifacts"] = False |
|
|
|
for idx in range(num_iter): |
|
if idx <= num_iter // 5: |
|
|
|
H_mat_lst = [ |
|
sample_homography([H, W], **homography_cfg_no_artifacts)[0][None] |
|
for _ in range(batch_size) |
|
] |
|
else: |
|
H_mat_lst = [ |
|
sample_homography([H, W], **homography_cfg["homographies"])[0][None] |
|
for _ in range(batch_size) |
|
] |
|
|
|
H_mats = np.concatenate(H_mat_lst, axis=0) |
|
H_tensor = torch.tensor(H_mats, dtype=torch.float, device=device) |
|
H_inv_tensor = torch.inverse(H_tensor) |
|
|
|
|
|
images_warped = warp_perspective( |
|
input_images, H_tensor, (H, W), flags="bilinear" |
|
) |
|
|
|
|
|
masks_junc_warped = warp_perspective( |
|
torch.ones([batch_size, 1, H, W], device=device), |
|
H_tensor, |
|
(H, W), |
|
flags="nearest", |
|
) |
|
masks_heatmap_warped = warp_perspective( |
|
torch.ones([batch_size, 1, H, W], device=device), |
|
H_tensor, |
|
(H, W), |
|
flags="nearest", |
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(images_warped) |
|
|
|
|
|
junc_prob_warped = pixel_shuffle( |
|
softmax(outputs["junctions"], dim=1)[:, :-1, :, :], grid_size |
|
) |
|
junc_prob = warp_perspective( |
|
junc_prob_warped, H_inv_tensor, (H, W), flags="bilinear" |
|
) |
|
|
|
|
|
out_boundary_mask = warp_perspective( |
|
torch.ones([batch_size, 1, H, W], device=device), |
|
H_inv_tensor, |
|
(H, W), |
|
flags="nearest", |
|
) |
|
out_boundary_mask = adjust_border(out_boundary_mask, device, margin) |
|
|
|
junc_prob = junc_prob * out_boundary_mask |
|
junc_count = warp_perspective( |
|
masks_junc_warped * out_boundary_mask, H_inv_tensor, (H, W), flags="nearest" |
|
) |
|
|
|
|
|
|
|
if outputs["heatmap"].shape[1] == 2: |
|
|
|
heatmap_prob_warped = softmax(outputs["heatmap"], dim=1)[:, 1:, :, :] |
|
else: |
|
heatmap_prob_warped = torch.sigmoid(outputs["heatmap"]) |
|
|
|
heatmap_prob_warped = heatmap_prob_warped * masks_heatmap_warped |
|
heatmap_prob = warp_perspective( |
|
heatmap_prob_warped, H_inv_tensor, (H, W), flags="bilinear" |
|
) |
|
heatmap_count = warp_perspective( |
|
masks_heatmap_warped, H_inv_tensor, (H, W), flags="nearest" |
|
) |
|
|
|
|
|
junc_probs[:, idx : idx + 1, :, :] = junc_prob |
|
heatmap_probs[:, idx : idx + 1, :, :] = heatmap_prob |
|
junc_counts += junc_count |
|
heatmap_counts += heatmap_count |
|
|
|
|
|
if homography_cfg["min_counts"] > 0: |
|
min_counts = homography_cfg["min_counts"] |
|
junc_count_mask = junc_counts < min_counts |
|
heatmap_count_mask = heatmap_counts < min_counts |
|
junc_counts[junc_count_mask] = 0 |
|
heatmap_counts[heatmap_count_mask] = 0 |
|
else: |
|
junc_count_mask = np.zeros_like(junc_counts, dtype=bool) |
|
heatmap_count_mask = np.zeros_like(heatmap_counts, dtype=bool) |
|
|
|
|
|
junc_probs_mean = torch.sum(junc_probs, dim=1, keepdim=True) / junc_counts |
|
junc_probs_mean[junc_count_mask] = 0.0 |
|
heatmap_probs_mean = torch.sum(heatmap_probs, dim=1, keepdim=True) / heatmap_counts |
|
heatmap_probs_mean[heatmap_count_mask] = 0.0 |
|
|
|
|
|
junc_probs_max = torch.max(junc_probs, dim=1, keepdim=True)[0] |
|
junc_probs_max[junc_count_mask] = 0.0 |
|
heatmap_probs_max = torch.max(heatmap_probs, dim=1, keepdim=True)[0] |
|
heatmap_probs_max[heatmap_count_mask] = 0.0 |
|
|
|
return { |
|
"junc_probs_mean": junc_probs_mean, |
|
"junc_probs_max": junc_probs_max, |
|
"junc_counts": junc_counts, |
|
"heatmap_probs_mean": heatmap_probs_mean, |
|
"heatmap_probs_max": heatmap_probs_max, |
|
"heatmap_counts": heatmap_counts, |
|
} |
|
|
|
|
|
def adjust_border(input_masks, device, margin=3): |
|
"""Adjust the border of the counts and valid_mask.""" |
|
|
|
dtype = input_masks.dtype |
|
input_masks = np.squeeze(input_masks.cpu().numpy(), axis=1) |
|
|
|
erosion_kernel = cv2.getStructuringElement( |
|
cv2.MORPH_ELLIPSE, (margin * 2, margin * 2) |
|
) |
|
batch_size = input_masks.shape[0] |
|
|
|
output_mask_lst = [] |
|
|
|
for i in range(batch_size): |
|
output_mask = cv2.erode(input_masks[i, ...], erosion_kernel) |
|
|
|
output_mask_lst.append( |
|
torch.tensor(output_mask, dtype=dtype, device=device)[None] |
|
) |
|
|
|
|
|
output_masks = torch.cat(output_mask_lst, dim=0) |
|
return output_masks.unsqueeze(dim=1) |
|
|