File size: 8,805 Bytes
e7d3e35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Adapted from https://github.com/huggingface/transformers/blob/f93c90d21749b61bd89152a7fe99a839df29ed94/src/transformers/debug_utils.py
"""

import json

from transformers.utils import ExplicitEnum, is_torch_available, logging

from m4.training.utils import get_stats


if is_torch_available():
    import torch


logger = logging.get_logger(__name__)


class ActivationTracker:
    """
    This debug class helps detect and understand where the model starts getting very large or very small, and more
    importantly `nan` or `inf` activation elements.

    This class will plug hooks into the model and record the activation values of the model into a list of dictionaries: `jsonl_stats`.

    Recording is only active during training, not during validation, and when `trace_activation` is set to True.
    In practise, since this tracking requires additional computation, we only track activations every X steps.

    In the case of gradient accumulation, all the batches being accumulated are being recorded and identified by the `batch_idx` key.

    Args:
        model (`nn.Module`):
            The model to debug.
        abort_after_batch_num  (`int``, *optional*):
            Whether to abort after a certain batch number has finished
    """

    def __init__(
        self,
        model,
        abort_after_batch_num=None,
    ):
        self.model = model
        self.is_validation = False
        self.abort_after_batch_num = abort_after_batch_num

        self.jsonl_stats = []
        self.batch_number = 0
        self.detected_overflow = False
        self.analyse_model()

        self.register_forward_hook()

    def analyse_model(self):
        # extract the fully qualified module names, to be able to report at run time. e.g.:
        # encoder.block.2.layer.0.SelfAttention.o
        #
        # for shared weights only the first shared module name will be registered
        self.module_names = {m: name for name, m in self.model.named_modules()}

    def analyse_variable(self, var, ctx, current_module_stats):
        if torch.is_tensor(var):
            dict_stats = get_stats(var, ctx)
            current_module_stats.update(dict_stats)
            # self.expand_frame(text_stats)
            if detect_overflow(var, ctx):
                self.detected_overflow = True
        return current_module_stats

    def create_frame(self, module, input, output):
        module_name = f"{self.module_names[module]}"
        module_type = f"{module.__class__.__name__}"
        current_module_stats = {}

        # inputs
        if isinstance(input, tuple):
            for i, x in enumerate(input):
                current_module_stats = self.analyse_variable(x, f"input[{i}]", current_module_stats)
        else:
            current_module_stats = self.analyse_variable(input, "input", current_module_stats)

        # outputs
        if isinstance(output, tuple):
            for i, x in enumerate(output):
                # possibly a tuple of tuples
                if isinstance(x, tuple):
                    for j, y in enumerate(x):
                        current_module_stats = self.analyse_variable(y, f"output[{i}][{j}]", current_module_stats)
                else:
                    current_module_stats = self.analyse_variable(x, f"output[{i}]", current_module_stats)
        else:
            current_module_stats = self.analyse_variable(output, "output", current_module_stats)
        if current_module_stats:
            # When we activate gradient checkpointing, the forward hook will be called twice for some (not all) modules.
            # That will lead to double (repeated) entries in the list.
            # This is a hack to avoid these double entries.
            if (module_name, module_type) not in [(x["name"], x["type"]) for x in self.jsonl_stats]:
                self.jsonl_stats.append(
                    {
                        "name": module_name,
                        "type": module_type,
                        **current_module_stats,
                    }
                )

    def register_forward_hook(self):
        self.model.apply(self._register_forward_hook)

    def _register_forward_hook(self, module):
        module.register_forward_hook(self.forward_hook)

    def forward_hook(self, module, input, output):
        # - input is a tuple of packed inputs (could be non-Tensors)
        # - output could be a Tensor or a tuple of Tensors and non-Tensors

        trace_activation = self.trace_activation

        # count batch numbers - the very first forward hook of the batch will be called when the
        # batch completes - i.e. it gets called very last - we know this batch has finished
        if module == self.model:
            self.batch_number += 1

        if trace_activation and not self.is_validation:
            self.create_frame(module, input, output)

        if self.detected_overflow:
            # now we can abort, as it's pointless to continue running
            raise ValueError(
                "DebugUnderflowOverflow: inf/nan detected, aborting as there is no point running further. "
                "Please scroll up above this traceback to see the activation values prior to this event."
            )

        # abort after certain batch if requested to do so
        if self.abort_after_batch_num is not None and self.batch_number > self.abort_after_batch_num:
            raise ValueError(
                f"DebugUnderflowOverflow: aborting after {self.batch_number} batches due to"
                f" `abort_after_batch_num={self.abort_after_batch_num}` arg"
            )

    def fill_in_batch_idx(self, batch_idx):
        if not self.jsonl_stats:
            return
        for r in self.jsonl_stats:
            if "batch_idx" not in r:
                r["batch_idx"] = batch_idx
            else:
                if not (r["batch_idx"] <= batch_idx):
                    raise ValueError("`batch_idx` should be increasing")

    def dump_stats(self, log_activations_filename, curr_opt_step):
        with open(log_activations_filename, "a") as file:
            # append stats to file
            for r in self.jsonl_stats:
                r["step"] = curr_opt_step
                file.write(json.dumps(r) + "\n")

    def reset_jsonl_stats(self):
        self.jsonl_stats = []

    def activate_hooks(self):
        self.trace_activation = True

    def deactivate_hooks(self):
        self.trace_activation = False

    def is_eval(self):
        self.is_validation = True

    def is_train(self):
        self.is_validation = False


def detect_overflow(var, ctx):
    """
    Report whether the tensor contains any `nan` or `inf` entries.

    This is useful for detecting overflows/underflows and best to call right after the function that did some math that
    modified the tensor in question.

    This function contains a few other helper features that you can enable and tweak directly if you want to track
    various other things.

    Args:
        var: the tensor variable to check
        ctx: the message to print as a context

    Return:
        `True` if `inf` or `nan` was detected, `False` otherwise
    """
    detected = False
    if torch.isnan(var).any().item():
        detected = True
        print(f"{ctx} has nans")
    if torch.isinf(var).any().item():
        detected = True
        print(f"{ctx} has infs")

    # if needed to monitor large elements can enable the following
    if 0:  # and detected:
        n100 = var[torch.ge(var.abs(), 100)]
        if n100.numel() > 0:
            print(f"{ctx}:  n100={n100.numel()}")
        n1000 = var[torch.ge(var.abs(), 1000)]
        if n1000.numel() > 0:
            print(f"{ctx}: n1000={n1000.numel()}")
        n10000 = var[torch.ge(var.abs(), 10000)]
        if n10000.numel() > 0:
            print(f"{ctx}: n10000={n10000.numel()}")

    if 0:
        print(f"min={var.min():9.2e} max={var.max():9.2e}")

    if 0:
        print(f"min={var.min():9.2e} max={var.max():9.2e} var={var.var():9.2e} mean={var.mean():9.2e} ({ctx})")

    return detected


class DebugOption(ExplicitEnum):
    UNDERFLOW_OVERFLOW = "underflow_overflow"
    TPU_METRICS_DEBUG = "tpu_metrics_debug"