Vincentqyw
fix: roma
358ab8f
raw
history blame
15.3 kB
import numpy as np
import copy
import cv2
import h5py
import math
from tqdm import tqdm
import torch
from torch.nn.functional import pixel_shuffle, softmax
from torch.utils.data import DataLoader
from kornia.geometry import warp_perspective
from .dataset.dataset_util import get_dataset
from .model.model_util import get_model
from .misc.train_utils import get_latest_checkpoint
from .train import convert_junc_predictions
from .dataset.transforms.homographic_transforms import sample_homography
def restore_weights(model, state_dict):
"""Restore weights in compatible mode."""
# Try to directly load state dict
try:
model.load_state_dict(state_dict)
except:
err = model.load_state_dict(state_dict, strict=False)
# missing keys are those in model but not in state_dict
missing_keys = err.missing_keys
# Unexpected keys are those in state_dict but not in model
unexpected_keys = err.unexpected_keys
# Load mismatched keys manually
model_dict = model.state_dict()
for idx, key in enumerate(missing_keys):
dict_keys = [_ for _ in unexpected_keys if not "tracked" in _]
model_dict[key] = state_dict[dict_keys[idx]]
model.load_state_dict(model_dict)
return model
def get_padded_filename(num_pad, idx):
"""Get the filename padded with 0."""
file_len = len("%d" % (idx))
filename = "0" * (num_pad - file_len) + "%d" % (idx)
return filename
def export_predictions(args, dataset_cfg, model_cfg, output_path, export_dataset_mode):
"""Export predictions."""
# Get the test configuration
test_cfg = model_cfg["test"]
# Create the dataset and dataloader based on the export_dataset_mode
print("\t Initializing dataset and dataloader")
batch_size = 4
export_dataset, collate_fn = get_dataset(export_dataset_mode, dataset_cfg)
export_loader = DataLoader(
export_dataset,
batch_size=batch_size,
num_workers=test_cfg.get("num_workers", 4),
shuffle=False,
pin_memory=False,
collate_fn=collate_fn,
)
print("\t Successfully intialized dataset and dataloader.")
# Initialize model and load the checkpoint
model = get_model(model_cfg, mode="test")
checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name)
model = restore_weights(model, checkpoint["model_state_dict"])
model = model.cuda()
model.eval()
print("\t Successfully initialized model")
# Start the export process
print("[Info] Start exporting predictions")
output_dataset_path = output_path + ".h5"
filename_idx = 0
with h5py.File(output_dataset_path, "w", libver="latest", swmr=True) as f:
# Iterate through all the data in dataloader
for data in tqdm(export_loader, ascii=True):
# Fetch the data
junc_map = data["junction_map"]
heatmap = data["heatmap"]
valid_mask = data["valid_mask"]
input_images = data["image"].cuda()
# Run the forward pass
with torch.no_grad():
outputs = model(input_images)
# Convert predictions
junc_np = convert_junc_predictions(
outputs["junctions"],
model_cfg["grid_size"],
model_cfg["detection_thresh"],
300,
)
junc_map_np = junc_map.numpy().transpose(0, 2, 3, 1)
heatmap_np = (
softmax(outputs["heatmap"].detach(), dim=1)
.cpu()
.numpy()
.transpose(0, 2, 3, 1)
)
heatmap_gt_np = heatmap.numpy().transpose(0, 2, 3, 1)
valid_mask_np = valid_mask.numpy().transpose(0, 2, 3, 1)
# Data entries to save
current_batch_size = input_images.shape[0]
for batch_idx in range(current_batch_size):
output_data = {
"image": input_images.cpu()
.numpy()
.transpose(0, 2, 3, 1)[batch_idx],
"junc_gt": junc_map_np[batch_idx],
"junc_pred": junc_np["junc_pred"][batch_idx],
"junc_pred_nms": junc_np["junc_pred_nms"][batch_idx].astype(
np.float32
),
"heatmap_gt": heatmap_gt_np[batch_idx],
"heatmap_pred": heatmap_np[batch_idx],
"valid_mask": valid_mask_np[batch_idx],
"junc_points": data["junctions"][batch_idx]
.numpy()[0]
.round()
.astype(np.int32),
"line_map": data["line_map"][batch_idx].numpy()[0].astype(np.int32),
}
# Save data to h5 dataset
num_pad = math.ceil(math.log10(len(export_loader))) + 1
output_key = get_padded_filename(num_pad, filename_idx)
f_group = f.create_group(output_key)
# Store data
for key, output_data in output_data.items():
f_group.create_dataset(key, data=output_data, compression="gzip")
filename_idx += 1
def export_homograpy_adaptation(
args, dataset_cfg, model_cfg, output_path, export_dataset_mode, device
):
"""Export homography adaptation results."""
# Check if the export_dataset_mode is supported
supported_modes = ["train", "test"]
if not export_dataset_mode in supported_modes:
raise ValueError("[Error] The specified export_dataset_mode is not supported.")
# Get the test configuration
test_cfg = model_cfg["test"]
# Get the homography adaptation configurations
homography_cfg = dataset_cfg.get("homography_adaptation", None)
if homography_cfg is None:
raise ValueError("[Error] Empty homography_adaptation entry in config.")
# Create the dataset and dataloader based on the export_dataset_mode
print("\t Initializing dataset and dataloader")
batch_size = args.export_batch_size
export_dataset, collate_fn = get_dataset(export_dataset_mode, dataset_cfg)
export_loader = DataLoader(
export_dataset,
batch_size=batch_size,
num_workers=test_cfg.get("num_workers", 4),
shuffle=False,
pin_memory=False,
collate_fn=collate_fn,
)
print("\t Successfully intialized dataset and dataloader.")
# Initialize model and load the checkpoint
model = get_model(model_cfg, mode="test")
checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name, device)
model = restore_weights(model, checkpoint["model_state_dict"])
model = model.to(device).eval()
print("\t Successfully initialized model")
# Start the export process
print("[Info] Start exporting predictions")
output_dataset_path = output_path + ".h5"
with h5py.File(output_dataset_path, "w", libver="latest") as f:
f.swmr_mode = True
for _, data in enumerate(tqdm(export_loader, ascii=True)):
input_images = data["image"].to(device)
file_keys = data["file_key"]
batch_size = input_images.shape[0]
# Run the homograpy adaptation
outputs = homography_adaptation(
input_images, model, model_cfg["grid_size"], homography_cfg
)
# Save the entries
for batch_idx in range(batch_size):
# Get the save key
save_key = file_keys[batch_idx]
output_data = {
"image": input_images.cpu()
.numpy()
.transpose(0, 2, 3, 1)[batch_idx],
"junc_prob_mean": outputs["junc_probs_mean"]
.cpu()
.numpy()
.transpose(0, 2, 3, 1)[batch_idx],
"junc_prob_max": outputs["junc_probs_max"]
.cpu()
.numpy()
.transpose(0, 2, 3, 1)[batch_idx],
"junc_count": outputs["junc_counts"]
.cpu()
.numpy()
.transpose(0, 2, 3, 1)[batch_idx],
"heatmap_prob_mean": outputs["heatmap_probs_mean"]
.cpu()
.numpy()
.transpose(0, 2, 3, 1)[batch_idx],
"heatmap_prob_max": outputs["heatmap_probs_max"]
.cpu()
.numpy()
.transpose(0, 2, 3, 1)[batch_idx],
"heatmap_cout": outputs["heatmap_counts"]
.cpu()
.numpy()
.transpose(0, 2, 3, 1)[batch_idx],
}
# Create group and write data
f_group = f.create_group(save_key)
for key, output_data in output_data.items():
f_group.create_dataset(key, data=output_data, compression="gzip")
def homography_adaptation(input_images, model, grid_size, homography_cfg):
"""The homography adaptation process.
Arguments:
input_images: The images to be evaluated.
model: The pytorch model in evaluation mode.
grid_size: Grid size of the junction decoder.
homography_cfg: Homography adaptation configurations.
"""
# Get the device of the current model
device = next(model.parameters()).device
# Define some constants and placeholder
batch_size, _, H, W = input_images.shape
num_iter = homography_cfg["num_iter"]
junc_probs = torch.zeros([batch_size, num_iter, H, W], device=device)
junc_counts = torch.zeros([batch_size, 1, H, W], device=device)
heatmap_probs = torch.zeros([batch_size, num_iter, H, W], device=device)
heatmap_counts = torch.zeros([batch_size, 1, H, W], device=device)
margin = homography_cfg["valid_border_margin"]
# Keep a config with no artifacts
homography_cfg_no_artifacts = copy.copy(homography_cfg["homographies"])
homography_cfg_no_artifacts["allow_artifacts"] = False
for idx in range(num_iter):
if idx <= num_iter // 5:
# Ensure that 20% of the homographies have no artifact
H_mat_lst = [
sample_homography([H, W], **homography_cfg_no_artifacts)[0][None]
for _ in range(batch_size)
]
else:
H_mat_lst = [
sample_homography([H, W], **homography_cfg["homographies"])[0][None]
for _ in range(batch_size)
]
H_mats = np.concatenate(H_mat_lst, axis=0)
H_tensor = torch.tensor(H_mats, dtype=torch.float, device=device)
H_inv_tensor = torch.inverse(H_tensor)
# Perform the homography warp
images_warped = warp_perspective(
input_images, H_tensor, (H, W), flags="bilinear"
)
# Warp the mask
masks_junc_warped = warp_perspective(
torch.ones([batch_size, 1, H, W], device=device),
H_tensor,
(H, W),
flags="nearest",
)
masks_heatmap_warped = warp_perspective(
torch.ones([batch_size, 1, H, W], device=device),
H_tensor,
(H, W),
flags="nearest",
)
# Run the network forward pass
with torch.no_grad():
outputs = model(images_warped)
# Unwarp and mask the junction prediction
junc_prob_warped = pixel_shuffle(
softmax(outputs["junctions"], dim=1)[:, :-1, :, :], grid_size
)
junc_prob = warp_perspective(
junc_prob_warped, H_inv_tensor, (H, W), flags="bilinear"
)
# Create the out of boundary mask
out_boundary_mask = warp_perspective(
torch.ones([batch_size, 1, H, W], device=device),
H_inv_tensor,
(H, W),
flags="nearest",
)
out_boundary_mask = adjust_border(out_boundary_mask, device, margin)
junc_prob = junc_prob * out_boundary_mask
junc_count = warp_perspective(
masks_junc_warped * out_boundary_mask, H_inv_tensor, (H, W), flags="nearest"
)
# Unwarp the mask and heatmap prediction
# Always fetch only one channel
if outputs["heatmap"].shape[1] == 2:
# Convert to single channel directly from here
heatmap_prob_warped = softmax(outputs["heatmap"], dim=1)[:, 1:, :, :]
else:
heatmap_prob_warped = torch.sigmoid(outputs["heatmap"])
heatmap_prob_warped = heatmap_prob_warped * masks_heatmap_warped
heatmap_prob = warp_perspective(
heatmap_prob_warped, H_inv_tensor, (H, W), flags="bilinear"
)
heatmap_count = warp_perspective(
masks_heatmap_warped, H_inv_tensor, (H, W), flags="nearest"
)
# Record the results
junc_probs[:, idx : idx + 1, :, :] = junc_prob
heatmap_probs[:, idx : idx + 1, :, :] = heatmap_prob
junc_counts += junc_count
heatmap_counts += heatmap_count
# Perform the accumulation operation
if homography_cfg["min_counts"] > 0:
min_counts = homography_cfg["min_counts"]
junc_count_mask = junc_counts < min_counts
heatmap_count_mask = heatmap_counts < min_counts
junc_counts[junc_count_mask] = 0
heatmap_counts[heatmap_count_mask] = 0
else:
junc_count_mask = np.zeros_like(junc_counts, dtype=bool)
heatmap_count_mask = np.zeros_like(heatmap_counts, dtype=bool)
# Compute the mean accumulation
junc_probs_mean = torch.sum(junc_probs, dim=1, keepdim=True) / junc_counts
junc_probs_mean[junc_count_mask] = 0.0
heatmap_probs_mean = torch.sum(heatmap_probs, dim=1, keepdim=True) / heatmap_counts
heatmap_probs_mean[heatmap_count_mask] = 0.0
# Compute the max accumulation
junc_probs_max = torch.max(junc_probs, dim=1, keepdim=True)[0]
junc_probs_max[junc_count_mask] = 0.0
heatmap_probs_max = torch.max(heatmap_probs, dim=1, keepdim=True)[0]
heatmap_probs_max[heatmap_count_mask] = 0.0
return {
"junc_probs_mean": junc_probs_mean,
"junc_probs_max": junc_probs_max,
"junc_counts": junc_counts,
"heatmap_probs_mean": heatmap_probs_mean,
"heatmap_probs_max": heatmap_probs_max,
"heatmap_counts": heatmap_counts,
}
def adjust_border(input_masks, device, margin=3):
"""Adjust the border of the counts and valid_mask."""
# Convert the mask to numpy array
dtype = input_masks.dtype
input_masks = np.squeeze(input_masks.cpu().numpy(), axis=1)
erosion_kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE, (margin * 2, margin * 2)
)
batch_size = input_masks.shape[0]
output_mask_lst = []
# Erode all the masks
for i in range(batch_size):
output_mask = cv2.erode(input_masks[i, ...], erosion_kernel)
output_mask_lst.append(
torch.tensor(output_mask, dtype=dtype, device=device)[None]
)
# Concat back along the batch dimension.
output_masks = torch.cat(output_mask_lst, dim=0)
return output_masks.unsqueeze(dim=1)