Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
# Modified from | |
# https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py | |
import typing | |
from collections import Counter, OrderedDict | |
from typing import Any, Callable, List, Optional, Union | |
import numpy as np | |
try: | |
from math import prod # type: ignore | |
except ImportError: | |
from numpy import prod as _prod # type: ignore | |
# Patch `numpy.prod` to avoid overflow on Windows by converting its result | |
# from `np.int32` to `int`. | |
def prod(*args, **kwargs): # type: ignore | |
return _prod(*args, **kwargs).item() | |
Handle = Callable[[List[Any], List[Any]], Union[typing.Counter[str], int]] | |
def get_shape(val: Any) -> Optional[List[int]]: | |
"""Get the shapes from a jit value object. | |
Args: | |
val (torch._C.Value): jit value object. | |
Returns: | |
list(int): return a list of ints. | |
""" | |
if val.isCompleteTensor(): | |
return val.type().sizes() | |
else: | |
return None # type: ignore | |
""" | |
Below are flop/activation counters for various ops. | |
Every counter has the following signature: | |
Args: | |
inputs (list(torch._C.Value)): | |
The inputs of the op in the form of a list of jit object. | |
outputs (list(torch._C.Value)): | |
The outputs of the op in the form of a list of jit object. | |
Returns: | |
number: The number of flops/activations for the operation. | |
or Counter[str] | |
""" | |
def generic_activation_jit(op_name: Optional[str] = None) -> Handle: | |
"""This method returns a handle that counts the number of activation from | |
the output shape for the specified operation. | |
Args: | |
op_name (str): The name of the operation. If given, the handle will | |
return a counter using this name. | |
Returns: | |
Callable: An activation handle for the given operation. | |
""" | |
def _generic_activation_jit( | |
i: Any, outputs: List[Any]) -> Union[typing.Counter[str], int]: | |
"""This is a generic jit handle that counts the number of activations | |
for any operation given the output shape.""" | |
out_shape = get_shape(outputs[0]) | |
ac_count = prod(out_shape) # type: ignore | |
if op_name is None: | |
return ac_count # type: ignore | |
else: | |
return Counter({op_name: ac_count}) | |
return _generic_activation_jit | |
def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Union[int, Any]: | |
"""Count flops for fully connected layers.""" | |
# Count flop for nn.Linear | |
# inputs is a list of length 3. | |
input_shapes = [get_shape(v) for v in inputs[1:3]] | |
# input_shapes[0]: [batch size, input feature dimension] | |
# input_shapes[1]: [batch size, output feature dimension] | |
assert len(input_shapes[0]) == 2, input_shapes[0] # type: ignore | |
assert len(input_shapes[1]) == 2, input_shapes[1] # type: ignore | |
batch_size, input_dim = input_shapes[0] # type: ignore | |
output_dim = input_shapes[1][1] # type: ignore | |
flops = batch_size * input_dim * output_dim | |
return flops | |
def linear_flop_jit(inputs: List[Any], outputs: List[Any]) -> Union[int, Any]: | |
"""Count flops for the aten::linear operator.""" | |
# Inputs is a list of length 3; unlike aten::addmm, it is the first | |
# two elements that are relevant. | |
input_shapes = [get_shape(v) for v in inputs[0:2]] | |
# input_shapes[0]: [dim0, dim1, ..., input_feature_dim] | |
# input_shapes[1]: [output_feature_dim, input_feature_dim] | |
assert input_shapes[0][-1] == input_shapes[1][-1] # type: ignore | |
flops = prod(input_shapes[0]) * input_shapes[1][0] # type: ignore | |
return flops | |
def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Union[int, Any]: | |
"""Count flops for the bmm operation.""" | |
# Inputs should be a list of length 2. | |
# Inputs contains the shapes of two tensor. | |
assert len(inputs) == 2, len(inputs) | |
input_shapes = [get_shape(v) for v in inputs] | |
n, c, t = input_shapes[0] # type: ignore | |
d = input_shapes[-1][-1] # type: ignore | |
flop = n * c * t * d | |
return flop | |
def conv_flop_count( | |
x_shape: List[int], | |
w_shape: List[int], | |
out_shape: List[int], | |
transposed: bool = False, | |
) -> Union[int, Any]: | |
"""Count flops for convolution. Note only multiplication is counted. | |
Computation for addition and bias is ignored. Flops for a transposed | |
convolution are calculated as. | |
flops = (x_shape[2:] * prod(w_shape) * batch_size). | |
Args: | |
x_shape (list(int)): The input shape before convolution. | |
w_shape (list(int)): The filter shape. | |
out_shape (list(int)): The output shape after convolution. | |
transposed (bool): is the convolution transposed | |
Returns: | |
int: the number of flops | |
""" | |
batch_size = x_shape[0] | |
conv_shape = (x_shape if transposed else out_shape)[2:] | |
flop = batch_size * prod(w_shape) * prod(conv_shape) | |
return flop | |
def conv_flop_jit(inputs: List[Any], | |
outputs: List[Any]) -> typing.Counter[str]: | |
"""Count flops for convolution.""" | |
# Inputs of Convolution should be a list of length 12 or 13. | |
# They represent: | |
# 0) input tensor, 1) convolution filter, 2) bias, 3) stride, 4) padding, | |
# 5) dilation, 6) transposed, 7) out_pad, 8) groups, 9) benchmark_cudnn, | |
# 10) deterministic_cudnn and 11) user_enabled_cudnn. | |
# starting with #40737 it will be 12) user_enabled_tf32 | |
assert len(inputs) == 12 or len(inputs) == 13, len(inputs) | |
x, w = inputs[:2] | |
x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), | |
get_shape(outputs[0])) | |
transposed = inputs[6].toIValue() | |
# use a custom name instead of "_convolution" | |
return Counter({ | |
'conv': | |
conv_flop_count( | |
x_shape, # type: ignore | |
w_shape, # type: ignore | |
out_shape, # type: ignore | |
transposed=transposed) # type: ignore | |
}) | |
def einsum_flop_jit(inputs: List[Any], outputs: List[Any]) -> Union[int, Any]: | |
"""Count flops for the einsum operation.""" | |
# Inputs of einsum should be a list of length 2+. | |
# Inputs[0] stores the equation used for einsum. | |
# Inputs[1] stores the list of input shapes. | |
assert len(inputs) >= 2, len(inputs) | |
equation = inputs[0].toIValue() | |
# Get rid of white space in the equation string. | |
equation = equation.replace(' ', '') | |
input_shapes_jit = inputs[1].node().inputs() | |
input_shapes = [get_shape(v) for v in input_shapes_jit] | |
# Re-map equation so that same equation with different alphabet | |
# representations will look the same. | |
letter_order = OrderedDict((k, 0) for k in equation if k.isalpha()).keys() | |
mapping = {ord(x): 97 + i for i, x in enumerate(letter_order)} | |
equation = equation.translate(mapping) | |
if equation == 'abc,abd->acd': | |
n, c, t = input_shapes[0] # type: ignore | |
p = input_shapes[-1][-1] # type: ignore | |
flop = n * c * t * p | |
return flop | |
elif equation == 'abc,adc->adb': | |
n, t, g = input_shapes[0] # type: ignore | |
c = input_shapes[-1][1] # type: ignore | |
flop = n * t * g * c | |
return flop | |
else: | |
np_arrs = [np.zeros(s) for s in input_shapes] | |
optim = np.einsum_path(equation, *np_arrs, optimize='optimal')[1] | |
for line in optim.split('\n'): | |
if 'optimized flop' in line.lower(): | |
# divided by 2 because we count MAC | |
# (multiply-add counted as one flop) | |
flop = float(np.floor(float(line.split(':')[-1]) / 2)) | |
return flop | |
raise NotImplementedError('Unsupported einsum operation.') | |
def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Union[int, Any]: | |
"""Count flops for matmul.""" | |
# input_shapes is a list of length 2. | |
input_shapes: list = [get_shape(v) for v in inputs] | |
input1, input2 = input_shapes | |
if len(input1) == 1: | |
input1 = [1, input1[0]] | |
if len(input2) == 1: | |
input2 = [input2[0], 1] | |
assert input1[-1] == input2[-2], input_shapes | |
flop = prod(input1) * input2[-1] | |
return flop | |
def norm_flop_counter(affine_arg_index: int) -> Handle: | |
""" | |
Args: | |
affine_arg_index: index of the affine argument in inputs | |
""" | |
def norm_flop_jit(inputs: List[Any], | |
outputs: List[Any]) -> Union[int, Any]: | |
"""Count flops for norm layers.""" | |
# Inputs[0] contains the shape of the input. | |
input_shape = get_shape(inputs[0]) | |
has_affine = get_shape(inputs[affine_arg_index]) is not None | |
assert 2 <= len(input_shape) <= 5, input_shape # type: ignore | |
# 5 is just a rough estimate | |
flop = prod(input_shape) * (5 if has_affine else 4) # type: ignore | |
return flop | |
return norm_flop_jit | |
def batchnorm_flop_jit(inputs: List[Any], | |
outputs: List[Any]) -> Union[int, Any]: | |
training = inputs[5].toIValue() | |
assert isinstance(training, | |
bool), 'Signature of aten::batch_norm has changed!' | |
if training: | |
return norm_flop_counter(1)(inputs, outputs) # pyre-ignore | |
has_affine = get_shape(inputs[1]) is not None | |
input_shape = prod(get_shape(inputs[0])) # type: ignore | |
return input_shape * (2 if has_affine else 1) | |
def elementwise_flop_counter(input_scale: float = 1, | |
output_scale: float = 0) -> Handle: | |
"""Count flops by. | |
input_tensor.numel() * input_scale + | |
output_tensor.numel() * output_scale | |
Args: | |
input_scale: scale of the input tensor (first argument) | |
output_scale: scale of the output tensor (first element in outputs) | |
""" | |
def elementwise_flop(inputs: List[Any], | |
outputs: List[Any]) -> Union[int, Any]: | |
ret = 0 | |
if input_scale != 0: | |
shape = get_shape(inputs[0]) | |
ret += input_scale * prod(shape) # type: ignore | |
if output_scale != 0: | |
shape = get_shape(outputs[0]) | |
ret += output_scale * prod(shape) # type: ignore | |
return ret | |
return elementwise_flop | |