Spaces:
Runtime error
Runtime error
File size: 4,736 Bytes
301b1c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import logging
import os
from typing import Callable, Optional
from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import check_integrity, download_and_extract_archive, verify_str_arg
_logger = logging.getLogger(__name__)
class ImageNetA(ImageFolder):
"""ImageNetA dataset.
- Paper: [https://arxiv.org/abs/1907.07174](https://arxiv.org/abs/1907.07174).
"""
base_folder = "imagenet-a"
url = "https://people.eecs.berkeley.edu/~hendrycks/imagenet-a.tar"
filename = "imagenet-a.tar"
tgz_md5 = "c3e55429088dc681f30d81f4726b6595"
def __init__(self, root: str, split=None, transform: Optional[Callable] = None, download: bool = False, **kwargs):
self.root = root
if download:
self.download()
if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
super().__init__(root=os.path.join(root, self.base_folder), transform=transform, **kwargs)
def _check_exists(self) -> bool:
return os.path.exists(os.path.join(self.root, self.base_folder))
def _check_integrity(self) -> bool:
return check_integrity(os.path.join(self.root, self.filename), self.tgz_md5)
def download(self) -> None:
if self._check_integrity() and self._check_exists():
_logger.debug("Files already downloaded and verified")
return
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
class ImageNetO(ImageNetA):
"""ImageNetO datasets.
Contains unknown classes to ImageNet-1k.
- Paper: [https://arxiv.org/abs/1907.07174](https://arxiv.org/abs/1907.07174)
"""
base_folder = "imagenet-o"
url = "https://people.eecs.berkeley.edu/~hendrycks/imagenet-o.tar"
filename = "imagenet-o.tar"
tgz_md5 = "86bd7a50c1c4074fb18fc5f219d6d50b"
class ImageNetR(ImageNetA):
"""ImageNet-R(endition) dataset.
Contains art, cartoons, deviantart, graffiti, embroidery, graphics, origami, paintings,
patterns, plastic objects,plush objects, sculptures, sketches, tattoos, toys,
and video game renditions of ImageNet-1k classes.
- Paper: [https://arxiv.org/abs/2006.16241](https://arxiv.org/abs/2006.16241)
"""
base_folder = "imagenet-r"
url = "https://people.eecs.berkeley.edu/~hendrycks/imagenet-r.tar"
filename = "imagenet-r.tar"
tgz_md5 = "a61312130a589d0ca1a8fca1f2bd3337"
class NINCOFull(ImageFolder):
"""`NINCO` Dataset subset.
Args:
root (string): Root directory of dataset where directory
exists or will be saved to if download is set to True.
split (string, optional): The dataset split, not used.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, `transforms.RandomCrop`.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
**kwargs: Additional arguments passed to :class:`~torchvision.datasets.ImageFolder`.
"""
PAPER_URL = "https://arxiv.org/pdf/2306.00826.pdf"
base_folder = "ninco"
filename = "NINCO_all.tar.gz"
file_md5 = "b9ffae324363cd900a81ce3c367cd834"
url = "https://zenodo.org/record/8013288/files/NINCO_all.tar.gz"
# size: 15393
def __init__(
self, root: str, split=None, transform: Optional[Callable] = None, download: bool = False, **kwargs
) -> None:
self.root = os.path.expanduser(root)
self.dataset_folder = os.path.join(self.root, self.base_folder)
self.archive = os.path.join(self.root, self.filename)
if download:
self.download()
if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it")
super().__init__(self.dataset_folder, transform=transform, **kwargs)
def _check_integrity(self) -> bool:
return check_integrity(self.archive, self.file_md5)
def _check_exists(self) -> bool:
return os.path.exists(self.dataset_folder)
def download(self) -> None:
if self._check_integrity() and self._check_exists():
return
download_and_extract_archive(
self.url, download_root=self.root, extract_root=self.dataset_folder, md5=self.file_md5
)
if __name__ == "__main__":
ImageNetR(root="data", download=True)
ImageNetO(root="data", download=True)
ImageNetA(root="data", download=True)
NINCOFull(root="data", download=True)
|