|
import os |
|
import torch |
|
from PIL import Image |
|
from collections import OrderedDict |
|
from toolkit.dhelper import traverse_recursively |
|
import random |
|
from torch import nn |
|
import numpy as np |
|
import timm |
|
import einops |
|
import torch.nn.functional as F |
|
|
|
|
|
class SRMConv2d_simple(nn.Module): |
|
def __init__(self, inc=3): |
|
super(SRMConv2d_simple, self).__init__() |
|
self.truc = nn.Hardtanh(-3, 3) |
|
self.kernel = torch.from_numpy(self._build_kernel(inc)).float() |
|
|
|
def forward(self, x): |
|
out = F.conv2d(x, self.kernel, stride=1, padding=2) |
|
out = self.truc(out) |
|
|
|
return out |
|
|
|
def _build_kernel(self, inc): |
|
|
|
filter1 = [[0, 0, 0, 0, 0], |
|
[0, -1, 2, -1, 0], |
|
[0, 2, -4, 2, 0], |
|
[0, -1, 2, -1, 0], |
|
[0, 0, 0, 0, 0]] |
|
|
|
filter2 = [[-1, 2, -2, 2, -1], |
|
[2, -6, 8, -6, 2], |
|
[-2, 8, -12, 8, -2], |
|
[2, -6, 8, -6, 2], |
|
[-1, 2, -2, 2, -1]] |
|
|
|
filter3 = [[0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0], |
|
[0, 1, -2, 1, 0], |
|
[0, 0, 0, 0, 0], |
|
[0, 0, 0, 0, 0]] |
|
|
|
filter1 = np.asarray(filter1, dtype=float) / 4. |
|
filter2 = np.asarray(filter2, dtype=float) / 12. |
|
filter3 = np.asarray(filter3, dtype=float) / 2. |
|
|
|
filters = [[filter1], |
|
[filter2], |
|
[filter3]] |
|
filters = np.array(filters) |
|
filters = np.repeat(filters, inc, axis=1) |
|
return filters |
|
|
|
|
|
class MultiClassificationProcessor_mfolder(torch.utils.data.Dataset): |
|
def __init__(self, transform=None): |
|
self.transformer_ = transform |
|
self.extension_ = '.jpg .jpeg .png .bmp .webp .tif .eps' |
|
|
|
self.ctg_names_ = [] |
|
self.ctg_name2idx_ = OrderedDict() |
|
|
|
self.img_names_ = [] |
|
self.img_paths_ = [] |
|
self.img_labels_ = [] |
|
|
|
self.srm = SRMConv2d_simple() |
|
|
|
def load_data_from_dir_test(self, folders): |
|
|
|
|
|
|
|
|
|
|
|
print(folders) |
|
img_paths = [] |
|
traverse_recursively(folders, img_paths, self.extension_) |
|
|
|
for img_path in img_paths: |
|
img_name = os.path.basename(img_path) |
|
self.img_names_.append(img_name) |
|
self.img_paths_.append(img_path) |
|
|
|
length = len(img_paths) |
|
print('log: {} image num is {}'.format(folders, length)) |
|
|
|
def load_data_from_dir(self, dataset_list): |
|
|
|
|
|
|
|
|
|
|
|
|
|
for ctg_name, folders in dataset_list.items(): |
|
|
|
if ctg_name not in self.ctg_name2idx_: |
|
self.ctg_name2idx_[ctg_name] = len(self.ctg_names_) |
|
self.ctg_names_.append(ctg_name) |
|
|
|
for img_root in folders: |
|
img_paths = [] |
|
traverse_recursively(img_root, img_paths, self.extension_) |
|
|
|
print(img_root) |
|
|
|
length = len(img_paths) |
|
for i in range(length): |
|
img_path = img_paths[i] |
|
img_name = os.path.basename(img_path) |
|
self.img_names_.append(img_name) |
|
self.img_paths_.append(img_path) |
|
self.img_labels_.append(self.ctg_name2idx_[ctg_name]) |
|
|
|
print('log: category is %d(%s), image num is %d' % (self.ctg_name2idx_[ctg_name], ctg_name, length)) |
|
|
|
def load_data_from_txt(self, img_list_txt, ctg_list_txt): |
|
"""Load image from txt. |
|
|
|
Args: |
|
img_list_txt: image txt, format is [file_path, ctg_idx]. |
|
ctg_list_txt: category txt, format is [ctg_name, ctg_idx]. |
|
""" |
|
|
|
assert os.path.exists(img_list_txt), 'log: does not exist: {}'.format(img_list_txt) |
|
assert os.path.exists(ctg_list_txt), 'log: does not exist: {}'.format(ctg_list_txt) |
|
|
|
|
|
|
|
with open(ctg_list_txt) as f: |
|
ctg_infos = [line.strip() for line in f.readlines()] |
|
|
|
for ctg_info in ctg_infos: |
|
tmp = ctg_info.split(' ') |
|
ctg_name = tmp[0] |
|
ctg_idx = int(tmp[1]) |
|
self.ctg_name2idx_[ctg_name] = ctg_idx |
|
self.ctg_names_.append(ctg_name) |
|
|
|
|
|
|
|
with open(img_list_txt) as f: |
|
img_infos = [line.strip() for line in f.readlines()] |
|
random.shuffle(img_infos) |
|
|
|
for img_info in img_infos: |
|
img_path, ctg_name = img_info.rsplit(' ', 1) |
|
img_name = img_path.split('/')[-1] |
|
ctg_idx = int(ctg_name) |
|
self.img_names_.append(img_name) |
|
self.img_paths_.append(img_path) |
|
self.img_labels_.append(ctg_idx) |
|
|
|
for ctg_name in self.ctg_names_: |
|
print('log: category is %d(%s), image num is %d' % (self.ctg_name2idx_[ctg_name], ctg_name, self.img_labels_.count(self.ctg_name2idx_[ctg_name]))) |
|
|
|
def _add_new_channels_worker(self, image): |
|
new_channels = [] |
|
|
|
image = einops.rearrange(image, "h w c -> c h w") |
|
image = (image- torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_MEAN).view(-1, 1, 1)) / torch.as_tensor(timm.data.constants.IMAGENET_DEFAULT_STD).view(-1, 1, 1) |
|
srm = self.srm(image.unsqueeze(0)).squeeze(0) |
|
new_channels.append(einops.rearrange(srm, "c h w -> h w c").numpy()) |
|
|
|
new_channels = np.concatenate(new_channels, axis=2) |
|
return torch.from_numpy(new_channels).float() |
|
|
|
def add_new_channels(self, images): |
|
images_copied = einops.rearrange(images, "c h w -> h w c") |
|
new_channels = self._add_new_channels_worker(images_copied) |
|
images_copied = torch.concatenate([images_copied, new_channels], dim=-1) |
|
images_copied = einops.rearrange(images_copied, "h w c -> c h w") |
|
|
|
return images_copied |
|
|
|
def __getitem__(self, index): |
|
img_path = self.img_paths_[index] |
|
|
|
img_data = Image.open(img_path).convert('RGB') |
|
img_size = img_data.size[::-1] |
|
|
|
all_data = [] |
|
for transform in self.transformer_: |
|
current_data = transform(img_data) |
|
current_data = self.add_new_channels(current_data) |
|
all_data.append(current_data) |
|
img_label = self.img_labels_[index] |
|
|
|
return torch.stack(all_data, dim=0), img_label, img_path, img_size |
|
|
|
def __len__(self): |
|
return len(self.img_names_) |
|
|