File size: 1,677 Bytes
95fcb89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
from transformers import MobileViTFeatureExtractor, MobileViTForSemanticSegmentation
from PIL import Image
import numpy as np

feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/deeplabv3-mobilevit-small")
model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small")

#(21 classes)
COLORS = np.array([
    [0, 0, 0],      
    [128, 0, 0], 
    [0, 128, 0],    
    [128, 128, 0],  
    [0, 0, 128],   
    [128, 0, 128],  
    [0, 128, 128],  
    [128, 128, 128],
    [64, 0, 0],     
    [192, 0, 0],    
    [64, 128, 0],   
    [192, 128, 0], 
    [64, 0, 128],   
    [192, 0, 128],  
    [64, 128, 128], 
    [192, 128, 128],
    [0, 64, 0],     
    [128, 64, 0],   
    [0, 192, 0],    
    [128, 192, 0],  
    [0, 64, 128],   
    [128, 64, 128]  
], dtype=np.uint8)  # Ensure the data type is uint8 for image processing

def segment_image(image):
    inputs = feature_extractor(images=image, return_tensors="pt")
    
    outputs = model(**inputs)
    logits = outputs.logits
    predicted_mask = logits.argmax(1).squeeze(0).numpy()
    
    colored_mask = COLORS[predicted_mask]
    colored_mask_image = Image.fromarray(colored_mask)
    colored_mask_resized = colored_mask_image.resize(image.size, Image.NEAREST)
    
    return colored_mask_resized

interface = gr.Interface(
    fn=segment_image,  
    inputs=gr.Image(type="pil"),  
    outputs="image",  
    title="Image Segmentation with MobileViT",
    description="Upload an image to see the semantic segmentation result. The segmentation mask uses different colors to indicate different classes.",
)

interface.launch(share=True)