Spaces:
Runtime error
Runtime error
import cv2 | |
import torch | |
import yaml | |
import numpy as np | |
import albumentations as A | |
from torch.utils.data import Dataset | |
def get_train_augs(IMAGE_SIZE): | |
return A.Compose([ | |
A.Resize(IMAGE_SIZE, IMAGE_SIZE), | |
A.HorizontalFlip(p = 0.5), | |
A.VerticalFlip(p = 0.5) | |
]) | |
def get_valid_augs(IMAGE_SIZE): | |
return A.Compose([ | |
A.Resize(IMAGE_SIZE, IMAGE_SIZE), | |
]) | |
def train_fn(data_loader, model, optimizer, DEVICE): | |
model.train() | |
total_loss = 0.0 | |
for images, masks in data_loader: | |
images = images.to(DEVICE) | |
masks = masks.to(DEVICE) | |
optimizer.zero_grad() | |
logits, loss = model(images, masks) | |
loss.backward() | |
optimizer.step() | |
total_loss += loss.item() | |
return total_loss / len(data_loader) | |
def eval_fn(data_loader, model, DEVICE): | |
model.eval() | |
total_loss = 0.0 | |
with torch.no_grad(): | |
for images, masks in data_loader: | |
images = images.to(DEVICE) | |
masks = masks.to(DEVICE) | |
logits, loss = model(images, masks) | |
total_loss += loss.item() | |
return total_loss / len(data_loader) | |
def load_config(): | |
config_file = f'config/config.yaml' | |
with open(config_file, 'r') as file: | |
config = yaml.safe_load(file) | |
return config | |
class SegmentationDataset(Dataset): | |
def __init__(self, df, augmentations): | |
self.df = df | |
self.augmentations = augmentations | |
def __len__(self): | |
return len(self.df) | |
def __getitem__(self, idx): | |
row = self.df.iloc[idx] | |
image_path = row.images | |
mask_path = row.masks | |
image = cv2.imread(image_path) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) #(h, w, c) | |
# Resize the mask to the same dimensions as the image | |
mask = cv2.resize(mask, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST) # (h, w) | |
mask = np.expand_dims(mask, axis = -1) | |
if self.augmentations: | |
data = self.augmentations(image = image, mask = mask) | |
image = data['image'] | |
mask = data['mask'] | |
# (h, w, c) -> (c, h, w) | |
image = np.transpose(image, (2,0,1)).astype(np.float32) | |
mask = np.transpose(mask, (2,0,1)).astype(np.float32) | |
image = torch.Tensor(image) / 255.0 | |
mask = torch.round(torch.Tensor(mask) / 255.0) | |
return image, mask | |