{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "collapsed_sections": [ "-tNVQkHnZfrs" ] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "gpuClass": "standard" }, "cells": [ { "cell_type": "markdown", "source": [ "## Setup" ], "metadata": { "id": "-tNVQkHnZfrs" } }, { "cell_type": "code", "source": [ "%cd /content/\n", "%rm -rf semantic-segmentation\n", "!git clone https://github.com/hb0313/semantic-segmentation\n", "%cd semantic-segmentation\n", "%pip install -e .\n", "%pip install -U gdown" ], "metadata": { "id": "pzBeWQDQZdic", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "30620197-c859-44ec-d7aa-eeccd032cdcc" }, "execution_count": 5, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "/content\n", "Cloning into 'semantic-segmentation'...\n", "remote: Enumerating objects: 792, done.\u001b[K\n", "remote: Counting objects: 100% (39/39), done.\u001b[K\n", "remote: Compressing objects: 100% (28/28), done.\u001b[K\n", "remote: Total 792 (delta 11), reused 31 (delta 11), pack-reused 753\u001b[K\n", "Receiving objects: 100% (792/792), 55.00 MiB | 19.25 MiB/s, done.\n", "Resolving deltas: 100% (462/462), done.\n", "/content/semantic-segmentation\n", "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Obtaining file:///content/semantic-segmentation\n", "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from semseg==0.4.1) (4.64.1)\n", "Requirement already satisfied: tabulate in /usr/local/lib/python3.7/dist-packages (from semseg==0.4.1) (0.8.10)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from semseg==0.4.1) (1.21.6)\n", "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from semseg==0.4.1) (1.7.3)\n", "Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from semseg==0.4.1) (3.2.2)\n", "Requirement already satisfied: tensorboard in /usr/local/lib/python3.7/dist-packages (from semseg==0.4.1) (2.8.0)\n", "Requirement already satisfied: fvcore in /usr/local/lib/python3.7/dist-packages (from semseg==0.4.1) (0.1.5.post20220512)\n", "Requirement already satisfied: einops in /usr/local/lib/python3.7/dist-packages (from semseg==0.4.1) (0.4.1)\n", "Requirement already satisfied: rich in /usr/local/lib/python3.7/dist-packages (from semseg==0.4.1) (12.5.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from fvcore->semseg==0.4.1) (6.0)\n", "Requirement already satisfied: iopath>=0.1.7 in /usr/local/lib/python3.7/dist-packages (from fvcore->semseg==0.4.1) (0.1.10)\n", "Requirement already satisfied: Pillow in /usr/local/lib/python3.7/dist-packages (from fvcore->semseg==0.4.1) (7.1.2)\n", "Requirement already satisfied: termcolor>=1.1 in /usr/local/lib/python3.7/dist-packages (from fvcore->semseg==0.4.1) (1.1.0)\n", "Requirement already satisfied: yacs>=0.1.6 in /usr/local/lib/python3.7/dist-packages (from fvcore->semseg==0.4.1) (0.1.8)\n", "Requirement already satisfied: portalocker in /usr/local/lib/python3.7/dist-packages (from iopath>=0.1.7->fvcore->semseg==0.4.1) (2.5.1)\n", "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from iopath>=0.1.7->fvcore->semseg==0.4.1) (4.1.1)\n", "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->semseg==0.4.1) (3.0.9)\n", "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->semseg==0.4.1) (2.8.2)\n", "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->semseg==0.4.1) (0.11.0)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->semseg==0.4.1) (1.4.4)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib->semseg==0.4.1) (1.15.0)\n", "Requirement already satisfied: commonmark<0.10.0,>=0.9.0 in /usr/local/lib/python3.7/dist-packages (from rich->semseg==0.4.1) (0.9.1)\n", "Requirement already satisfied: pygments<3.0.0,>=2.6.0 in /usr/local/lib/python3.7/dist-packages (from rich->semseg==0.4.1) (2.6.1)\n", "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.7/dist-packages (from tensorboard->semseg==0.4.1) (1.2.0)\n", "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard->semseg==0.4.1) (3.4.1)\n", "Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.7/dist-packages (from tensorboard->semseg==0.4.1) (0.37.1)\n", "Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.7/dist-packages (from tensorboard->semseg==0.4.1) (1.0.1)\n", "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard->semseg==0.4.1) (1.8.1)\n", "Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard->semseg==0.4.1) (1.35.0)\n", "Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard->semseg==0.4.1) (3.17.3)\n", "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard->semseg==0.4.1) (0.4.6)\n", "Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard->semseg==0.4.1) (1.48.1)\n", "Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard->semseg==0.4.1) (2.23.0)\n", "Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard->semseg==0.4.1) (57.4.0)\n", "Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard->semseg==0.4.1) (0.6.1)\n", "Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard->semseg==0.4.1) (4.9)\n", "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard->semseg==0.4.1) (4.2.4)\n", "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard->semseg==0.4.1) (0.2.8)\n", "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard->semseg==0.4.1) (1.3.1)\n", "Requirement already satisfied: importlib-metadata>=4.4 in /usr/local/lib/python3.7/dist-packages (from markdown>=2.6.8->tensorboard->semseg==0.4.1) (4.12.0)\n", "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=4.4->markdown>=2.6.8->tensorboard->semseg==0.4.1) (3.8.1)\n", "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard->semseg==0.4.1) (0.4.8)\n", "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard->semseg==0.4.1) (2.10)\n", "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard->semseg==0.4.1) (1.24.3)\n", "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard->semseg==0.4.1) (3.0.4)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.21.0->tensorboard->semseg==0.4.1) (2022.6.15)\n", "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard->semseg==0.4.1) (3.2.0)\n", "Installing collected packages: semseg\n", " Running setup.py develop for semseg\n", "Successfully installed semseg-0.4.1\n", "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Requirement already satisfied: gdown in /usr/local/lib/python3.7/dist-packages (4.5.1)\n", "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from gdown) (1.15.0)\n", "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from gdown) (4.64.1)\n", "Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.7/dist-packages (from gdown) (4.6.3)\n", "Requirement already satisfied: requests[socks] in /usr/local/lib/python3.7/dist-packages (from gdown) (2.23.0)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from gdown) (3.8.0)\n", "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (2.10)\n", "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (1.24.3)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (2022.6.15)\n", "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (3.0.4)\n", "Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (1.7.1)\n" ] } ] }, { "cell_type": "markdown", "source": [ "## Defination for loading model and checkpoints" ], "metadata": { "id": "L-WE4q6_ZHdQ" } }, { "cell_type": "code", "source": [ "import gdown\n", "from pathlib import Path\n", "import torch\n", "from torchvision import io\n", "from torchvision import transforms as T\n", "from PIL import Image\n", "from semseg.models import *\n", "from google.colab import files\n", "# from IPython.display import Image\n", "\n", "\n", "def get_checkpoints():\n", " ckpt = Path('./checkpoints/pretrained/segformer')\n", " ckpt.mkdir(exist_ok=True, parents=True)\n", "\n", " url = 'https://huggingface.co/hashb/semantic-segmentation-segformer/resolve/main/segformer.b3.ade.pth'\n", " output = './checkpoints/pretrained/segformer/segformer.b3.ade.pth'\n", " gdown.download(url, output, quiet=False)\n", "\n", "def show_image(image):\n", " if image.shape[2] != 3: image = image.permute(1, 2, 0)\n", " image = Image.fromarray(image.numpy())\n", " # image.save(\"result.png\")\n", " return image\n", "\n", "def load_model():\n", " model = eval('SegFormer')(\n", " backbone='MiT-B3',\n", " num_classes=150\n", " )\n", "\n", " try:\n", " model.load_state_dict(torch.load('checkpoints/pretrained/segformer/segformer.b3.ade.pth', map_location='cpu'))\n", " except:\n", " print(\"Download a pretrained model's weights from the result table.\")\n", " model.eval()\n", " return model\n", "\n", " print('Loaded Model')" ], "metadata": { "id": "QnfB4lrzjo33" }, "execution_count": 6, "outputs": [] }, { "cell_type": "code", "source": [ "get_checkpoints()\n", "model = load_model()" ], "metadata": { "id": "L1YCKFKJKP1H", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "0b5b2aeb-0978-46fd-8638-bb4976431c60" }, "execution_count": 7, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "Downloading...\n", "From: https://huggingface.co/hashb/semantic-segmentation-segformer/resolve/main/segformer.b3.ade.pth\n", "To: /content/semantic-segmentation/checkpoints/pretrained/segformer/segformer.b3.ade.pth\n", "100%|██████████| 190M/190M [00:00<00:00, 206MB/s]\n" ] } ] }, { "cell_type": "markdown", "source": [ "# Upload image file" ], "metadata": { "id": "FRDvSMmvoK_0" } }, { "cell_type": "code", "source": [ "uploaded = files.upload()\n", "for i in uploaded:\n", " image_path = i\n", "image = io.read_image(image_path)" ], "metadata": { "colab": { "resources": { "http://localhost:8080/nbextensions/google.colab/files.js": { "data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7CgpmdW5jdGlvbiBfdXBsb2FkRmlsZXMoaW5wdXRJZCwgb3V0cHV0SWQpIHsKICBjb25zdCBzdGVwcyA9IHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCk7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICAvLyBDYWNoZSBzdGVwcyBvbiB0aGUgb3V0cHV0RWxlbWVudCB0byBtYWtlIGl0IGF2YWlsYWJsZSBmb3IgdGhlIG5leHQgY2FsbAogIC8vIHRvIHVwbG9hZEZpbGVzQ29udGludWUgZnJvbSBQeXRob24uCiAgb3V0cHV0RWxlbWVudC5zdGVwcyA9IHN0ZXBzOwoKICByZXR1cm4gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpOwp9CgovLyBUaGlzIGlzIHJvdWdobHkgYW4gYXN5bmMgZ2VuZXJhdG9yIChub3Qgc3VwcG9ydGVkIGluIHRoZSBicm93c2VyIHlldCksCi8vIHdoZXJlIHRoZXJlIGFyZSBtdWx0aXBsZSBhc3luY2hyb25vdXMgc3RlcHMgYW5kIHRoZSBQeXRob24gc2lkZSBpcyBnb2luZwovLyB0byBwb2xsIGZvciBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcC4KLy8gVGhpcyB1c2VzIGEgUHJvbWlzZSB0byBibG9jayB0aGUgcHl0aG9uIHNpZGUgb24gY29tcGxldGlvbiBvZiBlYWNoIHN0ZXAsCi8vIHRoZW4gcGFzc2VzIHRoZSByZXN1bHQgb2YgdGhlIHByZXZpb3VzIHN0ZXAgYXMgdGhlIGlucHV0IHRvIHRoZSBuZXh0IHN0ZXAuCmZ1bmN0aW9uIF91cGxvYWRGaWxlc0NvbnRpbnVlKG91dHB1dElkKSB7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICBjb25zdCBzdGVwcyA9IG91dHB1dEVsZW1lbnQuc3RlcHM7CgogIGNvbnN0IG5leHQgPSBzdGVwcy5uZXh0KG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSk7CiAgcmV0dXJuIFByb21pc2UucmVzb2x2ZShuZXh0LnZhbHVlLnByb21pc2UpLnRoZW4oKHZhbHVlKSA9PiB7CiAgICAvLyBDYWNoZSB0aGUgbGFzdCBwcm9taXNlIHZhbHVlIHRvIG1ha2UgaXQgYXZhaWxhYmxlIHRvIHRoZSBuZXh0CiAgICAvLyBzdGVwIG9mIHRoZSBnZW5lcmF0b3IuCiAgICBvdXRwdXRFbGVtZW50Lmxhc3RQcm9taXNlVmFsdWUgPSB2YWx1ZTsKICAgIHJldHVybiBuZXh0LnZhbHVlLnJlc3BvbnNlOwogIH0pOwp9CgovKioKICogR2VuZXJhdG9yIGZ1bmN0aW9uIHdoaWNoIGlzIGNhbGxlZCBiZXR3ZWVuIGVhY2ggYXN5bmMgc3RlcCBvZiB0aGUgdXBsb2FkCiAqIHByb2Nlc3MuCiAqIEBwYXJhbSB7c3RyaW5nfSBpbnB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIGlucHV0IGZpbGUgcGlja2VyIGVsZW1lbnQuCiAqIEBwYXJhbSB7c3RyaW5nfSBvdXRwdXRJZCBFbGVtZW50IElEIG9mIHRoZSBvdXRwdXQgZGlzcGxheS4KICogQHJldHVybiB7IUl0ZXJhYmxlPCFPYmplY3Q+fSBJdGVyYWJsZSBvZiBuZXh0IHN0ZXBzLgogKi8KZnVuY3Rpb24qIHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IGlucHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKGlucHV0SWQpOwogIGlucHV0RWxlbWVudC5kaXNhYmxlZCA9IGZhbHNlOwoKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIG91dHB1dEVsZW1lbnQuaW5uZXJIVE1MID0gJyc7CgogIGNvbnN0IHBpY2tlZFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgaW5wdXRFbGVtZW50LmFkZEV2ZW50TGlzdGVuZXIoJ2NoYW5nZScsIChlKSA9PiB7CiAgICAgIHJlc29sdmUoZS50YXJnZXQuZmlsZXMpOwogICAgfSk7CiAgfSk7CgogIGNvbnN0IGNhbmNlbCA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2J1dHRvbicpOwogIGlucHV0RWxlbWVudC5wYXJlbnRFbGVtZW50LmFwcGVuZENoaWxkKGNhbmNlbCk7CiAgY2FuY2VsLnRleHRDb250ZW50ID0gJ0NhbmNlbCB1cGxvYWQnOwogIGNvbnN0IGNhbmNlbFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgY2FuY2VsLm9uY2xpY2sgPSAoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9OwogIH0pOwoKICAvLyBXYWl0IGZvciB0aGUgdXNlciB0byBwaWNrIHRoZSBmaWxlcy4KICBjb25zdCBmaWxlcyA9IHlpZWxkIHsKICAgIHByb21pc2U6IFByb21pc2UucmFjZShbcGlja2VkUHJvbWlzZSwgY2FuY2VsUHJvbWlzZV0pLAogICAgcmVzcG9uc2U6IHsKICAgICAgYWN0aW9uOiAnc3RhcnRpbmcnLAogICAgfQogIH07CgogIGNhbmNlbC5yZW1vdmUoKTsKCiAgLy8gRGlzYWJsZSB0aGUgaW5wdXQgZWxlbWVudCBzaW5jZSBmdXJ0aGVyIHBpY2tzIGFyZSBub3QgYWxsb3dlZC4KICBpbnB1dEVsZW1lbnQuZGlzYWJsZWQgPSB0cnVlOwoKICBpZiAoIWZpbGVzKSB7CiAgICByZXR1cm4gewogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgICAgfQogICAgfTsKICB9CgogIGZvciAoY29uc3QgZmlsZSBvZiBmaWxlcykgewogICAgY29uc3QgbGkgPSBkb2N1bWVudC5jcmVhdGVFbGVtZW50KCdsaScpOwogICAgbGkuYXBwZW5kKHNwYW4oZmlsZS5uYW1lLCB7Zm9udFdlaWdodDogJ2JvbGQnfSkpOwogICAgbGkuYXBwZW5kKHNwYW4oCiAgICAgICAgYCgke2ZpbGUudHlwZSB8fCAnbi9hJ30pIC0gJHtmaWxlLnNpemV9IGJ5dGVzLCBgICsKICAgICAgICBgbGFzdCBtb2RpZmllZDogJHsKICAgICAgICAgICAgZmlsZS5sYXN0TW9kaWZpZWREYXRlID8gZmlsZS5sYXN0TW9kaWZpZWREYXRlLnRvTG9jYWxlRGF0ZVN0cmluZygpIDoKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJ24vYSd9IC0gYCkpOwogICAgY29uc3QgcGVyY2VudCA9IHNwYW4oJzAlIGRvbmUnKTsKICAgIGxpLmFwcGVuZENoaWxkKHBlcmNlbnQpOwoKICAgIG91dHB1dEVsZW1lbnQuYXBwZW5kQ2hpbGQobGkpOwoKICAgIGNvbnN0IGZpbGVEYXRhUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICAgIGNvbnN0IHJlYWRlciA9IG5ldyBGaWxlUmVhZGVyKCk7CiAgICAgIHJlYWRlci5vbmxvYWQgPSAoZSkgPT4gewogICAgICAgIHJlc29sdmUoZS50YXJnZXQucmVzdWx0KTsKICAgICAgfTsKICAgICAgcmVhZGVyLnJlYWRBc0FycmF5QnVmZmVyKGZpbGUpOwogICAgfSk7CiAgICAvLyBXYWl0IGZvciB0aGUgZGF0YSB0byBiZSByZWFkeS4KICAgIGxldCBmaWxlRGF0YSA9IHlpZWxkIHsKICAgICAgcHJvbWlzZTogZmlsZURhdGFQcm9taXNlLAogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbnRpbnVlJywKICAgICAgfQogICAgfTsKCiAgICAvLyBVc2UgYSBjaHVua2VkIHNlbmRpbmcgdG8gYXZvaWQgbWVzc2FnZSBzaXplIGxpbWl0cy4gU2VlIGIvNjIxMTU2NjAuCiAgICBsZXQgcG9zaXRpb24gPSAwOwogICAgZG8gewogICAgICBjb25zdCBsZW5ndGggPSBNYXRoLm1pbihmaWxlRGF0YS5ieXRlTGVuZ3RoIC0gcG9zaXRpb24sIE1BWF9QQVlMT0FEX1NJWkUpOwogICAgICBjb25zdCBjaHVuayA9IG5ldyBVaW50OEFycmF5KGZpbGVEYXRhLCBwb3NpdGlvbiwgbGVuZ3RoKTsKICAgICAgcG9zaXRpb24gKz0gbGVuZ3RoOwoKICAgICAgY29uc3QgYmFzZTY0ID0gYnRvYShTdHJpbmcuZnJvbUNoYXJDb2RlLmFwcGx5KG51bGwsIGNodW5rKSk7CiAgICAgIHlpZWxkIHsKICAgICAgICByZXNwb25zZTogewogICAgICAgICAgYWN0aW9uOiAnYXBwZW5kJywKICAgICAgICAgIGZpbGU6IGZpbGUubmFtZSwKICAgICAgICAgIGRhdGE6IGJhc2U2NCwKICAgICAgICB9LAogICAgICB9OwoKICAgICAgbGV0IHBlcmNlbnREb25lID0gZmlsZURhdGEuYnl0ZUxlbmd0aCA9PT0gMCA/CiAgICAgICAgICAxMDAgOgogICAgICAgICAgTWF0aC5yb3VuZCgocG9zaXRpb24gLyBmaWxlRGF0YS5ieXRlTGVuZ3RoKSAqIDEwMCk7CiAgICAgIHBlcmNlbnQudGV4dENvbnRlbnQgPSBgJHtwZXJjZW50RG9uZX0lIGRvbmVgOwoKICAgIH0gd2hpbGUgKHBvc2l0aW9uIDwgZmlsZURhdGEuYnl0ZUxlbmd0aCk7CiAgfQoKICAvLyBBbGwgZG9uZS4KICB5aWVsZCB7CiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICB9CiAgfTsKfQoKc2NvcGUuZ29vZ2xlID0gc2NvcGUuZ29vZ2xlIHx8IHt9OwpzY29wZS5nb29nbGUuY29sYWIgPSBzY29wZS5nb29nbGUuY29sYWIgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYi5fZmlsZXMgPSB7CiAgX3VwbG9hZEZpbGVzLAogIF91cGxvYWRGaWxlc0NvbnRpbnVlLAp9Owp9KShzZWxmKTsK", "ok": true, "headers": [ [ "content-type", "application/javascript" ] ], "status": 200, "status_text": "" } }, "base_uri": "https://localhost:8080/", "height": 73 }, "id": "k2cOX2CUaZuK", "outputId": "10aa9d41-cedf-4ebe-e39e-501ce52060f9" }, "execution_count": 8, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", " \n", " \n", " Upload widget is only available when the cell has been executed in the\n", " current browser session. Please rerun this cell to enable.\n", " \n", " " ] }, "metadata": {} }, { "output_type": "stream", "name": "stdout", "text": [ "Saving pexels-photo-1485031.jpeg to pexels-photo-1485031.jpeg\n" ] } ] }, { "cell_type": "code", "source": [ "# resize\n", "image = T.CenterCrop((512, 512))(image)\n", "# scale to [0.0, 1.0]\n", "image = image.float() / 255\n", "# normalize\n", "image = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))(image)\n", "# add batch size\n", "image = image.unsqueeze(0)\n", "image.shape" ], "metadata": { "id": "Mnfbjt3vjzmI", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "e3255e3f-c3ed-47f6-f8db-c72c679af496" }, "execution_count": 9, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.Size([1, 3, 512, 512])" ] }, "metadata": {}, "execution_count": 9 } ] }, { "cell_type": "code", "source": [ "with torch.inference_mode():\n", " seg = model(image)\n", "seg.shape" ], "metadata": { "id": "lJ6xNAwzj17M", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "7253e889-2b43-4f5c-ebcd-731001f4dd17" }, "execution_count": 10, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "torch.Size([1, 150, 512, 512])" ] }, "metadata": {}, "execution_count": 10 } ] }, { "cell_type": "code", "source": [ "seg = seg.softmax(1).argmax(1).to(int)\n", "seg.unique()" ], "metadata": { "id": "EqbG_qokj30-", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "ad1c2e60-777b-45ee-f27f-44fff275fabb" }, "execution_count": 11, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "tensor([ 0, 1, 4, 12])" ] }, "metadata": {}, "execution_count": 11 } ] }, { "cell_type": "code", "source": [ "from semseg.datasets import *\n", "\n", "palette = eval('ADE20K').PALETTE" ], "metadata": { "id": "dyj4dPoMj7aq" }, "execution_count": 12, "outputs": [] }, { "cell_type": "code", "source": [ "seg_map = palette[seg].squeeze().to(torch.uint8)\n", "show_image(seg_map)" ], "metadata": { "id": "02ptBGPAj8_g", "colab": { "base_uri": "https://localhost:8080/", "height": 529 }, "outputId": "0df8f3b6-7bfe-4ad3-a5e3-f794363abdbe" }, "execution_count": 13, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ], "image/png": "iVBORw0KGgoAAAANSUhEUgAAAgAAAAIACAIAAAB7GkOtAAAQuElEQVR4nO3dyXEkR5aA4eAY5IAgRFMHSDIXWJmVDmW4UBIcqMEMRhBIMoekgSCWRGYs7m/5PuOpu7o72OnxfvfIhb/99fPnAqu8PD7PvgSuc/twN/sS6ri//zX7Erb6r9kXAMAcAsB6tpOQmgBAF4LNOwIA0JQAADQlAGziqQJtPT39mH0JWwkAQFMCALBS9kOAAAA0JQDQhW9uHyH1IUAAAJoSAGjEIeAIeQ8BAgC9aMARkjZAAKCdl8dnGWARAGhLBvaV8RAgANCaDHR2M/sCyM3sqOH1dfTbHq04AbCe6V+PA0ErAsBKxkRhXtwmBAD4hKNABwLAGkYDfJTug0ACAHxJ6WsTAK5mKLTi5S5MAIBvaEBVAsB1zAL4yv39r9mXcB0BAL4n/CUJAJfyucDmvPrf8ikganLzs1gG5QgA33Pb88piqEQA+IYbHqoSAM4x/fnIqihDAICraUANAsCX3OScYXkUIAAATQkAn7O/41sWSXYCAKynAe88Pf1I9HUwAeAT7mrYIksGBADYxHbhK/EbIADAVhqQlADwnpuZFSybjAQA2IcGpCMA/It7mC2sn1wEAKApAeAftm9sZxUlIgD8zX0L3QgAQFMCAOzMaTILAQBoSgBYFls2aEkAAJoSAGz/2Z9FlYIAADQlAN3ZqUFbAgAcwt4iPgFozS3KoSyw4AQAoCkBAA7kEBCZAAAc4v7+1+xL+IYA9GVrxhhWWlgCALC/+Nv/RQDasiljpG7rLcX0XwQAGKNbA1IQgI7cisAiAA2Z/sxi7UUjAABNCUAvtmDMZQWGIgAATQlAIzZfRFB+HWb5DOgiAAA7SjT9FwEAxit/CMhCALpwywHvCEALpj/RWJMRCEB97jTgUwIA0JQAFGf7T1gW53QCANCUAAA0JQCVOWITnCU6lwAANCUAZdlbkYKFOpEA1OSmAr4lAMBk9iuzCABAUwJQkP0U6dRYtLl+C3oRgHpq3Eg0ZOmOJwCluIWAywkAEEXqHUy65z+LAAC0JQB1pN49wUnSZZxx+78IAEBbAlBE0n0TfJRxMT89/Zh9CWsIAMAOMjZAACrIuGOCM5Iu6XQNEID0kt4qcJ6FPYAAAEFpwNEEIDd3CLCaAABx5dripPs2gAAkluveAKIRACC0RBsdnwJikER3BRCTAADR2e4cRABScj8A2wkAQFMCACTg1HsEAQBoSgAAmhKAfJyF6cnK350AADQlAABNCUAyTsF0Zv3vSwAAmhIAgKYEIJnbh7vZlwAzBX8KlOsHQQUgmeCrH0hEAACaEgCApgQgE89/YHEj7EcAgDVuH+58JCE7AUjDrgfiu7//NfsSriAACbw8Ppv+8JY7YhcCEJ2FDlnk2v4vAgAkZW+03c3sC+BL1jeRRVifL4/P3ojeQgAiinBrAeUJQBSGPmT39PQj19sA3gMIwfSHddw7WwjAfFYwMIUATGb6w0ZuotUEYCYLF5hIAObw5V7YkbtpHZ8CGs1KhapyfQRocQIYya4fjuPmWsEJYARLEwjICeBwpj+M4V67lhPAgSxH6CPdGwCLE8BxTH8Yz313FSeAnVl/9HT7cGfxpyMA+7D0IQi/EX05j4B2YPqDuyAjAdjER/shIHflhQRgPYsMOMn4EaBlWX776+fP2deQjLkPKQx+JyBjA5wArmP6A2UIwBVMf+ArT08/Zl/C1QTgIt7shXTcs9/yPYBzLCCgMCeAL5n+kJ27+DwngE9YNLCL0+dw3FBhOQG8Z7FCJe7oMwTgX6wVoA8B+IfpDyW5tb8iAH+zRIBuBGBZTH+gJQEw/aG+Abe53wLKx/QH2modANMfDhXqN1TiXEkcfQNgNUA37vp3+gYAaEgD3moaAIsAoGkAgLbs/151DICXH5ozBE46BgBgXxm/BLA0DIDyA4tRsCxLwwAAcNIrAJoPvDIQegUA4K3mDRAAgKYaBaB56mG824e701+zL+SczpOhUQCAkd7O/fgZ6KlLADpHHjiv7XzoEgBgpE/3+w4B0bQIQNu8A5zRIgDASBl3+j23iQIAsEnSHwJaOgSgZ9iBazWcFfUDAMCnigegYdJhujP3XfC3B7pNjOIBAMYLPuV5JQDAbgp84/faQ0Ded4CXZbmZfQEH6naag4myz/2enACArdpO/9Tb/0UAgO2KnbaL/e2cUTYAfV5CiODl8bnbTZd9+78UDgAwXrcGZCcAwJ5qNODbv4sC2/+lagBqLEFI6swN6N4MpWYAgLkKDPoCfwvfKvg9gA4vG8T39k48fU7UvRmNEwBwuKSfEcp4zVepFoDyLxgwUu2RUi0AAFxIAADOKXwIKBWAwq8TwO5KBQDgCFU3lwIAcJ0aXwNeKgWgaqKBCEpOmDoBADjUqQFltv+LAABcrtg5oEgAir0qAAMUCQDAGH/+8fvsS9iNAAA0JQAATVUIgDcAgJHKPAWqEAAAVhAAgKvVOAQIAEBT6QPgDQBgigKHgPQBAGAdAQBoSgAAmhIAgKZyB8A7wMBE2d8Hzh0AAFYTAICmBACgKQEAaCpxALwDDLBF4gAATJf6g0ACANCUAAA0JQAATQkAwCZ53wYQAICmBACgKQEAaEoAAJoSAICmBACgKQEAaEoAAJoSAICmBABgk//+n/+bfQkrCQBAUwIA0JQAAGzix+AASEYAAJoSAICtkj4FEgCAHWRsgAAA7OPPP37PlYGb2RcAUMrbBgT/jpgAABzlNQYxS+AREMDhYj4aEgCAEQK+Q+AREMA4oR4KOQEATBDhQJA1AC+Pz7MvAWCruQ3IGgCAGiYeBQQAYL4pGRAAgCgGN0AAAAIZ2QABAIhlWAMEACCcMQ0QAICIBjRAAACaShkA3wIDOjj6EJAyAABNHNoAAQAI7bgGpAzA7cPd7EsASC9lAADYTgCgNefpFA56CuQfCAM1vZ3s7z44927on/mTxPHnH7/v/s+QEQAo6MyIv/A/qAQdeAQE1ezyVOf24e701/b/Kvay+4MgAQDOkYFQ9m2AR0BQykHD2vsEJaUMgPUHnxqzVX/3v+J+zCtlAIA4HA4G2/HjQAIARUR4Uu9wkIsAQAURpv9Hn17VaxXO/7sMIADAUDFb1ZOPgQIks9eHQQUA0rOnbmiXBggAQFMCALlV2v57B/gq2w8BAgCJVZr+jJcyABY9wLL5EJAyAMBiJ8SyLNsakDIAHhRCvenvvh4vZQCgOdOft1YfAgQAkjH92YsAADOZ/hMJAEBuq38dWgCAaWz/5xIAyKTeGwBMJADAHLb/0wkAQClPTz8u/JP+gTCQRqXnP7b/BzlN/7cNuL//9emfWQQAsqg0/RnpNO5PGXh3OBAAYDTb//E+fS7kPQBIwPafIwgAMJTt/+5efwvo8rd/TzwCgugqbf9N/4NcO/pP8p0ALCBaqTT9Oc66wZgvAEBSdm/RCADEVWn7b/oHlOw9AGsI0nHbhuUEABzI9B/m/v7Xxy/9npfsBAB9FHj+Y/qP91UDfBEMoKlPwyAAAE0JAHAIz38Ge/0+8OUEANif6Z+CAAC08PFtAAEAdmb7P8u1T4EEACLK+xlQ03+up6cfl/8wnAAANCUAEI7tP6td9RIIAEAX794HFgCIxfafYQQA2IHpn5EAAJRyeYwFANjK9j8pAQBoKlMA7DIgIDdmQF99Jfjdd8QyBQDKy/sRIKK55Gch0gTALgOieXl8dmNG9ucfv5//WYg0AQBCMfqzOPPrQAIAUXj+w+7edvpjBgQAQsg1/W3/Ezm9WIn/ofBWG7WZ/hzqq5csRwCAIEz/pD594QQAJku0/Tf9ixEAmCnR9Ce7j/0WAJgm1/S3/S/g3YsoADCH6c8Ub1/Km4nXAW2Z/kz0+oImOAFYfBRj+hNEggBAJaY/cQgAjJNr+lNeggC4Z2AK2//yEgQAasi1lTH9OxAAGMH0JyABAP7F9O9DAIB/mP6tCADwN9O/mwQBsChhADdaQwkCABzN9O9JAACa8mNw0Jq9f2dOAHC4sF8CMP2bEwBoyvRHAKAj059FAGAA05aYBADaESROBAB6Mf15JQAATQkANGL7z1sCANCUAEAXtv+8IwDQgunPR9EDYNXCdu4jPhU9AMBGpj9fEQCozPTnDAGAskx/zhMAGGH8LDb9+ZYAADQlAFCQ7T+XEACoxvTnQqEDYB3Dtdw1XC50AICrmP5cRQAAmhIAKML2n2sJAFRg+rOCAEB6pj/r3My+AGA9o58tnAAAmgodgNuHu9mXAHHZ/rNR6AAAXzH92U4AIB/Tn10IACRj+rMXAQBoysdAIQ17f/blBAA5mP7sTgBgEBOcaAQAxlnXgJfHZ/HgCAIAoRn9HEcAYKirBrrpz6EEAIIy/Tla9AD4OSDquWSym/4MED0AUNKZ+e4tX4YRAJjj00Fv9DOSAMBMbye+6c9gfgoCJjP3mSXBCcD7wABHSBAAAI6QIwAOAQC7yxEAAHaXJgAOAQD7ShOARQMAdpUpAADsKFkAHAIA9pIsAADsJV8AHAIAdpEvAADsImUAHAIAtksZgEUDADbLGgAANkocAIcAgC0SBwCALXIHwCEAYLXcAVg0AGCt9AEAYB0BAGiqQgA8BQJYoUIAAFhBAACaKhIAT4EArlUkAIsGAFypTgAAuEqpADgEAFyuVAAAuFy1ADgEAFyoWgAAuFDBADgEAFyiYAAWDQC4QM0ALBoA8J2yAQDgvMoBcAgAOKNyABYNAPha8QAsGgDwhfoBAOBTLQLgEADwUYsAAPBRlwA4BAC80yUAALzTKAAOAQBvNQoAAG/1CoBDAMCrm9kXALCnb/d5L4/PY64kvt/++vlz9jWM5uWHYq493BsCJx1PALcPd15+KMBD3Y16vQfwyrqB7LbcxSbASdMALFYA0F7fAAB52cDtQgAAmhIAIJldtv/OEIsAALQlAEAmdu47EgCgKS0RACCN3Ud28wYIAEBTAgDQlAAANCUAQA4HPa/v/DaAAAA0JQAATQkAkEDnBzXHEQCApgQAiM72/yACANCUAAA0JQBAaAOe/7R9xCQAAE31DcDL4/PsSwC+0XZvPkbfAADBmf5HaxoA23/grZ6xaRoAILieE3kwAQDCMf3HEAAgFtN/mI4B8AYAhGX6j9QxAEBMpv9gAgCEYPqPJwAATXUMgI0GROOunKJjAABYegbAp4AgFNv/WToGAIjD9J9IAACaEgBgmlDb/1AXM0a7AHgDAOCkVwBMf4ij4Y47mpvZFzCI0Q/wTv0AGP0Anyr+CMj0h5g8/4mgcgBMf4AzCj4CMvchuJjb/4ajo9oJoOFLCLBOqQCY/hBfzO1/TxUeAZn7kIXpH0riAJj7kEvw6X/7cNdtqmR9BNTtdYLsgk//nlIGwPQH2C5fAEx/SMf2P6ZM7wEY/ZCR6R9WjgAY/ZCU6R9ZgkdApj/AEUKfAIx+yCvd3r/hwAkagIavBFSSbvr3FC4ARj+kZvQnEigARj9kZ/rnEuVNYNMfYLDJJwBzH8qw/U9nUwD+94t//T8X/GeNfqjE9M9oTQC+mvtv/8CZBhj9UEyN6d/w10CvDsC30//1j31sQLf/c6G8GqO/rZsLB/oKb/+bb41+gGBGfArI9IeSbP+zC/Q9ACALo7+Gw08Atv9QjOlfRpQvggEwmAAAV7D9r+TYAHj+A5WY/sU4AQA0JQDARWz/6xEAgKYODIA3AKAM2/+SnAAAmjoqAJf8IjSQgu1/VUcFwA9/AgTnERBAU/8PzcTX+DN9v8kAAAAASUVORK5CYII=\n" }, "metadata": {}, "execution_count": 13 } ] } ] }