|
from torch.utils.data import Dataset |
|
import numpy as np |
|
import h5py, torch, random, logging |
|
from skimage.feature import peak_local_max |
|
from skimage import measure |
|
|
|
def clean_patch(p, center): |
|
w, h = p.shape |
|
cc = measure.label(p > 0) |
|
if cc.max() == 1: |
|
return p |
|
|
|
|
|
lmin = np.inf |
|
cc_lmin = None |
|
for _c in range(1, cc.max()+1): |
|
lmax = peak_local_max(p * (cc==_c), min_distance=1) |
|
if lmax.shape[0] == 0:continue |
|
lc = lmax.mean(axis=0) |
|
dist = ((lc - center)**2).sum() |
|
if dist < lmin: |
|
cc_lmin = _c |
|
lmin = dist |
|
return p * (cc == cc_lmin) |
|
|
|
class BraggNNDataset(Dataset): |
|
def __init__(self, pfile, ffile, psz=15, rnd_shift=0, use='train', train_frac=0.8): |
|
self.psz = psz |
|
self.rnd_shift = rnd_shift |
|
|
|
with h5py.File(pfile, "r") as h5fd: |
|
if use == 'train': |
|
sti, edi = 0, int(train_frac * h5fd['peak_fidx'].shape[0]) |
|
elif use == 'validation': |
|
sti, edi = int(train_frac * h5fd['peak_fidx'].shape[0]), None |
|
else: |
|
logging.error(f"unsupported use: {use}. This class is written for building either training or validation set") |
|
|
|
mask = h5fd['npeaks'][sti:edi] == 1 |
|
mask = mask & ((h5fd['deviations'][sti:edi] >= 0) & (h5fd['deviations'][sti:edi] < 1)) |
|
|
|
self.peak_fidx= h5fd['peak_fidx'][sti:edi][mask] |
|
self.peak_row = h5fd['peak_row'][sti:edi][mask] |
|
self.peak_col = h5fd['peak_col'][sti:edi][mask] |
|
|
|
self.fidx_base = self.peak_fidx.min() |
|
|
|
with h5py.File(ffile, "r") as h5fd: |
|
self.frames = h5fd['frames'][self.peak_fidx.min():self.peak_fidx.max()+1] |
|
|
|
self.len = self.peak_fidx.shape[0] |
|
|
|
def __getitem__(self, idx): |
|
_frame = self.frames[self.peak_fidx[idx] - self.fidx_base] |
|
if self.rnd_shift > 0: |
|
row_shift = np.random.randint(-self.rnd_shift, self.rnd_shift+1) |
|
col_shift = np.random.randint(-self.rnd_shift, self.rnd_shift+1) |
|
else: |
|
row_shift, col_shift = 0, 0 |
|
prow_rnd = int(self.peak_row[idx]) + row_shift |
|
pcol_rnd = int(self.peak_col[idx]) + col_shift |
|
|
|
row_base = max(0, prow_rnd-self.psz//2) |
|
col_base = max(0, pcol_rnd-self.psz//2 ) |
|
|
|
crop_img = _frame[row_base:(prow_rnd + self.psz//2 + self.psz%2), \ |
|
col_base:(pcol_rnd + self.psz//2 + self.psz%2)] |
|
|
|
if crop_img.size != self.psz ** 2: |
|
c_pad_l = (self.psz - crop_img.shape[1]) // 2 |
|
c_pad_r = self.psz - c_pad_l - crop_img.shape[1] |
|
|
|
r_pad_t = (self.psz - crop_img.shape[0]) // 2 |
|
r_pad_b = self.psz - r_pad_t - crop_img.shape[0] |
|
|
|
logging.warn(f"sample {idx} touched edge when crop the patch: {crop_img.shape}") |
|
crop_img = np.pad(crop_img, ((r_pad_t, r_pad_b), (c_pad_l, c_pad_r)), mode='constant') |
|
else: |
|
c_pad_l, r_pad_t = 0 ,0 |
|
|
|
_center = np.array([self.peak_row[idx] - row_base + r_pad_t, self.peak_col[idx] - col_base + c_pad_l]) |
|
crop_img = clean_patch(crop_img, _center) |
|
if crop_img.max() != crop_img.min(): |
|
_min, _max = crop_img.min().astype(np.float32), crop_img.max().astype(np.float32) |
|
feature = (crop_img - _min) / (_max - _min) |
|
else: |
|
logging.warn("sample %d has unique intensity sum of %d" % (idx, crop_img.sum())) |
|
feature = crop_img |
|
|
|
px = (self.peak_col[idx] - col_base + c_pad_l) / self.psz |
|
py = (self.peak_row[idx] - row_base + r_pad_t) / self.psz |
|
|
|
return feature[np.newaxis], np.array([px, py]).astype(np.float32) |
|
|
|
def __len__(self): |
|
return self.len |
|
|
|
|
|
class PatchWiseDataset(Dataset): |
|
def __init__(self, pfile, ffile, psz=15, rnd_shift=0, use='train', train_frac=1): |
|
self.psz = psz |
|
self.rnd_shift = rnd_shift |
|
with h5py.File(pfile, "r") as h5fd: |
|
if use == 'train': |
|
sti, edi = 0, int(train_frac * h5fd['peak_fidx'].shape[0]) |
|
elif use == 'validation': |
|
sti, edi = int(train_frac * h5fd['peak_fidx'].shape[0]), None |
|
else: |
|
logging.error(f"unsupported use: {use}. This class is written for building either training or validation set") |
|
|
|
mask = h5fd['npeaks'][sti:edi] == 1 |
|
mask = mask & ((h5fd['deviations'][sti:edi] >= 0) & (h5fd['deviations'][sti:edi] < 1)) |
|
|
|
self.peak_fidx= h5fd['peak_fidx'][sti:edi][mask] |
|
self.peak_row = h5fd['peak_row'][sti:edi][mask] |
|
self.peak_col = h5fd['peak_col'][sti:edi][mask] |
|
|
|
self.fidx_base = self.peak_fidx.min() |
|
|
|
with h5py.File(ffile, 'r') as h5fd: |
|
if use == 'train': |
|
sti, edi = 0, int(train_frac * h5fd['frames'].shape[0]) |
|
elif use == 'validation': |
|
sti, edi = int(train_frac * h5fd['frames'].shape[0]), None |
|
else: |
|
logging.error(f"unsupported use: {use}. This class is written for building either training or validation set") |
|
|
|
self.crop_img = h5fd['frames'][sti:edi] |
|
self.len = self.peak_fidx.shape[0] |
|
|
|
def __getitem__(self, idx): |
|
crop_img = self.crop_img[idx] |
|
|
|
row_shift, col_shift = 0, 0 |
|
c_pad_l, r_pad_t = 0 ,0 |
|
prow_rnd = int(self.peak_row[idx]) + row_shift |
|
pcol_rnd = int(self.peak_col[idx]) + col_shift |
|
|
|
row_base = max(0, prow_rnd-self.psz//2) |
|
col_base = max(0, pcol_rnd-self.psz//2) |
|
|
|
if crop_img.max() != crop_img.min(): |
|
_min, _max = crop_img.min().astype(np.float32), crop_img.max().astype(np.float32) |
|
feature = (crop_img - _min) / (_max - _min) |
|
else: |
|
|
|
feature = crop_img |
|
|
|
px = (self.peak_col[idx] - col_base + c_pad_l) / self.psz |
|
py = (self.peak_row[idx] - row_base + r_pad_t) / self.psz |
|
|
|
return feature[np.newaxis], np.array([px, py]).astype(np.float32) |
|
|
|
def __len__(self): |
|
return self.len |
|
|
|
|