import torch import torchvision.transforms as transforms import torchvision.models as models from joblib import load from PIL import Image import gradio as gr import matplotlib.pyplot as plt import io # Transformation and device setup device = torch.device("cpu") data_transforms = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Load the Isolation Forest model clf = load('Models/Anomaly_MSI_MSS_Isolation_Forest_model.joblib') # Load feature extractor feature_extractor_path = 'Models/feature_extractor.pth' feature_extractor = models.resnet50(weights=None) feature_extractor.fc = nn.Sequential() feature_extractor.load_state_dict(torch.load(feature_extractor_path, map_location=device)) feature_extractor.to(device) feature_extractor.eval() # Load gastric classification model GASTRIC_MODEL_PATH = 'Gastric_Models/the_resnet_50_model.pth' model_ft = torch.load(GASTRIC_MODEL_PATH, map_location=device) model_ft.to(device) model_ft.eval() # Anomaly detection and classification function def classify_image(uploaded_image): image = Image.open(io.BytesIO(uploaded_image.read())).convert('RGB') input_image = data_transforms(image).unsqueeze(0).to(device) # Anomaly detection if is_anomaly(clf, feature_extractor): return "Anomaly detected. Image will not be classified.", None # Classification with torch.no_grad(): outputs = model_ft(input_image) probabilities = torch.nn.functional.softmax(outputs, dim=1) _, predicted = torch.max(outputs, 1) predicted_class_index = predicted.item() class_names = ['abnormal', 'normal'] predicted_class_name = class_names[predicted_class_index] predicted_probability = probabilities[0][predicted_class_index].item() * 100 return f"Class: {predicted_class_name}, Probability: {predicted_probability:.2f}%", None iface = gr.Interface( fn=classify_image, inputs=File(type="filepath"), outputs=gr.Image(plot=True), title="GastroHub AI Gastric Image Classifier", description="Upload an image to classify it as normal or abnormal.", article="Above is a sample image to test the results of the model. Click it to see the results.", examples=[ ["Gastric_Images/Ladybug.png"], ], allow_flagging="never", ) # Run the Gradio app iface.launch()