|
import torch |
|
from .base import Tokenizer |
|
from .helper import get_stats, merge_batch_get_stats |
|
from heapq import nlargest |
|
import time |
|
|
|
MANA_SPECIAL_TOKENS = { |
|
'<|end|>': 265712, |
|
'<|user|>': 265713, |
|
'<|assistant|>': 265714, |
|
'<|system|>': 265715 |
|
} |
|
|
|
class ManaTokenizer(Tokenizer): |
|
def __init__(self, pattern=None, multiprocess=True, store_dict=False, stop_list_size=0, freq_cutoff=1): |
|
""" |
|
- pattern: optional string to override the default (GPT-4 split pattern) |
|
- special_tokens: str -> int dictionary of special tokens |
|
example: {'<|endoftext|>': 100257} |
|
""" |
|
super().__init__(pattern, multiprocess, store_dict, stop_list_size, freq_cutoff) |
|
self.register_special_tokens(MANA_SPECIAL_TOKENS) |
|
self.load("mana_tokenizer/mana.model") |
|
self.padding_side = "right" |
|
self.pad_token_id = self.special_tokens.get('<|end|>') |
|
|
|
@property |
|
def tokens(self): |
|
"""Property to retrieve token IDs for a given text.""" |
|
return self._tokens |
|
|
|
@property |
|
def attention_masks(self): |
|
"""Property to retrieve attention masks for a given text.""" |
|
return self._attention_masks |
|
|
|
def encode(self, text, allowed_special="none_raise"): |
|
"""Override encode to include attention masks.""" |
|
encoded_ids = super().encode(text, allowed_special=allowed_special) |
|
self._tokens = encoded_ids |
|
self._attention_masks = torch.ones(len(encoded_ids), dtype=torch.int32) |
|
return self |
|
|
|
def batch_encode(self, texts, padding=True): |
|
""" |
|
Encode a list of texts with dynamic padding and attention masks. |
|
Handles left padding and attention masking. |
|
|
|
Parameters: |
|
texts (list of str): List of texts to encode. |
|
padding (bool): If True, pad sequences to the max length in the batch. |
|
|
|
Returns: |
|
dict: A dictionary containing input_ids and attention_mask tensors. |
|
""" |
|
|
|
encoded_texts = [{"input_ids": self.encode(text).tokens, "attention_mask": [1] * len(self.encode(text).tokens)} |
|
for text in texts] |
|
|
|
max_len = max(len(t["input_ids"]) for t in encoded_texts) if padding else None |
|
|
|
|
|
input_ids = [] |
|
attention_masks = [] |
|
for encoding in encoded_texts: |
|
ids = encoding["input_ids"] |
|
attn_mask = encoding["attention_mask"] |
|
if padding and len(ids) < max_len: |
|
pad_len = max_len - len(ids) |
|
if self.padding_side == "left": |
|
ids = [self.pad_token_id] * pad_len + ids |
|
attn_mask = [0] * pad_len + attn_mask |
|
else: |
|
ids = ids + [self.pad_token_id] * pad_len |
|
attn_mask = attn_mask + [0] * pad_len |
|
input_ids.append(ids) |
|
attention_masks.append(attn_mask) |
|
|
|
|
|
input_ids = torch.tensor(input_ids, dtype=torch.long) |
|
attention_masks = torch.tensor(attention_masks, dtype=torch.long) |
|
|
|
return {"input_ids": input_ids, "attention_mask": attention_masks} |
|
|
|
|
|
def get_vocab(self): |
|
"""Function to return the vocabulary dictionary.""" |
|
return self.vocab |
|
|
|
@property |
|
def vocab_size(self): |
|
"""Property to return the vocabulary size.""" |
|
return len(self.vocab) |
|
|
|
def train(self, data, vocab_size, cap_divisor=2, max_batch_size=0, verbose=False): |
|
t0 = time.time() |
|
ids = self._import_data(data) |
|
t1 = time.time() |
|
print(f'Time spent loading data: {t1-t0:.2f}') |
|
|
|
merges = self.merges |
|
vocab = self.vocab |
|
batch_count = 0 |
|
curr_vocab_size = len(vocab) |
|
num_merges = vocab_size - curr_vocab_size |
|
merges_remaining = num_merges |
|
if max_batch_size < 1: |
|
max_batch_size = num_merges |
|
stats = get_stats(ids) |
|
start_time = time.time() |
|
while merges_remaining > 0: |
|
seen_first = set() |
|
seen_last = set() |
|
pairs_to_merge = {} |
|
num_pairs_to_search = min(merges_remaining//cap_divisor, len(vocab), max_batch_size) or 1 |
|
top_pairs = nlargest(num_pairs_to_search, stats, key=stats.get) |
|
for first, last in top_pairs: |
|
if first in seen_last or last in seen_first: |
|
seen_first.add(first) |
|
seen_last.add(last) |
|
continue |
|
seen_first.add(first) |
|
seen_last.add(last) |
|
pairs_to_merge[(first, last)] = curr_vocab_size |
|
vocab[curr_vocab_size] = vocab[first] + vocab[last] |
|
curr_vocab_size += 1 |
|
merges_remaining -= len(pairs_to_merge) |
|
merges.update(pairs_to_merge) |
|
batch_count += 1 |
|
if merges_remaining: |
|
stats = merge_batch_get_stats(ids, pairs_to_merge) |
|
if verbose: |
|
t2 = time.time() |
|
time_taken = t2 - start_time |
|
avg_time_per_batch = time_taken / batch_count |
|
estimated_remaining_time = avg_time_per_batch * (num_merges - merges_remaining) |
|
estimated_end_time = time.strftime("%H:%M:%S", time.localtime(time.time() + estimated_remaining_time)) |
|
print(f"Batch {batch_count} merged {len(pairs_to_merge)} pairs in {t2-t1:.2f} sec. " |
|
f"Merges remaining: {merges_remaining}. Estimated end time: {estimated_end_time}") |
|
t1 = t2 |
|
|
|
self.merges = merges |
|
self.vocab = vocab |