Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,017 Bytes
b213d84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Any, List
import torch
from detectron2.config import CfgNode
from detectron2.structures import Instances
from .mask import MaskLoss
from .segm import SegmentationLoss
class MaskOrSegmentationLoss:
"""
Mask or segmentation loss as cross-entropy for raw unnormalized scores
given ground truth labels. Ground truth labels are either defined by coarse
segmentation annotation, or by mask annotation, depending on the config
value MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS
"""
def __init__(self, cfg: CfgNode):
"""
Initialize segmentation loss from configuration options
Args:
cfg (CfgNode): configuration options
"""
self.segm_trained_by_masks = (
cfg.MODEL.ROI_DENSEPOSE_HEAD.COARSE_SEGM_TRAINED_BY_MASKS
)
if self.segm_trained_by_masks:
self.mask_loss = MaskLoss()
self.segm_loss = SegmentationLoss(cfg)
def __call__(
self,
proposals_with_gt: List[Instances],
densepose_predictor_outputs: Any,
packed_annotations: Any,
) -> torch.Tensor:
"""
Compute segmentation loss as cross-entropy between aligned unnormalized
score estimates and ground truth; with ground truth given
either by masks, or by coarse segmentation annotations.
Args:
proposals_with_gt (list of Instances): detections with associated ground truth data
densepose_predictor_outputs: an object of a dataclass that contains predictor outputs
with estimated values; assumed to have the following attributes:
* coarse_segm - coarse segmentation estimates, tensor of shape [N, D, S, S]
packed_annotations: packed annotations for efficient loss computation
Return:
tensor: loss value as cross-entropy for raw unnormalized scores
given ground truth labels
"""
if self.segm_trained_by_masks:
return self.mask_loss(proposals_with_gt, densepose_predictor_outputs)
return self.segm_loss(
proposals_with_gt, densepose_predictor_outputs, packed_annotations
)
def fake_value(self, densepose_predictor_outputs: Any) -> torch.Tensor:
"""
Fake segmentation loss used when no suitable ground truth data
was found in a batch. The loss has a value 0 and is primarily used to
construct the computation graph, so that `DistributedDataParallel`
has similar graphs on all GPUs and can perform reduction properly.
Args:
densepose_predictor_outputs: DensePose predictor outputs, an object
of a dataclass that is assumed to have `coarse_segm`
attribute
Return:
Zero value loss with proper computation graph
"""
return densepose_predictor_outputs.coarse_segm.sum() * 0
|