|
import gradio as gr |
|
import model |
|
from config import app_config |
|
|
|
|
|
def init(): |
|
if model != None: |
|
print("Initializing App...") |
|
app_config.model = model.load_model() |
|
|
|
|
|
def clear(): |
|
return None, 2, None, None, None |
|
|
|
|
|
def create_interface(): |
|
md = """ |
|
# Famous Landmark Classifier using CNN |
|
**Choose an image containing any of the `50 possible classes` of world famous landmarks,** |
|
**choose the number of prediction required (k) and hit `Predict`, model will try to identify** |
|
**the landmark in the image.** |
|
**Please note that the model is trained on a small set of only 4,000 images hence it may not** |
|
**be right all the time, but its fun to try out.** |
|
Visit the [project's repo](https://github.com/sssingh/landmark-classification-tagging) |
|
""" |
|
with gr.Blocks( |
|
title=app_config.title, theme=app_config.theme, css=app_config.css |
|
) as app: |
|
with gr.Row(): |
|
gr.Markdown(md) |
|
with gr.Accordion( |
|
"Expand to see 50 classes:", open=False, elem_classes="accordion" |
|
): |
|
gr.JSON(app_config.classes, elem_classes="json-box") |
|
with gr.Row(): |
|
with gr.Column(): |
|
img = gr.Image(type="pil", elem_classes="image-picker") |
|
k = gr.Slider( |
|
label="Number of predictions (k):", |
|
minimum=2, |
|
maximum=5, |
|
value=2, |
|
step=1, |
|
elem_classes="slider", |
|
) |
|
with gr.Row(): |
|
submit_btn = gr.Button( |
|
"Predict", |
|
icon="assets/button-icon.png", |
|
elem_classes="submit-button", |
|
) |
|
clear_btn = gr.ClearButton(elem_classes="clear-button") |
|
with gr.Column(): |
|
landmarks = gr.JSON( |
|
label="Predicted Landmarks:", elem_classes="json-box" |
|
) |
|
proba = gr.JSON( |
|
label="Predicted Probabilities:", elem_classes="json-box" |
|
) |
|
plot = gr.Plot(container=True, elem_classes="plot") |
|
with gr.Row(): |
|
with gr.Accordion( |
|
"Expand for examples:", open=False, elem_classes="accordion" |
|
): |
|
gr.Examples( |
|
examples=[ |
|
["assets/examples/gateway-of-india.jpg", 3], |
|
["assets/examples/grand-canyon.jpg", 2], |
|
["assets/examples/opera-house.jpg", 3], |
|
["assets/examples/stone-henge.jpg", 4], |
|
["assets/examples/temple-of-zeus.jpg", 5], |
|
], |
|
inputs=[img, k], |
|
outputs=[landmarks, proba], |
|
elem_id="examples", |
|
cache_examples=True, |
|
) |
|
submit_btn.click( |
|
fn=model.predict, inputs=[img, k], outputs=[landmarks, proba, plot] |
|
) |
|
clear_btn.click(fn=clear, inputs=[], outputs=[img, k, landmarks, proba, plot]) |
|
img.clear(fn=clear, inputs=[], outputs=[img, k, landmarks, proba, plot]) |
|
return app |
|
|
|
|
|
if __name__ == "__main__": |
|
init() |
|
app = create_interface() |
|
app.launch() |
|
|