tiendung's picture
update
c7adb99
raw
history blame
7.05 kB
#!/usr/bin/env python3
import json, lzma, glob, sys, os, re, subprocess
from pprint import pprint
import torch, sys
import transformers
model_path = "e3.0"
print(f"Loading {model_path} ...")
model = transformers.AutoModelForCausalLM.from_pretrained(
model_path,
device_map = "auto",
torch_dtype = torch.bfloat16,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(".")
from qwen_vocab import old2new, new2old
STOP_WORDS = "<|im_end|> <|endoftext|>".split()
def map_tids(map_dict, tids):
return [ map_dict[x] for x in tids if x in map_dict ]
class KeywordsStoppingCriteria(transformers.StoppingCriteria):
def __init__(self, str):
self.keyword_ids = tokenizer.encode(str)
self.keyword_ids = map_tids(old2new, self.keyword_ids)
self.keyword_len = len(self.keyword_ids)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
last_token_ids = input_ids[0][-self.keyword_len:]
return last_token_ids.tolist() == self.keyword_ids
stop_criteria_list = transformers.StoppingCriteriaList(
[ KeywordsStoppingCriteria(x) for x in STOP_WORDS ]
)
def chat(q, temperature = 0.5):
prompt = f"<|im_start|>user\n{q}<|im_end|>\n<|im_start|>assistant"
old_tids = tokenizer.encode(prompt)
new_tids = map_tids(old2new, old_tids)
new_old_tids = map_tids(new2old, new_tids)
new_prompt = tokenizer.decode(new_old_tids)
# if new_old_tids != old_tids:
# print(f"!!! Cảnh báo sự trimm vocab làm mất thông tin !!!")
# print(f"!!! old prompt: {prompt}")
# print(f"!!! new prompt: {new_prompt}")
inputs = tokenizer(new_prompt, return_tensors="pt").to(model.device)
assert inputs["input_ids"][0].tolist() == new_old_tids
for i, x in enumerate(new_tids):
inputs["input_ids"][0][i] = x
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=1024*4,
temperature=0.1,
top_p=1.0, top_k=30, do_sample=True,
repetition_penalty=1.3,
stopping_criteria=stop_criteria_list,
pad_token_id=tokenizer.pad_token_id,
)
answer_tids = output_ids[0][len(inputs["input_ids"][0]) : ] # bỏ đi prompt tokens
old_tids = map_tids(new2old, answer_tids.tolist())
return tokenizer.decode(old_tids).split("<|im_end|>")[0].strip()
envi = """
Không cần giải thích, giữ nguyên các từ viết tắt, các ký hiệu, và dịch đoạn văn sau sang tiếng Việt:
Ví dụ 1:
<|en|> Most languages have been developed using the same alphabet because of the popularity and prevalence of the latin-based English Alphabet. This alphabet is estimated to be used by around 2 billion people, and is used by many European, romance, African and Vietnamese languages.
<|vi|> Hầu hết các ngôn ngữ được phát triển sử dụng cùng một bảng chữ cái do sự phổ biến và thịnh hành của bảng chữ cái tiếng Anh dựa trên hệ Latin. Bảng chữ cái này ước tính được khoảng 2 tỷ người sử dụng[4], và được dùng trong nhiều ngôn ngữ châu Âu, ngôn ngữ lãng mạn, châu Phi và tiếng Việt.
Ví dụ 2:
<|en|> Do you have any fun expressions in your language to say you forget something? Share them in the comments below!
<|vi|> Bạn có câu nói vui nào trong ngôn ngữ của mình để diễn tả việc quên điều gì đó không? Hãy chia sẻ trong phần bình luận bên dưới!
Ví dụ 3:
<|en|> What is the scientific explanation for making us feel "cuteness" when we see something cute?
<|vi|> Giải thích khoa học về việc tại sao chúng ta cảm thấy "dễ thương" khi nhìn thấy thứ gì đó dễ thương là gì?
Không cần giải thích, giữ nguyên các từ viết tắt, các ký hiệu, và dịch đoạn văn sau sang tiếng Việt:
<|en|> {english}
<|vi|>
""".strip()
junks = """
Câu trả lời của tôi:
sang tiếng Việt:
sang tiếng Việt là:
dịch tiếng Việt:
dịch tiếng Việt là:
tiếng Việt như sau:
sang tiếng Việt sẽ là:
tiếng Việt của đoạn văn:
tiếng Việt của câu hỏi là:
tiếng Việt của câu trên là:
tiếng Việt của đoạn văn là:
tiếng Việt của đoạn văn trên:
tiếng Việt của đoạn văn như sau:
tiếng Việt của đoạn văn trên là:
dịch đoạn văn sau sang tiếng Việt:
tiếng Việt của đoạn văn bạn yêu cầu:
Bây giờ đến lượt bạn:
dịch sang tiếng Việt là
<|en|>
<|vi|>
""".strip().split("\n")
# print(junks)
def trans(prompt, temperinit = 0.2):
print("\n- - - - - -\n")
print(prompt, "\n==>\n" )
res = trans_(prompt, temperinit)
print(res, flush = True)
return res
def trans_(prompt, temperinit = 0.2):
if not isinstance(prompt, str):
return prompt
if len(prompt) < 8:
return prompt
trials = max_trials = 3
temperature = temperinit
temperdelta = 0.2
while trials > 0:
trials -= 1
n = max_trials - trials
if n > 1:
temperature += temperdelta
print(f"\033[91m{prompt}\033[0m => {x}") # Red then reset
print(f"\033[33mThử lại lần {n}\033[0m") # Yellow then reset
x = trans__(prompt, temperature = temperature).strip()
if x is not None and len(x) > 0:
for j in junks: # Loại bỏ những header thừa
x = x.split(j.strip())[-1].strip()
pp = prompt.lower()
if "tiếng việt" in pp or "vietnamese" in pp:
return x
xx = x.lower()
if "tiếng việt" not in xx:
return x
def trans__(prompt, temperature = 0.0):
# print("\n- - - - - -\n")
# print(prompt, "\n==>\n")
prompt = envi.format(english = prompt)
res = chat(prompt, temperature = temperature)
# print(res)
return res
# infile = args.input
infile = sys.argv[1]
outfile = infile.replace(".jsonl.xz", "__vi.jsonl")
if os.path.exists(outfile):
sources = [ json.loads(line)['source'] for line in open(outfile, "rt") ]
else:
sources = []
print(len(sources), sources[-1] if len(sources) > 0 else None)
for idx, line in enumerate(lzma.open(infile, "rt")):
source = f"{infile}:{idx}"
if source in sources: continue
print(source)
data = json.loads(line)
data["query"] = trans(data['query'])
if data["query"] is None: continue
for idx, x in enumerate( data["pos"] ):
data['pos'][idx] = trans(x)
if data['pos'][idx] is None: break
if data['pos'][idx] is None: continue
for idx, x in enumerate( data["neg"] ):
data['neg'][idx] = trans(x)
if data['neg'][idx] is None: break
if data['neg'][idx] is None: continue
with open(outfile, "at") as f:
data["source"] = source
f.write(json.dumps(data, ensure_ascii = False) + "\n")