File size: 4,209 Bytes
8037bf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13b7ef2
8037bf7
 
 
9594fde
8037bf7
9594fde
8037bf7
 
 
 
 
 
 
2120a9a
 
8037bf7
9594fde
8037bf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2120a9a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import gradio as gr
from PIL import Image
import torch
import matplotlib.pyplot as plt
import cv2
import torch
import numpy as np

processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")


def process_image(image, prompt, threhsold, alpha_value, draw_rectangles):
    inputs = processor(
        text=prompt, images=image, padding="max_length", return_tensors="pt"
    )

    # predict
    with torch.no_grad():
        outputs = model(**inputs)
        preds = outputs.logits

    pred = torch.sigmoid(preds)
    mat = pred.cpu().numpy()
    mask = Image.fromarray(np.uint8(mat * 255), "L")
    mask = mask.convert("RGB")
    mask = mask.resize(image.size)
    mask = np.array(mask)[:, :, 0]

    # normalize the mask
    mask_min = mask.min()
    mask_max = mask.max()
    mask = (mask - mask_min) / (mask_max - mask_min)

    # threshold the mask
    bmask = mask > threhsold
    # zero out values below the threshold
    mask[mask < threhsold] = 0

    fig, ax = plt.subplots()
    ax.imshow(image)
    ax.imshow(mask, alpha=alpha_value, cmap="jet")

    if draw_rectangles:
        contours, hierarchy = cv2.findContours(
            bmask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
        )
        for contour in contours:
            x, y, w, h = cv2.boundingRect(contour)
            rect = plt.Rectangle(
                (x, y), w, h, fill=False, edgecolor="yellow", linewidth=2
            )
            ax.add_patch(rect)

    ax.axis("off")
    plt.tight_layout()

    bmask = Image.fromarray(bmask.astype(np.uint8) * 255, "L")
    output_image = Image.new("RGBA", image.size, (0, 0, 0, 0))
    output_image.paste(image, mask=bmask)

    return fig, mask, output_image


title = "Interactive demo: zero-shot image segmentation with CLIPSeg"
description = "Demo for using CLIPSeg, a CLIP-based model for zero- and one-shot image segmentation. To use it, simply upload an image and add a text to mask (identify in the image), or use one of the examples below and click 'submit'. Results will show up in a few seconds."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.10003'>CLIPSeg: Image Segmentation Using Text and Image Prompts</a> | <a href='https://huggingface.co/docs/transformers/main/en/model_doc/clipseg'>HuggingFace docs</a></p>"


with gr.Blocks() as demo:
    gr.Markdown("# CLIPSeg: Image Segmentation Using Text and Image Prompts")
    gr.Markdown(article)
    gr.Markdown(description)
    gr.Markdown(
        "*Example images are taken from the [ImageNet-A](https://paperswithcode.com/dataset/imagenet-a) dataset*"
    )

    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil")
            input_prompt = gr.Textbox(label="Please describe what you want to identify")
            input_slider_T = gr.Slider(
                minimum=0, maximum=1, value=0.4, label="Threshold"
            )
            input_slider_A = gr.Slider(minimum=0, maximum=1, value=0.5, label="Alpha")
            draw_rectangles = gr.Checkbox(label="Draw rectangles")
            btn_process = gr.Button(label="Process")

        with gr.Column():
            output_plot = gr.Plot(label="Segmentation Result")
            output_mask = gr.Image(label="Mask")
            output_image = gr.Image(label="Output Image")

    btn_process.click(
        process_image,
        inputs=[
            input_image,
            input_prompt,
            input_slider_T,
            input_slider_A,
            draw_rectangles,
        ],
        outputs=[output_mask],api_name="masking"
    )

    gr.Examples(
        [
            ["0.003473_cliff _ cliff_0.51112.jpg", "dog", 0.5, 0.5, True],
            ["0.001861_submarine _ submarine_0.9862991.jpg", "beacon", 0.55, 0.4, True],
            ["0.004658_spatula _ spatula_0.35416836.jpg", "banana", 0.4, 0.5, True],
        ],
        inputs=[
            input_image,
            input_prompt,
            input_slider_T,
            input_slider_A,
            draw_rectangles,
        ],
    )

demo.launch()