import unicodedata from collections import defaultdict from itertools import pairwise def get_adjacent_pair_counts(ids) -> defaultdict: counts = defaultdict(int) for pair in pairwise(ids): counts[pair] += 1 return counts def merge_pairs(ids, pair, idx): newids = [] i = 0 n = len(ids) while i < n: if i < n - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]: newids.append(idx) i += 2 else: newids.append(ids[i]) i += 1 return newids def replace_control_characters(s: str) -> str: chars = [] for ch in s: if unicodedata.category(ch)[0] != 'C': chars.append(ch) else: chars.append(f'\\u{ord(ch):04x}') # escape return ''.join(chars) def render_token(t: bytes) -> str: s = t.decode('utf-8', errors='replace') s = replace_control_characters(s) return s class Tokenizer: """Base class for Tokenizers""" def __init__(self): self.merges = {} self.pattern = '' self.special_tokens = {} self.vocab = self._build_vocab() def train(self, text, vocab_size, verbose=False): raise NotImplementedError def encode(self, text): raise NotImplementedError def decode(self, ids): raise NotImplementedError def _build_vocab(self): vocab = {idx: bytes([idx]) for idx in range(256)} for (p0, p1), idx in self.merges.items(): vocab[idx] = vocab[p0] + vocab[p1] for special, idx in self.special_tokens.items(): vocab[idx] = special.encode('utf-8') return vocab def save(self, file_prefix): # Similar to sentencepiece model saving model_file = file_prefix + '.model' with open(model_file, 'w') as f: f.write('xsbpe v1\n') f.write(f'{self.pattern}\n') f.write(f'{len(self.special_tokens)}\n') for special, idx in self.special_tokens.items(): f.write(f'{special} {idx}\n') for idx1, idx2 in self.merges: f.write(f'{idx1} {idx2}\n') # vocab file meant for human inspection only vocab_file = file_prefix + '.vocab' inverted_merges = {idx: pair for pair, idx in self.merges.items()} with open(vocab_file, 'w', encoding='utf-8') as f: for idx, token in self.vocab.items(): s = render_token(token) if idx in inverted_merges: idx0, idx1 = inverted_merges[idx] s0 = render_token(self.vocab[idx0]) s1 = render_token(self.vocab[idx1]) f.write(f'[{s0}][{s1}] -> [{s}] {idx}\n') else: f.write(f'[{s}] {idx} \n') def load(self, model_file): assert model_file.endswith('.model') merges = {} special_tokens = {} idx = 256 with open(model_file, 'r', encoding='utf-8') as f: version = f.readline().strip() assert version == 'xsbpe v1' self.pattern = f.readline().strip() num_special = int(f.readline().strip()) for _ in range(num_special): special, special_idx = f.readline().strip().split() special_tokens[special] = int(special_idx) for line in f: idx1, idx2 = map(int, line.split()) merges[(idx1, idx2)] = idx idx += 1 self.merges = merges self.special_tokens = special_tokens self.vocab = self._build_vocab()