DmitriiKhizbullin commited on
Commit
797d116
1 Parent(s): ff0a933

Split code into finer files

Browse files
Files changed (5) hide show
  1. README.md +10 -1
  2. data.py +231 -0
  3. metrics.py +54 -0
  4. train.py +1 -537
  5. trainer.py +272 -0
README.md CHANGED
@@ -2,12 +2,21 @@
2
 
3
  ## Setup
4
 
 
 
 
 
 
 
 
 
 
5
  ### Gradio app environment
6
 
7
  Install from pip requirements file:
8
 
9
  ```bash
10
- conda create -n retinopathy_app python=3.10
11
  conda activate retinopathy_app
12
  pip install -r requirements.txt
13
  python app.py
 
2
 
3
  ## Setup
4
 
5
+ ### Cloning the repo
6
+
7
+ Install git LFS via [this instruction](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage).
8
+ ```bash
9
+ git clone https://github.com/SDAIA-KAUST-AI/diabetic-retinopathy-detection.git
10
+ git lfs install # to make sure LFS is enabled
11
+ git lfs pull # to bring in demo images and pretrained models
12
+ ```
13
+
14
  ### Gradio app environment
15
 
16
  Install from pip requirements file:
17
 
18
  ```bash
19
+ conda create -y -n retinopathy_app python=3.10
20
  conda activate retinopathy_app
21
  pip install -r requirements.txt
22
  python app.py
