tiendung commited on
Commit
85e407f
1 Parent(s): 81adcf0
Files changed (2) hide show
  1. model_chat.py +1 -1
  2. translate.py +222 -0
model_chat.py CHANGED
@@ -2,7 +2,7 @@ import torch, sys
2
  import transformers
3
 
4
  try: model_path = sys.argv[1]
5
- except: model_path = "e2.0"
6
 
7
  print(f"Loading {model_path} ...")
8
 
 
2
  import transformers
3
 
4
  try: model_path = sys.argv[1]
5
+ except: model_path = "e3.0"
6
 
7
  print(f"Loading {model_path} ...")
8
 
translate.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import json, lzma, glob, sys, os, re, subprocess
3
+ from pprint import pprint
4
+
5
+ import torch, sys
6
+ import transformers
7
+
8
+ model_path = "e3.0"
9
+ print(f"Loading {model_path} ...")
10
+
11
+ model = transformers.AutoModelForCausalLM.from_pretrained(
12
+ model_path,
13
+ device_map = "auto",
14
+ torch_dtype = torch.bfloat16,
15
+ )
16
+ tokenizer = transformers.AutoTokenizer.from_pretrained(".")
17
+
18
+ from qwen_vocab import old2new, new2old
19
+ STOP_WORDS = "<|im_end|> <|endoftext|>".split()
20
+
21
+
22
+ def map_tids(map_dict, tids):
23
+ return [ map_dict[x] for x in tids if x in map_dict ]
24
+
25
+
26
+ class KeywordsStoppingCriteria(transformers.StoppingCriteria):
27
+ def __init__(self, str):
28
+ self.keyword_ids = tokenizer.encode(str)
29
+ self.keyword_ids = map_tids(old2new, self.keyword_ids)
30
+ self.keyword_len = len(self.keyword_ids)
31
+
32
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
33
+ last_token_ids = input_ids[0][-self.keyword_len:]
34
+ return last_token_ids.tolist() == self.keyword_ids
35
+
36
+ stop_criteria_list = transformers.StoppingCriteriaList(
37
+ [ KeywordsStoppingCriteria(x) for x in STOP_WORDS ]
38
+ )
39
+
40
+
41
+ def chat(prompt, temperature = 0.5):
42
+ prompt = f"<|im_start|>user\n{q}<|im_end|>\n<|im_start|>assistant"
43
+ old_tids = tokenizer.encode(prompt)
44
+
45
+ new_tids = map_tids(old2new, old_tids)
46
+ new_old_tids = map_tids(new2old, new_tids)
47
+
48
+ new_prompt = tokenizer.decode(new_old_tids)
49
+
50
+ if new_old_tids != old_tids:
51
+ print(f"!!! Cảnh báo sự trimm vocab làm mất thông tin !!!")
52
+ print(f"!!! old prompt: {prompt}")
53
+ print(f"!!! new prompt: {new_prompt}")
54
+
55
+ inputs = tokenizer(new_prompt, return_tensors="pt").to(model.device)
56
+
57
+ assert inputs["input_ids"][0].tolist() == new_old_tids
58
+
59
+ for i, x in enumerate(new_tids):
60
+ inputs["input_ids"][0][i] = x
61
+
62
+ with torch.no_grad():
63
+ output_ids = model.generate(
64
+ **inputs,
65
+ max_new_tokens=1024*4,
66
+ temperature=temperature,
67
+ top_p=1.0, top_k=30, do_sample=True,
68
+ repetition_penalty=1.1,
69
+ stopping_criteria=stop_criteria_list,
70
+ pad_token_id=tokenizer.pad_token_id,
71
+ )
72
+
73
+ answer_tids = output_ids[0][len(inputs["input_ids"][0]) : ] # bỏ đi prompt tokens
74
+ old_tids = map_tids(new2old, answer_tids.tolist())
75
+ return tokenizer.decode(old_tids).split("<|im_end|>")[0].strip()
76
+
77
+
78
+ envi = """
79
+ Không cần giải thích, giữ nguyên các từ viết tắt, các ký hiệu, và dịch đoạn văn sau sang tiếng Việt:
80
+
81
+ Ví dụ 1:
82
+ <|en|> Most languages have been developed using the same alphabet because of the popularity and prevalence of the latin-based English Alphabet. This alphabet is estimated to be used by around 2 billion people, and is used by many European, romance, African and Vietnamese languages.
83
+ <|vi|> Hầu hết các ngôn ngữ được phát triển sử dụng cùng một bảng chữ cái do sự phổ biến và thịnh hành của bảng chữ cái tiếng Anh dựa trên hệ Latin. Bảng chữ cái này ước tính được khoảng 2 tỷ người sử dụng[4], và được dùng trong nhiều ngôn ngữ châu Âu, ngôn ngữ lãng mạn, châu Phi và tiếng Việt.
84
+
85
+ Ví dụ 2:
86
+ <|en|> Do you have any fun expressions in your language to say you forget something? Share them in the comments below!
87
+ <|vi|> Bạn có câu nói vui nào trong ngôn ngữ của mình để diễn tả việc quên điều gì đó không? Hãy chia sẻ trong phần bình luận bên dưới!
88
+
89
+ Ví dụ 3:
90
+ <|en|> What is the scientific explanation for making us feel "cuteness" when we see something cute?
91
+ <|vi|> Giải thích khoa học về việc tại sao chúng ta cảm thấy "dễ thương" khi nhìn thấy thứ gì đó dễ thương là gì?
92
+
93
+ Không cần giải thích, giữ nguyên các từ viết tắt, các ký hiệu, và dịch đoạn văn sau sang tiếng Việt:
94
+ <|en|> {english}
95
+ <|vi|>
96
+ """.strip()
97
+
98
+
99
+ junks = """
100
+ Câu trả lời của tôi:
101
+ sang tiếng Việt:
102
+ sang tiếng Việt là:
103
+ dịch tiếng Việt:
104
+ dịch tiếng Việt là:
105
+ tiếng Việt như sau:
106
+ sang tiếng Việt sẽ là:
107
+ tiếng Việt của đoạn văn:
108
+ tiếng Việt của câu hỏi là:
109
+ tiếng Việt của câu trên là:
110
+ tiếng Việt của đoạn văn là:
111
+ tiếng Việt của đoạn văn trên:
112
+ tiếng Việt của đoạn văn như sau:
113
+ tiếng Việt của đoạn văn trên là:
114
+ dịch đoạn văn sau sang tiếng Việt:
115
+ tiếng Việt của đoạn văn bạn yêu cầu:
116
+ Bây giờ đến lượt bạn:
117
+ dịch sang tiếng Việt là
118
+ <|en|>
119
+ <|vi|>
120
+ """.strip().split("\n")
121
+
122
+ # print(junks)
123
+
124
+ def trans(prompt, temperinit = 0.2):
125
+ print("\n- - - - - -\n")
126
+ print(prompt, "\n==>\n" )
127
+
128
+ res = trans_(prompt, temperinit)
129
+
130
+ print(res, flush = True)
131
+ return res
132
+
133
+
134
+ def trans_(prompt, temperinit = 0.2):
135
+
136
+ if not isinstance(prompt, str):
137
+ return prompt
138
+
139
+ if len(prompt) < 8:
140
+ return prompt
141
+
142
+ trials = max_trials = 3
143
+ temperature = temperinit
144
+ temperdelta = 0.2
145
+
146
+ while trials > 0:
147
+ trials -= 1
148
+ n = max_trials - trials
149
+
150
+ if n > 1:
151
+ temperature += temperdelta
152
+ print(f"\033[91m{prompt}\033[0m => {x}") # Red then reset
153
+ print(f"\033[33mThử lại lần {n}\033[0m") # Yellow then reset
154
+
155
+ x = trans__(prompt, temperature = temperature).strip()
156
+
157
+ if x is not None and len(x) > 0:
158
+
159
+ for j in junks: # Loại bỏ những header thừa
160
+ x = x.split(j.strip())[-1].strip()
161
+
162
+ pp = prompt.lower()
163
+ if "tiếng việt" in pp or "vietnamese" in pp:
164
+ return x
165
+
166
+ xx = x.lower()
167
+ if "tiếng việt" not in xx:
168
+ return x
169
+
170
+
171
+ def trans__(prompt, temperature = 0.0):
172
+ # print("\n- - - - - -\n")
173
+ # print(prompt, "\n==>\n")
174
+
175
+ prompt = envi.format(english = prompt)
176
+ res = chat(prompt, temperature = temperature)
177
+
178
+ # print(res)
179
+ return res
180
+
181
+
182
+ # infile = args.input
183
+ infile = sys.argv[1]
184
+ outfile = infile.replace(".jsonl.xz", "__vi.jsonl")
185
+
186
+
187
+ if os.path.exists(outfile):
188
+ sources = [ json.loads(line)['source'] for line in open(outfile, "rt") ]
189
+ else:
190
+ sources = []
191
+
192
+ print(len(sources), sources[-1] if len(sources) > 0 else None)
193
+
194
+
195
+
196
+ for idx, line in enumerate(lzma.open(infile, "rt")):
197
+
198
+ source = f"{infile}:{idx}"
199
+ if source in sources: continue
200
+ print(source)
201
+
202
+ data = json.loads(line)
203
+
204
+ data["query"] = trans(data['query'])
205
+ if data["query"] is None: continue
206
+
207
+ for idx, x in enumerate( data["pos"] ):
208
+ data['pos'][idx] = trans(x)
209
+ if data['pos'][idx] is None: break
210
+
211
+ if data['pos'][idx] is None: continue
212
+
213
+
214
+ for idx, x in enumerate( data["neg"] ):
215
+ data['neg'][idx] = trans(x)
216
+ if data['neg'][idx] is None: break
217
+
218
+ if data['neg'][idx] is None: continue
219
+
220
+ with open(outfile, "at") as f:
221
+ data["source"] = source
222
+ f.write(json.dumps(data, ensure_ascii = False) + "\n")