Update README.md
Browse files
README.md
CHANGED
@@ -85,11 +85,12 @@ res_prefix.append(EOS_TOKEN_ID)
|
|
85 |
l_rp=len(res_prefix)
|
86 |
|
87 |
tokenized=tokenizer.encode(plain_text,add_special_tokens=False,truncation=True,max_length=self.max_seq_length-2-l_rp)
|
88 |
-
|
89 |
tokenized+=res_prefix
|
|
|
|
|
90 |
|
91 |
# Generate answer
|
92 |
-
pred_ids = model.generate(input_ids=
|
93 |
pred_tokens=tokenizer.batch_decode(pred_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
94 |
res=pred_tokens.replace('<extra_id_0>','').replace('有答案:','')
|
95 |
```
|
|
|
85 |
l_rp=len(res_prefix)
|
86 |
|
87 |
tokenized=tokenizer.encode(plain_text,add_special_tokens=False,truncation=True,max_length=self.max_seq_length-2-l_rp)
|
|
|
88 |
tokenized+=res_prefix
|
89 |
+
batch=[tokenized]*2
|
90 |
+
input_ids=torch.tensor(np.array(batch),dtype=torch.long)
|
91 |
|
92 |
# Generate answer
|
93 |
+
pred_ids = model.generate(input_ids=input_ids,max_new_token=self.max_target_length,do_sample=True,top_p=0.9)
|
94 |
pred_tokens=tokenizer.batch_decode(pred_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
95 |
res=pred_tokens.replace('<extra_id_0>','').replace('有答案:','')
|
96 |
```
|