# Install

In [1]:
%pip install uv

Note: you may need to restart the kernel to use updated packages.


In [2]:
!uv pip install dagshub setuptools accelerate toml torch torchvision transformers mlflow datasets ipywidgets python-dotenv evaluate

[2mAudited [1m12 packages[0m in 8ms[0m


# Setup

In [3]:
import os
import toml
import torch
import mlflow
import dagshub
import datasets
import evaluate
from dotenv import load_dotenv
from torchvision.transforms import v2
from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer

ENV_PATH = "/Users/andrewmayes/Openclassroom/CanineNet/.env"
CONFIG_PATH = "/Users/andrewmayes/Openclassroom/CanineNet/code/config.toml"
CONFIG = toml.load(CONFIG_PATH)

load_dotenv(ENV_PATH)

dagshub.init(repo_name=os.environ['MLFLOW_TRACKING_PROJECTNAME'], repo_owner=os.environ['MLFLOW_TRACKING_USERNAME'], mlflow=True, dvc=True)

os.environ['MLFLOW_TRACKING_USERNAME'] = "amaye15"

mlflow.set_tracking_uri(f'https://dagshub.com/' + os.environ['MLFLOW_TRACKING_USERNAME']
                         + '/' + os.environ['MLFLOW_TRACKING_PROJECTNAME'] + '.mlflow')

CREATE_DATASET = True
ORIGINAL_DATASET = "Alanox/stanford-dogs"
MODIFIED_DATASET = "amaye15/stanford-dogs"
REMOVE_COLUMNS = ["name", "annotations"]
RENAME_COLUMNS = {"image":"pixel_values", "target":"label"}
SPLIT = 0.2

METRICS = ["accuracy", "f1", "precision", "recall"]
# MODELS = 'google/vit-base-patch16-224'
# MODELS = "google/siglip-base-patch16-224"



# Dataset

In [4]:
if CREATE_DATASET:
    ds = datasets.load_dataset(ORIGINAL_DATASET, token=os.getenv("HF_TOKEN"), split="full", trust_remote_code=True)
    ds = ds.remove_columns(REMOVE_COLUMNS).rename_columns(RENAME_COLUMNS)

    labels = ds.select_columns("label").to_pandas().sort_values("label").get("label").unique().tolist()
    numbers = range(len(labels))
    label2int = dict(zip(labels, numbers))
    int2label = dict(zip(numbers, labels))

    for key, val in label2int.items():
        print(f"{key}: {val}")

    ds = ds.class_encode_column("label")
    ds = ds.align_labels_with_mapping(label2int, "label")

    ds = ds.train_test_split(test_size=SPLIT, stratify_by_column = "label")
    #ds.push_to_hub(MODIFIED_DATASET, token=os.getenv("HF_TOKEN"))

    CONFIG["label2int"] = str(label2int)
    CONFIG["int2label"] = str(int2label)

    # with open("output.toml", "w") as toml_file:
    #     toml.dump(toml.dumps(CONFIG), toml_file)

    #ds = datasets.load_dataset(MODIFIED_DATASET, token=os.getenv("HF_TOKEN"), trust_remote_code=True, streaming=True)

Affenpinscher: 0
Afghan Hound: 1
African Hunting Dog: 2
Airedale: 3
American Staffordshire Terrier: 4
Appenzeller: 5
Australian Terrier: 6
Basenji: 7
Basset: 8
Beagle: 9
Bedlington Terrier: 10
Bernese Mountain Dog: 11
Black And Tan Coonhound: 12
Blenheim Spaniel: 13
Bloodhound: 14
Bluetick: 15
Border Collie: 16
Border Terrier: 17
Borzoi: 18
Boston Bull: 19
Bouvier Des Flandres: 20
Boxer: 21
Brabancon Griffon: 22
Briard: 23
Brittany Spaniel: 24
Bull Mastiff: 25
Cairn: 26
Cardigan: 27
Chesapeake Bay Retriever: 28
Chihuahua: 29
Chow: 30
Clumber: 31
Cocker Spaniel: 32
Collie: 33
Curly Coated Retriever: 34
Dandie Dinmont: 35
Dhole: 36
Dingo: 37
Doberman: 38
English Foxhound: 39
English Setter: 40
English Springer: 41
Entlebucher: 42
Eskimo Dog: 43
Flat Coated Retriever: 44
French Bulldog: 45
German Shepherd: 46
German Short Haired Pointer: 47
Giant Schnauzer: 48
Golden Retriever: 49
Gordon Setter: 50
Great Dane: 51
Great Pyrenees: 52
Greater Swiss Mountain Dog: 53
Groenendael: 54
Ibizan Hou

In [5]:
metrics = {metric: evaluate.load(metric) for metric in METRICS}


# for lr in [5e-3, 5e-4, 5e-5]: # 5e-5
#     for batch in [64]: # 32
#         for model_name in ["google/vit-base-patch16-224", "microsoft/swinv2-base-patch4-window16-256", "google/siglip-base-patch16-224"]: # "facebook/dinov2-base"

lr = 5e-4
batch = 32
model_name = "microsoft/resnet-50"

image_processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(
model_name,
num_labels=len(label2int),
id2label=int2label,
label2id=label2int,
ignore_mismatched_sizes=True,
)

# Then, in your transformations:
def train_transform(examples, num_ops=10, magnitude=9, num_magnitude_bins=31):

    transformation = v2.Compose(
        [
            v2.RandAugment(
                num_ops=num_ops,
                magnitude=magnitude,
                num_magnitude_bins=num_magnitude_bins,
            )
        ]
    )
    # Ensure each image has three dimensions (in this case, ensure it's RGB)
    examples["pixel_values"] = [
        image.convert("RGB") for image in examples["pixel_values"]
    ]
    # Apply transformations
    examples["pixel_values"] = [
        image_processor(transformation(image), return_tensors="pt")[
            "pixel_values"
        ].squeeze()
        for image in examples["pixel_values"]
    ]
    return examples


def test_transform(examples):
    # Ensure each image is RGB
    examples["pixel_values"] = [
        image.convert("RGB") for image in examples["pixel_values"]
    ]
    # Apply processing
    examples["pixel_values"] = [
        image_processor(image, return_tensors="pt")["pixel_values"].squeeze()
        for image in examples["pixel_values"]
    ]
    return examples


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    # predictions = np.argmax(logits, axis=-1)
    results = {}
    for key, val in metrics.items():
        if "accuracy" == key:
            result = next(
                iter(val.compute(predictions=predictions, references=labels).items())
            )
        if "accuracy" != key:
            result = next(
                iter(
                    val.compute(
                        predictions=predictions, references=labels, average="macro"
                    ).items()
                )
            )
        results[result[0]] = result[1]
    return results


def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}


def preprocess_logits_for_metrics(logits, labels):
    """
    Original Trainer may have a memory leak.
    This is a workaround to avoid storing too many tensors that are not needed.
    """
    pred_ids = torch.argmax(logits, dim=-1)
    return pred_ids

ds["train"].set_transform(train_transform)
ds["test"].set_transform(test_transform)

training_args = TrainingArguments(**CONFIG["training_args"])
training_args.per_device_train_batch_size = batch
training_args.per_device_eval_batch_size = batch
training_args.hub_model_id = f"amaye15/{model_name.replace('/','-')}-batch{batch}-lr{lr}-standford-dogs"

mlflow.start_run(run_name=f"{model_name.replace('/','-')}-batch{batch}-lr{lr}")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds["train"],
    eval_dataset=ds["test"],
    tokenizer=image_processor,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    # callbacks=[early_stopping_callback],
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)

# Train the model
trainer.train()

trainer.push_to_hub()

mlflow.end_run()

Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([120]) in the model instantiated
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([120, 2048]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
max_steps is given, it will override any value given in num_train_epochs


  0%|          | 0/1000 [00:00<?, ?it/s]

{'loss': 4.7829, 'grad_norm': 0.6043907999992371, 'learning_rate': 4.9500000000000004e-05, 'epoch': 0.08}


  0%|          | 0/129 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 4.77471923828125, 'eval_accuracy': 0.2118561710398445, 'eval_f1': 0.187375517726323, 'eval_precision': 0.3919036860239945, 'eval_recall': 0.19824327355121704, 'eval_runtime': 33.4309, 'eval_samples_per_second': 123.12, 'eval_steps_per_second': 3.859, 'epoch': 0.08}
{'loss': 4.7714, 'grad_norm': 0.6754865050315857, 'learning_rate': 4.9e-05, 'epoch': 0.16}


  0%|          | 0/129 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 4.757228851318359, 'eval_accuracy': 0.20383867832847424, 'eval_f1': 0.18416981866925827, 'eval_precision': 0.42618136770448983, 'eval_recall': 0.18363998800713158, 'eval_runtime': 30.7622, 'eval_samples_per_second': 133.801, 'eval_steps_per_second': 4.193, 'epoch': 0.16}
{'loss': 4.7606, 'grad_norm': 0.6417286992073059, 'learning_rate': 4.85e-05, 'epoch': 0.23}


  0%|          | 0/129 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 4.736657619476318, 'eval_accuracy': 0.358600583090379, 'eval_f1': 0.3433113409864575, 'eval_precision': 0.6517178219942168, 'eval_recall': 0.3306848427836897, 'eval_runtime': 29.5337, 'eval_samples_per_second': 139.366, 'eval_steps_per_second': 4.368, 'epoch': 0.23}
{'loss': 4.747, 'grad_norm': 0.6243997812271118, 'learning_rate': 4.8e-05, 'epoch': 0.31}


  0%|          | 0/129 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 4.714941501617432, 'eval_accuracy': 0.4302721088435374, 'eval_f1': 0.42721541071711805, 'eval_precision': 0.773414620851018, 'eval_recall': 0.40385862239180403, 'eval_runtime': 31.0076, 'eval_samples_per_second': 132.742, 'eval_steps_per_second': 4.16, 'epoch': 0.31}
{'loss': 4.7253, 'grad_norm': 0.6433669328689575, 'learning_rate': 4.75e-05, 'epoch': 0.39}


  0%|          | 0/129 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 4.684640884399414, 'eval_accuracy': 0.4361030126336249, 'eval_f1': 0.4677602303574034, 'eval_precision': 0.7906333558807621, 'eval_recall': 0.4160270614713831, 'eval_runtime': 31.2777, 'eval_samples_per_second': 131.595, 'eval_steps_per_second': 4.124, 'epoch': 0.39}
{'loss': 4.7069, 'grad_norm': 0.7207397818565369, 'learning_rate': 4.7e-05, 'epoch': 0.47}


  0%|          | 0/129 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 4.653403282165527, 'eval_accuracy': 0.533041788143829, 'eval_f1': 0.5396864056951644, 'eval_precision': 0.804847761263092, 'eval_recall': 0.5092981237432466, 'eval_runtime': 28.6761, 'eval_samples_per_second': 143.534, 'eval_steps_per_second': 4.499, 'epoch': 0.47}
{'loss': 4.6857, 'grad_norm': 0.7303667068481445, 'learning_rate': 4.6500000000000005e-05, 'epoch': 0.54}


  0%|          | 0/129 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 4.617745399475098, 'eval_accuracy': 0.5500485908649174, 'eval_f1': 0.5511369002526866, 'eval_precision': 0.7998093864476505, 'eval_recall': 0.5263615811202424, 'eval_runtime': 28.5304, 'eval_samples_per_second': 144.267, 'eval_steps_per_second': 4.521, 'epoch': 0.54}
{'loss': 4.6569, 'grad_norm': 0.744701623916626, 'learning_rate': 4.600000000000001e-05, 'epoch': 0.62}


  0%|          | 0/129 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 4.576382160186768, 'eval_accuracy': 0.5738581146744413, 'eval_f1': 0.5800354551041117, 'eval_precision': 0.8207891649420048, 'eval_recall': 0.5516830965289926, 'eval_runtime': 28.9367, 'eval_samples_per_second': 142.241, 'eval_steps_per_second': 4.458, 'epoch': 0.62}
{'loss': 4.6293, 'grad_norm': 0.8225492238998413, 'learning_rate': 4.55e-05, 'epoch': 0.7}


  0%|          | 0/129 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 4.535852432250977, 'eval_accuracy': 0.6141885325558795, 'eval_f1': 0.6148517759248673, 'eval_precision': 0.807489842077252, 'eval_recall': 0.5926437581767611, 'eval_runtime': 29.8128, 'eval_samples_per_second': 138.062, 'eval_steps_per_second': 4.327, 'epoch': 0.7}
{'loss': 4.5953, 'grad_norm': 0.835442066192627, 'learning_rate': 4.5e-05, 'epoch': 0.78}


  0%|          | 0/129 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 4.482782363891602, 'eval_accuracy': 0.6207482993197279, 'eval_f1': 0.6233347480319061, 'eval_precision': 0.8108960881073339, 'eval_recall': 0.5999664720807305, 'eval_runtime': 30.0674, 'eval_samples_per_second': 136.893, 'eval_steps_per_second': 4.29, 'epoch': 0.78}
{'loss': 4.5651, 'grad_norm': 0.8578382134437561, 'learning_rate': 4.4500000000000004e-05, 'epoch': 0.85}


  0%|          | 0/129 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 4.425670146942139, 'eval_accuracy': 0.6591350826044704, 'eval_f1': 0.6584699253003153, 'eval_precision': 0.8147592711787498, 'eval_recall': 0.6393439002306762, 'eval_runtime': 28.4304, 'eval_samples_per_second': 144.775, 'eval_steps_per_second': 4.537, 'epoch': 0.85}
{'loss': 4.5296, 'grad_norm': 0.9620392322540283, 'learning_rate': 4.4000000000000006e-05, 'epoch': 0.93}


  0%|          | 0/129 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 4.364680290222168, 'eval_accuracy': 0.706268221574344, 'eval_f1': 0.7012054635300039, 'eval_precision': 0.8284350125904834, 'eval_recall': 0.688199507444556, 'eval_runtime': 28.5471, 'eval_samples_per_second': 144.183, 'eval_steps_per_second': 4.519, 'epoch': 0.93}
{'loss': 4.4911, 'grad_norm': 0.9173192977905273, 'learning_rate': 4.35e-05, 'epoch': 1.01}


  0%|          | 0/129 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 4.299846649169922, 'eval_accuracy': 0.7089407191448007, 'eval_f1': 0.7073568856764126, 'eval_precision': 0.8325596625698185, 'eval_recall': 0.6924090542708233, 'eval_runtime': 28.5965, 'eval_samples_per_second': 143.934, 'eval_steps_per_second': 4.511, 'epoch': 1.01}
{'loss': 4.4442, 'grad_norm': 0.9183776378631592, 'learning_rate': 4.3e-05, 'epoch': 1.09}


  0%|          | 0/129 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 4.228794574737549, 'eval_accuracy': 0.6938775510204082, 'eval_f1': 0.6890499178440211, 'eval_precision': 0.8302365826885487, 'eval_recall': 0.6758939664483897, 'eval_runtime': 28.7618, 'eval_samples_per_second': 143.106, 'eval_steps_per_second': 4.485, 'epoch': 1.09}
{'loss': 4.3912, 'grad_norm': 1.0323781967163086, 'learning_rate': 4.25e-05, 'epoch': 1.17}


  0%|          | 0/129 [00:00<?, ?it/s]

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 4.152723789215088, 'eval_accuracy': 0.6873177842565598, 'eval_f1': 0.6863011851876918, 'eval_precision': 0.8261897457310591, 'eval_recall': 0.6702606718880093, 'eval_runtime': 29.7578, 'eval_samples_per_second': 138.317, 'eval_steps_per_second': 4.335, 'epoch': 1.17}
