Spaces:
Running
Running
File size: 3,441 Bytes
29f689c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
'''
This code is refer from:
https://github.com/AlibabaResearch/AdvancedLiterateMachinery/blob/main/OCR/MGP-STR
'''
import numpy as np
from openrec.preprocess.ctc_label_encode import BaseRecLabelEncode
class MGPLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
SPACE = '[s]'
GO = '[GO]'
list_token = [GO, SPACE]
def __init__(self,
max_text_length,
character_dict_path=None,
use_space_char=False,
only_char=False,
**kwargs):
super(MGPLabelEncode,
self).__init__(max_text_length, character_dict_path,
use_space_char)
# character (str): set of the possible characters.
# [GO] for the start token of the attention decoder. [s] for end-of-sentence token.
self.batch_max_length = max_text_length + len(self.list_token)
self.only_char = only_char
if not only_char:
# transformers==4.2.1
from transformers import BertTokenizer, GPT2Tokenizer
self.bpe_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
self.wp_tokenizer = BertTokenizer.from_pretrained(
'bert-base-uncased')
def __call__(self, data):
text = data['label']
char_text, char_len = self.encode(text)
if char_text is None:
return None
data['length'] = np.array(char_len)
data['char_label'] = np.array(char_text)
if self.only_char:
return data
bpe_text = self.bpe_encode(text)
if bpe_text is None:
return None
wp_text = self.wp_encode(text)
data['bpe_label'] = np.array(bpe_text)
data['wp_label'] = wp_text
return data
def add_special_char(self, dict_character):
dict_character = self.list_token + dict_character
return dict_character
def encode(self, text):
""" convert text-label into text-index.
"""
if len(text) == 0:
return None, None
if self.lower:
text = text.lower()
length = len(text)
text = [self.GO] + list(text) + [self.SPACE]
text_list = []
for char in text:
if char not in self.dict:
continue
text_list.append(self.dict[char])
if len(text_list) == 0 or len(text_list) > self.batch_max_length:
return None, None
text_list = text_list + [self.dict[self.GO]
] * (self.batch_max_length - len(text_list))
return text_list, length
def bpe_encode(self, text):
if len(text) == 0:
return None
token = self.bpe_tokenizer(text)['input_ids']
text_list = [1] + token + [2]
if len(text_list) == 0 or len(text_list) > self.batch_max_length:
return None
text_list = text_list + [self.dict[self.GO]
] * (self.batch_max_length - len(text_list))
return text_list
def wp_encode(self, text):
wp_target = self.wp_tokenizer([text],
padding='max_length',
max_length=self.batch_max_length,
truncation=True,
return_tensors='np')
return wp_target['input_ids'][0]
|