|
import torch.cuda |
|
|
|
import chabud as ch |
|
|
|
ds_path = "A:/CodingProjekte/DataMining/src/train_eval.hdf5" |
|
|
|
|
|
if __name__ == '__main__': |
|
print(ch.__version__) |
|
print(torch.cuda.is_available()) |
|
|
|
channels = ["band_1", "band_2", "band_3", "band_4", "band_5", "band_6", "band_7", "band_8", "band_8a", "band_9", |
|
"band_11", "band_12", "nbr", "ndvi", "gndvi", "evi", "avi", "savi", "ndmi", "msi", "gci", "bsi", "ndwi", |
|
"ndgi"] |
|
|
|
|
|
channels_fun = [] |
|
|
|
for channel in channels: |
|
channels_fun.append(ch.CHANNEL_MAP[channel]) |
|
|
|
ch.main(accelerator="gpu", |
|
datafile=ds_path, |
|
batch_size=5, |
|
learning_rate=0.00025, |
|
channels=channels_fun, |
|
n_cpus=0, |
|
model="unet", |
|
encoder="resnet34", |
|
encoder_depth=5, |
|
encoder_weights="imagenet", |
|
loss="dice", |
|
train_use_pre_fire=False, |
|
train_use_augmentation=True) |
|
|
|
|
|
|