OssamaLafhel
commited on
Commit
•
fb48465
1
Parent(s):
3e86372
Update handler.py
Browse files- handler.py +4 -1
handler.py
CHANGED
@@ -161,8 +161,11 @@ class EndpointHandler:
|
|
161 |
# load the model
|
162 |
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
|
163 |
model = GPTJForCausalLM.from_pretrained(path, low_cpu_mem_usage=True)
|
164 |
-
|
|
|
|
|
165 |
model.to(device)
|
|
|
166 |
# create inference pipeline
|
167 |
self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
|
168 |
|
|
|
161 |
# load the model
|
162 |
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
|
163 |
model = GPTJForCausalLM.from_pretrained(path, low_cpu_mem_usage=True)
|
164 |
+
|
165 |
+
# check for GPU
|
166 |
+
device = 0 if torch.cuda.is_available() else -1
|
167 |
model.to(device)
|
168 |
+
|
169 |
# create inference pipeline
|
170 |
self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
|
171 |
|