Pierce Maloney
commited on
Commit
•
355a0ec
1
Parent(s):
a36be93
bugfix
Browse files- handler.py +3 -5
handler.py
CHANGED
@@ -53,13 +53,11 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, Stopping
|
|
53 |
# prediction = [{"generated_text": generated_text, "generated_ids": generated_ids[0][input_ids.shape[1]:].tolist()}]
|
54 |
# return prediction
|
55 |
|
56 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
57 |
-
|
58 |
class EndpointHandler():
|
59 |
def __init__(self, path=""):
|
60 |
self.model_path = path
|
61 |
tokenizer = AutoTokenizer.from_pretrained(path)
|
62 |
-
tokenizer.pad_token =
|
63 |
self.tokenizer = tokenizer
|
64 |
# Initialize the pipeline for text generation
|
65 |
self.text_generation_pipeline = pipeline("text-generation", model=path, tokenizer=self.tokenizer, device=0) # device=0 for CUDA
|
@@ -82,10 +80,10 @@ class EndpointHandler():
|
|
82 |
# Generate text using the pipeline
|
83 |
generation_kwargs = {
|
84 |
"max_length": 75, # Adjust as needed
|
85 |
-
"temperature":
|
86 |
"top_k": 40,
|
87 |
"bad_words_ids": bad_words_ids,
|
88 |
-
"pad_token_id": self.tokenizer.eos_token_id # Ensure padding with EOS token
|
89 |
}
|
90 |
generated_outputs = self.text_generation_pipeline(inputs, **generation_kwargs)
|
91 |
|
|
|
53 |
# prediction = [{"generated_text": generated_text, "generated_ids": generated_ids[0][input_ids.shape[1]:].tolist()}]
|
54 |
# return prediction
|
55 |
|
|
|
|
|
56 |
class EndpointHandler():
|
57 |
def __init__(self, path=""):
|
58 |
self.model_path = path
|
59 |
tokenizer = AutoTokenizer.from_pretrained(path)
|
60 |
+
tokenizer.pad_token = tokenizer.eos_token
|
61 |
self.tokenizer = tokenizer
|
62 |
# Initialize the pipeline for text generation
|
63 |
self.text_generation_pipeline = pipeline("text-generation", model=path, tokenizer=self.tokenizer, device=0) # device=0 for CUDA
|
|
|
80 |
# Generate text using the pipeline
|
81 |
generation_kwargs = {
|
82 |
"max_length": 75, # Adjust as needed
|
83 |
+
"temperature": 0.7,
|
84 |
"top_k": 40,
|
85 |
"bad_words_ids": bad_words_ids,
|
86 |
+
# "pad_token_id": self.tokenizer.eos_token_id # Ensure padding with EOS token
|
87 |
}
|
88 |
generated_outputs = self.text_generation_pipeline(inputs, **generation_kwargs)
|
89 |
|