Update README.md
Browse files
README.md
CHANGED
@@ -30,16 +30,17 @@ label2diacritic = {0: 'ู', 1: 'ู', 2: 'ู', 3: 'ู', 4: ''}
|
|
30 |
|
31 |
def arabic2diacritics(text, model, tokenizer):
|
32 |
tokens = tokenizer(text, return_tensors="pt")
|
33 |
-
preds = (model(**tokens).logits.sigmoid() > 0.5)[0]
|
34 |
new_text = []
|
35 |
for p, c in zip(preds, text):
|
|
|
36 |
for i in range(1, 5):
|
37 |
if p[i]:
|
38 |
new_text.append(label2diacritic[i])
|
39 |
# check shadda last
|
40 |
if p[0]:
|
41 |
new_text.append(label2diacritic[0])
|
42 |
-
|
43 |
new_text = "".join(new_text)
|
44 |
return new_text
|
45 |
|
|
|
30 |
|
31 |
def arabic2diacritics(text, model, tokenizer):
|
32 |
tokens = tokenizer(text, return_tensors="pt")
|
33 |
+
preds = (model(**tokens).logits.sigmoid() > 0.5)[0][1:-1] # remove CLS and SEP
|
34 |
new_text = []
|
35 |
for p, c in zip(preds, text):
|
36 |
+
new_text.append(c)
|
37 |
for i in range(1, 5):
|
38 |
if p[i]:
|
39 |
new_text.append(label2diacritic[i])
|
40 |
# check shadda last
|
41 |
if p[0]:
|
42 |
new_text.append(label2diacritic[0])
|
43 |
+
|
44 |
new_text = "".join(new_text)
|
45 |
return new_text
|
46 |
|