OssamaLafhel commited on
Commit
fb48465
1 Parent(s): 3e86372

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +4 -1
handler.py CHANGED
@@ -161,8 +161,11 @@ class EndpointHandler:
161
  # load the model
162
  tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
163
  model = GPTJForCausalLM.from_pretrained(path, low_cpu_mem_usage=True)
164
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
165
  model.to(device)
 
166
  # create inference pipeline
167
  self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
168
 
 
161
  # load the model
162
  tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
163
  model = GPTJForCausalLM.from_pretrained(path, low_cpu_mem_usage=True)
164
+
165
+ # check for GPU
166
+ device = 0 if torch.cuda.is_available() else -1
167
  model.to(device)
168
+
169
  # create inference pipeline
170
  self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
171