{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "RobustViT.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyNP00yXydKk0stZEJQyT5pO",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
""
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hzG96yZskJSy",
"outputId": "3eab22fa-e246-4cfb-d4a9-c35878cf75f2"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cloning into 'RobustViT'...\n",
"remote: Enumerating objects: 139, done.\u001b[K\n",
"remote: Counting objects: 100% (139/139), done.\u001b[K\n",
"remote: Compressing objects: 100% (119/119), done.\u001b[K\n",
"remote: Total 139 (delta 54), reused 84 (delta 18), pack-reused 0\u001b[K\n",
"Receiving objects: 100% (139/139), 4.50 MiB | 17.11 MiB/s, done.\n",
"Resolving deltas: 100% (54/54), done.\n"
]
}
],
"source": [
"!git clone https://github.com/hila-chefer/RobustViT.git\n",
"\n",
"import os\n",
"os.chdir(f'./RobustViT')"
]
},
{
"cell_type": "code",
"source": [
"!pip install timm\n",
"!pip install einops"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hZK84BL3mZQg",
"outputId": "f9273be0-3410-47cf-f52e-2a45989e9b1f"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting timm\n",
" Downloading timm-0.5.4-py3-none-any.whl (431 kB)\n",
"\u001b[K |████████████████████████████████| 431 kB 4.4 MB/s \n",
"\u001b[?25hRequirement already satisfied: torch>=1.4 in /usr/local/lib/python3.7/dist-packages (from timm) (1.11.0+cu113)\n",
"Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (from timm) (0.12.0+cu113)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.4->timm) (4.2.0)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from torchvision->timm) (1.21.6)\n",
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision->timm) (7.1.2)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from torchvision->timm) (2.23.0)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->timm) (3.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->timm) (2022.5.18.1)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->timm) (1.24.3)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->timm) (2.10)\n",
"Installing collected packages: timm\n",
"Successfully installed timm-0.5.4\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting einops\n",
" Downloading einops-0.4.1-py3-none-any.whl (28 kB)\n",
"Installing collected packages: einops\n",
"Successfully installed einops-0.4.1\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from PIL import Image\n",
"import torchvision.transforms as transforms\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"import numpy as np\n",
"import cv2\n",
"from CLS2IDX import CLS2IDX\n",
"\n",
"normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n",
"transform = transforms.Compose([\n",
" transforms.Resize(256),\n",
" transforms.CenterCrop(224),\n",
" transforms.ToTensor(),\n",
" normalize,\n",
"])\n",
"transform_224 = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" normalize,\n",
"])\n",
"\n",
"# create heatmap from mask on image\n",
"def show_cam_on_image(img, mask):\n",
" heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)\n",
" heatmap = np.float32(heatmap) / 255\n",
" cam = heatmap + np.float32(img)\n",
" cam = cam / np.max(cam)\n",
" return cam"
],
"metadata": {
"id": "uqDTsTS2k8pl"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from pydrive.auth import GoogleAuth\n",
"from pydrive.drive import GoogleDrive\n",
"from google.colab import auth\n",
"from oauth2client.client import GoogleCredentials\n",
"\n",
"# Authenticate and create the PyDrive client.\n",
"auth.authenticate_user()\n",
"gauth = GoogleAuth()\n",
"gauth.credentials = GoogleCredentials.get_application_default()\n",
"drive = GoogleDrive(gauth)\n",
"\n",
"# downloads weights\n",
"ids = ['1jbWiuBrL4sKpAjG3x4oGbs3WOC2UdbIb', '1DHKX_s8rVCDiX4pwnuCCZdGWsOl4SFMn', '1vDmuvbdLbYVAqWz6yVM4vT1Wdzt8KV-g']\n",
"for file_id in ids:\n",
" downloaded = drive.CreateFile({'id':file_id})\n",
" downloaded.FetchMetadata(fetch_all=True)\n",
" downloaded.GetContentFile(downloaded.metadata['title'])"
],
"metadata": {
"id": "eImo3TAenFbo"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"model_name = 'ar_base' #@param ['ar_base','vit_base', 'deit_base']\n",
"\n",
"if model_name == 'ar_base':\n",
" from ViT.ViT_new import vit_base_patch16_224 as vit\n",
"\n",
" # initialize ViT pretrained\n",
" model = vit(pretrained=True).cuda()\n",
" model.eval()\n",
"\n",
" model_finetuned = vit().cuda()\n",
" checkpoint = torch.load('ar_base.tar')\n",
"\n",
"if model_name == 'vit_base':\n",
" from ViT.ViT import vit_base_patch16_224 as vit\n",
"\n",
" # initialize ViT pretrained\n",
" model = vit(pretrained=True).cuda()\n",
" model.eval()\n",
"\n",
" model_finetuned = vit().cuda()\n",
" checkpoint = torch.load('vit_base.tar')\n",
"\n",
"if model_name == 'deit_base':\n",
" from ViT.ViT import deit_base_patch16_224 as vit\n",
"\n",
" # initialize ViT pretrained\n",
" model = vit(pretrained=True).cuda()\n",
" model.eval()\n",
"\n",
" model_finetuned = vit().cuda()\n",
" checkpoint = torch.load('deit_base.tar')\n",
"\n",
"model_finetuned.load_state_dict(checkpoint['state_dict'])\n",
"model_finetuned.eval()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"cellView": "form",
"id": "tayM9OIWlLOT",
"outputId": "ac6f819a-ef56-4bbf-c25c-5f9cc9b2bb5c"
},
"execution_count": 7,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"VisionTransformer(\n",
" (patch_embed): PatchEmbed(\n",
" (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))\n",
" (norm): Identity()\n",
" )\n",
" (pos_drop): Dropout(p=0.0, inplace=False)\n",
" (blocks): ModuleList(\n",
" (0): Block(\n",
" (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (attn): Attention(\n",
" (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
" (proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (proj_drop): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (drop_path): Identity()\n",
" (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (mlp): Mlp(\n",
" (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
" (act): GELU()\n",
" (drop1): Dropout(p=0.0, inplace=False)\n",
" (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
" (drop2): Dropout(p=0.0, inplace=False)\n",
" )\n",
" )\n",
" (1): Block(\n",
" (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (attn): Attention(\n",
" (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
" (proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (proj_drop): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (drop_path): Identity()\n",
" (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (mlp): Mlp(\n",
" (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
" (act): GELU()\n",
" (drop1): Dropout(p=0.0, inplace=False)\n",
" (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
" (drop2): Dropout(p=0.0, inplace=False)\n",
" )\n",
" )\n",
" (2): Block(\n",
" (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (attn): Attention(\n",
" (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
" (proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (proj_drop): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (drop_path): Identity()\n",
" (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (mlp): Mlp(\n",
" (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
" (act): GELU()\n",
" (drop1): Dropout(p=0.0, inplace=False)\n",
" (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
" (drop2): Dropout(p=0.0, inplace=False)\n",
" )\n",
" )\n",
" (3): Block(\n",
" (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (attn): Attention(\n",
" (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
" (proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (proj_drop): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (drop_path): Identity()\n",
" (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (mlp): Mlp(\n",
" (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
" (act): GELU()\n",
" (drop1): Dropout(p=0.0, inplace=False)\n",
" (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
" (drop2): Dropout(p=0.0, inplace=False)\n",
" )\n",
" )\n",
" (4): Block(\n",
" (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (attn): Attention(\n",
" (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
" (proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (proj_drop): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (drop_path): Identity()\n",
" (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (mlp): Mlp(\n",
" (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
" (act): GELU()\n",
" (drop1): Dropout(p=0.0, inplace=False)\n",
" (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
" (drop2): Dropout(p=0.0, inplace=False)\n",
" )\n",
" )\n",
" (5): Block(\n",
" (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (attn): Attention(\n",
" (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
" (proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (proj_drop): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (drop_path): Identity()\n",
" (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (mlp): Mlp(\n",
" (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
" (act): GELU()\n",
" (drop1): Dropout(p=0.0, inplace=False)\n",
" (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
" (drop2): Dropout(p=0.0, inplace=False)\n",
" )\n",
" )\n",
" (6): Block(\n",
" (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (attn): Attention(\n",
" (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
" (proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (proj_drop): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (drop_path): Identity()\n",
" (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (mlp): Mlp(\n",
" (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
" (act): GELU()\n",
" (drop1): Dropout(p=0.0, inplace=False)\n",
" (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
" (drop2): Dropout(p=0.0, inplace=False)\n",
" )\n",
" )\n",
" (7): Block(\n",
" (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (attn): Attention(\n",
" (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
" (proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (proj_drop): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (drop_path): Identity()\n",
" (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (mlp): Mlp(\n",
" (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
" (act): GELU()\n",
" (drop1): Dropout(p=0.0, inplace=False)\n",
" (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
" (drop2): Dropout(p=0.0, inplace=False)\n",
" )\n",
" )\n",
" (8): Block(\n",
" (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (attn): Attention(\n",
" (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
" (proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (proj_drop): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (drop_path): Identity()\n",
" (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (mlp): Mlp(\n",
" (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
" (act): GELU()\n",
" (drop1): Dropout(p=0.0, inplace=False)\n",
" (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
" (drop2): Dropout(p=0.0, inplace=False)\n",
" )\n",
" )\n",
" (9): Block(\n",
" (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (attn): Attention(\n",
" (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
" (proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (proj_drop): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (drop_path): Identity()\n",
" (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (mlp): Mlp(\n",
" (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
" (act): GELU()\n",
" (drop1): Dropout(p=0.0, inplace=False)\n",
" (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
" (drop2): Dropout(p=0.0, inplace=False)\n",
" )\n",
" )\n",
" (10): Block(\n",
" (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (attn): Attention(\n",
" (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
" (proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (proj_drop): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (drop_path): Identity()\n",
" (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (mlp): Mlp(\n",
" (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
" (act): GELU()\n",
" (drop1): Dropout(p=0.0, inplace=False)\n",
" (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
" (drop2): Dropout(p=0.0, inplace=False)\n",
" )\n",
" )\n",
" (11): Block(\n",
" (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (attn): Attention(\n",
" (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
" (proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (proj_drop): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (drop_path): Identity()\n",
" (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (mlp): Mlp(\n",
" (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
" (act): GELU()\n",
" (drop1): Dropout(p=0.0, inplace=False)\n",
" (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
" (drop2): Dropout(p=0.0, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
" (pre_logits): Identity()\n",
" (head): Linear(in_features=768, out_features=1000, bias=True)\n",
")"
]
},
"metadata": {},
"execution_count": 7
}
]
},
{
"cell_type": "code",
"source": [
"start_layer = 0\n",
"\n",
"# rule 5 from paper\n",
"def avg_heads(cam, grad):\n",
" cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1])\n",
" grad = grad.reshape(-1, grad.shape[-2], grad.shape[-1])\n",
" cam = grad * cam\n",
" cam = cam.clamp(min=0).mean(dim=0)\n",
" return cam\n",
"\n",
"# rule 6 from paper\n",
"def apply_self_attention_rules(R_ss, cam_ss):\n",
" R_ss_addition = torch.matmul(cam_ss, R_ss)\n",
" return R_ss_addition\n",
"\n",
"def generate_relevance(model, input, index=None):\n",
" output = model(input, register_hook=True)\n",
" if index == None:\n",
" index = np.argmax(output.cpu().data.numpy(), axis=-1)\n",
"\n",
" one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)\n",
" one_hot[0, index] = 1\n",
" one_hot_vector = one_hot\n",
" one_hot = torch.from_numpy(one_hot).requires_grad_(True)\n",
" one_hot = torch.sum(one_hot.cuda() * output)\n",
" model.zero_grad()\n",
" one_hot.backward(retain_graph=True)\n",
"\n",
" num_tokens = model.blocks[0].attn.get_attention_map().shape[-1]\n",
" R = torch.eye(num_tokens, num_tokens).cuda()\n",
" for i,blk in enumerate(model.blocks):\n",
" if i < start_layer:\n",
" continue\n",
" grad = blk.attn.get_attn_gradients()\n",
" cam = blk.attn.get_attention_map()\n",
" cam = avg_heads(cam, grad)\n",
" R += apply_self_attention_rules(R.cuda(), cam.cuda())\n",
" return R[0, 1:]"
],
"metadata": {
"id": "tKp64OSWlC7w"
},
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def generate_visualization(model, original_image, class_index=None):\n",
" with torch.enable_grad():\n",
" transformer_attribution = generate_relevance(model, original_image.unsqueeze(0).cuda(), index=class_index).detach()\n",
" transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)\n",
" transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')\n",
" transformer_attribution = transformer_attribution.reshape(224, 224).cuda().data.cpu().numpy()\n",
" transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())\n",
" \n",
" image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()\n",
" image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())\n",
" vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)\n",
" vis = np.uint8(255 * vis)\n",
" vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)\n",
" return vis\n",
"\n",
"def print_top_classes(predictions, **kwargs): \n",
" # Print Top-5 predictions\n",
" prob = torch.softmax(predictions, dim=1)\n",
" class_indices = predictions.data.topk(5, dim=1)[1][0].tolist()\n",
" max_str_len = 0\n",
" class_names = []\n",
" for cls_idx in class_indices:\n",
" class_names.append(CLS2IDX[cls_idx])\n",
" if len(CLS2IDX[cls_idx]) > max_str_len:\n",
" max_str_len = len(CLS2IDX[cls_idx])\n",
" \n",
" print('Top 5 classes:')\n",
" for cls_idx in class_indices:\n",
" output_string = '\\t{} : {}'.format(cls_idx, CLS2IDX[cls_idx])\n",
" output_string += ' ' * (max_str_len - len(CLS2IDX[cls_idx])) + '\\t\\t'\n",
" output_string += 'value = {:.3f}\\t prob = {:.1f}%'.format(predictions[0, cls_idx], 100 * prob[0, cls_idx])\n",
" print(output_string)"
],
"metadata": {
"id": "rmQ9pacLoGze"
},
"execution_count": 9,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# ImageNet-A"
],
"metadata": {
"id": "5o-p6euNMibE"
}
},
{
"cell_type": "code",
"source": [
"with torch.no_grad():\n",
" image = Image.open(f'samples/{model_name}/a.png')\n",
" dog_cat_image = transform_224(image)\n",
"\n",
" fig, axs = plt.subplots(1, 2)\n",
" fig.set_size_inches(10, 7)\n",
" axs[0].imshow(image);\n",
" axs[0].axis('off');\n",
"\n",
" output = model(dog_cat_image.unsqueeze(0).cuda())\n",
" print(\"original model\")\n",
" print_top_classes(output)\n",
"\n",
" out = generate_visualization(model, dog_cat_image)\n",
"\n",
" fig.suptitle('original model',y=0.8)\n",
" axs[1].imshow(out);\n",
" axs[1].axis('off');\n",
"\n",
" fig, axs = plt.subplots(1, 2)\n",
" fig.set_size_inches(10, 7)\n",
" axs[0].imshow(image);\n",
" axs[0].axis('off');\n",
" output = model_finetuned(dog_cat_image.unsqueeze(0).cuda())\n",
" print(\"finetuned model\")\n",
" print_top_classes(output)\n",
"\n",
" out = generate_visualization(model_finetuned, dog_cat_image)\n",
"\n",
" fig.suptitle('finetuned model',y=0.8)\n",
" axs[1].imshow(out);\n",
" axs[1].axis('off');"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 842
},
"id": "q8hsWi_WMlkb",
"outputId": "2a89083c-53b6-4b5e-c7b0-c01d9c90446c"
},
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"original model\n",
"Top 5 classes:\n",
"\t829 : streetcar, tram, tramcar, trolley, trolley car\t\tvalue = 10.911\t prob = 57.2%\n",
"\t874 : trolleybus, trolley coach, trackless trolley \t\tvalue = 10.221\t prob = 28.7%\n",
"\t466 : bullet train, bullet \t\tvalue = 6.897\t prob = 1.0%\n",
"\t733 : pole \t\tvalue = 6.878\t prob = 1.0%\n",
"\t547 : electric locomotive \t\tvalue = 6.626\t prob = 0.8%\n",
"finetuned model\n",
"Top 5 classes:\n",
"\t847 : tank, army tank, armored combat vehicle, armoured combat vehicle\t\tvalue = 11.573\t prob = 60.1%\n",
"\t408 : amphibian, amphibious vehicle \t\tvalue = 10.085\t prob = 13.6%\n",
"\t874 : trolleybus, trolley coach, trackless trolley \t\tvalue = 9.585\t prob = 8.2%\n",
"\t829 : streetcar, tram, tramcar, trolley, trolley car \t\tvalue = 9.583\t prob = 8.2%\n",
"\t586 : half track \t\tvalue = 7.935\t prob = 1.6%\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"