add config
Browse files- README.md +42 -0
- finetuning.ipynb +60 -0
- training_loop.py +16 -3
README.md
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
tags:
|
4 |
+
- vision
|
5 |
+
- image-segmentation
|
6 |
+
datasets:
|
7 |
+
- segments/sidewalk-semantic
|
8 |
+
---
|
9 |
+
|
10 |
+
# SegFormer (b0-sized) model fine-tuned on sidewalk-semantic dataset
|
11 |
+
|
12 |
+
SegFormer model fine-tuned on segments/sidewalk-semantic at resolution 512x512. It was introduced in the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Xie et al. and first released in [this repository](https://github.com/NVlabs/SegFormer).
|
13 |
+
|
14 |
+
## Model description
|
15 |
+
|
16 |
+
SegFormer consists of a hierarchical Transformer encoder and a lightweight all-MLP decode head to achieve great results on semantic segmentation benchmarks such as ADE20K and Cityscapes. The hierarchical Transformer is first pre-trained on ImageNet-1k, after which a decode head is added and fine-tuned altogether on a downstream dataset.
|
17 |
+
|
18 |
+
## Intended uses & limitations
|
19 |
+
|
20 |
+
You can use the raw model for semantic segmentation. See the [model hub](https://huggingface.co/models?other=segformer) to look for fine-tuned versions on a task that interests you.
|
21 |
+
|
22 |
+
### How to use
|
23 |
+
|
24 |
+
Here is how to use this model to classify an image of the COCO 2017 dataset into one of the 1,000 ImageNet classes:
|
25 |
+
|
26 |
+
```python
|
27 |
+
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
|
28 |
+
from PIL import Image
|
29 |
+
import requests
|
30 |
+
|
31 |
+
feature_extractor = SegformerFeatureExtractor(reduce_labels=True)
|
32 |
+
model = SegformerForSemanticSegmentation.from_pretrained("ChainYo/segformer-sidewalk")
|
33 |
+
|
34 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
35 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
36 |
+
|
37 |
+
inputs = feature_extractor(images=image, return_tensors="pt")
|
38 |
+
outputs = model(**inputs)
|
39 |
+
logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4)
|
40 |
+
```
|
41 |
+
|
42 |
+
For more code examples, we refer to the [documentation](https://huggingface.co/transformers/model_doc/segformer.html#).
|
finetuning.ipynb
CHANGED
@@ -938,6 +938,66 @@
|
|
938 |
"batch = next(iter(train_dataloader))"
|
939 |
]
|
940 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
941 |
{
|
942 |
"cell_type": "code",
|
943 |
"execution_count": null,
|
|
|
938 |
"batch = next(iter(train_dataloader))"
|
939 |
]
|
940 |
},
|
941 |
+
{
|
942 |
+
"cell_type": "code",
|
943 |
+
"execution_count": 31,
|
944 |
+
"metadata": {},
|
945 |
+
"outputs": [
|
946 |
+
{
|
947 |
+
"name": "stderr",
|
948 |
+
"output_type": "stream",
|
949 |
+
"text": [
|
950 |
+
"Cloning https://huggingface.co/ChainYo/segformer-sidewalk into local empty directory.\n",
|
951 |
+
"remote: Enforcing permissions... \n",
|
952 |
+
"remote: Allowed refs: all \n",
|
953 |
+
"To https://huggingface.co/ChainYo/segformer-sidewalk\n",
|
954 |
+
" c75c928..5d5f276 main -> main\n",
|
955 |
+
"\n"
|
956 |
+
]
|
957 |
+
},
|
958 |
+
{
|
959 |
+
"ename": "OSError",
|
960 |
+
"evalue": "It looks like the config file at '/home/chainyo/code/segformer-sidewalk/checkpoints/epoch=44-step=1125.ckpt' is not a valid JSON file.",
|
961 |
+
"output_type": "error",
|
962 |
+
"traceback": [
|
963 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
964 |
+
"\u001b[0;31mUnicodeDecodeError\u001b[0m Traceback (most recent call last)",
|
965 |
+
"File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py:650\u001b[0m, in \u001b[0;36mPretrainedConfig._get_config_dict\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=647'>648</a>\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=648'>649</a>\u001b[0m \u001b[39m# Load config dict\u001b[39;00m\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=649'>650</a>\u001b[0m config_dict \u001b[39m=\u001b[39m \u001b[39mcls\u001b[39;49m\u001b[39m.\u001b[39;49m_dict_from_json_file(resolved_config_file)\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=650'>651</a>\u001b[0m \u001b[39mexcept\u001b[39;00m (json\u001b[39m.\u001b[39mJSONDecodeError, \u001b[39mUnicodeDecodeError\u001b[39;00m):\n",
|
966 |
+
"File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py:733\u001b[0m, in \u001b[0;36mPretrainedConfig._dict_from_json_file\u001b[0;34m(cls, json_file)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=731'>732</a>\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mopen\u001b[39m(json_file, \u001b[39m\"\u001b[39m\u001b[39mr\u001b[39m\u001b[39m\"\u001b[39m, encoding\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mutf-8\u001b[39m\u001b[39m\"\u001b[39m) \u001b[39mas\u001b[39;00m reader:\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=732'>733</a>\u001b[0m text \u001b[39m=\u001b[39m reader\u001b[39m.\u001b[39;49mread()\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=733'>734</a>\u001b[0m \u001b[39mreturn\u001b[39;00m json\u001b[39m.\u001b[39mloads(text)\n",
|
967 |
+
"File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/codecs.py:322\u001b[0m, in \u001b[0;36mBufferedIncrementalDecoder.decode\u001b[0;34m(self, input, final)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/codecs.py?line=320'>321</a>\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbuffer \u001b[39m+\u001b[39m \u001b[39minput\u001b[39m\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/codecs.py?line=321'>322</a>\u001b[0m (result, consumed) \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_buffer_decode(data, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49merrors, final)\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/codecs.py?line=322'>323</a>\u001b[0m \u001b[39m# keep undecoded input until the next call\u001b[39;00m\n",
|
968 |
+
"\u001b[0;31mUnicodeDecodeError\u001b[0m: 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte",
|
969 |
+
"\nDuring handling of the above exception, another exception occurred:\n",
|
970 |
+
"\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)",
|
971 |
+
"\u001b[1;32m/home/chainyo/code/segformer-sidewalk/finetuning.ipynb Cell 23'\u001b[0m in \u001b[0;36m<cell line: 11>\u001b[0;34m()\u001b[0m\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000016?line=7'>8</a>\u001b[0m config \u001b[39m=\u001b[39m AutoConfig\u001b[39m.\u001b[39mfrom_pretrained(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mnvidia/mit-b0\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000016?line=8'>9</a>\u001b[0m config\u001b[39m.\u001b[39mpush_to_hub(\u001b[39m\"\u001b[39m\u001b[39msegformer-sidewalk\u001b[39m\u001b[39m\"\u001b[39m, repo_url\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mhttps://huggingface.co/ChainYo/segformer-sidewalk\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m---> <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000016?line=10'>11</a>\u001b[0m model \u001b[39m=\u001b[39m SegformerForSemanticSegmentation\u001b[39m.\u001b[39;49mfrom_pretrained(\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000016?line=11'>12</a>\u001b[0m \u001b[39m\"\u001b[39;49m\u001b[39m/home/chainyo/code/segformer-sidewalk/checkpoints/epoch=44-step=1125.ckpt\u001b[39;49m\u001b[39m\"\u001b[39;49m, \n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000016?line=12'>13</a>\u001b[0m num_labels\u001b[39m=\u001b[39;49mnum_labels, \n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000016?line=13'>14</a>\u001b[0m id2label\u001b[39m=\u001b[39;49mid2label, \n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000016?line=14'>15</a>\u001b[0m label2id\u001b[39m=\u001b[39;49mid2label,\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000016?line=15'>16</a>\u001b[0m )\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000016?line=16'>17</a>\u001b[0m model\u001b[39m.\u001b[39mpush_to_hub(\u001b[39m\"\u001b[39m\u001b[39msegformer-sidewalk\u001b[39m\u001b[39m\"\u001b[39m, repo_url\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mhttps://huggingface.co/ChainYo/segformer-sidewalk\u001b[39m\u001b[39m\"\u001b[39m)\n",
|
972 |
+
"File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py:1764\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1761'>1762</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(config, PretrainedConfig):\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1762'>1763</a>\u001b[0m config_path \u001b[39m=\u001b[39m config \u001b[39mif\u001b[39;00m config \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m pretrained_model_name_or_path\n\u001b[0;32m-> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1763'>1764</a>\u001b[0m config, model_kwargs \u001b[39m=\u001b[39m \u001b[39mcls\u001b[39;49m\u001b[39m.\u001b[39;49mconfig_class\u001b[39m.\u001b[39;49mfrom_pretrained(\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1764'>1765</a>\u001b[0m config_path,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1765'>1766</a>\u001b[0m cache_dir\u001b[39m=\u001b[39;49mcache_dir,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1766'>1767</a>\u001b[0m return_unused_kwargs\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1767'>1768</a>\u001b[0m force_download\u001b[39m=\u001b[39;49mforce_download,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1768'>1769</a>\u001b[0m resume_download\u001b[39m=\u001b[39;49mresume_download,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1769'>1770</a>\u001b[0m proxies\u001b[39m=\u001b[39;49mproxies,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1770'>1771</a>\u001b[0m local_files_only\u001b[39m=\u001b[39;49mlocal_files_only,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1771'>1772</a>\u001b[0m use_auth_token\u001b[39m=\u001b[39;49muse_auth_token,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1772'>1773</a>\u001b[0m revision\u001b[39m=\u001b[39;49mrevision,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1773'>1774</a>\u001b[0m _from_auto\u001b[39m=\u001b[39;49mfrom_auto_class,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1774'>1775</a>\u001b[0m _from_pipeline\u001b[39m=\u001b[39;49mfrom_pipeline,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1775'>1776</a>\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1776'>1777</a>\u001b[0m )\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1777'>1778</a>\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1778'>1779</a>\u001b[0m model_kwargs \u001b[39m=\u001b[39m kwargs\n",
|
973 |
+
"File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py:526\u001b[0m, in \u001b[0;36mPretrainedConfig.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=451'>452</a>\u001b[0m \u001b[39m@classmethod\u001b[39m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=452'>453</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mfrom_pretrained\u001b[39m(\u001b[39mcls\u001b[39m, pretrained_model_name_or_path: Union[\u001b[39mstr\u001b[39m, os\u001b[39m.\u001b[39mPathLike], \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mPretrainedConfig\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=453'>454</a>\u001b[0m \u001b[39mr\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=454'>455</a>\u001b[0m \u001b[39m Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration.\u001b[39;00m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=455'>456</a>\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=523'>524</a>\u001b[0m \u001b[39m assert unused_kwargs == {\"foo\": False}\u001b[39;00m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=524'>525</a>\u001b[0m \u001b[39m ```\"\"\"\u001b[39;00m\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=525'>526</a>\u001b[0m config_dict, kwargs \u001b[39m=\u001b[39m \u001b[39mcls\u001b[39;49m\u001b[39m.\u001b[39;49mget_config_dict(pretrained_model_name_or_path, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=526'>527</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m\"\u001b[39m\u001b[39mmodel_type\u001b[39m\u001b[39m\"\u001b[39m \u001b[39min\u001b[39;00m config_dict \u001b[39mand\u001b[39;00m \u001b[39mhasattr\u001b[39m(\u001b[39mcls\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mmodel_type\u001b[39m\u001b[39m\"\u001b[39m) \u001b[39mand\u001b[39;00m config_dict[\u001b[39m\"\u001b[39m\u001b[39mmodel_type\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m!=\u001b[39m \u001b[39mcls\u001b[39m\u001b[39m.\u001b[39mmodel_type:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=527'>528</a>\u001b[0m logger\u001b[39m.\u001b[39mwarning(\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=528'>529</a>\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mYou are using a model of type \u001b[39m\u001b[39m{\u001b[39;00mconfig_dict[\u001b[39m'\u001b[39m\u001b[39mmodel_type\u001b[39m\u001b[39m'\u001b[39m]\u001b[39m}\u001b[39;00m\u001b[39m to instantiate a model of type \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=529'>530</a>\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mcls\u001b[39m\u001b[39m.\u001b[39mmodel_type\u001b[39m}\u001b[39;00m\u001b[39m. This is not supported for all configurations of models and can yield errors.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=530'>531</a>\u001b[0m )\n",
|
974 |
+
"File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py:553\u001b[0m, in \u001b[0;36mPretrainedConfig.get_config_dict\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=550'>551</a>\u001b[0m original_kwargs \u001b[39m=\u001b[39m copy\u001b[39m.\u001b[39mdeepcopy(kwargs)\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=551'>552</a>\u001b[0m \u001b[39m# Get config dict associated with the base config file\u001b[39;00m\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=552'>553</a>\u001b[0m config_dict, kwargs \u001b[39m=\u001b[39m \u001b[39mcls\u001b[39;49m\u001b[39m.\u001b[39;49m_get_config_dict(pretrained_model_name_or_path, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=554'>555</a>\u001b[0m \u001b[39m# That config file may point us toward another config file to use.\u001b[39;00m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=555'>556</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m\"\u001b[39m\u001b[39mconfiguration_files\u001b[39m\u001b[39m\"\u001b[39m \u001b[39min\u001b[39;00m config_dict:\n",
|
975 |
+
"File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py:652\u001b[0m, in \u001b[0;36mPretrainedConfig._get_config_dict\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=649'>650</a>\u001b[0m config_dict \u001b[39m=\u001b[39m \u001b[39mcls\u001b[39m\u001b[39m.\u001b[39m_dict_from_json_file(resolved_config_file)\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=650'>651</a>\u001b[0m \u001b[39mexcept\u001b[39;00m (json\u001b[39m.\u001b[39mJSONDecodeError, \u001b[39mUnicodeDecodeError\u001b[39;00m):\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=651'>652</a>\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mEnvironmentError\u001b[39;00m(\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=652'>653</a>\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mIt looks like the config file at \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00mresolved_config_file\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m is not a valid JSON file.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=653'>654</a>\u001b[0m )\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=655'>656</a>\u001b[0m \u001b[39mif\u001b[39;00m resolved_config_file \u001b[39m==\u001b[39m config_file:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=656'>657</a>\u001b[0m logger\u001b[39m.\u001b[39minfo(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mloading configuration file \u001b[39m\u001b[39m{\u001b[39;00mconfig_file\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n",
|
976 |
+
"\u001b[0;31mOSError\u001b[0m: It looks like the config file at '/home/chainyo/code/segformer-sidewalk/checkpoints/epoch=44-step=1125.ckpt' is not a valid JSON file."
|
977 |
+
]
|
978 |
+
}
|
979 |
+
],
|
980 |
+
"source": [
|
981 |
+
"import json\n",
|
982 |
+
"from transformers import AutoConfig\n",
|
983 |
+
"\n",
|
984 |
+
"id2label_file = json.load(open(\"id2label.json\", \"r\"))\n",
|
985 |
+
"id2label = {int(k): v for k, v in id2label_file.items()}\n",
|
986 |
+
"num_labels = len(id2label)\n",
|
987 |
+
"\n",
|
988 |
+
"config = AutoConfig.from_pretrained(f\"nvidia/mit-b0\")\n",
|
989 |
+
"config.push_to_hub(\".\", repo_url=\"https://huggingface.co/ChainYo/segformer-sidewalk\")\n",
|
990 |
+
"\n",
|
991 |
+
"model = SegformerForSemanticSegmentation.from_pretrained(\n",
|
992 |
+
" \"/home/chainyo/code/segformer-sidewalk/checkpoints/epoch=44-step=1125.ckpt\", \n",
|
993 |
+
" num_labels=num_labels, \n",
|
994 |
+
" id2label=id2label, \n",
|
995 |
+
" label2id=id2label,\n",
|
996 |
+
" config=config,\n",
|
997 |
+
")\n",
|
998 |
+
"model.push_to_hub(\".\", repo_url=\"https://huggingface.co/ChainYo/segformer-sidewalk\")"
|
999 |
+
]
|
1000 |
+
},
|
1001 |
{
|
1002 |
"cell_type": "code",
|
1003 |
"execution_count": null,
|
training_loop.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
"""
|
2 |
Minimal command:
|
3 |
-
python training_loop.py --hub_dir "segments/sidewalk-semantic"
|
4 |
|
5 |
Maximal command:
|
6 |
-
python training_loop.py --hub_dir "segments/sidewalk-semantic" --batch_size 32 --learning_rate 6e-5 --model_flavor 0 --seed 42 --split train
|
7 |
"""
|
8 |
|
9 |
import json
|
@@ -12,6 +12,8 @@ import torch
|
|
12 |
from pytorch_lightning import Trainer, callbacks, seed_everything
|
13 |
from pytorch_lightning.loggers import WandbLogger
|
14 |
|
|
|
|
|
15 |
from dataloader import SidewalkSegmentationDataLoader
|
16 |
from model import SidewalkSegmentationModel
|
17 |
|
@@ -23,6 +25,7 @@ def main(
|
|
23 |
model_flavor: int = 0,
|
24 |
seed: int = 42,
|
25 |
split: str = "train",
|
|
|
26 |
):
|
27 |
seed_everything(seed)
|
28 |
logger = WandbLogger(project="sidewalk-segmentation")
|
@@ -69,6 +72,15 @@ def main(
|
|
69 |
)
|
70 |
trainer.fit(model, data_module)
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
if __name__ == "__main__":
|
74 |
import argparse
|
@@ -80,6 +92,7 @@ if __name__ == "__main__":
|
|
80 |
parser.add_argument("--model_flavor", type=int, default=0)
|
81 |
parser.add_argument("--seed", type=int, default=42)
|
82 |
parser.add_argument("--split", type=str, default="train")
|
|
|
83 |
args = parser.parse_args()
|
84 |
|
85 |
main(
|
@@ -89,5 +102,5 @@ if __name__ == "__main__":
|
|
89 |
model_flavor=args.model_flavor,
|
90 |
seed=args.seed,
|
91 |
split=args.split,
|
|
|
92 |
)
|
93 |
-
|
|
|
1 |
"""
|
2 |
Minimal command:
|
3 |
+
python training_loop.py --hub_dir "segments/sidewalk-semantic" --push_to_hub
|
4 |
|
5 |
Maximal command:
|
6 |
+
python training_loop.py --hub_dir "segments/sidewalk-semantic" --batch_size 32 --learning_rate 6e-5 --model_flavor 0 --seed 42 --split train --push_to_hub
|
7 |
"""
|
8 |
|
9 |
import json
|
|
|
12 |
from pytorch_lightning import Trainer, callbacks, seed_everything
|
13 |
from pytorch_lightning.loggers import WandbLogger
|
14 |
|
15 |
+
from transformers import AutoConfig, SegformerForSemanticSegmentation, SegformerFeatureExtractor
|
16 |
+
|
17 |
from dataloader import SidewalkSegmentationDataLoader
|
18 |
from model import SidewalkSegmentationModel
|
19 |
|
|
|
25 |
model_flavor: int = 0,
|
26 |
seed: int = 42,
|
27 |
split: str = "train",
|
28 |
+
push_to_hub: bool = False,
|
29 |
):
|
30 |
seed_everything(seed)
|
31 |
logger = WandbLogger(project="sidewalk-segmentation")
|
|
|
72 |
)
|
73 |
trainer.fit(model, data_module)
|
74 |
|
75 |
+
if push_to_hub:
|
76 |
+
config = AutoConfig.from_pretrained(f"nvidia/mit-b{model_flavor}")
|
77 |
+
config.push_to_hub("segformer-sidewalk", repo_url="https://huggingface.co/ChainYo/segformer-sidewalk")
|
78 |
+
checkpoint_path = checkpoint_callback.best_model_filepath
|
79 |
+
model = SegformerForSemanticSegmentation.from_pretrained(
|
80 |
+
checkpoint_path, num_labels=num_labels, id2label=id2label, label2id=id2label, config=config,
|
81 |
+
)
|
82 |
+
model.push_to_hub("segformer-sidewalk", repo_url="https://huggingface.co/ChainYo/segformer-sidewalk")
|
83 |
+
|
84 |
|
85 |
if __name__ == "__main__":
|
86 |
import argparse
|
|
|
92 |
parser.add_argument("--model_flavor", type=int, default=0)
|
93 |
parser.add_argument("--seed", type=int, default=42)
|
94 |
parser.add_argument("--split", type=str, default="train")
|
95 |
+
parser.add_argument("--push_to_hub", action="store_true")
|
96 |
args = parser.parse_args()
|
97 |
|
98 |
main(
|
|
|
102 |
model_flavor=args.model_flavor,
|
103 |
seed=args.seed,
|
104 |
split=args.split,
|
105 |
+
push_to_hub=args.push_to_hub,
|
106 |
)
|
|