import torch import torchvision.transforms as transforms import torchvision.models as models import torch.nn as nn from joblib import load from PIL import Image import matplotlib.pyplot as plt import io import numpy as np import gradio as gr # Device configuration device = torch.device("cpu") # Transformation for the input images data_transforms = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Load the Isolation Forest model def load_isolation_forest(): path = 'Models/Anomaly_MSI_MSS_Isolation_Forest_model.joblib' return load(path) # Load the feature extractor def load_feature_extractor(): feature_extractor_path = 'Models/feature_extractor.pth' feature_extractor = models.resnet50(weights=None) feature_extractor.fc = nn.Sequential() feature_extractor.load_state_dict(torch.load(feature_extractor_path, map_location=device)) feature_extractor.to(device) feature_extractor.eval() return feature_extractor # Anomaly detection function def is_anomaly(clf, feature_extractor, image): with torch.no_grad(): image_features = feature_extractor(image) return clf.predict(image_features.cpu().numpy().reshape(1, -1))[0] == -1 # Classification function def classify_image(model, image): with torch.no_grad(): outputs = model(image) probabilities = torch.nn.functional.softmax(outputs, dim=1) _, predicted = torch.max(outputs, 1) class_names = ['abnormal', 'normal'] predicted_class_index = predicted.item() predicted_class_name = class_names[predicted_class_index] predicted_probability = probabilities[0][predicted_class_index].item() * 100 return predicted_class_name, predicted_probability # Load the classification model def load_classification_model(): model_path = 'Gastric_Models/the_resnet_50_model.pth' model = torch.load(model_path, map_location=device) model.to(device) model.eval() return model # Function to process the image and get results def process_image(image_path): # Convert to PIL and apply transforms image = Image.open(io.BytesIO(image_path.read())).convert('RGB') input_image = data_transforms(image).unsqueeze(0).to(device) # Load models clf = load_isolation_forest() feature_extractor = load_feature_extractor() classification_model = load_classification_model() # Check for anomaly if is_anomaly(clf, feature_extractor, input_image): return "Anomaly detected. Image will not be classified.", None, None # Classify image predicted_class, probability = classify_image(classification_model, input_image) result = f"The predicted class is: {predicted_class} with a probability of {probability:.2f}%" # Further processing for heatmap or additional features can be added here return result, None, None # Returning placeholders for additional outputs if needed # Gradio interface iface = gr.Interface( fn=process_image, inputs=File(type="filepath"), outputs=[gr.Textbox(label="Result"), gr.Image(label="Heatmap"), gr.Image(label="Additional Output")], title="GastroHub AI Gastric Image Classifier", description="Upload an image to classify it as normal or abnormal.", article="Above is a sample image to test the results of the model. Click it to see the results.", examples=[ ["Gastric_Images/Ladybug.png"], ], allow_flagging="never", ) iface.launch()