File size: 755 Bytes
92f0e98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
from abc import ABC, abstractmethod
import torch
import pytorch_lightning as pl
class ValidationLoopHook(ABC):
@abstractmethod
def process(self, batch: torch.Tensor, target_batch: torch.Tensor, logits_batch: torch.Tensor, prediction_batch: torch.Tensor) -> None:
"""
Called for every validation batch to process results.
"""
pass
@abstractmethod
def trigger(self, module: pl.LightningModule):
"""
Called after the validation epoch has concluced to further interact with the module and/or log data.
"""
pass
@abstractmethod
def reset(self):
"""
Called right after build() to clean up before the next validation epoch starts.
"""
pass |