Tonic commited on
Commit
a103a7f
β€’
1 Parent(s): 44eb742

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -9
app.py CHANGED
@@ -13,17 +13,24 @@ model.to(device)
13
 
14
  def historical_generation(prompt, max_new_tokens=600):
15
  prompt = f"### Text ###\n{prompt}"
16
- input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
 
 
17
 
18
  # Generate text
19
- output = model.generate(input_ids,
20
- max_new_tokens=max_new_tokens,
21
- pad_token_id=tokenizer.eos_token_id,
22
- top_k=50,
23
- temperature=0.3,
24
- top_p=0.95,
25
- do_sample=True,
26
- repetition_penalty=1.5)
 
 
 
 
 
27
 
28
  # Decode the generated text
29
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
 
13
 
14
  def historical_generation(prompt, max_new_tokens=600):
15
  prompt = f"### Text ###\n{prompt}"
16
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
17
+ input_ids = inputs["input_ids"].to(device)
18
+ attention_mask = inputs["attention_mask"].to(device)
19
 
20
  # Generate text
21
+ output = model.generate(
22
+ input_ids,
23
+ attention_mask=attention_mask,
24
+ max_new_tokens=max_new_tokens,
25
+ pad_token_id=tokenizer.eos_token_id,
26
+ top_k=50,
27
+ temperature=0.3,
28
+ top_p=0.95,
29
+ do_sample=True,
30
+ repetition_penalty=1.5,
31
+ bos_token_id=tokenizer.bos_token_id,
32
+ eos_token_id=tokenizer.eos_token_id
33
+ )
34
 
35
  # Decode the generated text
36
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)