Changed Code for CUDA
Browse files- Model_Class.py +8 -5
- Model_Seg.py +3 -5
- app.py +6 -3
Model_Class.py
CHANGED
@@ -59,14 +59,15 @@ val_transforms_416x628 = Compose(
|
|
59 |
]
|
60 |
)
|
61 |
|
62 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
63 |
checkpoint = torch.load("classification_model.ckpt", map_location=torch.device('cpu'))
|
64 |
-
model = ResNet()
|
65 |
model.load_state_dict(checkpoint["state_dict"])
|
66 |
model.eval()
|
67 |
|
68 |
|
69 |
-
def load_and_classify_image(image_path):
|
|
|
|
|
70 |
image = val_transforms_416x628(image_path)
|
71 |
image = image.unsqueeze(0).to(device)
|
72 |
|
@@ -76,8 +77,10 @@ def load_and_classify_image(image_path):
|
|
76 |
return prediction.to('cpu'), image.to('cpu')
|
77 |
|
78 |
|
79 |
-
def make_GradCAM(image):
|
80 |
|
|
|
|
|
81 |
model.eval()
|
82 |
target_layers = [model.model.layer4[-1]]
|
83 |
|
@@ -90,7 +93,7 @@ def make_GradCAM(image):
|
|
90 |
aug_smooth=False,
|
91 |
eigen_smooth=True,
|
92 |
)
|
93 |
-
grayscale_cam = grayscale_cam.squeeze()
|
94 |
|
95 |
jet = plt.colormaps.get_cmap("inferno")
|
96 |
newcolors = jet(np.linspace(0, 1, 256))
|
|
|
59 |
]
|
60 |
)
|
61 |
|
|
|
62 |
checkpoint = torch.load("classification_model.ckpt", map_location=torch.device('cpu'))
|
63 |
+
model = ResNet()
|
64 |
model.load_state_dict(checkpoint["state_dict"])
|
65 |
model.eval()
|
66 |
|
67 |
|
68 |
+
def load_and_classify_image(image_path, device):
|
69 |
+
|
70 |
+
model = model.to(device)
|
71 |
image = val_transforms_416x628(image_path)
|
72 |
image = image.unsqueeze(0).to(device)
|
73 |
|
|
|
77 |
return prediction.to('cpu'), image.to('cpu')
|
78 |
|
79 |
|
80 |
+
def make_GradCAM(image, device):
|
81 |
|
82 |
+
model = model.to(device)
|
83 |
+
image = image.to(device)
|
84 |
model.eval()
|
85 |
target_layers = [model.model.layer4[-1]]
|
86 |
|
|
|
93 |
aug_smooth=False,
|
94 |
eigen_smooth=True,
|
95 |
)
|
96 |
+
grayscale_cam = grayscale_cam.to('cpu').squeeze()
|
97 |
|
98 |
jet = plt.colormaps.get_cmap("inferno")
|
99 |
newcolors = jet(np.linspace(0, 1, 256))
|
Model_Seg.py
CHANGED
@@ -39,10 +39,8 @@ model = UNet(
|
|
39 |
num_res_units=3
|
40 |
)
|
41 |
|
42 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
43 |
-
|
44 |
checkpoint_path = 'segmentation_model.pt'
|
45 |
-
checkpoint = torch.load(checkpoint_path, map_location=
|
46 |
assert model.state_dict().keys() == checkpoint['network'].keys(), "Model and checkpoint keys do not match"
|
47 |
|
48 |
model.load_state_dict(checkpoint['network'])
|
@@ -73,9 +71,9 @@ post_transforms = Compose([
|
|
73 |
|
74 |
|
75 |
|
76 |
-
def load_and_segment_image(input_image_path):
|
77 |
|
78 |
-
|
79 |
image_tensor = pre_transforms(input_image_path)
|
80 |
image_tensor = image_tensor.unsqueeze(0).to(device)
|
81 |
|
|
|
39 |
num_res_units=3
|
40 |
)
|
41 |
|
|
|
|
|
42 |
checkpoint_path = 'segmentation_model.pt'
|
43 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
44 |
assert model.state_dict().keys() == checkpoint['network'].keys(), "Model and checkpoint keys do not match"
|
45 |
|
46 |
model.load_state_dict(checkpoint['network'])
|
|
|
71 |
|
72 |
|
73 |
|
74 |
+
def load_and_segment_image(input_image_path, device):
|
75 |
|
76 |
+
model = model.to(device)
|
77 |
image_tensor = pre_transforms(input_image_path)
|
78 |
image_tensor = image_tensor.unsqueeze(0).to(device)
|
79 |
|
app.py
CHANGED
@@ -7,6 +7,9 @@ import SimpleITK as sitk
|
|
7 |
import torch
|
8 |
from numpy import uint8
|
9 |
import spaces
|
|
|
|
|
|
|
10 |
image_base64 = utils.image_to_base64("anatomy_aware_pipeline.png")
|
11 |
article_html = f"<img src='data:image/png;base64,{image_base64}' alt='Anatomical pipeline illustration' style='width:100%;'>"
|
12 |
|
@@ -64,7 +67,7 @@ def predict_image(input_image, input_file):
|
|
64 |
else:
|
65 |
return None , None , "Please input an image before pressing run" , None , None
|
66 |
|
67 |
-
image_mask = Model_Seg.load_and_segment_image(image_path)
|
68 |
|
69 |
overlay_image_np, original_image_np = utils.overlay_mask(image_path, image_mask)
|
70 |
|
@@ -75,10 +78,10 @@ def predict_image(input_image, input_file):
|
|
75 |
cropped_boxed_array = sitk.GetArrayFromImage(cropped_boxed_im)
|
76 |
cropped_boxed_array_disp = cropped_boxed_array.squeeze()
|
77 |
cropped_boxed_tensor = torch.Tensor(cropped_boxed_array)
|
78 |
-
prediction, image_transformed = Model_Class.load_and_classify_image(cropped_boxed_tensor)
|
79 |
|
80 |
|
81 |
-
gradcam = Model_Class.make_GradCAM(image_transformed)
|
82 |
|
83 |
nr_axSpA_prob = float(prediction[0].item())
|
84 |
r_axSpA_prob = float(prediction[1].item())
|
|
|
7 |
import torch
|
8 |
from numpy import uint8
|
9 |
import spaces
|
10 |
+
|
11 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
+
|
13 |
image_base64 = utils.image_to_base64("anatomy_aware_pipeline.png")
|
14 |
article_html = f"<img src='data:image/png;base64,{image_base64}' alt='Anatomical pipeline illustration' style='width:100%;'>"
|
15 |
|
|
|
67 |
else:
|
68 |
return None , None , "Please input an image before pressing run" , None , None
|
69 |
|
70 |
+
image_mask = Model_Seg.load_and_segment_image(image_path, device)
|
71 |
|
72 |
overlay_image_np, original_image_np = utils.overlay_mask(image_path, image_mask)
|
73 |
|
|
|
78 |
cropped_boxed_array = sitk.GetArrayFromImage(cropped_boxed_im)
|
79 |
cropped_boxed_array_disp = cropped_boxed_array.squeeze()
|
80 |
cropped_boxed_tensor = torch.Tensor(cropped_boxed_array)
|
81 |
+
prediction, image_transformed = Model_Class.load_and_classify_image(cropped_boxed_tensor, device)
|
82 |
|
83 |
|
84 |
+
gradcam = Model_Class.make_GradCAM(image_transformed, device)
|
85 |
|
86 |
nr_axSpA_prob = float(prediction[0].item())
|
87 |
r_axSpA_prob = float(prediction[1].item())
|