import re import collections from typing import Dict, List, Tuple, Set from tqdm import tqdm from functools import lru_cache class HindiBPE: def __init__(self, max_vocab_size: int = 5000, target_compression: float = 3.2): self.max_vocab_size = max_vocab_size self.target_compression = target_compression self.vocab = {"": 0, "": 1, "": 2, "": 3} self.inverse_vocab = {v: k for k, v in self.vocab.items()} self.bpe_ranks = {} self.cache = {} self.special_tokens = {"", "", "", ""} self.word_end_token = "▁" # Special token to mark word boundaries self.vocab[self.word_end_token] = len(self.vocab) self.inverse_vocab[self.vocab[self.word_end_token]] = self.word_end_token def _tokenize_word(self, word: str) -> List[str]: """Tokenize a word into characters, handling Hindi characters properly""" if word in self.cache: return self.cache[word] # First check if the whole word is in vocabulary if word in self.vocab: self.cache[word] = [word] return [word] # Split into individual characters while preserving character combinations tokens = [] i = 0 while i < len(word): # Check for Hindi character followed by combining marks if re.match(r'[\u0900-\u097F]', word[i]): token = word[i] i += 1 # Add combining marks to the token while i < len(word) and re.match(r'[\u0900-\u0903\u093A-\u094F\u0962-\u0963]', word[i]): token += word[i] i += 1 tokens.append(token) else: # Handle non-Hindi characters token = word[i] i += 1 tokens.append(token) self.cache[word] = tokens return tokens def train_on_chunk(self, text: str, is_first_chunk: bool = False): """Train BPE on text data""" if not text.strip(): return # Add common Hindi words and characters to vocabulary first common_words = ["है", "मैं", "हूं", "का", "की", "के", "में", "से", "को", "पर", "और", "हैं", "था", "थी", "थे", "नमस्ते", "भारत", "हिंदी", "सीख", "रहा", "यह", "एक", "परीक्षण", "वाक्य", "विशाल", "देश", "मुझे", "भाषा", "बहुत", "पसंद"] for word in common_words: if word not in self.vocab and len(self.vocab) < self.max_vocab_size: self.vocab[word] = len(self.vocab) self.inverse_vocab[self.vocab[word]] = word # First pass: collect word frequencies word_freqs = collections.Counter(text.split()) # Add most frequent whole words to vocabulary (up to 10% of vocab size) max_word_tokens = self.max_vocab_size // 10 for word, freq in word_freqs.most_common(max_word_tokens): if len(word) > 1 and word not in self.vocab and len(self.vocab) < self.max_vocab_size: self.vocab[word] = len(self.vocab) self.inverse_vocab[self.vocab[word]] = word # Tokenize words and filter out empty ones words = [self._tokenize_word(word) for word in tqdm(text.split(), desc="Tokenizing words")] words = [word for word in words if word] # Filter out empty words if not words: # If no valid words found return # Initialize pair statistics print("Computing pair statistics...") pair_stats = collections.Counter() for word in words: if len(word) < 2: # Skip single-character words continue word_freq = word_freqs[' '.join(word)] for i in range(len(word) - 1): pair = (word[i], word[i+1]) pair_stats[pair] += word_freq if not pair_stats: # If no valid pairs found return # Keep track of best model best_vocab_size = len(self.vocab) best_compression = 0.0 best_state = None # Training loop with tqdm(total=self.max_vocab_size - len(self.vocab), desc="Training BPE") as pbar: while len(self.vocab) < self.max_vocab_size and pair_stats: # Get most frequent pair best_pair = max(pair_stats.items(), key=lambda x: (x[1], x[0]))[0] new_token = ''.join(best_pair) if new_token in self.vocab or len(self.vocab) >= self.max_vocab_size: # Skip if token already exists or vocab is full del pair_stats[best_pair] continue # Add to vocabulary token_id = len(self.vocab) self.vocab[new_token] = token_id self.inverse_vocab[token_id] = new_token self.bpe_ranks[best_pair] = len(self.bpe_ranks) # Update words and pair statistics new_words = [] for word in words: if len(word) < 2: # Skip single-character words new_words.append(word) continue i = 0 new_word = [] while i < len(word): if i < len(word) - 1 and word[i] == best_pair[0] and word[i+1] == best_pair[1]: new_word.append(new_token) i += 2 else: new_word.append(word[i]) i += 1 new_words.append(new_word) # Update statistics pair_stats.clear() for word in new_words: if len(word) < 2: # Skip single-character words continue word_freq = word_freqs[' '.join(word)] for i in range(len(word) - 1): pair = (word[i], word[i+1]) pair_stats[pair] += word_freq words = new_words # Calculate compression ratio every 50 tokens if len(self.vocab) % 50 == 0: sample_text = ' '.join([''.join(w) for w in words[:2000]]) current_ratio = self.get_compression_ratio(sample_text) print(f"\nVocab size: {len(self.vocab)}, Compression ratio: {current_ratio:.2f}") # Update best model if we meet requirements if current_ratio >= self.target_compression and len(self.vocab) < self.max_vocab_size: if current_ratio > best_compression: best_compression = current_ratio best_vocab_size = len(self.vocab) best_state = { 'vocab': self.vocab.copy(), 'inverse_vocab': self.inverse_vocab.copy(), 'bpe_ranks': self.bpe_ranks.copy() } pbar.update(1) # Stop if we've exceeded vocab size if len(self.vocab) >= self.max_vocab_size: break # Restore best model if found if best_state is not None: print(f"\nRestoring best model (vocab size: {best_vocab_size}, compression: {best_compression:.2f})") self.vocab = best_state['vocab'] self.inverse_vocab = best_state['inverse_vocab'] self.bpe_ranks = best_state['bpe_ranks'] # Calculate final metrics on the full text final_ratio = self.get_compression_ratio(text) print(f"\nFinal vocabulary size: {len(self.vocab)}") print(f"Final compression ratio: {final_ratio:.2f}") def encode(self, text: str) -> List[int]: """Encode text to token ids""" if not text.strip(): return [] result = [] words = text.split() for i, word in enumerate(words): if not word.strip(): continue # Check if the word is in vocabulary as a whole if word in self.vocab: result.append(self.vocab[word]) else: # Start with character-level tokens tokens = self._tokenize_word(word) word_tokens = [] # Try to merge tokens using learned BPE merges while len(tokens) > 1: pairs = [(tokens[i], tokens[i+1]) for i in range(len(tokens) - 1)] if not pairs: break # Find the highest ranked pair best_pair = None best_rank = float('inf') best_idx = -1 for i, pair in enumerate(pairs): rank = self.bpe_ranks.get(pair, float('inf')) if rank < best_rank: best_rank = rank best_pair = pair best_idx = i if best_pair is None: # No mergeable pairs found break # Merge the best pair merged = ''.join(best_pair) if merged not in self.vocab: # Skip if merged token not in vocab break tokens = ( tokens[:best_idx] + [merged] + tokens[best_idx + 2:] ) # Convert tokens to ids for token in tokens: if token in self.vocab: word_tokens.append(self.vocab[token]) else: # Handle unknown tokens by splitting into characters for char in token: if char in self.vocab: word_tokens.append(self.vocab[char]) else: word_tokens.append(self.vocab[""]) result.extend(word_tokens) # Add word boundary token except for the last word if i < len(words) - 1: result.append(self.vocab[self.word_end_token]) return result def decode(self, ids: List[int]) -> str: """Decode token ids back to text""" if not ids: return "" tokens = [] current_word = [] for id in ids: token = self.inverse_vocab.get(id, "") # Skip special tokens except word boundary if token in self.special_tokens and token != self.word_end_token: continue # Handle word boundary if token == self.word_end_token: if current_word: word = ''.join(current_word) tokens.append(word) current_word = [] else: current_word.append(token) # Add the last word if exists if current_word: word = ''.join(current_word) tokens.append(word) # Join all words with spaces return ' '.join(tokens) def get_compression_ratio(self, text: str) -> float: """Calculate compression ratio""" if not text: return 0.0 original_size = len(text.encode('utf-8')) encoded = self.encode(text) if not encoded: return 0.0 # Use 1 byte per token id instead of 2 since vocab size < 5000 compressed_size = len(encoded) return original_size / compressed_size if compressed_size > 0 else 0.0