oschamp commited on
Commit
4375df4
1 Parent(s): bc0d2d6

update model card README.md

Browse files
Files changed (1) hide show
  1. README.md +16 -37
README.md CHANGED
@@ -21,7 +21,7 @@ model-index:
21
  metrics:
22
  - name: Accuracy
23
  type: accuracy
24
- value: 0.4887640449438202
25
  ---
26
 
27
  <!-- This model card has been generated automatically according to the information the Trainer had access to. You
@@ -29,10 +29,10 @@ should probably proofread and complete it, then remove this comment. -->
29
 
30
  # vit-artworkclassifier
31
 
32
- This model is a fine-tuned version of [google/vit-base-patch16-224-in21k](https://huggingface.co/google/vit-base-patch16-224-in21k) on the imagefolder dataset, a subset of the artbench-10 dataset. Train set size 1800, test set size 180, split equally over the 9 classes.
33
  It achieves the following results on the evaluation set:
34
- - Loss: 1.3363
35
- - Accuracy: 0.4888
36
 
37
  ## Model description
38
 
@@ -57,17 +57,24 @@ The following hyperparameters were used during training:
57
  - seed: 42
58
  - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
59
  - lr_scheduler_type: linear
60
- - num_epochs: 8
61
  - mixed_precision_training: Native AMP
62
 
63
  ### Training results
64
 
65
  | Training Loss | Epoch | Step | Validation Loss | Accuracy |
66
  |:-------------:|:-----:|:----:|:---------------:|:--------:|
67
- | 1.4136 | 1.79 | 100 | 1.5093 | 0.5112 |
68
- | 0.7189 | 3.57 | 200 | 1.3363 | 0.4888 |
69
- | 0.2717 | 5.36 | 300 | 1.4907 | 0.5281 |
70
- | 0.1227 | 7.14 | 400 | 1.4826 | 0.5562 |
 
 
 
 
 
 
 
71
 
72
 
73
  ### Framework versions
@@ -76,31 +83,3 @@ The following hyperparameters were used during training:
76
  - Pytorch 1.13.1+cu117
77
  - Datasets 2.9.0
78
  - Tokenizers 0.13.2
79
-
80
- ### Code to Run
81
-
82
- def vit_classify(image):
83
- from transformers import ViTFeatureExtractor
84
- from transformers import ViTForImageClassification
85
- import torch
86
-
87
- vit = ViTForImageClassification.from_pretrained("oschamp/vit-artworkclassifier")
88
- vit.eval()
89
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
90
- vit.to(device)
91
-
92
- model_name_or_path = 'google/vit-base-patch16-224-in21k'
93
- feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
94
-
95
- #LOAD IMAGE
96
-
97
- encoding = feature_extractor(images=image, return_tensors="pt")
98
- encoding.keys()
99
-
100
- pixel_values = encoding['pixel_values'].to(device)
101
-
102
- outputs = vit(pixel_values)
103
- logits = outputs.logits
104
-
105
- prediction = logits.argmax(-1)
106
- return prediction.item() #vit.config.id2label[prediction.item()]
 
21
  metrics:
22
  - name: Accuracy
23
  type: accuracy
24
+ value: 0.5947786606129398
25
  ---
26
 
27
  <!-- This model card has been generated automatically according to the information the Trainer had access to. You
 
29
 
30
  # vit-artworkclassifier
31
 
32
+ This model is a fine-tuned version of [google/vit-base-patch16-224-in21k](https://huggingface.co/google/vit-base-patch16-224-in21k) on the imagefolder dataset.
33
  It achieves the following results on the evaluation set:
34
+ - Loss: 1.1392
35
+ - Accuracy: 0.5948
36
 
37
  ## Model description
38
 
 
57
  - seed: 42
58
  - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
59
  - lr_scheduler_type: linear
60
+ - num_epochs: 4
61
  - mixed_precision_training: Native AMP
62
 
63
  ### Training results
64
 
65
  | Training Loss | Epoch | Step | Validation Loss | Accuracy |
66
  |:-------------:|:-----:|:----:|:---------------:|:--------:|
67
+ | 1.5906 | 0.36 | 100 | 1.4709 | 0.4847 |
68
+ | 1.3395 | 0.72 | 200 | 1.3208 | 0.5074 |
69
+ | 1.1461 | 1.08 | 300 | 1.3363 | 0.5165 |
70
+ | 0.9593 | 1.44 | 400 | 1.1790 | 0.5846 |
71
+ | 0.8761 | 1.8 | 500 | 1.1252 | 0.5902 |
72
+ | 0.5922 | 2.16 | 600 | 1.1392 | 0.5948 |
73
+ | 0.4803 | 2.52 | 700 | 1.1560 | 0.5936 |
74
+ | 0.4454 | 2.88 | 800 | 1.1545 | 0.6118 |
75
+ | 0.2271 | 3.24 | 900 | 1.2284 | 0.6039 |
76
+ | 0.207 | 3.6 | 1000 | 1.2625 | 0.5959 |
77
+ | 0.1958 | 3.96 | 1100 | 1.2621 | 0.6005 |
78
 
79
 
80
  ### Framework versions
 
83
  - Pytorch 1.13.1+cu117
84
  - Datasets 2.9.0
85
  - Tokenizers 0.13.2