Andron00e's picture
Update README.md
336e883
metadata
base_model: openai/clip-vit-base-patch32
tags:
  - generated_from_trainer
metrics:
  - accuracy
model-index:
  - name: outputs
    results: []
license: apache-2.0
datasets:
  - Andron00e/CIFAR10-custom
language:
  - en
library_name: transformers

outputs

This model is a fine-tuned version of openai/clip-vit-base-patch32 on an CIFAR10 dataset. It achieves the following results on the evaluation set:

  • Loss: 0.8115
  • Accuracy: 0.8255

Model description

Training and evaluation data

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 0.0002
  • train_batch_size: 10
  • eval_batch_size: 8
  • seed: 42
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • num_epochs: 4

Training results

Training Loss Epoch Step Validation Loss Accuracy
1.7258 0.02 100 1.6999 0.8048
1.669 0.04 200 1.6798 0.8055
1.6704 0.06 300 1.6599 0.8053
1.6655 0.08 400 1.6407 0.8047
1.5754 0.1 500 1.6223 0.809
1.6159 0.12 600 1.6040 0.8068
1.5663 0.15 700 1.5858 0.8073
1.5426 0.17 800 1.5677 0.8095
1.5794 0.19 900 1.5506 0.808
1.5504 0.21 1000 1.5342 0.8035
1.554 0.23 1100 1.5179 0.802
1.4831 0.25 1200 1.5022 0.7972
1.4718 0.27 1300 1.4867 0.7955
1.5206 0.29 1400 1.4716 0.796
1.4534 0.31 1500 1.4567 0.7963
1.3932 0.33 1600 1.4427 0.7875
1.4635 0.35 1700 1.4289 0.789
1.4339 0.38 1800 1.4151 0.793
1.4492 0.4 1900 1.4016 0.7973
1.4369 0.42 2000 1.3881 0.8018
1.4007 0.44 2100 1.3754 0.801
1.3697 0.46 2200 1.3627 0.8025
1.3298 0.48 2300 1.3505 0.8048
1.2809 0.5 2400 1.3386 0.8068
1.2989 0.52 2500 1.3272 0.8067
1.2958 0.54 2600 1.3159 0.81
1.3072 0.56 2700 1.3048 0.8097
1.2545 0.58 2800 1.2943 0.809
1.2722 0.6 2900 1.2834 0.8112
1.2628 0.62 3000 1.2732 0.8102
1.2357 0.65 3100 1.2632 0.8105
1.3189 0.67 3200 1.2532 0.8093
1.2465 0.69 3300 1.2436 0.8097
1.2579 0.71 3400 1.2342 0.8087
1.1963 0.73 3500 1.2249 0.8085
1.1701 0.75 3600 1.2159 0.8092
1.2117 0.77 3700 1.2069 0.8113
1.1907 0.79 3800 1.1984 0.8112
1.1903 0.81 3900 1.1902 0.8115
1.2357 0.83 4000 1.1821 0.8115
1.1924 0.85 4100 1.1738 0.8117
1.1914 0.88 4200 1.1657 0.8133
1.1536 0.9 4300 1.1580 0.8148
1.1893 0.92 4400 1.1505 0.8158
1.1811 0.94 4500 1.1433 0.8158
1.0182 0.96 4600 1.1358 0.8165
1.0396 0.98 4700 1.1287 0.8158
1.1502 1.0 4800 1.1217 0.816
1.1764 1.02 4900 1.1147 0.8158
1.1508 1.04 5000 1.1080 0.8152
1.0518 1.06 5100 1.1015 0.8155
1.0648 1.08 5200 1.0952 0.816
1.1631 1.1 5300 1.0889 0.8153
1.0629 1.12 5400 1.0826 0.8152
1.1151 1.15 5500 1.0771 0.815
1.1377 1.17 5600 1.0711 0.8145
1.0353 1.19 5700 1.0652 0.8158
1.068 1.21 5800 1.0594 0.815
1.0834 1.23 5900 1.0538 0.8162
1.0002 1.25 6000 1.0483 0.8165
1.0024 1.27 6100 1.0428 0.817
1.0609 1.29 6200 1.0376 0.817
1.0901 1.31 6300 1.0324 0.816
1.0772 1.33 6400 1.0275 0.8173
0.9434 1.35 6500 1.0226 0.817
0.9692 1.38 6600 1.0178 0.8157
1.0461 1.4 6700 1.0131 0.8155
1.0583 1.42 6800 1.0086 0.8143
0.9369 1.44 6900 1.0042 0.8157
1.0685 1.46 7000 0.9998 0.8152
1.062 1.48 7100 0.9955 0.8153
1.0394 1.5 7200 0.9912 0.8142
1.031 1.52 7300 0.9870 0.8157
0.9556 1.54 7400 0.9829 0.8155
0.9846 1.56 7500 0.9789 0.8152
0.9995 1.58 7600 0.9750 0.8158
1.0273 1.6 7700 0.9711 0.8163
0.9383 1.62 7800 0.9674 0.817
0.951 1.65 7900 0.9634 0.8163
0.9457 1.67 8000 0.9598 0.8167
1.012 1.69 8100 0.9563 0.816
0.9683 1.71 8200 0.9529 0.8158
0.9582 1.73 8300 0.9495 0.8157
0.9005 1.75 8400 0.9461 0.8162
0.888 1.77 8500 0.9428 0.8175
0.9267 1.79 8600 0.9396 0.8168
0.9298 1.81 8700 0.9364 0.8168
1.0072 1.83 8800 0.9334 0.8167
0.9425 1.85 8900 0.9303 0.8158
0.9729 1.88 9000 0.9273 0.8168
0.9104 1.9 9100 0.9244 0.8175
0.9153 1.92 9200 0.9216 0.817
0.9115 1.94 9300 0.9188 0.8165
0.9079 1.96 9400 0.9161 0.8168
0.8453 1.98 9500 0.9133 0.8175
0.8323 2.0 9600 0.9107 0.817
0.9071 2.02 9700 0.9080 0.8183
0.9331 2.04 9800 0.9054 0.8185
0.886 2.06 9900 0.9029 0.8193
0.8562 2.08 10000 0.9006 0.8183
0.8904 2.1 10100 0.8980 0.8193
0.8247 2.12 10200 0.8956 0.8188
0.8114 2.15 10300 0.8934 0.8202
0.96 2.17 10400 0.8912 0.8198
0.9326 2.19 10500 0.8889 0.8198
0.8057 2.21 10600 0.8867 0.8195
0.8266 2.23 10700 0.8846 0.8188
0.7909 2.25 10800 0.8823 0.82
0.886 2.27 10900 0.8803 0.8192
0.8691 2.29 11000 0.8783 0.8193
0.8676 2.31 11100 0.8763 0.8187
0.8147 2.33 11200 0.8744 0.819
0.7723 2.35 11300 0.8725 0.8195
0.9222 2.38 11400 0.8705 0.8188
0.9692 2.4 11500 0.8687 0.8195
0.8792 2.42 11600 0.8669 0.8188
0.939 2.44 11700 0.8650 0.8193
0.9093 2.46 11800 0.8633 0.8188
0.7794 2.48 11900 0.8616 0.8182
0.8572 2.5 12000 0.8599 0.8182
0.9035 2.52 12100 0.8582 0.8185
0.8063 2.54 12200 0.8566 0.8193
0.8935 2.56 12300 0.8550 0.8195
0.7991 2.58 12400 0.8535 0.8192
0.856 2.6 12500 0.8520 0.8195
0.8374 2.62 12600 0.8505 0.8197
0.8418 2.65 12700 0.8490 0.8203
0.9232 2.67 12800 0.8475 0.8208
0.8335 2.69 12900 0.8462 0.8207
0.8659 2.71 13000 0.8449 0.8205
0.9798 2.73 13100 0.8435 0.8205
0.7288 2.75 13200 0.8423 0.8205
0.9086 2.77 13300 0.8411 0.821
0.7912 2.79 13400 0.8398 0.8205
0.8675 2.81 13500 0.8386 0.8202
0.8045 2.83 13600 0.8374 0.8198
0.8421 2.85 13700 0.8362 0.8202
0.7453 2.88 13800 0.8350 0.8202
0.7348 2.9 13900 0.8339 0.8203
0.8977 2.92 14000 0.8328 0.8205
0.859 2.94 14100 0.8318 0.821
0.8571 2.96 14200 0.8307 0.8212
0.8158 2.98 14300 0.8297 0.8215
0.8635 3.0 14400 0.8287 0.8215
0.9095 3.02 14500 0.8277 0.8215
0.8491 3.04 14600 0.8268 0.8217
0.9136 3.06 14700 0.8259 0.8223
0.8652 3.08 14800 0.8250 0.8218
0.9299 3.1 14900 0.8242 0.8215
0.8259 3.12 15000 0.8233 0.8215
0.775 3.15 15100 0.8225 0.8222
0.801 3.17 15200 0.8217 0.8217
0.8535 3.19 15300 0.8209 0.8215
0.7973 3.21 15400 0.8202 0.8217
0.8937 3.23 15500 0.8195 0.8213
0.7632 3.25 15600 0.8188 0.821
0.8117 3.27 15700 0.8181 0.8212
0.8941 3.29 15800 0.8174 0.8217
0.802 3.31 15900 0.8168 0.8225
0.8303 3.33 16000 0.8161 0.8217
0.8264 3.35 16100 0.8155 0.8218
0.8411 3.38 16200 0.8149 0.8213
0.9378 3.4 16300 0.8143 0.8218
0.8514 3.42 16400 0.8138 0.8217
0.7313 3.44 16500 0.8133 0.8222
0.8238 3.46 16600 0.8128 0.8218
0.7876 3.48 16700 0.8123 0.8222
0.8364 3.5 16800 0.8118 0.8222
0.7049 3.52 16900 0.8114 0.8222
0.9101 3.54 17000 0.8109 0.8218
0.7984 3.56 17100 0.8105 0.822
0.85 3.58 17200 0.8101 0.8218
0.8677 3.6 17300 0.8098 0.822
0.8797 3.62 17400 0.8094 0.8218
0.7847 3.65 17500 0.8091 0.8222
0.8415 3.67 17600 0.8088 0.8218
0.8702 3.69 17700 0.8085 0.8222
0.8979 3.71 17800 0.8082 0.8222
0.8387 3.73 17900 0.8080 0.8222
0.8467 3.75 18000 0.8077 0.822
0.8729 3.77 18100 0.8075 0.822
0.8291 3.79 18200 0.8073 0.8222
0.7897 3.81 18300 0.8072 0.8222
0.8039 3.83 18400 0.8070 0.822
0.771 3.85 18500 0.8069 0.8223
0.7704 3.88 18600 0.8067 0.8223
0.7695 3.9 18700 0.8066 0.8223
0.8958 3.92 18800 0.8066 0.8223
0.8342 3.94 18900 0.8065 0.8223
0.8725 3.96 19000 0.8064 0.8225
0.8657 3.98 19100 0.8064 0.8225
0.779 4.0 19200 0.8064 0.8225

