SunderAli17 commited on
Commit
17d084f
1 Parent(s): 81dfef5

Create tokenizer.py

Browse files
Files changed (1) hide show
  1. evaclip/tokenizer.py +198 -0
evaclip/tokenizer.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP tokenizer
2
+ Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
3
+ """
4
+ import gzip
5
+ import html
6
+ import os
7
+ from functools import lru_cache
8
+ from typing import Union, List
9
+
10
+ import ftfy
11
+ import regex as re
12
+ import torch
13
+
14
+ # https://stackoverflow.com/q/62691279
15
+ import os
16
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
17
+
18
+
19
+ @lru_cache()
20
+ def default_bpe():
21
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
22
+
23
+
24
+ @lru_cache()
25
+ def bytes_to_unicode():
26
+ """
27
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
28
+ The reversible bpe codes work on unicode strings.
29
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
30
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
31
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
32
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
33
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
34
+ """
35
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
36
+ cs = bs[:]
37
+ n = 0
38
+ for b in range(2**8):
39
+ if b not in bs:
40
+ bs.append(b)
41
+ cs.append(2**8+n)
42
+ n += 1
43
+ cs = [chr(n) for n in cs]
44
+ return dict(zip(bs, cs))
45
+
46
+
47
+ def get_pairs(word):
48
+ """Return set of symbol pairs in a word.
49
+ Word is represented as tuple of symbols (symbols being variable-length strings).
50
+ """
51
+ pairs = set()
52
+ prev_char = word[0]
53
+ for char in word[1:]:
54
+ pairs.add((prev_char, char))
55
+ prev_char = char
56
+ return pairs
57
+
58
+
59
+ def basic_clean(text):
60
+ text = ftfy.fix_text(text)
61
+ text = html.unescape(html.unescape(text))
62
+ return text.strip()
63
+
64
+
65
+ def whitespace_clean(text):
66
+ text = re.sub(r'\s+', ' ', text)
67
+ text = text.strip()
68
+ return text
69
+
70
+
71
+ class SimpleTokenizer(object):
72
+ def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
73
+ self.byte_encoder = bytes_to_unicode()
74
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
75
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
76
+ merges = merges[1:49152-256-2+1]
77
+ merges = [tuple(merge.split()) for merge in merges]
78
+ vocab = list(bytes_to_unicode().values())
79
+ vocab = vocab + [v+'</w>' for v in vocab]
80
+ for merge in merges:
81
+ vocab.append(''.join(merge))
82
+ if not special_tokens:
83
+ special_tokens = ['<start_of_text>', '<end_of_text>']
84
+ else:
85
+ special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
86
+ vocab.extend(special_tokens)
87
+ self.encoder = dict(zip(vocab, range(len(vocab))))
88
+ self.decoder = {v: k for k, v in self.encoder.items()}
89
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
90
+ self.cache = {t:t for t in special_tokens}
91
+ special = "|".join(special_tokens)
92
+ self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
93
+
94
+ self.vocab_size = len(self.encoder)
95
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
96
+
97
+ def bpe(self, token):
98
+ if token in self.cache:
99
+ return self.cache[token]
100
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
101
+ pairs = get_pairs(word)
102
+
103
+ if not pairs:
104
+ return token+'</w>'
105
+
106
+ while True:
107
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
108
+ if bigram not in self.bpe_ranks:
109
+ break
110
+ first, second = bigram
111
+ new_word = []
112
+ i = 0
113
+ while i < len(word):
114
+ try:
115
+ j = word.index(first, i)
116
+ new_word.extend(word[i:j])
117
+ i = j
118
+ except:
119
+ new_word.extend(word[i:])
120
+ break
121
+
122
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
123
+ new_word.append(first+second)
124
+ i += 2
125
+ else:
126
+ new_word.append(word[i])
127
+ i += 1
128
+ new_word = tuple(new_word)
129
+ word = new_word
130
+ if len(word) == 1:
131
+ break
132
+ else:
133
+ pairs = get_pairs(word)
134
+ word = ' '.join(word)
135
+ self.cache[token] = word
136
+ return word
137
+
138
+ def encode(self, text):
139
+ bpe_tokens = []
140
+ text = whitespace_clean(basic_clean(text)).lower()
141
+ for token in re.findall(self.pat, text):
142
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
143
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
144
+ return bpe_tokens
145
+
146
+ def decode(self, tokens):
147
+ text = ''.join([self.decoder[token] for token in tokens])
148
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
149
+ return text
150
+
151
+
152
+ _tokenizer = SimpleTokenizer()
153
+
154
+
155
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
156
+ """
157
+ Returns the tokenized representation of given input string(s)
158
+ Parameters
159
+ ----------
160
+ texts : Union[str, List[str]]
161
+ An input string or a list of input strings to tokenize
162
+ context_length : int
163
+ The context length to use; all CLIP models use 77 as the context length
164
+ Returns
165
+ -------
166
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
167
+ """
168
+ if isinstance(texts, str):
169
+ texts = [texts]
170
+
171
+ sot_token = _tokenizer.encoder["<start_of_text>"]
172
+ eot_token = _tokenizer.encoder["<end_of_text>"]
173
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
174
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
175
+
176
+ for i, tokens in enumerate(all_tokens):
177
+ if len(tokens) > context_length:
178
+ tokens = tokens[:context_length] # Truncate
179
+ tokens[-1] = eot_token
180
+ result[i, :len(tokens)] = torch.tensor(tokens)
181
+
182
+ return result
183
+
184
+
185
+ class HFTokenizer:
186
+ "HuggingFace tokenizer wrapper"
187
+ def __init__(self, tokenizer_name:str):
188
+ from transformers import AutoTokenizer
189
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
190
+
191
+ def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor:
192
+ # same cleaning as for default tokenizer, except lowercasing
193
+ # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
194
+ if isinstance(texts, str):
195
+ texts = [texts]
196
+ texts = [whitespace_clean(basic_clean(text)) for text in texts]
197
+ input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids
198
+ return input_ids