|
import gradio as gr |
|
from transformers import MobileViTFeatureExtractor, MobileViTForSemanticSegmentation |
|
from PIL import Image |
|
import numpy as np |
|
|
|
feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/deeplabv3-mobilevit-small") |
|
model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small") |
|
|
|
|
|
COLORS = np.array([ |
|
[0, 0, 0], |
|
[128, 0, 0], |
|
[0, 128, 0], |
|
[128, 128, 0], |
|
[0, 0, 128], |
|
[128, 0, 128], |
|
[0, 128, 128], |
|
[128, 128, 128], |
|
[64, 0, 0], |
|
[192, 0, 0], |
|
[64, 128, 0], |
|
[192, 128, 0], |
|
[64, 0, 128], |
|
[192, 0, 128], |
|
[64, 128, 128], |
|
[192, 128, 128], |
|
[0, 64, 0], |
|
[128, 64, 0], |
|
[0, 192, 0], |
|
[128, 192, 0], |
|
[0, 64, 128], |
|
[128, 64, 128] |
|
], dtype=np.uint8) |
|
|
|
def segment_image(image): |
|
inputs = feature_extractor(images=image, return_tensors="pt") |
|
|
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
predicted_mask = logits.argmax(1).squeeze(0).numpy() |
|
|
|
colored_mask = COLORS[predicted_mask] |
|
colored_mask_image = Image.fromarray(colored_mask) |
|
colored_mask_resized = colored_mask_image.resize(image.size, Image.NEAREST) |
|
|
|
return colored_mask_resized |
|
|
|
interface = gr.Interface( |
|
fn=segment_image, |
|
inputs=gr.Image(type="pil"), |
|
outputs="image", |
|
title="Image Segmentation with MobileViT", |
|
description="Upload an image to see the semantic segmentation result. The segmentation mask uses different colors to indicate different classes.", |
|
) |
|
|
|
interface.launch(share=True) |
|
|
|
|