|
""" |
|
Convert the aggregation results from the homography adaptation to GT labels. |
|
""" |
|
import sys |
|
|
|
sys.path.append("../") |
|
import os |
|
import yaml |
|
import argparse |
|
import numpy as np |
|
import h5py |
|
import torch |
|
from tqdm import tqdm |
|
|
|
from config.project_config import Config as cfg |
|
from model.line_detection import LineSegmentDetectionModule |
|
from model.metrics import super_nms |
|
from misc.train_utils import parse_h5_data |
|
|
|
|
|
def convert_raw_exported_predictions( |
|
input_data, grid_size=8, detect_thresh=1 / 65, topk=300 |
|
): |
|
"""Convert the exported junctions and heatmaps predictions |
|
to a standard format. |
|
Arguments: |
|
input_data: the raw data (dict) decoded from the hdf5 dataset |
|
outputs: dict containing required entries including: |
|
junctions_pred: Nx2 ndarray containing nms junction predictions. |
|
heatmap_pred: HxW ndarray containing predicted heatmaps |
|
valid_mask: HxW ndarray containing the valid mask |
|
""" |
|
|
|
|
|
|
|
if ("junc_prob_mean" in input_data.keys()) and ( |
|
"heatmap_prob_mean" in input_data.keys() |
|
): |
|
|
|
junc_prob = input_data["junc_prob_mean"] |
|
junc_pred_np = junc_prob[None, ...] |
|
junc_pred_np_nms = super_nms(junc_pred_np, grid_size, detect_thresh, topk) |
|
junctions = np.where(junc_pred_np_nms.squeeze()) |
|
junc_points_pred = np.concatenate( |
|
[junctions[0][..., None], junctions[1][..., None]], axis=-1 |
|
) |
|
|
|
|
|
heatmap_pred = input_data["heatmap_prob_mean"].squeeze() |
|
valid_mask = np.ones(heatmap_pred.shape, dtype=np.int32) |
|
|
|
|
|
else: |
|
|
|
junc_points_pred = np.where(input_data["junc_pred_nms"]) |
|
junc_points_pred = np.concatenate( |
|
[junc_points_pred[0][..., None], junc_points_pred[1][..., None]], axis=-1 |
|
) |
|
|
|
|
|
heatmap_pred = input_data["heatmap_pred"] |
|
valid_mask = input_data["valid_mask"] |
|
|
|
return { |
|
"junctions_pred": junc_points_pred, |
|
"heatmap_pred": heatmap_pred, |
|
"valid_mask": valid_mask, |
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("input_dataset", type=str, help="Name of the exported dataset.") |
|
parser.add_argument("output_dataset", type=str, help="Name of the output dataset.") |
|
parser.add_argument("config", type=str, help="Path to the model config.") |
|
args = parser.parse_args() |
|
|
|
|
|
exported_dataset_path = os.path.join(cfg.export_dataroot, args.input_dataset) |
|
if not os.path.exists(exported_dataset_path): |
|
raise ValueError("Missing input dataset: " + exported_dataset_path) |
|
exported_dataset = h5py.File(exported_dataset_path, "r") |
|
|
|
|
|
output_dataset_path = os.path.join(cfg.export_dataroot, args.output_dataset) |
|
|
|
device = torch.device("cuda") |
|
nms_device = torch.device("cuda") |
|
|
|
|
|
if not os.path.exists(args.config): |
|
raise ValueError("Missing config file: " + args.config) |
|
with open(args.config, "r") as f: |
|
config = yaml.safe_load(f) |
|
model_cfg = config["model_cfg"] |
|
line_detector_cfg = config["line_detector_cfg"] |
|
|
|
|
|
line_detector = LineSegmentDetectionModule(**line_detector_cfg) |
|
|
|
|
|
with h5py.File(output_dataset_path, "w") as output_dataset: |
|
for idx, output_key in enumerate( |
|
tqdm(list(exported_dataset.keys()), ascii=True) |
|
): |
|
|
|
data = parse_h5_data(exported_dataset[output_key]) |
|
|
|
|
|
converted_data = convert_raw_exported_predictions( |
|
data, |
|
grid_size=model_cfg["grid_size"], |
|
detect_thresh=model_cfg["detection_thresh"], |
|
) |
|
junctions_pred_raw = converted_data["junctions_pred"] |
|
heatmap_pred = converted_data["heatmap_pred"] |
|
valid_mask = converted_data["valid_mask"] |
|
|
|
line_map_pred, junctions_pred, heatmap_pred = line_detector.detect( |
|
junctions_pred_raw, heatmap_pred, device=device |
|
) |
|
if isinstance(line_map_pred, torch.Tensor): |
|
line_map_pred = line_map_pred.cpu().numpy() |
|
if isinstance(junctions_pred, torch.Tensor): |
|
junctions_pred = junctions_pred.cpu().numpy() |
|
if isinstance(heatmap_pred, torch.Tensor): |
|
heatmap_pred = heatmap_pred.cpu().numpy() |
|
|
|
output_data = {"junctions": junctions_pred, "line_map": line_map_pred} |
|
|
|
|
|
f_group = output_dataset.create_group(output_key) |
|
|
|
|
|
for key, output_data in output_data.items(): |
|
f_group.create_dataset(key, data=output_data, compression="gzip") |
|
|