data.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import (Dict, Optional, Tuple,
3
+ Union, Callable, Iterable)
4
+ import pandas as pd
5
+ from PIL import Image
6
+ from enum import Enum
7
+ import numpy as np
8
+ from numpy.random import RandomState
9
+ import collections.abc
10
+ from collections import Counter, defaultdict
11
+
12
+ import torch
13
+ import torch.utils.data as data
14
+ from torch.utils.data import DataLoader
15
+
16
+
17
+ from labelmap import DR_LABELMAP
18
+
19
+
20
+ DataRecord = Tuple[Image.Image, int]
21
+
22
+
23
+ class RetinopathyDataset(data.Dataset[DataRecord]):
24
+ """ A class to access the pre-downloaded Diabetic Retinopathy dataset. """
25
+
26
+ def __init__(self, data_path: str) -> None:
27
+ """ Constructor.
28
+
29
+ Args:
30
+ data_path (str): path to the dataset, ex: "retinopathy_data"
31
+ containing "trainLabels.csv" and "train/".
32
+ """
33
+ super().__init__()
34
+
35
+ self.data_path = data_path
36
+
37
+ self.ext = ".jpeg"
38
+
39
+ anno_path = os.path.join(data_path, "trainLabels.csv")
40
+ self.anno_df = pd.read_csv(anno_path) # ['image', 'level']
41
+ anno_name_set = set(self.anno_df['image'])
42
+
43
+ if True:
44
+ train_path = os.path.join(data_path, "train")
45
+ img_path_list = os.listdir(train_path)
46
+ img_name_set = set([os.path.splitext(p)[0] for p in img_path_list])
47
+ assert anno_name_set == img_name_set
48
+
49
+ self.label_map = DR_LABELMAP
50
+
51
+ def __getitem__(self, index: Union[int, slice]) -> DataRecord:
52
+ assert isinstance(index, int)
53
+ img_path = self.get_path_at(index)
54
+ img = Image.open(img_path)
55
+ label = self.get_label_at(index)
56
+ return img, label
57
+
58
+ def __len__(self) -> int:
59
+ return len(self.anno_df)
60
+
61
+ def get_label_at(self, index: int) -> int:
62
+ label = self.anno_df['level'].iloc[index].item()
63
+ return label
64
+
65
+ def get_path_at(self, index: int) -> str:
66
+ img_name = self.anno_df['image'].iloc[index]
67
+ img_path = os.path.join(self.data_path, "train", img_name+self.ext)
68
+ return img_path
69
+
70
+
71
+ """ Purpose of a split: training or validation. """
72
+ class Purpose(Enum):
73
+ Train = 0
74
+ Val = 1
75
+
76
+ """ Augmentation transformations for an image and a label. """
77
+ FeatureAndTargetTransforms = Tuple[Callable[..., torch.Tensor],
78
+ Callable[..., torch.Tensor]]
79
+
80
+ """ Feature (image) and target (label) tensors. """
81
+ TensorRecord = Tuple[torch.Tensor, torch.Tensor]
82
+
83
+
84
+ class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
85
+ """ Split is a class that keep a view on a part of a dataset.
86
+ Split is used to hold the imormation about which samples go to training
87
+ and which to validation without a need to put these groups of files into
88
+ separate folders.
89
+ """
90
+ def __init__(self, dataset: RetinopathyDataset,
91
+ indices: np.ndarray,
92
+ purpose: Purpose,
93
+ transforms: FeatureAndTargetTransforms,
94
+ oversample_factor: int = 1,
95
+ stratify_classes: bool = False,
96
+ use_log_frequencies: bool = False,
97
+ ):
98
+ """ Constructor.
99
+
100
+ Args:
101
+ dataset (RetinopathyDataset): The dataset on which the Split "views".
102
+ indices (np.ndarray): Externally provided indices of samples that
103
+ are "viewed" on.
104
+ purpose (Purpose): Either train or val, to be able to replicate
105
+ the data for train split for effecient workers utilization.
106
+ transforms (FeatureAndTargetTransforms): Functors of feature and
107
+ target transforms.
108
+ oversample_factor (int, optional): Expand the training dataset by
109
+ replication to avoid dataloader stalls on epoch ends. Defaults to 1.
110
+ stratify_classes (bool, optional): Whether to apply stratified sampling.
111
+ Defaults to False.
112
+ use_log_frequencies (bool, optional): If stratify_classes=True,
113
+ whether to use logarithmic sampling strategy. If False, apply
114
+ regular even sampling. Defaults to False.
115
+ """
116
+ self.dataset = dataset
117
+ self.indices = indices
118
+ self.purpose = purpose
119
+ self.feature_transform = transforms[0]
120
+ self.target_transform = transforms[1]
121
+ self.oversample_factor = oversample_factor
122
+ self.stratify_classes = stratify_classes
123
+ self.use_log_frequencies = use_log_frequencies
124
+
125
+ self.per_class_indices: Optional[Dict[int, np.ndarray]] = None
126
+ self.frequencies: Optional[Dict[int, float]] = None
127
+ if self.stratify_classes:
128
+ self._bucketize_indices()
129
+ if self.use_log_frequencies:
130
+ self._calc_frequencies()
131
+
132
+ def _calc_frequencies(self):
133
+ assert self.per_class_indices is not None
134
+ counts_dict = {lbl: len(arr) for lbl, arr in self.per_class_indices.items()}
135
+ counts = np.array(list(counts_dict.values()))
136
+ counts_nrm = self._normalize(counts)
137
+ temperature = 50.0 # > 1 to even-out frequencies
138
+ freqs = self._normalize(np.log1p(counts_nrm * temperature))
139
+ self.frequencies = {k: freq.item() for k, freq
140
+ in zip(self.per_class_indices.keys(), freqs)}
141
+ print(self.frequencies)
142
+
143
+ @staticmethod
144
+ def _normalize(arr: np.ndarray) -> np.ndarray:
145
+ return arr / np.sum(arr)
146
+
147
+ def _bucketize_indices(self):
148
+ buckets = defaultdict(list)
149
+ for index in self.indices:
150
+ label = self.dataset.get_label_at(index)
151
+ buckets[label].append(index)
152
+ self.per_class_indices = {k: np.array(v)
153
+ for k, v in buckets.items()}
154
+
155
+ def __getitem__(self, index: Union[int, slice]) -> TensorRecord: # type: ignore[override]
156
+ assert isinstance(index, int)
157
+ if self.purpose == Purpose.Train:
158
+ index_rem = index % len(self.indices)
159
+ idx = self.indices[index_rem].item()
160
+ else:
161
+ idx = self.indices[index].item()
162
+ if self.per_class_indices:
163
+ if self.frequencies is not None:
164
+ arange = np.arange(len(self.per_class_indices))
165
+ frequencies = np.zeros(len(self.per_class_indices), dtype=float)
166
+ for k, v in self.frequencies.items():
167
+ frequencies[k] = v
168
+ random_key = np.random.choice(
169
+ arange,
170
+ p=frequencies)
171
+ else:
172
+ random_key = np.random.randint(len(self.per_class_indices))
173
+
174
+ indices = self.per_class_indices[random_key]
175
+ actual_index = np.random.choice(indices).item()
176
+ else:
177
+ actual_index = idx
178
+ feature, target = self.dataset[actual_index]
179
+ feature_tensor = self.feature_transform(feature)
180
+ target_tensor = self.target_transform(target)
181
+ return feature_tensor, target_tensor
182
+
183
+ def __len__(self):
184
+ if self.purpose == Purpose.Train:
185
+ return len(self.indices) * self.oversample_factor
186
+ else:
187
+ return len(self.indices)
188
+
189
+ @staticmethod
190
+ def make_splits(all_data: RetinopathyDataset,
191
+ train_transforms: FeatureAndTargetTransforms,
192
+ val_transforms: FeatureAndTargetTransforms,
193
+ train_fraction: float,
194
+ stratify_train: bool,
195
+ stratify_val: bool,
196
+ seed: int = 54,
197
+ ) -> Tuple['Split', 'Split']:
198
+
199
+ """ Prepare train and val splits deterministically.
200
+
201
+ Returns:
202
+ Tuple[Split, Split]:
203
+ - Train split
204
+ - Val split
205
+ """
206
+
207
+ prng = RandomState(seed)
208
+
209
+ num_train = int(len(all_data) * train_fraction)
210
+ all_indices = prng.permutation(len(all_data))
211
+ train_indices = all_indices[:num_train]
212
+ val_indices = all_indices[num_train:]
213
+ train_data = Split(all_data, train_indices, Purpose.Train,
214
+ train_transforms, stratify_classes=stratify_train)
215
+ val_data = Split(all_data, val_indices, Purpose.Val,
216
+ val_transforms, stratify_classes=stratify_val)
217
+ return train_data, val_data
218
+
219
+
220
+ def print_data_stats(dataset: Union[Iterable[DataRecord], DataLoader],
221
+ split_name: str) -> None:
222
+ labels = []
223
+ for _, label in dataset:
224
+ if isinstance(label, torch.Tensor):
225
+ label = label.cpu().numpy()
226
+ labels.append(label)
227
+ labels = np.concatenate(labels)
228
+ cnt = Counter(labels)
229
+ print(cnt)
230
+
231
+
metrics.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Callable
2
+
3
+ import torch
4
+
5
+ from torchmetrics.aggregation import MeanMetric
6
+ from torchmetrics.classification.accuracy import MulticlassAccuracy
7
+ from torchmetrics.classification import MulticlassCohenKappa
8
+
9
+
10
+ class Metrics:
11
+ def __init__(self,
12
+ num_classes: int,
13
+ labelmap: Dict[int, str],
14
+ split: str,
15
+ log_fn: Callable[..., None]) -> None:
16
+ self.labelmap = labelmap
17
+ self.loss = MeanMetric(nan_strategy='ignore')
18
+ self.accuracy = MulticlassAccuracy(num_classes=num_classes)
19
+ self.per_class_accuracies = MulticlassAccuracy(
20
+ num_classes=num_classes, average=None)
21
+ self.kappa = MulticlassCohenKappa(num_classes)
22
+ self.split = split
23
+ self.log_fn = log_fn
24
+
25
+ def update(self,
26
+ loss: torch.Tensor,
27
+ preds: torch.Tensor,
28
+ labels: torch.Tensor) -> None:
29
+ self.loss.update(loss)
30
+ self.accuracy.update(preds, labels)
31
+ self.per_class_accuracies.update(preds, labels)
32
+ self.kappa.update(preds, labels)
33
+
34
+ def log(self) -> None:
35
+ loss = self.loss.compute()
36
+ accuracy = self.accuracy.compute()
37
+ accuracies = self.per_class_accuracies.compute()
38
+ kappa = self.kappa.compute()
39
+ mean_accuracy = torch.nanmean(accuracies)
40
+ self.log_fn(f"{self.split}/loss", loss, sync_dist=True)
41
+ self.log_fn(f"{self.split}/accuracy", accuracy, sync_dist=True)
42
+ self.log_fn(f"{self.split}/mean_accuracy", mean_accuracy, sync_dist=True)
43
+ for i_class, acc in enumerate(accuracies):
44
+ name = self.labelmap[i_class]
45
+ self.log_fn(f"{self.split}/acc/{i_class} {name}", acc, sync_dist=True)
46
+ self.log_fn(f"{self.split}/kappa", kappa, sync_dist=True)
47
+
48
+ def to(self, device) -> 'Metrics':
49
+ self.loss.to(device) # BUG HERE? should I assign it back?
50
+ self.accuracy.to(device)
51
+ self.per_class_accuracies.to(device)
52
+ self.kappa.to(device)
53
+ return self
54
+
train.py CHANGED
@@ -1,549 +1,13 @@
1
- import os
2
- from typing import (Any, List, Dict, Optional, Tuple,
3
- Union, Callable, Iterable, Iterator)
4
- import pandas as pd
5
- from PIL import Image
6
  import datetime
