File size: 6,017 Bytes
b213d84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
# Copyright (c) Facebook, Inc. and its affiliates.
# -*- coding: utf-8 -*-
import typing
from typing import Any, List
import fvcore
from fvcore.nn import activation_count, flop_count, parameter_count, parameter_count_table
from torch import nn
from detectron2.export import TracingAdapter
__all__ = [
"activation_count_operators",
"flop_count_operators",
"parameter_count_table",
"parameter_count",
"FlopCountAnalysis",
]
FLOPS_MODE = "flops"
ACTIVATIONS_MODE = "activations"
# Some extra ops to ignore from counting, including elementwise and reduction ops
_IGNORED_OPS = {
"aten::add",
"aten::add_",
"aten::argmax",
"aten::argsort",
"aten::batch_norm",
"aten::constant_pad_nd",
"aten::div",
"aten::div_",
"aten::exp",
"aten::log2",
"aten::max_pool2d",
"aten::meshgrid",
"aten::mul",
"aten::mul_",
"aten::neg",
"aten::nonzero_numpy",
"aten::reciprocal",
"aten::repeat_interleave",
"aten::rsub",
"aten::sigmoid",
"aten::sigmoid_",
"aten::softmax",
"aten::sort",
"aten::sqrt",
"aten::sub",
"torchvision::nms", # TODO estimate flop for nms
}
class FlopCountAnalysis(fvcore.nn.FlopCountAnalysis):
"""
Same as :class:`fvcore.nn.FlopCountAnalysis`, but supports detectron2 models.
"""
def __init__(self, model, inputs):
"""
Args:
model (nn.Module):
inputs (Any): inputs of the given model. Does not have to be tuple of tensors.
"""
wrapper = TracingAdapter(model, inputs, allow_non_tensor=True)
super().__init__(wrapper, wrapper.flattened_inputs)
self.set_op_handle(**{k: None for k in _IGNORED_OPS})
def flop_count_operators(model: nn.Module, inputs: list) -> typing.DefaultDict[str, float]:
"""
Implement operator-level flops counting using jit.
This is a wrapper of :func:`fvcore.nn.flop_count` and adds supports for standard
detection models in detectron2.
Please use :class:`FlopCountAnalysis` for more advanced functionalities.
Note:
The function runs the input through the model to compute flops.
The flops of a detection model is often input-dependent, for example,
the flops of box & mask head depends on the number of proposals &
the number of detected objects.
Therefore, the flops counting using a single input may not accurately
reflect the computation cost of a model. It's recommended to average
across a number of inputs.
Args:
model: a detectron2 model that takes `list[dict]` as input.
inputs (list[dict]): inputs to model, in detectron2's standard format.
Only "image" key will be used.
supported_ops (dict[str, Handle]): see documentation of :func:`fvcore.nn.flop_count`
Returns:
Counter: Gflop count per operator
"""
old_train = model.training
model.eval()
ret = FlopCountAnalysis(model, inputs).by_operator()
model.train(old_train)
return {k: v / 1e9 for k, v in ret.items()}
def activation_count_operators(
model: nn.Module, inputs: list, **kwargs
) -> typing.DefaultDict[str, float]:
"""
Implement operator-level activations counting using jit.
This is a wrapper of fvcore.nn.activation_count, that supports standard detection models
in detectron2.
Note:
The function runs the input through the model to compute activations.
The activations of a detection model is often input-dependent, for example,
the activations of box & mask head depends on the number of proposals &
the number of detected objects.
Args:
model: a detectron2 model that takes `list[dict]` as input.
inputs (list[dict]): inputs to model, in detectron2's standard format.
Only "image" key will be used.
Returns:
Counter: activation count per operator
"""
return _wrapper_count_operators(model=model, inputs=inputs, mode=ACTIVATIONS_MODE, **kwargs)
def _wrapper_count_operators(
model: nn.Module, inputs: list, mode: str, **kwargs
) -> typing.DefaultDict[str, float]:
# ignore some ops
supported_ops = {k: lambda *args, **kwargs: {} for k in _IGNORED_OPS}
supported_ops.update(kwargs.pop("supported_ops", {}))
kwargs["supported_ops"] = supported_ops
assert len(inputs) == 1, "Please use batch size=1"
tensor_input = inputs[0]["image"]
inputs = [{"image": tensor_input}] # remove other keys, in case there are any
old_train = model.training
if isinstance(model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)):
model = model.module
wrapper = TracingAdapter(model, inputs)
wrapper.eval()
if mode == FLOPS_MODE:
ret = flop_count(wrapper, (tensor_input,), **kwargs)
elif mode == ACTIVATIONS_MODE:
ret = activation_count(wrapper, (tensor_input,), **kwargs)
else:
raise NotImplementedError("Count for mode {} is not supported yet.".format(mode))
# compatible with change in fvcore
if isinstance(ret, tuple):
ret = ret[0]
model.train(old_train)
return ret
def find_unused_parameters(model: nn.Module, inputs: Any) -> List[str]:
"""
Given a model, find parameters that do not contribute
to the loss.
Args:
model: a model in training mode that returns losses
inputs: argument or a tuple of arguments. Inputs of the model
Returns:
list[str]: the name of unused parameters
"""
assert model.training
for _, prm in model.named_parameters():
prm.grad = None
if isinstance(inputs, tuple):
losses = model(*inputs)
else:
losses = model(inputs)
if isinstance(losses, dict):
losses = sum(losses.values())
losses.backward()
unused: List[str] = []
for name, prm in model.named_parameters():
if prm.grad is None:
unused.append(name)
prm.grad = None
return unused
|