import torch import torchvision.transforms as transforms import torchvision.models as models import torch.nn as nn from joblib import load from gradio import File from PIL import Image import gradio as gr import matplotlib.pyplot as plt import io # Transformation and device setup device = torch.device("cpu") data_transforms = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Load the Isolation Forest model clf = load('Models/Anomaly_MSI_MSS_Isolation_Forest_model.joblib') # Load feature extractor feature_extractor_path = 'Models/feature_extractor.pth' feature_extractor = models.resnet50(weights=None) feature_extractor.fc = nn.Sequential() feature_extractor.load_state_dict(torch.load(feature_extractor_path, map_location=device)) feature_extractor.to(device) feature_extractor.eval() # Load gastric classification model GASTRIC_MODEL_PATH = 'Gastric_Models/the_resnet_50_model.pth' model_ft = torch.load(GASTRIC_MODEL_PATH, map_location=device) model_ft.to(device) model_ft.eval() # Anomaly detection and classification function def classify_image(uploaded_image): image = Image.open(uploaded_image).convert('RGB') input_image = data_transforms(image).unsqueeze(0).to(device) # Anomaly detection if is_anomaly(clf, feature_extractor): return "Anomaly detected. Image will not be classified.", None # Classification with torch.no_grad(): outputs = model_ft(input_image) probabilities = torch.nn.functional.softmax(outputs, dim=1) _, predicted = torch.max(outputs, 1) predicted_class_index = predicted.item() class_names = ['abnormal', 'normal'] predicted_class_name = class_names[predicted_class_index] predicted_probability = probabilities[0][predicted_class_index].item() * 100 return f"Class: {predicted_class_name}, Probability: {predicted_probability:.2f}%", None iface = gr.Interface( fn=classify_image, inputs=File(type="filepath"), outputs=gr.Image(), title="GastroHub AI Gastric Image Classifier", description="Upload an image to classify it as normal or abnormal.", article="Above is a sample image to test the results of the model. Click it to see the results.", examples=[ ["Gastric_Images/Ladybug.png"], ], allow_flagging="never", ) # Run the Gradio app iface.launch()