oschamp commited on
Commit
9aa63b6
1 Parent(s): a7bf3db

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +26 -0
README.md CHANGED
@@ -76,3 +76,29 @@ The following hyperparameters were used during training:
76
  - Pytorch 1.13.1+cu117
77
  - Datasets 2.9.0
78
  - Tokenizers 0.13.2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  - Pytorch 1.13.1+cu117
77
  - Datasets 2.9.0
78
  - Tokenizers 0.13.2
79
+
80
+ ### Code to Run
81
+
82
+ from transformers import ViTFeatureExtractor
83
+ from transformers import ViTForImageClassification
84
+ import torch
85
+
86
+ vit = ViTForImageClassification.from_pretrained("oschamp/vit-artworkclassifier")
87
+ vit.eval()
88
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
89
+ vit.to(device)
90
+
91
+ model_name_or_path = 'google/vit-base-patch16-224-in21k'
92
+ feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
93
+
94
+ def vit_classify(image):
95
+ encoding = feature_extractor(images=image, return_tensors="pt")
96
+ encoding.keys()
97
+
98
+ pixel_values = encoding['pixel_values'].to(device)
99
+
100
+ outputs = vit(pixel_values)
101
+ logits = outputs.logits
102
+
103
+ prediction = logits.argmax(-1)
104
+ return prediction.item() #vit.config.id2label[prediction.item()]