|
|
|
|
|
|
|
|
|
|
|
import bisect |
|
import time |
|
from collections import OrderedDict |
|
from typing import Dict, Optional |
|
|
|
try: |
|
import torch |
|
|
|
def type_as(a, b): |
|
if torch.is_tensor(a) and torch.is_tensor(b): |
|
return a.to(b) |
|
else: |
|
return a |
|
|
|
except ImportError: |
|
torch = None |
|
|
|
def type_as(a, b): |
|
return a |
|
|
|
|
|
try: |
|
import numpy as np |
|
except ImportError: |
|
np = None |
|
|
|
|
|
class Meter(object): |
|
"""Base class for Meters.""" |
|
|
|
def __init__(self): |
|
pass |
|
|
|
def state_dict(self): |
|
return {} |
|
|
|
def load_state_dict(self, state_dict): |
|
pass |
|
|
|
def reset(self): |
|
raise NotImplementedError |
|
|
|
@property |
|
def smoothed_value(self) -> float: |
|
"""Smoothed value used for logging.""" |
|
raise NotImplementedError |
|
|
|
|
|
def safe_round(number, ndigits): |
|
if hasattr(number, "__round__"): |
|
return round(number, ndigits) |
|
elif torch is not None and torch.is_tensor(number) and number.numel() == 1: |
|
return safe_round(number.item(), ndigits) |
|
elif np is not None and np.ndim(number) == 0 and hasattr(number, "item"): |
|
return safe_round(number.item(), ndigits) |
|
else: |
|
return number |
|
|
|
|
|
class AverageMeter(Meter): |
|
"""Computes and stores the average and current value""" |
|
|
|
def __init__(self, round: Optional[int] = None): |
|
self.round = round |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = None |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def update(self, val, n=1): |
|
if val is not None: |
|
self.val = val |
|
if n > 0: |
|
self.sum = type_as(self.sum, val) + (val * n) |
|
self.count = type_as(self.count, n) + n |
|
|
|
def state_dict(self): |
|
return { |
|
"val": self.val, |
|
"sum": self.sum, |
|
"count": self.count, |
|
"round": self.round, |
|
} |
|
|
|
def load_state_dict(self, state_dict): |
|
self.val = state_dict["val"] |
|
self.sum = state_dict["sum"] |
|
self.count = state_dict["count"] |
|
self.round = state_dict.get("round", None) |
|
|
|
@property |
|
def avg(self): |
|
return self.sum / self.count if self.count > 0 else self.val |
|
|
|
@property |
|
def smoothed_value(self) -> float: |
|
val = self.avg |
|
if self.round is not None and val is not None: |
|
val = safe_round(val, self.round) |
|
return val |
|
|
|
|
|
class SumMeter(Meter): |
|
"""Computes and stores the sum""" |
|
|
|
def __init__(self, round: Optional[int] = None): |
|
self.round = round |
|
self.reset() |
|
|
|
def reset(self): |
|
self.sum = 0 |
|
|
|
def update(self, val): |
|
if val is not None: |
|
self.sum = type_as(self.sum, val) + val |
|
|
|
def state_dict(self): |
|
return { |
|
"sum": self.sum, |
|
"round": self.round, |
|
} |
|
|
|
def load_state_dict(self, state_dict): |
|
self.sum = state_dict["sum"] |
|
self.round = state_dict.get("round", None) |
|
|
|
@property |
|
def smoothed_value(self) -> float: |
|
val = self.sum |
|
if self.round is not None and val is not None: |
|
val = safe_round(val, self.round) |
|
return val |
|
|
|
|
|
class TimeMeter(Meter): |
|
"""Computes the average occurrence of some event per second""" |
|
|
|
def __init__( |
|
self, |
|
init: int = 0, |
|
n: int = 0, |
|
round: Optional[int] = None, |
|
): |
|
self.round = round |
|
self.reset(init, n) |
|
|
|
def reset(self, init=0, n=0): |
|
self.init = init |
|
self.start = time.perf_counter() |
|
self.n = n |
|
self.i = 0 |
|
|
|
def update(self, val=1): |
|
self.n = type_as(self.n, val) + val |
|
self.i += 1 |
|
|
|
def state_dict(self): |
|
return { |
|
"init": self.elapsed_time, |
|
"n": self.n, |
|
"round": self.round, |
|
} |
|
|
|
def load_state_dict(self, state_dict): |
|
if "start" in state_dict: |
|
|
|
self.reset(init=state_dict["init"]) |
|
else: |
|
self.reset(init=state_dict["init"], n=state_dict["n"]) |
|
self.round = state_dict.get("round", None) |
|
|
|
@property |
|
def avg(self): |
|
return self.n / self.elapsed_time |
|
|
|
@property |
|
def elapsed_time(self): |
|
return self.init + (time.perf_counter() - self.start) |
|
|
|
@property |
|
def smoothed_value(self) -> float: |
|
val = self.avg |
|
if self.round is not None and val is not None: |
|
val = safe_round(val, self.round) |
|
return val |
|
|
|
|
|
class StopwatchMeter(Meter): |
|
"""Computes the sum/avg duration of some event in seconds""" |
|
|
|
def __init__(self, round: Optional[int] = None): |
|
self.round = round |
|
self.sum = 0 |
|
self.n = 0 |
|
self.start_time = None |
|
|
|
def start(self): |
|
self.start_time = time.perf_counter() |
|
|
|
def stop(self, n=1, prehook=None): |
|
if self.start_time is not None: |
|
if prehook is not None: |
|
prehook() |
|
delta = time.perf_counter() - self.start_time |
|
self.sum = self.sum + delta |
|
self.n = type_as(self.n, n) + n |
|
|
|
def reset(self): |
|
self.sum = 0 |
|
self.n = 0 |
|
self.start() |
|
|
|
def state_dict(self): |
|
return { |
|
"sum": self.sum, |
|
"n": self.n, |
|
"round": self.round, |
|
} |
|
|
|
def load_state_dict(self, state_dict): |
|
self.sum = state_dict["sum"] |
|
self.n = state_dict["n"] |
|
self.start_time = None |
|
self.round = state_dict.get("round", None) |
|
|
|
@property |
|
def avg(self): |
|
return self.sum / self.n if self.n > 0 else self.sum |
|
|
|
@property |
|
def elapsed_time(self): |
|
if self.start_time is None: |
|
return 0.0 |
|
return time.perf_counter() - self.start_time |
|
|
|
@property |
|
def smoothed_value(self) -> float: |
|
val = self.avg if self.sum > 0 else self.elapsed_time |
|
if self.round is not None and val is not None: |
|
val = safe_round(val, self.round) |
|
return val |
|
|
|
|
|
class MetersDict(OrderedDict): |
|
"""A sorted dictionary of :class:`Meters`. |
|
|
|
Meters are sorted according to a priority that is given when the |
|
meter is first added to the dictionary. |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.priorities = [] |
|
|
|
def __setitem__(self, key, value): |
|
assert key not in self, "MetersDict doesn't support reassignment" |
|
priority, value = value |
|
bisect.insort(self.priorities, (priority, len(self.priorities), key)) |
|
super().__setitem__(key, value) |
|
for _, _, key in self.priorities: |
|
self.move_to_end(key) |
|
|
|
def add_meter(self, key, meter, priority): |
|
self.__setitem__(key, (priority, meter)) |
|
|
|
def state_dict(self): |
|
return [ |
|
(pri, key, self[key].__class__.__name__, self[key].state_dict()) |
|
for pri, _, key in self.priorities |
|
|
|
if not isinstance(self[key], MetersDict._DerivedMeter) |
|
] |
|
|
|
def load_state_dict(self, state_dict): |
|
self.clear() |
|
self.priorities.clear() |
|
for pri, key, meter_cls, meter_state in state_dict: |
|
meter = globals()[meter_cls]() |
|
meter.load_state_dict(meter_state) |
|
self.add_meter(key, meter, pri) |
|
|
|
def get_smoothed_value(self, key: str) -> float: |
|
"""Get a single smoothed value.""" |
|
meter = self[key] |
|
if isinstance(meter, MetersDict._DerivedMeter): |
|
return meter.fn(self) |
|
else: |
|
return meter.smoothed_value |
|
|
|
def get_smoothed_values(self) -> Dict[str, float]: |
|
"""Get all smoothed values.""" |
|
return OrderedDict( |
|
[ |
|
(key, self.get_smoothed_value(key)) |
|
for key in self.keys() |
|
if not key.startswith("_") |
|
] |
|
) |
|
|
|
def reset(self): |
|
"""Reset Meter instances.""" |
|
for meter in self.values(): |
|
if isinstance(meter, MetersDict._DerivedMeter): |
|
continue |
|
meter.reset() |
|
|
|
class _DerivedMeter(Meter): |
|
"""A Meter whose values are derived from other Meters.""" |
|
|
|
def __init__(self, fn): |
|
self.fn = fn |
|
|
|
def reset(self): |
|
pass |
|
|