File size: 5,390 Bytes
45e669a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import os
from typing import Union
import torch
from torch import device
from .utils import get_parameter_device, get_parameter_dtype, save_state_dict_and_config, load_state_dict_from_path
class BaseModel(torch.nn.Module):
"""
A base model class that provides a template for implementing models. It includes methods for
loading, saving, and managing model configurations and states. This class is designed to be
extended by specific model implementations.
Attributes:
config (object): Configuration object containing model settings.
input_color_flip (bool): Whether to flip the color channels from BGR to RGB.
"""
def __init__(self, config=None):
"""
Initializes the BaseModel class.
Parameters:
config (object, optional): Configuration object containing model settings.
"""
super(BaseModel, self).__init__()
self.config = config
if self.config.color_space == 'BGR':
self.input_color_flip = True
self._config_color_space = 'BGR'
self.config.color_space = 'RGB'
else:
self.input_color_flip = False
def forward(self, x):
"""
Forward pass of the model. Needs to be implemented in subclass.
Parameters:
x (torch.Tensor): Input tensor.
Raises:
NotImplementedError: If the subclass does not implement this method.
"""
raise NotImplementedError('forward must be implemented in subclass')
@classmethod
def from_config(cls, config) -> "BaseModel":
"""
Creates an instance of this class from a configuration object. Needs to be implemented in subclass.
Parameters:
config (object): Configuration object.
Returns:
BaseModel: An instance of the subclass.
Raises:
NotImplementedError: If the subclass does not implement this method.
"""
raise NotImplementedError('from_config must be implemented in subclass')
def make_train_transform(self):
"""
Creates training data transformations. Needs to be implemented in subclass.
Raises:
NotImplementedError: If the subclass does not implement this method.
"""
raise NotImplementedError('make_train_transform must be implemented in subclass')
def make_test_transform(self):
"""
Creates testing data transformations. Needs to be implemented in subclass.
Raises:
NotImplementedError: If the subclass does not implement this method.
"""
raise NotImplementedError('make_test_transform must be implemented in subclass')
def save_pretrained(
self,
save_dir: Union[str, os.PathLike],
name: str = 'model.pt',
rank: int = 0,
):
"""
Saves the model's state_dict and configuration to the specified directory.
Parameters:
save_dir (Union[str, os.PathLike]): The directory to save the model.
name (str, optional): The name of the file to save the model as. Default is 'model.pt'.
rank (int, optional): The rank of the process (used in distributed training). Default is 0.
"""
save_path = os.path.join(save_dir, name)
if rank == 0:
save_state_dict_and_config(self.state_dict(), self.config, save_path)
def load_state_dict_from_path(self, pretrained_model_path):
state_dict = load_state_dict_from_path(pretrained_model_path)
if 'net.vit' in list(self.state_dict().keys())[-1] and 'pretrained_models' in pretrained_model_path:
state_dict = {k.replace('net', 'net.vit'): v for k, v in state_dict.items()}
st_keys = list(state_dict.keys())
self_keys = list(self.state_dict().keys())
print('compatible keys in state_dict', len(set(st_keys).intersection(set(self_keys))), '/', len(st_keys))
print('Check\n\n')
result = self.load_state_dict(state_dict, strict=False)
print(result)
print(f"Loaded pretrained model from {pretrained_model_path}")
@property
def device(self) -> device:
"""
Returns the device of the model's parameters.
Returns:
device: The device the model is on.
"""
return get_parameter_device(self)
@property
def dtype(self) -> torch.dtype:
"""
Returns the data type of the model's parameters.
Returns:
torch.dtype: The data type of the model.
"""
return get_parameter_dtype(self)
def num_parameters(self, only_trainable: bool = False) -> int:
"""
Returns the number of parameters in the model, optionally filtering only trainable parameters.
Parameters:
only_trainable (bool, optional): Whether to count only trainable parameters. Default is False.
Returns:
int: The number of parameters.
"""
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
def has_trainable_params(self):
"""
Checks if the model has any trainable parameters.
Returns:
bool: True if the model has trainable parameters, False otherwise.
"""
return any(p.requires_grad for p in self.parameters())
|