AdversarialArt / src /.ipynb_checkpoints /utils-checkpoint.py
will33am's picture
update
be2ced2
raw
history blame
995 Bytes
from PIL import Image
import torch
import torch.nn as nn
from typing import Dict, Iterable, Callable
from torch import Tensor
import glob
from tqdm import tqdm
import numpy as np
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None
# +
class RobustModel(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x, *args, **kwargs):
return self.model(x)
class CustomArt(torch.utils.data.Dataset):
def __init__(self, image,transforms=None):
self.transforms = transforms
self.image = image
self.mean = torch.tensor([0.4850, 0.4560, 0.4060])
self.std = torch.tensor([0.2290, 0.2240, 0.2250])
def __getitem__(self, idx):
if self.transforms:
img = self.transforms(self.image)
return torch.as_tensor(img, dtype=torch.float)
def __len__(self):
return len(self.image)