Vincentqyw
fix: roma
358ab8f
raw
history blame
35.5 kB
"""
This file implements the training process and all the summaries
"""
import os
import numpy as np
import cv2
import torch
from torch.nn.functional import pixel_shuffle, softmax
from torch.utils.data import DataLoader
import torch.utils.data.dataloader as torch_loader
from tensorboardX import SummaryWriter
from .dataset.dataset_util import get_dataset
from .model.model_util import get_model
from .model.loss import TotalLoss, get_loss_and_weights
from .model.metrics import AverageMeter, Metrics, super_nms
from .model.lr_scheduler import get_lr_scheduler
from .misc.train_utils import (
convert_image,
get_latest_checkpoint,
remove_old_checkpoints,
)
def customized_collate_fn(batch):
"""Customized collate_fn."""
batch_keys = ["image", "junction_map", "heatmap", "valid_mask"]
list_keys = ["junctions", "line_map"]
outputs = {}
for key in batch_keys:
outputs[key] = torch_loader.default_collate([b[key] for b in batch])
for key in list_keys:
outputs[key] = [b[key] for b in batch]
return outputs
def restore_weights(model, state_dict, strict=True):
"""Restore weights in compatible mode."""
# Try to directly load state dict
try:
model.load_state_dict(state_dict, strict=strict)
# Deal with some version compatibility issue (catch version incompatible)
except:
err = model.load_state_dict(state_dict, strict=False)
# missing keys are those in model but not in state_dict
missing_keys = err.missing_keys
# Unexpected keys are those in state_dict but not in model
unexpected_keys = err.unexpected_keys
# Load mismatched keys manually
model_dict = model.state_dict()
for idx, key in enumerate(missing_keys):
dict_keys = [_ for _ in unexpected_keys if not "tracked" in _]
model_dict[key] = state_dict[dict_keys[idx]]
model.load_state_dict(model_dict)
return model
def train_net(args, dataset_cfg, model_cfg, output_path):
"""Main training function."""
# Add some version compatibility check
if model_cfg.get("weighting_policy") is None:
# Default to static
model_cfg["weighting_policy"] = "static"
# Get the train, val, test config
train_cfg = model_cfg["train"]
test_cfg = model_cfg["test"]
# Create train and test dataset
print("\t Initializing dataset...")
train_dataset, train_collate_fn = get_dataset("train", dataset_cfg)
test_dataset, test_collate_fn = get_dataset("test", dataset_cfg)
# Create the dataloader
train_loader = DataLoader(
train_dataset,
batch_size=train_cfg["batch_size"],
num_workers=8,
shuffle=True,
pin_memory=True,
collate_fn=train_collate_fn,
)
test_loader = DataLoader(
test_dataset,
batch_size=test_cfg.get("batch_size", 1),
num_workers=test_cfg.get("num_workers", 1),
shuffle=False,
pin_memory=False,
collate_fn=test_collate_fn,
)
print("\t Successfully intialized dataloaders.")
# Get the loss function and weight first
loss_funcs, loss_weights = get_loss_and_weights(model_cfg)
# If resume.
if args.resume:
# Create model and load the state dict
checkpoint = get_latest_checkpoint(args.resume_path, args.checkpoint_name)
model = get_model(model_cfg, loss_weights)
model = restore_weights(model, checkpoint["model_state_dict"])
model = model.cuda()
optimizer = torch.optim.Adam(
[{"params": model.parameters(), "initial_lr": model_cfg["learning_rate"]}],
model_cfg["learning_rate"],
amsgrad=True,
)
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
# Optionally get the learning rate scheduler
scheduler = get_lr_scheduler(
lr_decay=model_cfg.get("lr_decay", False),
lr_decay_cfg=model_cfg.get("lr_decay_cfg", None),
optimizer=optimizer,
)
# If we start to use learning rate scheduler from the middle
if (scheduler is not None) and (
checkpoint.get("scheduler_state_dict", None) is not None
):
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
start_epoch = checkpoint["epoch"] + 1
# Initialize all the components.
else:
# Create model and optimizer
model = get_model(model_cfg, loss_weights)
# Optionally get the pretrained wieghts
if args.pretrained:
print("\t [Debug] Loading pretrained weights...")
checkpoint = get_latest_checkpoint(
args.pretrained_path, args.checkpoint_name
)
# If auto weighting restore from non-auto weighting
model = restore_weights(model, checkpoint["model_state_dict"], strict=False)
print("\t [Debug] Finished loading pretrained weights!")
model = model.cuda()
optimizer = torch.optim.Adam(
[{"params": model.parameters(), "initial_lr": model_cfg["learning_rate"]}],
model_cfg["learning_rate"],
amsgrad=True,
)
# Optionally get the learning rate scheduler
scheduler = get_lr_scheduler(
lr_decay=model_cfg.get("lr_decay", False),
lr_decay_cfg=model_cfg.get("lr_decay_cfg", None),
optimizer=optimizer,
)
start_epoch = 0
print("\t Successfully initialized model")
# Define the total loss
policy = model_cfg.get("weighting_policy", "static")
loss_func = TotalLoss(loss_funcs, loss_weights, policy).cuda()
if "descriptor_decoder" in model_cfg:
metric_func = Metrics(
model_cfg["detection_thresh"],
model_cfg["prob_thresh"],
model_cfg["descriptor_loss_cfg"]["grid_size"],
desc_metric_lst="all",
)
else:
metric_func = Metrics(
model_cfg["detection_thresh"],
model_cfg["prob_thresh"],
model_cfg["grid_size"],
)
# Define the summary writer
logdir = os.path.join(output_path, "log")
writer = SummaryWriter(logdir=logdir)
# Start the training loop
for epoch in range(start_epoch, model_cfg["epochs"]):
# Record the learning rate
current_lr = optimizer.state_dict()["param_groups"][0]["lr"]
writer.add_scalar("LR/lr", current_lr, epoch)
# Train for one epochs
print("\n\n================== Training ====================")
train_single_epoch(
model=model,
model_cfg=model_cfg,
optimizer=optimizer,
loss_func=loss_func,
metric_func=metric_func,
train_loader=train_loader,
writer=writer,
epoch=epoch,
)
# Do the validation
print("\n\n================== Validation ==================")
validate(
model=model,
model_cfg=model_cfg,
loss_func=loss_func,
metric_func=metric_func,
val_loader=test_loader,
writer=writer,
epoch=epoch,
)
# Update the scheduler
if scheduler is not None:
scheduler.step()
# Save checkpoints
file_name = os.path.join(output_path, "checkpoint-epoch%03d-end.tar" % (epoch))
print("[Info] Saving checkpoint %s ..." % file_name)
save_dict = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"model_cfg": model_cfg,
}
if scheduler is not None:
save_dict.update({"scheduler_state_dict": scheduler.state_dict()})
torch.save(save_dict, file_name)
# Remove the outdated checkpoints
remove_old_checkpoints(output_path, model_cfg.get("max_ckpt", 15))
def train_single_epoch(
model, model_cfg, optimizer, loss_func, metric_func, train_loader, writer, epoch
):
"""Train for one epoch."""
# Switch the model to training mode
model.train()
# Initialize the average meter
compute_descriptors = loss_func.compute_descriptors
if compute_descriptors:
average_meter = AverageMeter(is_training=True, desc_metric_lst="all")
else:
average_meter = AverageMeter(is_training=True)
# The training loop
for idx, data in enumerate(train_loader):
if compute_descriptors:
junc_map = data["ref_junction_map"].cuda()
junc_map2 = data["target_junction_map"].cuda()
heatmap = data["ref_heatmap"].cuda()
heatmap2 = data["target_heatmap"].cuda()
line_points = data["ref_line_points"].cuda()
line_points2 = data["target_line_points"].cuda()
line_indices = data["ref_line_indices"].cuda()
valid_mask = data["ref_valid_mask"].cuda()
valid_mask2 = data["target_valid_mask"].cuda()
input_images = data["ref_image"].cuda()
input_images2 = data["target_image"].cuda()
# Run the forward pass
outputs = model(input_images)
outputs2 = model(input_images2)
# Compute losses
losses = loss_func.forward_descriptors(
outputs["junctions"],
outputs2["junctions"],
junc_map,
junc_map2,
outputs["heatmap"],
outputs2["heatmap"],
heatmap,
heatmap2,
line_points,
line_points2,
line_indices,
outputs["descriptors"],
outputs2["descriptors"],
epoch,
valid_mask,
valid_mask2,
)
else:
junc_map = data["junction_map"].cuda()
heatmap = data["heatmap"].cuda()
valid_mask = data["valid_mask"].cuda()
input_images = data["image"].cuda()
# Run the forward pass
outputs = model(input_images)
# Compute losses
losses = loss_func(
outputs["junctions"], junc_map, outputs["heatmap"], heatmap, valid_mask
)
total_loss = losses["total_loss"]
# Update the model
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
# Compute the global step
global_step = epoch * len(train_loader) + idx
############## Measure the metric error #########################
# Only do this when needed
if ((idx % model_cfg["disp_freq"]) == 0) or (
(idx % model_cfg["summary_freq"]) == 0
):
junc_np = convert_junc_predictions(
outputs["junctions"],
model_cfg["grid_size"],
model_cfg["detection_thresh"],
300,
)
junc_map_np = junc_map.cpu().numpy().transpose(0, 2, 3, 1)
# Always fetch only one channel (compatible with L1, L2, and CE)
if outputs["heatmap"].shape[1] == 2:
heatmap_np = softmax(outputs["heatmap"].detach(), dim=1).cpu().numpy()
heatmap_np = heatmap_np.transpose(0, 2, 3, 1)[:, :, :, 1:]
else:
heatmap_np = torch.sigmoid(outputs["heatmap"].detach())
heatmap_np = heatmap_np.cpu().numpy().transpose(0, 2, 3, 1)
heatmap_gt_np = heatmap.cpu().numpy().transpose(0, 2, 3, 1)
valid_mask_np = valid_mask.cpu().numpy().transpose(0, 2, 3, 1)
# Evaluate metric results
if compute_descriptors:
metric_func.evaluate(
junc_np["junc_pred"],
junc_np["junc_pred_nms"],
junc_map_np,
heatmap_np,
heatmap_gt_np,
valid_mask_np,
line_points,
line_points2,
outputs["descriptors"],
outputs2["descriptors"],
line_indices,
)
else:
metric_func.evaluate(
junc_np["junc_pred"],
junc_np["junc_pred_nms"],
junc_map_np,
heatmap_np,
heatmap_gt_np,
valid_mask_np,
)
# Update average meter
junc_loss = losses["junc_loss"].item()
heatmap_loss = losses["heatmap_loss"].item()
loss_dict = {
"junc_loss": junc_loss,
"heatmap_loss": heatmap_loss,
"total_loss": total_loss.item(),
}
if compute_descriptors:
descriptor_loss = losses["descriptor_loss"].item()
loss_dict["descriptor_loss"] = losses["descriptor_loss"].item()
average_meter.update(metric_func, loss_dict, num_samples=junc_map.shape[0])
# Display the progress
if (idx % model_cfg["disp_freq"]) == 0:
results = metric_func.metric_results
average = average_meter.average()
# Get gpu memory usage in GB
gpu_mem_usage = torch.cuda.max_memory_allocated() / (1024**3)
if compute_descriptors:
print(
"Epoch [%d / %d] Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), descriptor_loss=%.4f (%.4f), gpu_mem=%.4fGB"
% (
epoch,
model_cfg["epochs"],
idx,
len(train_loader),
total_loss.item(),
average["total_loss"],
junc_loss,
average["junc_loss"],
heatmap_loss,
average["heatmap_loss"],
descriptor_loss,
average["descriptor_loss"],
gpu_mem_usage,
)
)
else:
print(
"Epoch [%d / %d] Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), gpu_mem=%.4fGB"
% (
epoch,
model_cfg["epochs"],
idx,
len(train_loader),
total_loss.item(),
average["total_loss"],
junc_loss,
average["junc_loss"],
heatmap_loss,
average["heatmap_loss"],
gpu_mem_usage,
)
)
print(
"\t Junction precision=%.4f (%.4f) / recall=%.4f (%.4f)"
% (
results["junc_precision"],
average["junc_precision"],
results["junc_recall"],
average["junc_recall"],
)
)
print(
"\t Junction nms precision=%.4f (%.4f) / recall=%.4f (%.4f)"
% (
results["junc_precision_nms"],
average["junc_precision_nms"],
results["junc_recall_nms"],
average["junc_recall_nms"],
)
)
print(
"\t Heatmap precision=%.4f (%.4f) / recall=%.4f (%.4f)"
% (
results["heatmap_precision"],
average["heatmap_precision"],
results["heatmap_recall"],
average["heatmap_recall"],
)
)
if compute_descriptors:
print(
"\t Descriptors matching score=%.4f (%.4f)"
% (results["matching_score"], average["matching_score"])
)
# Record summaries
if (idx % model_cfg["summary_freq"]) == 0:
results = metric_func.metric_results
average = average_meter.average()
# Add the shared losses
scalar_summaries = {
"junc_loss": junc_loss,
"heatmap_loss": heatmap_loss,
"total_loss": total_loss.detach().cpu().numpy(),
"metrics": results,
"average": average,
}
# Add descriptor terms
if compute_descriptors:
scalar_summaries["descriptor_loss"] = descriptor_loss
scalar_summaries["w_desc"] = losses["w_desc"]
# Add weighting terms (even for static terms)
scalar_summaries["w_junc"] = losses["w_junc"]
scalar_summaries["w_heatmap"] = losses["w_heatmap"]
scalar_summaries["reg_loss"] = losses["reg_loss"].item()
num_images = 3
junc_pred_binary = (
junc_np["junc_pred"][:num_images, ...] > model_cfg["detection_thresh"]
)
junc_pred_nms_binary = (
junc_np["junc_pred_nms"][:num_images, ...]
> model_cfg["detection_thresh"]
)
image_summaries = {
"image": input_images.cpu().numpy()[:num_images, ...],
"valid_mask": valid_mask_np[:num_images, ...],
"junc_map_pred": junc_pred_binary,
"junc_map_pred_nms": junc_pred_nms_binary,
"junc_map_gt": junc_map_np[:num_images, ...],
"junc_prob_map": junc_np["junc_prob"][:num_images, ...],
"heatmap_pred": heatmap_np[:num_images, ...],
"heatmap_gt": heatmap_gt_np[:num_images, ...],
}
# Record the training summary
record_train_summaries(
writer, global_step, scalars=scalar_summaries, images=image_summaries
)
def validate(model, model_cfg, loss_func, metric_func, val_loader, writer, epoch):
"""Validation."""
# Switch the model to eval mode
model.eval()
# Initialize the average meter
compute_descriptors = loss_func.compute_descriptors
if compute_descriptors:
average_meter = AverageMeter(is_training=True, desc_metric_lst="all")
else:
average_meter = AverageMeter(is_training=True)
# The validation loop
for idx, data in enumerate(val_loader):
if compute_descriptors:
junc_map = data["ref_junction_map"].cuda()
junc_map2 = data["target_junction_map"].cuda()
heatmap = data["ref_heatmap"].cuda()
heatmap2 = data["target_heatmap"].cuda()
line_points = data["ref_line_points"].cuda()
line_points2 = data["target_line_points"].cuda()
line_indices = data["ref_line_indices"].cuda()
valid_mask = data["ref_valid_mask"].cuda()
valid_mask2 = data["target_valid_mask"].cuda()
input_images = data["ref_image"].cuda()
input_images2 = data["target_image"].cuda()
# Run the forward pass
with torch.no_grad():
outputs = model(input_images)
outputs2 = model(input_images2)
# Compute losses
losses = loss_func.forward_descriptors(
outputs["junctions"],
outputs2["junctions"],
junc_map,
junc_map2,
outputs["heatmap"],
outputs2["heatmap"],
heatmap,
heatmap2,
line_points,
line_points2,
line_indices,
outputs["descriptors"],
outputs2["descriptors"],
epoch,
valid_mask,
valid_mask2,
)
else:
junc_map = data["junction_map"].cuda()
heatmap = data["heatmap"].cuda()
valid_mask = data["valid_mask"].cuda()
input_images = data["image"].cuda()
# Run the forward pass
with torch.no_grad():
outputs = model(input_images)
# Compute losses
losses = loss_func(
outputs["junctions"],
junc_map,
outputs["heatmap"],
heatmap,
valid_mask,
)
total_loss = losses["total_loss"]
############## Measure the metric error #########################
junc_np = convert_junc_predictions(
outputs["junctions"],
model_cfg["grid_size"],
model_cfg["detection_thresh"],
300,
)
junc_map_np = junc_map.cpu().numpy().transpose(0, 2, 3, 1)
# Always fetch only one channel (compatible with L1, L2, and CE)
if outputs["heatmap"].shape[1] == 2:
heatmap_np = (
softmax(outputs["heatmap"].detach(), dim=1)
.cpu()
.numpy()
.transpose(0, 2, 3, 1)
)
heatmap_np = heatmap_np[:, :, :, 1:]
else:
heatmap_np = torch.sigmoid(outputs["heatmap"].detach())
heatmap_np = heatmap_np.cpu().numpy().transpose(0, 2, 3, 1)
heatmap_gt_np = heatmap.cpu().numpy().transpose(0, 2, 3, 1)
valid_mask_np = valid_mask.cpu().numpy().transpose(0, 2, 3, 1)
# Evaluate metric results
if compute_descriptors:
metric_func.evaluate(
junc_np["junc_pred"],
junc_np["junc_pred_nms"],
junc_map_np,
heatmap_np,
heatmap_gt_np,
valid_mask_np,
line_points,
line_points2,
outputs["descriptors"],
outputs2["descriptors"],
line_indices,
)
else:
metric_func.evaluate(
junc_np["junc_pred"],
junc_np["junc_pred_nms"],
junc_map_np,
heatmap_np,
heatmap_gt_np,
valid_mask_np,
)
# Update average meter
junc_loss = losses["junc_loss"].item()
heatmap_loss = losses["heatmap_loss"].item()
loss_dict = {
"junc_loss": junc_loss,
"heatmap_loss": heatmap_loss,
"total_loss": total_loss.item(),
}
if compute_descriptors:
descriptor_loss = losses["descriptor_loss"].item()
loss_dict["descriptor_loss"] = losses["descriptor_loss"].item()
average_meter.update(metric_func, loss_dict, num_samples=junc_map.shape[0])
# Display the progress
if (idx % model_cfg["disp_freq"]) == 0:
results = metric_func.metric_results
average = average_meter.average()
if compute_descriptors:
print(
"Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f), descriptor_loss=%.4f (%.4f)"
% (
idx,
len(val_loader),
total_loss.item(),
average["total_loss"],
junc_loss,
average["junc_loss"],
heatmap_loss,
average["heatmap_loss"],
descriptor_loss,
average["descriptor_loss"],
)
)
else:
print(
"Iter [%d / %d] loss=%.4f (%.4f), junc_loss=%.4f (%.4f), heatmap_loss=%.4f (%.4f)"
% (
idx,
len(val_loader),
total_loss.item(),
average["total_loss"],
junc_loss,
average["junc_loss"],
heatmap_loss,
average["heatmap_loss"],
)
)
print(
"\t Junction precision=%.4f (%.4f) / recall=%.4f (%.4f)"
% (
results["junc_precision"],
average["junc_precision"],
results["junc_recall"],
average["junc_recall"],
)
)
print(
"\t Junction nms precision=%.4f (%.4f) / recall=%.4f (%.4f)"
% (
results["junc_precision_nms"],
average["junc_precision_nms"],
results["junc_recall_nms"],
average["junc_recall_nms"],
)
)
print(
"\t Heatmap precision=%.4f (%.4f) / recall=%.4f (%.4f)"
% (
results["heatmap_precision"],
average["heatmap_precision"],
results["heatmap_recall"],
average["heatmap_recall"],
)
)
if compute_descriptors:
print(
"\t Descriptors matching score=%.4f (%.4f)"
% (results["matching_score"], average["matching_score"])
)
# Record summaries
average = average_meter.average()
scalar_summaries = {"average": average}
# Record the training summary
record_test_summaries(writer, epoch, scalar_summaries)
def convert_junc_predictions(predictions, grid_size, detect_thresh=1 / 65, topk=300):
"""Convert torch predictions to numpy arrays for evaluation."""
# Convert to probability outputs first
junc_prob = softmax(predictions.detach(), dim=1).cpu()
junc_pred = junc_prob[:, :-1, :, :]
junc_prob_np = junc_prob.numpy().transpose(0, 2, 3, 1)[:, :, :, :-1]
junc_prob_np = np.sum(junc_prob_np, axis=-1)
junc_pred_np = (
pixel_shuffle(junc_pred, grid_size).cpu().numpy().transpose(0, 2, 3, 1)
)
junc_pred_np_nms = super_nms(junc_pred_np, grid_size, detect_thresh, topk)
junc_pred_np = junc_pred_np.squeeze(-1)
return {
"junc_pred": junc_pred_np,
"junc_pred_nms": junc_pred_np_nms,
"junc_prob": junc_prob_np,
}
def record_train_summaries(writer, global_step, scalars, images):
"""Record training summaries."""
# Record the scalar summaries
results = scalars["metrics"]
average = scalars["average"]
# GPU memory part
# Get gpu memory usage in GB
gpu_mem_usage = torch.cuda.max_memory_allocated() / (1024**3)
writer.add_scalar("GPU/GPU_memory_usage", gpu_mem_usage, global_step)
# Loss part
writer.add_scalar("Train_loss/junc_loss", scalars["junc_loss"], global_step)
writer.add_scalar("Train_loss/heatmap_loss", scalars["heatmap_loss"], global_step)
writer.add_scalar("Train_loss/total_loss", scalars["total_loss"], global_step)
# Add regularization loss
if "reg_loss" in scalars.keys():
writer.add_scalar("Train_loss/reg_loss", scalars["reg_loss"], global_step)
# Add descriptor loss
if "descriptor_loss" in scalars.keys():
key = "descriptor_loss"
writer.add_scalar("Train_loss/%s" % (key), scalars[key], global_step)
writer.add_scalar("Train_loss_average/%s" % (key), average[key], global_step)
# Record weighting
for key in scalars.keys():
if "w_" in key:
writer.add_scalar("Train_weight/%s" % (key), scalars[key], global_step)
# Smoothed loss
writer.add_scalar("Train_loss_average/junc_loss", average["junc_loss"], global_step)
writer.add_scalar(
"Train_loss_average/heatmap_loss", average["heatmap_loss"], global_step
)
writer.add_scalar(
"Train_loss_average/total_loss", average["total_loss"], global_step
)
# Add smoothed descriptor loss
if "descriptor_loss" in average.keys():
writer.add_scalar(
"Train_loss_average/descriptor_loss",
average["descriptor_loss"],
global_step,
)
# Metrics part
writer.add_scalar(
"Train_metrics/junc_precision", results["junc_precision"], global_step
)
writer.add_scalar(
"Train_metrics/junc_precision_nms", results["junc_precision_nms"], global_step
)
writer.add_scalar("Train_metrics/junc_recall", results["junc_recall"], global_step)
writer.add_scalar(
"Train_metrics/junc_recall_nms", results["junc_recall_nms"], global_step
)
writer.add_scalar(
"Train_metrics/heatmap_precision", results["heatmap_precision"], global_step
)
writer.add_scalar(
"Train_metrics/heatmap_recall", results["heatmap_recall"], global_step
)
# Add descriptor metric
if "matching_score" in results.keys():
writer.add_scalar(
"Train_metrics/matching_score", results["matching_score"], global_step
)
# Average part
writer.add_scalar(
"Train_metrics_average/junc_precision", average["junc_precision"], global_step
)
writer.add_scalar(
"Train_metrics_average/junc_precision_nms",
average["junc_precision_nms"],
global_step,
)
writer.add_scalar(
"Train_metrics_average/junc_recall", average["junc_recall"], global_step
)
writer.add_scalar(
"Train_metrics_average/junc_recall_nms", average["junc_recall_nms"], global_step
)
writer.add_scalar(
"Train_metrics_average/heatmap_precision",
average["heatmap_precision"],
global_step,
)
writer.add_scalar(
"Train_metrics_average/heatmap_recall", average["heatmap_recall"], global_step
)
# Add smoothed descriptor metric
if "matching_score" in average.keys():
writer.add_scalar(
"Train_metrics_average/matching_score",
average["matching_score"],
global_step,
)
# Record the image summary
# Image part
image_tensor = convert_image(images["image"], 1)
valid_masks = convert_image(images["valid_mask"], -1)
writer.add_images("Train/images", image_tensor, global_step, dataformats="NCHW")
writer.add_images("Train/valid_map", valid_masks, global_step, dataformats="NHWC")
# Heatmap part
writer.add_images(
"Train/heatmap_gt",
convert_image(images["heatmap_gt"], -1),
global_step,
dataformats="NHWC",
)
writer.add_images(
"Train/heatmap_pred",
convert_image(images["heatmap_pred"], -1),
global_step,
dataformats="NHWC",
)
# Junction prediction part
junc_plots = plot_junction_detection(
image_tensor,
images["junc_map_pred"],
images["junc_map_pred_nms"],
images["junc_map_gt"],
)
writer.add_images(
"Train/junc_gt",
junc_plots["junc_gt_plot"] / 255.0,
global_step,
dataformats="NHWC",
)
writer.add_images(
"Train/junc_pred",
junc_plots["junc_pred_plot"] / 255.0,
global_step,
dataformats="NHWC",
)
writer.add_images(
"Train/junc_pred_nms",
junc_plots["junc_pred_nms_plot"] / 255.0,
global_step,
dataformats="NHWC",
)
writer.add_images(
"Train/junc_prob_map",
convert_image(images["junc_prob_map"][..., None], axis=-1),
global_step,
dataformats="NHWC",
)
def record_test_summaries(writer, epoch, scalars):
"""Record testing summaries."""
average = scalars["average"]
# Average loss
writer.add_scalar("Val_loss/junc_loss", average["junc_loss"], epoch)
writer.add_scalar("Val_loss/heatmap_loss", average["heatmap_loss"], epoch)
writer.add_scalar("Val_loss/total_loss", average["total_loss"], epoch)
# Add descriptor loss
if "descriptor_loss" in average.keys():
key = "descriptor_loss"
writer.add_scalar("Val_loss/%s" % (key), average[key], epoch)
# Average metrics
writer.add_scalar("Val_metrics/junc_precision", average["junc_precision"], epoch)
writer.add_scalar(
"Val_metrics/junc_precision_nms", average["junc_precision_nms"], epoch
)
writer.add_scalar("Val_metrics/junc_recall", average["junc_recall"], epoch)
writer.add_scalar("Val_metrics/junc_recall_nms", average["junc_recall_nms"], epoch)
writer.add_scalar(
"Val_metrics/heatmap_precision", average["heatmap_precision"], epoch
)
writer.add_scalar("Val_metrics/heatmap_recall", average["heatmap_recall"], epoch)
# Add descriptor metric
if "matching_score" in average.keys():
writer.add_scalar(
"Val_metrics/matching_score", average["matching_score"], epoch
)
def plot_junction_detection(
image_tensor, junc_pred_tensor, junc_pred_nms_tensor, junc_gt_tensor
):
"""Plot the junction points on images."""
# Get the batch_size
batch_size = image_tensor.shape[0]
# Process through batch dimension
junc_pred_lst = []
junc_pred_nms_lst = []
junc_gt_lst = []
for i in range(batch_size):
# Convert image to 255 uint8
image = (image_tensor[i, :, :, :] * 255.0).astype(np.uint8).transpose(1, 2, 0)
# Plot groundtruth onto image
junc_gt = junc_gt_tensor[i, ...]
coord_gt = np.where(junc_gt.squeeze() > 0)
points_gt = np.concatenate(
(coord_gt[0][..., None], coord_gt[1][..., None]), axis=1
)
plot_gt = image.copy()
for id in range(points_gt.shape[0]):
cv2.circle(
plot_gt,
tuple(np.flip(points_gt[id, :])),
3,
color=(255, 0, 0),
thickness=2,
)
junc_gt_lst.append(plot_gt[None, ...])
# Plot junc_pred
junc_pred = junc_pred_tensor[i, ...]
coord_pred = np.where(junc_pred > 0)
points_pred = np.concatenate(
(coord_pred[0][..., None], coord_pred[1][..., None]), axis=1
)
plot_pred = image.copy()
for id in range(points_pred.shape[0]):
cv2.circle(
plot_pred,
tuple(np.flip(points_pred[id, :])),
3,
color=(0, 255, 0),
thickness=2,
)
junc_pred_lst.append(plot_pred[None, ...])
# Plot junc_pred_nms
junc_pred_nms = junc_pred_nms_tensor[i, ...]
coord_pred_nms = np.where(junc_pred_nms > 0)
points_pred_nms = np.concatenate(
(coord_pred_nms[0][..., None], coord_pred_nms[1][..., None]), axis=1
)
plot_pred_nms = image.copy()
for id in range(points_pred_nms.shape[0]):
cv2.circle(
plot_pred_nms,
tuple(np.flip(points_pred_nms[id, :])),
3,
color=(0, 255, 0),
thickness=2,
)
junc_pred_nms_lst.append(plot_pred_nms[None, ...])
return {
"junc_gt_plot": np.concatenate(junc_gt_lst, axis=0),
"junc_pred_plot": np.concatenate(junc_pred_lst, axis=0),
"junc_pred_nms_plot": np.concatenate(junc_pred_nms_lst, axis=0),
}