""" Minimal (byte-level) Byte Pair Encoding tokenizer. Algorithmically follows along the GPT tokenizer: https://github.com/openai/gpt-2/blob/master/src/encoder.py But: - Does not handle the regular expression splitting pattern. - Does not handle any special tokens. """ import copy from .base import Tokenizer, get_stats, merge # class BasicTokenizer(Tokenizer): # # def __init__(self): # super().__init__() # # def train(self, text, vocab_size, verbose=False): # assert vocab_size >= 256 # num_merges = vocab_size - 256 # # # input text preprocessing # text_bytes = text.encode("utf-8") # raw bytes # ids = list(text_bytes) # list of integers in range 0..255 # # # iteratively merge the most common pairs to create new tokens # merges = {} # (int, int) -> int # vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes # for i in range(num_merges): # # count up the number of times every consecutive pair appears # stats = get_stats(ids) # # find the pair with the highest count # pair = max(stats, key=stats.get) # # mint a new token: assign it the next available id # idx = 256 + i # # replace all occurrences of pair in ids with idx # ids = merge(ids, pair, idx) # # save the merge # merges[pair] = idx # vocab[idx] = vocab[pair[0]] + vocab[pair[1]] # # prints # if verbose: # print(f"merge {i + 1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences") # # # save class variables # self.merges = merges # used in encode() # self.vocab = vocab # used in decode() # # def decode(self, ids): # # given ids (list of integers), return Python string # text_bytes = b"".join(self.vocab[idx] for idx in ids) # text = text_bytes.decode("utf-8", errors="replace") # return text # # def encode(self, text): # # given a string text, return the token ids # text_bytes = text.encode("utf-8") # raw bytes # ids = list(text_bytes) # list of integers in range 0..255 # while len(ids) >= 2: # # find the pair with the lowest merge index # stats = get_stats(ids) # pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) # # subtle: if there are no more merges available, the key will # # result in an inf for every single pair, and the min will be # # just the first pair in the list, arbitrarily # # we can detect this terminating case by a membership check # if pair not in self.merges: # break # nothing else can be merged anymore # # otherwise let's merge the best pair (lowest merge index) # idx = self.merges[pair] # ids = merge(ids, pair, idx) # return ids class BasicTokenizer(Tokenizer): def __init__(self): super().__init__() self.merge_counter = 0 def train(self, text, vocab_size, verbose=False): # left assert in place just to introduce consistency and a hard check of the increase in vocab size and number of merges assert vocab_size >= 256 num_merges = vocab_size - 256 current_batch_merge_counter = 0 # in case not all exact `num_merges` happen # input text preprocessing text_bytes = text.encode("utf-8") # encode to get all waw bytes ids = list(text_bytes) # represent the bytes in ints # use same merge dict if exists self.merges = {} if self.merges is None else self.merges # to hold all merges (int, int) -> int # Use same vocab for this Tokenizer object if it exists # Tokenizer vocab: int -> bytes self.vocab = {idx: bytes([idx]) for idx in range(256)} if self.vocab is None else self.vocab # iteratively merge the MOST COMMON pair from the text for i in range(num_merges): # get count of pairs stats = get_stats(ids) # find the pair with the highest count # pair = max(stats, key=stats.get) # tmp_stats = copy.deepcopy(stats) # get most occurring pair from ids pair = max(stats, key=stats.get) while pair in self.merges: # pair was previously merged ... use this first to update IDS # No need to add to merges and vocab, use previously stored token already_merged_idx = self.merges[pair] # just replace already merged pairs in ids and get new ids and no need to again add to merges and vocab ids = merge(ids, pair, already_merged_idx) stats = get_stats(ids) if stats and len(ids) >= 2: pair = max(stats, key=stats.get) else: # no new merges found in this incoming data batch print(f"\n\nstopping merges as no new byte pair found in the current batch") break # this most occurring pair not merged yet in any data batch # generate a new token considering how many have been generated so far for the same tokenizer idx = len(self.vocab) + 1 # update current new generated tokens to add to self.merge_counter later current_batch_merge_counter += 1 # replace all occurrences of `pair` above in `ids` with NEW `idx` token, add this one to merges & vocab # Note: this pair has never been seen for merging ids = merge(ids, pair, idx) self.merges[pair] = idx self.vocab[idx] = self.vocab[pair[0]] + self.vocab[pair[1]] if verbose: print(f"merge {i + 1}/{num_merges}: {pair} -> {idx} ({self.vocab[idx]}) had {stats[pair]} count") self.merge_counter += current_batch_merge_counter def decode(self, ids): # given ids (list of integers), return Python string text_bytes = b"".join(self.vocab[idx] for idx in ids) text = text_bytes.decode("utf-8", errors="replace") return text def encode(self, text): # input a string text, returns the token ids text_bytes = text.encode("utf-8") ids = list(text_bytes) while len(ids) >= 2: # here find the pair with the lowest merge index stats = get_stats(ids) pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) # if no merges i.e. the pair is not in merges dict, # the key will result in an `inf` for every single pair, # and the min will be just the first pair in the list, # we can detect this terminating case by a membership check if pair not in self.merges: break # nothing else can be merged anymore # otherwise merge the best pair NOTE: (lowest merge index) idx = self.merges[pair] ids = merge(ids, pair, idx) return ids