File size: 4,573 Bytes
5e03e65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9271906
 
 
30ef691
 
9271906
5e03e65
 
 
 
 
 
30ef691
5e03e65
 
30ef691
 
 
 
 
 
 
 
 
 
 
 
 
 
5e03e65
 
 
 
 
 
 
 
 
 
783a0a3
 
5e03e65
 
 
783a0a3
5e03e65
783a0a3
 
30ef691
783a0a3
 
 
 
5e03e65
 
783a0a3
 
 
 
 
 
 
5e03e65
 
 
783a0a3
 
 
 
30ef691
 
 
 
 
 
 
 
5e03e65
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

import torch

from transformers import AutoImageProcessor, AutoModelForObjectDetection
#from transformers import pipeline

from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import io
from random import choice


image_processor_tiny = AutoImageProcessor.from_pretrained("hustvl/yolos-tiny")
model_tiny = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-tiny")

image_processor_small = AutoImageProcessor.from_pretrained("hustvl/yolos-small")
model_small = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-small")


import gradio as gr


COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
            "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
            "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]

fdic = {
    "family" : "Impact",
    "style" : "italic",
    "size" : 15,
    "color" : "yellow",
    "weight" : "bold"
}


def get_figure(in_pil_img, in_results):
    plt.figure(figsize=(16, 10))
    plt.imshow(in_pil_img)
    ax = plt.gca()

    for score, label, box in zip(in_results["scores"], in_results["labels"], in_results["boxes"]):
        selected_color = choice(COLORS)

        box_int = [i.item() for i in torch.round(box).to(torch.int32)]
        x, y, w, h = box[0], box[1], box[2]-box[0], box[3]-box[1]
        #x, y, w, h = torch.round(box[0]).item(), torch.round(box[1]).item(), torch.round(box[2]-box[0]).item(), torch.round(box[3]-box[1]).item()

        ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=2))
        ax.text(x, y, f"{model_tiny.config.id2label[label.item()]}: {round(score.item()*100, 2)}%", fontdict=fdic, alpha=0.8)

    plt.axis("off")

    return plt.gcf()


def infer(in_pil_img, in_model="yolos-tiny", in_threshold=0.9):
    target_sizes = torch.tensor([in_pil_img.size[::-1]])

    if in_model == "yolos-small":
        inputs = image_processor_small(images=in_pil_img, return_tensors="pt")
        outputs = model_small(**inputs)

        # convert outputs (bounding boxes and class logits) to COCO API
        results = image_processor_small.post_process_object_detection(outputs, threshold=in_threshold, target_sizes=target_sizes)[0]

    else:
        inputs = image_processor_tiny(images=in_pil_img, return_tensors="pt")
        outputs = model_tiny(**inputs)

        # convert outputs (bounding boxes and class logits) to COCO API
        results = image_processor_tiny.post_process_object_detection(outputs, threshold=in_threshold, target_sizes=target_sizes)[0]

    figure = get_figure(in_pil_img, results)

    buf = io.BytesIO()
    figure.savefig(buf, bbox_inches='tight')
    buf.seek(0)
    output_pil_img = Image.open(buf)

    return output_pil_img


with gr.Blocks(title="YOLOS Object Detection - ClassCat",
                    css=".gradio-container {background:lightyellow;}"
               ) as demo:
    #sample_index = gr.State([])

    gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">YOLOS Object Detection</div>""")

    gr.HTML("""<h4 style="color:navy;">1. Select a model.</h4>""")

    model = gr.Radio(["yolos-tiny", "yolos-small"], value="yolos-tiny", label="Model name")

    gr.HTML("""<br/>""")
    gr.HTML("""<h4 style="color:navy;">2-a. Select an example by clicking a thumbnail below.</h4>""")
    gr.HTML("""<h4 style="color:navy;">2-b. Or upload an image by clicking on the canvas.</h4>""")

    with gr.Row():
        input_image = gr.Image(label="Input image", type="pil")
        output_image = gr.Image(label="Output image with predicted instances", type="pil")

    gr.Examples(['samples/cats.jpg', 'samples/detectron2.png', 'samples/cat.jpg', 'samples/hotdog.jpg'], inputs=input_image)

    gr.HTML("""<br/>""")
    gr.HTML("""<h4 style="color:navy;">3. Set threshold value (default to 0.9)</h4>""")

    threshold = gr.Slider(0, 1.0, value=0.9, label='threshold')

    gr.HTML("""<br/>""")
    gr.HTML("""<h4 style="color:navy;">4. Then, click "Infer" button to predict object instances. It will take about 10 seconds (on cpu)</h4>""")

    send_btn = gr.Button("Infer")
    send_btn.click(fn=infer, inputs=[input_image, model, threshold], outputs=[output_image])

    gr.HTML("""<br/>""")
    gr.HTML("""<h4 style="color:navy;">Reference</h4>""")
    gr.HTML("""<ul>""")
    gr.HTML("""<li><a href="https://huggingface.co/docs/transformers/model_doc/yolos" target="_blank">Hugging Face Transformers - YOLOS</a>""")
    gr.HTML("""</ul>""")


#demo.queue()
demo.launch(debug=True)




### EOF ###