minchul's picture
Upload directory
ea689fe verified
from ..base import BaseModel
from .model import IR_101, IR_50, IR_18
from torchvision import transforms
class IResNetModel(BaseModel):
"""
A class representing a model for IResNet architectures. It supports creating
models with specific configurations such as IR_50 and IR_101.
Attributes:
net (torch.nn.Module): The IResNet network (either IR_50 or IR_101).
config (object): The configuration object with model specifications.
"""
def __init__(self, net, config):
super(IResNetModel, self).__init__(config)
self.net = net
self.config = config
@classmethod
def from_config(cls, config):
if config.name == 'ir50':
net = IR_50(input_size=(112,112), output_dim=config.output_dim)
elif config.name == 'ir101':
net = IR_101(input_size=(112,112), output_dim=config.output_dim)
elif config.name == 'ir18':
net = IR_18(input_size=(112,112), output_dim=config.output_dim)
else:
raise NotImplementedError
model = cls(net, config)
model.eval()
return model
def forward(self, x):
if self.input_color_flip:
x = x.flip(1)
return self.net(x)
def make_train_transform(self):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
return transform
def make_test_transform(self):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
return transform
def load_model(model_config):
model = IResNetModel.from_config(model_config)
return model