oral-lesions-detection / plots /gradcam /detectron2_gradcam.py
Federico Galatolo
gradcam working on cv image
fa81659
raw
history blame
No virus
3.83 kB
# Author: Alexander Riedel
# License: Unlicensed
# Link: https://github.com/alexriedel1/detectron2-GradCAM
from plots.gradcam.gradcam import GradCAM, GradCamPlusPlus
import detectron2.data.transforms as T
import torch
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.data.detection_utils import read_image
from detectron2.modeling import build_model
from detectron2.data.datasets import register_coco_instances
class Detectron2GradCAM():
"""
Attributes
----------
config_file : str
detectron2 model config file path
cfg_list : list
List of additional model configurations
root_dir : str [optional]
directory of coco.josn and dataset images for custom dataset registration
custom_dataset : str [optional]
Name of the custom dataset to register
"""
def __init__(self, config_file, cfg_list, root_dir=None, custom_dataset=None):
# load config from file
cfg = get_cfg()
cfg.merge_from_file(config_file)
if custom_dataset:
register_coco_instances(custom_dataset, {}, root_dir + "coco.json", root_dir)
cfg.DATASETS.TRAIN = (custom_dataset,)
MetadataCatalog.get(custom_dataset)
DatasetCatalog.get(custom_dataset)
if torch.cuda.is_available():
cfg.MODEL.DEVICE = "cuda"
else:
cfg.MODEL.DEVICE = "cpu"
cfg.merge_from_list(cfg_list)
cfg.freeze()
self.cfg = cfg
def _get_input_dict(self, original_image):
height, width = original_image.shape[:2]
transform_gen = T.ResizeShortestEdge(
[self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MIN_SIZE_TEST], self.cfg.INPUT.MAX_SIZE_TEST
)
image = transform_gen.get_transform(original_image).apply_image(original_image)
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)).requires_grad_(True)
inputs = {"image": image, "height": height, "width": width}
return inputs
def get_cam(self, img, target_instance, layer_name, grad_cam_type="GradCAM"):
"""
Calls the GradCAM++ instance
Parameters
----------
img : str
Path to inference image
target_instance : int
The target instance index
layer_name : str
Convolutional layer to perform GradCAM on
grad_cam_type : str
GradCAM or GradCAM++ (for multiple instances of the same object, GradCAM++ can be favorable)
Returns
-------
image_dict : dict
{"image" : <image>, "cam" : <cam>, "output" : <output>, "label" : <label>}
<image> original input image
<cam> class activation map resized to original image shape
<output> instances object generated by the model
<label> label of the
cam_orig : numpy.ndarray
unprocessed raw cam
"""
model = build_model(self.cfg)
checkpointer = DetectionCheckpointer(model)
checkpointer.load(self.cfg.MODEL.WEIGHTS)
image = img
input_image_dict = self._get_input_dict(image)
if grad_cam_type == "GradCAM":
grad_cam = GradCAM(model, layer_name)
elif grad_cam_type == "GradCAM++":
grad_cam = GradCamPlusPlus(model, layer_name)
else:
raise ValueError('Grad CAM type not specified')
with grad_cam as cam:
cam, cam_orig, output = cam(input_image_dict, target_category=target_instance)
image_dict = {}
image_dict["image"] = image
image_dict["cam"] = cam
image_dict["output"] = output[0]
image_dict["label"] = MetadataCatalog.get(self.cfg.DATASETS.TRAIN[0]).thing_classes[output[0]["instances"].pred_classes[target_instance]]
return image_dict, cam_orig