Pierce Maloney commited on
Commit
b179918
1 Parent(s): 6d8b690

do_sample=True, and model.to('cuda')

Browse files
Files changed (1) hide show
  1. handler.py +3 -2
handler.py CHANGED
@@ -6,7 +6,7 @@ class EndpointHandler():
6
  def __init__(self, path=""):
7
  tokenizer = AutoTokenizer.from_pretrained(path)
8
  tokenizer.pad_token = tokenizer.eos_token
9
- self.model = AutoModelForCausalLM.from_pretrained(path)
10
  self.tokenizer = tokenizer
11
  self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])
12
 
@@ -29,7 +29,7 @@ class EndpointHandler():
29
  bad_words_ids = [[3070], [313, 334], [10456], [13], [1976, 441, 29889], [2087, 29885, 4430, 29889], [4920, 441], [4920, 441, 29889], [4920, 18054, 29889]]
30
  bad_words_ids.extend(additional_bad_words_ids)
31
 
32
- input_ids = self.tokenizer.encode(inputs, return_tensors="pt")
33
  max_generation_length = 75 # Desired number of tokens to generate
34
  max_input_length = 4092 - max_generation_length # Maximum input length to allow space for generation
35
 
@@ -45,6 +45,7 @@ class EndpointHandler():
45
  bad_words_ids=bad_words_ids,
46
  temperature=1,
47
  top_k=40,
 
48
  stopping_criteria=self.stopping_criteria,
49
  )
50
 
 
6
  def __init__(self, path=""):
7
  tokenizer = AutoTokenizer.from_pretrained(path)
8
  tokenizer.pad_token = tokenizer.eos_token
9
+ self.model = AutoModelForCausalLM.from_pretrained(path).to('cuda')
10
  self.tokenizer = tokenizer
11
  self.stopping_criteria = StoppingCriteriaList([StopAtPeriodCriteria(tokenizer)])
12
 
 
29
  bad_words_ids = [[3070], [313, 334], [10456], [13], [1976, 441, 29889], [2087, 29885, 4430, 29889], [4920, 441], [4920, 441, 29889], [4920, 18054, 29889]]
30
  bad_words_ids.extend(additional_bad_words_ids)
31
 
32
+ input_ids = self.tokenizer.encode(inputs, return_tensors="pt").to('cuda')
33
  max_generation_length = 75 # Desired number of tokens to generate
34
  max_input_length = 4092 - max_generation_length # Maximum input length to allow space for generation
35
 
 
45
  bad_words_ids=bad_words_ids,
46
  temperature=1,
47
  top_k=40,
48
+ do_sample=True,
49
  stopping_criteria=self.stopping_criteria,
50
  )
51