Spaces:
Running
Running
File size: 4,602 Bytes
a80d6bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
"""
Base class for trainable models.
"""
from abc import ABCMeta, abstractmethod
import omegaconf
from omegaconf import OmegaConf
from torch import nn
from copy import copy
class MetaModel(ABCMeta):
def __prepare__(name, bases, **kwds):
total_conf = OmegaConf.create()
for base in bases:
for key in ('base_default_conf', 'default_conf'):
update = getattr(base, key, {})
if isinstance(update, dict):
update = OmegaConf.create(update)
total_conf = OmegaConf.merge(total_conf, update)
return dict(base_default_conf=total_conf)
class BaseModel(nn.Module, metaclass=MetaModel):
"""
What the child model is expect to declare:
default_conf: dictionary of the default configuration of the model.
It recursively updates the default_conf of all parent classes, and
it is updated by the user-provided configuration passed to __init__.
Configurations can be nested.
required_data_keys: list of expected keys in the input data dictionary.
strict_conf (optional): boolean. If false, BaseModel does not raise
an error when the user provides an unknown configuration entry.
_init(self, conf): initialization method, where conf is the final
configuration object (also accessible with `self.conf`). Accessing
unknown configuration entries will raise an error.
_forward(self, data): method that returns a dictionary of batched
prediction tensors based on a dictionary of batched input data tensors.
loss(self, pred, data): method that returns a dictionary of losses,
computed from model predictions and input data. Each loss is a batch
of scalars, i.e. a torch.Tensor of shape (B,).
The total loss to be optimized has the key `'total'`.
metrics(self, pred, data): method that returns a dictionary of metrics,
each as a batch of scalars.
"""
default_conf = {
'name': None,
'trainable': True, # if false: do not optimize this model parameters
'freeze_batch_normalization': False, # use test-time statistics
}
required_data_keys = []
strict_conf = True
def __init__(self, conf):
"""Perform some logic and call the _init method of the child model."""
super().__init__()
default_conf = OmegaConf.merge(
self.base_default_conf, OmegaConf.create(self.default_conf))
if self.strict_conf:
OmegaConf.set_struct(default_conf, True)
# fixme: backward compatibility
if 'pad' in conf and 'pad' not in default_conf: # backward compat.
with omegaconf.read_write(conf):
with omegaconf.open_dict(conf):
conf['interpolation'] = {'pad': conf.pop('pad')}
if isinstance(conf, dict):
conf = OmegaConf.create(conf)
self.conf = conf = OmegaConf.merge(default_conf, conf)
OmegaConf.set_readonly(conf, True)
OmegaConf.set_struct(conf, True)
self.required_data_keys = copy(self.required_data_keys)
self._init(conf)
if not conf.trainable:
for p in self.parameters():
p.requires_grad = False
def train(self, mode=True):
super().train(mode)
def freeze_bn(module):
if isinstance(module, nn.modules.batchnorm._BatchNorm):
module.eval()
if self.conf.freeze_batch_normalization:
self.apply(freeze_bn)
return self
def forward(self, data):
"""Check the data and call the _forward method of the child model."""
def recursive_key_check(expected, given):
for key in expected:
assert key in given, f'Missing key {key} in data'
if isinstance(expected, dict):
recursive_key_check(expected[key], given[key])
recursive_key_check(self.required_data_keys, data)
return self._forward(data)
@abstractmethod
def _init(self, conf):
"""To be implemented by the child class."""
raise NotImplementedError
@abstractmethod
def _forward(self, data):
"""To be implemented by the child class."""
raise NotImplementedError
@abstractmethod
def loss(self, pred, data):
"""To be implemented by the child class."""
raise NotImplementedError
@abstractmethod
def metrics(self, pred, data):
"""To be implemented by the child class."""
raise NotImplementedError
|