|
--- |
|
library_name: keras-hub |
|
tags: |
|
- image-segmentation |
|
- keras |
|
--- |
|
## Model Overview |
|
A Keras model implementing the MixTransformer architecture to be used as a backbone for the SegFormer architecture. This model is supported in both KerasCV and KerasHub. KerasCV will no longer be actively developed, so please try to use KerasHub. |
|
|
|
References: |
|
- [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) # noqa: E501 |
|
- [Based on the TensorFlow implementation from DeepVision](https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/classification/mix_transformer) # noqa: E501 |
|
|
|
## Links |
|
* [MiT Quickstart Notebook: coming soon]() |
|
* [MiT API Documentation: coming soon]() |
|
|
|
## Installation |
|
|
|
Keras and KerasHub can be installed with: |
|
|
|
``` |
|
pip install -U -q keras-Hub |
|
pip install -U -q keras>=3 |
|
``` |
|
|
|
Jax, TensorFlow, and Torch come preinstalled in Kaggle Notebooks. For instructions on installing them in another environment see the [Keras Getting Started](https://keras.io/getting_started/) page. |
|
|
|
## Presets |
|
|
|
The following model checkpoints are provided by the Keras team. Weights have been ported from https://dl.fbaipublicfiles.com/segment_anything/. Full code examples for each are available below. |
|
Here's the table formatted similarly to the given pattern: |
|
|
|
Here's the updated table with the input resolutions included in the descriptions: |
|
|
|
| Preset name | Parameters | Description | |
|
|--------------------------|------------|--------------------------------------------------------------------------------------------------| |
|
| mit_b0_ade20k_512 | 3.32M | MiT (MixTransformer) model with 8 transformer blocks, trained on the ADE20K dataset with an input resolution of 512x512 pixels. | |
|
| mit_b1_ade20k_512 | 13.16M | MiT (MixTransformer) model with 8 transformer blocks, trained on the ADE20K dataset with an input resolution of 512x512 pixels. | |
|
| mit_b2_ade20k_512 | 24.20M | MiT (MixTransformer) model with 16 transformer blocks, trained on the ADE20K dataset with an input resolution of 512x512 pixels. | |
|
| mit_b3_ade20k_512 | 44.08M | MiT (MixTransformer) model with 28 transformer blocks, trained on the ADE20K dataset with an input resolution of 512x512 pixels. | |
|
| mit_b4_ade20k_512 | 60.85M | MiT (MixTransformer) model with 41 transformer blocks, trained on the ADE20K dataset with an input resolution of 512x512 pixels. | |
|
| mit_b5_ade20k_640 | 81.45M | MiT (MixTransformer) model with 52 transformer blocks, trained on the ADE20K dataset with an input resolution of 640x640 pixels. | |
|
| mit_b0_cityscapes_1024 | 3.32M | MiT (MixTransformer) model with 8 transformer blocks, trained on the Cityscapes dataset with an input resolution of 1024x1024 pixels. | |
|
| mit_b1_cityscapes_1024 | 13.16M | MiT (MixTransformer) model with 8 transformer blocks, trained on the Cityscapes dataset with an input resolution of 1024x1024 pixels. | |
|
| mit_b2_cityscapes_1024 | 24.20M | MiT (MixTransformer) model with 16 transformer blocks, trained on the Cityscapes dataset with an input resolution of 1024x1024 pixels. | |
|
| mit_b3_cityscapes_1024 | 44.08M | MiT (MixTransformer) model with 28 transformer blocks, trained on the Cityscapes dataset with an input resolution of 1024x1024 pixels. | |
|
| mit_b4_cityscapes_1024 | 60.85M | MiT (MixTransformer) model with 41 transformer blocks, trained on the Cityscapes dataset with an input resolution of 1024x1024 pixels. | |
|
| mit_b5_cityscapes_1024 | 81.45M | MiT (MixTransformer) model with 52 transformer blocks, trained on the Cityscapes dataset with an input resolution of 1024x1024 pixels. | |
|
|
|
## Example Usage |
|
Using the class with a `backbone`: |
|
|
|
``` |
|
import tensorflow as tf |
|
import keras_cv |
|
import numpy as np |
|
|
|
images = np.ones(shape=(1, 96, 96, 3)) |
|
labels = np.zeros(shape=(1, 96, 96, 1)) |
|
backbone = keras_cv.models.MiTBackbone.from_preset("mit_b4_cityscapes_1024") |
|
|
|
# Evaluate model |
|
model(images) |
|
|
|
# Train model |
|
model.compile( |
|
optimizer="adam", |
|
loss=keras.losses.BinaryCrossentropy(from_logits=False), |
|
metrics=["accuracy"], |
|
) |
|
model.fit(images, labels, epochs=3) |
|
``` |
|
|
|
## Example Usage with Hugging Face URI |
|
|
|
Using the class with a `backbone`: |
|
|
|
``` |
|
import tensorflow as tf |
|
import keras_cv |
|
import numpy as np |
|
|
|
images = np.ones(shape=(1, 96, 96, 3)) |
|
labels = np.zeros(shape=(1, 96, 96, 1)) |
|
backbone = keras_cv.models.MiTBackbone.from_preset("hf://keras/mit_b4_cityscapes_1024") |
|
|
|
# Evaluate model |
|
model(images) |
|
|
|
# Train model |
|
model.compile( |
|
optimizer="adam", |
|
loss=keras.losses.BinaryCrossentropy(from_logits=False), |
|
metrics=["accuracy"], |
|
) |
|
model.fit(images, labels, epochs=3) |
|
``` |