File size: 6,356 Bytes
be6798e d786ff1 4128ba5 be6798e 4128ba5 be6798e 4128ba5 be6798e 4128ba5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import torch
from .base import Tokenizer
from .helper import get_stats, merge_batch_get_stats
from heapq import nlargest
import time
MANA_SPECIAL_TOKENS = {
'<|end|>': 265712,
'<|user|>': 265713,
'<|assistant|>': 265714,
'<|system|>': 265715
}
class ManaTokenizer(Tokenizer):
def __init__(self, pattern=None, multiprocess=True, store_dict=False, stop_list_size=0, freq_cutoff=1):
"""
- pattern: optional string to override the default (GPT-4 split pattern)
- special_tokens: str -> int dictionary of special tokens
example: {'<|endoftext|>': 100257}
"""
super().__init__(pattern, multiprocess, store_dict, stop_list_size, freq_cutoff)
self.register_special_tokens(MANA_SPECIAL_TOKENS)
self.load("mana_tokenizer/mana.model")
self.padding_side = "right"
self.pad_token_id = self.special_tokens.get('<|end|>')
@property
def tokens(self):
"""Property to retrieve token IDs for a given text."""
return self._tokens
@property
def attention_masks(self):
"""Property to retrieve attention masks for a given text."""
return self._attention_masks
def encode(self, text, allowed_special="none_raise"):
"""Override encode to include attention masks."""
encoded_ids = super().encode(text, allowed_special=allowed_special)
self._tokens = encoded_ids
self._attention_masks = torch.ones(len(encoded_ids), dtype=torch.int32)
return self
def batch_encode(self, texts, padding=True):
"""
Encode a list of texts with dynamic padding and attention masks.
Handles left padding and attention masking.
Parameters:
texts (list of str): List of texts to encode.
padding (bool): If True, pad sequences to the max length in the batch.
Returns:
dict: A dictionary containing input_ids and attention_mask tensors.
"""
# Ensure encode method returns a dict with 'input_ids' and 'attention_mask'
encoded_texts = [{"input_ids": self.encode(text).tokens, "attention_mask": [1] * len(self.encode(text).tokens)}
for text in texts]
max_len = max(len(t["input_ids"]) for t in encoded_texts) if padding else None
# Apply padding with left alignment
input_ids = []
attention_masks = []
for encoding in encoded_texts:
ids = encoding["input_ids"]
attn_mask = encoding["attention_mask"]
if padding and len(ids) < max_len:
pad_len = max_len - len(ids)
if self.padding_side == "left":
ids = [self.pad_token_id] * pad_len + ids
attn_mask = [0] * pad_len + attn_mask
else:
ids = ids + [self.pad_token_id] * pad_len
attn_mask = attn_mask + [0] * pad_len
input_ids.append(ids)
attention_masks.append(attn_mask)
# Convert to tensors
input_ids = torch.tensor(input_ids, dtype=torch.long)
attention_masks = torch.tensor(attention_masks, dtype=torch.long)
return {"input_ids": input_ids, "attention_mask": attention_masks}
def get_vocab(self):
"""Function to return the vocabulary dictionary."""
return self.vocab
@property
def vocab_size(self):
"""Property to return the vocabulary size."""
return len(self.vocab)
def train(self, data, vocab_size, cap_divisor=2, max_batch_size=0, verbose=False):
t0 = time.time()
ids = self._import_data(data) # [(bytes, int)] -> text chunks and their counts
t1 = time.time()
print(f'Time spent loading data: {t1-t0:.2f}')
merges = self.merges # {(int, int): int} -> token pair to new token
vocab = self.vocab # {int: bytes} -> token to its bytes representation
batch_count = 0
curr_vocab_size = len(vocab)
num_merges = vocab_size - curr_vocab_size
merges_remaining = num_merges
if max_batch_size < 1:
max_batch_size = num_merges
stats = get_stats(ids) # stats are later updated by merge_batch_get_stats
start_time = time.time()
while merges_remaining > 0:
seen_first = set() # tokens seen in the first position in pairs
seen_last = set() # tokens seen in the last position in pairs
pairs_to_merge = {}
num_pairs_to_search = min(merges_remaining//cap_divisor, len(vocab), max_batch_size) or 1
top_pairs = nlargest(num_pairs_to_search, stats, key=stats.get)
for first, last in top_pairs: # pairs are (first, last) tuples
if first in seen_last or last in seen_first: # unsafe merge
seen_first.add(first)
seen_last.add(last)
continue # skip this pair but keep looking for safe merges in top_pairs
seen_first.add(first)
seen_last.add(last)
pairs_to_merge[(first, last)] = curr_vocab_size
vocab[curr_vocab_size] = vocab[first] + vocab[last]
curr_vocab_size += 1
merges_remaining -= len(pairs_to_merge)
merges.update(pairs_to_merge) # save the merges
batch_count += 1
if merges_remaining: # no need to merge last batch
stats = merge_batch_get_stats(ids, pairs_to_merge) # replace pairs_to_merge keys in ids with their values
if verbose:
t2 = time.time()
time_taken = t2 - start_time
avg_time_per_batch = time_taken / batch_count
estimated_remaining_time = avg_time_per_batch * (num_merges - merges_remaining)
estimated_end_time = time.strftime("%H:%M:%S", time.localtime(time.time() + estimated_remaining_time))
print(f"Batch {batch_count} merged {len(pairs_to_merge)} pairs in {t2-t1:.2f} sec. "
f"Merges remaining: {merges_remaining}. Estimated end time: {estimated_end_time}")
t1 = t2
self.merges = merges # used in encode()
self.vocab = vocab # used in decode() |