|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
from pathlib import Path |
|
import pandas as pd |
|
import albumentations as A |
|
import albumentations.pytorch.transforms as Atorch |
|
import h5py |
|
import numpy as np |
|
import pytorch_lightning as pl |
|
import segmentation_models_pytorch as smp |
|
import torch |
|
import xarray as xr |
|
|
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
|
|
fn = Path("A:/CodingProjekte/DataMining/src/train_eval.hdf5") |
|
|
|
|
|
|
|
|
|
def basic_df(): |
|
res = [] |
|
|
|
count_ds = 0 |
|
|
|
with h5py.File(fn, "r") as fd: |
|
|
|
for name, ds in fd.items(): |
|
count_burnt_pixels = 0 |
|
|
|
pre_miss = 0 |
|
|
|
post = ds["post_fire"] |
|
mask = ds["mask"] |
|
|
|
count_burnt_pixels = np.sum(mask) |
|
count_pixels =512 * 512 |
|
burnt_pixel_rel = count_burnt_pixels / count_pixels |
|
|
|
if "pre_fire" not in ds: |
|
pre_miss = 1 |
|
res.append({"name": name, "pre_missing": pre_miss, "burnt_pixel_abs": count_burnt_pixels, "burnt_pixel_rel": burnt_pixel_rel}) |
|
|
|
return pd.DataFrame(res) |
|
|
|
def miss_dp_df(): |
|
BANDS = ["coastal_aerosol", "blue", "green", "red", |
|
"veg_red_1", "veg_red_2", "veg_red_3", "nir", |
|
"veg_red_4", "water_vapour", "swir_1", "swir_2"] |
|
|
|
res = basic_df().values |
|
|
|
|
|
miss_count = 0 |
|
miss_dp = [] |
|
with h5py.File(fn, "r") as fd: |
|
for x in res: |
|
|
|
|
|
|
|
pre_miss = False |
|
|
|
|
|
name = x["name"] |
|
ds = to_xarray(fd[name]) |
|
pre = ds["pre"][...] |
|
post = ds["post"][...] |
|
mask = ds["mask"][...] |
|
|
|
if x["pre_missing"] == 1: |
|
|
|
post_miss = [] |
|
for band in range(pre.shape[2]): |
|
post_miss.append((np.sum(post[band] == 0).values)) |
|
x_post_miss = xr.DataArray(post_miss, dims=["band"], coords={"band": BANDS}) |
|
miss_dp.append({"name": name, "pre": [], "post": x_post_miss.values}) |
|
else: |
|
|
|
pre_miss = [] |
|
post_miss = [] |
|
for band in range(pre.shape[2]): |
|
pre_miss.append((np.sum(pre[band] == 0).values)) |
|
post_miss.append((np.sum(post[band] == 0).values)) |
|
x_pre_miss = xr.DataArray(pre_miss, dims=["band"], coords={"band": BANDS}) |
|
x_post_miss = xr.DataArray(post_miss, dims=["band"], coords={"band": BANDS}) |
|
miss_dp.append({"name": name, "pre": x_pre_miss.values, "post": x_post_miss.values}) |
|
|
|
return miss_dp |
|
|
|
|
|
def wl(): |
|
whitelist = [] |
|
df = basic_df() |
|
|
|
for index, row in df.iterrows(): |
|
if row["burnt_pixel_rel"] < 0.0025: |
|
continue |
|
whitelist.append(row["name"]) |
|
return whitelist |
|
|
|
checkpoint_callback = ModelCheckpoint( |
|
|
|
filename='model-{epoch:02d}-{val_iou:.2f}', |
|
monitor='valid_iou', |
|
mode='max', |
|
save_top_k=3 |
|
) |
|
|
|
|
|
__version__ = "1.0.0" |
|
logger = logging.getLogger(__name__) |
|
|
|
ds_path = "A:/CodingProjekte/DataMining/src/train_eval.hdf5" |
|
|
|
|
|
def to_xarray(dataset, pretty_band_names=True): |
|
"""Konvertiert ein HDF5-Gruppenobjekt, das Vor- und Nach-Brandbilder enthält, in xarray DataArrays. |
|
|
|
Parameters |
|
---------- |
|
dataset : h5py.Group |
|
Ein HDF5-Gruppenobjekt, das die Vor- und Nach-Brandbilder, die Maske und die Metadaten enthält. |
|
pretty_band_names : bool, optional |
|
Wenn True (Standard), werden die "Pretty" Bandnamen verwendet, ansonsten die ursprünglichen MSI Bandnummern. |
|
|
|
Returns |
|
------- |
|
dict |
|
Ein Dictionary, das die xarray DataArrays für die Vor- und Nach-Brandbilder, die Maske und die Fold-Informationen enthält. |
|
""" |
|
if pretty_band_names: |
|
BANDS = ["coastal_aerosol", "blue", "green", "red", |
|
"veg_red_1", "veg_red_2", "veg_red_3", "nir", |
|
"veg_red_4", "water_vapour", "swir_1", "swir_2"] |
|
else: |
|
BANDS = ["1", "2", "3", "4", "5", "6", "7", "8", "8a", "9", "11", "12"] |
|
|
|
post = dataset["post_fire"][...].astype("float32") / 10000.0 |
|
|
|
try: |
|
pre = dataset["pre_fire"][...].astype("float32") / 10000.0 |
|
except KeyError: |
|
pre = np.zeros_like(post, dtype="float32") |
|
|
|
mask = dataset["mask"][..., 0] |
|
|
|
return {"pre": xr.DataArray(pre, dims=["x", "y", "band"], coords={"x": range(512), "y": range(512), "band": BANDS}), |
|
"post": xr.DataArray(post, dims=["x", "y", "band"], coords={"x": range(512), "y": range(512), "band": BANDS}), |
|
"mask": xr.DataArray(mask, dims=["x", "y"], coords={"x": range(512), "y": range(512)}), |
|
"fold": dataset.attrs["fold"]} |
|
|
|
class BandExtractor: |
|
def __init__(self, index, name) -> None: |
|
self.index = index |
|
self.name = name |
|
|
|
def __call__(self, data): |
|
if isinstance(data, np.ndarray): |
|
return data[..., self.index] |
|
elif isinstance(data, xr.DataArray): |
|
return data.sel(band=self.name).values |
|
else: |
|
msg = "Unknown data format." |
|
raise Exception(msg) |
|
|
|
def __repr__(self) -> str: |
|
return f'BandExtractor({self.index}, "{self.name}")' |
|
|
|
|
|
band_1 = BandExtractor(0, "coastal_aerosol") |
|
band_2 = BandExtractor(1, "blue") |
|
band_3 = BandExtractor(2, "green") |
|
band_4 = BandExtractor(3, "red") |
|
band_5 = BandExtractor(4, "veg_red_1") |
|
band_6 = BandExtractor(5, "veg_red_2") |
|
band_7 = BandExtractor(6, "veg_red_3") |
|
band_8 = BandExtractor(7, "nir") |
|
band_8a = BandExtractor(8, "veg_red_4") |
|
band_9 = BandExtractor(9, "water_vapour") |
|
band_11 = BandExtractor(10, "swir_1") |
|
band_12 = BandExtractor(11, "swir_2") |
|
|
|
|
|
def NBR(data): |
|
"""Normalized Burn Ratio. |
|
|
|
nbr = (nir - swir_2) / (nir + swir_2) |
|
""" |
|
if isinstance(data, np.ndarray): |
|
nir = data[..., 7] |
|
swir_2 = data[..., 11] |
|
elif isinstance(data, xr.DataArray): |
|
nir = data.sel(band="nir").values |
|
swir_2 = data.sel(band="swir_2").values |
|
else: |
|
msg = "Unknown data format." |
|
raise Exception(msg) |
|
|
|
zaehler = nir - swir_2 |
|
nenner = nir + swir_2 |
|
return np.divide(zaehler, nenner, out=np.zeros_like(zaehler), where=nenner != 0.0) |
|
|
|
|
|
def NDVI(data): |
|
"""Normalized Difference Vegetation Index.""" |
|
if isinstance(data, np.ndarray): |
|
red = data[..., 3] |
|
nir = data[..., 7] |
|
elif isinstance(data, xr.DataArray): |
|
red = data.sel(band="red").values |
|
nir = data.sel(band="nir").values |
|
else: |
|
msg = "Unknown data format." |
|
raise Exception(msg) |
|
|
|
zaehler = nir - red |
|
nenner = nir + red |
|
return np.divide(zaehler, nenner, out=np.zeros_like(zaehler), where=nenner != 0.0) |
|
|
|
|
|
def GNDVI(data): |
|
"""Green Normalized Difference Vegetation Index.""" |
|
if isinstance(data, np.ndarray): |
|
green = data[..., 2] |
|
red = data[..., 3] |
|
nir = data[..., 7] |
|
elif isinstance(data, xr.DataArray): |
|
green = data.sel(band="green").values |
|
red = data.sel(band="red").values |
|
nir = data.sel(band="nir").values |
|
else: |
|
msg = "Unknown data format." |
|
raise Exception(msg) |
|
|
|
zaehler = nir - green |
|
nenner = nir + red |
|
return np.divide(zaehler, nenner, out=np.zeros_like(zaehler), where=nenner != 0.0) |
|
|
|
|
|
def EVI(data): |
|
"""Enhanced Vegetation Index.""" |
|
if isinstance(data, np.ndarray): |
|
blue = data[..., 1] |
|
red = data[..., 3] |
|
nir = data[..., 7] |
|
elif isinstance(data, xr.DataArray): |
|
blue = data.sel(band="blue").values |
|
red = data.sel(band="red").values |
|
nir = data.sel(band="nir").values |
|
else: |
|
msg = "Unknown data format." |
|
raise Exception(msg) |
|
|
|
zaehler = nir - red |
|
nenner = nir + 6 * red - 7.5 * blue + 1 |
|
|
|
return np.divide(zaehler, nenner, out=np.zeros_like(zaehler), where=nenner != 0.0) |
|
|
|
|
|
def AVI(data): |
|
"""Advanced Vegetation Index.""" |
|
if isinstance(data, np.ndarray): |
|
red = data[..., 3] |
|
nir = data[..., 7] |
|
elif isinstance(data, xr.DataArray): |
|
red = data.sel(band="red").values |
|
nir = data.sel(band="nir").values |
|
else: |
|
msg = "Unknown data format." |
|
raise Exception(msg) |
|
|
|
base = nir * (1 - red) * (nir - red) |
|
|
|
return np.power(base, 1./3., out=np.zeros_like(base), where=base>0) |
|
|
|
|
|
def SAVI(data): |
|
"""Soil Adjusted Vegetation Index.""" |
|
if isinstance(data, np.ndarray): |
|
red = data[..., 3] |
|
nir = data[..., 7] |
|
elif isinstance(data, xr.DataArray): |
|
red = data.sel(band="red").values |
|
nir = data.sel(band="nir").values |
|
else: |
|
msg = "Unknown data format." |
|
raise Exception(msg) |
|
|
|
return (nir - red) / (nir + red + 0.428) * 1.428 |
|
|
|
|
|
def NDMI(data): |
|
if isinstance(data, np.ndarray): |
|
nir = data[..., 7] |
|
swir_1 = data[..., 10] |
|
elif isinstance(data, xr.DataArray): |
|
nir = data.sel(band="nir").values |
|
swir_1 = data.sel(band="swir_1").values |
|
else: |
|
msg = "Unknown data format." |
|
raise Exception(msg) |
|
|
|
zaehler = nir - swir_1 |
|
nenner = nir + swir_1 |
|
return np.divide(zaehler, nenner, out=np.zeros_like(zaehler), where=nenner != 0.0) |
|
|
|
|
|
def MSI(data): |
|
"""Moisture Stress Index. |
|
|
|
Moisture Stress Index is used for canopy stress analysis, productivity |
|
prediction and biophysical modeling. Interpretation of the MSI is inverted |
|
relative to other water vegetation indices; thus, higher values of the |
|
index indicate greater plant water stress and in inference, less soil |
|
moisture content. The values of this index range from 0 to more than 3 with |
|
the common range for green vegetation being 0.2 to 2. |
|
""" |
|
if isinstance(data, np.ndarray): |
|
nir = data[..., 7] |
|
swir_1 = data[..., 10] |
|
elif isinstance(data, xr.DataArray): |
|
nir = data.sel(band="nir").values |
|
swir_1 = data.sel(band="swir_1").values |
|
else: |
|
msg = "Unknown data format." |
|
raise Exception(msg) |
|
|
|
return swir_1 - nir |
|
|
|
|
|
def GCI(data): |
|
"""Green Chlorophyll Index.""" |
|
if isinstance(data, np.ndarray): |
|
green = data[..., 2] |
|
water_vapour = data[..., 9] |
|
elif isinstance(data, xr.DataArray): |
|
green = data.sel(band="green").values |
|
water_vapour = data.sel(band="water_vapour").values |
|
else: |
|
msg = "Unknown data format." |
|
raise Exception(msg) |
|
|
|
return water_vapour - green |
|
|
|
|
|
def BSI(data): |
|
"""Bare Soil Index.""" |
|
if isinstance(data, np.ndarray): |
|
blue = data[..., 1] |
|
red = data[..., 3] |
|
nir = data[..., 7] |
|
swir_1 = data[..., 10] |
|
elif isinstance(data, xr.DataArray): |
|
blue = data.sel(band="blue").values |
|
red = data.sel(band="red").values |
|
nir = data.sel(band="nir").values |
|
swir_1 = data.sel(band="swir_1").values |
|
else: |
|
msg = "Unknown data format." |
|
raise Exception(msg) |
|
|
|
swir_red = swir_1 + red |
|
nir_blue = nir + blue |
|
zaehler = swir_red - nir_blue |
|
nenner = swir_red + nir_blue |
|
return np.divide(zaehler, nenner, out=np.zeros_like(zaehler), where=nenner != 0.0) |
|
|
|
|
|
def NDWI(data): |
|
"""Normalized Difference Water Index.""" |
|
if isinstance(data, np.ndarray): |
|
green = data[..., 2] |
|
nir = data[..., 7] |
|
elif isinstance(data, xr.DataArray): |
|
green = data.sel(band="green").values |
|
nir = data.sel(band="nir").values |
|
else: |
|
msg = "Unknown data format." |
|
raise Exception(msg) |
|
|
|
zaehler = green - nir |
|
nenner = green + nir |
|
return np.divide(zaehler, nenner, out=np.zeros_like(zaehler), where=nenner != 0.0) |
|
|
|
|
|
def NDSI(data): |
|
"""Normalized Difference Snow Index.""" |
|
if isinstance(data, np.ndarray): |
|
green = data[..., 2] |
|
swir_1 = data[..., 10] |
|
elif isinstance(data, xr.DataArray): |
|
green = data.sel(band="green").values |
|
swir_1 = data.sel(band="swir_1").values |
|
else: |
|
msg = "Unknown data format." |
|
raise Exception(msg) |
|
|
|
zaehler = green - swir_1 |
|
nenner = green + swir_1 |
|
return np.divide(zaehler, nenner, out=np.zeros_like(zaehler), where=nenner != 0.0) |
|
|
|
|
|
def NDGI(data): |
|
if isinstance(data, np.ndarray): |
|
green = data[..., 2] |
|
red = data[..., 3] |
|
elif isinstance(data, xr.DataArray): |
|
green = data.sel(band="green").values |
|
red = data.sel(band="red").values |
|
else: |
|
msg = "Unknown data format." |
|
raise Exception(msg) |
|
|
|
zaehler = green - red |
|
nenner = green + red |
|
return np.divide(zaehler, nenner, out=np.zeros_like(zaehler), where=nenner != 0.0) |
|
|
|
|
|
class FiresDataset(torch.utils.data.Dataset): |
|
def __init__(self, filename, folds=(0, 1, 2, 3, 4), |
|
channels=[], |
|
include_pre=False, |
|
transform=None) -> None: |
|
self._filename = filename |
|
self._fd = h5py.File(filename, "r") |
|
self._channels = channels |
|
self._transform = transform |
|
self._names = [] |
|
|
|
whitelist = wl() |
|
for name in self._fd: |
|
if self._fd[name].attrs["fold"] not in folds: |
|
continue |
|
if name in whitelist: |
|
self._names.append((name, "post_fire")) |
|
if include_pre and "pre_fire" in self._fd[name]: |
|
pre_image = self._fd[name]["pre_fire"][...] |
|
|
|
if np.mean(pre_image > 0) > 0.8: |
|
self._names.append((name, "pre_fire")) |
|
|
|
def number_of_channels(self): |
|
return len(self._channels) |
|
|
|
def __getitem__(self, idx): |
|
name, state = self._names[idx] |
|
data = self._fd[name][state][...].astype("float32") / 10000.0 |
|
if state == "pre_fire": |
|
mask = np.zeros((512, 512), dtype="float32") |
|
else: |
|
mask = self._fd[name]["mask"][..., 0].astype("float32") |
|
|
|
channels = [] |
|
for channel in self._channels: |
|
channels.append(channel(data)) |
|
|
|
|
|
image = np.stack(channels) |
|
|
|
if self._transform: |
|
|
|
|
|
image = image.transpose((1, 2, 0)) |
|
xfrm = self._transform(image=image, mask=mask) |
|
image, mask = xfrm["image"], xfrm["mask"] |
|
logger.debug("Final tensor shape: %s", image.shape) |
|
|
|
return {"image": image, "mask": mask[None, :]} |
|
|
|
def __len__(self) -> int: |
|
return len(self._names) |
|
|
|
|
|
class FireModel(pl.LightningModule): |
|
def __init__(self, |
|
datafile, |
|
model, |
|
encoder, |
|
encoder_depth, |
|
encoder_weights, |
|
loss, |
|
channels, |
|
train_transform, |
|
train_use_pre_fire, |
|
n_cpus, |
|
batch_size, |
|
lr=0.00025, |
|
**kwargs) -> None: |
|
super().__init__() |
|
self.save_hyperparameters() |
|
self.datafile = datafile |
|
self.lr = lr |
|
self.channels = channels |
|
if model == "unet": |
|
decoder_channels = [2**(8 - d) for d in range(encoder_depth, 0, -1)] |
|
self.model = smp.Unet(encoder_name=encoder, encoder_depth=encoder_depth, encoder_weights=encoder_weights, |
|
decoder_channels=decoder_channels, |
|
in_channels=len(channels), classes=1) |
|
elif model == "unetpp": |
|
decoder_channels = [2**(8 - d) for d in range(encoder_depth, 0, -1)] |
|
self.model = smp.UnetPlusPlus(encoder_name=encoder, encoder_depth=encoder_depth, encoder_weights=encoder_weights, |
|
decoder_channels=decoder_channels, |
|
in_channels=len(channels), classes=1) |
|
elif model == "fpn": |
|
if encoder_depth == 3: |
|
upsampling = 1 |
|
elif encoder_depth == 4: |
|
upsampling = 2 |
|
elif encoder_depth == 5: |
|
upsampling = 4 |
|
else: |
|
raise "FPN: Unsupported encoder depth {encoder_depth}." |
|
self.model = smp.FPN(encoder_name=encoder, encoder_weights=encoder_weights, encoder_depth=encoder_depth, |
|
upsampling=upsampling, |
|
in_channels=len(channels), classes=1) |
|
elif model == "dlv3": |
|
self.model = smp.DeepLabV3(encoder_name=encoder, encoder_weights=encoder_weights, encoder_depth=encoder_depth, |
|
in_channels=len(channels), classes=1) |
|
elif model == "dlv3p": |
|
if encoder_depth != 5: |
|
raise f"Unsupported encoder depth {encoder_depth} for DeepLabV3+ (must be 5)." |
|
self.model = smp.DeepLabV3Plus(encoder_name=encoder, encoder_weights=encoder_weights, encoder_depth=encoder_depth, |
|
in_channels=len(channels), classes=1) |
|
else: |
|
raise f"Unsupported model '{model}'." |
|
|
|
if loss == "dice": |
|
self.loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True) |
|
elif loss == "bce": |
|
self.loss_fn = smp.losses.SoftBCEWithLogitsLoss() |
|
else: |
|
raise f"Unsupported loss function '{loss}'." |
|
|
|
self.train_transform = train_transform |
|
self.train_use_pre_fire = train_use_pre_fire |
|
self.n_cpus = n_cpus |
|
self.batch_size = batch_size |
|
|
|
def forward(self, image): |
|
mask = self.model(image) |
|
return mask |
|
|
|
def shared_step(self, batch, stage): |
|
image, mask = batch["image"], batch["mask"] |
|
|
|
logits_mask = self.forward(image) |
|
loss = self.loss_fn(logits_mask, mask) |
|
|
|
prob_mask = logits_mask.sigmoid() |
|
pred_mask = (prob_mask > 0.5).long() |
|
tp, fp, fn, tn = smp.metrics.get_stats(pred_mask, mask.long(), mode="binary") |
|
iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise") |
|
|
|
self.log(f"{stage}_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) |
|
self.log(f"{stage}_iou", iou, on_step=False, on_epoch=True, prog_bar=True, logger=True) |
|
return loss |
|
|
|
def training_step(self, batch, batch_idx): |
|
return self.shared_step(batch, "train") |
|
|
|
def train_dataloader(self): |
|
train_ds = FiresDataset(self.datafile, folds=[1, 2, 3, 4], |
|
channels=self.channels, |
|
transform=self.train_transform, |
|
include_pre=self.train_use_pre_fire) |
|
train_dl = torch.utils.data.DataLoader(train_ds, |
|
batch_size=self.batch_size, |
|
num_workers=self.n_cpus, |
|
shuffle=True, |
|
pin_memory=True, |
|
drop_last=False) |
|
return train_dl |
|
|
|
def validation_step(self, batch, batch_idx): |
|
return self.shared_step(batch, "valid") |
|
|
|
def val_dataloader(self): |
|
val_ds = FiresDataset(self.datafile, folds=[0], |
|
channels=self.channels, |
|
transform=None, |
|
include_pre=False) |
|
val_dl = torch.utils.data.DataLoader(val_ds, |
|
batch_size=self.batch_size, |
|
num_workers=self.n_cpus, |
|
shuffle=False, |
|
pin_memory=True, |
|
drop_last=False) |
|
return val_dl |
|
|
|
def test_step(self, batch, batch_idx): |
|
return self.shared_step(batch, "test") |
|
|
|
def configure_optimizers(self): |
|
|
|
return torch.optim.Adam(self.parameters(), lr=self.lr) |
|
|
|
|
|
def main(accelerator, |
|
datafile, |
|
batch_size, |
|
channels, |
|
n_cpus, |
|
model, |
|
encoder, |
|
encoder_depth, |
|
encoder_weights, |
|
loss, |
|
train_use_pre_fire, |
|
train_use_augmentation, |
|
learning_rate, |
|
): |
|
|
|
|
|
if train_use_augmentation: |
|
train_xfrm = A.Compose([ |
|
A.VerticalFlip(p=0.5), |
|
A.HorizontalFlip(p=0.5), |
|
A.Transpose(p=0.5), |
|
A.RandomRotate90(p=0.5), |
|
Atorch.ToTensorV2(), |
|
]) |
|
else: |
|
train_xfrm = None |
|
|
|
logger.info("Instantiating model.") |
|
mdl = FireModel(datafile=datafile, |
|
model=model, |
|
encoder=encoder, |
|
encoder_depth=encoder_depth, |
|
encoder_weights=encoder_weights, |
|
loss=loss, |
|
channels=channels, |
|
n_cpus=n_cpus, |
|
train_transform=train_xfrm, |
|
train_use_pre_fire=train_use_pre_fire, |
|
batch_size=batch_size, |
|
lr=learning_rate) |
|
|
|
trainer = pl.Trainer(accelerator=accelerator, devices="auto", |
|
log_every_n_steps=10, max_epochs=30, callbacks=[checkpoint_callback]) |
|
|
|
logger.info("Start training.") |
|
trainer.fit(mdl) |
|
|
|
|
|
CHANNEL_MAP = { |
|
"band_1": band_1, |
|
"band_2": band_2, |
|
"band_3": band_3, |
|
"band_4": band_4, |
|
"band_5": band_5, |
|
"band_6": band_6, |
|
"band_7": band_7, |
|
"band_8": band_8, |
|
"band_8a": band_8a, |
|
"band_9": band_9, |
|
"band_11": band_11, |
|
"band_12": band_12, |
|
"nbr": NBR, |
|
"ndvi": NDVI, |
|
"gndvi": GNDVI, |
|
"evi": EVI, |
|
"avi": AVI, |
|
"savi": SAVI, |
|
"ndmi": NDMI, |
|
"msi": MSI, |
|
"gci": GCI, |
|
"bsi": BSI, |
|
"ndwi": NDWI, |
|
"ndsi": NDSI, |
|
"ndgi": NDGI, |
|
} |
|
|
|
if __name__ == "__main__": |
|
|
|
import argparse |
|
|
|
N_CPUS = int(os.getenv("SLURM_CPUS_PER_TASK", 1)) |
|
parser = argparse.ArgumentParser("chabud.py") |
|
parser.add_argument("--accelerator", type=str, choices=["cpu", "gpu", "auto"], default="auto") |
|
parser.add_argument("--datafile", type=Path, default=ds_path, |
|
help="Location of data file used for training.") |
|
parser.add_argument("--n-cpus", type=int, default=N_CPUS, help="Number of CPU cores to use.") |
|
parser.add_argument("--batch-size", type=int, default=2, |
|
help="Training and validation batch size.") |
|
parser.add_argument("--learning-rate", type=float, default=0.00025, |
|
help="Learning rate of optimizer.") |
|
parser.add_argument("--model", choices=["unet", "unetpp", "fpn", "dlv3", "dlv3p"], default="unet", |
|
help="Segmentation model") |
|
parser.add_argument("--encoder", choices=["resnet18", "resnet34", "resnet50", "vgg13", "dpn68", "dpn92", "timm-efficientnet-b0"], default="resnet34", |
|
help="Encoder of segmentation model") |
|
parser.add_argument("--encoder-depth", type=int, default=5, |
|
help="Depth of encoder stage") |
|
parser.add_argument("--encoder-weights", choices=["random", "imagenet"], default="imagenet", |
|
help="Weight initialization for encoder") |
|
parser.add_argument("--loss", choices=["dice", "bce"], default="dice", |
|
help="Loss function") |
|
parser.add_argument("--train-use-pre_fire", action="store_true", |
|
help="Use pre_fire data for training?") |
|
parser.add_argument("--train-use-augmentation", action="store_true", |
|
help="Use data augmentation in training step?") |
|
parser.add_argument("--channels", nargs="+", choices=CHANNEL_MAP.keys(), |
|
default=["band_1", "band_2", "band_3", "band_4", "band_5", "band_6", "band_7", "band_8", "band_8a", "band_9", "band_11", "band_12"], |
|
help="Channels to use for prediction") |
|
parser.add_argument("--log-level", type=str, choices=["info", "debug"], default="info") |
|
|
|
args = parser.parse_args() |
|
|
|
LOGGING_MAP = {"info": logging.INFO, "debug": logging.DEBUG} |
|
logging.basicConfig(level=LOGGING_MAP[args.log_level], |
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
datefmt="%d-%b-%y %H:%M:%S") |
|
|
|
|
|
if args.encoder_weights == "random": |
|
args.encoder_weights = None |
|
|
|
|
|
logger.info(f"Selected channels: {args.channels}") |
|
channels = [] |
|
for channel in args.channels: |
|
channels.append(CHANNEL_MAP[channel]) |
|
|
|
|
|
torch.set_num_threads(args.n_cpus) |
|
torch.set_float32_matmul_precision("medium") |
|
|
|
main(accelerator=args.accelerator, |
|
datafile=args.datafile, |
|
batch_size=args.batch_size, |
|
learning_rate=args.learning_rate, |
|
channels=channels, |
|
n_cpus=args.n_cpus, |
|
model=args.model, |
|
encoder=args.encoder, |
|
encoder_depth=args.encoder_depth, |
|
encoder_weights=args.encoder_weights, |
|
loss=args.loss, |
|
train_use_pre_fire=args.train_use_pre_fire, |
|
train_use_augmentation=args.train_use_augmentation) |
|
|
|
|