File size: 1,293 Bytes
920ac9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch.cuda

import chabud as ch

ds_path = "A:/CodingProjekte/DataMining/src/train_eval.hdf5"

# Press the green button in the gutter to run the script.
if __name__ == '__main__':
    print(ch.__version__)
    print(torch.cuda.is_available())
    # See PyCharm help at https://www.jetbrains.com/help/pycharm/
    channels = ["band_1", "band_2", "band_3", "band_4", "band_5", "band_6", "band_7", "band_8", "band_8a", "band_9",
                "band_11", "band_12", "nbr", "ndvi", "gndvi", "evi", "avi", "savi", "ndmi", "msi", "gci", "bsi", "ndwi",
                "ndgi"]
    # channels = ["band_1", "band_2", "band_3", "band_4", "band_5", "band_6", "band_7", "band_8", "band_8a", "band_9",
    #             "band_11", "band_12", "nbr", "ndmi", "ndvi", "bsi", "ndwi"]
    channels_fun = []

    for channel in channels:
        channels_fun.append(ch.CHANNEL_MAP[channel])

    ch.main(accelerator="gpu",
            datafile=ds_path,
            batch_size=5,
            learning_rate=0.00025,
            channels=channels_fun,
            n_cpus=0,
            model="unet",
            encoder="resnet34",
            encoder_depth=5,
            encoder_weights="imagenet",
            loss="dice",
            train_use_pre_fire=False,
            train_use_augmentation=True)