Spaces:
Runtime error
Runtime error
Upload text_tokenizer.py
Browse files- min_dalle/text_tokenizer.py +41 -0
min_dalle/text_tokenizer.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import inf
|
2 |
+
from typing import List, Tuple
|
3 |
+
from emoji import demojize
|
4 |
+
|
5 |
+
class TextTokenizer:
|
6 |
+
def __init__(self, vocab: dict, merges: List[str]):
|
7 |
+
self.token_from_subword = vocab
|
8 |
+
pairs = [tuple(pair.split()) for pair in merges]
|
9 |
+
self.rank_from_pair = dict(zip(pairs, range(len(pairs))))
|
10 |
+
|
11 |
+
def tokenize(self, text: str, is_verbose: bool = False) -> List[int]:
|
12 |
+
sep_token = self.token_from_subword['</s>']
|
13 |
+
cls_token = self.token_from_subword['<s>']
|
14 |
+
unk_token = self.token_from_subword['<unk>']
|
15 |
+
text = demojize(text, delimiters=['', ''])
|
16 |
+
text = text.lower().encode("ascii", errors="ignore").decode()
|
17 |
+
tokens = [
|
18 |
+
self.token_from_subword.get(subword, unk_token)
|
19 |
+
for word in text.split(" ") if len(word) > 0
|
20 |
+
for subword in self.get_byte_pair_encoding(word, is_verbose)
|
21 |
+
]
|
22 |
+
return [cls_token] + tokens + [sep_token]
|
23 |
+
|
24 |
+
def get_byte_pair_encoding(self, word: str, is_verbose: bool) -> List[str]:
|
25 |
+
def get_pair_rank(pair: Tuple[str, str]) -> int:
|
26 |
+
return self.rank_from_pair.get(pair, inf)
|
27 |
+
|
28 |
+
subwords = [chr(ord(" ") + 256)] + list(word)
|
29 |
+
while len(subwords) > 1:
|
30 |
+
pairs = list(zip(subwords[:-1], subwords[1:]))
|
31 |
+
pair_to_merge = min(pairs, key=get_pair_rank)
|
32 |
+
if pair_to_merge not in self.rank_from_pair: break
|
33 |
+
i = pairs.index(pair_to_merge)
|
34 |
+
subwords = (
|
35 |
+
(subwords[:i] if i > 0 else []) +
|
36 |
+
[subwords[i] + subwords[i + 1]] +
|
37 |
+
(subwords[i + 2:] if i + 2 < len(subwords) else [])
|
38 |
+
)
|
39 |
+
|
40 |
+
if is_verbose: print(subwords)
|
41 |
+
return subwords
|