ArchitSharma's picture
Upload 16 files
c716076
raw
history blame
936 Bytes
from fastai.basic_train import Learner, LearnerCallback
from fastai.vision.gan import GANLearner
class GANSaveCallback(LearnerCallback):
"""A `LearnerCallback` that saves history of metrics while training `learn` into CSV `filename`."""
def __init__(
self,
learn: GANLearner,
learn_gen: Learner,
filename: str,
save_iters: int = 1000,
):
super().__init__(learn)
self.learn_gen = learn_gen
self.filename = filename
self.save_iters = save_iters
def on_batch_end(self, iteration: int, epoch: int, **kwargs) -> None:
if iteration == 0:
return
if iteration % self.save_iters == 0:
self._save_gen_learner(iteration=iteration, epoch=epoch)
def _save_gen_learner(self, iteration: int, epoch: int):
filename = '{}_{}_{}'.format(self.filename, epoch, iteration)
self.learn_gen.save(filename)