Delete utils/pipeline.py
Browse files- utils/pipeline.py +0 -501
utils/pipeline.py
DELETED
@@ -1,501 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import sys
|
3 |
-
import tempfile
|
4 |
-
from glob import glob
|
5 |
-
from torchsummary import summary
|
6 |
-
import numpy as np
|
7 |
-
import pandas as pd
|
8 |
-
from tqdm import tqdm
|
9 |
-
import torch
|
10 |
-
from torch.utils.tensorboard import SummaryWriter
|
11 |
-
from torch.cuda.amp import autocast, GradScaler
|
12 |
-
import torch.nn as nn
|
13 |
-
import torchvision
|
14 |
-
import monai
|
15 |
-
from monai.metrics import DiceMetric, ConfusionMatrixMetric, MeanIoU
|
16 |
-
from monai.visualize import plot_2d_or_3d_image
|
17 |
-
from visualization import visualize_patient
|
18 |
-
from sliding_window import sw_inference
|
19 |
-
from data_preparation import build_dataset
|
20 |
-
from models import UNet2D, UNet3D
|
21 |
-
from loss import WeaklyDiceFocalLoss
|
22 |
-
from sklearn.linear_model import LinearRegression
|
23 |
-
from nrrd import write, read
|
24 |
-
import morphsnakes as ms
|
25 |
-
from monai.data import decollate_batch
|
26 |
-
|
27 |
-
|
28 |
-
def build_optimizer(model, config):
|
29 |
-
|
30 |
-
if config['LOSS'] == "gdice":
|
31 |
-
loss_function = monai.losses.GeneralizedDiceLoss(
|
32 |
-
include_background=config['EVAL_INCLUDE_BACKGROUND'],
|
33 |
-
reduction="mean", to_onehot_y=True, sigmoid=True) if len(config['KEEP_CLASSES'])<=2 else monai.losses.GeneralizedDiceLoss(
|
34 |
-
include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean", to_onehot_y=False, softmax=True)
|
35 |
-
elif config['LOSS'] == 'cdice':
|
36 |
-
loss_function = monai.losses.DiceCELoss(
|
37 |
-
include_background=config['EVAL_INCLUDE_BACKGROUND'],
|
38 |
-
reduction="mean", to_onehot_y=True, sigmoid=True) if len(config['KEEP_CLASSES'])<=2 else monai.losses.DiceCELoss(
|
39 |
-
include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean", to_onehot_y=False, softmax=True)
|
40 |
-
elif config['LOSS'] == 'mdice':
|
41 |
-
loss_function = monai.losses.MaskedDiceLoss()
|
42 |
-
elif config['LOSS'] == 'wdice':
|
43 |
-
# Example with 3 classes (including the background: label 0).
|
44 |
-
# The distance between the background class (label 0) and the other classes is the maximum, equal to 1.
|
45 |
-
# The distance between class 1 and class 2 is 0.5.
|
46 |
-
dist_mat = np.array([[0.0, 1.0, 1.0], [1.0, 0.0, 0.5], [1.0, 0.5, 0.0]], dtype=np.float32)
|
47 |
-
loss_function = monai.losses.GeneralizedWassersteinDiceLoss(dist_matrix=dist_mat)
|
48 |
-
elif config['LOSS'] == "fdice":
|
49 |
-
loss_function = monai.losses.DiceFocalLoss(
|
50 |
-
include_background=config['EVAL_INCLUDE_BACKGROUND'], to_onehot_y=True, sigmoid=True) if len(config['KEEP_CLASSES'])<=2 else monai.losses.DiceFocalLoss(
|
51 |
-
include_background=config['EVAL_INCLUDE_BACKGROUND'], to_onehot_y=False, softmax=True)
|
52 |
-
elif config['LOSS'] == "wfdice":
|
53 |
-
loss_function = WeaklyDiceFocalLoss(include_background=config['EVAL_INCLUDE_BACKGROUND'], to_onehot_y=True, sigmoid=True, lambda_weak=config['LAMBDA_WEAK']) if len(config['KEEP_CLASSES'])<=2 else WeaklyDiceFocalLoss(include_background=config['EVAL_INCLUDE_BACKGROUND'], to_onehot_y=False, softmax=True, lambda_weak=config['LAMBDA_WEAK'])
|
54 |
-
else:
|
55 |
-
loss_function = monai.losses.DiceLoss(
|
56 |
-
include_background=config['EVAL_INCLUDE_BACKGROUND'],
|
57 |
-
reduction="mean", to_onehot_y=True, sigmoid=True, squared_pred=True) if len(config['KEEP_CLASSES'])<=2 else monai.losses.DiceLoss(
|
58 |
-
include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean", to_onehot_y=False, softmax=True, squared_pred=True)
|
59 |
-
|
60 |
-
eval_metrics = [
|
61 |
-
("sensitivity", ConfusionMatrixMetric(include_background=config['EVAL_INCLUDE_BACKGROUND'], metric_name='sensitivity', reduction="mean_batch")),
|
62 |
-
("specificity", ConfusionMatrixMetric(include_background=config['EVAL_INCLUDE_BACKGROUND'], metric_name='specificity', reduction="mean_batch")),
|
63 |
-
("accuracy", ConfusionMatrixMetric(include_background=config['EVAL_INCLUDE_BACKGROUND'], metric_name='accuracy', reduction="mean_batch")),
|
64 |
-
("dice", DiceMetric(include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean_batch")),
|
65 |
-
("IoU", MeanIoU(include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean_batch"))
|
66 |
-
]
|
67 |
-
|
68 |
-
optimizer = torch.optim.Adam(model.parameters(), config['LEARNING_RATE'], weight_decay=1e-5, amsgrad=True)
|
69 |
-
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['MAX_EPOCHS'])
|
70 |
-
return loss_function, optimizer, lr_scheduler, eval_metrics
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
def load_weights(model, config):
|
75 |
-
try:
|
76 |
-
model.load_state_dict(torch.load("checkpoints/" + config['PRETRAINED_WEIGHTS'] + ".pth", map_location=torch.device(config['DEVICE'])))
|
77 |
-
print("Model weights from", config['PRETRAINED_WEIGHTS'], "have been loaded")
|
78 |
-
except Exception as e:
|
79 |
-
try:
|
80 |
-
model.load_state_dict(torch.load(config['PRETRAINED_WEIGHTS'], map_location=torch.device(config['DEVICE'])))
|
81 |
-
print("Model weights from", config['PRETRAINED_WEIGHTS'], "have been loaded")
|
82 |
-
except Exception as e: # load
|
83 |
-
print("WARNING: weights were not loaded. ", e)
|
84 |
-
pass
|
85 |
-
|
86 |
-
return model
|
87 |
-
|
88 |
-
|
89 |
-
def build_model(config):
|
90 |
-
|
91 |
-
config = get_defaults(config)
|
92 |
-
|
93 |
-
dropout_prob = config['DROPOUT']
|
94 |
-
|
95 |
-
if "SegResNetVAE" in config["MODEL_NAME"]:
|
96 |
-
model = monai.networks.nets.SegResNetVAE(
|
97 |
-
input_image_size=config['ROI_SIZE'] if "3D" in config['MODEL_NAME'] else (config['ROI_SIZE'][0], config['ROI_SIZE'][1]),
|
98 |
-
vae_estimate_std=False,
|
99 |
-
vae_default_std=0.3,
|
100 |
-
vae_nz=256,
|
101 |
-
spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2,
|
102 |
-
blocks_down=[1, 2, 2, 4],
|
103 |
-
blocks_up=[1, 1, 1],
|
104 |
-
init_filters=16,
|
105 |
-
in_channels=1,
|
106 |
-
norm='instance',
|
107 |
-
out_channels=len(config['KEEP_CLASSES']),
|
108 |
-
dropout_prob=dropout_prob,
|
109 |
-
).to(config['DEVICE'])
|
110 |
-
|
111 |
-
elif "SegResNet" in config["MODEL_NAME"]:
|
112 |
-
model = monai.networks.nets.SegResNet(
|
113 |
-
spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2,
|
114 |
-
blocks_down=[1, 2, 2, 4],
|
115 |
-
blocks_up=[1, 1, 1],
|
116 |
-
init_filters=16,
|
117 |
-
in_channels=1,
|
118 |
-
out_channels=len(config['KEEP_CLASSES']),
|
119 |
-
dropout_prob=dropout_prob,
|
120 |
-
norm="instance"
|
121 |
-
).to(config['DEVICE'])
|
122 |
-
|
123 |
-
elif "SwinUNETR" in config["MODEL_NAME"]:
|
124 |
-
model = monai.networks.nets.SwinUNETR(
|
125 |
-
img_size=config['ROI_SIZE'],
|
126 |
-
in_channels=1,
|
127 |
-
out_channels=len(config['KEEP_CLASSES']),
|
128 |
-
feature_size=48,
|
129 |
-
drop_rate=dropout_prob,
|
130 |
-
attn_drop_rate=0.0,
|
131 |
-
dropout_path_rate=0.0,
|
132 |
-
use_checkpoint=True
|
133 |
-
).to(config['DEVICE'])
|
134 |
-
|
135 |
-
elif "UNETR" in config["MODEL_NAME"]:
|
136 |
-
model = monai.networks.nets.UNETR(
|
137 |
-
img_size=config['ROI_SIZE'] if "3D" in config['MODEL_NAME'] else (config['ROI_SIZE'][0], config['ROI_SIZE'][1]),
|
138 |
-
in_channels=1,
|
139 |
-
out_channels=len(config['KEEP_CLASSES']),
|
140 |
-
feature_size=16,
|
141 |
-
hidden_size=256,
|
142 |
-
mlp_dim=3072,
|
143 |
-
num_heads=8,
|
144 |
-
pos_embed="perceptron",
|
145 |
-
norm_name="instance",
|
146 |
-
res_block=True,
|
147 |
-
dropout_rate=dropout_prob,
|
148 |
-
).to(config['DEVICE'])
|
149 |
-
|
150 |
-
elif "MANet" in config["MODEL_NAME"]:
|
151 |
-
if "2D" in config["MODEL_NAME"]:
|
152 |
-
model = UNet2D(
|
153 |
-
1,
|
154 |
-
len(config['KEEP_CLASSES']),
|
155 |
-
pab_channels=64,
|
156 |
-
use_batchnorm=True
|
157 |
-
).to(config['DEVICE'])
|
158 |
-
else:
|
159 |
-
model = UNet3D(
|
160 |
-
1,
|
161 |
-
len(config['KEEP_CLASSES']),
|
162 |
-
pab_channels=32,
|
163 |
-
use_batchnorm=True
|
164 |
-
).to(config['DEVICE'])
|
165 |
-
|
166 |
-
elif "UNetPlusPlus" in config["MODEL_NAME"]:
|
167 |
-
model = monai.networks.nets.BasicUNetPlusPlus(
|
168 |
-
spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2,
|
169 |
-
in_channels=1,
|
170 |
-
out_channels=len(config['KEEP_CLASSES']),
|
171 |
-
features=(32, 32, 64, 128, 256, 32),
|
172 |
-
norm="instance",
|
173 |
-
dropout=dropout_prob,
|
174 |
-
).to(config['DEVICE'])
|
175 |
-
|
176 |
-
elif "UNet1" in config['MODEL_NAME']:
|
177 |
-
model = monai.networks.nets.UNet(
|
178 |
-
spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2,
|
179 |
-
in_channels=1,
|
180 |
-
out_channels=len(config['KEEP_CLASSES']),
|
181 |
-
channels=(16, 32, 64, 128, 256),
|
182 |
-
strides=(2, 2, 2, 2),
|
183 |
-
num_res_units=2,
|
184 |
-
norm="instance"
|
185 |
-
).to(config['DEVICE'])
|
186 |
-
|
187 |
-
elif "UNet2" in config['MODEL_NAME']:
|
188 |
-
model = monai.networks.nets.UNet(
|
189 |
-
spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2,
|
190 |
-
in_channels=1,
|
191 |
-
out_channels=len(config['KEEP_CLASSES']),
|
192 |
-
channels=(32, 64, 128, 256),
|
193 |
-
strides=(2, 2, 2, 2),
|
194 |
-
num_res_units=4,
|
195 |
-
norm="instance"
|
196 |
-
).to(config['DEVICE'])
|
197 |
-
|
198 |
-
else:
|
199 |
-
print(config["MODEL_NAME"], "is not a valid model name")
|
200 |
-
return None
|
201 |
-
|
202 |
-
try:
|
203 |
-
if "3D" in config['MODEL_NAME']:
|
204 |
-
print(summary(model, input_size=(1, config['ROI_SIZE'][0], config['ROI_SIZE'][1], config['ROI_SIZE'][2])))
|
205 |
-
else:
|
206 |
-
print(summary(model, input_size=(1, config['ROI_SIZE'][0], config['ROI_SIZE'][1])))
|
207 |
-
except Exception as e:
|
208 |
-
print("could not load model summary:", e)
|
209 |
-
|
210 |
-
if config['PRETRAINED_WEIGHTS'] is not None and config['PRETRAINED_WEIGHTS']:
|
211 |
-
model = load_weights(model, config)
|
212 |
-
return model
|
213 |
-
|
214 |
-
|
215 |
-
def train(model, train_loader, val_loader, loss_function, eval_metrics, optimizer, config,
|
216 |
-
scheduler=None, writer=None, postprocessing_transforms = None, weak_labels = None):
|
217 |
-
|
218 |
-
if writer is None: writer = SummaryWriter(log_dir="runs/" + config['EXPORT_FILE_NAME'])
|
219 |
-
best_metric, best_metric_epoch = -1, -1
|
220 |
-
prev_metric, patience, patience_counter = 1, config['EARLY_STOPPING_PATIENCE'], 0
|
221 |
-
if config['AUTOCAST']: scaler = GradScaler() # Initialize GradScaler for mixed precision training
|
222 |
-
|
223 |
-
for epoch in range(config['MAX_EPOCHS']):
|
224 |
-
print("-" * 10)
|
225 |
-
model.train()
|
226 |
-
epoch_loss, step = 0, 0
|
227 |
-
with tqdm(train_loader) as progress_bar:
|
228 |
-
for batch_data in progress_bar:
|
229 |
-
step += 1
|
230 |
-
inputs, labels = batch_data["image"].to(config['DEVICE']), batch_data["mask"].to(config['DEVICE'])
|
231 |
-
|
232 |
-
# only train with batches that have tumor; skip those without tumor
|
233 |
-
if config['TYPE'] == "tumor":
|
234 |
-
if torch.sum(labels[:,-1]) == 0:
|
235 |
-
continue
|
236 |
-
|
237 |
-
# check input shapes
|
238 |
-
if inputs is None or labels is None:
|
239 |
-
continue
|
240 |
-
if inputs.shape[-1] != labels.shape[-1] or inputs.shape[0] != labels.shape[0]:
|
241 |
-
print("WARNING: Batch skipped. Image and mask shape does not match:", inputs.shape[0], labels.shape[0])
|
242 |
-
continue
|
243 |
-
|
244 |
-
optimizer.zero_grad()
|
245 |
-
if not config['AUTOCAST']:
|
246 |
-
|
247 |
-
# segmentation output
|
248 |
-
outputs = model(inputs)
|
249 |
-
if "SegResNetVAE" in config["MODEL_NAME"]: outputs = outputs[0]
|
250 |
-
if isinstance(outputs, list): outputs = outputs[0]
|
251 |
-
|
252 |
-
# loss
|
253 |
-
if weak_labels is not None:
|
254 |
-
weak_label = torch.tensor([weak_labels[step]]).to(config['DEVICE'])
|
255 |
-
loss = loss_function(outputs, labels, weak_label) if config['LOSS'] == 'wfdice' else loss_function(outputs, labels)
|
256 |
-
loss.backward()
|
257 |
-
optimizer.step()
|
258 |
-
|
259 |
-
else:
|
260 |
-
with autocast():
|
261 |
-
outputs = model(inputs)
|
262 |
-
if "SegResNetVAE" in config["MODEL_NAME"]: outputs = outputs[0]
|
263 |
-
if isinstance(outputs, list): outputs = outputs[0]
|
264 |
-
loss = loss_function(outputs, labels, [weak_labels[step]]) if config['LOSS'] == 'wfdice' else loss_function(outputs, labels)
|
265 |
-
|
266 |
-
scaler.scale(loss).backward()
|
267 |
-
scaler.unscale_(optimizer)
|
268 |
-
if torch.isinf(loss).any():
|
269 |
-
print("Detected inf in gradients.")
|
270 |
-
else:
|
271 |
-
scaler.step(optimizer)
|
272 |
-
scaler.update()
|
273 |
-
|
274 |
-
epoch_loss += loss.item()
|
275 |
-
progress_bar.set_description(f'Epoch [{epoch+1}/{config["MAX_EPOCHS"]}], Loss: {epoch_loss/step:.4f}')
|
276 |
-
|
277 |
-
epoch_loss /= step
|
278 |
-
writer.add_scalar("train_loss_epoch", epoch_loss, epoch)
|
279 |
-
progress_bar.set_description(f'Epoch [{epoch+1}/{config["MAX_EPOCHS"]}], Loss: {epoch_loss:.4f}')
|
280 |
-
|
281 |
-
# validation
|
282 |
-
if (epoch + 1) % config['VAL_INTERVAL'] == 0:
|
283 |
-
|
284 |
-
# get a list of validation measures, pick one to be the decision maker
|
285 |
-
val_metrics, (val_images, val_labels, val_outputs) = evaluate(model, val_loader, eval_metrics, config, postprocessing_transforms)
|
286 |
-
if isinstance(config['EVAL_METRIC'], list):
|
287 |
-
cur_metric = np.mean([val_metrics[m] for m in config['EVAL_METRIC']])
|
288 |
-
else:
|
289 |
-
cur_metric = val_metrics[config['EVAL_METRIC']]
|
290 |
-
|
291 |
-
# determine if better than previous best validation metric
|
292 |
-
if cur_metric > best_metric:
|
293 |
-
best_metric, best_metric_epoch = cur_metric, epoch + 1
|
294 |
-
torch.save(model.state_dict(), "checkpoints/" + config['EXPORT_FILE_NAME'] + ".pth")
|
295 |
-
|
296 |
-
# early stopping
|
297 |
-
patience_counter = patience_counter + 1 if prev_metric > cur_metric else 0
|
298 |
-
if patience_counter == patience or epoch - best_metric_epoch > patience:
|
299 |
-
print("Early stopping at epoch", epoch + 1)
|
300 |
-
break
|
301 |
-
print(f'Current epoch: {epoch + 1} current avg {config["EVAL_METRIC"]}: {cur_metric :.4f} best avg {config["EVAL_METRIC"]}: {best_metric:.4f} at epoch {best_metric_epoch}')
|
302 |
-
prev_metric = cur_metric
|
303 |
-
|
304 |
-
# writer
|
305 |
-
for key, value in val_metrics.items():
|
306 |
-
writer.add_scalar("val_" + key, value, epoch)
|
307 |
-
plot_2d_or_3d_image(val_images, epoch + 1, writer, index=len(val_outputs)//2, tag="image",frame_dim=-1)
|
308 |
-
plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=len(val_outputs)//2, tag="label",frame_dim=-1)
|
309 |
-
plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=len(val_outputs)//2, tag="output",frame_dim=-1)
|
310 |
-
|
311 |
-
# update scheduler
|
312 |
-
try:
|
313 |
-
if scheduler is not None: scheduler.step()
|
314 |
-
except:
|
315 |
-
pass
|
316 |
-
|
317 |
-
print(f"Train completed, best {config['EVAL_METRIC']}: {best_metric:.4f} at epoch: {best_metric_epoch}")
|
318 |
-
writer.close()
|
319 |
-
return model, writer
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
def evaluate(model, val_loader, eval_metrics, config, postprocessing_transforms=None, use_liver_seg=False, export_filenames = [], export_file_metadata = []):
|
324 |
-
|
325 |
-
val_metrics = {}
|
326 |
-
model.eval()
|
327 |
-
with torch.no_grad():
|
328 |
-
|
329 |
-
step = 0
|
330 |
-
for val_data in val_loader:
|
331 |
-
# 3D: val_images has shape (1,C,H,W,Z)
|
332 |
-
# 2D: val_images has shape (B,C,H,W)
|
333 |
-
val_images, val_labels = val_data["image"].to(config['DEVICE']), val_data["mask"].to(config['DEVICE'])
|
334 |
-
if use_liver_seg: val_liver = val_data["pred_liver"].to(config['DEVICE'])
|
335 |
-
|
336 |
-
if (val_images[0].shape[-1] != val_labels[0].shape[-1]) or (
|
337 |
-
"3D" not in config["MODEL_NAME"] and val_images.shape[0] != val_labels.shape[0]):
|
338 |
-
print("WARNING: Batch skipped. Image and mask shape does not match:", val_images.shape, val_labels.shape)
|
339 |
-
continue
|
340 |
-
|
341 |
-
# convert outputs to probability
|
342 |
-
if "3D" in config["MODEL_NAME"]:
|
343 |
-
val_outputs = sw_inference(model, val_images, config['ROI_SIZE'], config['AUTOCAST'], discard_second_output='SegResNetVAE' in config['MODEL_NAME'])
|
344 |
-
else:
|
345 |
-
if "SegResNetVAE" in config["MODEL_NAME"]: val_outputs, _ = model(val_images)
|
346 |
-
else: val_outputs = model(val_images)
|
347 |
-
|
348 |
-
# post-procesing
|
349 |
-
if postprocessing_transforms is not None:
|
350 |
-
val_outputs = [postprocessing_transforms(i) for i in decollate_batch(val_outputs)]
|
351 |
-
|
352 |
-
# remove tumor predictions outside liver
|
353 |
-
for i in range(len(val_outputs)):
|
354 |
-
val_outputs[i][-1][torch.where(val_images[i][0] <= 1e-6)] = 0
|
355 |
-
|
356 |
-
# apply morphological snakes algorithm
|
357 |
-
if config['POSTPROCESSING_MORF']:
|
358 |
-
for i in range(len(val_outputs)):
|
359 |
-
val_outputs[i][-1] = torch.from_numpy(ms.morphological_chan_vese(val_images[i][0].cpu(), iterations=2, init_level_set=val_outputs[i][-1].cpu())).to(config['DEVICE'])
|
360 |
-
|
361 |
-
for i in range(len(val_outputs)):
|
362 |
-
if use_liver_seg:
|
363 |
-
# use liver model outputs for liver channel
|
364 |
-
val_outputs[i][1] = val_liver[i]
|
365 |
-
# if region is tumor, assign liver prediction to 0
|
366 |
-
val_outputs[i][1] -= val_outputs[i][2]
|
367 |
-
|
368 |
-
# compute metric for current iteration
|
369 |
-
for metric_name, metric in eval_metrics:
|
370 |
-
if isinstance(val_outputs[0], list):
|
371 |
-
val_outputs = val_outputs[0]
|
372 |
-
metric(val_outputs, val_labels)
|
373 |
-
|
374 |
-
# save prediction to local folder
|
375 |
-
if len(export_filenames) > 0:
|
376 |
-
for _ in range(len(val_outputs)):
|
377 |
-
numpy_array = val_outputs[_].cpu().detach().numpy()
|
378 |
-
write(export_filenames[step], numpy_array[-1], header=export_file_metadata[step])
|
379 |
-
print(" Segmentation exported to", export_filenames[step])
|
380 |
-
step += 1
|
381 |
-
|
382 |
-
# aggregate the final mean metric
|
383 |
-
for metric_name, metric in eval_metrics:
|
384 |
-
if "dice" in metric_name or "IoU" in metric_name: metric_value = metric.aggregate().tolist()
|
385 |
-
else: metric_value = metric.aggregate()[0].tolist() # a list of accuracies, one per class
|
386 |
-
val_metrics[metric_name + "_avg"] = np.mean(metric_value)
|
387 |
-
if config['TYPE'] != "liver":
|
388 |
-
for c in range(1, len(metric_value) + 1): # class-wise accuracies
|
389 |
-
val_metrics[metric_name + "_class" + str(c)] = metric_value[c-1]
|
390 |
-
metric.reset()
|
391 |
-
|
392 |
-
return val_metrics, (val_images, val_labels, val_outputs)
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
def get_defaults(config):
|
398 |
-
|
399 |
-
if 'TRAIN' not in config.keys(): config['TRAIN'] = True
|
400 |
-
if 'VALID_PATIENT_RATIO' not in config.keys(): config['VALID_PATIENT_RATIO'] = 0.2
|
401 |
-
if 'VAL_INTERVAL' not in config.keys(): config['VAL_INTERVAL'] = 1
|
402 |
-
if 'VAL_INTERVAL' not in config.keys(): config['DROPOUT'] = 0.1
|
403 |
-
if 'EARLY_STOPPING_PATIENCE' not in config.keys(): config['EARLY_STOPPING_PATIENCE'] = 20
|
404 |
-
if 'AUTOCAST' not in config.keys(): config['AUTOCAST'] = False
|
405 |
-
if 'NUM_WORKERS' not in config.keys(): config['NUM_WORKERS'] = 0
|
406 |
-
if 'DROPOUT' not in config.keys(): config['DROPOUT'] = 0.1
|
407 |
-
if 'ONESAMPLETESTRUN' not in config.keys(): config['ONESAMPLETESTRUN'] = False
|
408 |
-
if 'TRAIN' not in config.keys(): config['TRAIN'] = True
|
409 |
-
if 'DATA_AUGMENTATION' not in config.keys(): config['DATA_AUGMENTATION'] = False
|
410 |
-
if 'POSTPROCESSING_MORF' not in config.keys(): config['POSTPROCESSING_MORF'] = False
|
411 |
-
if 'PREPROCESSING' not in config.keys(): config['PREPROCESSING'] = ""
|
412 |
-
if 'PRETRAINED_WEIGHTS' not in config.keys(): config['PRETRAINED_WEIGHTS'] = ""
|
413 |
-
|
414 |
-
if 'EVAL_INCLUDE_BACKGROUND' not in config.keys():
|
415 |
-
if config['TYPE'] == "liver": config['EVAL_INCLUDE_BACKGROUND'] = True
|
416 |
-
else: config['EVAL_INCLUDE_BACKGROUND'] = False
|
417 |
-
if 'EVAL_METRIC' not in config.keys():
|
418 |
-
if config['TYPE'] == "liver": config['EVAL_METRIC'] = ["dice_avg"]
|
419 |
-
else: config['EVAL_METRIC'] = ["dice_class2"]
|
420 |
-
|
421 |
-
if 'CLINICAL_DATA_FILE' not in config.keys(): config['CLINICAL_DATA_FILE'] = "Dataset/HCC-TACE-Seg_clinical_data-V2.xlsx"
|
422 |
-
if 'CLINICAL_PREDICTORS' not in config.keys(): config['CLINICAL_PREDICTORS'] = ['T_involvment', 'CLIP_Score','Personal history of cancer', 'TNM', 'Metastasis','fhx_can', 'Alcohol', 'Smoking', 'Evidence_of_cirh', 'AFP', 'age', 'Diabetes', 'Lymphnodes', 'Interval_BL', 'TTP']
|
423 |
-
if 'LAMBDA_WEAK' not in config.keys(): config['LAMBDA_WEAK'] = 0.5
|
424 |
-
if 'MASKNONLIVER' not in config.keys(): config['MASKNONLIVER'] = False
|
425 |
-
|
426 |
-
if config['TYPE'] == "liver": config['KEEP_CLASSES']=["normal", "liver"]
|
427 |
-
elif config['TYPE'] == "tumor": config['KEEP_CLASSES']=["normal", "liver", "tumor"]
|
428 |
-
else: config['KEEP_CLASSES'] = ["normal", "liver", "tumor", "portal vein", "abdominal aorta"]
|
429 |
-
|
430 |
-
config['DEVICE'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
431 |
-
config['EXPORT_FILE_NAME'] = config['TYPE']+ "_" + config['MODEL_NAME'] + "_" + config['LOSS'] + "_batchsize" + str(config['BATCH_SIZE']) + "_DA" + str(config['DATA_AUGMENTATION']) + "_HU" + str(config['HU_RANGE'][0]) + "-" + str(config['HU_RANGE'][1]) + "_" + config['PREPROCESSING'] + "_" + str(config['ROI_SIZE'][0]) + "_" + str(config['ROI_SIZE'][1]) + "_" + str(config['ROI_SIZE'][2]) + "_dropout" + str(config['DROPOUT'])
|
432 |
-
if config['MASKNONLIVER']: config['EXPORT_FILE_NAME'] += "_wobackground"
|
433 |
-
if config['LOSS'] == "wfdice": config['EXPORT_FILE_NAME'] += "_weaklambda" + str(config['LAMBDA_WEAK'])
|
434 |
-
if config['PRETRAINED_WEIGHTS'] != "" and config['PRETRAINED_WEIGHTS'] != config['EXPORT_FILE_NAME']: config['EXPORT_FILE_NAME'] += "_pretraining"
|
435 |
-
if config['POSTPROCESSING_MORF']: config['EXPORT_FILE_NAME'] += "_wpostmorf"
|
436 |
-
if not config['EVAL_INCLUDE_BACKGROUND']: config['EXPORT_FILE_NAME'] += "_evalnobackground"
|
437 |
-
|
438 |
-
return config
|
439 |
-
|
440 |
-
|
441 |
-
def train_clinical(df_clinical):
|
442 |
-
|
443 |
-
clinical_model = LinearRegression()
|
444 |
-
|
445 |
-
# train model
|
446 |
-
print("Training model using", df_clinical.loc[:, df_clinical.columns != 'tumor_ratio'].shape[1], "features")
|
447 |
-
print(df_clinical.head())
|
448 |
-
clinical_model.fit(df_clinical.loc[:, df_clinical.columns != 'tumor_ratio'], df_clinical['tumor_ratio'])
|
449 |
-
|
450 |
-
# obtain predicted ratios
|
451 |
-
pred = clinical_model.predict(df_clinical.loc[:, df_clinical.columns != 'tumor_ratio'])
|
452 |
-
|
453 |
-
# evaluate
|
454 |
-
corr = np.corrcoef(pred, df_clinical['tumor_ratio'])[0][1]
|
455 |
-
mae = np.mean(np.abs(pred - df_clinical['tumor_ratio']))
|
456 |
-
print(f"The clinical model was fitted. Corr = {corr: .6f} MAE = {mae: .6f}")
|
457 |
-
|
458 |
-
return pred
|
459 |
-
|
460 |
-
|
461 |
-
def model_pipeline(config=None, plot=True):
|
462 |
-
|
463 |
-
torch.cuda.empty_cache()
|
464 |
-
config = get_defaults(config)
|
465 |
-
print(f"You Are Running on a: {config['DEVICE']}")
|
466 |
-
print("file name:", config['EXPORT_FILE_NAME'])
|
467 |
-
|
468 |
-
writer = SummaryWriter(log_dir="runs/" + config['EXPORT_FILE_NAME'])
|
469 |
-
|
470 |
-
# prepare data
|
471 |
-
train_loader, valid_loader, test_loader, postprocessing_transforms, df_clinical_train = build_dataset(config, get_clinical=config['LOSS']=="wfdice")
|
472 |
-
|
473 |
-
# train clinical model
|
474 |
-
if config['LOSS'] == "wfdice": weak_labels = train_clinical(df_clinical_train)
|
475 |
-
else: weak_labels = None
|
476 |
-
|
477 |
-
# train segmentation model
|
478 |
-
model = build_model(config)
|
479 |
-
loss_function, optimizer, lr_scheduler, eval_metrics = build_optimizer(model, config)
|
480 |
-
if config['TRAIN']:
|
481 |
-
train(model, train_loader, valid_loader, loss_function, eval_metrics, optimizer, config, lr_scheduler, writer, postprocessing_transforms, weak_labels)
|
482 |
-
model.load_state_dict(torch.load("checkpoints/" + config['EXPORT_FILE_NAME'] + ".pth", map_location=torch.device(config['DEVICE'])))
|
483 |
-
if config['ONESAMPLETESTRUN']:
|
484 |
-
return None, None, None
|
485 |
-
|
486 |
-
# test segmentation model
|
487 |
-
test_metrics, (test_images, test_labels, test_outputs) = evaluate(model, test_loader, eval_metrics, config, postprocessing_transforms)
|
488 |
-
print("Test metrics")
|
489 |
-
for key, value in test_metrics.items():
|
490 |
-
print(f" {key}: {value:.4f}")
|
491 |
-
|
492 |
-
# visualize
|
493 |
-
if plot:
|
494 |
-
if "3D" in config['MODEL_NAME']:
|
495 |
-
visualize_patient(test_images[0].cpu(), mask=test_labels[0].cpu(), n_slices=9, title="ground truth", z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1)
|
496 |
-
visualize_patient(test_images[0].cpu(), mask=test_outputs[0].cpu(), n_slices=9, title="predicted", z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1)
|
497 |
-
else:
|
498 |
-
visualize_patient(test_images.cpu(), mask=test_labels.cpu(), n_slices=9, title="ground truth", z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1)
|
499 |
-
visualize_patient(test_images.cpu(), mask=torch.stack(test_outputs).cpu(), n_slices=9, title="predicted", z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1)
|
500 |
-
|
501 |
-
return (test_images, test_labels, test_outputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|