Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# Base Machine Learning Experiment | |
# | |
# @ Fabian Hörst, fabian.hoerst@uk-essen.de | |
# Institute for Artifical Intelligence in Medicine, | |
# University Medicine Essen | |
import logging | |
logger = logging.getLogger("__main__") | |
logger.addHandler(logging.NullHandler()) | |
import wandb | |
class EarlyStopping: | |
"""Early Stopping Class | |
Args: | |
patience (int): Patience to wait before stopping | |
strategy (str, optional): Optimization strategy. | |
Please select 'minimize' or 'maximize' for strategy. Defaults to "minimize". | |
""" | |
def __init__(self, patience: int, strategy: str = "minimize"): | |
assert strategy.lower() in [ | |
"minimize", | |
"maximize", | |
], "Please select 'minimize' or 'maximize' for strategy" | |
self.patience = patience | |
self.counter = 0 | |
self.strategy = strategy.lower() | |
self.best_metric = None | |
self.best_epoch = None | |
self.early_stop = False | |
logger.info( | |
f"Using early stopping with a range of {self.patience} and {self.strategy} strategy" | |
) | |
def __call__(self, metric: float, epoch: int) -> bool: | |
"""Early stopping update call | |
Args: | |
metric (float): Metric for early stopping | |
epoch (int): Current epoch | |
Returns: | |
bool: Returns true if the model is performing better than the current best model, | |
otherwise false | |
""" | |
if self.best_metric is None: | |
self.best_metric = metric | |
self.best_epoch = epoch | |
return True | |
else: | |
if self.strategy == "minimize": | |
if self.best_metric >= metric: | |
self.best_metric = metric | |
self.best_epoch = epoch | |
self.counter = 0 | |
wandb.run.summary["Best-Epoch"] = epoch | |
wandb.run.summary["Best-Metric"] = metric | |
return True | |
else: | |
self.counter += 1 | |
if self.counter >= self.patience: | |
self.early_stop = True | |
return False | |
elif self.strategy == "maximize": | |
if self.best_metric <= metric: | |
self.best_metric = metric | |
self.best_epoch = epoch | |
self.counter = 0 | |
wandb.run.summary["Best-Epoch"] = epoch | |
wandb.run.summary["Best-Metric"] = metric | |
return True | |
else: | |
self.counter += 1 | |
if self.counter >= self.patience: | |
self.early_stop = True | |
return False | |