import torch import torchvision.transforms as transforms import torchvision.models as models import torch.nn as nn from joblib import load from gradio import File from PIL import Image import matplotlib.pyplot as plt import io import numpy as np import gradio as gr # Device configuration device = torch.device("cpu") # Transformation for the input images data_transforms = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Load the Isolation Forest model def load_isolation_forest(): path = 'Models/Anomaly_MSI_MSS_Isolation_Forest_model.joblib' return load(path) # Load the feature extractor def load_feature_extractor(): feature_extractor_path = 'Models/feature_extractor.pth' feature_extractor = models.resnet50(weights=None) feature_extractor.fc = nn.Sequential() feature_extractor.load_state_dict(torch.load(feature_extractor_path, map_location=device)) feature_extractor.to(device) feature_extractor.eval() return feature_extractor # Anomaly detection function def is_anomaly(clf, feature_extractor, image): with torch.no_grad(): image_features = feature_extractor(image) return clf.predict(image_features.cpu().numpy().reshape(1, -1))[0] == -1 # Classification function def classify_image(model, image): with torch.no_grad(): outputs = model(image) probabilities = torch.nn.functional.softmax(outputs, dim=1) _, predicted = torch.max(outputs, 1) class_names = ['abnormal', 'normal'] predicted_class_index = predicted.item() predicted_class_name = class_names[predicted_class_index] predicted_probability = probabilities[0][predicted_class_index].item() * 100 return predicted_class_name, predicted_probability # Load the classification model def load_classification_model(): model_path = 'Gastric_Models/the_resnet_50_model.pth' model = torch.load(model_path, map_location=device) model.to(device) model.eval() return model # Function to process the image and get results def process_image(image_path): # Convert to PIL and apply transforms image = Image.open(image_path).convert('RGB') input_image = data_transforms(image).unsqueeze(0).to(device) # Load models clf = load_isolation_forest() feature_extractor = load_feature_extractor() classification_model = load_classification_model() # Check for anomaly if is_anomaly(clf, feature_extractor, input_image): return "Anomaly detected. Image will not be classified.", None # Classify image predicted_class, probability = classify_image(classification_model, input_image) result = f"The predicted class is: {predicted_class} with a probability of {probability:.2f}%" # Generate heatmap heatmap = generate_heatmap(classification_model, input_image) heatmap_image = Image.fromarray(np.uint8(plt.cm.hot(heatmap) * 255)) return result, heatmap_image # Function to generate heatmap def generate_heatmap(model, image): activation = [] def hook_fn(module, input, output): activation.append(output) for module in model.named_modules(): if isinstance(module[1], torch.nn.ReLU): module[1].register_forward_hook(hook_fn) # Forward pass output = model(image) prediction = output.argmax(1) # Backpropagation to compute gradients model.zero_grad() one_hot_output = torch.FloatTensor(1, output.size()[-1]).zero_().to(device) one_hot_output[0][prediction] = 1 output.backward(gradient=one_hot_output) # Compute the heatmap if len(activation) > 0: gradients = activation[0][0].detach().cpu().numpy() heatmap = gradients.max(axis=0) threshold = 0.5 # Adjust this threshold value as needed heatmap[heatmap < threshold] = 0 return heatmap else: return np.zeros((224, 224)) # Return an empty heatmap if no activation is recorded # Gradio interface iface = gr.Interface( fn=process_image, inputs=File(type="filepath"), outputs=[gr.Textbox(label="Result"), gr.Image(label="Heatmap"), gr.Image(label="Additional Output")], title="GastroHub AI Gastric Image Classifier", description="Upload an image to classify it as normal or abnormal.", article="Above is a sample image to test the results of the model. Click it to see the results.", examples=[ ["/Gastric_Images/Ladybug.png"], ], allow_flagging="never", ) iface.launch()