File size: 1,546 Bytes
4bde5d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import sys
from abc import ABCMeta, abstractmethod
from torch import nn
from copy import copy
import inspect
class BaseModel(nn.Module, metaclass=ABCMeta):
default_conf = {}
required_inputs = []
def __init__(self, conf):
"""Perform some logic and call the _init method of the child model."""
super().__init__()
self.conf = conf = {**self.default_conf, **conf}
self.required_inputs = copy(self.required_inputs)
self._init(conf)
sys.stdout.flush()
def forward(self, data):
"""Check the data and call the _forward method of the child model."""
for key in self.required_inputs:
assert key in data, "Missing key {} in data".format(key)
return self._forward(data)
@abstractmethod
def _init(self, conf):
"""To be implemented by the child class."""
raise NotImplementedError
@abstractmethod
def _forward(self, data):
"""To be implemented by the child class."""
raise NotImplementedError
def dynamic_load(root, model):
module_path = f"{root.__name__}.{model}"
module = __import__(module_path, fromlist=[""])
classes = inspect.getmembers(module, inspect.isclass)
# Filter classes defined in the module
classes = [c for c in classes if c[1].__module__ == module_path]
# Filter classes inherited from BaseModel
classes = [c for c in classes if issubclass(c[1], BaseModel)]
assert len(classes) == 1, classes
return classes[0][1]
# return getattr(module, 'Model')
|