FJDorfner commited on
Commit
f722806
1 Parent(s): 4b61dc5

Changed Code for CUDA

Browse files
Files changed (3) hide show
  1. Model_Class.py +8 -5
  2. Model_Seg.py +3 -5
  3. app.py +6 -3
Model_Class.py CHANGED
@@ -59,14 +59,15 @@ val_transforms_416x628 = Compose(
59
  ]
60
  )
61
 
62
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
  checkpoint = torch.load("classification_model.ckpt", map_location=torch.device('cpu'))
64
- model = ResNet().to(device)
65
  model.load_state_dict(checkpoint["state_dict"])
66
  model.eval()
67
 
68
 
69
- def load_and_classify_image(image_path):
 
 
70
  image = val_transforms_416x628(image_path)
71
  image = image.unsqueeze(0).to(device)
72
 
@@ -76,8 +77,10 @@ def load_and_classify_image(image_path):
76
  return prediction.to('cpu'), image.to('cpu')
77
 
78
 
79
- def make_GradCAM(image):
80
 
 
 
81
  model.eval()
82
  target_layers = [model.model.layer4[-1]]
83
 
@@ -90,7 +93,7 @@ def make_GradCAM(image):
90
  aug_smooth=False,
91
  eigen_smooth=True,
92
  )
93
- grayscale_cam = grayscale_cam.squeeze()
94
 
95
  jet = plt.colormaps.get_cmap("inferno")
96
  newcolors = jet(np.linspace(0, 1, 256))
 
59
  ]
60
  )
61
 
 
62
  checkpoint = torch.load("classification_model.ckpt", map_location=torch.device('cpu'))
63
+ model = ResNet()
64
  model.load_state_dict(checkpoint["state_dict"])
65
  model.eval()
66
 
67
 
68
+ def load_and_classify_image(image_path, device):
69
+
70
+ model = model.to(device)
71
  image = val_transforms_416x628(image_path)
72
  image = image.unsqueeze(0).to(device)
73
 
 
77
  return prediction.to('cpu'), image.to('cpu')
78
 
79
 
80
+ def make_GradCAM(image, device):
81
 
82
+ model = model.to(device)
83
+ image = image.to(device)
84
  model.eval()
85
  target_layers = [model.model.layer4[-1]]
86
 
 
93
  aug_smooth=False,
94
  eigen_smooth=True,
95
  )
96
+ grayscale_cam = grayscale_cam.to('cpu').squeeze()
97
 
98
  jet = plt.colormaps.get_cmap("inferno")
99
  newcolors = jet(np.linspace(0, 1, 256))
Model_Seg.py CHANGED
@@ -39,10 +39,8 @@ model = UNet(
39
  num_res_units=3
40
  )
41
 
42
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
-
44
  checkpoint_path = 'segmentation_model.pt'
45
- checkpoint = torch.load(checkpoint_path, map_location=device)
46
  assert model.state_dict().keys() == checkpoint['network'].keys(), "Model and checkpoint keys do not match"
47
 
48
  model.load_state_dict(checkpoint['network'])
@@ -73,9 +71,9 @@ post_transforms = Compose([
73
 
74
 
75
 
76
- def load_and_segment_image(input_image_path):
77
 
78
-
79
  image_tensor = pre_transforms(input_image_path)
80
  image_tensor = image_tensor.unsqueeze(0).to(device)
81
 
 
39
  num_res_units=3
40
  )
41
 
 
 
42
  checkpoint_path = 'segmentation_model.pt'
43
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
44
  assert model.state_dict().keys() == checkpoint['network'].keys(), "Model and checkpoint keys do not match"
45
 
46
  model.load_state_dict(checkpoint['network'])
 
71
 
72
 
73
 
74
+ def load_and_segment_image(input_image_path, device):
75
 
76
+ model = model.to(device)
77
  image_tensor = pre_transforms(input_image_path)
78
  image_tensor = image_tensor.unsqueeze(0).to(device)
79
 
app.py CHANGED
@@ -7,6 +7,9 @@ import SimpleITK as sitk
7
  import torch
8
  from numpy import uint8
9
  import spaces
 
 
 
10
  image_base64 = utils.image_to_base64("anatomy_aware_pipeline.png")
11
  article_html = f"<img src='data:image/png;base64,{image_base64}' alt='Anatomical pipeline illustration' style='width:100%;'>"
12
 
@@ -64,7 +67,7 @@ def predict_image(input_image, input_file):
64
  else:
65
  return None , None , "Please input an image before pressing run" , None , None
66
 
67
- image_mask = Model_Seg.load_and_segment_image(image_path)
68
 
69
  overlay_image_np, original_image_np = utils.overlay_mask(image_path, image_mask)
70
 
@@ -75,10 +78,10 @@ def predict_image(input_image, input_file):
75
  cropped_boxed_array = sitk.GetArrayFromImage(cropped_boxed_im)
76
  cropped_boxed_array_disp = cropped_boxed_array.squeeze()
77
  cropped_boxed_tensor = torch.Tensor(cropped_boxed_array)
78
- prediction, image_transformed = Model_Class.load_and_classify_image(cropped_boxed_tensor)
79
 
80
 
81
- gradcam = Model_Class.make_GradCAM(image_transformed)
82
 
83
  nr_axSpA_prob = float(prediction[0].item())
84
  r_axSpA_prob = float(prediction[1].item())
 
7
  import torch
8
  from numpy import uint8
9
  import spaces
10
+
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
  image_base64 = utils.image_to_base64("anatomy_aware_pipeline.png")
14
  article_html = f"<img src='data:image/png;base64,{image_base64}' alt='Anatomical pipeline illustration' style='width:100%;'>"
15
 
 
67
  else:
68
  return None , None , "Please input an image before pressing run" , None , None
69
 
70
+ image_mask = Model_Seg.load_and_segment_image(image_path, device)
71
 
72
  overlay_image_np, original_image_np = utils.overlay_mask(image_path, image_mask)
73
 
 
78
  cropped_boxed_array = sitk.GetArrayFromImage(cropped_boxed_im)
79
  cropped_boxed_array_disp = cropped_boxed_array.squeeze()
80
  cropped_boxed_tensor = torch.Tensor(cropped_boxed_array)
81
+ prediction, image_transformed = Model_Class.load_and_classify_image(cropped_boxed_tensor, device)
82
 
83
 
84
+ gradcam = Model_Class.make_GradCAM(image_transformed, device)
85
 
86
  nr_axSpA_prob = float(prediction[0].item())
87
  r_axSpA_prob = float(prediction[1].item())