|
import json |
|
from torch.utils import data |
|
from torchvision.datasets import ImageFolder |
|
import torch |
|
import os |
|
from PIL import Image |
|
import numpy as np |
|
import argparse |
|
from tqdm import tqdm |
|
from munkres import Munkres |
|
import multiprocessing |
|
from multiprocessing import Process, Manager |
|
import collections |
|
import torchvision.transforms as transforms |
|
import torchvision.transforms.functional as TF |
|
import random |
|
import torchvision |
|
import cv2 |
|
|
|
torch.manual_seed(0) |
|
|
|
SegItem = collections.namedtuple('SegItem', ('image_name', 'tag')) |
|
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], |
|
std=[0.5, 0.5, 0.5]) |
|
|
|
TRANSFORM_TRAIN = transforms.Compose([ |
|
transforms.RandomResizedCrop(224), |
|
transforms.RandomHorizontalFlip(), |
|
]) |
|
|
|
TRANSFORM_EVAL = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
]) |
|
|
|
IMAGE_TRANSFORMS = transforms.Compose([ |
|
transforms.ToTensor(), |
|
normalize |
|
]) |
|
|
|
MERGED_TAGS = {'n04356056', 'n04355933', |
|
'n04493381', 'n02808440', |
|
'n03642806', 'n03832673', |
|
'n04008634', 'n03773504', |
|
'n03887697', 'n15075141'} |
|
|
|
TRAIN_PARTITION = "train" |
|
VAL_PARTITION = "val" |
|
LEGAL_PARTITIONS = {TRAIN_PARTITION, VAL_PARTITION} |
|
|
|
|
|
|
|
|
|
class SegmentationDataset(ImageFolder): |
|
def __init__(self, seg_path, imagenet_path, partition=TRAIN_PARTITION, num_samples=2, train_classes=500 |
|
, imagenet_classes_path='imagenet_classes.json'): |
|
assert partition in LEGAL_PARTITIONS |
|
self._partition = partition |
|
self._seg_path = seg_path |
|
self._imagenet_path = imagenet_path |
|
with open(imagenet_classes_path, 'r') as f: |
|
self._imagenet_classes = json.load(f) |
|
self._tag_list = [tag for tag in os.listdir(self._seg_path) if tag not in MERGED_TAGS] |
|
if partition == TRAIN_PARTITION: |
|
|
|
self._tag_list = self._tag_list[:train_classes] |
|
elif partition == VAL_PARTITION: |
|
|
|
self._tag_list = self._tag_list[train_classes:] |
|
for tag in self._tag_list: |
|
assert tag in self._imagenet_classes |
|
self._all_segementations = [] |
|
for tag in self._tag_list: |
|
base_dir = os.path.join(self._seg_path, tag) |
|
curr_num_samples = 0 |
|
for i, seg in enumerate(os.listdir(base_dir)): |
|
seg_name = seg.split('.')[0] |
|
if 'bfs' not in seg_name: |
|
continue |
|
seg_path = os.path.join(self._seg_path, tag, seg) |
|
seg_map = torch.load(seg_path) |
|
seg_map = torch.from_numpy(seg_map.astype(np.float32)) |
|
if torch.sum(seg_map) < 520: |
|
continue |
|
if curr_num_samples >= num_samples: |
|
break |
|
self._all_segementations.append(SegItem(seg_name, tag)) |
|
curr_num_samples += 1 |
|
|
|
def __getitem__(self, item): |
|
seg_item = self._all_segementations[item] |
|
|
|
seg_path = os.path.join(self._seg_path, seg_item.tag, seg_item.image_name + ".pt") |
|
|
|
image_path = os.path.join(self._imagenet_path, seg_item.image_name.split('_tokencut_bfs')[0] + ".JPEG") |
|
image = Image.open(image_path) |
|
image = image.convert('RGB') |
|
|
|
seg_map = torch.load(seg_path) |
|
seg_map = torch.from_numpy(seg_map.astype(np.float32)) |
|
|
|
|
|
seg_map = seg_map.reshape(1, seg_map.shape[-2], seg_map.shape[-1]) |
|
|
|
resize = transforms.Resize(224) |
|
image = resize(image) |
|
|
|
if self._partition == VAL_PARTITION: |
|
image = TRANSFORM_EVAL(image) |
|
seg_map = TRANSFORM_EVAL(seg_map) |
|
|
|
elif self._partition == TRAIN_PARTITION: |
|
|
|
resize = transforms.Resize(size=(256, 256)) |
|
image = resize(image) |
|
seg_map = resize(seg_map) |
|
|
|
|
|
i, j, h, w = transforms.RandomCrop.get_params( |
|
image, output_size=(224, 224)) |
|
image = TF.crop(image, i, j, h, w) |
|
seg_map = TF.crop(seg_map, i, j, h, w) |
|
|
|
|
|
if random.random() > 0.5: |
|
image = TF.hflip(image) |
|
seg_map = TF.hflip(seg_map) |
|
|
|
else: |
|
raise Exception(f"Unsupported partition type {self._partition}") |
|
image_ten = IMAGE_TRANSFORMS(image) |
|
|
|
|
|
class_name = int(self._imagenet_classes[seg_item.tag]) |
|
|
|
return seg_map, image_ten, class_name |
|
|
|
def __len__(self): |
|
return len(self._all_segementations) |
|
|