File size: 1,532 Bytes
e7d3e35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
""" Trainer debug utils """


def dump_optim_states(self):
    """dumps basic information about the state of the optimizer"""

    print("*** Optim States Dump:")
    param_groups_cnt = len(self.vl_optim.param_groups)
    # state dict has more than param_groups info, so extract only the param groups
    param_group_states = list(self.vl_optim.state.values())[:param_groups_cnt]
    for i, state in enumerate(param_group_states):
        print(f"param group: {i}")
        print(f"  step={state['step']}")
        print(f"  exp_avg    all_zero={all(state['exp_avg'] == 0)}")
        print(f"  exp_avg_sq all_zero={all(state['exp_avg_sq'] == 0)}")

    # can also dump LR state if need be
    # print(f"LR={self.vl_scheduler.get_last_lr()}")


def validate_optim_states_are_reset(self):
    """
    for a new or fully reset optimizer we expect all zeros `exp_avg` and `exp_avg_sq` state tensors and step=1
    """

    param_groups_cnt = len(self.vl_optim.param_groups)
    param_group_states = list(self.vl_optim.state.values())[:param_groups_cnt]
    for i, state in enumerate(param_group_states):
        if state["step"] != 1:
            raise ValueError(f"optimizer reset didn't seem to work: state={i} step={state['step']}")
        if not all(state["exp_avg"] == 0):
            raise ValueError(f"optimizer reset didn't seem to work: state={i} step={state['exp_avg']}")
        if not all(state["exp_avg_sq"] == 0):
            raise ValueError(f"optimizer reset didn't seem to work: state={i} step={state['exp_avg_sq']}")