Spaces:
Sleeping
Sleeping
File size: 6,768 Bytes
693faa9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import unicodedata
def get_stats(ids, counts=None):
"""
Given a list of ints/ids, count the pairwise occurence
Returns count dict
"""
counts = {} if counts is None else counts
for pair in zip(ids, ids[1:]):
counts[pair] = counts.get(pair, 0) + 1
return counts
def merge(ids, pair_to_merge, idx_to_use):
"""
find and merge the given `pair` and replace it with given `idx_to_use` in given list of ints/ids
Return updated list
"""
new_ids = []
i = 0
while i < len(ids):
# check pair match AND if 0th position is NOT last element
if i < len(ids) - 1 and (pair_to_merge[0] == ids[i] and pair_to_merge[1] == ids[i + 1]):
new_ids.append(idx_to_use) # pair found, append to new list of ids
i += 2 # skip by two elements as the pair is found
else:
# pair not found in the list, normal 1 element update
new_ids.append(ids[i]) # append the current item from old list as it is not a pair
i += 1
return new_ids
# helper functions taken directly from Karpathy's BPE repo
def replace_control_characters(s: str) -> str:
chars = []
for ch in s:
if unicodedata.category(ch)[0] != "C":
chars.append(ch) # this character is ok
else:
chars.append(f"\\u{ord(ch):04x}") # escape
return "".join(chars)
def render_token(t: bytes) -> str:
# pretty print a token, escaping control characters
s = t.decode('utf-8', errors='replace')
s = replace_control_characters(s)
return s
# base Tokenizer class
class Tokenizer:
"""Base Tokenizer class, MUST inherit for use"""
def __init__(self) -> None:
# defaults -> no patterns used, no merges, use usual first 256 bytes as mapping/vocab items
self.merges = {} # this will hold the actual merged data eg: (101, 32) -> 256 , here say 101 chr e and 32 ' '(space) had max pair count -> replace this with next ID in order
self.pattern = "" # any regular expression pattern if to be used on raw text
self.special_tokens = {} # a mapping t hold any special tokens, empty here, to be used for subclasses, str -> int, e.g. {'<|endoftext|>': 90257}
self.vocab = self._build_vocab() # int -> bytes
def train(self, text, vocab_size, verbose=False):
# Tokenizer can train a vocabulary of size vocab_size from text
raise NotImplementedError
def encode(self, text):
# Tokenizer can encode a string into a list of integers
raise NotImplementedError
def decode(self, ids):
# Tokenizer can decode a list of integers into a string
raise NotImplementedError
def _build_vocab(self):
# here vocab starts from normal 256 bytes of ints and then merges after it
vocab = {idx: bytes([idx]) for idx in range(256)}
for (pos0, pos1), idx in self.merges.items():
vocab[idx] = vocab[pos0] + vocab[pos1]
# NOW add special tokens defined in __init__()
# NOTE encode special tokens using .encode with UTF-8 encoding
for tok, idx in self.special_tokens.items():
vocab[idx] = tok.encode("utf-8")
# directly from BPE repo
def save(self, file_prefix):
"""
Saves two files: file_prefix.vocab and file_prefix.model
This is inspired (but not equivalent to!) sentencepiece's model saving:
- model file is the critical one, intended for load()
- vocab file is just a pretty printed version for human inspection only
"""
print("Saving tokenizer...")
# write the model: to be used in load() later
model_file = file_prefix + ".model"
with open(model_file, 'w') as f:
# write the version, pattern and merges, that's all that's needed
f.write("base v1\n")
f.write(f"{self.pattern}\n")
# write the special tokens, first the number of them, then each one
f.write(f"{len(self.special_tokens)}\n")
for special, idx in self.special_tokens.items():
f.write(f"{special} {idx}\n")
# the merges dict
for idx1, idx2 in self.merges:
f.write(f"{idx1} {idx2}\n")
# write the vocab: for the human to look at
vocab_file = file_prefix + ".vocab"
inverted_merges = {idx: pair for pair, idx in self.merges.items()}
with open(vocab_file, "w", encoding="utf-8") as f:
for idx, token in self.vocab.items():
# note: many tokens may be partial utf-8 sequences
# and cannot be decoded into valid strings. Here we're using
# errors='replace' to replace them with the replacement char �.
# this also means that we couldn't possibly use .vocab in load()
# because decoding in this way is a lossy operation!
s = render_token(token)
# find the children of this token, if any
if idx in inverted_merges:
# if this token has children, render it nicely as a merge
idx0, idx1 = inverted_merges[idx]
s0 = render_token(self.vocab[idx0])
s1 = render_token(self.vocab[idx1])
f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
else:
# otherwise this is leaf token, just print it
# (this should just be the first 256 tokens, the bytes)
f.write(f"[{s}] {idx}\n")
def load(self, model_file):
"""Inverse of save() but only for the model file"""
assert model_file.endswith(".model")
# read the model file
merges = {}
special_tokens = {}
idx = 256
with open(model_file, 'r', encoding="utf-8") as f:
# read the version
version = f.readline().strip()
print(version)
# read the pattern
self.pattern = f.readline().strip()
# read the special tokens
num_special = int(f.readline().strip())
for _ in range(num_special):
special, special_idx = f.readline().strip().split()
special_tokens[special] = int(special_idx)
# read the merges
for line in f:
idx1, idx2 = map(int, line.split())
merges[(idx1, idx2)] = idx
idx += 1
self.merges = merges
self.special_tokens = special_tokens
self.vocab = self._build_vocab()
|