Spaces:
Running
Running
File size: 1,775 Bytes
c4c7cee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import logging
from pytorch_lightning.callbacks import Callback
import torch
log = logging.getLogger(__name__)
class FixNANinGrad(Callback):
def __init__(self, monitor):
super().__init__()
self.monitor = monitor
self.continuous_nan_batchs = 0
def on_before_optimizer_step(self, trainer, pl_module, optimizer) -> None:
has_nan = []
is_inf = []
for name, param in pl_module.named_parameters():
if param.grad is not None:
if torch.isnan(param.grad).any():
has_nan.append(name)
if torch.isinf(param.grad).any():
is_inf.append(name)
torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad)
if len(has_nan) > 0:
print(f"Found NaN in {has_nan}")
if len(is_inf) > 0:
print(f"Found Inf in {is_inf}")
def on_train_batch_end(
self,
trainer,
pl_module,
outputs,
batch,
batch_idx,
) -> None:
logs = trainer.callback_metrics
i = 0
found_metric = False
while i < len(self.monitor) and not found_metric:
if self.monitor[i] in logs.keys():
current = logs[self.monitor[i]].squeeze()
found_metric = True
else:
i += 1
if not found_metric:
raise ValueError("Asked metric not in logs")
if not torch.isfinite(current):
self.continuous_nan_batchs += 1
if self.continuous_nan_batchs >= 5:
trainer.should_stop = True
log.info("Training interrupted because of NaN in {self.monitor}")
else:
self.continuous_nan_batchs = 0
|