Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the Apache License, Version 2.0 | |
# found in the LICENSE file in the root directory of this source tree. | |
from dataclasses import dataclass | |
from enum import Enum | |
from functools import lru_cache | |
from gzip import GzipFile | |
from io import BytesIO | |
from mmap import ACCESS_READ, mmap | |
import os | |
from typing import Any, Callable, List, Optional, Set, Tuple | |
import warnings | |
import numpy as np | |
from .extended import ExtendedVisionDataset | |
_Labels = int | |
_DEFAULT_MMAP_CACHE_SIZE = 16 # Warning: This can exhaust file descriptors | |
class _ClassEntry: | |
block_offset: int | |
maybe_filename: Optional[str] = None | |
class _Entry: | |
class_index: int # noqa: E701 | |
start_offset: int | |
end_offset: int | |
filename: str | |
class _Split(Enum): | |
TRAIN = "train" | |
VAL = "val" | |
def length(self) -> int: | |
return { | |
_Split.TRAIN: 11_797_647, | |
_Split.VAL: 561_050, | |
}[self] | |
def entries_path(self): | |
return f"imagenet21kp_{self.value}.txt" | |
def _get_tarball_path(class_id: str) -> str: | |
return f"{class_id}.tar" | |
def _make_mmap_tarball(tarballs_root: str, mmap_cache_size: int): | |
def _mmap_tarball(class_id: str) -> mmap: | |
tarball_path = _get_tarball_path(class_id) | |
tarball_full_path = os.path.join(tarballs_root, tarball_path) | |
with open(tarball_full_path) as f: | |
return mmap(fileno=f.fileno(), length=0, access=ACCESS_READ) | |
return _mmap_tarball | |
class ImageNet22k(ExtendedVisionDataset): | |
_GZIPPED_INDICES: Set[int] = { | |
841_545, | |
1_304_131, | |
2_437_921, | |
2_672_079, | |
2_795_676, | |
2_969_786, | |
6_902_965, | |
6_903_550, | |
6_903_628, | |
7_432_557, | |
7_432_589, | |
7_813_809, | |
8_329_633, | |
10_296_990, | |
10_417_652, | |
10_492_265, | |
10_598_078, | |
10_782_398, | |
10_902_612, | |
11_203_736, | |
11_342_890, | |
11_397_596, | |
11_589_762, | |
11_705_103, | |
12_936_875, | |
13_289_782, | |
} | |
Labels = _Labels | |
def __init__( | |
self, | |
*, | |
root: str, | |
extra: str, | |
transforms: Optional[Callable] = None, | |
transform: Optional[Callable] = None, | |
target_transform: Optional[Callable] = None, | |
mmap_cache_size: int = _DEFAULT_MMAP_CACHE_SIZE, | |
) -> None: | |
super().__init__(root, transforms, transform, target_transform) | |
self._extra_root = extra | |
entries_path = self._get_entries_path(root) | |
self._entries = self._load_extra(entries_path) | |
class_ids_path = self._get_class_ids_path(root) | |
self._class_ids = self._load_extra(class_ids_path) | |
self._gzipped_indices = ImageNet22k._GZIPPED_INDICES | |
self._mmap_tarball = _make_mmap_tarball(self._tarballs_root, mmap_cache_size) | |
def _get_entries_path(self, root: Optional[str] = None) -> str: | |
return "entries.npy" | |
def _get_class_ids_path(self, root: Optional[str] = None) -> str: | |
return "class-ids.npy" | |
def _find_class_ids(self, path: str) -> List[str]: | |
class_ids = [] | |
with os.scandir(path) as entries: | |
for entry in entries: | |
root, ext = os.path.splitext(entry.name) | |
if ext != ".tar": | |
continue | |
class_ids.append(root) | |
return sorted(class_ids) | |
def _load_entries_class_ids(self, root: Optional[str] = None) -> Tuple[List[_Entry], List[str]]: | |
root = self.get_root(root) | |
entries: List[_Entry] = [] | |
class_ids = self._find_class_ids(root) | |
for class_index, class_id in enumerate(class_ids): | |
path = os.path.join(root, "blocks", f"{class_id}.log") | |
class_entries = [] | |
try: | |
with open(path) as f: | |
for line in f: | |
line = line.rstrip() | |
block, filename = line.split(":") | |
block_offset = int(block[6:]) | |
filename = filename[1:] | |
maybe_filename = None | |
if filename != "** Block of NULs **": | |
maybe_filename = filename | |
_, ext = os.path.splitext(filename) | |
# assert ext == ".JPEG" | |
class_entry = _ClassEntry(block_offset, maybe_filename) | |
class_entries.append(class_entry) | |
except OSError as e: | |
raise RuntimeError(f'can not read blocks file "{path}"') from e | |
assert class_entries[-1].maybe_filename is None | |
for class_entry1, class_entry2 in zip(class_entries, class_entries[1:]): | |
assert class_entry1.block_offset <= class_entry2.block_offset | |
start_offset = 512 * class_entry1.block_offset | |
end_offset = 512 * class_entry2.block_offset | |
assert class_entry1.maybe_filename is not None | |
filename = class_entry1.maybe_filename | |
entry = _Entry(class_index, start_offset, end_offset, filename) | |
# Skip invalid image files (PIL throws UnidentifiedImageError) | |
if filename == "n06470073_47249.JPEG": | |
continue | |
entries.append(entry) | |
return entries, class_ids | |
def _load_extra(self, extra_path: str) -> np.ndarray: | |
extra_root = self._extra_root | |
extra_full_path = os.path.join(extra_root, extra_path) | |
return np.load(extra_full_path, mmap_mode="r") | |
def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None: | |
extra_root = self._extra_root | |
extra_full_path = os.path.join(extra_root, extra_path) | |
os.makedirs(extra_root, exist_ok=True) | |
np.save(extra_full_path, extra_array) | |
def _tarballs_root(self) -> str: | |
return self.root | |
def find_class_id(self, class_index: int) -> str: | |
return str(self._class_ids[class_index]) | |
def get_image_data(self, index: int) -> bytes: | |
entry = self._entries[index] | |
class_id = entry["class_id"] | |
class_mmap = self._mmap_tarball(class_id) | |
start_offset, end_offset = entry["start_offset"], entry["end_offset"] | |
try: | |
mapped_data = class_mmap[start_offset:end_offset] | |
data = mapped_data[512:] # Skip entry header block | |
if len(data) >= 2 and tuple(data[:2]) == (0x1F, 0x8B): | |
assert index in self._gzipped_indices, f"unexpected gzip header for sample {index}" | |
with GzipFile(fileobj=BytesIO(data)) as g: | |
data = g.read() | |
except Exception as e: | |
raise RuntimeError(f"can not retrieve image data for sample {index} " f'from "{class_id}" tarball') from e | |
return data | |
def get_target(self, index: int) -> Any: | |
return int(self._entries[index]["class_index"]) | |
def get_targets(self) -> np.ndarray: | |
return self._entries["class_index"] | |
def get_class_id(self, index: int) -> str: | |
return str(self._entries[index]["class_id"]) | |
def get_class_ids(self) -> np.ndarray: | |
return self._entries["class_id"] | |
def __getitem__(self, index: int) -> Tuple[Any, Any]: | |
with warnings.catch_warnings(): | |
warnings.simplefilter("ignore") | |
return super().__getitem__(index) | |
def __len__(self) -> int: | |
return len(self._entries) | |
def _dump_entries(self, *args, **kwargs) -> None: | |
entries, class_ids = self._load_entries_class_ids(*args, **kwargs) | |
max_class_id_length, max_filename_length, max_class_index = -1, -1, -1 | |
for entry in entries: | |
class_id = class_ids[entry.class_index] | |
max_class_index = max(entry.class_index, max_class_index) | |
max_class_id_length = max(len(class_id), max_class_id_length) | |
max_filename_length = max(len(entry.filename), max_filename_length) | |
dtype = np.dtype( | |
[ | |
("class_index", "<u4"), | |
("class_id", f"U{max_class_id_length}"), | |
("start_offset", "<u4"), | |
("end_offset", "<u4"), | |
("filename", f"U{max_filename_length}"), | |
] | |
) | |
sample_count = len(entries) | |
entries_array = np.empty(sample_count, dtype=dtype) | |
for i, entry in enumerate(entries): | |
class_index = entry.class_index | |
class_id = class_ids[class_index] | |
start_offset = entry.start_offset | |
end_offset = entry.end_offset | |
filename = entry.filename | |
entries_array[i] = ( | |
class_index, | |
class_id, | |
start_offset, | |
end_offset, | |
filename, | |
) | |
entries_path = self._get_entries_path(*args, **kwargs) | |
self._save_extra(entries_array, entries_path) | |
def _dump_class_ids(self, *args, **kwargs) -> None: | |
entries_path = self._get_entries_path(*args, **kwargs) | |
entries_array = self._load_extra(entries_path) | |
max_class_id_length, max_class_index = -1, -1 | |
for entry in entries_array: | |
class_index, class_id = entry["class_index"], entry["class_id"] | |
max_class_index = max(int(class_index), max_class_index) | |
max_class_id_length = max(len(str(class_id)), max_class_id_length) | |
class_ids_array = np.empty(max_class_index + 1, dtype=f"U{max_class_id_length}") | |
for entry in entries_array: | |
class_index, class_id = entry["class_index"], entry["class_id"] | |
class_ids_array[class_index] = class_id | |
class_ids_path = self._get_class_ids_path(*args, **kwargs) | |
self._save_extra(class_ids_array, class_ids_path) | |
def _dump_extra(self, *args, **kwargs) -> None: | |
self._dump_entries(*args, *kwargs) | |
self._dump_class_ids(*args, *kwargs) | |
def dump_extra(self, root: Optional[str] = None) -> None: | |
return self._dump_extra(root) | |