|
|
|
import json, lzma, glob, sys, os, re, subprocess |
|
from pprint import pprint |
|
|
|
import torch, sys |
|
import transformers |
|
|
|
model_path = "e3.0" |
|
print(f"Loading {model_path} ...") |
|
|
|
model = transformers.AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
device_map = "auto", |
|
torch_dtype = torch.bfloat16, |
|
) |
|
tokenizer = transformers.AutoTokenizer.from_pretrained(".") |
|
|
|
from qwen_vocab import old2new, new2old |
|
STOP_WORDS = "<|im_end|> <|endoftext|>".split() |
|
|
|
|
|
def map_tids(map_dict, tids): |
|
return [ map_dict[x] for x in tids if x in map_dict ] |
|
|
|
|
|
class KeywordsStoppingCriteria(transformers.StoppingCriteria): |
|
def __init__(self, str): |
|
self.keyword_ids = tokenizer.encode(str) |
|
self.keyword_ids = map_tids(old2new, self.keyword_ids) |
|
self.keyword_len = len(self.keyword_ids) |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
last_token_ids = input_ids[0][-self.keyword_len:] |
|
return last_token_ids.tolist() == self.keyword_ids |
|
|
|
stop_criteria_list = transformers.StoppingCriteriaList( |
|
[ KeywordsStoppingCriteria(x) for x in STOP_WORDS ] |
|
) |
|
|
|
|
|
def chat(q, temperature = 0.5): |
|
prompt = f"<|im_start|>user\n{q}<|im_end|>\n<|im_start|>assistant" |
|
old_tids = tokenizer.encode(prompt) |
|
|
|
new_tids = map_tids(old2new, old_tids) |
|
new_old_tids = map_tids(new2old, new_tids) |
|
|
|
new_prompt = tokenizer.decode(new_old_tids) |
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = tokenizer(new_prompt, return_tensors="pt").to(model.device) |
|
|
|
assert inputs["input_ids"][0].tolist() == new_old_tids |
|
|
|
for i, x in enumerate(new_tids): |
|
inputs["input_ids"][0][i] = x |
|
|
|
with torch.no_grad(): |
|
output_ids = model.generate( |
|
**inputs, |
|
max_new_tokens=1024*4, |
|
temperature=0.1, |
|
top_p=1.0, top_k=30, do_sample=True, |
|
repetition_penalty=1.3, |
|
stopping_criteria=stop_criteria_list, |
|
pad_token_id=tokenizer.pad_token_id, |
|
) |
|
|
|
answer_tids = output_ids[0][len(inputs["input_ids"][0]) : ] |
|
old_tids = map_tids(new2old, answer_tids.tolist()) |
|
return tokenizer.decode(old_tids).split("<|im_end|>")[0].strip() |
|
|
|
|
|
envi = """ |
|
Không cần giải thích, giữ nguyên các từ viết tắt, các ký hiệu, và dịch đoạn văn sau sang tiếng Việt: |
|
|
|
Ví dụ 1: |
|
<|en|> Most languages have been developed using the same alphabet because of the popularity and prevalence of the latin-based English Alphabet. This alphabet is estimated to be used by around 2 billion people, and is used by many European, romance, African and Vietnamese languages. |
|
<|vi|> Hầu hết các ngôn ngữ được phát triển sử dụng cùng một bảng chữ cái do sự phổ biến và thịnh hành của bảng chữ cái tiếng Anh dựa trên hệ Latin. Bảng chữ cái này ước tính được khoảng 2 tỷ người sử dụng[4], và được dùng trong nhiều ngôn ngữ châu Âu, ngôn ngữ lãng mạn, châu Phi và tiếng Việt. |
|
|
|
Ví dụ 2: |
|
<|en|> Do you have any fun expressions in your language to say you forget something? Share them in the comments below! |
|
<|vi|> Bạn có câu nói vui nào trong ngôn ngữ của mình để diễn tả việc quên điều gì đó không? Hãy chia sẻ trong phần bình luận bên dưới! |
|
|
|
Ví dụ 3: |
|
<|en|> What is the scientific explanation for making us feel "cuteness" when we see something cute? |
|
<|vi|> Giải thích khoa học về việc tại sao chúng ta cảm thấy "dễ thương" khi nhìn thấy thứ gì đó dễ thương là gì? |
|
|
|
Không cần giải thích, giữ nguyên các từ viết tắt, các ký hiệu, và dịch đoạn văn sau sang tiếng Việt: |
|
<|en|> {english} |
|
<|vi|> |
|
""".strip() |
|
|
|
|
|
junks = """ |
|
Câu trả lời của tôi: |
|
sang tiếng Việt: |
|
sang tiếng Việt là: |
|
dịch tiếng Việt: |
|
dịch tiếng Việt là: |
|
tiếng Việt như sau: |
|
sang tiếng Việt sẽ là: |
|
tiếng Việt của đoạn văn: |
|
tiếng Việt của câu hỏi là: |
|
tiếng Việt của câu trên là: |
|
tiếng Việt của đoạn văn là: |
|
tiếng Việt của đoạn văn trên: |
|
tiếng Việt của đoạn văn như sau: |
|
tiếng Việt của đoạn văn trên là: |
|
dịch đoạn văn sau sang tiếng Việt: |
|
tiếng Việt của đoạn văn bạn yêu cầu: |
|
Bây giờ đến lượt bạn: |
|
dịch sang tiếng Việt là |
|
<|en|> |
|
<|vi|> |
|
""".strip().split("\n") |
|
|
|
|
|
|
|
def trans(prompt, temperinit = 0.2): |
|
print("\n- - - - - -\n") |
|
print(prompt, "\n==>\n" ) |
|
|
|
res = trans_(prompt, temperinit) |
|
|
|
print(res, flush = True) |
|
return res |
|
|
|
|
|
def trans_(prompt, temperinit = 0.2): |
|
|
|
if not isinstance(prompt, str): |
|
return prompt |
|
|
|
if len(prompt) < 8: |
|
return prompt |
|
|
|
trials = max_trials = 3 |
|
temperature = temperinit |
|
temperdelta = 0.2 |
|
|
|
while trials > 0: |
|
trials -= 1 |
|
n = max_trials - trials |
|
|
|
if n > 1: |
|
temperature += temperdelta |
|
print(f"\033[91m{prompt}\033[0m => {x}") |
|
print(f"\033[33mThử lại lần {n}\033[0m") |
|
|
|
x = trans__(prompt, temperature = temperature).strip() |
|
|
|
if x is not None and len(x) > 0: |
|
|
|
for j in junks: |
|
x = x.split(j.strip())[-1].strip() |
|
|
|
pp = prompt.lower() |
|
if "tiếng việt" in pp or "vietnamese" in pp: |
|
return x |
|
|
|
xx = x.lower() |
|
if "tiếng việt" not in xx: |
|
return x |
|
|
|
|
|
def trans__(prompt, temperature = 0.0): |
|
|
|
|
|
|
|
prompt = envi.format(english = prompt) |
|
res = chat(prompt, temperature = temperature) |
|
|
|
|
|
return res |
|
|
|
|
|
|
|
infile = sys.argv[1] |
|
outfile = infile.replace(".jsonl.xz", "__vi.jsonl") |
|
|
|
|
|
if os.path.exists(outfile): |
|
sources = [ json.loads(line)['source'] for line in open(outfile, "rt") ] |
|
else: |
|
sources = [] |
|
|
|
print(len(sources), sources[-1] if len(sources) > 0 else None) |
|
|
|
|
|
|
|
for idx, line in enumerate(lzma.open(infile, "rt")): |
|
|
|
source = f"{infile}:{idx}" |
|
if source in sources: continue |
|
print(source) |
|
|
|
data = json.loads(line) |
|
|
|
data["query"] = trans(data['query']) |
|
if data["query"] is None: continue |
|
|
|
for idx, x in enumerate( data["pos"] ): |
|
data['pos'][idx] = trans(x) |
|
if data['pos'][idx] is None: break |
|
|
|
if data['pos'][idx] is None: continue |
|
|
|
|
|
for idx, x in enumerate( data["neg"] ): |
|
data['neg'][idx] = trans(x) |
|
if data['neg'][idx] is None: break |
|
|
|
if data['neg'][idx] is None: continue |
|
|
|
with open(outfile, "at") as f: |
|
data["source"] = source |
|
f.write(json.dumps(data, ensure_ascii = False) + "\n") |
|
|