LocoTrack / locotrack_pytorch /model_utils.py
Seokju Cho
initial commit
f1586f7
raw
history blame
6.23 kB
from typing import Sequence, Optional
import torch
import torch.nn.functional as F
from models.utils import convert_grid_coordinates
from data.evaluation_datasets import compute_tapvid_metrics
def huber_loss(tracks, target_points, occluded, delta=4.0, reduction_axes=(1, 2)):
"""Huber loss for point trajectories."""
error = tracks - target_points
distsqr = torch.sum(error ** 2, dim=-1)
dist = torch.sqrt(distsqr + 1e-12) # add eps to prevent nan
loss_huber = torch.where(dist < delta, distsqr / 2, delta * (torch.abs(dist) - delta / 2))
loss_huber = loss_huber * (1.0 - occluded.float())
if reduction_axes:
loss_huber = torch.mean(loss_huber, dim=reduction_axes)
return loss_huber
def prob_loss(tracks, expd, target_points, occluded, expected_dist_thresh=8.0, reduction_axes=(1, 2)):
"""Loss for classifying if a point is within pixel threshold of its target."""
err = torch.sum((tracks - target_points) ** 2, dim=-1)
invalid = (err > expected_dist_thresh ** 2).float()
logprob = F.binary_cross_entropy_with_logits(expd, invalid, reduction='none')
logprob = logprob * (1.0 - occluded.float())
if reduction_axes:
logprob = torch.mean(logprob, dim=reduction_axes)
return logprob
def tapnet_loss(points, occlusion, target_points, target_occ, shape, mask=None, expected_dist=None,
position_loss_weight=0.05, expected_dist_thresh=6.0, huber_loss_delta=4.0,
rebalance_factor=None, occlusion_loss_mask=None):
"""TAPNet loss."""
if mask is None:
mask = torch.tensor(1.0)
points = convert_grid_coordinates(points, shape[3:1:-1], (256, 256), coordinate_format='xy')
target_points = convert_grid_coordinates(target_points, shape[3:1:-1], (256, 256), coordinate_format='xy')
loss_huber = huber_loss(points, target_points, target_occ, delta=huber_loss_delta, reduction_axes=None) * mask
loss_huber = torch.mean(loss_huber) * position_loss_weight
if expected_dist is None:
loss_prob = torch.tensor(0.0)
else:
loss_prob = prob_loss(points.detach(), expected_dist, target_points, target_occ, expected_dist_thresh, reduction_axes=None) * mask
loss_prob = torch.mean(loss_prob)
target_occ = target_occ.to(dtype=occlusion.dtype)
loss_occ = F.binary_cross_entropy_with_logits(occlusion, target_occ, reduction='none') * mask
if rebalance_factor is not None:
loss_occ = loss_occ * ((1 + rebalance_factor) - rebalance_factor * target_occ)
if occlusion_loss_mask is not None:
loss_occ = loss_occ * occlusion_loss_mask
loss_occ = torch.mean(loss_occ)
return loss_huber, loss_occ, loss_prob
def tapir_loss(
batch,
output,
position_loss_weight=0.05,
expected_dist_thresh=6.0,
):
loss_scalars = {}
loss_huber, loss_occ, loss_prob = tapnet_loss(
output['tracks'],
output['occlusion'],
batch['target_points'],
batch['occluded'],
batch['video'].shape, # pytype: disable=attribute-error # numpy-scalars
expected_dist=output['expected_dist']
if 'expected_dist' in output
else None,
position_loss_weight=position_loss_weight,
expected_dist_thresh=expected_dist_thresh,
)
loss = loss_huber + loss_occ + loss_prob
loss_scalars['position_loss'] = loss_huber
loss_scalars['occlusion_loss'] = loss_occ
if 'expected_dist' in output:
loss_scalars['prob_loss'] = loss_prob
if 'unrefined_tracks' in output:
for l in range(len(output['unrefined_tracks'])):
loss_huber, loss_occ, loss_prob = tapnet_loss(
output['unrefined_tracks'][l],
output['unrefined_occlusion'][l],
batch['target_points'],
batch['occluded'],
batch['video'].shape, # pytype: disable=attribute-error # numpy-scalars
expected_dist=output['unrefined_expected_dist'][l]
if 'unrefined_expected_dist' in output
else None,
position_loss_weight=position_loss_weight,
expected_dist_thresh=expected_dist_thresh,
)
loss = loss + loss_huber + loss_occ + loss_prob
loss_scalars[f'position_loss_{l}'] = loss_huber
loss_scalars[f'occlusion_loss_{l}'] = loss_occ
if 'unrefined_expected_dist' in output:
loss_scalars[f'prob_loss_{l}'] = loss_prob
loss_scalars['loss'] = loss
return loss, loss_scalars
def eval_batch(
batch,
output,
eval_metrics_resolution = (256, 256),
query_first = False,
):
query_points = batch['query_points']
query_points = convert_grid_coordinates(
query_points,
(1,) + batch['video'].shape[2:4], # (1, height, width)
(1,) + eval_metrics_resolution, # (1, height, width)
coordinate_format='tyx',
)
gt_target_points = batch['target_points']
gt_target_points = convert_grid_coordinates(
gt_target_points,
batch['video'].shape[3:1:-1], # (width, height)
eval_metrics_resolution[::-1], # (width, height)
coordinate_format='xy',
)
gt_occluded = batch['occluded']
tracks = output['tracks']
tracks = convert_grid_coordinates(
tracks,
batch['video'].shape[3:1:-1], # (width, height)
eval_metrics_resolution[::-1], # (width, height)
coordinate_format='xy',
)
occlusion_logits = output['occlusion']
pred_occ = torch.sigmoid(occlusion_logits)
if 'expected_dist' in output:
expected_dist = output['expected_dist']
pred_occ = 1 - (1 - pred_occ) * (1 - torch.sigmoid(expected_dist))
pred_occ = pred_occ > 0.5 # threshold
query_mode = 'first' if query_first else 'strided'
metrics = compute_tapvid_metrics(
query_points=query_points.detach().cpu().numpy(),
gt_occluded=gt_occluded.detach().cpu().numpy(),
gt_tracks=gt_target_points.detach().cpu().numpy(),
pred_occluded=pred_occ.detach().cpu().numpy(),
pred_tracks=tracks.detach().cpu().numpy(),
query_mode=query_mode,
)
return metrics