Pierce Maloney commited on
Commit
88fdb99
1 Parent(s): 2872c1c

returning ids

Browse files
Files changed (1) hide show
  1. handler.py +1 -2
handler.py CHANGED
@@ -37,9 +37,8 @@ class EndpointHandler():
37
 
38
  # Decode the generated ids to text
39
  # Exclude the input_ids length to get only the new tokens
40
- print("Generated IDs:", prediction_ids[0, input_ids.shape[1]:])
41
  prediction_text = self.tokenizer.decode(prediction_ids[0, input_ids.shape[1]:], skip_special_tokens=True)
42
- return [{"generated_text": prediction_text}]
43
 
44
 
45
  class StopAtPeriodCriteria(StoppingCriteria):
 
37
 
38
  # Decode the generated ids to text
39
  # Exclude the input_ids length to get only the new tokens
 
40
  prediction_text = self.tokenizer.decode(prediction_ids[0, input_ids.shape[1]:], skip_special_tokens=True)
41
+ return [{"generated_text": prediction_text, "ids": prediction_ids[0, input_ids.shape[1]:].tolist()}]
42
 
43
 
44
  class StopAtPeriodCriteria(StoppingCriteria):