Spaces:
Running
on
T4
Running
on
T4
from abc import abstractmethod | |
import torchvision.transforms as transforms | |
from datasets import augmentations | |
class TransformsConfig(object): | |
def __init__(self, opts): | |
self.opts = opts | |
def get_transforms(self): | |
pass | |
class EncodeTransforms(TransformsConfig): | |
def __init__(self, opts): | |
super(EncodeTransforms, self).__init__(opts) | |
def get_transforms(self): | |
transforms_dict = { | |
'transform_gt_train': transforms.Compose([ | |
transforms.Resize((320, 320)), | |
transforms.RandomHorizontalFlip(0.5), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), | |
'transform_source': None, | |
'transform_test': transforms.Compose([ | |
transforms.Resize((320, 320)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), | |
'transform_inference': transforms.Compose([ | |
transforms.Resize((320, 320)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) | |
} | |
return transforms_dict | |
class FrontalizationTransforms(TransformsConfig): | |
def __init__(self, opts): | |
super(FrontalizationTransforms, self).__init__(opts) | |
def get_transforms(self): | |
transforms_dict = { | |
'transform_gt_train': transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.RandomHorizontalFlip(0.5), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), | |
'transform_source': transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.RandomHorizontalFlip(0.5), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), | |
'transform_test': transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), | |
'transform_inference': transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) | |
} | |
return transforms_dict | |
class SketchToImageTransforms(TransformsConfig): | |
def __init__(self, opts): | |
super(SketchToImageTransforms, self).__init__(opts) | |
def get_transforms(self): | |
transforms_dict = { | |
'transform_gt_train': transforms.Compose([ | |
transforms.Resize((320, 320)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), | |
'transform_source': transforms.Compose([ | |
transforms.Resize((320, 320)), | |
transforms.ToTensor()]), | |
'transform_test': transforms.Compose([ | |
transforms.Resize((320, 320)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), | |
'transform_inference': transforms.Compose([ | |
transforms.Resize((320, 320)), | |
transforms.ToTensor()]), | |
} | |
return transforms_dict | |
class SegToImageTransforms(TransformsConfig): | |
def __init__(self, opts): | |
super(SegToImageTransforms, self).__init__(opts) | |
def get_transforms(self): | |
transforms_dict = { | |
'transform_gt_train': transforms.Compose([ | |
transforms.Resize((320, 320)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), | |
'transform_source': transforms.Compose([ | |
transforms.Resize((320, 320)), | |
augmentations.ToOneHot(self.opts.label_nc), | |
transforms.ToTensor()]), | |
'transform_test': transforms.Compose([ | |
transforms.Resize((320, 320)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), | |
'transform_inference': transforms.Compose([ | |
transforms.Resize((320, 320)), | |
augmentations.ToOneHot(self.opts.label_nc), | |
transforms.ToTensor()]) | |
} | |
return transforms_dict | |
class SuperResTransforms(TransformsConfig): | |
def __init__(self, opts): | |
super(SuperResTransforms, self).__init__(opts) | |
def get_transforms(self): | |
if self.opts.resize_factors is None: | |
self.opts.resize_factors = '1,2,4,8,16,32' | |
factors = [int(f) for f in self.opts.resize_factors.split(",")] | |
print("Performing down-sampling with factors: {}".format(factors)) | |
transforms_dict = { | |
'transform_gt_train': transforms.Compose([ | |
transforms.Resize((1280, 1280)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), | |
'transform_source': transforms.Compose([ | |
transforms.Resize((320, 320)), | |
augmentations.BilinearResize(factors=factors), | |
transforms.Resize((320, 320)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), | |
'transform_test': transforms.Compose([ | |
transforms.Resize((1280, 1280)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), | |
'transform_inference': transforms.Compose([ | |
transforms.Resize((320, 320)), | |
augmentations.BilinearResize(factors=factors), | |
transforms.Resize((320, 320)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) | |
} | |
return transforms_dict | |
class SuperResTransforms_320(TransformsConfig): | |
def __init__(self, opts): | |
super(SuperResTransforms_320, self).__init__(opts) | |
def get_transforms(self): | |
if self.opts.resize_factors is None: | |
self.opts.resize_factors = '1,2,4,8,16,32' | |
factors = [int(f) for f in self.opts.resize_factors.split(",")] | |
print("Performing down-sampling with factors: {}".format(factors)) | |
transforms_dict = { | |
'transform_gt_train': transforms.Compose([ | |
transforms.Resize((320, 320)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), | |
'transform_source': transforms.Compose([ | |
transforms.Resize((320, 320)), | |
augmentations.BilinearResize(factors=factors), | |
transforms.Resize((320, 320)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), | |
'transform_test': transforms.Compose([ | |
transforms.Resize((320, 320)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), | |
'transform_inference': transforms.Compose([ | |
transforms.Resize((320, 320)), | |
augmentations.BilinearResize(factors=factors), | |
transforms.Resize((320, 320)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) | |
} | |
return transforms_dict | |
class ToonifyTransforms(TransformsConfig): | |
def __init__(self, opts): | |
super(ToonifyTransforms, self).__init__(opts) | |
def get_transforms(self): | |
transforms_dict = { | |
'transform_gt_train': transforms.Compose([ | |
transforms.Resize((1024, 1024)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), | |
'transform_source': transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), | |
'transform_test': transforms.Compose([ | |
transforms.Resize((1024, 1024)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), | |
'transform_inference': transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) | |
} | |
return transforms_dict | |
class EditingTransforms(TransformsConfig): | |
def __init__(self, opts): | |
super(EditingTransforms, self).__init__(opts) | |
def get_transforms(self): | |
transforms_dict = { | |
'transform_gt_train': transforms.Compose([ | |
transforms.Resize((1280, 1280)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), | |
'transform_source': transforms.Compose([ | |
transforms.Resize((320, 320)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), | |
'transform_test': transforms.Compose([ | |
transforms.Resize((1280, 1280)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), | |
'transform_inference': transforms.Compose([ | |
transforms.Resize((320, 320)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) | |
} | |
return transforms_dict |