File size: 1,272 Bytes
7694c84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
from .symbols import (RNN_BIG_CHARACTERS_MAPPING, 
DIACRITICS_LIST, ARABIC_LETTERS_LIST, RNN_REV_CLASSES_MAPPING, RNN_SMALL_CHARACTERS_MAPPING)


def remove_diacritics(data, DIACRITICS_LIST):
  return data.translate(str.maketrans('', '', ''.join(DIACRITICS_LIST)))

CHARACTERS_MAPPING = RNN_BIG_CHARACTERS_MAPPING
# CHARACTERS_MAPPING = RNN_SMALL_CHARACTERS_MAPPING
REV_CLASSES_MAPPING = RNN_REV_CLASSES_MAPPING


def encode(input_text:str):
    x = [CHARACTERS_MAPPING['<SOS>']]
    for idx, char in enumerate(input_text):
        if char in DIACRITICS_LIST:
            continue
        if char not in CHARACTERS_MAPPING:
            x.append(CHARACTERS_MAPPING['<UNK>'])
        else:
            x.append(CHARACTERS_MAPPING[char])
            
    x.append(CHARACTERS_MAPPING['<EOS>'])

    return x

def decode(probs, input_text:str):
    probs = probs[0][1:]

    output = ''
    for char, prediction in zip(remove_diacritics(input_text, DIACRITICS_LIST), probs):
        output += char

        if char not in ARABIC_LETTERS_LIST:
            continue

        prediction = torch.argmax(prediction).item()

        if '<' in REV_CLASSES_MAPPING[prediction]:
            continue

        output += REV_CLASSES_MAPPING[prediction]
    
    return output