mana_tokenizer / base.py
tspersian's picture
package
d786ff1
from collections import Counter
from functools import lru_cache
import requests
from datasets import IterableDataset, Dataset
from pyarrow import ChunkedArray
from joblib import Parallel, delayed, cpu_count
import time
import os
import regex as re
import csv
import time
from mana_tokenizer.helper import _process_string_scalar, render_token, merge
class Tokenizer:
"""Base class for Tokenizers"""
def __init__(self, pattern=None, multiprocess=True, store_dict=False, stop_list_size=0, freq_cutoff=1):
# default: vocab size of 256 (all bytes), no merges, no patterns
MANA_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re|می|نمی|به|بی|در|باز|بر|فرا|هم|ور|وا|ف|ک|چ|ن|پ|ا|از|ای|ی|ها|ترین|تر|ات|ان|ت|ٔ|یی|‌ا)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
self.merges = {} # (int, int) -> int
self.pattern = "" # str
self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
self.vocab = self._build_vocab() # int -> bytes
self.pattern = MANA_SPLIT_PATTERN if pattern is None else pattern
self.compiled_pattern = re.compile(self.pattern)
self.multiprocess = multiprocess
if multiprocess:
self._cpus = cpu_count()
else:
self._cpus = 1
self.store_dict = store_dict
self.stop_list_size = stop_list_size
self.stop_words = {}
self.freq_cutoff = freq_cutoff
def _id_dict_to_list(self, ids):
if self.stop_list_size:
# get twice as many to be sure to be able to get X chunks of length > 1
top2X = ids.most_common(2*self.stop_list_size)
index = len(self.vocab)
stop_index = index + self.stop_list_size
stop_words = {}
for key, val in top2X:
if len(key) > 1: # and re.match(r'^ [A-Za-z\'’`]+$[A-Za-z]*', key):
stop_words[key] = index
self.vocab[index] = key.encode('utf-8')
index += 1
if index == stop_index:
break
self.stop_words = stop_words
if self.freq_cutoff > 1:
return [([*key.encode('utf-8')], val) for key, val in ids.items()
if (val >= self.freq_cutoff and key not in self.stop_words)]
else:
return [([*key.encode('utf-8')], val) for key, val in ids.items()
if key not in self.stop_words]
else: # self.stop_list_size == 0
if self.freq_cutoff > 1:
return [([*key.encode('utf-8')], val) for key, val in ids.items()
if val >= self.freq_cutoff]
else:
return [([*key.encode('utf-8')], val) for key, val in ids.items()]
def _import_data(self, data):
# determine if `data` is a text as a string, a path to a file, a url to
# a text document, a dictionary of datasets kwargs, or a list of any of
# the above. Return a list of 2-tuples of bytes objects and their counts.
ids = Counter()
if not isinstance(data, (list, tuple)):
data = (data,)
for item in data:
# convert to ChunkedArray, dict, or str of text to parse
if isinstance(item, Dataset):
item = item.data['text']
elif isinstance(item, str) and item.endswith('.csv'): # csv file from previous data load
with open(item, 'r') as f:
reader = csv.reader(f)
next(reader)
item = {k: int(v) for k, v in reader}
elif isinstance(item, str):
if item.startswith('https://') or item.startswith('http://'):
item = requests.get(item).text # if it's a url, assume it's to a text file
elif os.path.isfile(item) and item.endswith('.txt'):
with open(item, 'r', encoding='utf-8') as f:
item = f.read()
# process data
if isinstance(item, dict):
last_item = item.popitem()
if last_item[1] != 0:
print(f'Warning: the csv file or dictionary passed does not seem to have been made by this tokenizer.')
item[last_item[0]] = last_item[1]
elif last_item[0] != self.pattern:
print(f'Warning: the dictionary or csv file passed did not use the same split pattern.')
ids.update(item)
elif isinstance(item, str): # assume the string is the text itself
ids.update(re.findall(self.compiled_pattern, item))
elif isinstance(item, ChunkedArray):
batch_size = len(item) // (self._cpus*2) or 1
batches = [item[i:i + batch_size] for i in range(0, len(item), batch_size)]
print(f'Processing {len(batches)} batches of size {batch_size}')
results = Parallel(n_jobs=self._cpus)(delayed(_process_string_scalar)(batch, self.compiled_pattern) for batch in batches)
for result in results: # Aggregate results into one Counter
ids.update(result)
elif isinstance(item, IterableDataset):
print('Serially processing IterableDataset...')
for _dict in item:
ids.update(re.findall(self.compiled_pattern, _dict['text']))
if self.store_dict: # store dict compression of dataset to a csv file if requested
ids[self.pattern] = 0 # store the pattern used to split the text as the last key
formatted_time = time.strftime('%Y-%m-%d-%H_%M', time.localtime())
filename = f'{formatted_time}-dataset-dict.csv'
try:
with open(filename, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(['text_chunk', 'count'])
for key, value in ids.items():
writer.writerow([key, value])
print(f"Stored dictionary of {len(ids)} keys to {filename}")
except:
print('Failed to store dictionary of dataset.')
del ids[self.pattern] # remove the pattern key from the ids dict
ids = self._id_dict_to_list(ids)
return ids
def train(self, text, vocab_size, verbose=False):
# Tokenizer can train a vocabulary of size vocab_size from text
raise NotImplementedError
def _build_vocab(self):
# vocab is simply and deterministically derived from merges
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in self.merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
for special, idx in self.special_tokens.items():
vocab[idx] = special.encode("utf-8")
return vocab
def register_special_tokens(self, special_tokens):
# special_tokens is a dictionary of str -> int
# example: {"<|endoftext|>": 100257}
self.special_tokens = special_tokens
self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}
def save(self, file_prefix):
"""
Saves two files: file_prefix.vocab and file_prefix.model
This is inspired (but not equivalent to!) sentencepiece's model saving:
- model file is the critical one, intended for load() later
- vocab file is just a pretty printed version for human inspection only
"""
# write the model: to be used in load() later
model_file = file_prefix + ".model"
with open(model_file, 'w', encoding='utf-8') as f: # Added encoding='utf-8'
# write the version, pattern and merges, that's all that's needed
f.write("mana v1\n")
f.write(f"{self.pattern}\n")
# write the special tokens, first the number of them, then each one
f.write(f"{len(self.special_tokens)}\n")
for special, idx in self.special_tokens.items():
f.write(f"{special} {idx}\n")
# the merges dict
for key in self.merges:
if isinstance(key, tuple):
f.write(f"{key[0]} {key[1]}\n")
else:
f.write(f"{key}\n")
# write the vocab: for the human to look at
vocab_file = file_prefix + ".vocab"
inverted_merges = {idx: pair for pair, idx in self.merges.items()}
with open(vocab_file, "w", encoding="utf-8") as f: # Ensure this is also utf-8
for idx, token in self.vocab.items():
s = render_token(token)
# find the children of this token, if any
if idx in inverted_merges:
idx0, idx1 = inverted_merges[idx]
s0 = render_token(self.vocab[idx0])
s1 = render_token(self.vocab[idx1])
f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
else:
f.write(f"[{s}] {idx}\n")
def load(self, model_file):
"""Inverse of save() but only for the model file"""
assert model_file.endswith(".model")
# read the model file
merges = {}
special_tokens = {}
idx = 256
with open(model_file, 'r', encoding="utf-8") as f:
# read the version
version = f.readline().strip()
assert version == "mana v1"
# read the pattern
self.pattern = f.readline().strip()
# read the special tokens
num_special = int(f.readline().strip())
for _ in range(num_special):
special, special_idx = f.readline().strip().split()
special_tokens[special] = int(special_idx)
# read the merges
for line in f:
idx1, idx2 = map(int, line.split())
merges[(idx1, idx2)] = idx
idx += 1
self.merges = merges
self.special_tokens = special_tokens
self.vocab = self._build_vocab()
def decode(self, ids):
# given ids (list of integers), return Python string
part_bytes = [self.vocab[idx] if idx in self.vocab
else self.inverse_special_tokens[idx].encode("utf-8")
for idx in ids] # raises KeyError if any idx is not a valid token
text_bytes = b"".join(part_bytes)
text = text_bytes.decode("utf-8", errors="replace")
return text
@lru_cache(maxsize=131072)
def _encode_chunk(self, chunk):
if chunk in self.stop_words: # TODO: revisit this if statement
return [self.stop_words[chunk]]
# return the token chunk as a list of ints, similar to a bytes object
chunk = [*chunk.encode("utf-8")]
len_chunk = len(chunk)
while len_chunk >= 2:
# find the pair with the lowest merge index
low = 987654321
for i in range(len_chunk - 1):
current_pair = (chunk[i], chunk[i+1])
new_val = self.merges.get(current_pair, 987654321)
if new_val < low:
pair = current_pair
low = new_val
if low == 987654321: # no merges were found
break # nothing else can be merged
# otherwise let's merge the best pair (lowest merge index)
idx = self.merges[pair]
len_chunk = merge(chunk, pair, idx, len_chunk)
return chunk # list of ints
def encode_ordinary(self, text):
"""Encoding that ignores any special tokens."""
ids = []
for chunk in re.findall(self.compiled_pattern, text):
ids.extend(self._encode_chunk(chunk))
return ids
def encode(self, text, allowed_special="none_raise"):
"""
Unlike encode_ordinary, this function handles special tokens.
allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens
if none_raise, then an error is raised if any special token is encountered in text
this is the default tiktoken behavior right now as well
any other behavior is either annoying, or a major footgun
"""
# decode the user desire w.r.t. handling of special tokens
special = None
if allowed_special == "all":
special = self.special_tokens
elif allowed_special == "none":
special = {}
elif allowed_special == "none_raise":
special = {}
assert all(token not in text for token in self.special_tokens)
elif isinstance(allowed_special, set):
special = {k: v for k, v in self.special_tokens.items() if k in allowed_special}
else:
raise ValueError(f"allowed_special={allowed_special} not understood")
if not special: # shortcut: if no special tokens, just use the ordinary encoding
return self.encode_ordinary(text)
# split on special tokens. Note that surrounding the pattern with ()
# makes it into a capturing group, so the special tokens will be included
special_pattern = f"({'|'.join([re.escape(k) for k in special])})"
special_chunks = re.split(special_pattern, text)
# now all the special characters are separated from the rest of the text
# all chunks of text are encoded separately, then results are joined
ids = []
for part in special_chunks:
special_token = special.get(part)
if special_token is None: # this is an ordinary sequence, encode it normally
ids.extend(self.encode_ordinary(part))
else: # this is a special token, encode it separately as a special case
ids.append(special_token)
return ids
def batch_encode(self, texts, allowed_special="none_raise"):
"""
Encode a list of texts in batch mode.
Each text will be encoded according to the handling of special tokens specified in allowed_special.
Parameters:
texts (list of str): List of texts to encode.
allowed_special (str|set): Special token handling mode.
Returns:
list of list of int: A list where each element is the encoded form of a text in `texts`.
"""
return [self.encode(text, allowed_special=allowed_special) for text in texts]