Update README.md
Browse files
README.md
CHANGED
@@ -76,3 +76,29 @@ The following hyperparameters were used during training:
|
|
76 |
- Pytorch 1.13.1+cu117
|
77 |
- Datasets 2.9.0
|
78 |
- Tokenizers 0.13.2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
- Pytorch 1.13.1+cu117
|
77 |
- Datasets 2.9.0
|
78 |
- Tokenizers 0.13.2
|
79 |
+
|
80 |
+
### Code to Run
|
81 |
+
|
82 |
+
from transformers import ViTFeatureExtractor
|
83 |
+
from transformers import ViTForImageClassification
|
84 |
+
import torch
|
85 |
+
|
86 |
+
vit = ViTForImageClassification.from_pretrained("oschamp/vit-artworkclassifier")
|
87 |
+
vit.eval()
|
88 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
89 |
+
vit.to(device)
|
90 |
+
|
91 |
+
model_name_or_path = 'google/vit-base-patch16-224-in21k'
|
92 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
|
93 |
+
|
94 |
+
def vit_classify(image):
|
95 |
+
encoding = feature_extractor(images=image, return_tensors="pt")
|
96 |
+
encoding.keys()
|
97 |
+
|
98 |
+
pixel_values = encoding['pixel_values'].to(device)
|
99 |
+
|
100 |
+
outputs = vit(pixel_values)
|
101 |
+
logits = outputs.logits
|
102 |
+
|
103 |
+
prediction = logits.argmax(-1)
|
104 |
+
return prediction.item() #vit.config.id2label[prediction.item()]
|