|
import os, sys, glob, json |
|
from utils_lang import * |
|
from transformers import AutoTokenizer |
|
|
|
def get_kept_tids(): |
|
|
|
kept_tids = set( x for x in range(151643, 151664 + 1) ) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(".") |
|
|
|
canbe_vi_kept = 0 |
|
is_ascii_kept = 0 |
|
|
|
for tid in range(0, tokenizer.vocab_size): |
|
token = tokenizer.decode(tid) |
|
|
|
if vietnamese_syllable_ratio(token) > 0.8: |
|
canbe_vi_kept += 1 |
|
kept_tids.add(tid) |
|
|
|
if len(token) <= 2 and canbe_vietnamese(token): |
|
canbe_vi_kept += 1 |
|
kept_tids.add(tid) |
|
|
|
if len(token) <= 2 and is_ascii(token): |
|
is_ascii_kept += 1 |
|
kept_tids.add(tid) |
|
|
|
print(">>> canbe_vi_kept", canbe_vi_kept) |
|
print(">>> is_ascii_kept", is_ascii_kept) |
|
|
|
kept_filenames = glob.glob("data/qwen__1000__20000/tokens_kept__*.jsonl") |
|
|
|
for filename in kept_filenames: |
|
for line in open(filename, "rt"): |
|
token, tid, count = json.loads(line) |
|
kept_tids.add(tid) |
|
|
|
kept_tids = list( kept_tids ) |
|
kept_tids.sort() |
|
|
|
print("new_qwen_vocab", len(kept_tids)) |
|
return kept_tids |
|
|
|
|
|
kept_tids = get_kept_tids() |
|
|
|
|
|
old2new = {} |
|
new2old = {} |
|
|
|
for new_tid, old_tid in enumerate( kept_tids ): |
|
old2new[ old_tid ] = new_tid |
|
new2old[ new_tid ] = old_tid |
|
|
|
|
|
STRANGE_TOKENS = set() |
|
|
|
def old2new_tid(x, tokenizer): |
|
global STRANGE_TOKENS |
|
|
|
if x in old2new: |
|
return old2new[x] |
|
|
|
else: |
|
token = tokenizer.decode(x) |
|
if contains_unwanted(token): |
|
return None |
|
|
|
words = re.findall(r'[a-z]+', token, flags = re.IGNORECASE) |
|
|
|
if len(words) > 1: |
|
print(">>>", words) |
|
|
|
if len(words) == 1: |
|
tids = tokenizer.encode(words[0]) |
|
if len(tids) == 1 and tids[0] in old2new: |
|
return old2new[tids[0]] |
|
|
|
msg = f">>> old2new_tid error: id {x}, token '{token}'" |
|
if token not in STRANGE_TOKENS: |
|
print(msg) |
|
STRANGE_TOKENS.add( token ) |
|
|
|
|
|
return None |
|
|
|
assert False, "Không thể tới bước này, có lỗi ở phần code trên" |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
n = len(kept_tids) |
|
nn = round(n / 64) * 64 |
|
|
|
print("kept_tids", n) |
|
print(n, nn) |
|
|