jatingocodeo commited on
Commit
ca9ce93
·
verified ·
1 Parent(s): f081412

Upload 2 files

Browse files
Files changed (2) hide show
  1. src/hindi_bpe.py +299 -0
  2. src/train_bpe.py +100 -0
src/hindi_bpe.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import collections
3
+ from typing import Dict, List, Tuple, Set
4
+ from tqdm import tqdm
5
+ from functools import lru_cache
6
+
7
+ class HindiBPE:
8
+ def __init__(self, max_vocab_size: int = 5000, target_compression: float = 3.2):
9
+ self.max_vocab_size = max_vocab_size
10
+ self.target_compression = target_compression
11
+ self.vocab = {"<PAD>": 0, "<UNK>": 1, "<BOS>": 2, "<EOS>": 3}
12
+ self.inverse_vocab = {v: k for k, v in self.vocab.items()}
13
+ self.bpe_ranks = {}
14
+ self.cache = {}
15
+ self.special_tokens = {"<PAD>", "<UNK>", "<BOS>", "<EOS>"}
16
+ self.word_end_token = "▁" # Special token to mark word boundaries
17
+ self.vocab[self.word_end_token] = len(self.vocab)
18
+ self.inverse_vocab[self.vocab[self.word_end_token]] = self.word_end_token
19
+
20
+ def _tokenize_word(self, word: str) -> List[str]:
21
+ """Tokenize a word into characters, handling Hindi characters properly"""
22
+ if word in self.cache:
23
+ return self.cache[word]
24
+
25
+ # First check if the whole word is in vocabulary
26
+ if word in self.vocab:
27
+ self.cache[word] = [word]
28
+ return [word]
29
+
30
+ # Split into individual characters while preserving character combinations
31
+ tokens = []
32
+ i = 0
33
+ while i < len(word):
34
+ # Check for Hindi character followed by combining marks
35
+ if re.match(r'[\u0900-\u097F]', word[i]):
36
+ token = word[i]
37
+ i += 1
38
+ # Add combining marks to the token
39
+ while i < len(word) and re.match(r'[\u0900-\u0903\u093A-\u094F\u0962-\u0963]', word[i]):
40
+ token += word[i]
41
+ i += 1
42
+ tokens.append(token)
43
+ else:
44
+ # Handle non-Hindi characters
45
+ token = word[i]
46
+ i += 1
47
+ tokens.append(token)
48
+
49
+ self.cache[word] = tokens
50
+ return tokens
51
+
52
+ def train_on_chunk(self, text: str, is_first_chunk: bool = False):
53
+ """Train BPE on text data"""
54
+ if not text.strip():
55
+ return
56
+
57
+ # Add common Hindi words and characters to vocabulary first
58
+ common_words = ["है", "मैं", "हूं", "का", "की", "के", "में", "से", "को", "पर", "और", "हैं", "था", "थी", "थे",
59
+ "नमस्ते", "भारत", "हिंदी", "सीख", "रहा", "यह", "एक", "परीक्षण", "वाक्य", "विशाल", "देश",
60
+ "मुझे", "भाषा", "बहुत", "पसंद"]
61
+ for word in common_words:
62
+ if word not in self.vocab and len(self.vocab) < self.max_vocab_size:
63
+ self.vocab[word] = len(self.vocab)
64
+ self.inverse_vocab[self.vocab[word]] = word
65
+
66
+ # First pass: collect word frequencies
67
+ word_freqs = collections.Counter(text.split())
68
+
69
+ # Add most frequent whole words to vocabulary (up to 10% of vocab size)
70
+ max_word_tokens = self.max_vocab_size // 10
71
+ for word, freq in word_freqs.most_common(max_word_tokens):
72
+ if len(word) > 1 and word not in self.vocab and len(self.vocab) < self.max_vocab_size:
73
+ self.vocab[word] = len(self.vocab)
74
+ self.inverse_vocab[self.vocab[word]] = word
75
+
76
+ # Tokenize words and filter out empty ones
77
+ words = [self._tokenize_word(word) for word in tqdm(text.split(), desc="Tokenizing words")]
78
+ words = [word for word in words if word] # Filter out empty words
79
+
80
+ if not words: # If no valid words found
81
+ return
82
+
83
+ # Initialize pair statistics
84
+ print("Computing pair statistics...")
85
+ pair_stats = collections.Counter()
86
+ for word in words:
87
+ if len(word) < 2: # Skip single-character words
88
+ continue
89
+ word_freq = word_freqs[' '.join(word)]
90
+ for i in range(len(word) - 1):
91
+ pair = (word[i], word[i+1])
92
+ pair_stats[pair] += word_freq
93
+
94
+ if not pair_stats: # If no valid pairs found
95
+ return
96
+
97
+ # Keep track of best model
98
+ best_vocab_size = len(self.vocab)
99
+ best_compression = 0.0
100
+ best_state = None
101
+
102
+ # Training loop
103
+ with tqdm(total=self.max_vocab_size - len(self.vocab), desc="Training BPE") as pbar:
104
+ while len(self.vocab) < self.max_vocab_size and pair_stats:
105
+ # Get most frequent pair
106
+ best_pair = max(pair_stats.items(), key=lambda x: (x[1], x[0]))[0]
107
+ new_token = ''.join(best_pair)
108
+
109
+ if new_token in self.vocab or len(self.vocab) >= self.max_vocab_size:
110
+ # Skip if token already exists or vocab is full
111
+ del pair_stats[best_pair]
112
+ continue
113
+
114
+ # Add to vocabulary
115
+ token_id = len(self.vocab)
116
+ self.vocab[new_token] = token_id
117
+ self.inverse_vocab[token_id] = new_token
118
+ self.bpe_ranks[best_pair] = len(self.bpe_ranks)
119
+
120
+ # Update words and pair statistics
121
+ new_words = []
122
+ for word in words:
123
+ if len(word) < 2: # Skip single-character words
124
+ new_words.append(word)
125
+ continue
126
+
127
+ i = 0
128
+ new_word = []
129
+ while i < len(word):
130
+ if i < len(word) - 1 and word[i] == best_pair[0] and word[i+1] == best_pair[1]:
131
+ new_word.append(new_token)
132
+ i += 2
133
+ else:
134
+ new_word.append(word[i])
135
+ i += 1
136
+ new_words.append(new_word)
137
+
138
+ # Update statistics
139
+ pair_stats.clear()
140
+ for word in new_words:
141
+ if len(word) < 2: # Skip single-character words
142
+ continue
143
+ word_freq = word_freqs[' '.join(word)]
144
+ for i in range(len(word) - 1):
145
+ pair = (word[i], word[i+1])
146
+ pair_stats[pair] += word_freq
147
+
148
+ words = new_words
149
+
150
+ # Calculate compression ratio every 50 tokens
151
+ if len(self.vocab) % 50 == 0:
152
+ sample_text = ' '.join([''.join(w) for w in words[:2000]])
153
+ current_ratio = self.get_compression_ratio(sample_text)
154
+ print(f"\nVocab size: {len(self.vocab)}, Compression ratio: {current_ratio:.2f}")
155
+
156
+ # Update best model if we meet requirements
157
+ if current_ratio >= self.target_compression and len(self.vocab) < self.max_vocab_size:
158
+ if current_ratio > best_compression:
159
+ best_compression = current_ratio
160
+ best_vocab_size = len(self.vocab)
161
+ best_state = {
162
+ 'vocab': self.vocab.copy(),
163
+ 'inverse_vocab': self.inverse_vocab.copy(),
164
+ 'bpe_ranks': self.bpe_ranks.copy()
165
+ }
166
+
167
+ pbar.update(1)
168
+
169
+ # Stop if we've exceeded vocab size
170
+ if len(self.vocab) >= self.max_vocab_size:
171
+ break
172
+
173
+ # Restore best model if found
174
+ if best_state is not None:
175
+ print(f"\nRestoring best model (vocab size: {best_vocab_size}, compression: {best_compression:.2f})")
176
+ self.vocab = best_state['vocab']
177
+ self.inverse_vocab = best_state['inverse_vocab']
178
+ self.bpe_ranks = best_state['bpe_ranks']
179
+
180
+ # Calculate final metrics on the full text
181
+ final_ratio = self.get_compression_ratio(text)
182
+ print(f"\nFinal vocabulary size: {len(self.vocab)}")
183
+ print(f"Final compression ratio: {final_ratio:.2f}")
184
+
185
+ def encode(self, text: str) -> List[int]:
186
+ """Encode text to token ids"""
187
+ if not text.strip():
188
+ return []
189
+
190
+ result = []
191
+ words = text.split()
192
+
193
+ for i, word in enumerate(words):
194
+ if not word.strip():
195
+ continue
196
+
197
+ # Check if the word is in vocabulary as a whole
198
+ if word in self.vocab:
199
+ result.append(self.vocab[word])
200
+ else:
201
+ # Start with character-level tokens
202
+ tokens = self._tokenize_word(word)
203
+ word_tokens = []
204
+
205
+ # Try to merge tokens using learned BPE merges
206
+ while len(tokens) > 1:
207
+ pairs = [(tokens[i], tokens[i+1]) for i in range(len(tokens) - 1)]
208
+ if not pairs:
209
+ break
210
+
211
+ # Find the highest ranked pair
212
+ best_pair = None
213
+ best_rank = float('inf')
214
+ best_idx = -1
215
+
216
+ for i, pair in enumerate(pairs):
217
+ rank = self.bpe_ranks.get(pair, float('inf'))
218
+ if rank < best_rank:
219
+ best_rank = rank
220
+ best_pair = pair
221
+ best_idx = i
222
+
223
+ if best_pair is None: # No mergeable pairs found
224
+ break
225
+
226
+ # Merge the best pair
227
+ merged = ''.join(best_pair)
228
+ if merged not in self.vocab: # Skip if merged token not in vocab
229
+ break
230
+
231
+ tokens = (
232
+ tokens[:best_idx] +
233
+ [merged] +
234
+ tokens[best_idx + 2:]
235
+ )
236
+
237
+ # Convert tokens to ids
238
+ for token in tokens:
239
+ if token in self.vocab:
240
+ word_tokens.append(self.vocab[token])
241
+ else:
242
+ # Handle unknown tokens by splitting into characters
243
+ for char in token:
244
+ if char in self.vocab:
245
+ word_tokens.append(self.vocab[char])
246
+ else:
247
+ word_tokens.append(self.vocab["<UNK>"])
248
+
249
+ result.extend(word_tokens)
250
+
251
+ # Add word boundary token except for the last word
252
+ if i < len(words) - 1:
253
+ result.append(self.vocab[self.word_end_token])
254
+
255
+ return result
256
+
257
+ def decode(self, ids: List[int]) -> str:
258
+ """Decode token ids back to text"""
259
+ if not ids:
260
+ return ""
261
+
262
+ tokens = []
263
+ current_word = []
264
+
265
+ for id in ids:
266
+ token = self.inverse_vocab.get(id, "<UNK>")
267
+
268
+ # Skip special tokens except word boundary
269
+ if token in self.special_tokens and token != self.word_end_token:
270
+ continue
271
+
272
+ # Handle word boundary
273
+ if token == self.word_end_token:
274
+ if current_word:
275
+ word = ''.join(current_word)
276
+ tokens.append(word)
277
+ current_word = []
278
+ else:
279
+ current_word.append(token)
280
+
281
+ # Add the last word if exists
282
+ if current_word:
283
+ word = ''.join(current_word)
284
+ tokens.append(word)
285
+
286
+ # Join all words with spaces
287
+ return ' '.join(tokens)
288
+
289
+ def get_compression_ratio(self, text: str) -> float:
290
+ """Calculate compression ratio"""
291
+ if not text:
292
+ return 0.0
293
+ original_size = len(text.encode('utf-8'))
294
+ encoded = self.encode(text)
295
+ if not encoded:
296
+ return 0.0
297
+ # Use 1 byte per token id instead of 2 since vocab size < 5000
298
+ compressed_size = len(encoded)
299
+ return original_size / compressed_size if compressed_size > 0 else 0.0
src/train_bpe.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from hindi_bpe import HindiBPE
3
+ from tqdm import tqdm
4
+
5
+ def load_processed_data_in_chunks(file_path: str, max_sentences: int = 1_000_000) -> str:
6
+ """Load data in chunks, up to max_sentences"""
7
+ buffer = []
8
+ sentence_count = 0
9
+
10
+ with open(file_path, 'r', encoding='utf-8') as f:
11
+ for line in tqdm(f, desc="Reading sentences"):
12
+ if sentence_count >= max_sentences:
13
+ break
14
+
15
+ line = line.strip()
16
+ if not line:
17
+ continue
18
+
19
+ buffer.append(line)
20
+ sentence_count += 1
21
+
22
+ if len(buffer) >= 10000: # Process in chunks of 10K sentences
23
+ yield ' '.join(buffer)
24
+ buffer = []
25
+
26
+ if buffer: # Don't forget the last chunk
27
+ yield ' '.join(buffer)
28
+
29
+ def main():
30
+ # Initialize paths
31
+ data_dir = os.path.join("..", "data")
32
+ processed_file = os.path.join(data_dir, "hi_processed.txt")
33
+
34
+ # Check if processed data exists
35
+ if not os.path.exists(processed_file):
36
+ print("Processed data not found. Please run download_data.py first.")
37
+ return
38
+
39
+ # Initialize BPE
40
+ print("Initializing BPE tokenizer...")
41
+ print("Training Parameters:")
42
+ print("1. Using first 1 million sentences")
43
+ print("2. Vocabulary size must be < 5000 tokens")
44
+ print("3. Compression ratio must be ≥ 3.2")
45
+ bpe = HindiBPE()
46
+
47
+ print("\nTraining BPE model...")
48
+ is_first_chunk = True
49
+ total_sentences = 0
50
+
51
+ for chunk in load_processed_data_in_chunks(processed_file):
52
+ if not chunk.strip():
53
+ continue
54
+
55
+ bpe.train_on_chunk(chunk, is_first_chunk=is_first_chunk)
56
+ is_first_chunk = False
57
+
58
+ # Check if we've met both requirements
59
+ test_text = chunk[:10000] # Use a sample of text
60
+ compression_ratio = bpe.get_compression_ratio(test_text)
61
+ vocab_size = len(bpe.vocab)
62
+
63
+ print(f"\nCurrent status:")
64
+ print(f"Vocabulary size: {vocab_size} tokens")
65
+ print(f"Compression ratio: {compression_ratio:.2f}")
66
+
67
+ if compression_ratio >= 3.2:
68
+ if vocab_size < 5000:
69
+ print("\nSuccess! Met all requirements:")
70
+ print(f"1. Vocabulary size: {vocab_size} tokens (< 5000)")
71
+ print(f"2. Compression ratio: {compression_ratio:.2f} (≥ 3.2)")
72
+ break
73
+ else:
74
+ print("\nWarning: Need to reduce vocabulary size while maintaining compression ratio")
75
+
76
+ print("\nFinal Results:")
77
+ print(f"Vocabulary size: {len(bpe.vocab)} tokens")
78
+ print(f"Compression ratio: {compression_ratio:.2f}")
79
+
80
+ # Test the model with various Hindi texts
81
+ test_cases = [
82
+ "नमस्ते भारत",
83
+ "मैं हिंदी सीख रहा हूं",
84
+ "यह एक परीक्षण वाक्य है",
85
+ "भारत एक विशाल देश है",
86
+ "मुझे हिंदी भाषा बहुत पसंद है"
87
+ ]
88
+
89
+ print("\nTesting encoding/decoding on multiple examples:")
90
+ for i, test_text in enumerate(test_cases, 1):
91
+ print(f"\nTest case {i}:")
92
+ print(f"Original: {test_text}")
93
+ encoded = bpe.encode(test_text)
94
+ print(f"Encoded: {encoded}")
95
+ decoded = bpe.decode(encoded)
96
+ print(f"Decoded: {decoded}")
97
+ print(f"Matches: {'✓' if decoded == test_text else '✗'}")
98
+
99
+ if __name__ == "__main__":
100
+ main()