|
from collections import Counter |
|
from functools import lru_cache |
|
import requests |
|
from datasets import IterableDataset, Dataset |
|
from pyarrow import ChunkedArray |
|
from joblib import Parallel, delayed, cpu_count |
|
import time |
|
import os |
|
import regex as re |
|
import csv |
|
import time |
|
from mana_tokenizer.helper import _process_string_scalar, render_token, merge |
|
|
|
class Tokenizer: |
|
"""Base class for Tokenizers""" |
|
def __init__(self, pattern=None, multiprocess=True, store_dict=False, stop_list_size=0, freq_cutoff=1): |
|
|
|
MANA_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re|می|نمی|به|بی|در|باز|بر|فرا|هم|ور|وا|ف|ک|چ|ن|پ|ا|از|ای|ی|ها|ترین|تر|ات|ان|ت|ٔ|یی|ا)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" |
|
self.merges = {} |
|
self.pattern = "" |
|
self.special_tokens = {} |
|
self.vocab = self._build_vocab() |
|
self.pattern = MANA_SPLIT_PATTERN if pattern is None else pattern |
|
self.compiled_pattern = re.compile(self.pattern) |
|
self.multiprocess = multiprocess |
|
if multiprocess: |
|
self._cpus = cpu_count() |
|
else: |
|
self._cpus = 1 |
|
self.store_dict = store_dict |
|
self.stop_list_size = stop_list_size |
|
self.stop_words = {} |
|
self.freq_cutoff = freq_cutoff |
|
|
|
def _id_dict_to_list(self, ids): |
|
if self.stop_list_size: |
|
|
|
top2X = ids.most_common(2*self.stop_list_size) |
|
index = len(self.vocab) |
|
stop_index = index + self.stop_list_size |
|
stop_words = {} |
|
for key, val in top2X: |
|
if len(key) > 1: |
|
stop_words[key] = index |
|
self.vocab[index] = key.encode('utf-8') |
|
index += 1 |
|
if index == stop_index: |
|
break |
|
self.stop_words = stop_words |
|
if self.freq_cutoff > 1: |
|
return [([*key.encode('utf-8')], val) for key, val in ids.items() |
|
if (val >= self.freq_cutoff and key not in self.stop_words)] |
|
else: |
|
return [([*key.encode('utf-8')], val) for key, val in ids.items() |
|
if key not in self.stop_words] |
|
else: |
|
if self.freq_cutoff > 1: |
|
return [([*key.encode('utf-8')], val) for key, val in ids.items() |
|
if val >= self.freq_cutoff] |
|
else: |
|
return [([*key.encode('utf-8')], val) for key, val in ids.items()] |
|
|
|
def _import_data(self, data): |
|
|
|
|
|
|
|
ids = Counter() |
|
if not isinstance(data, (list, tuple)): |
|
data = (data,) |
|
for item in data: |
|
|
|
if isinstance(item, Dataset): |
|
item = item.data['text'] |
|
elif isinstance(item, str) and item.endswith('.csv'): |
|
with open(item, 'r') as f: |
|
reader = csv.reader(f) |
|
next(reader) |
|
item = {k: int(v) for k, v in reader} |
|
elif isinstance(item, str): |
|
if item.startswith('https://') or item.startswith('http://'): |
|
item = requests.get(item).text |
|
elif os.path.isfile(item) and item.endswith('.txt'): |
|
with open(item, 'r', encoding='utf-8') as f: |
|
item = f.read() |
|
|
|
if isinstance(item, dict): |
|
last_item = item.popitem() |
|
if last_item[1] != 0: |
|
print(f'Warning: the csv file or dictionary passed does not seem to have been made by this tokenizer.') |
|
item[last_item[0]] = last_item[1] |
|
elif last_item[0] != self.pattern: |
|
print(f'Warning: the dictionary or csv file passed did not use the same split pattern.') |
|
ids.update(item) |
|
elif isinstance(item, str): |
|
ids.update(re.findall(self.compiled_pattern, item)) |
|
elif isinstance(item, ChunkedArray): |
|
batch_size = len(item) // (self._cpus*2) or 1 |
|
batches = [item[i:i + batch_size] for i in range(0, len(item), batch_size)] |
|
print(f'Processing {len(batches)} batches of size {batch_size}') |
|
results = Parallel(n_jobs=self._cpus)(delayed(_process_string_scalar)(batch, self.compiled_pattern) for batch in batches) |
|
for result in results: |
|
ids.update(result) |
|
elif isinstance(item, IterableDataset): |
|
print('Serially processing IterableDataset...') |
|
for _dict in item: |
|
ids.update(re.findall(self.compiled_pattern, _dict['text'])) |
|
|
|
if self.store_dict: |
|
ids[self.pattern] = 0 |
|
formatted_time = time.strftime('%Y-%m-%d-%H_%M', time.localtime()) |
|
filename = f'{formatted_time}-dataset-dict.csv' |
|
try: |
|
with open(filename, 'w', newline='') as f: |
|
writer = csv.writer(f) |
|
writer.writerow(['text_chunk', 'count']) |
|
for key, value in ids.items(): |
|
writer.writerow([key, value]) |
|
print(f"Stored dictionary of {len(ids)} keys to {filename}") |
|
except: |
|
print('Failed to store dictionary of dataset.') |
|
del ids[self.pattern] |
|
|
|
ids = self._id_dict_to_list(ids) |
|
return ids |
|
|
|
def train(self, text, vocab_size, verbose=False): |
|
|
|
raise NotImplementedError |
|
|
|
def _build_vocab(self): |
|
|
|
vocab = {idx: bytes([idx]) for idx in range(256)} |
|
for (p0, p1), idx in self.merges.items(): |
|
vocab[idx] = vocab[p0] + vocab[p1] |
|
for special, idx in self.special_tokens.items(): |
|
vocab[idx] = special.encode("utf-8") |
|
return vocab |
|
|
|
def register_special_tokens(self, special_tokens): |
|
|
|
|
|
self.special_tokens = special_tokens |
|
self.inverse_special_tokens = {v: k for k, v in special_tokens.items()} |
|
|
|
def save(self, file_prefix): |
|
""" |
|
Saves two files: file_prefix.vocab and file_prefix.model |
|
This is inspired (but not equivalent to!) sentencepiece's model saving: |
|
- model file is the critical one, intended for load() later |
|
- vocab file is just a pretty printed version for human inspection only |
|
""" |
|
|
|
model_file = file_prefix + ".model" |
|
with open(model_file, 'w', encoding='utf-8') as f: |
|
|
|
f.write("mana v1\n") |
|
f.write(f"{self.pattern}\n") |
|
|
|
f.write(f"{len(self.special_tokens)}\n") |
|
for special, idx in self.special_tokens.items(): |
|
f.write(f"{special} {idx}\n") |
|
|
|
for key in self.merges: |
|
if isinstance(key, tuple): |
|
f.write(f"{key[0]} {key[1]}\n") |
|
else: |
|
f.write(f"{key}\n") |
|
|
|
|
|
vocab_file = file_prefix + ".vocab" |
|
inverted_merges = {idx: pair for pair, idx in self.merges.items()} |
|
with open(vocab_file, "w", encoding="utf-8") as f: |
|
for idx, token in self.vocab.items(): |
|
s = render_token(token) |
|
|
|
if idx in inverted_merges: |
|
idx0, idx1 = inverted_merges[idx] |
|
s0 = render_token(self.vocab[idx0]) |
|
s1 = render_token(self.vocab[idx1]) |
|
f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n") |
|
else: |
|
f.write(f"[{s}] {idx}\n") |
|
|
|
def load(self, model_file): |
|
"""Inverse of save() but only for the model file""" |
|
assert model_file.endswith(".model") |
|
|
|
merges = {} |
|
special_tokens = {} |
|
idx = 256 |
|
with open(model_file, 'r', encoding="utf-8") as f: |
|
|
|
version = f.readline().strip() |
|
assert version == "mana v1" |
|
|
|
self.pattern = f.readline().strip() |
|
|
|
num_special = int(f.readline().strip()) |
|
for _ in range(num_special): |
|
special, special_idx = f.readline().strip().split() |
|
special_tokens[special] = int(special_idx) |
|
|
|
for line in f: |
|
idx1, idx2 = map(int, line.split()) |
|
merges[(idx1, idx2)] = idx |
|
idx += 1 |
|
self.merges = merges |
|
self.special_tokens = special_tokens |
|
self.vocab = self._build_vocab() |
|
|
|
def decode(self, ids): |
|
|
|
part_bytes = [self.vocab[idx] if idx in self.vocab |
|
else self.inverse_special_tokens[idx].encode("utf-8") |
|
for idx in ids] |
|
text_bytes = b"".join(part_bytes) |
|
text = text_bytes.decode("utf-8", errors="replace") |
|
return text |
|
|
|
@lru_cache(maxsize=131072) |
|
def _encode_chunk(self, chunk): |
|
if chunk in self.stop_words: |
|
return [self.stop_words[chunk]] |
|
|
|
chunk = [*chunk.encode("utf-8")] |
|
len_chunk = len(chunk) |
|
while len_chunk >= 2: |
|
|
|
low = 987654321 |
|
for i in range(len_chunk - 1): |
|
current_pair = (chunk[i], chunk[i+1]) |
|
new_val = self.merges.get(current_pair, 987654321) |
|
if new_val < low: |
|
pair = current_pair |
|
low = new_val |
|
if low == 987654321: |
|
break |
|
|
|
idx = self.merges[pair] |
|
len_chunk = merge(chunk, pair, idx, len_chunk) |
|
return chunk |
|
|
|
def encode_ordinary(self, text): |
|
"""Encoding that ignores any special tokens.""" |
|
ids = [] |
|
for chunk in re.findall(self.compiled_pattern, text): |
|
ids.extend(self._encode_chunk(chunk)) |
|
return ids |
|
|
|
def encode(self, text, allowed_special="none_raise"): |
|
""" |
|
Unlike encode_ordinary, this function handles special tokens. |
|
allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens |
|
if none_raise, then an error is raised if any special token is encountered in text |
|
this is the default tiktoken behavior right now as well |
|
any other behavior is either annoying, or a major footgun |
|
""" |
|
|
|
special = None |
|
if allowed_special == "all": |
|
special = self.special_tokens |
|
elif allowed_special == "none": |
|
special = {} |
|
elif allowed_special == "none_raise": |
|
special = {} |
|
assert all(token not in text for token in self.special_tokens) |
|
elif isinstance(allowed_special, set): |
|
special = {k: v for k, v in self.special_tokens.items() if k in allowed_special} |
|
else: |
|
raise ValueError(f"allowed_special={allowed_special} not understood") |
|
if not special: |
|
return self.encode_ordinary(text) |
|
|
|
|
|
special_pattern = f"({'|'.join([re.escape(k) for k in special])})" |
|
special_chunks = re.split(special_pattern, text) |
|
|
|
|
|
ids = [] |
|
for part in special_chunks: |
|
special_token = special.get(part) |
|
if special_token is None: |
|
ids.extend(self.encode_ordinary(part)) |
|
else: |
|
ids.append(special_token) |
|
return ids |
|
|
|
def batch_encode(self, texts, allowed_special="none_raise"): |
|
""" |
|
Encode a list of texts in batch mode. |
|
Each text will be encoded according to the handling of special tokens specified in allowed_special. |
|
|
|
Parameters: |
|
texts (list of str): List of texts to encode. |
|
allowed_special (str|set): Special token handling mode. |
|
|
|
Returns: |
|
list of list of int: A list where each element is the encoded form of a text in `texts`. |
|
""" |
|
return [self.encode(text, allowed_special=allowed_special) for text in texts] |
|
|