|
""" |
|
This file implements the training process and all the summaries |
|
""" |
|
import os |
|
import numpy as np |
|
import cv2 |
|
import torch |
|
from torch.nn.functional import pixel_shuffle, softmax |
|
from torch.utils.data import DataLoader |
|
import torch.utils.data.dataloader as torch_loader |
|
from tensorboardX import SummaryWriter |
|
|
|
from .dataset.dataset_util import get_dataset |
|
from .model.model_util import get_model |
|
from .model.loss import TotalLoss, get_loss_and_weights |
|
from .model.metrics import AverageMeter, Metrics, super_nms |
|
from .model.lr_scheduler import get_lr_scheduler |
|
from .misc.train_utils import ( |
|
convert_image, |
|
get_latest_checkpoint, |
|
remove_old_checkpoints, |
|
) |
|
|
|
|
|
def customized_collate_fn(batch): |
|
"""Customized collate_fn.""" |
|
batch_keys = ["image", "junction_map", "heatmap", "valid_mask"] |
|
list_keys = ["junctions", "line_map"] |
|
|
|
outputs = {} |
|
for key in batch_keys: |
|
outputs[key] = torch_loader.default_collate([b[key] for b in batch]) |
|
for key in list_keys: |
|
outputs[key] = [b[key] for b in batch] |
|
|
|
return outputs |
|
|
|
|
|
def restore_weights(model, state_dict, strict=True): |
|
"""Restore weights in compatible mode.""" |
|
|
|
try: |
|
model.load_state_dict(state_dict, strict=strict) |
|
|
|
except: |
|
err = model.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
missing_keys = err.missing_keys |
|
|
|
unexpected_keys = err.unexpected_keys |
|
|
|
|
|
model_dict = model.state_dict() |
|
for idx, key in enumerate(missing_keys): |
|
dict_keys = [_ for _ in unexpected_keys if not "tracked" in _] |
|
model_dict[key] = state_dict[dict_keys[idx]] |
|
model.load_state_dict(model_dict) |
|
|
|
return model |
|
|
|
|
|
def train_net(args, dataset_cfg, model_cfg, output_path): |
|
"""Main training function.""" |
|
|
|
if model_cfg.get("weighting_policy") is None: |
|
|
|
model_cfg["weighting_policy"] = "static" |
|
|
|
|
|
train_cfg = model_cfg["train"] |
|
test_cfg = model_cfg["test"] |
|
|
|
|
|
print("\t Initializing dataset...") |
|
train_dataset, train_collate_fn = get_dataset("train", dataset_cfg) |
|
test_dataset, test_collate_fn = get_dataset("test", dataset_cfg) |
|
|
|
|
|
train_loader = DataLoader( |
|
train_dataset, |
|
batch_size=train_cfg["batch_size"], |
|
num_workers=8, |
|
shuffle=True, |
|
pin_memory=True, |
|
collate_fn=train_collate_fn, |
|
) |
|
test_loader = DataLoader( |
|
test_dataset, |
|
batch_size=test_cfg.get("batch_size", 1), |
|
num_workers=test_cfg.get("num_workers", 1), |
|
shuffle=False, |
|
pin_memory=False, |
|
collate_fn=test_collate_fn, |
|
) |
|
print("\t Successfully intialized dataloaders.") |
|
|
|
|
|
loss_funcs, loss_weights = get_loss_and_weights(model_cfg) |
|
|
|
|
|
if args.resume: |
|
|
|
checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name) |
|
model = get_model(model_cfg, loss_weights) |
|
model = restore_weights(model, checkpoint["model_state_dict"]) |
|
model = model.cuda() |
|
optimizer = torch.optim.Adam( |
|
[{"params": model.parameters(), "initial_lr": model_cfg["learning_rate"]}], |
|
model_cfg["learning_rate"], |
|
amsgrad=True, |
|
) |
|
optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) |
|
|
|
scheduler = get_lr_scheduler( |
|
lr_decay=model_cfg.get("lr_decay", False), |
|
lr_decay_cfg=model_cfg.get("lr_decay_cfg", None), |
|
optimizer=optimizer, |
|
) |
|
|
|
if (scheduler is not None) and ( |
|
checkpoint.get("scheduler_state_dict", None) is not None |
|
): |
|
scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) |
|
start_epoch = checkpoint["epoch"] + 1 |
|
|
|
else: |
|
|
|
model = get_model(model_cfg, loss_weights) |
|
|
|
if args.pretrained: |
|
print("\t [Debug] Loading pretrained weights...") |
|
checkpoint = get_latest_checkpoint( |
|
args.pretrained_path, args.checkpoint_name |
|
) |
|
|
|
model = restore_weights(model, checkpoint["model_state_dict"], strict=False) |
|
print("\t [Debug] Finished loading pretrained weights!") |
|
|
|
model = model.cuda() |
|
optimizer = torch.optim.Adam( |
|
[{"params": model.parameters(), "initial_lr": model_cfg["learning_rate"]}], |
|
model_cfg["learning_rate"], |
|
amsgrad=True, |
|
) |
|
|
|
scheduler = get_lr_scheduler( |
|
lr_decay=model_cfg.get("lr_decay", False), |
|
lr_decay_cfg=model_cfg.get("lr_decay_cfg", None), |
|
optimizer=optimizer, |
|
) |
|
start_epoch = 0 |
|
|
|
print("\t Successfully initialized model") |
|
|
|
|
|
policy = model_cfg.get("weighting_policy", "static") |
|
loss_func = TotalLoss(loss_funcs, loss_weights, policy).cuda() |
|
if "descriptor_decoder" in model_cfg: |
|
metric_func = Metrics( |
|
model_cfg["detection_thresh"], |
|
model_cfg["prob_thresh"], |
|
model_cfg["descriptor_loss_cfg"]["grid_size"], |
|
desc_metric_lst="all", |
|
) |
|
else: |
|
metric_func = Metrics( |
|
model_cfg["detection_thresh"], |
|
model_cfg["prob_thresh"], |
|
model_cfg["grid_size"], |
|
) |
|
|
|
|
|
logdir = os.path.join(output_path, "log") |
|
writer = SummaryWriter(logdir=logdir) |
|
|
|
|
|
for epoch in range(start_epoch, model_cfg["epochs"]): |
|
|
|
current_lr = optimizer.state_dict()["param_groups"][0]["lr"] |
|
writer.add_scalar("LR/lr", current_lr, epoch) |
|
|
|
|
|
print("\n\n================== Training ====================") |
|
train_single_epoch( |
|
model=model, |
|
model_cfg=model_cfg, |
|
optimizer=optimizer, |
|
loss_func=loss_func, |
|
metric_func=metric_func, |
|
train_loader=train_loader, |
|
writer=writer, |
|
epoch=epoch, |
|
) |
|
|
|
|
|
print("\n\n================== Validation ==================") |
|
validate( |
|
model=model, |
|
model_cfg=model_cfg, |
|
loss_func=loss_func, |
|
metric_func=metric_func, |
|
val_loader=test_loader, |
|
writer=writer, |
|
epoch=epoch, |
|
) |
|
|
|
|
|
if scheduler is not None: |
|
scheduler.step() |
|
|
|
|
|
file_name = os.path.join(output_path, "checkpoint-epoch%03d-end.tar" % (epoch)) |
|
print("[Info] Saving checkpoint %s ..." % file_name) |
|
save_dict = { |
|
"epoch": epoch, |
|
"model_state_dict": model.state_dict(), |
|
"optimizer_state_dict": optimizer.state_dict(), |
|
"model_cfg": model_cfg, |
|
} |
|
if scheduler is not None: |
|
save_dict.update({"scheduler_state_dict": scheduler.state_dict()}) |
|
torch.save(save_dict, file_name) |
|
|
|
|
|
remove_old_checkpoints(output_path, model_cfg.get("max_ckpt", 15)) |
|
|
|
|
|
def train_single_epoch( |
|
model, model_cfg, optimizer, loss_func, metric_func, train_loader, writer, epoch |
|
): |
|
"""Train for one epoch.""" |
|
|
|
model.train() |
|
|
|
|
|
compute_descriptors = loss_func.compute_descriptors |
|
if compute_descriptors: |
|
average_meter = AverageMeter(is_training=True, desc_metric_lst="all") |
|
else: |
|
average_meter = AverageMeter(is_training=True) |
|
|
|
|
|
for idx, data in enumerate(train_loader): |
|
if compute_descriptors: |
|
junc_map = data["ref_junction_map"].cuda() |
|
junc_map2 = data["target_junction_map"].cuda() |
|
heatmap = data["ref_heatmap"].cuda() |
|
heatmap2 = data["target_heatmap"].cuda() |
|
line_points = data["ref_line_points"].cuda() |
|
line_points2 = data["target_line_points"].cuda() |
|
line_indices = data["ref_line_indices"].cuda() |
|
valid_mask = data["ref_valid_mask"].cuda() |
|
valid_mask2 = data["target_valid_mask"].cuda() |
|
input_images = data["ref_image"].cuda() |
|
input_images2 = data["target_image"].cuda() |
|
|
|
|
|
outputs = model(input_images) |
|
outputs2 = model(input_images2) |
|
|
|
|
|
losses = loss_func.forward_descriptors( |
|
outputs["junctions"], |
|
outputs2["junctions"], |
|
junc_map, |
|
junc_map2, |
|
outputs["heatmap"], |
|
outputs2["heatmap"], |
|
heatmap, |
|
heatmap2, |
|
line_points, |
|
line_points2, |
|
line_indices, |
|
outputs["descriptors"], |
|
outputs2["descriptors"], |
|
epoch, |
|
valid_mask, |
|
valid_mask2, |
|
) |
|
else: |
|
junc_map = data["junction_map"].cuda() |
|
heatmap = data["heatmap"].cuda() |
|
valid_mask = data["valid_mask"].cuda() |
|
input_images = data["image"].cuda() |
|
|
|
|
|
outputs = model(input_images) |
|
|
|
|
|
losses = loss_func( |
|
outputs["junctions"], junc_map, outputs["heatmap"], heatmap, valid_mask |
|
) |
|
|
|
total_loss = losses["total_loss"] |
|
|
|
|
|
optimizer.zero_grad() |
|
total_loss.backward() |
|
optimizer.step() |
|
|
|
|
|
global_step = epoch * len(train_loader) + idx |
|
|
|
|
|
if ((idx % model_cfg["disp_freq"]) == 0) or ( |
|
(idx % model_cfg["summary_freq"]) == 0 |
|
): |
|
junc_np = convert_junc_predictions( |
|
outputs["junctions"], |
|
model_cfg["grid_size"], |
|
model_cfg["detection_thresh"], |
|
300, |
|
) |
|
junc_map_np = junc_map.cpu().numpy().transpose(0, 2, 3, 1) |
|
|
|
|
|
if outputs["heatmap"].shape[1] == 2: |
|
heatmap_np = softmax(outputs["heatmap"].detach(), dim=1).cpu().numpy() |
|
heatmap_np = heatmap_np.transpose(0, 2, 3, 1)[:, :, :, 1:] |
|
else: |
|
heatmap_np = torch.sigmoid(outputs["heatmap"].detach()) |
|
heatmap_np = heatmap_np.cpu().numpy().transpose(0, 2, 3, 1) |
|
|
|
heatmap_gt_np = heatmap.cpu().numpy().transpose(0, 2, 3, 1) |
|
valid_mask_np = valid_mask.cpu().numpy().transpose(0, 2, 3, 1) |
|
|
|
|
|
if compute_descriptors: |
|
metric_func.evaluate( |
|
junc_np["junc_pred"], |
|
junc_np["junc_pred_nms"], |
|
junc_map_np, |
|
heatmap_np, |
|
heatmap_gt_np, |
|
valid_mask_np, |
|
line_points, |
|
line_points2, |
|
outputs["descriptors"], |
|
outputs2["descriptors"], |
|
line_indices, |
|
) |
|
else: |
|
metric_func.evaluate( |
|
junc_np["junc_pred"], |
|
junc_np["junc_pred_nms"], |
|
junc_map_np, |
|
heatmap_np, |
|
heatmap_gt_np, |
|
valid_mask_np, |
|
) |
|
|
|
junc_loss = losses["junc_loss"].item() |
|
heatmap_loss = losses["heatmap_loss"].item() |
|
loss_dict = { |
|
"junc_loss": junc_loss, |
|
"heatmap_loss": heatmap_loss, |
|
"total_loss": total_loss.item(), |
|
} |
|
if compute_descriptors: |
|
descriptor_loss = losses["descriptor_loss"].item() |
|
loss_dict["descriptor_loss"] = losses["descriptor_loss"].item() |
|
|
|
average_meter.update(metric_func, loss_dict, num_samples=junc_map.shape[0]) |
|
|
|
|
|
if (idx % model_cfg["disp_freq"]) == 0: |
|
results = metric_func.metric_results |
|
average = average_meter.average() |
|
|
|
gpu_mem_usage = torch.cuda.max_memory_allocated() / (1024**3) |
|
if compute_descriptors: |
|
print( |
|
"Epoch [%d / %d] Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), descriptor_loss=%.4f (%.4f), gpu_mem=%.4fGB" |
|
% ( |
|
epoch, |
|
model_cfg["epochs"], |
|
idx, |
|
len(train_loader), |
|
total_loss.item(), |
|
average["total_loss"], |
|
junc_loss, |
|
average["junc_loss"], |
|
heatmap_loss, |
|
average["heatmap_loss"], |
|
descriptor_loss, |
|
average["descriptor_loss"], |
|
gpu_mem_usage, |
|
) |
|
) |
|
else: |
|
print( |
|
"Epoch [%d / %d] Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), gpu_mem=%.4fGB" |
|
% ( |
|
epoch, |
|
model_cfg["epochs"], |
|
idx, |
|
len(train_loader), |
|
total_loss.item(), |
|
average["total_loss"], |
|
junc_loss, |
|
average["junc_loss"], |
|
heatmap_loss, |
|
average["heatmap_loss"], |
|
gpu_mem_usage, |
|
) |
|
) |
|
print( |
|
"\t Junction precision=%.4f (%.4f) / recall=%.4f (%.4f)" |
|
% ( |
|
results["junc_precision"], |
|
average["junc_precision"], |
|
results["junc_recall"], |
|
average["junc_recall"], |
|
) |
|
) |
|
print( |
|
"\t Junction nms precision=%.4f (%.4f) / recall=%.4f (%.4f)" |
|
% ( |
|
results["junc_precision_nms"], |
|
average["junc_precision_nms"], |
|
results["junc_recall_nms"], |
|
average["junc_recall_nms"], |
|
) |
|
) |
|
print( |
|
"\t Heatmap precision=%.4f (%.4f) / recall=%.4f (%.4f)" |
|
% ( |
|
results["heatmap_precision"], |
|
average["heatmap_precision"], |
|
results["heatmap_recall"], |
|
average["heatmap_recall"], |
|
) |
|
) |
|
if compute_descriptors: |
|
print( |
|
"\t Descriptors matching score=%.4f (%.4f)" |
|
% (results["matching_score"], average["matching_score"]) |
|
) |
|
|
|
|
|
if (idx % model_cfg["summary_freq"]) == 0: |
|
results = metric_func.metric_results |
|
average = average_meter.average() |
|
|
|
scalar_summaries = { |
|
"junc_loss": junc_loss, |
|
"heatmap_loss": heatmap_loss, |
|
"total_loss": total_loss.detach().cpu().numpy(), |
|
"metrics": results, |
|
"average": average, |
|
} |
|
|
|
if compute_descriptors: |
|
scalar_summaries["descriptor_loss"] = descriptor_loss |
|
scalar_summaries["w_desc"] = losses["w_desc"] |
|
|
|
|
|
scalar_summaries["w_junc"] = losses["w_junc"] |
|
scalar_summaries["w_heatmap"] = losses["w_heatmap"] |
|
scalar_summaries["reg_loss"] = losses["reg_loss"].item() |
|
|
|
num_images = 3 |
|
junc_pred_binary = ( |
|
junc_np["junc_pred"][:num_images, ...] > model_cfg["detection_thresh"] |
|
) |
|
junc_pred_nms_binary = ( |
|
junc_np["junc_pred_nms"][:num_images, ...] |
|
> model_cfg["detection_thresh"] |
|
) |
|
image_summaries = { |
|
"image": input_images.cpu().numpy()[:num_images, ...], |
|
"valid_mask": valid_mask_np[:num_images, ...], |
|
"junc_map_pred": junc_pred_binary, |
|
"junc_map_pred_nms": junc_pred_nms_binary, |
|
"junc_map_gt": junc_map_np[:num_images, ...], |
|
"junc_prob_map": junc_np["junc_prob"][:num_images, ...], |
|
"heatmap_pred": heatmap_np[:num_images, ...], |
|
"heatmap_gt": heatmap_gt_np[:num_images, ...], |
|
} |
|
|
|
record_train_summaries( |
|
writer, global_step, scalars=scalar_summaries, images=image_summaries |
|
) |
|
|
|
|
|
def validate(model, model_cfg, loss_func, metric_func, val_loader, writer, epoch): |
|
"""Validation.""" |
|
|
|
model.eval() |
|
|
|
|
|
compute_descriptors = loss_func.compute_descriptors |
|
if compute_descriptors: |
|
average_meter = AverageMeter(is_training=True, desc_metric_lst="all") |
|
else: |
|
average_meter = AverageMeter(is_training=True) |
|
|
|
|
|
for idx, data in enumerate(val_loader): |
|
if compute_descriptors: |
|
junc_map = data["ref_junction_map"].cuda() |
|
junc_map2 = data["target_junction_map"].cuda() |
|
heatmap = data["ref_heatmap"].cuda() |
|
heatmap2 = data["target_heatmap"].cuda() |
|
line_points = data["ref_line_points"].cuda() |
|
line_points2 = data["target_line_points"].cuda() |
|
line_indices = data["ref_line_indices"].cuda() |
|
valid_mask = data["ref_valid_mask"].cuda() |
|
valid_mask2 = data["target_valid_mask"].cuda() |
|
input_images = data["ref_image"].cuda() |
|
input_images2 = data["target_image"].cuda() |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(input_images) |
|
outputs2 = model(input_images2) |
|
|
|
|
|
losses = loss_func.forward_descriptors( |
|
outputs["junctions"], |
|
outputs2["junctions"], |
|
junc_map, |
|
junc_map2, |
|
outputs["heatmap"], |
|
outputs2["heatmap"], |
|
heatmap, |
|
heatmap2, |
|
line_points, |
|
line_points2, |
|
line_indices, |
|
outputs["descriptors"], |
|
outputs2["descriptors"], |
|
epoch, |
|
valid_mask, |
|
valid_mask2, |
|
) |
|
else: |
|
junc_map = data["junction_map"].cuda() |
|
heatmap = data["heatmap"].cuda() |
|
valid_mask = data["valid_mask"].cuda() |
|
input_images = data["image"].cuda() |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(input_images) |
|
|
|
|
|
losses = loss_func( |
|
outputs["junctions"], |
|
junc_map, |
|
outputs["heatmap"], |
|
heatmap, |
|
valid_mask, |
|
) |
|
total_loss = losses["total_loss"] |
|
|
|
|
|
junc_np = convert_junc_predictions( |
|
outputs["junctions"], |
|
model_cfg["grid_size"], |
|
model_cfg["detection_thresh"], |
|
300, |
|
) |
|
junc_map_np = junc_map.cpu().numpy().transpose(0, 2, 3, 1) |
|
|
|
if outputs["heatmap"].shape[1] == 2: |
|
heatmap_np = ( |
|
softmax(outputs["heatmap"].detach(), dim=1) |
|
.cpu() |
|
.numpy() |
|
.transpose(0, 2, 3, 1) |
|
) |
|
heatmap_np = heatmap_np[:, :, :, 1:] |
|
else: |
|
heatmap_np = torch.sigmoid(outputs["heatmap"].detach()) |
|
heatmap_np = heatmap_np.cpu().numpy().transpose(0, 2, 3, 1) |
|
|
|
heatmap_gt_np = heatmap.cpu().numpy().transpose(0, 2, 3, 1) |
|
valid_mask_np = valid_mask.cpu().numpy().transpose(0, 2, 3, 1) |
|
|
|
|
|
if compute_descriptors: |
|
metric_func.evaluate( |
|
junc_np["junc_pred"], |
|
junc_np["junc_pred_nms"], |
|
junc_map_np, |
|
heatmap_np, |
|
heatmap_gt_np, |
|
valid_mask_np, |
|
line_points, |
|
line_points2, |
|
outputs["descriptors"], |
|
outputs2["descriptors"], |
|
line_indices, |
|
) |
|
else: |
|
metric_func.evaluate( |
|
junc_np["junc_pred"], |
|
junc_np["junc_pred_nms"], |
|
junc_map_np, |
|
heatmap_np, |
|
heatmap_gt_np, |
|
valid_mask_np, |
|
) |
|
|
|
junc_loss = losses["junc_loss"].item() |
|
heatmap_loss = losses["heatmap_loss"].item() |
|
loss_dict = { |
|
"junc_loss": junc_loss, |
|
"heatmap_loss": heatmap_loss, |
|
"total_loss": total_loss.item(), |
|
} |
|
if compute_descriptors: |
|
descriptor_loss = losses["descriptor_loss"].item() |
|
loss_dict["descriptor_loss"] = losses["descriptor_loss"].item() |
|
average_meter.update(metric_func, loss_dict, num_samples=junc_map.shape[0]) |
|
|
|
|
|
if (idx % model_cfg["disp_freq"]) == 0: |
|
results = metric_func.metric_results |
|
average = average_meter.average() |
|
if compute_descriptors: |
|
print( |
|
"Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), descriptor_loss=%.4f (%.4f)" |
|
% ( |
|
idx, |
|
len(val_loader), |
|
total_loss.item(), |
|
average["total_loss"], |
|
junc_loss, |
|
average["junc_loss"], |
|
heatmap_loss, |
|
average["heatmap_loss"], |
|
descriptor_loss, |
|
average["descriptor_loss"], |
|
) |
|
) |
|
else: |
|
print( |
|
"Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f)" |
|
% ( |
|
idx, |
|
len(val_loader), |
|
total_loss.item(), |
|
average["total_loss"], |
|
junc_loss, |
|
average["junc_loss"], |
|
heatmap_loss, |
|
average["heatmap_loss"], |
|
) |
|
) |
|
print( |
|
"\t Junction precision=%.4f (%.4f) / recall=%.4f (%.4f)" |
|
% ( |
|
results["junc_precision"], |
|
average["junc_precision"], |
|
results["junc_recall"], |
|
average["junc_recall"], |
|
) |
|
) |
|
print( |
|
"\t Junction nms precision=%.4f (%.4f) / recall=%.4f (%.4f)" |
|
% ( |
|
results["junc_precision_nms"], |
|
average["junc_precision_nms"], |
|
results["junc_recall_nms"], |
|
average["junc_recall_nms"], |
|
) |
|
) |
|
print( |
|
"\t Heatmap precision=%.4f (%.4f) / recall=%.4f (%.4f)" |
|
% ( |
|
results["heatmap_precision"], |
|
average["heatmap_precision"], |
|
results["heatmap_recall"], |
|
average["heatmap_recall"], |
|
) |
|
) |
|
if compute_descriptors: |
|
print( |
|
"\t Descriptors matching score=%.4f (%.4f)" |
|
% (results["matching_score"], average["matching_score"]) |
|
) |
|
|
|
|
|
average = average_meter.average() |
|
scalar_summaries = {"average": average} |
|
|
|
record_test_summaries(writer, epoch, scalar_summaries) |
|
|
|
|
|
def convert_junc_predictions(predictions, grid_size, detect_thresh=1 / 65, topk=300): |
|
"""Convert torch predictions to numpy arrays for evaluation.""" |
|
|
|
junc_prob = softmax(predictions.detach(), dim=1).cpu() |
|
junc_pred = junc_prob[:, :-1, :, :] |
|
|
|
junc_prob_np = junc_prob.numpy().transpose(0, 2, 3, 1)[:, :, :, :-1] |
|
junc_prob_np = np.sum(junc_prob_np, axis=-1) |
|
junc_pred_np = ( |
|
pixel_shuffle(junc_pred, grid_size).cpu().numpy().transpose(0, 2, 3, 1) |
|
) |
|
junc_pred_np_nms = super_nms(junc_pred_np, grid_size, detect_thresh, topk) |
|
junc_pred_np = junc_pred_np.squeeze(-1) |
|
|
|
return { |
|
"junc_pred": junc_pred_np, |
|
"junc_pred_nms": junc_pred_np_nms, |
|
"junc_prob": junc_prob_np, |
|
} |
|
|
|
|
|
def record_train_summaries(writer, global_step, scalars, images): |
|
"""Record training summaries.""" |
|
|
|
results = scalars["metrics"] |
|
average = scalars["average"] |
|
|
|
|
|
|
|
gpu_mem_usage = torch.cuda.max_memory_allocated() / (1024**3) |
|
writer.add_scalar("GPU/GPU_memory_usage", gpu_mem_usage, global_step) |
|
|
|
|
|
writer.add_scalar("Train_loss/junc_loss", scalars["junc_loss"], global_step) |
|
writer.add_scalar("Train_loss/heatmap_loss", scalars["heatmap_loss"], global_step) |
|
writer.add_scalar("Train_loss/total_loss", scalars["total_loss"], global_step) |
|
|
|
if "reg_loss" in scalars.keys(): |
|
writer.add_scalar("Train_loss/reg_loss", scalars["reg_loss"], global_step) |
|
|
|
if "descriptor_loss" in scalars.keys(): |
|
key = "descriptor_loss" |
|
writer.add_scalar("Train_loss/%s" % (key), scalars[key], global_step) |
|
writer.add_scalar("Train_loss_average/%s" % (key), average[key], global_step) |
|
|
|
|
|
for key in scalars.keys(): |
|
if "w_" in key: |
|
writer.add_scalar("Train_weight/%s" % (key), scalars[key], global_step) |
|
|
|
|
|
writer.add_scalar("Train_loss_average/junc_loss", average["junc_loss"], global_step) |
|
writer.add_scalar( |
|
"Train_loss_average/heatmap_loss", average["heatmap_loss"], global_step |
|
) |
|
writer.add_scalar( |
|
"Train_loss_average/total_loss", average["total_loss"], global_step |
|
) |
|
|
|
if "descriptor_loss" in average.keys(): |
|
writer.add_scalar( |
|
"Train_loss_average/descriptor_loss", |
|
average["descriptor_loss"], |
|
global_step, |
|
) |
|
|
|
|
|
writer.add_scalar( |
|
"Train_metrics/junc_precision", results["junc_precision"], global_step |
|
) |
|
writer.add_scalar( |
|
"Train_metrics/junc_precision_nms", results["junc_precision_nms"], global_step |
|
) |
|
writer.add_scalar("Train_metrics/junc_recall", results["junc_recall"], global_step) |
|
writer.add_scalar( |
|
"Train_metrics/junc_recall_nms", results["junc_recall_nms"], global_step |
|
) |
|
writer.add_scalar( |
|
"Train_metrics/heatmap_precision", results["heatmap_precision"], global_step |
|
) |
|
writer.add_scalar( |
|
"Train_metrics/heatmap_recall", results["heatmap_recall"], global_step |
|
) |
|
|
|
if "matching_score" in results.keys(): |
|
writer.add_scalar( |
|
"Train_metrics/matching_score", results["matching_score"], global_step |
|
) |
|
|
|
|
|
writer.add_scalar( |
|
"Train_metrics_average/junc_precision", average["junc_precision"], global_step |
|
) |
|
writer.add_scalar( |
|
"Train_metrics_average/junc_precision_nms", |
|
average["junc_precision_nms"], |
|
global_step, |
|
) |
|
writer.add_scalar( |
|
"Train_metrics_average/junc_recall", average["junc_recall"], global_step |
|
) |
|
writer.add_scalar( |
|
"Train_metrics_average/junc_recall_nms", average["junc_recall_nms"], global_step |
|
) |
|
writer.add_scalar( |
|
"Train_metrics_average/heatmap_precision", |
|
average["heatmap_precision"], |
|
global_step, |
|
) |
|
writer.add_scalar( |
|
"Train_metrics_average/heatmap_recall", average["heatmap_recall"], global_step |
|
) |
|
|
|
if "matching_score" in average.keys(): |
|
writer.add_scalar( |
|
"Train_metrics_average/matching_score", |
|
average["matching_score"], |
|
global_step, |
|
) |
|
|
|
|
|
|
|
image_tensor = convert_image(images["image"], 1) |
|
valid_masks = convert_image(images["valid_mask"], -1) |
|
writer.add_images("Train/images", image_tensor, global_step, dataformats="NCHW") |
|
writer.add_images("Train/valid_map", valid_masks, global_step, dataformats="NHWC") |
|
|
|
|
|
writer.add_images( |
|
"Train/heatmap_gt", |
|
convert_image(images["heatmap_gt"], -1), |
|
global_step, |
|
dataformats="NHWC", |
|
) |
|
writer.add_images( |
|
"Train/heatmap_pred", |
|
convert_image(images["heatmap_pred"], -1), |
|
global_step, |
|
dataformats="NHWC", |
|
) |
|
|
|
|
|
junc_plots = plot_junction_detection( |
|
image_tensor, |
|
images["junc_map_pred"], |
|
images["junc_map_pred_nms"], |
|
images["junc_map_gt"], |
|
) |
|
writer.add_images( |
|
"Train/junc_gt", |
|
junc_plots["junc_gt_plot"] / 255.0, |
|
global_step, |
|
dataformats="NHWC", |
|
) |
|
writer.add_images( |
|
"Train/junc_pred", |
|
junc_plots["junc_pred_plot"] / 255.0, |
|
global_step, |
|
dataformats="NHWC", |
|
) |
|
writer.add_images( |
|
"Train/junc_pred_nms", |
|
junc_plots["junc_pred_nms_plot"] / 255.0, |
|
global_step, |
|
dataformats="NHWC", |
|
) |
|
writer.add_images( |
|
"Train/junc_prob_map", |
|
convert_image(images["junc_prob_map"][..., None], axis=-1), |
|
global_step, |
|
dataformats="NHWC", |
|
) |
|
|
|
|
|
def record_test_summaries(writer, epoch, scalars): |
|
"""Record testing summaries.""" |
|
average = scalars["average"] |
|
|
|
|
|
writer.add_scalar("Val_loss/junc_loss", average["junc_loss"], epoch) |
|
writer.add_scalar("Val_loss/heatmap_loss", average["heatmap_loss"], epoch) |
|
writer.add_scalar("Val_loss/total_loss", average["total_loss"], epoch) |
|
|
|
if "descriptor_loss" in average.keys(): |
|
key = "descriptor_loss" |
|
writer.add_scalar("Val_loss/%s" % (key), average[key], epoch) |
|
|
|
|
|
writer.add_scalar("Val_metrics/junc_precision", average["junc_precision"], epoch) |
|
writer.add_scalar( |
|
"Val_metrics/junc_precision_nms", average["junc_precision_nms"], epoch |
|
) |
|
writer.add_scalar("Val_metrics/junc_recall", average["junc_recall"], epoch) |
|
writer.add_scalar("Val_metrics/junc_recall_nms", average["junc_recall_nms"], epoch) |
|
writer.add_scalar( |
|
"Val_metrics/heatmap_precision", average["heatmap_precision"], epoch |
|
) |
|
writer.add_scalar("Val_metrics/heatmap_recall", average["heatmap_recall"], epoch) |
|
|
|
if "matching_score" in average.keys(): |
|
writer.add_scalar( |
|
"Val_metrics/matching_score", average["matching_score"], epoch |
|
) |
|
|
|
|
|
def plot_junction_detection( |
|
image_tensor, junc_pred_tensor, junc_pred_nms_tensor, junc_gt_tensor |
|
): |
|
"""Plot the junction points on images.""" |
|
|
|
batch_size = image_tensor.shape[0] |
|
|
|
|
|
junc_pred_lst = [] |
|
junc_pred_nms_lst = [] |
|
junc_gt_lst = [] |
|
for i in range(batch_size): |
|
|
|
image = (image_tensor[i, :, :, :] * 255.0).astype(np.uint8).transpose(1, 2, 0) |
|
|
|
|
|
junc_gt = junc_gt_tensor[i, ...] |
|
coord_gt = np.where(junc_gt.squeeze() > 0) |
|
points_gt = np.concatenate( |
|
(coord_gt[0][..., None], coord_gt[1][..., None]), axis=1 |
|
) |
|
plot_gt = image.copy() |
|
for id in range(points_gt.shape[0]): |
|
cv2.circle( |
|
plot_gt, |
|
tuple(np.flip(points_gt[id, :])), |
|
3, |
|
color=(255, 0, 0), |
|
thickness=2, |
|
) |
|
junc_gt_lst.append(plot_gt[None, ...]) |
|
|
|
|
|
junc_pred = junc_pred_tensor[i, ...] |
|
coord_pred = np.where(junc_pred > 0) |
|
points_pred = np.concatenate( |
|
(coord_pred[0][..., None], coord_pred[1][..., None]), axis=1 |
|
) |
|
plot_pred = image.copy() |
|
for id in range(points_pred.shape[0]): |
|
cv2.circle( |
|
plot_pred, |
|
tuple(np.flip(points_pred[id, :])), |
|
3, |
|
color=(0, 255, 0), |
|
thickness=2, |
|
) |
|
junc_pred_lst.append(plot_pred[None, ...]) |
|
|
|
|
|
junc_pred_nms = junc_pred_nms_tensor[i, ...] |
|
coord_pred_nms = np.where(junc_pred_nms > 0) |
|
points_pred_nms = np.concatenate( |
|
(coord_pred_nms[0][..., None], coord_pred_nms[1][..., None]), axis=1 |
|
) |
|
plot_pred_nms = image.copy() |
|
for id in range(points_pred_nms.shape[0]): |
|
cv2.circle( |
|
plot_pred_nms, |
|
tuple(np.flip(points_pred_nms[id, :])), |
|
3, |
|
color=(0, 255, 0), |
|
thickness=2, |
|
) |
|
junc_pred_nms_lst.append(plot_pred_nms[None, ...]) |
|
|
|
return { |
|
"junc_gt_plot": np.concatenate(junc_gt_lst, axis=0), |
|
"junc_pred_plot": np.concatenate(junc_pred_lst, axis=0), |
|
"junc_pred_nms_plot": np.concatenate(junc_pred_nms_lst, axis=0), |
|
} |
|
|