File size: 3,834 Bytes
bc679dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa81659
bc679dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Author: Alexander Riedel 
# License: Unlicensed
# Link: https://github.com/alexriedel1/detectron2-GradCAM

from plots.gradcam.gradcam import GradCAM, GradCamPlusPlus
import detectron2.data.transforms as T
import torch
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.data.detection_utils import read_image
from detectron2.modeling import build_model
from detectron2.data.datasets import register_coco_instances

class Detectron2GradCAM():
  """
      Attributes
    ----------
    config_file : str
        detectron2 model config file path
    cfg_list : list
        List of additional model configurations
    root_dir : str [optional]
        directory of coco.josn and dataset images for custom dataset registration
    custom_dataset : str [optional]
        Name of the custom dataset to register
    """
  def __init__(self, config_file, cfg_list, root_dir=None, custom_dataset=None):
      # load config from file
      cfg = get_cfg()
      cfg.merge_from_file(config_file)

      if custom_dataset:
          register_coco_instances(custom_dataset, {}, root_dir + "coco.json", root_dir)
          cfg.DATASETS.TRAIN = (custom_dataset,)
          MetadataCatalog.get(custom_dataset)
          DatasetCatalog.get(custom_dataset)

      if torch.cuda.is_available():
          cfg.MODEL.DEVICE = "cuda"
      else:
          cfg.MODEL.DEVICE = "cpu"

      cfg.merge_from_list(cfg_list)
      cfg.freeze()

      self.cfg =  cfg

  def _get_input_dict(self, original_image):
      height, width = original_image.shape[:2]
      transform_gen = T.ResizeShortestEdge(
          [self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MIN_SIZE_TEST], self.cfg.INPUT.MAX_SIZE_TEST
      )
      image = transform_gen.get_transform(original_image).apply_image(original_image)
      image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)).requires_grad_(True)
      inputs = {"image": image, "height": height, "width": width}
      return inputs

  def get_cam(self, img, target_instance, layer_name, grad_cam_type="GradCAM"):
      """
      Calls the GradCAM++ instance

      Parameters
      ----------
      img : str
          Path to inference image
      target_instance : int
          The target instance index
      layer_name : str
          Convolutional layer to perform GradCAM on
      grad_cam_type : str
          GradCAM or GradCAM++ (for multiple instances of the same object, GradCAM++ can be favorable)

      Returns
      -------
      image_dict : dict
        {"image" : <image>, "cam" : <cam>, "output" : <output>, "label" : <label>}
        <image> original input image
        <cam> class activation map resized to original image shape
        <output> instances object generated by the model
        <label> label of the 
      cam_orig : numpy.ndarray
        unprocessed raw cam
      """
      model = build_model(self.cfg)
      checkpointer = DetectionCheckpointer(model)
      checkpointer.load(self.cfg.MODEL.WEIGHTS)

      image = img
      input_image_dict = self._get_input_dict(image)

      if grad_cam_type == "GradCAM":
        grad_cam = GradCAM(model, layer_name)

      elif grad_cam_type == "GradCAM++":
        grad_cam = GradCamPlusPlus(model, layer_name)
      
      else:
        raise ValueError('Grad CAM type not specified')

      with grad_cam as cam:
        cam, cam_orig, output = cam(input_image_dict, target_category=target_instance)
      
      image_dict = {}
      image_dict["image"] = image
      image_dict["cam"] = cam
      image_dict["output"] = output[0]
      image_dict["label"] = MetadataCatalog.get(self.cfg.DATASETS.TRAIN[0]).thing_classes[output[0]["instances"].pred_classes[target_instance]]
      return image_dict, cam_orig