Upload 2 files
Browse files- cam.py +80 -0
- glaucoma.py +51 -0
cam.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
from typing import List, Callable, Optional
|
8 |
+
from functools import partial
|
9 |
+
|
10 |
+
from pytorch_grad_cam import GradCAM
|
11 |
+
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
12 |
+
from pytorch_grad_cam.utils.image import show_cam_on_image
|
13 |
+
|
14 |
+
|
15 |
+
""" Model wrapper to return a tensor"""
|
16 |
+
class HuggingfaceToTensorModelWrapper(torch.nn.Module):
|
17 |
+
def __init__(self, model):
|
18 |
+
super(HuggingfaceToTensorModelWrapper, self).__init__()
|
19 |
+
self.model = model
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
return self.model(x).logits
|
23 |
+
|
24 |
+
|
25 |
+
class ClassActivationMap(object):
|
26 |
+
def __init__(self, model, processor):
|
27 |
+
self.model = HuggingfaceToTensorModelWrapper(model)
|
28 |
+
target_layer = model.swinv2.layernorm
|
29 |
+
self.target_layer = [target_layer]
|
30 |
+
self.processor = processor
|
31 |
+
|
32 |
+
def swinT_reshape_transform_huggingface(self, tensor, width, height):
|
33 |
+
result = tensor.reshape(tensor.size(0),
|
34 |
+
height,
|
35 |
+
width,
|
36 |
+
tensor.size(2))
|
37 |
+
result = result.transpose(2, 3).transpose(1, 2)
|
38 |
+
return result
|
39 |
+
|
40 |
+
def run_grad_cam_on_image(self,
|
41 |
+
targets_for_gradcam: List[Callable],
|
42 |
+
reshape_transform: Optional[Callable],
|
43 |
+
input_tensor: torch.nn.Module,
|
44 |
+
input_image: Image,
|
45 |
+
method: Callable=GradCAM):
|
46 |
+
with method(model=self.model,
|
47 |
+
target_layers=self.target_layer,
|
48 |
+
reshape_transform=reshape_transform) as cam:
|
49 |
+
|
50 |
+
# Replicate the tensor for each of the categories we want to create Grad-CAM for:
|
51 |
+
# print(input_tensor.size())
|
52 |
+
repeated_tensor = input_tensor[None, :].repeat(len(targets_for_gradcam), 1, 1, 1)
|
53 |
+
# print(repeated_tensor.size())
|
54 |
+
|
55 |
+
batch_results = cam(input_tensor=repeated_tensor,
|
56 |
+
targets=targets_for_gradcam)
|
57 |
+
results = []
|
58 |
+
for grayscale_cam in batch_results:
|
59 |
+
visualization = show_cam_on_image(np.float32(input_image) / 255,
|
60 |
+
grayscale_cam,
|
61 |
+
use_rgb=True)
|
62 |
+
# Make it weight less in the notebook:
|
63 |
+
visualization = cv2.resize(visualization,
|
64 |
+
(visualization.shape[1] // 1, visualization.shape[0] // 1))
|
65 |
+
results.append(visualization)
|
66 |
+
return np.hstack(results)
|
67 |
+
|
68 |
+
def get_cam(self, image, category_id):
|
69 |
+
image = Image.fromarray(image).resize((self.processor.size['height'], self.processor.size['width']))
|
70 |
+
img_tensor = self.processor(images=image, return_tensors="pt")['pixel_values'].squeeze()
|
71 |
+
targets_for_gradcam = [ClassifierOutputTarget(category_id)]
|
72 |
+
reshape_transform = partial(self.swinT_reshape_transform_huggingface,
|
73 |
+
width=img_tensor.shape[2] // 32,
|
74 |
+
height=img_tensor.shape[1] // 32)
|
75 |
+
cam = self.run_grad_cam_on_image(input_tensor=img_tensor,
|
76 |
+
input_image=image,
|
77 |
+
targets_for_gradcam=targets_for_gradcam,
|
78 |
+
reshape_transform=reshape_transform)
|
79 |
+
|
80 |
+
return cam
|
glaucoma.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from transformers import AutoImageProcessor, Swinv2ForImageClassification
|
5 |
+
|
6 |
+
from lib.cam import ClassActivationMap
|
7 |
+
|
8 |
+
|
9 |
+
class GlaucomaModel(object):
|
10 |
+
def __init__(self,
|
11 |
+
cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification",
|
12 |
+
device=torch.device('cpu')):
|
13 |
+
# where to load the model, gpu or cpu ?
|
14 |
+
self.device = device
|
15 |
+
# classification model for nails disease
|
16 |
+
self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path)
|
17 |
+
self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval()
|
18 |
+
# class activation map
|
19 |
+
self.cam = ClassActivationMap(self.cls_model, self.cls_extractor)
|
20 |
+
|
21 |
+
# classification id to label
|
22 |
+
self.id2label = self.cls_model.config.id2label
|
23 |
+
|
24 |
+
# number of classes for nails disease
|
25 |
+
self.num_diseases = len(self.id2label)
|
26 |
+
|
27 |
+
def glaucoma_pred(self, image):
|
28 |
+
"""
|
29 |
+
Args:
|
30 |
+
image: image array in RGB order.
|
31 |
+
"""
|
32 |
+
inputs = self.cls_extractor(images=image.copy(), return_tensors="pt")
|
33 |
+
with torch.no_grad():
|
34 |
+
inputs.to(self.device)
|
35 |
+
outputs = self.cls_model(**inputs).logits
|
36 |
+
disease_idx = outputs.cpu()[0, :].detach().numpy().argmax()
|
37 |
+
|
38 |
+
return disease_idx
|
39 |
+
|
40 |
+
def process(self, image):
|
41 |
+
"""
|
42 |
+
Args:
|
43 |
+
image: image array in RGB order.
|
44 |
+
"""
|
45 |
+
image_shape = image.shape[:2]
|
46 |
+
disease_idx = self.glaucoma_pred(image)
|
47 |
+
cam = self.cam.get_cam(image, disease_idx)
|
48 |
+
cam = cv2.resize(cam, image_shape[::-1])
|
49 |
+
|
50 |
+
return disease_idx, cam
|
51 |
+
|