sayanbanerjee32's picture
Upload folder using huggingface_hub
cbab219 verified
raw
history blame
2.2 kB
import regex as re
def get_stats(ids, counts= None):
counts = {} if counts is None else counts
for pair in zip(ids, ids[1:]):
counts[pair] = counts.get(pair, 0) + 1
return counts
def merge(ids, pair, idx):
newids = []
i = 0
while i < len(ids):
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
newids.append(idx)
i += 2
else:
newids.append(ids[i])
i += 1
return newids
def _encode_chunk(text_bytes, merges):
# return the token ids
# let's begin. first, convert all bytes to integers in range 0..255
ids = list(text_bytes)
while len(ids) >= 2:
# find the pair with the lowest merge index
stats = get_stats(ids)
pair = min(stats, key=lambda p: merges.get(p, float("inf")))
# subtle: if there are no more merges available, the key will
# result in an inf for every single pair, and the min will be
# just the first pair in the list, arbitrarily
# we can detect this terminating case by a membership check
if pair not in merges:
break # nothing else can be merged anymore
# otherwise let's merge the best pair (lowest merge index)
idx = merges[pair]
ids = merge(ids, pair, idx)
return ids
def encode(text, regex_pat, merges):
# split text into chunks of text by categories defined in regex pattern
text_chunks = re.findall(regex_pat, text)
# all chunks of text are encoded separately, then results are joined
ids = []
for chunk in text_chunks:
chunk_bytes = chunk.encode("utf-8") # raw bytes
chunk_ids = _encode_chunk(chunk_bytes, merges)
ids.extend(chunk_ids)
return ids
def decode(ids, vocab):
# given ids (list of integers), return Python string
part_bytes = []
for idx in ids:
if idx in vocab:
part_bytes.append(vocab[idx])
else:
raise ValueError(f"invalid token id: {idx}")
text_bytes = b"".join(part_bytes)
text = text_bytes.decode("utf-8", errors="replace")
return text