pfzhu awkrail commited on
Commit
b7f8796
1 Parent(s): ff0d639

Fix sample code by incorporating device into model, image, and text variables (#1)

Browse files

- Fix sample code by incorporating device into model, image, and text variables (26b45f800eed56011a5632aad3cd2d0e272c3f9c)


Co-authored-by: welix <awkrail@users.noreply.huggingface.co>

Files changed (1) hide show
  1. README.md +4 -4
README.md CHANGED
@@ -25,14 +25,14 @@ import torch
25
  from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
26
 
27
  HF_MODEL_PATH = 'line-corporation/clip-japanese-base'
 
28
  tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH, trust_remote_code=True)
29
  processor = AutoImageProcessor.from_pretrained(HF_MODEL_PATH, trust_remote_code=True)
30
- model = AutoModel.from_pretrained(HF_MODEL_PATH, trust_remote_code=True)
31
- device = "cuda" if torch.cuda.is_available() else "cpu"
32
 
33
  image = Image.open(io.BytesIO(requests.get('https://images.pexels.com/photos/2253275/pexels-photo-2253275.jpeg?auto=compress&cs=tinysrgb&dpr=3&h=750&w=1260').content))
34
- image = processor(image, return_tensors="pt")
35
- text = tokenizer(["犬", "猫", "象"])
36
 
37
  with torch.no_grad():
38
  image_features = model.get_image_features(**image)
 
25
  from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
26
 
27
  HF_MODEL_PATH = 'line-corporation/clip-japanese-base'
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
  tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH, trust_remote_code=True)
30
  processor = AutoImageProcessor.from_pretrained(HF_MODEL_PATH, trust_remote_code=True)
31
+ model = AutoModel.from_pretrained(HF_MODEL_PATH, trust_remote_code=True).to(device)
 
32
 
33
  image = Image.open(io.BytesIO(requests.get('https://images.pexels.com/photos/2253275/pexels-photo-2253275.jpeg?auto=compress&cs=tinysrgb&dpr=3&h=750&w=1260').content))
34
+ image = processor(image, return_tensors="pt").to(device)
35
+ text = tokenizer(["犬", "猫", "象"]).to(device)
36
 
37
  with torch.no_grad():
38
  image_features = model.get_image_features(**image)