File size: 712 Bytes
d66d160
8090b75
 
 
4625865
 
8090b75
d66d160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from transformers import ViTConfig, ViTForImageClassification
from transformers import ViTFeatureExtractor
from PIL import Image
import requests
import matplotlib.pyplot as plt



# option 1: load with randomly initialized weights (train from scratch)

config = ViTConfig(num_hidden_layers=12, hidden_size=768)
model = ViTForImageClassification(config)

print(config)

feature_extractor = ViTFeatureExtractor()

# or, to load one that corresponds to a checkpoint on the hub:
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image.save("cats.png")
image