Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import typing | |
from collections import defaultdict | |
from typing import Any, Counter, DefaultDict, Dict, Optional, Tuple, Union | |
import torch.nn as nn | |
from rich import box | |
from rich.console import Console | |
from rich.table import Table | |
from torch import Tensor | |
from .jit_analysis import JitModelAnalysis | |
from .jit_handles import (Handle, addmm_flop_jit, batchnorm_flop_jit, | |
bmm_flop_jit, conv_flop_jit, einsum_flop_jit, | |
elementwise_flop_counter, generic_activation_jit, | |
linear_flop_jit, matmul_flop_jit, norm_flop_counter) | |
# A dictionary that maps supported operations to their flop count jit handles. | |
_DEFAULT_SUPPORTED_FLOP_OPS: Dict[str, Handle] = { | |
'aten::addmm': addmm_flop_jit, | |
'aten::bmm': bmm_flop_jit, | |
'aten::_convolution': conv_flop_jit, | |
'aten::einsum': einsum_flop_jit, | |
'aten::matmul': matmul_flop_jit, | |
'aten::mm': matmul_flop_jit, | |
'aten::linear': linear_flop_jit, | |
# You might want to ignore BN flops due to inference-time fusion. | |
# Use `set_op_handle("aten::batch_norm", None) | |
'aten::batch_norm': batchnorm_flop_jit, | |
'aten::group_norm': norm_flop_counter(2), | |
'aten::layer_norm': norm_flop_counter(2), | |
'aten::instance_norm': norm_flop_counter(1), | |
'aten::upsample_nearest2d': elementwise_flop_counter(0, 1), | |
'aten::upsample_bilinear2d': elementwise_flop_counter(0, 4), | |
'aten::adaptive_avg_pool2d': elementwise_flop_counter(1, 0), | |
'aten::grid_sampler': elementwise_flop_counter(0, 4), # assume bilinear | |
} | |
# A dictionary that maps supported operations to | |
# their activation count handles. | |
_DEFAULT_SUPPORTED_ACT_OPS: Dict[str, Handle] = { | |
'aten::_convolution': generic_activation_jit('conv'), | |
'aten::addmm': generic_activation_jit(), | |
'aten::bmm': generic_activation_jit(), | |
'aten::einsum': generic_activation_jit(), | |
'aten::matmul': generic_activation_jit(), | |
'aten::linear': generic_activation_jit(), | |
} | |
class FlopAnalyzer(JitModelAnalysis): | |
"""Provides access to per-submodule model flop count obtained by tracing a | |
model with pytorch's jit tracing functionality. | |
By default, comes with standard flop counters for a few common operators. | |
Note: | |
- Flop is not a well-defined concept. We just produce our best | |
estimate. | |
- We count one fused multiply-add as one flop. | |
Handles for additional operators may be added, or the default ones | |
overwritten, using the ``.set_op_handle(name, func)`` method. | |
See the method documentation for details. | |
Flop counts can be obtained as: | |
- ``.total(module_name="")``: total flop count for the module | |
- ``.by_operator(module_name="")``: flop counts for the module, as a | |
Counter over different operator types | |
- ``.by_module()``: Counter of flop counts for all submodules | |
- ``.by_module_and_operator()``: dictionary indexed by descendant of | |
Counters over different operator types | |
An operator is treated as within a module if it is executed inside the | |
module's ``__call__`` method. Note that this does not include calls to | |
other methods of the module or explicit calls to ``module.forward(...)``. | |
Modified from | |
https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/flop_count.py | |
Args: | |
model (nn.Module): The model to analyze. | |
inputs (Union[Tensor, Tuple[Tensor, ...]]): The input to the model. | |
Examples: | |
>>> import torch.nn as nn | |
>>> import torch | |
>>> class TestModel(nn.Module): | |
... def __init__(self): | |
... super().__init__() | |
... self.fc = nn.Linear(in_features=1000, out_features=10) | |
... self.conv = nn.Conv2d( | |
... in_channels=3, out_channels=10, kernel_size=1 | |
... ) | |
... self.act = nn.ReLU() | |
... def forward(self, x): | |
... return self.fc(self.act(self.conv(x)).flatten(1)) | |
>>> model = TestModel() | |
>>> inputs = (torch.randn((1,3,10,10)),) | |
>>> flops = FlopAnalyzer(model, inputs) | |
>>> flops.total() | |
13000 | |
>>> flops.total("fc") | |
10000 | |
>>> flops.by_operator() | |
Counter({"addmm" : 10000, "conv" : 3000}) | |
>>> flops.by_module() | |
Counter({"" : 13000, "fc" : 10000, "conv" : 3000, "act" : 0}) | |
>>> flops.by_module_and_operator() | |
{"" : Counter({"addmm" : 10000, "conv" : 3000}), | |
"fc" : Counter({"addmm" : 10000}), | |
"conv" : Counter({"conv" : 3000}), | |
"act" : Counter() | |
} | |
""" | |
def __init__( | |
self, | |
model: nn.Module, | |
inputs: Union[Tensor, Tuple[Tensor, ...]], | |
) -> None: | |
super().__init__(model=model, inputs=inputs) | |
self.set_op_handle(**_DEFAULT_SUPPORTED_FLOP_OPS) | |
__init__.__doc__ = JitModelAnalysis.__init__.__doc__ | |
class ActivationAnalyzer(JitModelAnalysis): | |
"""Provides access to per-submodule model activation count obtained by | |
tracing a model with pytorch's jit tracing functionality. | |
By default, comes with standard activation counters for convolutional and | |
dot-product operators. Handles for additional operators may be added, or | |
the default ones overwritten, using the ``.set_op_handle(name, func)`` | |
method. See the method documentation for details. Activation counts can be | |
obtained as: | |
- ``.total(module_name="")``: total activation count for a module | |
- ``.by_operator(module_name="")``: activation counts for the module, | |
as a Counter over different operator types | |
- ``.by_module()``: Counter of activation counts for all submodules | |
- ``.by_module_and_operator()``: dictionary indexed by descendant of | |
Counters over different operator types | |
An operator is treated as within a module if it is executed inside the | |
module's ``__call__`` method. Note that this does not include calls to | |
other methods of the module or explicit calls to ``module.forward(...)``. | |
Modified from | |
https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/activation_count.py | |
Args: | |
model (nn.Module): The model to analyze. | |
inputs (Union[Tensor, Tuple[Tensor, ...]]): The input to the model. | |
Examples: | |
>>> import torch.nn as nn | |
>>> import torch | |
>>> class TestModel(nn.Module): | |
... def __init__(self): | |
... super().__init__() | |
... self.fc = nn.Linear(in_features=1000, out_features=10) | |
... self.conv = nn.Conv2d( | |
... in_channels=3, out_channels=10, kernel_size=1 | |
... ) | |
... self.act = nn.ReLU() | |
... def forward(self, x): | |
... return self.fc(self.act(self.conv(x)).flatten(1)) | |
>>> model = TestModel() | |
>>> inputs = (torch.randn((1,3,10,10)),) | |
>>> acts = ActivationAnalyzer(model, inputs) | |
>>> acts.total() | |
1010 | |
>>> acts.total("fc") | |
10 | |
>>> acts.by_operator() | |
Counter({"conv" : 1000, "addmm" : 10}) | |
>>> acts.by_module() | |
Counter({"" : 1010, "fc" : 10, "conv" : 1000, "act" : 0}) | |
>>> acts.by_module_and_operator() | |
{"" : Counter({"conv" : 1000, "addmm" : 10}), | |
"fc" : Counter({"addmm" : 10}), | |
"conv" : Counter({"conv" : 1000}), | |
"act" : Counter() | |
} | |
""" | |
def __init__( | |
self, | |
model: nn.Module, | |
inputs: Union[Tensor, Tuple[Tensor, ...]], | |
) -> None: | |
super().__init__(model=model, inputs=inputs) | |
self.set_op_handle(**_DEFAULT_SUPPORTED_ACT_OPS) | |
__init__.__doc__ = JitModelAnalysis.__init__.__doc__ | |
def flop_count( | |
model: nn.Module, | |
inputs: Tuple[Any, ...], | |
supported_ops: Optional[Dict[str, Handle]] = None, | |
) -> Tuple[DefaultDict[str, float], Counter[str]]: | |
"""Given a model and an input to the model, compute the per-operator Gflops | |
of the given model. | |
Adopted from | |
https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/flop_count.py | |
Args: | |
model (nn.Module): The model to compute flop counts. | |
inputs (tuple): Inputs that are passed to `model` to count flops. | |
Inputs need to be in a tuple. | |
supported_ops (dict(str,Callable) or None) : provide additional | |
handlers for extra ops, or overwrite the existing handlers for | |
convolution and matmul and einsum. The key is operator name and | |
the value is a function that takes (inputs, outputs) of the op. | |
We count one Multiply-Add as one FLOP. | |
Returns: | |
tuple[defaultdict, Counter]: A dictionary that records the number of | |
gflops for each operation and a Counter that records the number of | |
unsupported operations. | |
""" | |
if supported_ops is None: | |
supported_ops = {} | |
flop_counter = FlopAnalyzer(model, inputs).set_op_handle(**supported_ops) | |
giga_flops = defaultdict(float) | |
for op, flop in flop_counter.by_operator().items(): | |
giga_flops[op] = flop / 1e9 | |
return giga_flops, flop_counter.unsupported_ops() | |
def activation_count( | |
model: nn.Module, | |
inputs: Tuple[Any, ...], | |
supported_ops: Optional[Dict[str, Handle]] = None, | |
) -> Tuple[DefaultDict[str, float], Counter[str]]: | |
"""Given a model and an input to the model, compute the total number of | |
activations of the model. | |
Adopted from | |
https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/activation_count.py | |
Args: | |
model (nn.Module): The model to compute activation counts. | |
inputs (tuple): Inputs that are passed to `model` to count activations. | |
Inputs need to be in a tuple. | |
supported_ops (dict(str,Callable) or None) : provide additional | |
handlers for extra ops, or overwrite the existing handlers for | |
convolution and matmul. The key is operator name and the value | |
is a function that takes (inputs, outputs) of the op. | |
Returns: | |
tuple[defaultdict, Counter]: A dictionary that records the number of | |
activation (mega) for each operation and a Counter that records the | |
number of unsupported operations. | |
""" | |
if supported_ops is None: | |
supported_ops = {} | |
act_counter = ActivationAnalyzer(model, | |
inputs).set_op_handle(**supported_ops) | |
mega_acts = defaultdict(float) | |
for op, act in act_counter.by_operator().items(): | |
mega_acts[op] = act / 1e6 | |
return mega_acts, act_counter.unsupported_ops() | |
def parameter_count(model: nn.Module) -> typing.DefaultDict[str, int]: | |
"""Count parameters of a model and its submodules. | |
Adopted from | |
https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/parameter_count.py | |
Args: | |
model (nn.Module): the model to count parameters. | |
Returns: | |
dict[str, int]: the key is either a parameter name or a module name. | |
The value is the number of elements in the parameter, or in all | |
parameters of the module. The key "" corresponds to the total | |
number of parameters of the model. | |
""" | |
count = defaultdict(int) # type: typing.DefaultDict[str, int] | |
for name, param in model.named_parameters(): | |
size = param.numel() | |
name = name.split('.') | |
for k in range(0, len(name) + 1): | |
prefix = '.'.join(name[:k]) | |
count[prefix] += size | |
return count | |
def parameter_count_table(model: nn.Module, max_depth: int = 3) -> str: | |
"""Format the parameter count of the model (and its submodules or | |
parameters) | |
Adopted from | |
https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/parameter_count.py | |
Args: | |
model (nn.Module): the model to count parameters. | |
max_depth (int): maximum depth to recursively print submodules or | |
parameters | |
Returns: | |
str: the table to be printed | |
""" | |
count: typing.DefaultDict[str, int] = parameter_count(model) | |
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. | |
param_shape: typing.Dict[str, typing.Tuple] = { | |
k: tuple(v.shape) | |
for k, v in model.named_parameters() | |
} | |
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter. | |
rows: typing.List[typing.Tuple] = [] | |
def format_size(x: int) -> str: | |
if x > 1e8: | |
return f'{x / 1e9:.1f}G' | |
if x > 1e5: | |
return f'{x / 1e6:.1f}M' | |
if x > 1e2: | |
return f'{x / 1e3:.1f}K' | |
return str(x) | |
def fill(lvl: int, prefix: str) -> None: | |
if lvl >= max_depth: | |
return | |
for name, v in count.items(): | |
if name.count('.') == lvl and name.startswith(prefix): | |
indent = ' ' * (lvl + 1) | |
if name in param_shape: | |
rows.append( | |
(indent + name, indent + str(param_shape[name]))) | |
else: | |
rows.append((indent + name, indent + format_size(v))) | |
fill(lvl + 1, name + '.') | |
rows.append(('model', format_size(count.pop('')))) | |
fill(0, '') | |
table = Table( | |
title=f'parameter count of {model.__class__.__name__}', box=box.ASCII2) | |
table.add_column('name') | |
table.add_column('#elements or shape') | |
for row in rows: | |
table.add_row(*row) | |
console = Console() | |
with console.capture() as capture: | |
console.print(table, end='') | |
return capture.get() | |