lingchmao commited on
Commit
fc3517c
·
verified ·
1 Parent(s): 12d5907

Delete utils/pipeline.py

Browse files
Files changed (1) hide show
  1. utils/pipeline.py +0 -501
utils/pipeline.py DELETED
@@ -1,501 +0,0 @@
1
- import logging
2
- import sys
3
- import tempfile
4
- from glob import glob
5
- from torchsummary import summary
6
- import numpy as np
7
- import pandas as pd
8
- from tqdm import tqdm
9
- import torch
10
- from torch.utils.tensorboard import SummaryWriter
11
- from torch.cuda.amp import autocast, GradScaler
12
- import torch.nn as nn
13
- import torchvision
14
- import monai
15
- from monai.metrics import DiceMetric, ConfusionMatrixMetric, MeanIoU
16
- from monai.visualize import plot_2d_or_3d_image
17
- from visualization import visualize_patient
18
- from sliding_window import sw_inference
19
- from data_preparation import build_dataset
20
- from models import UNet2D, UNet3D
21
- from loss import WeaklyDiceFocalLoss
22
- from sklearn.linear_model import LinearRegression
23
- from nrrd import write, read
24
- import morphsnakes as ms
25
- from monai.data import decollate_batch
26
-
27
-
28
- def build_optimizer(model, config):
29
-
30
- if config['LOSS'] == "gdice":
31
- loss_function = monai.losses.GeneralizedDiceLoss(
32
- include_background=config['EVAL_INCLUDE_BACKGROUND'],
33
- reduction="mean", to_onehot_y=True, sigmoid=True) if len(config['KEEP_CLASSES'])<=2 else monai.losses.GeneralizedDiceLoss(
34
- include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean", to_onehot_y=False, softmax=True)
35
- elif config['LOSS'] == 'cdice':
36
- loss_function = monai.losses.DiceCELoss(
37
- include_background=config['EVAL_INCLUDE_BACKGROUND'],
38
- reduction="mean", to_onehot_y=True, sigmoid=True) if len(config['KEEP_CLASSES'])<=2 else monai.losses.DiceCELoss(
39
- include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean", to_onehot_y=False, softmax=True)
40
- elif config['LOSS'] == 'mdice':
41
- loss_function = monai.losses.MaskedDiceLoss()
42
- elif config['LOSS'] == 'wdice':
43
- # Example with 3 classes (including the background: label 0).
44
- # The distance between the background class (label 0) and the other classes is the maximum, equal to 1.
45
- # The distance between class 1 and class 2 is 0.5.
46
- dist_mat = np.array([[0.0, 1.0, 1.0], [1.0, 0.0, 0.5], [1.0, 0.5, 0.0]], dtype=np.float32)
47
- loss_function = monai.losses.GeneralizedWassersteinDiceLoss(dist_matrix=dist_mat)
48
- elif config['LOSS'] == "fdice":
49
- loss_function = monai.losses.DiceFocalLoss(
50
- include_background=config['EVAL_INCLUDE_BACKGROUND'], to_onehot_y=True, sigmoid=True) if len(config['KEEP_CLASSES'])<=2 else monai.losses.DiceFocalLoss(
51
- include_background=config['EVAL_INCLUDE_BACKGROUND'], to_onehot_y=False, softmax=True)
52
- elif config['LOSS'] == "wfdice":
53
- loss_function = WeaklyDiceFocalLoss(include_background=config['EVAL_INCLUDE_BACKGROUND'], to_onehot_y=True, sigmoid=True, lambda_weak=config['LAMBDA_WEAK']) if len(config['KEEP_CLASSES'])<=2 else WeaklyDiceFocalLoss(include_background=config['EVAL_INCLUDE_BACKGROUND'], to_onehot_y=False, softmax=True, lambda_weak=config['LAMBDA_WEAK'])
54
- else:
55
- loss_function = monai.losses.DiceLoss(
56
- include_background=config['EVAL_INCLUDE_BACKGROUND'],
57
- reduction="mean", to_onehot_y=True, sigmoid=True, squared_pred=True) if len(config['KEEP_CLASSES'])<=2 else monai.losses.DiceLoss(
58
- include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean", to_onehot_y=False, softmax=True, squared_pred=True)
59
-
60
- eval_metrics = [
61
- ("sensitivity", ConfusionMatrixMetric(include_background=config['EVAL_INCLUDE_BACKGROUND'], metric_name='sensitivity', reduction="mean_batch")),
62
- ("specificity", ConfusionMatrixMetric(include_background=config['EVAL_INCLUDE_BACKGROUND'], metric_name='specificity', reduction="mean_batch")),
63
- ("accuracy", ConfusionMatrixMetric(include_background=config['EVAL_INCLUDE_BACKGROUND'], metric_name='accuracy', reduction="mean_batch")),
64
- ("dice", DiceMetric(include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean_batch")),
65
- ("IoU", MeanIoU(include_background=config['EVAL_INCLUDE_BACKGROUND'], reduction="mean_batch"))
66
- ]
67
-
68
- optimizer = torch.optim.Adam(model.parameters(), config['LEARNING_RATE'], weight_decay=1e-5, amsgrad=True)
69
- lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['MAX_EPOCHS'])
70
- return loss_function, optimizer, lr_scheduler, eval_metrics
71
-
72
-
73
-
74
- def load_weights(model, config):
75
- try:
76
- model.load_state_dict(torch.load("checkpoints/" + config['PRETRAINED_WEIGHTS'] + ".pth", map_location=torch.device(config['DEVICE'])))
77
- print("Model weights from", config['PRETRAINED_WEIGHTS'], "have been loaded")
78
- except Exception as e:
79
- try:
80
- model.load_state_dict(torch.load(config['PRETRAINED_WEIGHTS'], map_location=torch.device(config['DEVICE'])))
81
- print("Model weights from", config['PRETRAINED_WEIGHTS'], "have been loaded")
82
- except Exception as e: # load
83
- print("WARNING: weights were not loaded. ", e)
84
- pass
85
-
86
- return model
87
-
88
-
89
- def build_model(config):
90
-
91
- config = get_defaults(config)
92
-
93
- dropout_prob = config['DROPOUT']
94
-
95
- if "SegResNetVAE" in config["MODEL_NAME"]:
96
- model = monai.networks.nets.SegResNetVAE(
97
- input_image_size=config['ROI_SIZE'] if "3D" in config['MODEL_NAME'] else (config['ROI_SIZE'][0], config['ROI_SIZE'][1]),
98
- vae_estimate_std=False,
99
- vae_default_std=0.3,
100
- vae_nz=256,
101
- spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2,
102
- blocks_down=[1, 2, 2, 4],
103
- blocks_up=[1, 1, 1],
104
- init_filters=16,
105
- in_channels=1,
106
- norm='instance',
107
- out_channels=len(config['KEEP_CLASSES']),
108
- dropout_prob=dropout_prob,
109
- ).to(config['DEVICE'])
110
-
111
- elif "SegResNet" in config["MODEL_NAME"]:
112
- model = monai.networks.nets.SegResNet(
113
- spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2,
114
- blocks_down=[1, 2, 2, 4],
115
- blocks_up=[1, 1, 1],
116
- init_filters=16,
117
- in_channels=1,
118
- out_channels=len(config['KEEP_CLASSES']),
119
- dropout_prob=dropout_prob,
120
- norm="instance"
121
- ).to(config['DEVICE'])
122
-
123
- elif "SwinUNETR" in config["MODEL_NAME"]:
124
- model = monai.networks.nets.SwinUNETR(
125
- img_size=config['ROI_SIZE'],
126
- in_channels=1,
127
- out_channels=len(config['KEEP_CLASSES']),
128
- feature_size=48,
129
- drop_rate=dropout_prob,
130
- attn_drop_rate=0.0,
131
- dropout_path_rate=0.0,
132
- use_checkpoint=True
133
- ).to(config['DEVICE'])
134
-
135
- elif "UNETR" in config["MODEL_NAME"]:
136
- model = monai.networks.nets.UNETR(
137
- img_size=config['ROI_SIZE'] if "3D" in config['MODEL_NAME'] else (config['ROI_SIZE'][0], config['ROI_SIZE'][1]),
138
- in_channels=1,
139
- out_channels=len(config['KEEP_CLASSES']),
140
- feature_size=16,
141
- hidden_size=256,
142
- mlp_dim=3072,
143
- num_heads=8,
144
- pos_embed="perceptron",
145
- norm_name="instance",
146
- res_block=True,
147
- dropout_rate=dropout_prob,
148
- ).to(config['DEVICE'])
149
-
150
- elif "MANet" in config["MODEL_NAME"]:
151
- if "2D" in config["MODEL_NAME"]:
152
- model = UNet2D(
153
- 1,
154
- len(config['KEEP_CLASSES']),
155
- pab_channels=64,
156
- use_batchnorm=True
157
- ).to(config['DEVICE'])
158
- else:
159
- model = UNet3D(
160
- 1,
161
- len(config['KEEP_CLASSES']),
162
- pab_channels=32,
163
- use_batchnorm=True
164
- ).to(config['DEVICE'])
165
-
166
- elif "UNetPlusPlus" in config["MODEL_NAME"]:
167
- model = monai.networks.nets.BasicUNetPlusPlus(
168
- spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2,
169
- in_channels=1,
170
- out_channels=len(config['KEEP_CLASSES']),
171
- features=(32, 32, 64, 128, 256, 32),
172
- norm="instance",
173
- dropout=dropout_prob,
174
- ).to(config['DEVICE'])
175
-
176
- elif "UNet1" in config['MODEL_NAME']:
177
- model = monai.networks.nets.UNet(
178
- spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2,
179
- in_channels=1,
180
- out_channels=len(config['KEEP_CLASSES']),
181
- channels=(16, 32, 64, 128, 256),
182
- strides=(2, 2, 2, 2),
183
- num_res_units=2,
184
- norm="instance"
185
- ).to(config['DEVICE'])
186
-
187
- elif "UNet2" in config['MODEL_NAME']:
188
- model = monai.networks.nets.UNet(
189
- spatial_dims=3 if "3D" in config["MODEL_NAME"] else 2,
190
- in_channels=1,
191
- out_channels=len(config['KEEP_CLASSES']),
192
- channels=(32, 64, 128, 256),
193
- strides=(2, 2, 2, 2),
194
- num_res_units=4,
195
- norm="instance"
196
- ).to(config['DEVICE'])
197
-
198
- else:
199
- print(config["MODEL_NAME"], "is not a valid model name")
200
- return None
201
-
202
- try:
203
- if "3D" in config['MODEL_NAME']:
204
- print(summary(model, input_size=(1, config['ROI_SIZE'][0], config['ROI_SIZE'][1], config['ROI_SIZE'][2])))
205
- else:
206
- print(summary(model, input_size=(1, config['ROI_SIZE'][0], config['ROI_SIZE'][1])))
207
- except Exception as e:
208
- print("could not load model summary:", e)
209
-
210
- if config['PRETRAINED_WEIGHTS'] is not None and config['PRETRAINED_WEIGHTS']:
211
- model = load_weights(model, config)
212
- return model
213
-
214
-
215
- def train(model, train_loader, val_loader, loss_function, eval_metrics, optimizer, config,
216
- scheduler=None, writer=None, postprocessing_transforms = None, weak_labels = None):
217
-
218
- if writer is None: writer = SummaryWriter(log_dir="runs/" + config['EXPORT_FILE_NAME'])
219
- best_metric, best_metric_epoch = -1, -1
220
- prev_metric, patience, patience_counter = 1, config['EARLY_STOPPING_PATIENCE'], 0
221
- if config['AUTOCAST']: scaler = GradScaler() # Initialize GradScaler for mixed precision training
222
-
223
- for epoch in range(config['MAX_EPOCHS']):
224
- print("-" * 10)
225
- model.train()
226
- epoch_loss, step = 0, 0
227
- with tqdm(train_loader) as progress_bar:
228
- for batch_data in progress_bar:
229
- step += 1
230
- inputs, labels = batch_data["image"].to(config['DEVICE']), batch_data["mask"].to(config['DEVICE'])
231
-
232
- # only train with batches that have tumor; skip those without tumor
233
- if config['TYPE'] == "tumor":
234
- if torch.sum(labels[:,-1]) == 0:
235
- continue
236
-
237
- # check input shapes
238
- if inputs is None or labels is None:
239
- continue
240
- if inputs.shape[-1] != labels.shape[-1] or inputs.shape[0] != labels.shape[0]:
241
- print("WARNING: Batch skipped. Image and mask shape does not match:", inputs.shape[0], labels.shape[0])
242
- continue
243
-
244
- optimizer.zero_grad()
245
- if not config['AUTOCAST']:
246
-
247
- # segmentation output
248
- outputs = model(inputs)
249
- if "SegResNetVAE" in config["MODEL_NAME"]: outputs = outputs[0]
250
- if isinstance(outputs, list): outputs = outputs[0]
251
-
252
- # loss
253
- if weak_labels is not None:
254
- weak_label = torch.tensor([weak_labels[step]]).to(config['DEVICE'])
255
- loss = loss_function(outputs, labels, weak_label) if config['LOSS'] == 'wfdice' else loss_function(outputs, labels)
256
- loss.backward()
257
- optimizer.step()
258
-
259
- else:
260
- with autocast():
261
- outputs = model(inputs)
262
- if "SegResNetVAE" in config["MODEL_NAME"]: outputs = outputs[0]
263
- if isinstance(outputs, list): outputs = outputs[0]
264
- loss = loss_function(outputs, labels, [weak_labels[step]]) if config['LOSS'] == 'wfdice' else loss_function(outputs, labels)
265
-
266
- scaler.scale(loss).backward()
267
- scaler.unscale_(optimizer)
268
- if torch.isinf(loss).any():
269
- print("Detected inf in gradients.")
270
- else:
271
- scaler.step(optimizer)
272
- scaler.update()
273
-
274
- epoch_loss += loss.item()
275
- progress_bar.set_description(f'Epoch [{epoch+1}/{config["MAX_EPOCHS"]}], Loss: {epoch_loss/step:.4f}')
276
-
277
- epoch_loss /= step
278
- writer.add_scalar("train_loss_epoch", epoch_loss, epoch)
279
- progress_bar.set_description(f'Epoch [{epoch+1}/{config["MAX_EPOCHS"]}], Loss: {epoch_loss:.4f}')
280
-
281
- # validation
282
- if (epoch + 1) % config['VAL_INTERVAL'] == 0:
283
-
284
- # get a list of validation measures, pick one to be the decision maker
285
- val_metrics, (val_images, val_labels, val_outputs) = evaluate(model, val_loader, eval_metrics, config, postprocessing_transforms)
286
- if isinstance(config['EVAL_METRIC'], list):
287
- cur_metric = np.mean([val_metrics[m] for m in config['EVAL_METRIC']])
288
- else:
289
- cur_metric = val_metrics[config['EVAL_METRIC']]
290
-
291
- # determine if better than previous best validation metric
292
- if cur_metric > best_metric:
293
- best_metric, best_metric_epoch = cur_metric, epoch + 1
294
- torch.save(model.state_dict(), "checkpoints/" + config['EXPORT_FILE_NAME'] + ".pth")
295
-
296
- # early stopping
297
- patience_counter = patience_counter + 1 if prev_metric > cur_metric else 0
298
- if patience_counter == patience or epoch - best_metric_epoch > patience:
299
- print("Early stopping at epoch", epoch + 1)
300
- break
301
- print(f'Current epoch: {epoch + 1} current avg {config["EVAL_METRIC"]}: {cur_metric :.4f} best avg {config["EVAL_METRIC"]}: {best_metric:.4f} at epoch {best_metric_epoch}')
302
- prev_metric = cur_metric
303
-
304
- # writer
305
- for key, value in val_metrics.items():
306
- writer.add_scalar("val_" + key, value, epoch)
307
- plot_2d_or_3d_image(val_images, epoch + 1, writer, index=len(val_outputs)//2, tag="image",frame_dim=-1)
308
- plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=len(val_outputs)//2, tag="label",frame_dim=-1)
309
- plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=len(val_outputs)//2, tag="output",frame_dim=-1)
310
-
311
- # update scheduler
312
- try:
313
- if scheduler is not None: scheduler.step()
314
- except:
315
- pass
316
-
317
- print(f"Train completed, best {config['EVAL_METRIC']}: {best_metric:.4f} at epoch: {best_metric_epoch}")
318
- writer.close()
319
- return model, writer
320
-
321
-
322
-
323
- def evaluate(model, val_loader, eval_metrics, config, postprocessing_transforms=None, use_liver_seg=False, export_filenames = [], export_file_metadata = []):
324
-
325
- val_metrics = {}
326
- model.eval()
327
- with torch.no_grad():
328
-
329
- step = 0
330
- for val_data in val_loader:
331
- # 3D: val_images has shape (1,C,H,W,Z)
332
- # 2D: val_images has shape (B,C,H,W)
333
- val_images, val_labels = val_data["image"].to(config['DEVICE']), val_data["mask"].to(config['DEVICE'])
334
- if use_liver_seg: val_liver = val_data["pred_liver"].to(config['DEVICE'])
335
-
336
- if (val_images[0].shape[-1] != val_labels[0].shape[-1]) or (
337
- "3D" not in config["MODEL_NAME"] and val_images.shape[0] != val_labels.shape[0]):
338
- print("WARNING: Batch skipped. Image and mask shape does not match:", val_images.shape, val_labels.shape)
339
- continue
340
-
341
- # convert outputs to probability
342
- if "3D" in config["MODEL_NAME"]:
343
- val_outputs = sw_inference(model, val_images, config['ROI_SIZE'], config['AUTOCAST'], discard_second_output='SegResNetVAE' in config['MODEL_NAME'])
344
- else:
345
- if "SegResNetVAE" in config["MODEL_NAME"]: val_outputs, _ = model(val_images)
346
- else: val_outputs = model(val_images)
347
-
348
- # post-procesing
349
- if postprocessing_transforms is not None:
350
- val_outputs = [postprocessing_transforms(i) for i in decollate_batch(val_outputs)]
351
-
352
- # remove tumor predictions outside liver
353
- for i in range(len(val_outputs)):
354
- val_outputs[i][-1][torch.where(val_images[i][0] <= 1e-6)] = 0
355
-
356
- # apply morphological snakes algorithm
357
- if config['POSTPROCESSING_MORF']:
358
- for i in range(len(val_outputs)):
359
- val_outputs[i][-1] = torch.from_numpy(ms.morphological_chan_vese(val_images[i][0].cpu(), iterations=2, init_level_set=val_outputs[i][-1].cpu())).to(config['DEVICE'])
360
-
361
- for i in range(len(val_outputs)):
362
- if use_liver_seg:
363
- # use liver model outputs for liver channel
364
- val_outputs[i][1] = val_liver[i]
365
- # if region is tumor, assign liver prediction to 0
366
- val_outputs[i][1] -= val_outputs[i][2]
367
-
368
- # compute metric for current iteration
369
- for metric_name, metric in eval_metrics:
370
- if isinstance(val_outputs[0], list):
371
- val_outputs = val_outputs[0]
372
- metric(val_outputs, val_labels)
373
-
374
- # save prediction to local folder
375
- if len(export_filenames) > 0:
376
- for _ in range(len(val_outputs)):
377
- numpy_array = val_outputs[_].cpu().detach().numpy()
378
- write(export_filenames[step], numpy_array[-1], header=export_file_metadata[step])
379
- print(" Segmentation exported to", export_filenames[step])
380
- step += 1
381
-
382
- # aggregate the final mean metric
383
- for metric_name, metric in eval_metrics:
384
- if "dice" in metric_name or "IoU" in metric_name: metric_value = metric.aggregate().tolist()
385
- else: metric_value = metric.aggregate()[0].tolist() # a list of accuracies, one per class
386
- val_metrics[metric_name + "_avg"] = np.mean(metric_value)
387
- if config['TYPE'] != "liver":
388
- for c in range(1, len(metric_value) + 1): # class-wise accuracies
389
- val_metrics[metric_name + "_class" + str(c)] = metric_value[c-1]
390
- metric.reset()
391
-
392
- return val_metrics, (val_images, val_labels, val_outputs)
393
-
394
-
395
-
396
-
397
- def get_defaults(config):
398
-
399
- if 'TRAIN' not in config.keys(): config['TRAIN'] = True
400
- if 'VALID_PATIENT_RATIO' not in config.keys(): config['VALID_PATIENT_RATIO'] = 0.2
401
- if 'VAL_INTERVAL' not in config.keys(): config['VAL_INTERVAL'] = 1
402
- if 'VAL_INTERVAL' not in config.keys(): config['DROPOUT'] = 0.1
403
- if 'EARLY_STOPPING_PATIENCE' not in config.keys(): config['EARLY_STOPPING_PATIENCE'] = 20
404
- if 'AUTOCAST' not in config.keys(): config['AUTOCAST'] = False
405
- if 'NUM_WORKERS' not in config.keys(): config['NUM_WORKERS'] = 0
406
- if 'DROPOUT' not in config.keys(): config['DROPOUT'] = 0.1
407
- if 'ONESAMPLETESTRUN' not in config.keys(): config['ONESAMPLETESTRUN'] = False
408
- if 'TRAIN' not in config.keys(): config['TRAIN'] = True
409
- if 'DATA_AUGMENTATION' not in config.keys(): config['DATA_AUGMENTATION'] = False
410
- if 'POSTPROCESSING_MORF' not in config.keys(): config['POSTPROCESSING_MORF'] = False
411
- if 'PREPROCESSING' not in config.keys(): config['PREPROCESSING'] = ""
412
- if 'PRETRAINED_WEIGHTS' not in config.keys(): config['PRETRAINED_WEIGHTS'] = ""
413
-
414
- if 'EVAL_INCLUDE_BACKGROUND' not in config.keys():
415
- if config['TYPE'] == "liver": config['EVAL_INCLUDE_BACKGROUND'] = True
416
- else: config['EVAL_INCLUDE_BACKGROUND'] = False
417
- if 'EVAL_METRIC' not in config.keys():
418
- if config['TYPE'] == "liver": config['EVAL_METRIC'] = ["dice_avg"]
419
- else: config['EVAL_METRIC'] = ["dice_class2"]
420
-
421
- if 'CLINICAL_DATA_FILE' not in config.keys(): config['CLINICAL_DATA_FILE'] = "Dataset/HCC-TACE-Seg_clinical_data-V2.xlsx"
422
- if 'CLINICAL_PREDICTORS' not in config.keys(): config['CLINICAL_PREDICTORS'] = ['T_involvment', 'CLIP_Score','Personal history of cancer', 'TNM', 'Metastasis','fhx_can', 'Alcohol', 'Smoking', 'Evidence_of_cirh', 'AFP', 'age', 'Diabetes', 'Lymphnodes', 'Interval_BL', 'TTP']
423
- if 'LAMBDA_WEAK' not in config.keys(): config['LAMBDA_WEAK'] = 0.5
424
- if 'MASKNONLIVER' not in config.keys(): config['MASKNONLIVER'] = False
425
-
426
- if config['TYPE'] == "liver": config['KEEP_CLASSES']=["normal", "liver"]
427
- elif config['TYPE'] == "tumor": config['KEEP_CLASSES']=["normal", "liver", "tumor"]
428
- else: config['KEEP_CLASSES'] = ["normal", "liver", "tumor", "portal vein", "abdominal aorta"]
429
-
430
- config['DEVICE'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
431
- config['EXPORT_FILE_NAME'] = config['TYPE']+ "_" + config['MODEL_NAME'] + "_" + config['LOSS'] + "_batchsize" + str(config['BATCH_SIZE']) + "_DA" + str(config['DATA_AUGMENTATION']) + "_HU" + str(config['HU_RANGE'][0]) + "-" + str(config['HU_RANGE'][1]) + "_" + config['PREPROCESSING'] + "_" + str(config['ROI_SIZE'][0]) + "_" + str(config['ROI_SIZE'][1]) + "_" + str(config['ROI_SIZE'][2]) + "_dropout" + str(config['DROPOUT'])
432
- if config['MASKNONLIVER']: config['EXPORT_FILE_NAME'] += "_wobackground"
433
- if config['LOSS'] == "wfdice": config['EXPORT_FILE_NAME'] += "_weaklambda" + str(config['LAMBDA_WEAK'])
434
- if config['PRETRAINED_WEIGHTS'] != "" and config['PRETRAINED_WEIGHTS'] != config['EXPORT_FILE_NAME']: config['EXPORT_FILE_NAME'] += "_pretraining"
435
- if config['POSTPROCESSING_MORF']: config['EXPORT_FILE_NAME'] += "_wpostmorf"
436
- if not config['EVAL_INCLUDE_BACKGROUND']: config['EXPORT_FILE_NAME'] += "_evalnobackground"
437
-
438
- return config
439
-
440
-
441
- def train_clinical(df_clinical):
442
-
443
- clinical_model = LinearRegression()
444
-
445
- # train model
446
- print("Training model using", df_clinical.loc[:, df_clinical.columns != 'tumor_ratio'].shape[1], "features")
447
- print(df_clinical.head())
448
- clinical_model.fit(df_clinical.loc[:, df_clinical.columns != 'tumor_ratio'], df_clinical['tumor_ratio'])
449
-
450
- # obtain predicted ratios
451
- pred = clinical_model.predict(df_clinical.loc[:, df_clinical.columns != 'tumor_ratio'])
452
-
453
- # evaluate
454
- corr = np.corrcoef(pred, df_clinical['tumor_ratio'])[0][1]
455
- mae = np.mean(np.abs(pred - df_clinical['tumor_ratio']))
456
- print(f"The clinical model was fitted. Corr = {corr: .6f} MAE = {mae: .6f}")
457
-
458
- return pred
459
-
460
-
461
- def model_pipeline(config=None, plot=True):
462
-
463
- torch.cuda.empty_cache()
464
- config = get_defaults(config)
465
- print(f"You Are Running on a: {config['DEVICE']}")
466
- print("file name:", config['EXPORT_FILE_NAME'])
467
-
468
- writer = SummaryWriter(log_dir="runs/" + config['EXPORT_FILE_NAME'])
469
-
470
- # prepare data
471
- train_loader, valid_loader, test_loader, postprocessing_transforms, df_clinical_train = build_dataset(config, get_clinical=config['LOSS']=="wfdice")
472
-
473
- # train clinical model
474
- if config['LOSS'] == "wfdice": weak_labels = train_clinical(df_clinical_train)
475
- else: weak_labels = None
476
-
477
- # train segmentation model
478
- model = build_model(config)
479
- loss_function, optimizer, lr_scheduler, eval_metrics = build_optimizer(model, config)
480
- if config['TRAIN']:
481
- train(model, train_loader, valid_loader, loss_function, eval_metrics, optimizer, config, lr_scheduler, writer, postprocessing_transforms, weak_labels)
482
- model.load_state_dict(torch.load("checkpoints/" + config['EXPORT_FILE_NAME'] + ".pth", map_location=torch.device(config['DEVICE'])))
483
- if config['ONESAMPLETESTRUN']:
484
- return None, None, None
485
-
486
- # test segmentation model
487
- test_metrics, (test_images, test_labels, test_outputs) = evaluate(model, test_loader, eval_metrics, config, postprocessing_transforms)
488
- print("Test metrics")
489
- for key, value in test_metrics.items():
490
- print(f" {key}: {value:.4f}")
491
-
492
- # visualize
493
- if plot:
494
- if "3D" in config['MODEL_NAME']:
495
- visualize_patient(test_images[0].cpu(), mask=test_labels[0].cpu(), n_slices=9, title="ground truth", z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1)
496
- visualize_patient(test_images[0].cpu(), mask=test_outputs[0].cpu(), n_slices=9, title="predicted", z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1)
497
- else:
498
- visualize_patient(test_images.cpu(), mask=test_labels.cpu(), n_slices=9, title="ground truth", z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1)
499
- visualize_patient(test_images.cpu(), mask=torch.stack(test_outputs).cpu(), n_slices=9, title="predicted", z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1)
500
-
501
- return (test_images, test_labels, test_outputs)