File size: 1,434 Bytes
7694c84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from torch.utils.tensorboard import SummaryWriter
from utils.plotting import get_alignment_figure, get_specs_figure


class TBLogger(SummaryWriter):
    def __init__(self, log_dir):
        super(TBLogger, self).__init__(log_dir)

    def add_training_data(self, meta, grad_norm,
                          learning_rate, tb_step: int):

        for k, v in meta.items():
            self.add_scalar(f'train/{k}', v.item(), tb_step)
        self.add_scalar("train/grad_norm", grad_norm, tb_step)
        self.add_scalar("train/learning_rate", learning_rate, tb_step)

    def add_parameters(self, model, tb_step: int):

        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            self.add_histogram(tag, value.data.cpu().numpy(), tb_step)

    def add_sample(self, alignment, mel_pred,
                   mel_targ, mel_infer, len_targ,
                   tb_step: int):

        self.add_figure(
            "alignment",
            get_alignment_figure(alignment.detach().cpu().numpy().T),
            tb_step)

        self.add_figure(
            "spectrograms",
            get_specs_figure([
                mel_infer.detach().cpu().numpy(),
                mel_pred[:, :len_targ].detach().cpu().numpy(),
                mel_targ[:, :len_targ].detach().cpu().numpy(),
            ],
                ['Frames (inferred)', 'Frames (predicted)', 'Frames (target)']
            ), tb_step)