Spaces:
Running
Running
import torch | |
from torch import nn | |
from openrec.modeling.decoders import build_decoder | |
from openrec.modeling.encoders import build_encoder | |
from openrec.modeling.transforms import build_transform | |
__all__ = ['BaseRecognizer'] | |
class BaseRecognizer(nn.Module): | |
def __init__(self, config): | |
"""the module for OCR. | |
args: | |
config (dict): the super parameters for module. | |
""" | |
super(BaseRecognizer, self).__init__() | |
in_channels = config.get('in_channels', 3) | |
self.use_wd = config.get('use_wd', True) | |
# build transfrom, | |
# for rec, transfrom can be TPS,None | |
if 'Transform' not in config or config['Transform'] is None: | |
self.use_transform = False | |
else: | |
self.use_transform = True | |
config['Transform']['in_channels'] = in_channels | |
self.transform = build_transform(config['Transform']) | |
in_channels = self.transform.out_channels | |
# build backbone | |
if 'Encoder' not in config or config['Encoder'] is None: | |
self.use_encoder = False | |
else: | |
self.use_encoder = True | |
config['Encoder']['in_channels'] = in_channels | |
self.encoder = build_encoder(config['Encoder']) | |
in_channels = self.encoder.out_channels | |
# build decoder | |
if 'Decoder' not in config or config['Decoder'] is None: | |
self.use_decoder = False | |
else: | |
self.use_decoder = True | |
config['Decoder']['in_channels'] = in_channels | |
self.decoder = build_decoder(config['Decoder']) | |
def no_weight_decay(self): | |
if self.use_wd: | |
if hasattr(self.encoder, 'no_weight_decay'): | |
no_weight_decay = self.encoder.no_weight_decay() | |
else: | |
no_weight_decay = {} | |
if hasattr(self.decoder, 'no_weight_decay'): | |
no_weight_decay.update(self.decoder.no_weight_decay()) | |
return no_weight_decay | |
else: | |
return {} | |
def forward(self, x, data=None): | |
if self.use_transform: | |
x = self.transform(x) | |
if self.use_encoder: | |
x = self.encoder(x) | |
if self.use_decoder: | |
x = self.decoder(x, data=data) | |
return x | |