7
  from argparse import ArgumentParser
8
- from enum import Enum
9
- import numpy as np
10
- from numpy.random import RandomState
11
- import collections.abc
12
- from collections import Counter, defaultdict
13
- import math
14
 
15
  import torch
16
- import torch.nn as nn
17
- import torch.utils.data as data
18
- from torch.utils.data import DataLoader
19
 
20
- from torchvision.transforms import (
21
- CenterCrop,
22
- Compose,
23
- Normalize,
24
- RandomHorizontalFlip,
25
- RandomResizedCrop,
26
- RandomRotation,
27
- RandomAffine,
28
- Resize,
29
- ToTensor)
30
-
31
- from transformers import ViTImageProcessor
32
- from transformers import ViTForImageClassification
33
- from transformers import AdamW
34
-
35
- from transformers import AutoImageProcessor, ResNetForImageClassification
36
-
37
- import lightning as L
38
  from lightning import Trainer
39
  from lightning.pytorch.loggers import TensorBoardLogger
40
  from lightning.pytorch.callbacks import ModelSummary
41
- from torchmetrics.aggregation import MeanMetric
42
- from torchmetrics.classification.accuracy import MulticlassAccuracy
43
- from torchmetrics.classification import MulticlassCohenKappa
44
-
45
- from labelmap import DR_LABELMAP
46
-
47
-
48
- DataRecord = Tuple[Image.Image, int]
49
-
50
-
51
- class RetinopathyDataset(data.Dataset[DataRecord]):
52
- """ A class to access the pre-downloaded Diabetic Retinopathy dataset. """
53
-
54
- def __init__(self, data_path: str) -> None:
55
- """ Constructor.
56
-
57
- Args:
58
- data_path (str): path to the dataset, ex: "retinopathy_data"
59
- containing "trainLabels.csv" and "train/".
60
- """
61
- super().__init__()
62
-
63
- self.data_path = data_path
64
-
65
- self.ext = ".jpeg"
66
-
67
- anno_path = os.path.join(data_path, "trainLabels.csv")
68
- self.anno_df = pd.read_csv(anno_path) # ['image', 'level']
69
- anno_name_set = set(self.anno_df['image'])
70
-
71
- if True:
72
- train_path = os.path.join(data_path, "train")
73
- img_path_list = os.listdir(train_path)
74
- img_name_set = set([os.path.splitext(p)[0] for p in img_path_list])
75
- assert anno_name_set == img_name_set
76
-
77
- self.label_map = DR_LABELMAP
78
-
79
- def __getitem__(self, index: Union[int, slice]) -> DataRecord:
80
- assert isinstance(index, int)
81
- img_path = self.get_path_at(index)
82
- img = Image.open(img_path)
83
- label = self.get_label_at(index)
84
- return img, label
85
-
86
- def __len__(self) -> int:
87
- return len(self.anno_df)
88
-
89
- def get_label_at(self, index: int) -> int:
90
- label = self.anno_df['level'].iloc[index].item()
91
- return label
92
-
93
- def get_path_at(self, index: int) -> str:
94
- img_name = self.anno_df['image'].iloc[index]
95
- img_path = os.path.join(self.data_path, "train", img_name+self.ext)
96
- return img_path
97
-
98
-
99
- """ Purpose of a split: training or validation. """
100
- class Purpose(Enum):
101
- Train = 0
102
- Val = 1
103
-
104
- """ Augmentation transformations for an image and a label. """
105
- FeatureAndTargetTransforms = Tuple[Callable[..., torch.Tensor],
106
- Callable[..., torch.Tensor]]
107
-
108
- """ Feature (image) and target (label) tensors. """
109
- TensorRecord = Tuple[torch.Tensor, torch.Tensor]
110
-
111
-
112
- class Split(data.Dataset[TensorRecord], collections.abc.Sequence[TensorRecord]):
113
- """ Split is a class that keep a view on a part of a dataset.
114
- Split is used to hold the imormation about which samples go to training
115
- and which to validation without a need to put these groups of files into
116
- separate folders.
117
- """
118
- def __init__(self, dataset: RetinopathyDataset,
119
- indices: np.ndarray,
120
- purpose: Purpose,
121
- transforms: FeatureAndTargetTransforms,
122
- oversample_factor: int = 1,
123
- stratify_classes: bool = False,
124
- use_log_frequencies: bool = False,
125
- ):
126
- """ Constructor.
127
-
128
- Args:
129
- dataset (RetinopathyDataset): The dataset on which the Split "views".
130
- indices (np.ndarray): Externally provided indices of samples that
131
- are "viewed" on.
132
- purpose (Purpose): Either train or val, to be able to replicate
133
- the data for train split for effecient workers utilization.
134
- transforms (FeatureAndTargetTransforms): Functors of feature and
135
- target transforms.
136
- oversample_factor (int, optional): Expand the training dataset by
137
- replication to avoid dataloader stalls on epoch ends. Defaults to 1.
138
- stratify_classes (bool, optional): Whether to apply stratified sampling.
139
- Defaults to False.
140
- use_log_frequencies (bool, optional): If stratify_classes=True,
141
- whether to use logarithmic sampling strategy. If False, apply
142
- regular even sampling. Defaults to False.
143
- """
144
- self.dataset = dataset
145
- self.indices = indices
146
- self.purpose = purpose
147
- self.feature_transform = transforms[0]
148
- self.target_transform = transforms[1]
149
- self.oversample_factor = oversample_factor
150
- self.stratify_classes = stratify_classes
151
- self.use_log_frequencies = use_log_frequencies
152
-
153
- self.per_class_indices: Optional[Dict[int, np.ndarray]] = None
154
- self.frequencies: Optional[Dict[int, float]] = None
155
- if self.stratify_classes:
156
- self._bucketize_indices()
157
- if self.use_log_frequencies:
158
- self._calc_frequencies()
159
-
160
- def _calc_frequencies(self):
161
- assert self.per_class_indices is not None
162
- counts_dict = {lbl: len(arr) for lbl, arr in self.per_class_indices.items()}
163
- counts = np.array(list(counts_dict.values()))
164
- counts_nrm = self._normalize(counts)
165
- temperature = 50.0 # > 1 to even-out frequencies
166
- freqs = self._normalize(np.log1p(counts_nrm * temperature))
167
- self.frequencies = {k: freq.item() for k, freq
168
- in zip(self.per_class_indices.keys(), freqs)}
169
- print(self.frequencies)
170
-
171
- @staticmethod
172
- def _normalize(arr: np.ndarray) -> np.ndarray:
173
- return arr / np.sum(arr)
174
-
175
- def _bucketize_indices(self):
176
- buckets = defaultdict(list)
177
- for index in self.indices:
178
- label = self.dataset.get_label_at(index)
179
- buckets[label].append(index)
180
- self.per_class_indices = {k: np.array(v)
181
- for k, v in buckets.items()}
182
-
183
- def __getitem__(self, index: Union[int, slice]) -> TensorRecord: # type: ignore[override]
184
- assert isinstance(index, int)
185
- if self.purpose == Purpose.Train:
186
- index_rem = index % len(self.indices)
187
- idx = self.indices[index_rem].item()
188
- else:
189
- idx = self.indices[index].item()
190
- if self.per_class_indices:
191
- if self.frequencies is not None:
192
- arange = np.arange(len(self.per_class_indices))
193
- frequencies = np.zeros(len(self.per_class_indices), dtype=float)
194
- for k, v in self.frequencies.items():
195
- frequencies[k] = v
196
- random_key = np.random.choice(
197
- arange,
198
- p=frequencies)
199
- else:
200
- random_key = np.random.randint(len(self.per_class_indices))
201
-
202
- indices = self.per_class_indices[random_key]
203
- actual_index = np.random.choice(indices).item()
204
- else:
205
- actual_index = idx
206
- feature, target = self.dataset[actual_index]
207
- feature_tensor = self.feature_transform(feature)
208
- target_tensor = self.target_transform(target)
209
- return feature_tensor, target_tensor
210
-
211
- def __len__(self):
212
- if self.purpose == Purpose.Train:
213
- return len(self.indices) * self.oversample_factor
214
- else:
215
- return len(self.indices)
216
-
217
- @staticmethod
218
- def make_splits(all_data: RetinopathyDataset,
219
- train_transforms: FeatureAndTargetTransforms,
220
- val_transforms: FeatureAndTargetTransforms,
221
- train_fraction: float,
222
- stratify_train: bool,
223
- stratify_val: bool,
224
- seed: int = 54,
225
- ) -> Tuple['Split', 'Split']:
226
-
227
- """ Prepare train and val splits deterministically.
228
-
229
- Returns:
230
- Tuple[Split, Split]:
231
- - Train split
232
- - Val split
233
- """
234
-
235
- prng = RandomState(seed)
236
-
237
- num_train = int(len(all_data) * train_fraction)
238
- all_indices = prng.permutation(len(all_data))
239
- train_indices = all_indices[:num_train]
240
- val_indices = all_indices[num_train:]
241
- train_data = Split(all_data, train_indices, Purpose.Train,
242
- train_transforms, stratify_classes=stratify_train)
243
- val_data = Split(all_data, val_indices, Purpose.Val,
244
- val_transforms, stratify_classes=stratify_val)
245
- return train_data, val_data
246
-
247
-
248
- def print_data_stats(dataset: Union[Iterable[DataRecord], DataLoader],
249
- split_name: str) -> None:
250
- labels = []
251
- for _, label in dataset:
252
- if isinstance(label, torch.Tensor):
253
- label = label.cpu().numpy()
254
- labels.append(label)
255
- labels = np.concatenate(labels)
256
- cnt = Counter(labels)
257
- print(cnt)
258
-
259
-
260
- class Metrics:
261
- def __init__(self,
262
- num_classes: int,
263
- labelmap: Dict[int, str],
264
- split: str,
265
- log_fn: Callable[..., None]) -> None:
266
- self.labelmap = labelmap
267
- self.loss = MeanMetric(nan_strategy='ignore')
268
- self.accuracy = MulticlassAccuracy(num_classes=num_classes)
269
- self.per_class_accuracies = MulticlassAccuracy(
270
- num_classes=num_classes, average=None)
271
- self.kappa = MulticlassCohenKappa(num_classes)
272
- self.split = split
273
- self.log_fn = log_fn
274
-
275
- def update(self,
276
- loss: torch.Tensor,
277
- preds: torch.Tensor,
278
- labels: torch.Tensor) -> None:
279
- self.loss.update(loss)
280
- self.accuracy.update(preds, labels)
281
- self.per_class_accuracies.update(preds, labels)
282
- self.kappa.update(preds, labels)
283
-
284
- def log(self) -> None:
285
- loss = self.loss.compute()
286
- accuracy = self.accuracy.compute()
287
- accuracies = self.per_class_accuracies.compute()
288
- kappa = self.kappa.compute()
289
- mean_accuracy = torch.nanmean(accuracies)
290
- self.log_fn(f"{self.split}/loss", loss, sync_dist=True)
291
- self.log_fn(f"{self.split}/accuracy", accuracy, sync_dist=True)
292
- self.log_fn(f"{self.split}/mean_accuracy", mean_accuracy, sync_dist=True)
293
- for i_class, acc in enumerate(accuracies):
294
- name = self.labelmap[i_class]
295
- self.log_fn(f"{self.split}/acc/{i_class} {name}", acc, sync_dist=True)
296
- self.log_fn(f"{self.split}/kappa", kappa, sync_dist=True)
297
-
298
- def to(self, device) -> 'Metrics':
299
- self.loss.to(device) # BUG HERE? should I assign it back?
300
- self.accuracy.to(device)
301
- self.per_class_accuracies.to(device)
302
- self.kappa.to(device)
303
- return self
304
-
305
-
306
- def worker_init_fn(worker_id: int) -> None:
307
- """ Initialize workers in a way that they draw different
308
- random samples and do not repeat identical pseudorandom
309
- sequences of each other, which may be the case with Fork
310
- multiprocessing.
311
-
312
- Args:
313
- worker_id (int): id of a preprocessing worker process launched
314
- by one DDP training process.
315
- """
316
- state = np.random.get_state()
317
- assert isinstance(state, tuple)
318
- assert isinstance(state[1], np.ndarray)
319
- seed_arr = state[1]
320
- seed_np = seed_arr[0] + worker_id
321
- np.random.seed(seed_np)
322
- seed_pt = seed_np + 1111
323
- torch.manual_seed(seed_pt)
324
- print(f"Setting numpy seed to {seed_np} and pytorch seed to {seed_pt} in worker {worker_id}")
325
-
326
-
327
- class ViTLightningModule(L.LightningModule):
328
- """ Lightning Module that implements neural network training hooks. """
329
- def __init__(self, debug: bool) -> None:
330
- super().__init__()
331
-
332
- self.save_hyperparameters()
333
-
334
- np.random.seed(53)
335
-
336
- # pretrained_name = 'google/vit-base-patch16-224-in21k'
337
- # pretrained_name = 'google/vit-base-patch16-384-in21k'
338
-
339
- # pretrained_name = "microsoft/resnet-50"
340
- pretrained_name = "microsoft/resnet-34"
341
-
342
- # processor = ViTImageProcessor.from_pretrained(pretrained_name)
343
- processor = AutoImageProcessor.from_pretrained(pretrained_name)
344
-
345
- image_mean = processor.image_mean # type: ignore
346
- image_std = processor.image_std # type: ignore
347
- # size = processor.size["height"] # type: ignore
348
- # size = processor.size["shortest_edge"] # type: ignore
349
- size = 896 # 448
350
-
351
- normalize = Normalize(mean=image_mean, std=image_std)
352
- train_transforms = Compose(
353
- [
354
- # RandomRotation((-180, 180)),
355
- RandomAffine((-180, 180), shear=10),
356
- RandomResizedCrop(size, scale=(0.5, 1.0)),
357
- RandomHorizontalFlip(),
358
- ToTensor(),
359
- normalize,
360
- ]
361
- )
362
- val_transforms = Compose(
363
- [
364
- Resize(size),
365
- CenterCrop(size),
366
- ToTensor(),
367
- normalize,
368
- ]
369
- )
370
-
371
- self.dataset = RetinopathyDataset("retinopathy_data")
372
-
373
- # print_data_stats(self.dataset, "all_data")
374
-
375
- train_data, val_data = Split.make_splits(
376
- self.dataset,
377
- train_transforms=(train_transforms, torch.tensor),
378
- val_transforms=(val_transforms, torch.tensor),
379
- train_fraction=0.9,
380
- stratify_train=True,
381
- stratify_val=True,
382
- )
383
-
384
- assert len(set(train_data.indices).intersection(set(val_data.indices))) == 0
385
-
386
- label2id = {label: id for id, label in self.dataset.label_map.items()}
387
-
388
- num_classes = len(self.dataset.label_map)
389
- labelmap = self.dataset.label_map
390
- assert len(labelmap) == num_classes
391
- assert set(labelmap.keys()) == set(range(num_classes))
392
-
393
- train_batch_size = 4 if debug else 20
394
- val_batch_size = 4 if debug else 20
395
-
396
- num_gpus = torch.cuda.device_count()
397
- print(f"{num_gpus=}")
398
-
399
- num_cores = torch.get_num_threads()
400
- print(f"{num_cores=}")
401
-
402
- num_threads_per_gpu = max(1, int(math.ceil(num_cores / num_gpus))) \
403
- if num_gpus > 0 else 1
404
-
405
- num_workers = 1 if debug else num_threads_per_gpu
406
- print(f"{num_workers=}")
407
-
408
- self._train_dataloader = DataLoader(
409
- train_data,
410
- shuffle=True,
411
- num_workers=num_workers,
412
- persistent_workers=num_workers > 0,
413
- pin_memory=True,
414
- batch_size=train_batch_size,
415
- worker_init_fn=worker_init_fn,
416
- )
417
- self._val_dataloader = DataLoader(
418
- val_data,
419
- shuffle=False,
420
- num_workers=num_workers,
421
- persistent_workers=num_workers > 0,
422
- pin_memory=True,
423
- batch_size=val_batch_size,
424
- )
425
-
426
- # print_data_stats(self._val_dataloader, "val")
427
- # print_data_stats(self._train_dataloader, "train")
428
-
429
- img_batch, label_batch = next(iter(self._train_dataloader))
430
- assert isinstance(img_batch, torch.Tensor)
431
- assert isinstance(label_batch, torch.Tensor)
432
- print(f"{img_batch.shape=} {label_batch.shape=}")
433
-
434
- assert img_batch.shape == (train_batch_size, 3, size, size)
435
- assert label_batch.shape == (train_batch_size,)
436
-
437
- self.example_input_array = torch.randn_like(img_batch)
438
-
439
- # self._model = ViTForImageClassification.from_pretrained(
440
- # pretrained_name,
441
- # num_labels=len(self.dataset.label_map),
442
- # id2label=self.dataset.label_map,
443
- # label2id=label2id)
444
-
445
- self._model = ResNetForImageClassification.from_pretrained(
446
- pretrained_name,
447
- num_labels=len(self.dataset.label_map),
448
- id2label=self.dataset.label_map,
449
- label2id=label2id,
450
- ignore_mismatched_sizes=True)
451
-
452
- assert isinstance(self._model, nn.Module)
453
-
454
- self.train_metrics: Optional[Metrics] = None
455
- self.val_metrics: Optional[Metrics] = None
456
-
457
- @property
458
- def num_classes(self):
459
- return len(self.dataset.label_map)
460
-
461
- @property
462
- def labelmap(self):
463
- return self.dataset.label_map
464
-
465
- def forward(self, img_batch):
466
- outputs = self._model(img_batch) # type: ignore
467
- return outputs.logits
468
-
469
- def common_step(self, batch, batch_idx):
470
- img_batch, label_batch = batch
471
-
472
- logits = self(img_batch)
473
-
474
- criterion = nn.CrossEntropyLoss()
475
- loss = criterion(logits, label_batch)
476
- preds_batch = logits.argmax(-1)
477
-
478
- return loss, preds_batch, label_batch
479
-
480
- def on_train_epoch_start(self) -> None:
481
- self.train_metrics = Metrics(
482
- self.num_classes,
483
- self.labelmap,
484
- "train",
485
- self.log).to(self.device)
486
-
487
- def training_step(self, batch, batch_idx):
488
- loss, preds, labels = self.common_step(batch, batch_idx)
489
- assert self.train_metrics is not None
490
- self.train_metrics.update(loss, preds, labels)
491
-
492
- if False and batch_idx == 0:
493
- self._dump_train_images()
494
-
495
- return loss
496
-
497
- def _dump_train_images(self) -> None:
498
- """ Save augmented images to disk for inspection. """
499
- img_batch, label_batch = next(iter(self._train_dataloader))
500
- for i_img, (img, label) in enumerate(zip(img_batch, label_batch)):
501
- img_np = img.cpu().numpy()
502
- denorm_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
503
- img_uint8 = (255 * denorm_np).astype(np.uint8)
504
- pil_img = Image.fromarray(np.transpose(img_uint8, (1, 2, 0)))
505
- if self.logger is not None and self.logger.log_dir is not None:
506
- assert isinstance(self.logger.log_dir, str)
507
- os.makedirs(self.logger.log_dir, exist_ok=True)
508
- path = os.path.join(self.logger.log_dir,
509
- f"img_{i_img:02d}_{label.item()}.png")
510
- pil_img.save(path)
511
-
512
- def on_train_epoch_end(self) -> None:
513
- assert self.train_metrics is not None
514
- self.train_metrics.log()
515
- assert self.logger is not None
516
- if self.logger.log_dir is not None:
517
- path = os.path.join(self.logger.log_dir, "inference")
518
- self.save_checkpoint_dk(path)
519
-
520
- def save_checkpoint_dk(self, dirpath: str) -> None:
521
- if self.global_rank == 0:
522
- self._model.save_pretrained(dirpath)
523
-
524
- def validation_step(self, batch, batch_idx):
525
- loss, preds, labels = self.common_step(batch, batch_idx)
526
- assert self.val_metrics is not None
527
- self.val_metrics.update(loss, preds, labels)
528
- return loss
529
-
530
- def on_validation_epoch_start(self) -> None:
531
- self.val_metrics = Metrics(
532
- self.num_classes,
533
- self.labelmap,
534
- "val",
535
- self.log).to(self.device)
536
-
537
- def on_validation_epoch_end(self) -> None:
538
- assert self.val_metrics is not None
539
- self.val_metrics.log()
540
 
