Spaces:
Running
Running
File size: 4,055 Bytes
29f689c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import re
import numpy as np
from tools.utils.logging import get_logger
class BaseRecLabelEncode(object):
"""Convert between text-label and text-index."""
def __init__(
self,
max_text_length,
character_dict_path=None,
use_space_char=False,
lower=False,
):
self.max_text_len = max_text_length
self.beg_str = 'sos'
self.end_str = 'eos'
self.lower = lower
self.reverse = False
if character_dict_path is None:
logger = get_logger()
logger.warning(
'The character_dict_path is None, model can only recognize number and lower letters'
)
self.character_str = '0123456789abcdefghijklmnopqrstuvwxyz'
dict_character = list(self.character_str)
self.lower = True
else:
self.character_str = []
with open(character_dict_path, 'rb') as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip('\n').strip('\r\n')
self.character_str.append(line)
if use_space_char:
self.character_str.append(' ')
dict_character = list(self.character_str)
if 'arabic' in character_dict_path:
self.reverse = True
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
def label_reverse(self, text):
text_re = []
c_current = ''
for c in text:
if not bool(re.search('[a-zA-Z0-9 :*./%+-١٢٣٤٥٦٧٨٩٠]', c)):
if c_current != '':
text_re.append(c_current)
text_re.append(c)
c_current = ''
else:
c_current += c
if c_current != '':
text_re.append(c_current)
return ''.join(text_re[::-1])
def add_special_char(self, dict_character):
return dict_character
def encode(self, text):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
output:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
if len(text) == 0:
return None
if self.lower:
text = text.lower()
text_list = []
for char in text:
if char not in self.dict:
continue
text_list.append(self.dict[char])
if len(text_list) == 0 or len(text_list) > self.max_text_len:
return None
return text_list
class CTCLabelEncode(BaseRecLabelEncode):
"""Convert between text-label and text-index."""
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
**kwargs):
super(CTCLabelEncode,
self).__init__(max_text_length, character_dict_path,
use_space_char)
self.is_reverse = kwargs.get('is_reverse', False)
def __call__(self, data):
text = data['label']
if self.reverse and self.is_reverse: # for arabic rec
text = self.label_reverse(text)
text = self.encode(text)
if text is None:
return None
data['length'] = np.array(len(text))
text = text + [0] * (self.max_text_len - len(text))
data['label'] = np.array(text)
label = [0] * len(self.character)
for x in text:
label[x] += 1
data['label_ace'] = np.array(label)
return data
def add_special_char(self, dict_character):
dict_character = ['blank'] + dict_character
return dict_character
|