Spaces:
Runtime error
Runtime error
import logging | |
import os | |
from typing import Callable, Optional | |
from torchvision.datasets import ImageFolder | |
from torchvision.datasets.utils import check_integrity, download_and_extract_archive, verify_str_arg | |
_logger = logging.getLogger(__name__) | |
class ImageNetA(ImageFolder): | |
"""ImageNetA dataset. | |
- Paper: [https://arxiv.org/abs/1907.07174](https://arxiv.org/abs/1907.07174). | |
""" | |
base_folder = "imagenet-a" | |
url = "https://people.eecs.berkeley.edu/~hendrycks/imagenet-a.tar" | |
filename = "imagenet-a.tar" | |
tgz_md5 = "c3e55429088dc681f30d81f4726b6595" | |
def __init__(self, root: str, split=None, transform: Optional[Callable] = None, download: bool = False, **kwargs): | |
self.root = root | |
if download: | |
self.download() | |
if not self._check_integrity(): | |
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") | |
super().__init__(root=os.path.join(root, self.base_folder), transform=transform, **kwargs) | |
def _check_exists(self) -> bool: | |
return os.path.exists(os.path.join(self.root, self.base_folder)) | |
def _check_integrity(self) -> bool: | |
return check_integrity(os.path.join(self.root, self.filename), self.tgz_md5) | |
def download(self) -> None: | |
if self._check_integrity() and self._check_exists(): | |
_logger.debug("Files already downloaded and verified") | |
return | |
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) | |
class ImageNetO(ImageNetA): | |
"""ImageNetO datasets. | |
Contains unknown classes to ImageNet-1k. | |
- Paper: [https://arxiv.org/abs/1907.07174](https://arxiv.org/abs/1907.07174) | |
""" | |
base_folder = "imagenet-o" | |
url = "https://people.eecs.berkeley.edu/~hendrycks/imagenet-o.tar" | |
filename = "imagenet-o.tar" | |
tgz_md5 = "86bd7a50c1c4074fb18fc5f219d6d50b" | |
class ImageNetR(ImageNetA): | |
"""ImageNet-R(endition) dataset. | |
Contains art, cartoons, deviantart, graffiti, embroidery, graphics, origami, paintings, | |
patterns, plastic objects,plush objects, sculptures, sketches, tattoos, toys, | |
and video game renditions of ImageNet-1k classes. | |
- Paper: [https://arxiv.org/abs/2006.16241](https://arxiv.org/abs/2006.16241) | |
""" | |
base_folder = "imagenet-r" | |
url = "https://people.eecs.berkeley.edu/~hendrycks/imagenet-r.tar" | |
filename = "imagenet-r.tar" | |
tgz_md5 = "a61312130a589d0ca1a8fca1f2bd3337" | |
class NINCOFull(ImageFolder): | |
"""`NINCO` Dataset subset. | |
Args: | |
root (string): Root directory of dataset where directory | |
exists or will be saved to if download is set to True. | |
split (string, optional): The dataset split, not used. | |
transform (callable, optional): A function/transform that takes in an PIL image | |
and returns a transformed version. E.g, `transforms.RandomCrop`. | |
download (bool, optional): If true, downloads the dataset from the internet and | |
puts it in root directory. If dataset is already downloaded, it is not | |
downloaded again. | |
**kwargs: Additional arguments passed to :class:`~torchvision.datasets.ImageFolder`. | |
""" | |
PAPER_URL = "https://arxiv.org/pdf/2306.00826.pdf" | |
base_folder = "ninco" | |
filename = "NINCO_all.tar.gz" | |
file_md5 = "b9ffae324363cd900a81ce3c367cd834" | |
url = "https://zenodo.org/record/8013288/files/NINCO_all.tar.gz" | |
# size: 15393 | |
def __init__( | |
self, root: str, split=None, transform: Optional[Callable] = None, download: bool = False, **kwargs | |
) -> None: | |
self.root = os.path.expanduser(root) | |
self.dataset_folder = os.path.join(self.root, self.base_folder) | |
self.archive = os.path.join(self.root, self.filename) | |
if download: | |
self.download() | |
if not self._check_integrity(): | |
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") | |
super().__init__(self.dataset_folder, transform=transform, **kwargs) | |
def _check_integrity(self) -> bool: | |
return check_integrity(self.archive, self.file_md5) | |
def _check_exists(self) -> bool: | |
return os.path.exists(self.dataset_folder) | |
def download(self) -> None: | |
if self._check_integrity() and self._check_exists(): | |
return | |
download_and_extract_archive( | |
self.url, download_root=self.root, extract_root=self.dataset_folder, md5=self.file_md5 | |
) | |
if __name__ == "__main__": | |
ImageNetR(root="data", download=True) | |
ImageNetO(root="data", download=True) | |
ImageNetA(root="data", download=True) | |
NINCOFull(root="data", download=True) | |