Files changed (2) hide show
  1. configs/config.yaml +174 -0
  2. configs/config_full_tile.yaml +176 -0
configs/config.yaml ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # lightning.pytorch==2.1.1
2
+ seed_everything: 0
3
+
4
+ ### Trainer configuration
5
+ trainer:
6
+ accelerator: auto
7
+ strategy: auto
8
+ devices: auto
9
+ num_nodes: 1
10
+ # precision: 16-mixed
11
+ logger:
12
+ # You can swtich to TensorBoard for logging by uncommenting the below line and commenting out the procedding line
13
+ #class_path: TensorBoardLogger
14
+ class_path: lightning.pytorch.loggers.csv_logs.CSVLogger
15
+ init_args:
16
+ save_dir: ./experiments
17
+ name: fine_tune_suhi
18
+ callbacks:
19
+ - class_path: RichProgressBar
20
+ - class_path: LearningRateMonitor
21
+ init_args:
22
+ logging_interval: epoch
23
+ - class_path: EarlyStopping
24
+ init_args:
25
+ monitor: val/loss
26
+ patience: 600
27
+ max_epochs: 600
28
+ check_val_every_n_epoch: 1
29
+ log_every_n_steps: 10
30
+ enable_checkpointing: true
31
+ default_root_dir: ./experiments
32
+ out_dtype: float32
33
+
34
+ ### Data configuration
35
+ data:
36
+ class_path: GenericNonGeoPixelwiseRegressionDataModule
37
+ init_args:
38
+ batch_size: 64
39
+ num_workers: 8
40
+ train_transform:
41
+ - class_path: albumentations.HorizontalFlip
42
+ init_args:
43
+ p: 0.5
44
+ - class_path: albumentations.Rotate
45
+ init_args:
46
+ limit: 30
47
+ border_mode: 0 # cv2.BORDER_CONSTANT
48
+ value: 0
49
+ mask_value: 1
50
+ p: 0.5
51
+ - class_path: ToTensorV2
52
+ # Specify all bands which are in the input data.
53
+ dataset_bands:
54
+ # 6 HLS bands
55
+ - BLUE
56
+ - GREEN
57
+ - RED
58
+ - NIR_NARROW
59
+ - SWIR_1
60
+ - SWIR_2
61
+ # ERA5-Land t2m_spatial_avg
62
+ - 7
63
+ # ERA5-Land t2m_sunrise_avg
64
+ - 8
65
+ # ERA5-Land t2m_midnight_avg
66
+ - 9
67
+ # ERA5-Land t2m_delta_avg
68
+ - 10
69
+ # cos_tod
70
+ - 11
71
+ # sin_tod
72
+ - 12
73
+ # cos_doy
74
+ - 13
75
+ # sin_doy
76
+ - 14
77
+ # Specify the bands which are used from the input data.
78
+ # Bands 8 - 14 were discarded in the final model
79
+ output_bands:
80
+ - BLUE
81
+ - GREEN
82
+ - RED
83
+ - NIR_NARROW
84
+ - SWIR_1
85
+ - SWIR_2
86
+ - 7
87
+ rgb_indices:
88
+ - 2
89
+ - 1
90
+ - 0
91
+ # Directory roots to training, validation and test datasplits:
92
+ train_data_root: train/inputs
93
+ train_label_data_root: train/targets
94
+ val_data_root: val/inputs
95
+ val_label_data_root: val/targets
96
+ test_data_root: test/inputs
97
+ test_label_data_root: test/targets
98
+ img_grep: "*.inputs.tif"
99
+ label_grep: "*.lst.tif"
100
+ # Nodata value in the input data
101
+ no_data_replace: 0
102
+ # Nodata value in label (target) data
103
+ no_label_replace: -9999
104
+ # Mean value of the training dataset per band
105
+ means:
106
+ - 702.4754028320312
107
+ - 1023.23291015625
108
+ - 1118.8924560546875
109
+ - 2440.750732421875
110
+ - 2052.705810546875
111
+ - 1514.15087890625
112
+ - 21.031919479370117
113
+ # Standard deviation of the training dataset per band
114
+ stds:
115
+ - 554.8255615234375
116
+ - 613.5565185546875
117
+ - 745.929443359375
118
+ - 715.0111083984375
119
+ - 761.47607421875
120
+ - 734.991943359375
121
+ - 8.66781997680664
122
+
123
+ ### Model configuration
124
+ model:
125
+ class_path: terratorch.tasks.PixelwiseRegressionTask
126
+ init_args:
127
+ model_args:
128
+ decoder: UperNetDecoder
129
+ pretrained: false
130
+ backbone: prithvi_swin_L
131
+ img_size: 224
132
+ backbone_drop_path_rate: 0.3
133
+ decoder_channels: 256
134
+ in_channels: 7
135
+ bands:
136
+ - BLUE
137
+ - GREEN
138
+ - RED
139
+ - NIR_NARROW
140
+ - SWIR_1
141
+ - SWIR_2
142
+ - 7
143
+ num_frames: 1
144
+ loss: rmse
145
+ aux_heads:
146
+ - name: aux_head
147
+ decoder: IdentityDecoder
148
+ decoder_args:
149
+ head_dropout: 0.5
150
+ head_channel_list:
151
+ - 1
152
+ head_final_act: torch.nn.LazyLinear
153
+ aux_loss:
154
+ aux_head: 0.4
155
+ ignore_index: -9999
156
+ freeze_backbone: false
157
+ freeze_decoder: false
158
+ model_factory: PrithviModelFactory
159
+ # uncomment this block for tiled inference
160
+ tiled_inference_parameters:
161
+ h_crop: 224
162
+ h_stride: 224
163
+ w_crop: 224
164
+ w_stride: 224
165
+ average_patches: true
166
+ optimizer:
167
+ class_path: torch.optim.AdamW
168
+ init_args:
169
+ lr: 0.0001
170
+ weight_decay: 0.05
171
+ lr_scheduler:
172
+ class_path: ReduceLROnPlateau
173
+ init_args:
174
+ monitor: val/loss
configs/config_full_tile.yaml ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # lightning.pytorch==2.1.1
2
+ seed_everything: 0
3
+
4
+ ### Trainer configuration
5
+ trainer:
6
+ accelerator: auto
7
+ strategy: auto
8
+ devices: auto
9
+ num_nodes: 1
10
+ # precision: 16-mixed
11
+ logger:
12
+ # You can swtich to TensorBoard for logging by uncommenting the below line and commenting out the procedding line
13
+ #class_path: TensorBoardLogger
14
+ class_path: lightning.pytorch.loggers.csv_logs.CSVLogger
15
+ init_args:
16
+ save_dir: ./experiments
17
+ name: fine_tune_suhi
18
+ callbacks:
19
+ - class_path: RichProgressBar
20
+ - class_path: LearningRateMonitor
21
+ init_args:
22
+ logging_interval: epoch
23
+ - class_path: EarlyStopping
24
+ init_args:
25
+ monitor: val/loss
26
+ patience: 600
27
+ max_epochs: 600
28
+ check_val_every_n_epoch: 1
29
+ log_every_n_steps: 10
30
+ enable_checkpointing: true
31
+ default_root_dir: ./experiments
32
+ out_dtype: float32
33
+
34
+ ### Data configuration
35
+ data:
36
+ class_path: GenericNonGeoPixelwiseRegressionDataModule
37
+ init_args:
38
+ batch_size: 1
39
+ num_workers: 8
40
+ train_transform:
41
+ - class_path: albumentations.HorizontalFlip
42
+ init_args:
43
+ p: 0.5
44
+ - class_path: albumentations.Rotate
45
+ init_args:
46
+ limit: 30
47
+ border_mode: 0 # cv2.BORDER_CONSTANT
48
+ value: 0
49
+ mask_value: 1
50
+ p: 0.5
51
+ - class_path: ToTensorV2
52
+ # Specify all bands which are in the input data.
53
+ dataset_bands:
54
+ # 6 HLS bands
55
+ - BLUE
56
+ - GREEN
57
+ - RED
58
+ - NIR_NARROW
59
+ - SWIR_1
60
+ - SWIR_2
61
+ # ERA5-Land t2m_spatial_avg
62
+ - 7
63
+ # ERA5-Land t2m_sunrise_avg
64
+ - 8
65
+ # ERA5-Land t2m_midnight_avg
66
+ - 9
67
+ # ERA5-Land t2m_delta_avg
68
+ - 10
69
+ # cos_tod
70
+ - 11
71
+ # sin_tod
72
+ - 12
73
+ # cos_doy
74
+ - 13
75
+ # sin_doy
76
+ - 14
77
+ # Specify the bands which are used from the input data.
78
+ # Bands 8 - 14 were discarded in the final model
79
+ output_bands:
80
+ - BLUE
81
+ - GREEN
82
+ - RED
83
+ - NIR_NARROW
84
+ - SWIR_1
85
+ - SWIR_2
86
+ - 7
87
+ rgb_indices:
88
+ - 2
89
+ - 1
90
+ - 0
91
+ # Directory roots to training, validation and test datasplits:
92
+ train_data_root: train/inputs
93
+ train_label_data_root: train/targets
94
+ val_data_root: val/inputs
95
+ val_label_data_root: val/targets
96
+ test_data_root: test/inputs
97
+ test_label_data_root: test/targets
98
+ img_grep: "*.inputs.tif"
99
+ label_grep: "*.lst.tif"
100
+ # Nodata value in the input data
101
+ no_data_replace: 0
102
+ # Nodata value in label (target) data
103
+ no_label_replace: -9999
104
+ # Mean value of the training dataset per band
105
+ means:
106
+ - 702.4754028320312
107
+ - 1023.23291015625
108
+ - 1118.8924560546875
109
+ - 2440.750732421875
110
+ - 2052.705810546875
111
+ - 1514.15087890625
112
+ - 21.031919479370117
113
+ # Standard deviation of the training dataset per band
114
+ stds:
115
+ - 554.8255615234375
116
+ - 613.5565185546875
117
+ - 745.929443359375
118
+ - 715.0111083984375
119
+ - 761.47607421875
120
+ - 734.991943359375
121
+ - 8.66781997680664
122
+
123
+ ### Model configuration
124
+ model:
125
+ class_path: terratorch.tasks.PixelwiseRegressionTask
126
+ init_args:
127
+ model_args:
128
+ decoder: UperNetDecoder
129
+ pretrained: false
130
+ backbone: prithvi_swin_L
131
+ img_size: 224
132
+ backbone_drop_path_rate: 0.3
133
+ decoder_channels: 256
134
+ in_channels: 7
135
+ bands:
136
+ - BLUE
137
+ - GREEN
138
+ - RED
139
+ - NIR_NARROW
140
+ - SWIR_1
141
+ - SWIR_2
142
+ - 7
143
+ num_frames: 1
144
+ loss: rmse
145
+ aux_heads:
146
+ - name: aux_head
147
+ decoder: IdentityDecoder
148
+ decoder_args:
149
+ head_dropout: 0.5
150
+ head_channel_list:
151
+ - 1
152
+ head_final_act: torch.nn.LazyLinear
153
+ aux_loss:
154
+ aux_head: 0.4
155
+ ignore_index: -9999
156
+ freeze_backbone: false
157
+ freeze_decoder: false
158
+ model_factory: PrithviModelFactory
159
+ # This block is commented out when inferencing on full tiles.
160
+ # It is possible to inference on full tiles with this paramter on, the benefit is that the compute requirement is smaller.
161
+ # However, using this to inference on a full tile will introduce artefacting/"patchy" predictions.
162
+ # tiled_inference_parameters:
163
+ # h_crop: 224
164
+ # h_stride: 224
165
+ # w_crop: 224
166
+ # w_stride: 224
167
+ # average_patches: true
168
+ optimizer:
169
+ class_path: torch.optim.AdamW
170
+ init_args:
171
+ lr: 0.0001
172
+ weight_decay: 0.05
173
+ lr_scheduler:
174
+ class_path: ReduceLROnPlateau
175
+ init_args:
176
+ monitor: val/loss