|
import os, sys, glob, json
|
|
from utils_lang import *
|
|
from transformers import AutoTokenizer
|
|
|
|
def get_kept_tids():
|
|
|
|
kept_tids = set( x for x in range(151643, 151664 + 1) )
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(".")
|
|
|
|
canbe_vi_kept = 0
|
|
is_ascii_kept = 0
|
|
|
|
for tid in range(0, tokenizer.vocab_size):
|
|
token = tokenizer.decode(tid)
|
|
|
|
if vietnamese_syllable_ratio(token) > 0.8:
|
|
canbe_vi_kept += 1
|
|
kept_tids.add(tid)
|
|
|
|
if len(token) <= 2 and canbe_vietnamese(token):
|
|
canbe_vi_kept += 1
|
|
kept_tids.add(tid)
|
|
|
|
if len(token) <= 2 and is_ascii(token):
|
|
is_ascii_kept += 1
|
|
kept_tids.add(tid)
|
|
|
|
print(">>> canbe_vi_kept", canbe_vi_kept)
|
|
print(">>> is_ascii_kept", is_ascii_kept)
|
|
|
|
|
|
kept_filenames = glob.glob("qwen__1000__20000/tokens_kept__*.jsonl")
|
|
|
|
for filename in kept_filenames:
|
|
for line in open(filename, "rt"):
|
|
token, tid, count = json.loads(line)
|
|
kept_tids.add(tid)
|
|
|
|
kept_tids = list( kept_tids )
|
|
kept_tids.sort()
|
|
|
|
print("new_qwen_vocab", len(kept_tids))
|
|
return kept_tids
|
|
|
|
|
|
kept_tids = get_kept_tids()
|
|
|
|
|
|
old2new = {}
|
|
new2old = {}
|
|
|
|
for new_tid, old_tid in enumerate( kept_tids ):
|
|
old2new[ old_tid ] = new_tid
|
|
new2old[ new_tid ] = old_tid
|
|
|
|
|
|
STRANGE_TOKENS = set()
|
|
|
|
def old2new_tid(x, tokenizer):
|
|
global STRANGE_TOKENS
|
|
|
|
if x in old2new:
|
|
return old2new[x]
|
|
|
|
else:
|
|
token = tokenizer.decode(x)
|
|
if contains_unwanted(token):
|
|
return None
|
|
|
|
words = re.findall(r'[a-z]+', token, flags = re.IGNORECASE)
|
|
|
|
if len(words) > 1:
|
|
print(">>>", words)
|
|
|
|
if len(words) == 1:
|
|
tids = tokenizer.encode(words[0])
|
|
if len(tids) == 1 and tids[0] in old2new:
|
|
return old2new[tids[0]]
|
|
|
|
msg = f">>> old2new_tid error: id {x}, token '{token}'"
|
|
if token not in STRANGE_TOKENS:
|
|
print(msg)
|
|
STRANGE_TOKENS.add( token )
|
|
|
|
|
|
return None
|
|
|
|
assert False, "Không thể tới bước này, có lỗi ở phần code trên"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
n = len(kept_tids)
|
|
nn = round(n / 64) * 64
|
|
|
|
print("kept_tids", n)
|
|
print(n, nn)
|
|
|