|
import sys |
|
from abc import ABCMeta, abstractmethod |
|
from torch import nn |
|
from copy import copy |
|
import inspect |
|
|
|
|
|
class BaseModel(nn.Module, metaclass=ABCMeta): |
|
default_conf = {} |
|
required_inputs = [] |
|
|
|
def __init__(self, conf): |
|
"""Perform some logic and call the _init method of the child model.""" |
|
super().__init__() |
|
self.conf = conf = {**self.default_conf, **conf} |
|
self.required_inputs = copy(self.required_inputs) |
|
self._init(conf) |
|
sys.stdout.flush() |
|
|
|
def forward(self, data): |
|
"""Check the data and call the _forward method of the child model.""" |
|
for key in self.required_inputs: |
|
assert key in data, "Missing key {} in data".format(key) |
|
return self._forward(data) |
|
|
|
@abstractmethod |
|
def _init(self, conf): |
|
"""To be implemented by the child class.""" |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def _forward(self, data): |
|
"""To be implemented by the child class.""" |
|
raise NotImplementedError |
|
|
|
|
|
def dynamic_load(root, model): |
|
module_path = f"{root.__name__}.{model}" |
|
module = __import__(module_path, fromlist=[""]) |
|
classes = inspect.getmembers(module, inspect.isclass) |
|
|
|
classes = [c for c in classes if c[1].__module__ == module_path] |
|
|
|
classes = [c for c in classes if issubclass(c[1], BaseModel)] |
|
assert len(classes) == 1, classes |
|
return classes[0][1] |
|
|
|
|