Spaces:
Build error
Build error
# Copyright 2020 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
Adapted from https://github.com/huggingface/transformers/blob/f93c90d21749b61bd89152a7fe99a839df29ed94/src/transformers/debug_utils.py | |
""" | |
import json | |
from transformers.utils import ExplicitEnum, is_torch_available, logging | |
from m4.training.utils import get_stats | |
if is_torch_available(): | |
import torch | |
logger = logging.get_logger(__name__) | |
class ActivationTracker: | |
""" | |
This debug class helps detect and understand where the model starts getting very large or very small, and more | |
importantly `nan` or `inf` activation elements. | |
This class will plug hooks into the model and record the activation values of the model into a list of dictionaries: `jsonl_stats`. | |
Recording is only active during training, not during validation, and when `trace_activation` is set to True. | |
In practise, since this tracking requires additional computation, we only track activations every X steps. | |
In the case of gradient accumulation, all the batches being accumulated are being recorded and identified by the `batch_idx` key. | |
Args: | |
model (`nn.Module`): | |
The model to debug. | |
abort_after_batch_num (`int``, *optional*): | |
Whether to abort after a certain batch number has finished | |
""" | |
def __init__( | |
self, | |
model, | |
abort_after_batch_num=None, | |
): | |
self.model = model | |
self.is_validation = False | |
self.abort_after_batch_num = abort_after_batch_num | |
self.jsonl_stats = [] | |
self.batch_number = 0 | |
self.detected_overflow = False | |
self.analyse_model() | |
self.register_forward_hook() | |
def analyse_model(self): | |
# extract the fully qualified module names, to be able to report at run time. e.g.: | |
# encoder.block.2.layer.0.SelfAttention.o | |
# | |
# for shared weights only the first shared module name will be registered | |
self.module_names = {m: name for name, m in self.model.named_modules()} | |
def analyse_variable(self, var, ctx, current_module_stats): | |
if torch.is_tensor(var): | |
dict_stats = get_stats(var, ctx) | |
current_module_stats.update(dict_stats) | |
# self.expand_frame(text_stats) | |
if detect_overflow(var, ctx): | |
self.detected_overflow = True | |
return current_module_stats | |
def create_frame(self, module, input, output): | |
module_name = f"{self.module_names[module]}" | |
module_type = f"{module.__class__.__name__}" | |
current_module_stats = {} | |
# inputs | |
if isinstance(input, tuple): | |
for i, x in enumerate(input): | |
current_module_stats = self.analyse_variable(x, f"input[{i}]", current_module_stats) | |
else: | |
current_module_stats = self.analyse_variable(input, "input", current_module_stats) | |
# outputs | |
if isinstance(output, tuple): | |
for i, x in enumerate(output): | |
# possibly a tuple of tuples | |
if isinstance(x, tuple): | |
for j, y in enumerate(x): | |
current_module_stats = self.analyse_variable(y, f"output[{i}][{j}]", current_module_stats) | |
else: | |
current_module_stats = self.analyse_variable(x, f"output[{i}]", current_module_stats) | |
else: | |
current_module_stats = self.analyse_variable(output, "output", current_module_stats) | |
if current_module_stats: | |
# When we activate gradient checkpointing, the forward hook will be called twice for some (not all) modules. | |
# That will lead to double (repeated) entries in the list. | |
# This is a hack to avoid these double entries. | |
if (module_name, module_type) not in [(x["name"], x["type"]) for x in self.jsonl_stats]: | |
self.jsonl_stats.append( | |
{ | |
"name": module_name, | |
"type": module_type, | |
**current_module_stats, | |
} | |
) | |
def register_forward_hook(self): | |
self.model.apply(self._register_forward_hook) | |
def _register_forward_hook(self, module): | |
module.register_forward_hook(self.forward_hook) | |
def forward_hook(self, module, input, output): | |
# - input is a tuple of packed inputs (could be non-Tensors) | |
# - output could be a Tensor or a tuple of Tensors and non-Tensors | |
trace_activation = self.trace_activation | |
# count batch numbers - the very first forward hook of the batch will be called when the | |
# batch completes - i.e. it gets called very last - we know this batch has finished | |
if module == self.model: | |
self.batch_number += 1 | |
if trace_activation and not self.is_validation: | |
self.create_frame(module, input, output) | |
if self.detected_overflow: | |
# now we can abort, as it's pointless to continue running | |
raise ValueError( | |
"DebugUnderflowOverflow: inf/nan detected, aborting as there is no point running further. " | |
"Please scroll up above this traceback to see the activation values prior to this event." | |
) | |
# abort after certain batch if requested to do so | |
if self.abort_after_batch_num is not None and self.batch_number > self.abort_after_batch_num: | |
raise ValueError( | |
f"DebugUnderflowOverflow: aborting after {self.batch_number} batches due to" | |
f" `abort_after_batch_num={self.abort_after_batch_num}` arg" | |
) | |
def fill_in_batch_idx(self, batch_idx): | |
if not self.jsonl_stats: | |
return | |
for r in self.jsonl_stats: | |
if "batch_idx" not in r: | |
r["batch_idx"] = batch_idx | |
else: | |
if not (r["batch_idx"] <= batch_idx): | |
raise ValueError("`batch_idx` should be increasing") | |
def dump_stats(self, log_activations_filename, curr_opt_step): | |
with open(log_activations_filename, "a") as file: | |
# append stats to file | |
for r in self.jsonl_stats: | |
r["step"] = curr_opt_step | |
file.write(json.dumps(r) + "\n") | |
def reset_jsonl_stats(self): | |
self.jsonl_stats = [] | |
def activate_hooks(self): | |
self.trace_activation = True | |
def deactivate_hooks(self): | |
self.trace_activation = False | |
def is_eval(self): | |
self.is_validation = True | |
def is_train(self): | |
self.is_validation = False | |
def detect_overflow(var, ctx): | |
""" | |
Report whether the tensor contains any `nan` or `inf` entries. | |
This is useful for detecting overflows/underflows and best to call right after the function that did some math that | |
modified the tensor in question. | |
This function contains a few other helper features that you can enable and tweak directly if you want to track | |
various other things. | |
Args: | |
var: the tensor variable to check | |
ctx: the message to print as a context | |
Return: | |
`True` if `inf` or `nan` was detected, `False` otherwise | |
""" | |
detected = False | |
if torch.isnan(var).any().item(): | |
detected = True | |
print(f"{ctx} has nans") | |
if torch.isinf(var).any().item(): | |
detected = True | |
print(f"{ctx} has infs") | |
# if needed to monitor large elements can enable the following | |
if 0: # and detected: | |
n100 = var[torch.ge(var.abs(), 100)] | |
if n100.numel() > 0: | |
print(f"{ctx}: n100={n100.numel()}") | |
n1000 = var[torch.ge(var.abs(), 1000)] | |
if n1000.numel() > 0: | |
print(f"{ctx}: n1000={n1000.numel()}") | |
n10000 = var[torch.ge(var.abs(), 10000)] | |
if n10000.numel() > 0: | |
print(f"{ctx}: n10000={n10000.numel()}") | |
if 0: | |
print(f"min={var.min():9.2e} max={var.max():9.2e}") | |
if 0: | |
print(f"min={var.min():9.2e} max={var.max():9.2e} var={var.var():9.2e} mean={var.mean():9.2e} ({ctx})") | |
return detected | |
class DebugOption(ExplicitEnum): | |
UNDERFLOW_OVERFLOW = "underflow_overflow" | |
TPU_METRICS_DEBUG = "tpu_metrics_debug" | |