import torch from .symbols import (RNN_BIG_CHARACTERS_MAPPING, DIACRITICS_LIST, ARABIC_LETTERS_LIST, RNN_REV_CLASSES_MAPPING, RNN_SMALL_CHARACTERS_MAPPING) def remove_diacritics(data, DIACRITICS_LIST): return data.translate(str.maketrans('', '', ''.join(DIACRITICS_LIST))) CHARACTERS_MAPPING = RNN_BIG_CHARACTERS_MAPPING # CHARACTERS_MAPPING = RNN_SMALL_CHARACTERS_MAPPING REV_CLASSES_MAPPING = RNN_REV_CLASSES_MAPPING def encode(input_text:str): x = [CHARACTERS_MAPPING['']] for idx, char in enumerate(input_text): if char in DIACRITICS_LIST: continue if char not in CHARACTERS_MAPPING: x.append(CHARACTERS_MAPPING['']) else: x.append(CHARACTERS_MAPPING[char]) x.append(CHARACTERS_MAPPING['']) return x def decode(probs, input_text:str): probs = probs[0][1:] output = '' for char, prediction in zip(remove_diacritics(input_text, DIACRITICS_LIST), probs): output += char if char not in ARABIC_LETTERS_LIST: continue prediction = torch.argmax(prediction).item() if '<' in REV_CLASSES_MAPPING[prediction]: continue output += REV_CLASSES_MAPPING[prediction] return output