Realcat
fix: eloftr
63f3cf2
raw
history blame
1.52 kB
from abc import ABCMeta, abstractmethod
from torch import nn
from copy import copy
import inspect
class BaseModel(nn.Module, metaclass=ABCMeta):
default_conf = {}
required_data_keys = []
def __init__(self, conf):
"""Perform some logic and call the _init method of the child model."""
super().__init__()
self.conf = conf = {**self.default_conf, **conf}
self.required_data_keys = copy(self.required_data_keys)
self._init(conf)
def forward(self, data):
"""Check the data and call the _forward method of the child model."""
for key in self.required_data_keys:
assert key in data, 'Missing key {} in data'.format(key)
return self._forward(data)
@abstractmethod
def _init(self, conf):
"""To be implemented by the child class."""
raise NotImplementedError
@abstractmethod
def _forward(self, data):
"""To be implemented by the child class."""
raise NotImplementedError
def dynamic_load(root, model):
module_path = f'{root.__name__}.{model}'
module = __import__(module_path, fromlist=[''])
classes = inspect.getmembers(module, inspect.isclass)
# Filter classes defined in the module
classes = [c for c in classes if c[1].__module__ == module_path]
# Filter classes inherited from BaseModel
classes = [c for c in classes if issubclass(c[1], BaseModel)]
assert len(classes) == 1, classes
return classes[0][1]
# return getattr(module, 'Model')