|
import gradio as gr |
|
import os |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
import timm |
|
from PIL import Image |
|
from torchvision import transforms |
|
import gradio |
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
model = timm.create_model('swinv2_cr_tiny_ns_224.sw_in1k', pretrained=True) |
|
output_shape = 60 |
|
model.classifier = torch.nn.Sequential( |
|
torch.nn.Dropout(p=0.2, inplace=True), |
|
torch.nn.Linear(in_features=1000, |
|
out_features=output_shape, |
|
bias=True).to('cpu')) |
|
model.load_state_dict(torch.load('./swin_70_65.pth', map_location=torch.device('cpu'))) |
|
|
|
preprocess = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
class_names = [ |
|
'Ahaetulla_prasina', 'Bitis_arietans', 'Boa_constrictor', 'Boa_imperator', |
|
'Bothriechis_schlegelii', 'Bothrops_asper', 'Bothrops_atrox', 'Bungarus_fasciatus', |
|
'Chrysopelea_ornata', 'Coelognathus_radiatus', 'Corallus_hortulana', 'Coronella_austriaca', |
|
'Crotaphopeltis_hotamboeia', 'Dendrelaphis_pictus', 'Dolichophis_caspius', 'Drymarchon_melanurus', |
|
'Drymobius_margaritiferus', 'Elaphe_dione', 'Epicrates_cenchria', 'Erythrolamprus_poecilogyrus', |
|
'Eunectes_murinus', 'Fowlea_flavipunctata', 'Gonyosoma_oxycephalum', 'Helicops_angulatus', |
|
'Hierophis_viridiflavus', 'Imantodes_cenchoa', 'Indotyphlops_braminus', 'Laticauda_colubrina', |
|
'Leptodeira_annulata', 'Leptodeira_ornata', 'Leptodeira_septentrionalis', 'Leptophis_ahaetulla', |
|
'Leptophis_mexicanus', 'Lycodon_capucinus', 'Malayopython_reticulatus', 'Malpolon_insignitus', |
|
'Mastigodryas_boddaerti', 'Natrix_helvetica', 'Natrix_maura', 'Natrix_natrix', 'Natrix_tessellata', |
|
'Ninia_sebae', 'Ophiophagus_hannah', 'Oxybelis_aeneus', 'Oxybelis_fulgidus', 'Oxyrhopus_petolarius', |
|
'Phrynonax_poecilonotus', 'Psammodynastes_pulverulentus', 'Ptyas_korros', 'Ptyas_mucosa', |
|
'Python_bivittatus', 'Rhabdophis_tigrinus', 'Sibon_nebulatus', 'Spilotes_pullatus', |
|
'Tantilla_melanocephala', 'Trimeresurus_albolabris', 'Vipera_ammodytes', 'Vipera_aspis', |
|
'Vipera_berus', 'Zamenis_longissimus' |
|
] |
|
|
|
|
|
def predict(image): |
|
if image is None: |
|
return "No image provided." |
|
|
|
try: |
|
input_tensor = preprocess(image) |
|
except Exception as e: |
|
return f"Error in preprocessing: {str(e)}" |
|
|
|
input_batch = input_tensor.unsqueeze(0).to('cpu') |
|
|
|
try: |
|
with torch.no_grad(): |
|
output = model(input_batch) |
|
except Exception as e: |
|
return f"Error in model inference: {str(e)}" |
|
|
|
probabilities = torch.nn.functional.softmax(output, dim=1) |
|
percentages = probabilities[0].cpu().numpy() * 100 |
|
|
|
top_n = 5 |
|
combined = list(zip(class_names, percentages)) |
|
sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True) |
|
top_predictions = sorted_combined[:top_n] |
|
|
|
|
|
html_content = "<div style='font-family: Arial, sans-serif;'>" |
|
for class_label, confidence in top_predictions: |
|
html_content += f""" |
|
<div style='margin-bottom: 10px; position: relative;'> |
|
<div style='display: flex; align-items: center;'> |
|
<strong style='flex: 1;'>{class_label}</strong> |
|
<span style='flex-shrink: 0; color: black; margin-left: 10px;'> |
|
{confidence:.2f}% |
|
</span> |
|
</div> |
|
<div style='background-color: #f3f3f3; border-radius: 5px; width: 100%; height: 20px; margin-top: 5px;'> |
|
<div style='background-color: #4CAF50; height: 100%; width: {confidence:.2f}%; border-radius: 5px;'></div> |
|
</div> |
|
</div> |
|
""" |
|
html_content += "</div>" |
|
|
|
return html_content |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type='pil'), |
|
outputs=gr.HTML(), |
|
title="Snake Species Classification (SnakeCLEF2024)", |
|
description = """ |
|
<div style='font-family: Arial, sans-serif; line-height: 1.6;'> |
|
<p>Datasets and classes are referenced from: <a href="https://www.imageclef.org/node/319" target="_blank">ImageCLEF</a> and for more details: <a href="https://github.com/Tanaanan/SnakeCLEF2024_MLCS" target="_blank">GitHub repository</a></p> |
|
<p style='font-size: smaller;'>This project is part of the course 'Machine Learning Systems (01418262).'</p> |
|
<p style='font-size: smaller;'>Developed by Tanaanan, Narakorn, Department of Computer Science, Kasetsart University.</p> |
|
</div> |
|
""", |
|
examples=['./sample_imgs/Bitis_arietans.png', |
|
'./sample_imgs/Boa_imperator.png', |
|
'./sample_imgs/Coelognathus_radiatus.png', |
|
'./sample_imgs/Leptodeira_septentrionalis.png', |
|
'./sample_imgs/Natrix_tessellata.png', |
|
'./sample_imgs/Psammodynastes_pulverulentus.png', |
|
'./sample_imgs/Ptyas_mucosa.png', |
|
'./sample_imgs/Vipera_ammodytes.png', |
|
'./sample_imgs/Zamenis_longissimus.png'] |
|
, |
|
live=True |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|