Pierce Maloney
commited on
Commit
•
6d8b690
1
Parent(s):
4c4f932
truncating earliest tokens if longer than 4092
Browse files- handler.py +9 -3
handler.py
CHANGED
@@ -4,7 +4,6 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, Stopping
|
|
4 |
|
5 |
class EndpointHandler():
|
6 |
def __init__(self, path=""):
|
7 |
-
# Preload all the elements you are going to need at inference.
|
8 |
tokenizer = AutoTokenizer.from_pretrained(path)
|
9 |
tokenizer.pad_token = tokenizer.eos_token
|
10 |
self.model = AutoModelForCausalLM.from_pretrained(path)
|
@@ -31,11 +30,18 @@ class EndpointHandler():
|
|
31 |
bad_words_ids.extend(additional_bad_words_ids)
|
32 |
|
33 |
input_ids = self.tokenizer.encode(inputs, return_tensors="pt")
|
|
|
|
|
34 |
|
35 |
-
#
|
|
|
|
|
|
|
|
|
|
|
36 |
generated_ids = self.model.generate(
|
37 |
input_ids,
|
38 |
-
max_length=
|
39 |
bad_words_ids=bad_words_ids,
|
40 |
temperature=1,
|
41 |
top_k=40,
|
|
|
4 |
|
5 |
class EndpointHandler():
|
6 |
def __init__(self, path=""):
|
|
|
7 |
tokenizer = AutoTokenizer.from_pretrained(path)
|
8 |
tokenizer.pad_token = tokenizer.eos_token
|
9 |
self.model = AutoModelForCausalLM.from_pretrained(path)
|
|
|
30 |
bad_words_ids.extend(additional_bad_words_ids)
|
31 |
|
32 |
input_ids = self.tokenizer.encode(inputs, return_tensors="pt")
|
33 |
+
max_generation_length = 75 # Desired number of tokens to generate
|
34 |
+
max_input_length = 4092 - max_generation_length # Maximum input length to allow space for generation
|
35 |
|
36 |
+
# Truncate input_ids to the most recent tokens that fit within the max_input_length
|
37 |
+
if input_ids.shape[1] > max_input_length:
|
38 |
+
input_ids = input_ids[:, -max_input_length:]
|
39 |
+
|
40 |
+
max_length = input_ids.shape[1] + max_generation_length
|
41 |
+
|
42 |
generated_ids = self.model.generate(
|
43 |
input_ids,
|
44 |
+
max_length=max_length, # 50 new tokens
|
45 |
bad_words_ids=bad_words_ids,
|
46 |
temperature=1,
|
47 |
top_k=40,
|