File size: 6,356 Bytes
be6798e
d786ff1
 
4128ba5
 
 
 
be6798e
 
 
 
4128ba5
 
 
 
 
 
 
 
 
 
be6798e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4128ba5
be6798e
 
 
 
 
 
 
 
 
 
4128ba5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
from .base import Tokenizer
from .helper import get_stats, merge_batch_get_stats
from heapq import nlargest
import time

MANA_SPECIAL_TOKENS = {
    '<|end|>': 265712,
    '<|user|>': 265713,
    '<|assistant|>': 265714,
    '<|system|>': 265715
}

class ManaTokenizer(Tokenizer):
    def __init__(self, pattern=None, multiprocess=True, store_dict=False, stop_list_size=0, freq_cutoff=1):
        """
        - pattern: optional string to override the default (GPT-4 split pattern)
        - special_tokens: str -> int dictionary of special tokens
          example: {'<|endoftext|>': 100257}
        """
        super().__init__(pattern, multiprocess, store_dict, stop_list_size, freq_cutoff)
        self.register_special_tokens(MANA_SPECIAL_TOKENS)
        self.load("mana_tokenizer/mana.model")
        self.padding_side = "right"
        self.pad_token_id = self.special_tokens.get('<|end|>')
        
    @property
    def tokens(self):
        """Property to retrieve token IDs for a given text."""
        return self._tokens
    
    @property
    def attention_masks(self):
        """Property to retrieve attention masks for a given text."""
        return self._attention_masks
    
    def encode(self, text, allowed_special="none_raise"):
        """Override encode to include attention masks."""
        encoded_ids = super().encode(text, allowed_special=allowed_special)
        self._tokens = encoded_ids
        self._attention_masks = torch.ones(len(encoded_ids), dtype=torch.int32)
        return self
    
    def batch_encode(self, texts, padding=True):
        """
        Encode a list of texts with dynamic padding and attention masks.
        Handles left padding and attention masking.
        
        Parameters:
            texts (list of str): List of texts to encode.
            padding (bool): If True, pad sequences to the max length in the batch.

        Returns:
            dict: A dictionary containing input_ids and attention_mask tensors.
        """
        # Ensure encode method returns a dict with 'input_ids' and 'attention_mask'
        encoded_texts = [{"input_ids": self.encode(text).tokens, "attention_mask": [1] * len(self.encode(text).tokens)}
                        for text in texts]
        
        max_len = max(len(t["input_ids"]) for t in encoded_texts) if padding else None

        # Apply padding with left alignment
        input_ids = []
        attention_masks = []
        for encoding in encoded_texts:
            ids = encoding["input_ids"]
            attn_mask = encoding["attention_mask"]
            if padding and len(ids) < max_len:
                pad_len = max_len - len(ids)
                if self.padding_side == "left":
                    ids = [self.pad_token_id] * pad_len + ids
                    attn_mask = [0] * pad_len + attn_mask
                else:
                    ids = ids + [self.pad_token_id] * pad_len
                    attn_mask = attn_mask + [0] * pad_len
            input_ids.append(ids)
            attention_masks.append(attn_mask)

        # Convert to tensors
        input_ids = torch.tensor(input_ids, dtype=torch.long)
        attention_masks = torch.tensor(attention_masks, dtype=torch.long)

        return {"input_ids": input_ids, "attention_mask": attention_masks}

        
    def get_vocab(self):
        """Function to return the vocabulary dictionary."""
        return self.vocab
    
    @property
    def vocab_size(self):
        """Property to return the vocabulary size."""
        return len(self.vocab)
    
    def train(self, data, vocab_size, cap_divisor=2, max_batch_size=0, verbose=False):
        t0 = time.time()
        ids = self._import_data(data)   # [(bytes, int)] -> text chunks and their counts
        t1 = time.time()
        print(f'Time spent loading data: {t1-t0:.2f}')

        merges = self.merges   # {(int, int): int} -> token pair to new token
        vocab = self.vocab   # {int: bytes} -> token to its bytes representation
        batch_count = 0
        curr_vocab_size = len(vocab)
        num_merges = vocab_size - curr_vocab_size
        merges_remaining = num_merges
        if max_batch_size < 1:
            max_batch_size = num_merges
        stats = get_stats(ids)   # stats are later updated by merge_batch_get_stats
        start_time = time.time()
        while merges_remaining > 0:
            seen_first = set()   # tokens seen in the first position in pairs
            seen_last = set()   # tokens seen in the last position in pairs
            pairs_to_merge = {}
            num_pairs_to_search = min(merges_remaining//cap_divisor, len(vocab), max_batch_size) or 1
            top_pairs = nlargest(num_pairs_to_search, stats, key=stats.get)
            for first, last in top_pairs:  # pairs are (first, last) tuples
                if first in seen_last or last in seen_first:   # unsafe merge
                    seen_first.add(first)
                    seen_last.add(last)
                    continue # skip this pair but keep looking for safe merges in top_pairs
                seen_first.add(first)
                seen_last.add(last)
                pairs_to_merge[(first, last)] = curr_vocab_size
                vocab[curr_vocab_size] = vocab[first] + vocab[last]
                curr_vocab_size += 1
            merges_remaining -= len(pairs_to_merge)
            merges.update(pairs_to_merge)  # save the merges
            batch_count += 1
            if merges_remaining:   # no need to merge last batch
                stats = merge_batch_get_stats(ids, pairs_to_merge)   # replace pairs_to_merge keys in ids with their values
            if verbose:
                t2 = time.time()
                time_taken = t2 - start_time
                avg_time_per_batch = time_taken / batch_count
                estimated_remaining_time = avg_time_per_batch * (num_merges - merges_remaining)
                estimated_end_time = time.strftime("%H:%M:%S", time.localtime(time.time() + estimated_remaining_time))
                print(f"Batch {batch_count} merged {len(pairs_to_merge)} pairs in {t2-t1:.2f} sec. "
                    f"Merges remaining: {merges_remaining}. Estimated end time: {estimated_end_time}") 
                t1 = t2
                
        self.merges = merges # used in encode()
        self.vocab = vocab   # used in decode()