Tanaanan's picture
Update app.py
71732b9 verified
import gradio as gr
import os
import numpy as np
import pandas as pd
import torch
import timm
from PIL import Image
from torchvision import transforms
import gradio
import warnings
warnings.filterwarnings("ignore")
model = timm.create_model('swinv2_cr_tiny_ns_224.sw_in1k', pretrained=True)
output_shape = 60
model.classifier = torch.nn.Sequential(
torch.nn.Dropout(p=0.2, inplace=True),
torch.nn.Linear(in_features=1000,
out_features=output_shape,
bias=True).to('cpu'))
model.load_state_dict(torch.load('./swin_70_65.pth', map_location=torch.device('cpu')))
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
class_names = [
'Ahaetulla_prasina', 'Bitis_arietans', 'Boa_constrictor', 'Boa_imperator',
'Bothriechis_schlegelii', 'Bothrops_asper', 'Bothrops_atrox', 'Bungarus_fasciatus',
'Chrysopelea_ornata', 'Coelognathus_radiatus', 'Corallus_hortulana', 'Coronella_austriaca',
'Crotaphopeltis_hotamboeia', 'Dendrelaphis_pictus', 'Dolichophis_caspius', 'Drymarchon_melanurus',
'Drymobius_margaritiferus', 'Elaphe_dione', 'Epicrates_cenchria', 'Erythrolamprus_poecilogyrus',
'Eunectes_murinus', 'Fowlea_flavipunctata', 'Gonyosoma_oxycephalum', 'Helicops_angulatus',
'Hierophis_viridiflavus', 'Imantodes_cenchoa', 'Indotyphlops_braminus', 'Laticauda_colubrina',
'Leptodeira_annulata', 'Leptodeira_ornata', 'Leptodeira_septentrionalis', 'Leptophis_ahaetulla',
'Leptophis_mexicanus', 'Lycodon_capucinus', 'Malayopython_reticulatus', 'Malpolon_insignitus',
'Mastigodryas_boddaerti', 'Natrix_helvetica', 'Natrix_maura', 'Natrix_natrix', 'Natrix_tessellata',
'Ninia_sebae', 'Ophiophagus_hannah', 'Oxybelis_aeneus', 'Oxybelis_fulgidus', 'Oxyrhopus_petolarius',
'Phrynonax_poecilonotus', 'Psammodynastes_pulverulentus', 'Ptyas_korros', 'Ptyas_mucosa',
'Python_bivittatus', 'Rhabdophis_tigrinus', 'Sibon_nebulatus', 'Spilotes_pullatus',
'Tantilla_melanocephala', 'Trimeresurus_albolabris', 'Vipera_ammodytes', 'Vipera_aspis',
'Vipera_berus', 'Zamenis_longissimus'
]
def predict(image):
if image is None:
return "No image provided."
try:
input_tensor = preprocess(image)
except Exception as e:
return f"Error in preprocessing: {str(e)}"
input_batch = input_tensor.unsqueeze(0).to('cpu')
try:
with torch.no_grad():
output = model(input_batch)
except Exception as e:
return f"Error in model inference: {str(e)}"
probabilities = torch.nn.functional.softmax(output, dim=1)
percentages = probabilities[0].cpu().numpy() * 100
top_n = 5
combined = list(zip(class_names, percentages))
sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True)
top_predictions = sorted_combined[:top_n]
# Generate HTML for progress bars with numbers above
html_content = "<div style='font-family: Arial, sans-serif;'>"
for class_label, confidence in top_predictions:
html_content += f"""
<div style='margin-bottom: 10px; position: relative;'>
<div style='display: flex; align-items: center;'>
<strong style='flex: 1;'>{class_label}</strong>
<span style='flex-shrink: 0; color: black; margin-left: 10px;'>
{confidence:.2f}%
</span>
</div>
<div style='background-color: #f3f3f3; border-radius: 5px; width: 100%; height: 20px; margin-top: 5px;'>
<div style='background-color: #4CAF50; height: 100%; width: {confidence:.2f}%; border-radius: 5px;'></div>
</div>
</div>
"""
html_content += "</div>"
return html_content
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type='pil'),
outputs=gr.HTML(),
title="Snake Species Classification (SnakeCLEF2024)",
description = """
<div style='font-family: Arial, sans-serif; line-height: 1.6;'>
<p>Datasets and classes are referenced from: <a href="https://www.imageclef.org/node/319" target="_blank">ImageCLEF</a> and for more details: <a href="https://github.com/Tanaanan/SnakeCLEF2024_MLCS" target="_blank">GitHub repository</a></p>
<p style='font-size: smaller;'>This project is part of the course 'Machine Learning Systems (01418262).'</p>
<p style='font-size: smaller;'>Developed by Tanaanan, Narakorn, Department of Computer Science, Kasetsart University.</p>
</div>
""",
examples=['./sample_imgs/Bitis_arietans.png',
'./sample_imgs/Boa_imperator.png',
'./sample_imgs/Coelognathus_radiatus.png',
'./sample_imgs/Leptodeira_septentrionalis.png',
'./sample_imgs/Natrix_tessellata.png',
'./sample_imgs/Psammodynastes_pulverulentus.png',
'./sample_imgs/Ptyas_mucosa.png',
'./sample_imgs/Vipera_ammodytes.png',
'./sample_imgs/Zamenis_longissimus.png']
,
live=True
)
if __name__ == "__main__":
interface.launch()