fix device

#1
by ariG23498 HF staff - opened
Files changed (1) hide show
  1. README.md +3 -5
README.md CHANGED
@@ -239,19 +239,17 @@ Then, copy the code snippet below to run the example.
239
 
240
  ```python
241
  from transformers import AutoModelForCausalLM, AutoTokenizer
242
- device = "auto"
243
  model_path = "ibm-granite/granite-3.0-3b-a800m-base"
244
  tokenizer = AutoTokenizer.from_pretrained(model_path)
245
  # drop device_map if running on CPU
246
- model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device)
247
  model.eval()
248
  # change input text as desired
249
  input_text = "Where is the MIT-IBM Watson AI Lab located?"
250
  # tokenize the text
251
- input_tokens = tokenizer(input_text, return_tensors="pt").to(device)
252
  # generate output tokens
253
- output = model.generate(**input_tokens,
254
- max_length=4000)
255
  # decode output tokens into text
256
  output = tokenizer.batch_decode(output)
257
  # print output
 
239
 
240
  ```python
241
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
242
  model_path = "ibm-granite/granite-3.0-3b-a800m-base"
243
  tokenizer = AutoTokenizer.from_pretrained(model_path)
244
  # drop device_map if running on CPU
245
+ model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
246
  model.eval()
247
  # change input text as desired
248
  input_text = "Where is the MIT-IBM Watson AI Lab located?"
249
  # tokenize the text
250
+ input_tokens = tokenizer(input_text, return_tensors="pt").to(model.device)
251
  # generate output tokens
252
+ output = model.generate(**input_tokens, max_length=4000)
 
253
  # decode output tokens into text
254
  output = tokenizer.batch_decode(output)
255
  # print output