Spaces:
Sleeping
Sleeping
File size: 1,434 Bytes
7694c84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
from torch.utils.tensorboard import SummaryWriter
from utils.plotting import get_alignment_figure, get_specs_figure
class TBLogger(SummaryWriter):
def __init__(self, log_dir):
super(TBLogger, self).__init__(log_dir)
def add_training_data(self, meta, grad_norm,
learning_rate, tb_step: int):
for k, v in meta.items():
self.add_scalar(f'train/{k}', v.item(), tb_step)
self.add_scalar("train/grad_norm", grad_norm, tb_step)
self.add_scalar("train/learning_rate", learning_rate, tb_step)
def add_parameters(self, model, tb_step: int):
for tag, value in model.named_parameters():
tag = tag.replace('.', '/')
self.add_histogram(tag, value.data.cpu().numpy(), tb_step)
def add_sample(self, alignment, mel_pred,
mel_targ, mel_infer, len_targ,
tb_step: int):
self.add_figure(
"alignment",
get_alignment_figure(alignment.detach().cpu().numpy().T),
tb_step)
self.add_figure(
"spectrograms",
get_specs_figure([
mel_infer.detach().cpu().numpy(),
mel_pred[:, :len_targ].detach().cpu().numpy(),
mel_targ[:, :len_targ].detach().cpu().numpy(),
],
['Frames (inferred)', 'Frames (predicted)', 'Frames (target)']
), tb_step)
|