import torch as th class Normalize(object): def __init__(self, mean, std): self.mean = th.FloatTensor(mean).view(1, 3, 1, 1) self.std = th.FloatTensor(std).view(1, 3, 1, 1) def __call__(self, tensor): tensor = (tensor - self.mean) / (self.std + 1e-8) return tensor class Preprocessing(object): def __init__(self): self.norm = Normalize( mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) def __call__(self, tensor): tensor = tensor / 255.0 tensor = self.norm(tensor) return tensor