mana_tokenizer / mana_tokenizer.py
tspersian's picture
improve
be6798e
import torch
from .base import Tokenizer
from .helper import get_stats, merge_batch_get_stats
from heapq import nlargest
import time
MANA_SPECIAL_TOKENS = {
'<|end|>': 265712,
'<|user|>': 265713,
'<|assistant|>': 265714,
'<|system|>': 265715
}
class ManaTokenizer(Tokenizer):
def __init__(self, pattern=None, multiprocess=True, store_dict=False, stop_list_size=0, freq_cutoff=1):
"""
- pattern: optional string to override the default (GPT-4 split pattern)
- special_tokens: str -> int dictionary of special tokens
example: {'<|endoftext|>': 100257}
"""
super().__init__(pattern, multiprocess, store_dict, stop_list_size, freq_cutoff)
self.register_special_tokens(MANA_SPECIAL_TOKENS)
self.load("mana_tokenizer/mana.model")
self.padding_side = "right"
self.pad_token_id = self.special_tokens.get('<|end|>')
@property
def tokens(self):
"""Property to retrieve token IDs for a given text."""
return self._tokens
@property
def attention_masks(self):
"""Property to retrieve attention masks for a given text."""
return self._attention_masks
def encode(self, text, allowed_special="none_raise"):
"""Override encode to include attention masks."""
encoded_ids = super().encode(text, allowed_special=allowed_special)
self._tokens = encoded_ids
self._attention_masks = torch.ones(len(encoded_ids), dtype=torch.int32)
return self
def batch_encode(self, texts, padding=True):
"""
Encode a list of texts with dynamic padding and attention masks.
Handles left padding and attention masking.
Parameters:
texts (list of str): List of texts to encode.
padding (bool): If True, pad sequences to the max length in the batch.
Returns:
dict: A dictionary containing input_ids and attention_mask tensors.
"""
# Ensure encode method returns a dict with 'input_ids' and 'attention_mask'
encoded_texts = [{"input_ids": self.encode(text).tokens, "attention_mask": [1] * len(self.encode(text).tokens)}
for text in texts]
max_len = max(len(t["input_ids"]) for t in encoded_texts) if padding else None
# Apply padding with left alignment
input_ids = []
attention_masks = []
for encoding in encoded_texts:
ids = encoding["input_ids"]
attn_mask = encoding["attention_mask"]
if padding and len(ids) < max_len:
pad_len = max_len - len(ids)
if self.padding_side == "left":
ids = [self.pad_token_id] * pad_len + ids
attn_mask = [0] * pad_len + attn_mask
else:
ids = ids + [self.pad_token_id] * pad_len
attn_mask = attn_mask + [0] * pad_len
input_ids.append(ids)
attention_masks.append(attn_mask)
# Convert to tensors
input_ids = torch.tensor(input_ids, dtype=torch.long)
attention_masks = torch.tensor(attention_masks, dtype=torch.long)
return {"input_ids": input_ids, "attention_mask": attention_masks}
def get_vocab(self):
"""Function to return the vocabulary dictionary."""
return self.vocab
@property
def vocab_size(self):
"""Property to return the vocabulary size."""
return len(self.vocab)
def train(self, data, vocab_size, cap_divisor=2, max_batch_size=0, verbose=False):
t0 = time.time()
ids = self._import_data(data) # [(bytes, int)] -> text chunks and their counts
t1 = time.time()
print(f'Time spent loading data: {t1-t0:.2f}')
merges = self.merges # {(int, int): int} -> token pair to new token
vocab = self.vocab # {int: bytes} -> token to its bytes representation
batch_count = 0
curr_vocab_size = len(vocab)
num_merges = vocab_size - curr_vocab_size
merges_remaining = num_merges
if max_batch_size < 1:
max_batch_size = num_merges
stats = get_stats(ids) # stats are later updated by merge_batch_get_stats
start_time = time.time()
while merges_remaining > 0:
seen_first = set() # tokens seen in the first position in pairs
seen_last = set() # tokens seen in the last position in pairs
pairs_to_merge = {}
num_pairs_to_search = min(merges_remaining//cap_divisor, len(vocab), max_batch_size) or 1
top_pairs = nlargest(num_pairs_to_search, stats, key=stats.get)
for first, last in top_pairs: # pairs are (first, last) tuples
if first in seen_last or last in seen_first: # unsafe merge
seen_first.add(first)
seen_last.add(last)
continue # skip this pair but keep looking for safe merges in top_pairs
seen_first.add(first)
seen_last.add(last)
pairs_to_merge[(first, last)] = curr_vocab_size
vocab[curr_vocab_size] = vocab[first] + vocab[last]
curr_vocab_size += 1
merges_remaining -= len(pairs_to_merge)
merges.update(pairs_to_merge) # save the merges
batch_count += 1
if merges_remaining: # no need to merge last batch
stats = merge_batch_get_stats(ids, pairs_to_merge) # replace pairs_to_merge keys in ids with their values
if verbose:
t2 = time.time()
time_taken = t2 - start_time
avg_time_per_batch = time_taken / batch_count
estimated_remaining_time = avg_time_per_batch * (num_merges - merges_remaining)
estimated_end_time = time.strftime("%H:%M:%S", time.localtime(time.time() + estimated_remaining_time))
print(f"Batch {batch_count} merged {len(pairs_to_merge)} pairs in {t2-t1:.2f} sec. "
f"Merges remaining: {merges_remaining}. Estimated end time: {estimated_end_time}")
t1 = t2
self.merges = merges # used in encode()
self.vocab = vocab # used in decode()