File size: 1,775 Bytes
c4c7cee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import logging
from pytorch_lightning.callbacks import Callback
import torch

log = logging.getLogger(__name__)


class FixNANinGrad(Callback):
    def __init__(self, monitor):
        super().__init__()
        self.monitor = monitor
        self.continuous_nan_batchs = 0

    def on_before_optimizer_step(self, trainer, pl_module, optimizer) -> None:
        has_nan = []
        is_inf = []
        for name, param in pl_module.named_parameters():
            if param.grad is not None:
                if torch.isnan(param.grad).any():
                    has_nan.append(name)
                if torch.isinf(param.grad).any():
                    is_inf.append(name)
                torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad)
        if len(has_nan) > 0:
            print(f"Found NaN in {has_nan}")
        if len(is_inf) > 0:
            print(f"Found Inf in {is_inf}")

    def on_train_batch_end(
        self,
        trainer,
        pl_module,
        outputs,
        batch,
        batch_idx,
    ) -> None:
        logs = trainer.callback_metrics
        i = 0
        found_metric = False
        while i < len(self.monitor) and not found_metric:
            if self.monitor[i] in logs.keys():
                current = logs[self.monitor[i]].squeeze()
                found_metric = True
            else:
                i += 1
        if not found_metric:
            raise ValueError("Asked metric not in logs")

        if not torch.isfinite(current):
            self.continuous_nan_batchs += 1
            if self.continuous_nan_batchs >= 5:
                trainer.should_stop = True
                log.info("Training interrupted because of NaN in {self.monitor}")
        else:
            self.continuous_nan_batchs = 0