import torch import torchvision.transforms as transforms import torchvision.models as models import torch.nn as nn from joblib import load from PIL import Image import matplotlib.pyplot as plt # Assuming that the model paths and other configurations are correctly set device = torch.device("cpu") # Transformation data_transforms = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Load the Isolation Forest model clf = load('Models/Anomaly_MSI_MSS_Isolation_Forest_model.joblib') # Anomaly detection function def is_anomaly(clf, feature_extractor, input_image): feature_extractor.to(device) with torch.no_grad(): image_features = feature_extractor(input_image) is_outlier = clf.predict(image_features.cpu().numpy().reshape(1, -1)) return is_outlier[0] == -1 # Feature extractor feature_extractor = models.resnet50(weights=None) feature_extractor.fc = torch.nn.Sequential() feature_extractor.load_state_dict(torch.load('Models/feature_extractor.pth', map_location=device)) feature_extractor.to(device) feature_extractor.eval() # Load the classification model model_ft = torch.load('Gastric_Models/the_resnet_50_model.pth', map_location=device) model_ft = model_ft.to(device) model_ft.eval() def classify_image(clf, feature_extractor, input_image, model_path, class_names): """ Detects anomalies and classifies the image. Parameters: clf - Anomaly detection model. feature_extractor - Feature extractor for the anomaly detection model. input_image - The image to be classified. model_path - Path to the classification model. class_names - List of class names for classification. Returns: A tuple containing the predicted class name and its probability. """ # Anomaly detection if is_anomaly(clf, feature_extractor, input_image): print("Anomaly detected. Image will not be classified.") return None, None else: print("No anomaly detected. Proceeding with classification.") # Load classification model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = torch.load(model_path, map_location=device) model = model.to(device) model.eval() # Classification with torch.no_grad(): outputs = model(input_image) probabilities = torch.nn.functional.softmax(outputs, dim=1) _, predicted = torch.max(outputs, 1) predicted_class_index = predicted.item() predicted_class_name = class_names[predicted_class_index] predicted_probability = probabilities[0][predicted_class_index].item() * 100 return predicted_class_name, predicted_probability def generate_heatmap(model, input_image, device): """ Generates a heatmap based on the model's activation. Parameters: model - The classification model. input_image - The image for which heatmap is to be generated. device - The device (CPU or GPU) where the model is loaded. Returns: The generated heatmap. """ activation = [] def hook_fn(module, input, output): activation.append(output) for module in model.named_modules(): if isinstance(module[1], torch.nn.ReLU): module[1].register_forward_hook(hook_fn) model.zero_grad() output = model(input_image) prediction = output.argmax(1) one_hot_output = torch.FloatTensor(1, output.size()[-1]).zero_().to(device) one_hot_output[0][prediction] = 1 output.backward(gradient=one_hot_output) if len(activation) > 0: gradients = activation[0][0].detach().cpu().numpy() heatmap = gradients.max(axis=0) # Thresholding the heatmap threshold = 0.5 heatmap[heatmap < threshold] = 0 return heatmap else: print("No activations recorded.") return None import gradio as gr import torch from PIL import Image import io def process_image(image): """ Processes the image and returns the classification results and heatmap. """ # Transform the image input_image = data_transforms(image).unsqueeze(0).to(device) # Classify the image predicted_class_name, predicted_probability = classify_image(clf, feature_extractor, input_image, 'Gastric_Models/the_resnet_50_model.pth', ['abnormal', 'normal']) if predicted_class_name is None: return "Anomaly Detected - Image not classified", None # Generate heatmap heatmap = generate_heatmap(model_ft, input_image, device) if heatmap is not None: plt.imshow(heatmap, cmap='hot') plt.axis('off') buffer = io.BytesIO() plt.savefig(buffer, format='png') buffer.seek(0) heatmap_image = Image.open(buffer) plt.close() description = "\n\nRed Regions - highest importance for the predicted class.\nYellow Regions - moderately high importance for the predicted class." return f"The predicted class is: {predicted_class_name} at a probability of ({predicted_probability:.2f}%) {description}", heatmap_image else: return f"{predicted_class_name} ({predicted_probability:.2f}%)", "No heatmap generated" gr.Interface( fn=process_image, inputs=gr.Image(type="pil"), outputs=[gr.Textbox(label="Classification Result"), gr.Image(label="Heatmap")], title="Gastrohub Cancer Detection", description="Upload an image to classify it as normal or abnormal.", article="Above are a few sample images to test the results of the model. Click any to see the results.", examples=[ ["Gastric_Images/Abnormal-04038-test2.png"], ["Gastric_Images/cancer_108_WSI.png"], ["Gastric_Images/Ladybug.png"], ["Gastric_Images/normal_78_WSI.png"], ["Gastric_Images/Normal-01006-test1.png"], ["Gastric_Images/Normal-07000.png"], ], allow_flagging="never", ).launch()