File size: 2,982 Bytes
66a6dc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import numpy as np
import pytorch_lightning as pl
import matplotlib.pyplot as plt

import deepafx_st.utils as utils


class LogParametersCallback(pl.callbacks.Callback):
    def __init__(self, num_examples=4):
        super().__init__()
        self.num_examples = 4

    def on_validation_epoch_start(self, trainer, pl_module):
        """At the start of validation init storage for parameters."""
        self.params = []

    def on_validation_batch_end(
        self,
        trainer,
        pl_module,
        outputs,
        batch,
        batch_idx,
        dataloader_idx,
    ):
        """Called when the validation batch ends.

        Here we log the parameters only from the first batch.

        """
        if outputs is not None and batch_idx == 0:
            examples = np.min([self.num_examples, outputs["x"].shape[0]])
            for n in range(examples):
                self.log_parameters(
                    outputs,
                    n,
                    pl_module.processor.ports,
                    trainer.global_step,
                    trainer.logger,
                    True if batch_idx == 0 else False,
                )

    def on_validation_epoch_end(self, trainer, pl_module):
        pass

    def log_parameters(self, outputs, batch_idx, ports, global_step, logger, log=True):
        p = outputs["p"][batch_idx, ...]

        table = ""

        # table += f"""## {plugin["name"]}\n"""
        table += "| Index| Name | Value | Units | Min | Max | Default | Raw Value | \n"
        table += "|------|------|------:|:------|----:|----:|--------:| ---------:| \n"

        start_idx = 0
        # set plugin parameters based on provided normalized parameters
        for port_list in ports:
            for pidx, port in enumerate(port_list):
                param_max = port["max"]
                param_min = port["min"]
                param_name = port["name"]
                param_default = port["default"]
                param_units = port["units"]

                param_val = p[start_idx]
                denorm_val = utils.denormalize(param_val, param_max, param_min)

                # add values to table in row
                table += f"| {start_idx + 1} | {param_name} "
                if np.abs(denorm_val) > 10:
                    table += f"| {denorm_val:0.1f} "
                    table += f"| {param_units} "
                    table += f"| {param_min:0.1f} | {param_max:0.1f} "
                    table += f"| {param_default:0.1f} "
                else:
                    table += f"| {denorm_val:0.3f} "
                    table += f"| {param_units} "
                    table += f"| {param_min:0.3f} | {param_max:0.3f} "
                    table += f"| {param_default:0.3f} "

                table += f"| {np.squeeze(param_val):0.2f} | \n"
                start_idx += 1

        table += "\n\n"

        if log:
            logger.experiment.add_text(f"params/{batch_idx+1}", table, global_step)