File size: 7,585 Bytes
db45d00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
# Copied from rut5compressed/util.py of rut5compressed repository.

import logging
import re
from functools import wraps
from re import Pattern
from typing import Callable, Dict, Optional, Tuple

import numpy as np
import torch as T

from .modules import SVDCompressedLinear


def map_module(root: T.nn.Module,
               func: Callable[[T.nn.Module, str], T.nn.Module],
               patt: Optional[str] = None) -> T.nn.Module:
    """Function ``map_module`` applies a function to each leaf of module tree
    which matches to a specified pattern.

    Parameters
    ----------
    root : torch.nn.Module
        Module to modify.
    func : callable
        Function to be applied to every module (or matched to pattern) in
        module tree.
    patt : str, optional
        Pattern to filter modules by path in module tree.

    Returns
    -------
    torch.nn.Module
        Module modified in-place.
    """
    @wraps(func)
    def func_safe(*args, **kwargs):
        node = func(*args, **kwargs)
        if not isinstance(node, T.nn.Module):
            raise ValueError('Mapped result must be toch.nn.Module type '
                             f'but given {type(node)}.')
        return node

    return _map_module(root, func_safe, re.compile(patt or r'.*'), '')


def _map_module(root: T.nn.Module,
                func: Callable[[T.nn.Module, str], T.nn.Module], patt: Pattern,
                path: str) -> T.nn.Module:
    for name, child in root.named_children():
        node = _map_module(child, func, patt, f'{path}/{name}')
        if node != child:
            setattr(root, name, node)
    if patt.match(path or '/'):
        root = func(root, path or '/')
    return root


def convert_linear(module: T.nn.Linear, ctor, **kwargs) -> T.nn.Module:
    """Function convert_linear takes module and returns linear module with
    approximate matmul. Non-linear modules are returned intact.
    """
    if not isinstance(module, T.nn.Linear):
        return module
    raise NotImplementedError


def numel(module: T.nn.Module):
    value = sum(x.numel() for x in module.parameters()) + \
            sum(x.numel() for x in module.buffers())

    def account_prunned(module: T.nn.Module, path: str):
        nonlocal value
        for name, attr in vars(module).items():
            if not name.endswith('_mask') or not isinstance(attr, T.Tensor):
                continue

            weight_name = name[:-5]
            if not hasattr(module, weight_name):
                continue

            weight = getattr(module, weight_name)
            value -= weight.numel() - attr.sum()
            value += attr.numel()
        return module

    def account_quantized(module: T.nn.Module, path: str):
        nonlocal value
        if isinstance(module, T.nn.quantized.Linear):
            value += module.weight().numel()
            if module.bias() is not None:
                value += module.bias().numel()
        return module

    def account_rest(module: T.nn.Module, path: str):
        account_prunned(module, path)
        account_quantized(module, path)
        return module

    map_module(module, account_rest)
    return value


def sizeof(module: T.nn.Module):
    value = sum(x.numel() * x.element_size() for x in module.parameters()) + \
            sum(x.numel() * x.element_size() for x in module.buffers())

    def account_prunned(module: T.nn.Module, path: str):
        nonlocal value
        for name, attr in vars(module).items():
            if not name.endswith('_mask') or not isinstance(attr, T.Tensor):
                continue

            weight_name = name[:-5]
            if not hasattr(module, weight_name):
                continue

            weight = getattr(module, weight_name)
            value -= (weight.numel() - attr.sum()) * weight.element_size()
            value += attr.numel() * attr.element_size()
        return module

    def account_quantized(module: T.nn.Module, path: str):
        nonlocal value
        if isinstance(module, T.nn.quantized.Linear):
            value += module.weight().numel() * module.weight().element_size()
            if (bias := module.bias()) is not None:
                value += bias.numel() * bias.element_size()
        return module

    def account_rest(module: T.nn.Module, path: str):
        account_prunned(module, path)
        account_quantized(module, path)
        return module

    map_module(module, account_rest)
    return value


def flatten_module(module: T.nn.Module, regexp=None) -> Dict[str, T.nn.Module]:
    modules = {}
    map_module(module, lambda x, y: modules.update(**{y: x}) or x, regexp)
    return modules


def print_flatten(module: T.nn.Module):
    paths = []
    path_len = 0
    names = []
    name_len = 0
    indx_len = 0

    def func(module, path):
        nonlocal path_len, name_len, indx_len
        paths.append(path)
        path_len = max(path_len, len(path))
        name = module.__class__.__name__
        names.append(name)
        name_len = max(name_len, len(name))
        indx_len += 1
        return module

    map_module(module, func)

    indx_len = int(np.ceil(np.log10(indx_len)))
    fmt = f'{{indx:>{indx_len}s}} {{path:{path_len}s}} {{name:{name_len}s}}'
    print(fmt.format(indx='#', path='Path', name='Layer'))
    print('-' * (indx_len + path_len + name_len + 2))
    for i, (path, name) in enumerate(zip(paths, names)):
        print(fmt.format(indx=str(i), path=path, name=name))


def compress_linear_svd(module: T.nn.Module, path: str,
                        rank: Optional[int] = None) -> T.nn.Module:
    if not isinstance(module, T.nn.Linear):
        return module

    # Do not factorize if ranks equals to the size of the
    # smallest dimension.
    norows, nocols = module.weight.shape
    if rank == min(norows, nocols):
        return module

    # If there is no rank, then choose rank to be equal point when the number
    # of elements in original matrix is approximately equal to the number of
    # elements in SVD factors.
    if rank is None:
        ratio = norows * nocols / (norows + nocols)
        rank = int(np.floor(ratio))

    return SVDCompressedLinear.from_linear(module, rank)


def compress_linear_tt(module: T.nn.Module, path: str,
                       shape: Tuple[Tuple[int], Tuple[int]],
                       rank: int) -> T.nn.Module:
    if not isinstance(module, T.nn.Linear):
        return module

    # TODO(@not-found): We need propper compression config.
    inp_size = np.prod(shape[0])
    out_size = np.prod(shape[1])
    if inp_size == module.in_features and out_size == module.out_features:
        pass
    elif inp_size == module.out_features and out_size == module.in_features:
        shape = (shape[1], shape[0])
    else:
        raise ValueError(
            'Input and output features does not match to compression shape: '
            f'{shape[0]} vs {module.in_features} and {shape[1]} vs '
            f'{module.out_features}.')

    logging.info('apply tt compression to layer %s', path)
    return TTCompressedLinear.from_linear(module, shape, rank)  # noqa: F821


def compress(module: T.nn.Module, rank: int) -> T.nn.Module:
    """Function compress substitutes in-place linear layer of T5 model with
    linear layer which weight matrix is factorized with SVD.

    :param module: Model to compress.
    :param rank: Desired rank of compressed layer.
    """
    return map_module(
        root=module,
        func=lambda x, y: compress_linear_svd(x, y, rank),
        patt=r'.*/DenseReluDense/w.*')  # TODO(@not-found): Remove?