File size: 3,372 Bytes
0362b22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6d3b8f
0362b22
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
import model
from config import app_config


def init():
    if model != None:
        print("Initializing App...")
    app_config.model = model.load_model()


def clear():
    return None, 2, None, None, None


def create_interface():
    md = """
       # Famous Landmark Classifier using CNN
       **Choose an image containing any of the `50 possible classes` of world famous landmarks,** 
       **choose the number of prediction required (k) and hit `Predict`, model will try to identify**
       **the landmark in the image.** 
       **Please note that the model is trained on a small set of only 4,000 images hence it may not** 
       **be right all the time, but its fun to try out.**  
       Visit the [project's repo](https://github.com/sssingh/landmark-classification-tagging)
       """
    with gr.Blocks(
        title=app_config.title, theme=app_config.theme, css=app_config.css
    ) as app:
        with gr.Row():
            gr.Markdown(md)
            with gr.Accordion(
                "Expand to see 50 classes:", open=False, elem_classes="accordion"
            ):
                gr.JSON(app_config.classes, elem_classes="json-box")
        with gr.Row():
            with gr.Column():
                img = gr.Image(type="pil", elem_classes="image-picker")
                k = gr.Slider(
                    label="Number of predictions (k):",
                    minimum=2,
                    maximum=5,
                    value=2,
                    step=1,
                    elem_classes="slider",
                )
                with gr.Row():
                    submit_btn = gr.Button(
                        "Predict",
                        icon="assets/button-icon.png",
                        elem_classes="submit-button",
                    )
                    clear_btn = gr.ClearButton(elem_classes="clear-button")
            with gr.Column():
                landmarks = gr.JSON(
                    label="Predicted Landmarks:", elem_classes="json-box"
                )
                proba = gr.JSON(
                    label="Predicted Probabilities:", elem_classes="json-box"
                )
                plot = gr.Plot(container=True, elem_classes="plot")
        with gr.Row():
            with gr.Accordion(
                "Expand for examples:", open=False, elem_classes="accordion"
            ):
                gr.Examples(
                    examples=[
                        ["assets/examples/gateway-of-india.jpg", 3],
                        ["assets/examples/grand-canyon.jpg", 2],
                        ["assets/examples/opera-house.jpg", 3],
                        ["assets/examples/stone-henge.jpg", 4],
                        ["assets/examples/temple-of-zeus.jpg", 5],
                    ],
                    inputs=[img, k],
                    outputs=[landmarks, proba],
                    elem_id="examples",
                    cache_examples=True,
                )
        submit_btn.click(
            fn=model.predict, inputs=[img, k], outputs=[landmarks, proba, plot]
        )
        clear_btn.click(fn=clear, inputs=[], outputs=[img, k, landmarks, proba, plot])
        img.clear(fn=clear, inputs=[], outputs=[img, k, landmarks, proba, plot])
    return app


if __name__ == "__main__":
    init()
    app = create_interface()
    app.launch()