metadata
base_model: openai/clip-vit-base-patch32
tags:
- generated_from_trainer
metrics:
- accuracy
model-index:
- name: outputs
results: []
license: apache-2.0
datasets:
- Andron00e/CIFAR10-custom
language:
- en
library_name: transformers
outputs
This model is a fine-tuned version of openai/clip-vit-base-patch32 on an CIFAR10 dataset. It achieves the following results on the evaluation set:
- Loss: 0.8115
- Accuracy: 0.8255
Model description
Training and evaluation data
Training procedure
Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: 0.0002
- train_batch_size: 10
- eval_batch_size: 8
- seed: 42
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: linear
- num_epochs: 4
Training results
Training Loss | Epoch | Step | Validation Loss | Accuracy |
---|---|---|---|---|
1.7258 | 0.02 | 100 | 1.6999 | 0.8048 |
1.669 | 0.04 | 200 | 1.6798 | 0.8055 |
1.6704 | 0.06 | 300 | 1.6599 | 0.8053 |
1.6655 | 0.08 | 400 | 1.6407 | 0.8047 |
1.5754 | 0.1 | 500 | 1.6223 | 0.809 |
1.6159 | 0.12 | 600 | 1.6040 | 0.8068 |
1.5663 | 0.15 | 700 | 1.5858 | 0.8073 |
1.5426 | 0.17 | 800 | 1.5677 | 0.8095 |
1.5794 | 0.19 | 900 | 1.5506 | 0.808 |
1.5504 | 0.21 | 1000 | 1.5342 | 0.8035 |
1.554 | 0.23 | 1100 | 1.5179 | 0.802 |
1.4831 | 0.25 | 1200 | 1.5022 | 0.7972 |
1.4718 | 0.27 | 1300 | 1.4867 | 0.7955 |
1.5206 | 0.29 | 1400 | 1.4716 | 0.796 |
1.4534 | 0.31 | 1500 | 1.4567 | 0.7963 |
1.3932 | 0.33 | 1600 | 1.4427 | 0.7875 |
1.4635 | 0.35 | 1700 | 1.4289 | 0.789 |
1.4339 | 0.38 | 1800 | 1.4151 | 0.793 |
1.4492 | 0.4 | 1900 | 1.4016 | 0.7973 |
1.4369 | 0.42 | 2000 | 1.3881 | 0.8018 |
1.4007 | 0.44 | 2100 | 1.3754 | 0.801 |
1.3697 | 0.46 | 2200 | 1.3627 | 0.8025 |
1.3298 | 0.48 | 2300 | 1.3505 | 0.8048 |
1.2809 | 0.5 | 2400 | 1.3386 | 0.8068 |
1.2989 | 0.52 | 2500 | 1.3272 | 0.8067 |
1.2958 | 0.54 | 2600 | 1.3159 | 0.81 |
1.3072 | 0.56 | 2700 | 1.3048 | 0.8097 |
1.2545 | 0.58 | 2800 | 1.2943 | 0.809 |
1.2722 | 0.6 | 2900 | 1.2834 | 0.8112 |
1.2628 | 0.62 | 3000 | 1.2732 | 0.8102 |
1.2357 | 0.65 | 3100 | 1.2632 | 0.8105 |
1.3189 | 0.67 | 3200 | 1.2532 | 0.8093 |
1.2465 | 0.69 | 3300 | 1.2436 | 0.8097 |
1.2579 | 0.71 | 3400 | 1.2342 | 0.8087 |
1.1963 | 0.73 | 3500 | 1.2249 | 0.8085 |
1.1701 | 0.75 | 3600 | 1.2159 | 0.8092 |
1.2117 | 0.77 | 3700 | 1.2069 | 0.8113 |
1.1907 | 0.79 | 3800 | 1.1984 | 0.8112 |
1.1903 | 0.81 | 3900 | 1.1902 | 0.8115 |
1.2357 | 0.83 | 4000 | 1.1821 | 0.8115 |
1.1924 | 0.85 | 4100 | 1.1738 | 0.8117 |
1.1914 | 0.88 | 4200 | 1.1657 | 0.8133 |
1.1536 | 0.9 | 4300 | 1.1580 | 0.8148 |
1.1893 | 0.92 | 4400 | 1.1505 | 0.8158 |
1.1811 | 0.94 | 4500 | 1.1433 | 0.8158 |
1.0182 | 0.96 | 4600 | 1.1358 | 0.8165 |
1.0396 | 0.98 | 4700 | 1.1287 | 0.8158 |
1.1502 | 1.0 | 4800 | 1.1217 | 0.816 |
1.1764 | 1.02 | 4900 | 1.1147 | 0.8158 |
1.1508 | 1.04 | 5000 | 1.1080 | 0.8152 |
1.0518 | 1.06 | 5100 | 1.1015 | 0.8155 |
1.0648 | 1.08 | 5200 | 1.0952 | 0.816 |
1.1631 | 1.1 | 5300 | 1.0889 | 0.8153 |
1.0629 | 1.12 | 5400 | 1.0826 | 0.8152 |
1.1151 | 1.15 | 5500 | 1.0771 | 0.815 |
1.1377 | 1.17 | 5600 | 1.0711 | 0.8145 |
1.0353 | 1.19 | 5700 | 1.0652 | 0.8158 |
1.068 | 1.21 | 5800 | 1.0594 | 0.815 |
1.0834 | 1.23 | 5900 | 1.0538 | 0.8162 |
1.0002 | 1.25 | 6000 | 1.0483 | 0.8165 |
1.0024 | 1.27 | 6100 | 1.0428 | 0.817 |
1.0609 | 1.29 | 6200 | 1.0376 | 0.817 |
1.0901 | 1.31 | 6300 | 1.0324 | 0.816 |
1.0772 | 1.33 | 6400 | 1.0275 | 0.8173 |
0.9434 | 1.35 | 6500 | 1.0226 | 0.817 |
0.9692 | 1.38 | 6600 | 1.0178 | 0.8157 |
1.0461 | 1.4 | 6700 | 1.0131 | 0.8155 |
1.0583 | 1.42 | 6800 | 1.0086 | 0.8143 |
0.9369 | 1.44 | 6900 | 1.0042 | 0.8157 |
1.0685 | 1.46 | 7000 | 0.9998 | 0.8152 |
1.062 | 1.48 | 7100 | 0.9955 | 0.8153 |
1.0394 | 1.5 | 7200 | 0.9912 | 0.8142 |
1.031 | 1.52 | 7300 | 0.9870 | 0.8157 |
0.9556 | 1.54 | 7400 | 0.9829 | 0.8155 |
0.9846 | 1.56 | 7500 | 0.9789 | 0.8152 |
0.9995 | 1.58 | 7600 | 0.9750 | 0.8158 |
1.0273 | 1.6 | 7700 | 0.9711 | 0.8163 |
0.9383 | 1.62 | 7800 | 0.9674 | 0.817 |
0.951 | 1.65 | 7900 | 0.9634 | 0.8163 |
0.9457 | 1.67 | 8000 | 0.9598 | 0.8167 |
1.012 | 1.69 | 8100 | 0.9563 | 0.816 |
0.9683 | 1.71 | 8200 | 0.9529 | 0.8158 |
0.9582 | 1.73 | 8300 | 0.9495 | 0.8157 |
0.9005 | 1.75 | 8400 | 0.9461 | 0.8162 |
0.888 | 1.77 | 8500 | 0.9428 | 0.8175 |
0.9267 | 1.79 | 8600 | 0.9396 | 0.8168 |
0.9298 | 1.81 | 8700 | 0.9364 | 0.8168 |
1.0072 | 1.83 | 8800 | 0.9334 | 0.8167 |
0.9425 | 1.85 | 8900 | 0.9303 | 0.8158 |
0.9729 | 1.88 | 9000 | 0.9273 | 0.8168 |
0.9104 | 1.9 | 9100 | 0.9244 | 0.8175 |
0.9153 | 1.92 | 9200 | 0.9216 | 0.817 |
0.9115 | 1.94 | 9300 | 0.9188 | 0.8165 |
0.9079 | 1.96 | 9400 | 0.9161 | 0.8168 |
0.8453 | 1.98 | 9500 | 0.9133 | 0.8175 |
0.8323 | 2.0 | 9600 | 0.9107 | 0.817 |
0.9071 | 2.02 | 9700 | 0.9080 | 0.8183 |
0.9331 | 2.04 | 9800 | 0.9054 | 0.8185 |
0.886 | 2.06 | 9900 | 0.9029 | 0.8193 |
0.8562 | 2.08 | 10000 | 0.9006 | 0.8183 |
0.8904 | 2.1 | 10100 | 0.8980 | 0.8193 |
0.8247 | 2.12 | 10200 | 0.8956 | 0.8188 |
0.8114 | 2.15 | 10300 | 0.8934 | 0.8202 |
0.96 | 2.17 | 10400 | 0.8912 | 0.8198 |
0.9326 | 2.19 | 10500 | 0.8889 | 0.8198 |
0.8057 | 2.21 | 10600 | 0.8867 | 0.8195 |
0.8266 | 2.23 | 10700 | 0.8846 | 0.8188 |
0.7909 | 2.25 | 10800 | 0.8823 | 0.82 |
0.886 | 2.27 | 10900 | 0.8803 | 0.8192 |
0.8691 | 2.29 | 11000 | 0.8783 | 0.8193 |
0.8676 | 2.31 | 11100 | 0.8763 | 0.8187 |
0.8147 | 2.33 | 11200 | 0.8744 | 0.819 |
0.7723 | 2.35 | 11300 | 0.8725 | 0.8195 |
0.9222 | 2.38 | 11400 | 0.8705 | 0.8188 |
0.9692 | 2.4 | 11500 | 0.8687 | 0.8195 |
0.8792 | 2.42 | 11600 | 0.8669 | 0.8188 |
0.939 | 2.44 | 11700 | 0.8650 | 0.8193 |
0.9093 | 2.46 | 11800 | 0.8633 | 0.8188 |
0.7794 | 2.48 | 11900 | 0.8616 | 0.8182 |
0.8572 | 2.5 | 12000 | 0.8599 | 0.8182 |
0.9035 | 2.52 | 12100 | 0.8582 | 0.8185 |
0.8063 | 2.54 | 12200 | 0.8566 | 0.8193 |
0.8935 | 2.56 | 12300 | 0.8550 | 0.8195 |
0.7991 | 2.58 | 12400 | 0.8535 | 0.8192 |
0.856 | 2.6 | 12500 | 0.8520 | 0.8195 |
0.8374 | 2.62 | 12600 | 0.8505 | 0.8197 |
0.8418 | 2.65 | 12700 | 0.8490 | 0.8203 |
0.9232 | 2.67 | 12800 | 0.8475 | 0.8208 |
0.8335 | 2.69 | 12900 | 0.8462 | 0.8207 |
0.8659 | 2.71 | 13000 | 0.8449 | 0.8205 |
0.9798 | 2.73 | 13100 | 0.8435 | 0.8205 |
0.7288 | 2.75 | 13200 | 0.8423 | 0.8205 |
0.9086 | 2.77 | 13300 | 0.8411 | 0.821 |
0.7912 | 2.79 | 13400 | 0.8398 | 0.8205 |
0.8675 | 2.81 | 13500 | 0.8386 | 0.8202 |
0.8045 | 2.83 | 13600 | 0.8374 | 0.8198 |
0.8421 | 2.85 | 13700 | 0.8362 | 0.8202 |
0.7453 | 2.88 | 13800 | 0.8350 | 0.8202 |
0.7348 | 2.9 | 13900 | 0.8339 | 0.8203 |
0.8977 | 2.92 | 14000 | 0.8328 | 0.8205 |
0.859 | 2.94 | 14100 | 0.8318 | 0.821 |
0.8571 | 2.96 | 14200 | 0.8307 | 0.8212 |
0.8158 | 2.98 | 14300 | 0.8297 | 0.8215 |
0.8635 | 3.0 | 14400 | 0.8287 | 0.8215 |
0.9095 | 3.02 | 14500 | 0.8277 | 0.8215 |
0.8491 | 3.04 | 14600 | 0.8268 | 0.8217 |
0.9136 | 3.06 | 14700 | 0.8259 | 0.8223 |
0.8652 | 3.08 | 14800 | 0.8250 | 0.8218 |
0.9299 | 3.1 | 14900 | 0.8242 | 0.8215 |
0.8259 | 3.12 | 15000 | 0.8233 | 0.8215 |
0.775 | 3.15 | 15100 | 0.8225 | 0.8222 |
0.801 | 3.17 | 15200 | 0.8217 | 0.8217 |
0.8535 | 3.19 | 15300 | 0.8209 | 0.8215 |
0.7973 | 3.21 | 15400 | 0.8202 | 0.8217 |
0.8937 | 3.23 | 15500 | 0.8195 | 0.8213 |
0.7632 | 3.25 | 15600 | 0.8188 | 0.821 |
0.8117 | 3.27 | 15700 | 0.8181 | 0.8212 |
0.8941 | 3.29 | 15800 | 0.8174 | 0.8217 |
0.802 | 3.31 | 15900 | 0.8168 | 0.8225 |
0.8303 | 3.33 | 16000 | 0.8161 | 0.8217 |
0.8264 | 3.35 | 16100 | 0.8155 | 0.8218 |
0.8411 | 3.38 | 16200 | 0.8149 | 0.8213 |
0.9378 | 3.4 | 16300 | 0.8143 | 0.8218 |
0.8514 | 3.42 | 16400 | 0.8138 | 0.8217 |
0.7313 | 3.44 | 16500 | 0.8133 | 0.8222 |
0.8238 | 3.46 | 16600 | 0.8128 | 0.8218 |
0.7876 | 3.48 | 16700 | 0.8123 | 0.8222 |
0.8364 | 3.5 | 16800 | 0.8118 | 0.8222 |
0.7049 | 3.52 | 16900 | 0.8114 | 0.8222 |
0.9101 | 3.54 | 17000 | 0.8109 | 0.8218 |
0.7984 | 3.56 | 17100 | 0.8105 | 0.822 |
0.85 | 3.58 | 17200 | 0.8101 | 0.8218 |
0.8677 | 3.6 | 17300 | 0.8098 | 0.822 |
0.8797 | 3.62 | 17400 | 0.8094 | 0.8218 |
0.7847 | 3.65 | 17500 | 0.8091 | 0.8222 |
0.8415 | 3.67 | 17600 | 0.8088 | 0.8218 |
0.8702 | 3.69 | 17700 | 0.8085 | 0.8222 |
0.8979 | 3.71 | 17800 | 0.8082 | 0.8222 |
0.8387 | 3.73 | 17900 | 0.8080 | 0.8222 |
0.8467 | 3.75 | 18000 | 0.8077 | 0.822 |
0.8729 | 3.77 | 18100 | 0.8075 | 0.822 |
0.8291 | 3.79 | 18200 | 0.8073 | 0.8222 |
0.7897 | 3.81 | 18300 | 0.8072 | 0.8222 |
0.8039 | 3.83 | 18400 | 0.8070 | 0.822 |
0.771 | 3.85 | 18500 | 0.8069 | 0.8223 |
0.7704 | 3.88 | 18600 | 0.8067 | 0.8223 |
0.7695 | 3.9 | 18700 | 0.8066 | 0.8223 |
0.8958 | 3.92 | 18800 | 0.8066 | 0.8223 |
0.8342 | 3.94 | 18900 | 0.8065 | 0.8223 |
0.8725 | 3.96 | 19000 | 0.8064 | 0.8225 |
0.8657 | 3.98 | 19100 | 0.8064 | 0.8225 |
0.779 | 4.0 | 19200 | 0.8064 | 0.8225 |
Framework versions
- Transformers 4.35.2
- Pytorch 2.1.0+cu118
- Datasets 2.15.0
- Tokenizers 0.15.0
Example of usage
Simple demo for Google Colab
!pip install datasets transformers[torch] accelerate -U
!git clone https://github.com/Andron00e/CLIPForImageClassification
%cd CLIPForImageClassification/clip_for_classification
import torch
from transformers import TrainingArguments
from datasets import load_dataset, load_metric
from transformers import CLIPProcessor, AutoModelForImageClassification
from modeling_clipforimageclassification import CLIPForImageClassification
processor = CLIPProcessor.from_pretrained("Andron00e/CLIPForImageClassification-v1")
model = CLIPForImageClassification.from_pretrained("Andron00e/CLIPForImageClassification-v1", 10)
dataset = load_dataset("Andron00e/CIFAR10-custom")
dataset = dataset["train"].train_test_split(test_size=0.2)
from datasets import DatasetDict
val_test = dataset["test"].train_test_split(test_size=0.5)
dataset = DatasetDict({
"train": dataset["train"],
"validation": val_test["train"],
"test": val_test["test"],
})
classes = {0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"}
def transform(example_batch):
inputs = processor(text=[classes[x] for x in example_batch['labels']], images=[x for x in example_batch['image']], padding=True, return_tensors='pt')
inputs['labels'] = example_batch['labels']
return inputs
def collate_fn(batch):
return {
'input_ids': torch.stack([x['input_ids'] for x in batch]),
'attention_mask': torch.stack([x['attention_mask'] for x in batch]),
'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
'labels': torch.tensor([x['labels'] for x in batch])
}
metric = load_metric("accuracy")
def compute_metrics(p):
return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)
training_args = TrainingArguments(
output_dir="./outputs",
per_device_train_batch_size=16,
evaluation_strategy="steps",
num_train_epochs=4,
fp16=False,
save_steps=100,
eval_steps=100,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
push_to_hub=False,
report_to='tensorboard',
load_best_model_at_end=True,
)
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
data_collator=collate_fn,
compute_metrics=compute_metrics,
train_dataset=dataset.with_transform(transform)["train"],
eval_dataset=dataset.with_transform(transform)["validation"],
tokenizer=model.processor,
)
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()
metrics = trainer.evaluate(processed_dataset['test'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
%cd ..
%cd ..