import gradio as gr from transformers import MobileViTFeatureExtractor, MobileViTForSemanticSegmentation from PIL import Image import numpy as np feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/deeplabv3-mobilevit-small") model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small") #(21 classes) COLORS = np.array([ [0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128], [128, 64, 128] ], dtype=np.uint8) # Ensure the data type is uint8 for image processing def segment_image(image): inputs = feature_extractor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits predicted_mask = logits.argmax(1).squeeze(0).numpy() colored_mask = COLORS[predicted_mask] colored_mask_image = Image.fromarray(colored_mask) colored_mask_resized = colored_mask_image.resize(image.size, Image.NEAREST) return colored_mask_resized interface = gr.Interface( fn=segment_image, inputs=gr.Image(type="pil"), outputs="image", title="Image Segmentation with MobileViT", description="Upload an image to see the semantic segmentation result. The segmentation mask uses different colors to indicate different classes.", ) interface.launch(share=True)