541
- def configure_optimizers(self):
542
- # No WD is the same as 1e-3 and better than 1e-2
543
- # LR 1e-3 is worse than 1e-4 (without LR scheduler)
544
- return AdamW(self.parameters(),
545
- lr=1e-4,
546
- )
547
 
548
 
549
  def main():
 
 
 
 
 
 
1
  import datetime
2
  from argparse import ArgumentParser
 
 
 
 
 
 
3
 
4
  import torch
 
 
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from lightning import Trainer
7
  from lightning.pytorch.loggers import TensorBoardLogger
8
  from lightning.pytorch.callbacks import ModelSummary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ from trainer import ViTLightningModule
 
 
 
 
 
11
 
12
 
13
  def main():
trainer.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+ import numpy as np
4
+ import math
5
+ from PIL import Image
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.utils.data import DataLoader
10
+
11
+ from torchvision.transforms import (
12
+ CenterCrop,
13
+ Compose,
14
+ Normalize,
15
+ RandomHorizontalFlip,
16
+ RandomResizedCrop,
17
+ RandomRotation,
18
+ RandomAffine,
19
+ Resize,
20
+ ToTensor)
21
+
22
+ # from transformers import ViTImageProcessor
23
+ # from transformers import ViTForImageClassification
24
+ from transformers import AdamW
25
+ from transformers import AutoImageProcessor, ResNetForImageClassification
26
+ import lightning as L
27
+
28
+ from data import RetinopathyDataset, Split
29
+ from metrics import Metrics
30
+
31
+
32
+ def worker_init_fn(worker_id: int) -> None:
33
+ """ Initialize workers in a way that they draw different
34
+ random samples and do not repeat identical pseudorandom
35
+ sequences of each other, which may be the case with Fork
36
+ multiprocessing.
37
+
38
+ Args:
39
+ worker_id (int): id of a preprocessing worker process launched
40
+ by one DDP training process.
41
+ """
42
+ state = np.random.get_state()
43
+ assert isinstance(state, tuple)
44
+ assert isinstance(state[1], np.ndarray)
45
+ seed_arr = state[1]
46
+ seed_np = seed_arr[0] + worker_id
47
+ np.random.seed(seed_np)
48
+ seed_pt = seed_np + 1111
49
+ torch.manual_seed(seed_pt)
50
+ print(f"Setting numpy seed to {seed_np} and pytorch seed to {seed_pt} in worker {worker_id}")
51
+
52
+
53
+ class ViTLightningModule(L.LightningModule):
54
+ """ Lightning Module that implements neural network training hooks. """
55
+ def __init__(self, debug: bool) -> None:
56
+ super().__init__()
57
+
58
+ self.save_hyperparameters()
59
+
60
+ np.random.seed(53)
61
+
62
+ # pretrained_name = 'google/vit-base-patch16-224-in21k'
63
+ # pretrained_name = 'google/vit-base-patch16-384-in21k'
64
+
65
+ # pretrained_name = "microsoft/resnet-50"
66
+ pretrained_name = "microsoft/resnet-34"
67
+
68
+ # processor = ViTImageProcessor.from_pretrained(pretrained_name)
69
+ processor = AutoImageProcessor.from_pretrained(pretrained_name)
70
+
71
+ image_mean = processor.image_mean # type: ignore
72
+ image_std = processor.image_std # type: ignore
73
+ # size = processor.size["height"] # type: ignore
74
+ # size = processor.size["shortest_edge"] # type: ignore
75
+ size = 896 # 448
76
+
77
+ normalize = Normalize(mean=image_mean, std=image_std)
78
+ train_transforms = Compose(
79
+ [
80
+ # RandomRotation((-180, 180)),
81
+ RandomAffine((-180, 180), shear=10),
82
+ RandomResizedCrop(size, scale=(0.5, 1.0)),
83
+ RandomHorizontalFlip(),
84
+ ToTensor(),
85
+ normalize,
86
+ ]
87
+ )
88
+ val_transforms = Compose(
89
+ [
90
+ Resize(size),
91
+ CenterCrop(size),
92
+ ToTensor(),
93
+ normalize,
94
+ ]
95
+ )
96
+
97
+ self.dataset = RetinopathyDataset("retinopathy_data")
98
+
99
+ # print_data_stats(self.dataset, "all_data")
100
+
101
+ train_data, val_data = Split.make_splits(
102
+ self.dataset,
103
+ train_transforms=(train_transforms, torch.tensor),
104
+ val_transforms=(val_transforms, torch.tensor),
105
+ train_fraction=0.9,
106
+ stratify_train=True,
107
+ stratify_val=True,
108
+ )
109
+
110
+ assert len(set(train_data.indices).intersection(set(val_data.indices))) == 0
111
+
112
+ label2id = {label: id for id, label in self.dataset.label_map.items()}
113
+
114
+ num_classes = len(self.dataset.label_map)
115
+ labelmap = self.dataset.label_map
116
+ assert len(labelmap) == num_classes
117
+ assert set(labelmap.keys()) == set(range(num_classes))
118
+
119
+ train_batch_size = 4 if debug else 20
120
+ val_batch_size = 4 if debug else 20
121
+
122
+ num_gpus = torch.cuda.device_count()
123
+ print(f"{num_gpus=}")
124
+
125
+ num_cores = torch.get_num_threads()
126
+ print(f"{num_cores=}")
127
+
128
+ num_threads_per_gpu = max(1, int(math.ceil(num_cores / num_gpus))) \
129
+ if num_gpus > 0 else 1
130
+
131
+ num_workers = 1 if debug else num_threads_per_gpu
132
+ print(f"{num_workers=}")
133
+
134
+ self._train_dataloader = DataLoader(
135
+ train_data,
136
+ shuffle=True,
137
+ num_workers=num_workers,
138
+ persistent_workers=num_workers > 0,
139
+ pin_memory=True,
140
+ batch_size=train_batch_size,
141
+ worker_init_fn=worker_init_fn,
142
+ )
143
+ self._val_dataloader = DataLoader(
144
+ val_data,
145
+ shuffle=False,
146
+ num_workers=num_workers,
147
+ persistent_workers=num_workers > 0,
148
+ pin_memory=True,
149
+ batch_size=val_batch_size,
150
+ )
151
+
152
+ # print_data_stats(self._val_dataloader, "val")
153
+ # print_data_stats(self._train_dataloader, "train")
154
+
155
+ img_batch, label_batch = next(iter(self._train_dataloader))
156
+ assert isinstance(img_batch, torch.Tensor)
157
+ assert isinstance(label_batch, torch.Tensor)
158
+ print(f"{img_batch.shape=} {label_batch.shape=}")
159
+
160
+ assert img_batch.shape == (train_batch_size, 3, size, size)
161
+ assert label_batch.shape == (train_batch_size,)
162
+
163
+ self.example_input_array = torch.randn_like(img_batch)
164
+
165
+ # self._model = ViTForImageClassification.from_pretrained(
166
+ # pretrained_name,
167
+ # num_labels=len(self.dataset.label_map),
168
+ # id2label=self.dataset.label_map,
169
+ # label2id=label2id)
170
+
171
+ self._model = ResNetForImageClassification.from_pretrained(
172
+ pretrained_name,
173
+ num_labels=len(self.dataset.label_map),
174
+ id2label=self.dataset.label_map,
175
+ label2id=label2id,
176
+ ignore_mismatched_sizes=True)
177
+
178
+ assert isinstance(self._model, nn.Module)
179
+
180
+ self.train_metrics: Optional[Metrics] = None
181
+ self.val_metrics: Optional[Metrics] = None
182
+
183
+ @property
184
+ def num_classes(self):
185
+ return len(self.dataset.label_map)
186
+
187
+ @property
188
+ def labelmap(self):
189
+ return self.dataset.label_map
190
+
191
+ def forward(self, img_batch):
192
+ outputs = self._model(img_batch) # type: ignore
193
+ return outputs.logits
194
+
195
+ def common_step(self, batch, batch_idx):
196
+ img_batch, label_batch = batch
197
+
198
+ logits = self(img_batch)
199
+
200
+ criterion = nn.CrossEntropyLoss()
201
+ loss = criterion(logits, label_batch)
202
+ preds_batch = logits.argmax(-1)
203
+
204
+ return loss, preds_batch, label_batch
205
+
206
+ def on_train_epoch_start(self) -> None:
207
+ self.train_metrics = Metrics(
208
+ self.num_classes,
209
+ self.labelmap,
210
+ "train",
211
+ self.log).to(self.device)
212
+
213
+ def training_step(self, batch, batch_idx):
214
+ loss, preds, labels = self.common_step(batch, batch_idx)
215
+ assert self.train_metrics is not None
216
+ self.train_metrics.update(loss, preds, labels)
217
+
218
+ if False and batch_idx == 0:
219
+ self._dump_train_images()
220
+
221
+ return loss
222
+
223
+ def _dump_train_images(self) -> None:
224
+ """ Save augmented images to disk for inspection. """
225
+ img_batch, label_batch = next(iter(self._train_dataloader))
226
+ for i_img, (img, label) in enumerate(zip(img_batch, label_batch)):
227
+ img_np = img.cpu().numpy()
228
+ denorm_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
229
+ img_uint8 = (255 * denorm_np).astype(np.uint8)
230
+ pil_img = Image.fromarray(np.transpose(img_uint8, (1, 2, 0)))
231
+ if self.logger is not None and self.logger.log_dir is not None:
232
+ assert isinstance(self.logger.log_dir, str)
233
+ os.makedirs(self.logger.log_dir, exist_ok=True)
234
+ path = os.path.join(self.logger.log_dir,
235
+ f"img_{i_img:02d}_{label.item()}.png")
236
+ pil_img.save(path)
237
+
238
+ def on_train_epoch_end(self) -> None:
239
+ assert self.train_metrics is not None
240
+ self.train_metrics.log()
241
+ assert self.logger is not None
242
+ if self.logger.log_dir is not None:
243
+ path = os.path.join(self.logger.log_dir, "inference")
244
+ self.save_checkpoint_dk(path)
245
+
246
+ def save_checkpoint_dk(self, dirpath: str) -> None:
247
+ if self.global_rank == 0:
248
+ self._model.save_pretrained(dirpath)
249
+
250
+ def validation_step(self, batch, batch_idx):
251
+ loss, preds, labels = self.common_step(batch, batch_idx)
252
+ assert self.val_metrics is not None
253
+ self.val_metrics.update(loss, preds, labels)
254
+ return loss
255
+
256
+ def on_validation_epoch_start(self) -> None:
257
+ self.val_metrics = Metrics(
258
+ self.num_classes,
259
+ self.labelmap,
260
+ "val",
261
+ self.log).to(self.device)
262
+
263
+ def on_validation_epoch_end(self) -> None:
264
+ assert self.val_metrics is not None
265
+ self.val_metrics.log()
266
+
267
+ def configure_optimizers(self):
268
+ # No WD is the same as 1e-3 and better than 1e-2
269
+ # LR 1e-3 is worse than 1e-4 (without LR scheduler)
270
+ return AdamW(self.parameters(),
271
+ lr=1e-4,
272
+ )