Spaces:
Sleeping
Sleeping
File size: 12,645 Bytes
ca9ce93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 |
import re
import collections
from typing import Dict, List, Tuple, Set
from tqdm import tqdm
from functools import lru_cache
class HindiBPE:
def __init__(self, max_vocab_size: int = 5000, target_compression: float = 3.2):
self.max_vocab_size = max_vocab_size
self.target_compression = target_compression
self.vocab = {"<PAD>": 0, "<UNK>": 1, "<BOS>": 2, "<EOS>": 3}
self.inverse_vocab = {v: k for k, v in self.vocab.items()}
self.bpe_ranks = {}
self.cache = {}
self.special_tokens = {"<PAD>", "<UNK>", "<BOS>", "<EOS>"}
self.word_end_token = "▁" # Special token to mark word boundaries
self.vocab[self.word_end_token] = len(self.vocab)
self.inverse_vocab[self.vocab[self.word_end_token]] = self.word_end_token
def _tokenize_word(self, word: str) -> List[str]:
"""Tokenize a word into characters, handling Hindi characters properly"""
if word in self.cache:
return self.cache[word]
# First check if the whole word is in vocabulary
if word in self.vocab:
self.cache[word] = [word]
return [word]
# Split into individual characters while preserving character combinations
tokens = []
i = 0
while i < len(word):
# Check for Hindi character followed by combining marks
if re.match(r'[\u0900-\u097F]', word[i]):
token = word[i]
i += 1
# Add combining marks to the token
while i < len(word) and re.match(r'[\u0900-\u0903\u093A-\u094F\u0962-\u0963]', word[i]):
token += word[i]
i += 1
tokens.append(token)
else:
# Handle non-Hindi characters
token = word[i]
i += 1
tokens.append(token)
self.cache[word] = tokens
return tokens
def train_on_chunk(self, text: str, is_first_chunk: bool = False):
"""Train BPE on text data"""
if not text.strip():
return
# Add common Hindi words and characters to vocabulary first
common_words = ["है", "मैं", "हूं", "का", "की", "के", "में", "से", "को", "पर", "और", "हैं", "था", "थी", "थे",
"नमस्ते", "भारत", "हिंदी", "सीख", "रहा", "यह", "एक", "परीक्षण", "वाक्य", "विशाल", "देश",
"मुझे", "भाषा", "बहुत", "पसंद"]
for word in common_words:
if word not in self.vocab and len(self.vocab) < self.max_vocab_size:
self.vocab[word] = len(self.vocab)
self.inverse_vocab[self.vocab[word]] = word
# First pass: collect word frequencies
word_freqs = collections.Counter(text.split())
# Add most frequent whole words to vocabulary (up to 10% of vocab size)
max_word_tokens = self.max_vocab_size // 10
for word, freq in word_freqs.most_common(max_word_tokens):
if len(word) > 1 and word not in self.vocab and len(self.vocab) < self.max_vocab_size:
self.vocab[word] = len(self.vocab)
self.inverse_vocab[self.vocab[word]] = word
# Tokenize words and filter out empty ones
words = [self._tokenize_word(word) for word in tqdm(text.split(), desc="Tokenizing words")]
words = [word for word in words if word] # Filter out empty words
if not words: # If no valid words found
return
# Initialize pair statistics
print("Computing pair statistics...")
pair_stats = collections.Counter()
for word in words:
if len(word) < 2: # Skip single-character words
continue
word_freq = word_freqs[' '.join(word)]
for i in range(len(word) - 1):
pair = (word[i], word[i+1])
pair_stats[pair] += word_freq
if not pair_stats: # If no valid pairs found
return
# Keep track of best model
best_vocab_size = len(self.vocab)
best_compression = 0.0
best_state = None
# Training loop
with tqdm(total=self.max_vocab_size - len(self.vocab), desc="Training BPE") as pbar:
while len(self.vocab) < self.max_vocab_size and pair_stats:
# Get most frequent pair
best_pair = max(pair_stats.items(), key=lambda x: (x[1], x[0]))[0]
new_token = ''.join(best_pair)
if new_token in self.vocab or len(self.vocab) >= self.max_vocab_size:
# Skip if token already exists or vocab is full
del pair_stats[best_pair]
continue
# Add to vocabulary
token_id = len(self.vocab)
self.vocab[new_token] = token_id
self.inverse_vocab[token_id] = new_token
self.bpe_ranks[best_pair] = len(self.bpe_ranks)
# Update words and pair statistics
new_words = []
for word in words:
if len(word) < 2: # Skip single-character words
new_words.append(word)
continue
i = 0
new_word = []
while i < len(word):
if i < len(word) - 1 and word[i] == best_pair[0] and word[i+1] == best_pair[1]:
new_word.append(new_token)
i += 2
else:
new_word.append(word[i])
i += 1
new_words.append(new_word)
# Update statistics
pair_stats.clear()
for word in new_words:
if len(word) < 2: # Skip single-character words
continue
word_freq = word_freqs[' '.join(word)]
for i in range(len(word) - 1):
pair = (word[i], word[i+1])
pair_stats[pair] += word_freq
words = new_words
# Calculate compression ratio every 50 tokens
if len(self.vocab) % 50 == 0:
sample_text = ' '.join([''.join(w) for w in words[:2000]])
current_ratio = self.get_compression_ratio(sample_text)
print(f"\nVocab size: {len(self.vocab)}, Compression ratio: {current_ratio:.2f}")
# Update best model if we meet requirements
if current_ratio >= self.target_compression and len(self.vocab) < self.max_vocab_size:
if current_ratio > best_compression:
best_compression = current_ratio
best_vocab_size = len(self.vocab)
best_state = {
'vocab': self.vocab.copy(),
'inverse_vocab': self.inverse_vocab.copy(),
'bpe_ranks': self.bpe_ranks.copy()
}
pbar.update(1)
# Stop if we've exceeded vocab size
if len(self.vocab) >= self.max_vocab_size:
break
# Restore best model if found
if best_state is not None:
print(f"\nRestoring best model (vocab size: {best_vocab_size}, compression: {best_compression:.2f})")
self.vocab = best_state['vocab']
self.inverse_vocab = best_state['inverse_vocab']
self.bpe_ranks = best_state['bpe_ranks']
# Calculate final metrics on the full text
final_ratio = self.get_compression_ratio(text)
print(f"\nFinal vocabulary size: {len(self.vocab)}")
print(f"Final compression ratio: {final_ratio:.2f}")
def encode(self, text: str) -> List[int]:
"""Encode text to token ids"""
if not text.strip():
return []
result = []
words = text.split()
for i, word in enumerate(words):
if not word.strip():
continue
# Check if the word is in vocabulary as a whole
if word in self.vocab:
result.append(self.vocab[word])
else:
# Start with character-level tokens
tokens = self._tokenize_word(word)
word_tokens = []
# Try to merge tokens using learned BPE merges
while len(tokens) > 1:
pairs = [(tokens[i], tokens[i+1]) for i in range(len(tokens) - 1)]
if not pairs:
break
# Find the highest ranked pair
best_pair = None
best_rank = float('inf')
best_idx = -1
for i, pair in enumerate(pairs):
rank = self.bpe_ranks.get(pair, float('inf'))
if rank < best_rank:
best_rank = rank
best_pair = pair
best_idx = i
if best_pair is None: # No mergeable pairs found
break
# Merge the best pair
merged = ''.join(best_pair)
if merged not in self.vocab: # Skip if merged token not in vocab
break
tokens = (
tokens[:best_idx] +
[merged] +
tokens[best_idx + 2:]
)
# Convert tokens to ids
for token in tokens:
if token in self.vocab:
word_tokens.append(self.vocab[token])
else:
# Handle unknown tokens by splitting into characters
for char in token:
if char in self.vocab:
word_tokens.append(self.vocab[char])
else:
word_tokens.append(self.vocab["<UNK>"])
result.extend(word_tokens)
# Add word boundary token except for the last word
if i < len(words) - 1:
result.append(self.vocab[self.word_end_token])
return result
def decode(self, ids: List[int]) -> str:
"""Decode token ids back to text"""
if not ids:
return ""
tokens = []
current_word = []
for id in ids:
token = self.inverse_vocab.get(id, "<UNK>")
# Skip special tokens except word boundary
if token in self.special_tokens and token != self.word_end_token:
continue
# Handle word boundary
if token == self.word_end_token:
if current_word:
word = ''.join(current_word)
tokens.append(word)
current_word = []
else:
current_word.append(token)
# Add the last word if exists
if current_word:
word = ''.join(current_word)
tokens.append(word)
# Join all words with spaces
return ' '.join(tokens)
def get_compression_ratio(self, text: str) -> float:
"""Calculate compression ratio"""
if not text:
return 0.0
original_size = len(text.encode('utf-8'))
encoded = self.encode(text)
if not encoded:
return 0.0
# Use 1 byte per token id instead of 2 since vocab size < 5000
compressed_size = len(encoded)
return original_size / compressed_size if compressed_size > 0 else 0.0 |