Justcode commited on
Commit
18705d6
1 Parent(s): 8d455bd

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -2
README.md CHANGED
@@ -85,11 +85,12 @@ res_prefix.append(EOS_TOKEN_ID)
85
  l_rp=len(res_prefix)
86
 
87
  tokenized=tokenizer.encode(plain_text,add_special_tokens=False,truncation=True,max_length=self.max_seq_length-2-l_rp)
88
-
89
  tokenized+=res_prefix
 
 
90
 
91
  # Generate answer
92
- pred_ids = model.generate(input_ids=tokenized,max_new_token=self.max_target_length,do_sample=True,top_p=0.9)
93
  pred_tokens=tokenizer.batch_decode(pred_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
94
  res=pred_tokens.replace('<extra_id_0>','').replace('有答案:','')
95
  ```
 
85
  l_rp=len(res_prefix)
86
 
87
  tokenized=tokenizer.encode(plain_text,add_special_tokens=False,truncation=True,max_length=self.max_seq_length-2-l_rp)
 
88
  tokenized+=res_prefix
89
+ batch=[tokenized]*2
90
+ input_ids=torch.tensor(np.array(batch),dtype=torch.long)
91
 
92
  # Generate answer
93
+ pred_ids = model.generate(input_ids=input_ids,max_new_token=self.max_target_length,do_sample=True,top_p=0.9)
94
  pred_tokens=tokenizer.batch_decode(pred_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
95
  res=pred_tokens.replace('<extra_id_0>','').replace('有答案:','')
96
  ```