WannaBeDataScientist's picture
Final report submission
920ac9e
raw
history blame
27 kB
##
## chabud.py - Hilfsfunktionen für die ChaBuD ECML Challenge 2023
##
## CHANGES:
## 2023-05-23: Erste Version veröffentlicht
##
## TODO:
## * Funktion um Vorhersage als CSV zu speichern für Leaderboard
## * Argument um Anzahl Trainingsepochen zu steuern (epoch, max_epoch, ... ?)
## * Finales Modell ausgeben und ggf. auch Vorhersage auf Validierungsdaten speichern
##
import logging
import os
from pathlib import Path
import pandas as pd
import albumentations as A
import albumentations.pytorch.transforms as Atorch
import h5py
import numpy as np
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
import torch
import xarray as xr
from pytorch_lightning.callbacks import ModelCheckpoint
fn = Path("A:/CodingProjekte/DataMining/src/train_eval.hdf5")
#Wir wollen ein Dataframe erstellen, welches nur die Namen der Datensätze enthält, die eine größere Brandfläche als 2% haben.
#Sogesehen ist es dann eine whitelist
def basic_df():
res = []
#Anzahl aller Datensätze ("name")
count_ds = 0
with h5py.File(fn, "r") as fd:
for name, ds in fd.items():
count_burnt_pixels = 0
#Standardmäßig ist überall ein pre_fire verfügbar
pre_miss = 0
#Weil wir den Datensatz schon gecheckt haben, ist ein sicherer Zugrif auf post_fire und mask möglich (hier fehlen keine ganzen Datensätze)
post = ds["post_fire"]
mask = ds["mask"]
count_burnt_pixels = np.sum(mask)
count_pixels =512 * 512
burnt_pixel_rel = count_burnt_pixels / count_pixels
#Anders als bei mask und Post müssen wir vor Zugriff überprüfen ob "pre_fire" überhaupt existiert - Vermeidung einer Fehlermeldung
if "pre_fire" not in ds:
pre_miss = 1
res.append({"name": name, "pre_missing": pre_miss, "burnt_pixel_abs": count_burnt_pixels, "burnt_pixel_rel": burnt_pixel_rel})
return pd.DataFrame(res)
def miss_dp_df():
BANDS = ["coastal_aerosol", "blue", "green", "red",
"veg_red_1", "veg_red_2", "veg_red_3", "nir",
"veg_red_4", "water_vapour", "swir_1", "swir_2"]
res = basic_df().values
# miss_dp ist eine Liste mit "name", "pre" "post" (Werte von Pre + Postt werden mit den Bandnamen selektiert)
miss_count = 0
miss_dp = []
with h5py.File(fn, "r") as fd:
for x in res:
# skippe die Datensätze mit fehlendem Pre-Bild
# if x["pre_missing"] == 1:
# continue
pre_miss = False
# Laden der Daten aus dem Originaldatensatz
name = x["name"]
ds = to_xarray(fd[name])
pre = ds["pre"][...]
post = ds["post"][...]
mask = ds["mask"][...]
if x["pre_missing"] == 1:
# Code für den Fall, dass 'pre_missing' gleich 1 ist
post_miss = []
for band in range(pre.shape[2]):
post_miss.append((np.sum(post[band] == 0).values))
x_post_miss = xr.DataArray(post_miss, dims=["band"], coords={"band": BANDS})
miss_dp.append({"name": name, "pre": [], "post": x_post_miss.values})
else:
# Code für den Fall, dass 'pre_missing' nicht gleich 1 ist
pre_miss = []
post_miss = []
for band in range(pre.shape[2]):
pre_miss.append((np.sum(pre[band] == 0).values))
post_miss.append((np.sum(post[band] == 0).values))
x_pre_miss = xr.DataArray(pre_miss, dims=["band"], coords={"band": BANDS})
x_post_miss = xr.DataArray(post_miss, dims=["band"], coords={"band": BANDS})
miss_dp.append({"name": name, "pre": x_pre_miss.values, "post": x_post_miss.values})
return miss_dp
def wl():
whitelist = []
df = basic_df()
for index, row in df.iterrows():
if row["burnt_pixel_rel"] < 0.0025:
continue
whitelist.append(row["name"])
return whitelist
checkpoint_callback = ModelCheckpoint(
#dirpath='checkpoints/',
filename='model-{epoch:02d}-{val_iou:.2f}',
monitor='valid_iou',
mode='max',
save_top_k=3
)
__version__ = "1.0.0"
logger = logging.getLogger(__name__)
ds_path = "A:/CodingProjekte/DataMining/src/train_eval.hdf5"
def to_xarray(dataset, pretty_band_names=True):
"""Konvertiert ein HDF5-Gruppenobjekt, das Vor- und Nach-Brandbilder enthält, in xarray DataArrays.
Parameters
----------
dataset : h5py.Group
Ein HDF5-Gruppenobjekt, das die Vor- und Nach-Brandbilder, die Maske und die Metadaten enthält.
pretty_band_names : bool, optional
Wenn True (Standard), werden die "Pretty" Bandnamen verwendet, ansonsten die ursprünglichen MSI Bandnummern.
Returns
-------
dict
Ein Dictionary, das die xarray DataArrays für die Vor- und Nach-Brandbilder, die Maske und die Fold-Informationen enthält.
"""
if pretty_band_names:
BANDS = ["coastal_aerosol", "blue", "green", "red",
"veg_red_1", "veg_red_2", "veg_red_3", "nir",
"veg_red_4", "water_vapour", "swir_1", "swir_2"]
else:
BANDS = ["1", "2", "3", "4", "5", "6", "7", "8", "8a", "9", "11", "12"]
post = dataset["post_fire"][...].astype("float32") / 10000.0
try:
pre = dataset["pre_fire"][...].astype("float32") / 10000.0
except KeyError:
pre = np.zeros_like(post, dtype="float32")
mask = dataset["mask"][..., 0]
return {"pre": xr.DataArray(pre, dims=["x", "y", "band"], coords={"x": range(512), "y": range(512), "band": BANDS}),
"post": xr.DataArray(post, dims=["x", "y", "band"], coords={"x": range(512), "y": range(512), "band": BANDS}),
"mask": xr.DataArray(mask, dims=["x", "y"], coords={"x": range(512), "y": range(512)}),
"fold": dataset.attrs["fold"]}
class BandExtractor:
def __init__(self, index, name) -> None:
self.index = index
self.name = name
def __call__(self, data):
if isinstance(data, np.ndarray):
return data[..., self.index]
elif isinstance(data, xr.DataArray):
return data.sel(band=self.name).values
else:
msg = "Unknown data format."
raise Exception(msg)
def __repr__(self) -> str:
return f'BandExtractor({self.index}, "{self.name}")'
band_1 = BandExtractor(0, "coastal_aerosol")
band_2 = BandExtractor(1, "blue")
band_3 = BandExtractor(2, "green")
band_4 = BandExtractor(3, "red")
band_5 = BandExtractor(4, "veg_red_1")
band_6 = BandExtractor(5, "veg_red_2")
band_7 = BandExtractor(6, "veg_red_3")
band_8 = BandExtractor(7, "nir")
band_8a = BandExtractor(8, "veg_red_4")
band_9 = BandExtractor(9, "water_vapour")
band_11 = BandExtractor(10, "swir_1")
band_12 = BandExtractor(11, "swir_2")
def NBR(data):
"""Normalized Burn Ratio.
nbr = (nir - swir_2) / (nir + swir_2)
"""
if isinstance(data, np.ndarray):
nir = data[..., 7]
swir_2 = data[..., 11]
elif isinstance(data, xr.DataArray):
nir = data.sel(band="nir").values
swir_2 = data.sel(band="swir_2").values
else:
msg = "Unknown data format."
raise Exception(msg)
zaehler = nir - swir_2
nenner = nir + swir_2
return np.divide(zaehler, nenner, out=np.zeros_like(zaehler), where=nenner != 0.0)
def NDVI(data):
"""Normalized Difference Vegetation Index."""
if isinstance(data, np.ndarray):
red = data[..., 3]
nir = data[..., 7]
elif isinstance(data, xr.DataArray):
red = data.sel(band="red").values
nir = data.sel(band="nir").values
else:
msg = "Unknown data format."
raise Exception(msg)
zaehler = nir - red
nenner = nir + red
return np.divide(zaehler, nenner, out=np.zeros_like(zaehler), where=nenner != 0.0)
def GNDVI(data):
"""Green Normalized Difference Vegetation Index."""
if isinstance(data, np.ndarray):
green = data[..., 2]
red = data[..., 3]
nir = data[..., 7]
elif isinstance(data, xr.DataArray):
green = data.sel(band="green").values
red = data.sel(band="red").values
nir = data.sel(band="nir").values
else:
msg = "Unknown data format."
raise Exception(msg)
zaehler = nir - green
nenner = nir + red
return np.divide(zaehler, nenner, out=np.zeros_like(zaehler), where=nenner != 0.0)
def EVI(data):
"""Enhanced Vegetation Index."""
if isinstance(data, np.ndarray):
blue = data[..., 1]
red = data[..., 3]
nir = data[..., 7]
elif isinstance(data, xr.DataArray):
blue = data.sel(band="blue").values
red = data.sel(band="red").values
nir = data.sel(band="nir").values
else:
msg = "Unknown data format."
raise Exception(msg)
zaehler = nir - red
nenner = nir + 6 * red - 7.5 * blue + 1
return np.divide(zaehler, nenner, out=np.zeros_like(zaehler), where=nenner != 0.0)
def AVI(data):
"""Advanced Vegetation Index."""
if isinstance(data, np.ndarray):
red = data[..., 3]
nir = data[..., 7]
elif isinstance(data, xr.DataArray):
red = data.sel(band="red").values
nir = data.sel(band="nir").values
else:
msg = "Unknown data format."
raise Exception(msg)
base = nir * (1 - red) * (nir - red)
## FIXME: Deal with cube roots of negative values?
return np.power(base, 1./3., out=np.zeros_like(base), where=base>0)
def SAVI(data):
"""Soil Adjusted Vegetation Index."""
if isinstance(data, np.ndarray):
red = data[..., 3]
nir = data[..., 7]
elif isinstance(data, xr.DataArray):
red = data.sel(band="red").values
nir = data.sel(band="nir").values
else:
msg = "Unknown data format."
raise Exception(msg)
return (nir - red) / (nir + red + 0.428) * 1.428
def NDMI(data):
if isinstance(data, np.ndarray):
nir = data[..., 7]
swir_1 = data[..., 10]
elif isinstance(data, xr.DataArray):
nir = data.sel(band="nir").values
swir_1 = data.sel(band="swir_1").values
else:
msg = "Unknown data format."
raise Exception(msg)
zaehler = nir - swir_1
nenner = nir + swir_1
return np.divide(zaehler, nenner, out=np.zeros_like(zaehler), where=nenner != 0.0)
def MSI(data):
"""Moisture Stress Index.
Moisture Stress Index is used for canopy stress analysis, productivity
prediction and biophysical modeling. Interpretation of the MSI is inverted
relative to other water vegetation indices; thus, higher values of the
index indicate greater plant water stress and in inference, less soil
moisture content. The values of this index range from 0 to more than 3 with
the common range for green vegetation being 0.2 to 2.
"""
if isinstance(data, np.ndarray):
nir = data[..., 7]
swir_1 = data[..., 10]
elif isinstance(data, xr.DataArray):
nir = data.sel(band="nir").values
swir_1 = data.sel(band="swir_1").values
else:
msg = "Unknown data format."
raise Exception(msg)
return swir_1 - nir
def GCI(data):
"""Green Chlorophyll Index."""
if isinstance(data, np.ndarray):
green = data[..., 2]
water_vapour = data[..., 9]
elif isinstance(data, xr.DataArray):
green = data.sel(band="green").values
water_vapour = data.sel(band="water_vapour").values
else:
msg = "Unknown data format."
raise Exception(msg)
return water_vapour - green
def BSI(data):
"""Bare Soil Index."""
if isinstance(data, np.ndarray):
blue = data[..., 1]
red = data[..., 3]
nir = data[..., 7]
swir_1 = data[..., 10]
elif isinstance(data, xr.DataArray):
blue = data.sel(band="blue").values
red = data.sel(band="red").values
nir = data.sel(band="nir").values
swir_1 = data.sel(band="swir_1").values
else:
msg = "Unknown data format."
raise Exception(msg)
swir_red = swir_1 + red
nir_blue = nir + blue
zaehler = swir_red - nir_blue
nenner = swir_red + nir_blue
return np.divide(zaehler, nenner, out=np.zeros_like(zaehler), where=nenner != 0.0)
def NDWI(data):
"""Normalized Difference Water Index."""
if isinstance(data, np.ndarray):
green = data[..., 2]
nir = data[..., 7]
elif isinstance(data, xr.DataArray):
green = data.sel(band="green").values
nir = data.sel(band="nir").values
else:
msg = "Unknown data format."
raise Exception(msg)
zaehler = green - nir
nenner = green + nir
return np.divide(zaehler, nenner, out=np.zeros_like(zaehler), where=nenner != 0.0)
def NDSI(data):
"""Normalized Difference Snow Index."""
if isinstance(data, np.ndarray):
green = data[..., 2]
swir_1 = data[..., 10]
elif isinstance(data, xr.DataArray):
green = data.sel(band="green").values
swir_1 = data.sel(band="swir_1").values
else:
msg = "Unknown data format."
raise Exception(msg)
zaehler = green - swir_1
nenner = green + swir_1
return np.divide(zaehler, nenner, out=np.zeros_like(zaehler), where=nenner != 0.0)
def NDGI(data):
if isinstance(data, np.ndarray):
green = data[..., 2]
red = data[..., 3]
elif isinstance(data, xr.DataArray):
green = data.sel(band="green").values
red = data.sel(band="red").values
else:
msg = "Unknown data format."
raise Exception(msg)
zaehler = green - red
nenner = green + red
return np.divide(zaehler, nenner, out=np.zeros_like(zaehler), where=nenner != 0.0)
#Die Bänder kommen in Channels
class FiresDataset(torch.utils.data.Dataset):
def __init__(self, filename, folds=(0, 1, 2, 3, 4),
channels=[],
include_pre=False,
transform=None) -> None:
self._filename = filename
self._fd = h5py.File(filename, "r")
self._channels = channels
self._transform = transform
self._names = []
whitelist = wl()
for name in self._fd:
if self._fd[name].attrs["fold"] not in folds:
continue
if name in whitelist:
self._names.append((name, "post_fire"))
if include_pre and "pre_fire" in self._fd[name]:
pre_image = self._fd[name]["pre_fire"][...]
# Include only "real" pre_fire images
if np.mean(pre_image > 0) > 0.8:
self._names.append((name, "pre_fire"))
def number_of_channels(self):
return len(self._channels)
def __getitem__(self, idx):
name, state = self._names[idx]
data = self._fd[name][state][...].astype("float32") / 10000.0
if state == "pre_fire":
mask = np.zeros((512, 512), dtype="float32")
else:
mask = self._fd[name]["mask"][..., 0].astype("float32")
channels = []
for channel in self._channels:
channels.append(channel(data))
# Stack indices into a new image in CHW format.
image = np.stack(channels)
if self._transform:
# Transpose image so we get HWC instead of CHW format.
# Transform is responsible for transposing back as required by PyTorch.
image = image.transpose((1, 2, 0))
xfrm = self._transform(image=image, mask=mask)
image, mask = xfrm["image"], xfrm["mask"]
logger.debug("Final tensor shape: %s", image.shape)
return {"image": image, "mask": mask[None, :]}
def __len__(self) -> int:
return len(self._names)
class FireModel(pl.LightningModule):
def __init__(self,
datafile,
model,
encoder,
encoder_depth,
encoder_weights,
loss,
channels,
train_transform,
train_use_pre_fire,
n_cpus,
batch_size,
lr=0.00025,
**kwargs) -> None:
super().__init__()
self.save_hyperparameters()
self.datafile = datafile
self.lr = lr
self.channels = channels
if model == "unet":
decoder_channels = [2**(8 - d) for d in range(encoder_depth, 0, -1)]
self.model = smp.Unet(encoder_name=encoder, encoder_depth=encoder_depth, encoder_weights=encoder_weights,
decoder_channels=decoder_channels,
in_channels=len(channels), classes=1)
elif model == "unetpp":
decoder_channels = [2**(8 - d) for d in range(encoder_depth, 0, -1)]
self.model = smp.UnetPlusPlus(encoder_name=encoder, encoder_depth=encoder_depth, encoder_weights=encoder_weights,
decoder_channels=decoder_channels,
in_channels=len(channels), classes=1)
elif model == "fpn":
if encoder_depth == 3:
upsampling = 1
elif encoder_depth == 4:
upsampling = 2
elif encoder_depth == 5:
upsampling = 4
else:
raise "FPN: Unsupported encoder depth {encoder_depth}."
self.model = smp.FPN(encoder_name=encoder, encoder_weights=encoder_weights, encoder_depth=encoder_depth,
upsampling=upsampling,
in_channels=len(channels), classes=1)
elif model == "dlv3":
self.model = smp.DeepLabV3(encoder_name=encoder, encoder_weights=encoder_weights, encoder_depth=encoder_depth,
in_channels=len(channels), classes=1)
elif model == "dlv3p":
if encoder_depth != 5:
raise f"Unsupported encoder depth {encoder_depth} for DeepLabV3+ (must be 5)."
self.model = smp.DeepLabV3Plus(encoder_name=encoder, encoder_weights=encoder_weights, encoder_depth=encoder_depth,
in_channels=len(channels), classes=1)
else:
raise f"Unsupported model '{model}'."
if loss == "dice":
self.loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
elif loss == "bce":
self.loss_fn = smp.losses.SoftBCEWithLogitsLoss()
else:
raise f"Unsupported loss function '{loss}'."
self.train_transform = train_transform
self.train_use_pre_fire = train_use_pre_fire
self.n_cpus = n_cpus
self.batch_size = batch_size
def forward(self, image):
mask = self.model(image)
return mask
def shared_step(self, batch, stage):
image, mask = batch["image"], batch["mask"]
logits_mask = self.forward(image)
loss = self.loss_fn(logits_mask, mask)
prob_mask = logits_mask.sigmoid()
pred_mask = (prob_mask > 0.5).long()
tp, fp, fn, tn = smp.metrics.get_stats(pred_mask, mask.long(), mode="binary")
iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
self.log(f"{stage}_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log(f"{stage}_iou", iou, on_step=False, on_epoch=True, prog_bar=True, logger=True)
return loss
def training_step(self, batch, batch_idx):
return self.shared_step(batch, "train")
def train_dataloader(self):
train_ds = FiresDataset(self.datafile, folds=[1, 2, 3, 4],
channels=self.channels,
transform=self.train_transform,
include_pre=self.train_use_pre_fire)
train_dl = torch.utils.data.DataLoader(train_ds,
batch_size=self.batch_size,
num_workers=self.n_cpus,
shuffle=True,
pin_memory=True,
drop_last=False)
return train_dl
def validation_step(self, batch, batch_idx):
return self.shared_step(batch, "valid")
def val_dataloader(self):
val_ds = FiresDataset(self.datafile, folds=[0],
channels=self.channels,
transform=None,
include_pre=False)
val_dl = torch.utils.data.DataLoader(val_ds,
batch_size=self.batch_size,
num_workers=self.n_cpus,
shuffle=False,
pin_memory=True,
drop_last=False)
return val_dl
def test_step(self, batch, batch_idx):
return self.shared_step(batch, "test")
def configure_optimizers(self):
# TODO: Can we do better? We should probably implement a learning rate schedule?
return torch.optim.Adam(self.parameters(), lr=self.lr)
def main(accelerator,
datafile,
batch_size,
channels,
n_cpus,
model,
encoder,
encoder_depth,
encoder_weights,
loss,
train_use_pre_fire,
train_use_augmentation,
learning_rate,
):
if train_use_augmentation:
train_xfrm = A.Compose([
A.VerticalFlip(p=0.5),
A.HorizontalFlip(p=0.5),
A.Transpose(p=0.5),
A.RandomRotate90(p=0.5),
Atorch.ToTensorV2(),
])
else:
train_xfrm = None
logger.info("Instantiating model.")
mdl = FireModel(datafile=datafile,
model=model,
encoder=encoder,
encoder_depth=encoder_depth,
encoder_weights=encoder_weights,
loss=loss,
channels=channels,
n_cpus=n_cpus,
train_transform=train_xfrm,
train_use_pre_fire=train_use_pre_fire,
batch_size=batch_size,
lr=learning_rate)
trainer = pl.Trainer(accelerator=accelerator, devices="auto",
log_every_n_steps=10, max_epochs=30, callbacks=[checkpoint_callback])
#callbacks=[checkpoint_callback]
logger.info("Start training.")
trainer.fit(mdl)
CHANNEL_MAP = {
"band_1": band_1,
"band_2": band_2,
"band_3": band_3,
"band_4": band_4,
"band_5": band_5,
"band_6": band_6,
"band_7": band_7,
"band_8": band_8,
"band_8a": band_8a,
"band_9": band_9,
"band_11": band_11,
"band_12": band_12,
"nbr": NBR,
"ndvi": NDVI,
"gndvi": GNDVI,
"evi": EVI,
"avi": AVI,
"savi": SAVI,
"ndmi": NDMI,
"msi": MSI,
"gci": GCI,
"bsi": BSI,
"ndwi": NDWI,
"ndsi": NDSI,
"ndgi": NDGI,
}
if __name__ == "__main__":
import argparse # Only import when needed
N_CPUS = int(os.getenv("SLURM_CPUS_PER_TASK", 1))
parser = argparse.ArgumentParser("chabud.py")
parser.add_argument("--accelerator", type=str, choices=["cpu", "gpu", "auto"], default="auto")
parser.add_argument("--datafile", type=Path, default=ds_path,
help="Location of data file used for training.")
parser.add_argument("--n-cpus", type=int, default=N_CPUS, help="Number of CPU cores to use.")
parser.add_argument("--batch-size", type=int, default=2,
help="Training and validation batch size.")
parser.add_argument("--learning-rate", type=float, default=0.00025,
help="Learning rate of optimizer.")
parser.add_argument("--model", choices=["unet", "unetpp", "fpn", "dlv3", "dlv3p"], default="unet",
help="Segmentation model")
parser.add_argument("--encoder", choices=["resnet18", "resnet34", "resnet50", "vgg13", "dpn68", "dpn92", "timm-efficientnet-b0"], default="resnet34",
help="Encoder of segmentation model")
parser.add_argument("--encoder-depth", type=int, default=5,
help="Depth of encoder stage")
parser.add_argument("--encoder-weights", choices=["random", "imagenet"], default="imagenet",
help="Weight initialization for encoder")
parser.add_argument("--loss", choices=["dice", "bce"], default="dice",
help="Loss function")
parser.add_argument("--train-use-pre_fire", action="store_true",
help="Use pre_fire data for training?")
parser.add_argument("--train-use-augmentation", action="store_true",
help="Use data augmentation in training step?")
parser.add_argument("--channels", nargs="+", choices=CHANNEL_MAP.keys(),
default=["band_1", "band_2", "band_3", "band_4", "band_5", "band_6", "band_7", "band_8", "band_8a", "band_9", "band_11", "band_12"],
help="Channels to use for prediction")
parser.add_argument("--log-level", type=str, choices=["info", "debug"], default="info")
args = parser.parse_args()
LOGGING_MAP = {"info": logging.INFO, "debug": logging.DEBUG}
logging.basicConfig(level=LOGGING_MAP[args.log_level],
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%d-%b-%y %H:%M:%S")
if args.encoder_weights == "random":
args.encoder_weights = None
# Translate channel names to function that calculates the channel / index.
logger.info(f"Selected channels: {args.channels}")
channels = []
for channel in args.channels:
channels.append(CHANNEL_MAP[channel])
torch.set_num_threads(args.n_cpus)
torch.set_float32_matmul_precision("medium")
main(accelerator=args.accelerator,
datafile=args.datafile,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
channels=channels,
n_cpus=args.n_cpus,
model=args.model,
encoder=args.encoder,
encoder_depth=args.encoder_depth,
encoder_weights=args.encoder_weights,
loss=args.loss,
train_use_pre_fire=args.train_use_pre_fire,
train_use_augmentation=args.train_use_augmentation)