Spaces:
Sleeping
Sleeping
import configargparse | |
import torch | |
import torch.nn as nn | |
import torch.utils.data | |
import torchvision.utils as tvutils | |
import torchvision.transforms | |
from video3d.utils.segmentation_transforms import * | |
from video3d.utils.misc import setup_runtime | |
from video3d import networks | |
from video3d.trainer import Trainer | |
from video3d.dataloaders import SegmentationDataset | |
class Segmentation: | |
def __init__(self, cfgs, _): | |
self.cfgs = cfgs | |
self.device = cfgs.get('device', 'cpu') | |
self.total_loss = None | |
self.net = networks.EDDeconv(cin=3, cout=1, zdim=128, nf=64, activation=None) | |
self.optimizer = torch.optim.Adam( | |
filter(lambda p: p.requires_grad, self.net.parameters()), | |
lr=cfgs.get('lr', 1e-4), | |
betas=(0.9, 0.999), | |
weight_decay=5e-4) | |
def load_model_state(self, cp): | |
self.net.load_state_dict(cp["net"]) | |
def load_optimizer_state(self, cp): | |
self.net.load_state_dict(cp["optimizer"]) | |
def get_data_loaders(cfgs): | |
batch_size = cfgs.get('batch_size', 64) | |
num_workers = cfgs.get('num_workers', 4) | |
data_dir = cfgs.get('data_dir', './data') | |
img_size = cfgs.get('image_size', 64) | |
min_size = int(img_size * cfgs.get('aug_min_resize', 0.5)) | |
max_size = int(img_size * cfgs.get('aug_max_resize', 2.0)) | |
transform = Compose([RandomResize(min_size, max_size), | |
RandomHorizontalFlip(cfgs.get("aug_horizontal_flip", 0.4)), | |
RandomCrop(img_size), | |
ImageOnly(torchvision.transforms.ColorJitter(**cfgs.get("aug_color_jitter", {}))), | |
ImageOnly(torchvision.transforms.RandomGrayscale(cfgs.get("aug_grayscale", 0.2))), | |
ToTensor()]) | |
train_loader = torch.utils.data.DataLoader( | |
SegmentationDataset(data_dir, is_validation=False, transform=transform, sequence_range=(0, 0.5)), | |
batch_size=batch_size, | |
shuffle=True, | |
num_workers=num_workers, | |
pin_memory=True | |
) | |
transform = Compose([ToTensor()]) | |
val_loader = torch.utils.data.DataLoader( | |
SegmentationDataset(data_dir, is_validation=True, transform=transform, sequence_range=(0.5, 1.0)), | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=num_workers, | |
pin_memory=True | |
) | |
return train_loader, val_loader, None | |
def get_state_dict(self): | |
return { | |
"net": self.net.state_dict(), | |
"optimizer": self.optimizer.state_dict() | |
} | |
def to(self, device): | |
self.device = device | |
self.net.to(device) | |
def set_train(self): | |
self.net.train() | |
def set_eval(self): | |
self.net.eval() | |
def backward(self): | |
self.optimizer.zero_grad() | |
self.total_loss.backward() | |
self.optimizer.step() | |
def forward(self, batch, visualize=False): | |
image, target = batch | |
image = image.to(self.device)*2 - 1 | |
target = target[:, 0, :, :].to(self.device).unsqueeze(1) | |
pred = self.net(image) | |
self.total_loss = nn.functional.binary_cross_entropy_with_logits(pred, target) | |
metrics = {'loss': self.total_loss} | |
visuals = {} | |
if visualize: | |
visuals['rgb'] = self.image_visual(image, normalize=True, range=(-1, 1)) | |
visuals['target'] = self.image_visual(target, normalize=True, range=(0, 1)) | |
visuals['pred'] = self.image_visual(nn.functional.sigmoid(pred), normalize=True, range=(0, 1)) | |
return metrics, visuals | |
return metrics | |
def visualize(self, logger, total_iter, max_bs=25): | |
pass | |
def save_results(self, save_dir): | |
pass | |
def save_scores(self, path): | |
pass | |
def image_visual(tensor, **kwargs): | |
if tensor.shape[1] == 1: | |
tensor = tensor.repeat(1, 3, 1, 1) | |
n = int(tensor.shape[0]**0.5 + 0.5) | |
tensor = tvutils.make_grid(tensor.detach(), nrow=n, **kwargs).permute(1, 2, 0) | |
return torch.clamp(tensor[:, :, :3] * 255, 0, 255).byte().cpu() | |
if __name__ == "__main__": | |
parser = configargparse.ArgumentParser(description='Training configurations.') | |
parser.add_argument('--config', default="config/train_segmentation.yml", type=str, is_config_file=True, | |
help='Specify a config file path') | |
parser.add_argument('--gpu', default=1, type=int, help='Specify a GPU device') | |
parser.add_argument('--seed', default=0, type=int, help='Specify a random seed') | |
args, _ = parser.parse_known_args() | |
cfgs = setup_runtime(args) | |
trainer = Trainer(cfgs, Segmentation) | |
trainer.train() | |