Spaces:
Build error
Build error
from PIL import Image | |
import torch | |
import torch.nn as nn | |
from typing import Dict, Iterable, Callable | |
from torch import Tensor | |
import glob | |
from tqdm import tqdm | |
import numpy as np | |
from PIL import ImageFile | |
ImageFile.LOAD_TRUNCATED_IMAGES = True | |
Image.MAX_IMAGE_PIXELS = None | |
# + | |
class RobustModel(nn.Module): | |
def __init__(self, model): | |
super().__init__() | |
self.model = model | |
def forward(self, x, *args, **kwargs): | |
return self.model(x) | |
class CustomArt(torch.utils.data.Dataset): | |
def __init__(self, image,transforms=None): | |
self.transforms = transforms | |
self.image = image | |
self.mean = torch.tensor([0.4850, 0.4560, 0.4060]) | |
self.std = torch.tensor([0.2290, 0.2240, 0.2250]) | |
def __getitem__(self, idx): | |
if self.transforms: | |
img = self.transforms(self.image) | |
return torch.as_tensor(img, dtype=torch.float) | |
def __len__(self): | |
return len(self.image) | |