File size: 755 Bytes
92f0e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from abc import ABC, abstractmethod
import torch
import pytorch_lightning as pl

class ValidationLoopHook(ABC):
    @abstractmethod
    def process(self, batch: torch.Tensor, target_batch: torch.Tensor, logits_batch: torch.Tensor, prediction_batch: torch.Tensor) -> None:
        """
        Called for every validation batch to process results.
        """
        pass

    @abstractmethod
    def trigger(self, module: pl.LightningModule):
        """
        Called after the validation epoch has concluced to further interact with the module and/or log data.
        """
        pass

    @abstractmethod
    def reset(self):
        """
        Called right after build() to clean up before the next validation epoch starts.
        """
        pass