Framework versions

  • Transformers 4.35.2
  • Pytorch 2.1.0+cu118
  • Datasets 2.15.0
  • Tokenizers 0.15.0

Example of usage

Simple demo for Google Colab

!pip install datasets transformers[torch] accelerate -U
!git clone https://github.com/Andron00e/CLIPForImageClassification
%cd CLIPForImageClassification/clip_for_classification

import torch
from transformers import TrainingArguments
from datasets import load_dataset, load_metric
from transformers import CLIPProcessor, AutoModelForImageClassification
from modeling_clipforimageclassification import CLIPForImageClassification

processor = CLIPProcessor.from_pretrained("Andron00e/CLIPForImageClassification-v1")
model = CLIPForImageClassification.from_pretrained("Andron00e/CLIPForImageClassification-v1", 10)

dataset = load_dataset("Andron00e/CIFAR10-custom")
dataset = dataset["train"].train_test_split(test_size=0.2)
from datasets import DatasetDict

val_test = dataset["test"].train_test_split(test_size=0.5)
dataset = DatasetDict({
    "train": dataset["train"],
    "validation": val_test["train"],
    "test": val_test["test"],
})

classes = {0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"}

def transform(example_batch):
    inputs = processor(text=[classes[x] for x in example_batch['labels']], images=[x for x in example_batch['image']], padding=True, return_tensors='pt')
    inputs['labels'] = example_batch['labels']
    return inputs

def collate_fn(batch):
    return {
        'input_ids': torch.stack([x['input_ids'] for x in batch]),
        'attention_mask': torch.stack([x['attention_mask'] for x in batch]),
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

metric = load_metric("accuracy")

def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

training_args = TrainingArguments(
  output_dir="./outputs",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=4,
  fp16=False,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=dataset.with_transform(transform)["train"],
    eval_dataset=dataset.with_transform(transform)["validation"],
    tokenizer=model.processor,
)

train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

metrics = trainer.evaluate(processed_dataset['test'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

%cd ..
%cd ..