Spaces:
Running
Running
import torch.nn as nn | |
from util.util import to_device | |
from torch.nn import init | |
import os | |
import torch | |
from .networks import * | |
from params import * | |
class BidirectionalLSTM(nn.Module): | |
def __init__(self, nIn, nHidden, nOut): | |
super(BidirectionalLSTM, self).__init__() | |
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) | |
self.embedding = nn.Linear(nHidden * 2, nOut) | |
def forward(self, input): | |
recurrent, _ = self.rnn(input) | |
T, b, h = recurrent.size() | |
t_rec = recurrent.view(T * b, h) | |
output = self.embedding(t_rec) # [T * b, nOut] | |
output = output.view(T, b, -1) | |
return output | |
class CRNN(nn.Module): | |
def __init__(self, leakyRelu=False): | |
super(CRNN, self).__init__() | |
self.name = 'OCR' | |
#assert opt.imgH % 16 == 0, 'imgH has to be a multiple of 16' | |
ks = [3, 3, 3, 3, 3, 3, 2] | |
ps = [1, 1, 1, 1, 1, 1, 0] | |
ss = [1, 1, 1, 1, 1, 1, 1] | |
nm = [64, 128, 256, 256, 512, 512, 512] | |
cnn = nn.Sequential() | |
nh = 256 | |
dealwith_lossnone=False # whether to replace all nan/inf in gradients to zero | |
def convRelu(i, batchNormalization=False): | |
nIn = 1 if i == 0 else nm[i - 1] | |
nOut = nm[i] | |
cnn.add_module('conv{0}'.format(i), | |
nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i])) | |
if batchNormalization: | |
cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut)) | |
if leakyRelu: | |
cnn.add_module('relu{0}'.format(i), | |
nn.LeakyReLU(0.2, inplace=True)) | |
else: | |
cnn.add_module('relu{0}'.format(i), nn.ReLU(True)) | |
convRelu(0) | |
cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64 | |
convRelu(1) | |
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32 | |
convRelu(2, True) | |
convRelu(3) | |
cnn.add_module('pooling{0}'.format(2), | |
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 | |
convRelu(4, True) | |
if resolution==63: | |
cnn.add_module('pooling{0}'.format(3), | |
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 | |
convRelu(5) | |
cnn.add_module('pooling{0}'.format(4), | |
nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16 | |
convRelu(6, True) # 512x1x16 | |
self.cnn = cnn | |
self.use_rnn = False | |
if self.use_rnn: | |
self.rnn = nn.Sequential( | |
BidirectionalLSTM(512, nh, nh), | |
BidirectionalLSTM(nh, nh, )) | |
else: | |
self.linear = nn.Linear(512, VOCAB_SIZE) | |
# replace all nan/inf in gradients to zero | |
if dealwith_lossnone: | |
self.register_backward_hook(self.backward_hook) | |
self.device = torch.device('cuda:{}'.format(0)) | |
self.init = 'N02' | |
# Initialize weights | |
self = init_weights(self, self.init) | |
def forward(self, input): | |
# conv features | |
conv = self.cnn(input) | |
b, c, h, w = conv.size() | |
if h!=1: | |
print('a') | |
assert h == 1, "the height of conv must be 1" | |
conv = conv.squeeze(2) | |
conv = conv.permute(2, 0, 1) # [w, b, c] | |
if self.use_rnn: | |
# rnn features | |
output = self.rnn(conv) | |
else: | |
output = self.linear(conv) | |
return output | |
def backward_hook(self, module, grad_input, grad_output): | |
for g in grad_input: | |
g[g != g] = 0 # replace all nan/inf in gradients to zero | |
class OCRLabelConverter(object): | |
"""Convert between str and label. | |
NOTE: | |
Insert `blank` to the alphabet for CTC. | |
Args: | |
alphabet (str): set of the possible characters. | |
ignore_case (bool, default=True): whether or not to ignore all of the case. | |
""" | |
def __init__(self, alphabet, ignore_case=False): | |
self._ignore_case = ignore_case | |
if self._ignore_case: | |
alphabet = alphabet.lower() | |
self.alphabet = alphabet + '-' # for `-1` index | |
self.dict = {} | |
for i, char in enumerate(alphabet): | |
# NOTE: 0 is reserved for 'blank' required by wrap_ctc | |
self.dict[char] = i + 1 | |
def encode(self, text): | |
"""Support batch or single str. | |
Args: | |
text (str or list of str): texts to convert. | |
Returns: | |
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. | |
torch.IntTensor [n]: length of each text. | |
""" | |
''' | |
if isinstance(text, str): | |
text = [ | |
self.dict[char.lower() if self._ignore_case else char] | |
for char in text | |
] | |
length = [len(text)] | |
elif isinstance(text, collections.Iterable): | |
length = [len(s) for s in text] | |
text = ''.join(text) | |
text, _ = self.encode(text) | |
return (torch.IntTensor(text), torch.IntTensor(length)) | |
''' | |
length = [] | |
result = [] | |
for item in text: | |
item = item.decode('utf-8', 'strict') | |
length.append(len(item)) | |
for char in item: | |
index = self.dict[char] | |
result.append(index) | |
text = result | |
return (torch.IntTensor(text), torch.IntTensor(length)) | |
def decode(self, t, length, raw=False): | |
"""Decode encoded texts back into strs. | |
Args: | |
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. | |
torch.IntTensor [n]: length of each text. | |
Raises: | |
AssertionError: when the texts and its length does not match. | |
Returns: | |
text (str or list of str): texts to convert. | |
""" | |
if length.numel() == 1: | |
length = length[0] | |
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), | |
length) | |
if raw: | |
return ''.join([self.alphabet[i - 1] for i in t]) | |
else: | |
char_list = [] | |
for i in range(length): | |
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): | |
char_list.append(self.alphabet[t[i] - 1]) | |
return ''.join(char_list) | |
else: | |
# batch mode | |
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format( | |
t.numel(), length.sum()) | |
texts = [] | |
index = 0 | |
for i in range(length.numel()): | |
l = length[i] | |
texts.append( | |
self.decode( | |
t[index:index + l], torch.IntTensor([l]), raw=raw)) | |
index += l | |
return texts | |
class strLabelConverter(object): | |
"""Convert between str and label. | |
NOTE: | |
Insert `blank` to the alphabet for CTC. | |
Args: | |
alphabet (str): set of the possible characters. | |
ignore_case (bool, default=True): whether or not to ignore all of the case. | |
""" | |
def __init__(self, alphabet, ignore_case=False): | |
self._ignore_case = ignore_case | |
if self._ignore_case: | |
alphabet = alphabet.lower() | |
self.alphabet = alphabet + '-' # for `-1` index | |
self.dict = {} | |
for i, char in enumerate(alphabet): | |
# NOTE: 0 is reserved for 'blank' required by wrap_ctc | |
self.dict[char] = i + 1 | |
def encode(self, text): | |
"""Support batch or single str. | |
Args: | |
text (str or list of str): texts to convert. | |
Returns: | |
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. | |
torch.IntTensor [n]: length of each text. | |
""" | |
''' | |
if isinstance(text, str): | |
text = [ | |
self.dict[char.lower() if self._ignore_case else char] | |
for char in text | |
] | |
length = [len(text)] | |
elif isinstance(text, collections.Iterable): | |
length = [len(s) for s in text] | |
text = ''.join(text) | |
text, _ = self.encode(text) | |
return (torch.IntTensor(text), torch.IntTensor(length)) | |
''' | |
length = [] | |
result = [] | |
results = [] | |
for item in text: | |
item = item.decode('utf-8', 'strict') | |
length.append(len(item)) | |
for char in item: | |
index = self.dict[char] | |
result.append(index) | |
results.append(result) | |
result = [] | |
return (torch.nn.utils.rnn.pad_sequence([torch.LongTensor(text) for text in results], batch_first=True), torch.IntTensor(length)) | |
def decode(self, t, length, raw=False): | |
"""Decode encoded texts back into strs. | |
Args: | |
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. | |
torch.IntTensor [n]: length of each text. | |
Raises: | |
AssertionError: when the texts and its length does not match. | |
Returns: | |
text (str or list of str): texts to convert. | |
""" | |
if length.numel() == 1: | |
length = length[0] | |
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), | |
length) | |
if raw: | |
return ''.join([self.alphabet[i - 1] for i in t]) | |
else: | |
char_list = [] | |
for i in range(length): | |
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): | |
char_list.append(self.alphabet[t[i] - 1]) | |
return ''.join(char_list) | |
else: | |
# batch mode | |
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format( | |
t.numel(), length.sum()) | |
texts = [] | |
index = 0 | |
for i in range(length.numel()): | |
l = length[i] | |
texts.append( | |
self.decode( | |
t[index:index + l], torch.IntTensor([l]), raw=raw)) | |
index += l | |
return texts | |