from torch.utils.data import Dataset import numpy as np import h5py, torch, random, logging from skimage.feature import peak_local_max from skimage import measure def clean_patch(p, center): w, h = p.shape cc = measure.label(p > 0) if cc.max() == 1: return p # logging.warn(f"{cc.max()} peaks located in a patch") lmin = np.inf cc_lmin = None for _c in range(1, cc.max()+1): lmax = peak_local_max(p * (cc==_c), min_distance=1) if lmax.shape[0] == 0:continue # single pixel component lc = lmax.mean(axis=0) dist = ((lc - center)**2).sum() if dist < lmin: cc_lmin = _c lmin = dist return p * (cc == cc_lmin) class BraggNNDataset(Dataset): def __init__(self, pfile, ffile, psz=15, rnd_shift=0, use='train', train_frac=0.8): self.psz = psz self.rnd_shift = rnd_shift with h5py.File(pfile, "r") as h5fd: if use == 'train': sti, edi = 0, int(train_frac * h5fd['peak_fidx'].shape[0]) elif use == 'validation': sti, edi = int(train_frac * h5fd['peak_fidx'].shape[0]), None else: logging.error(f"unsupported use: {use}. This class is written for building either training or validation set") mask = h5fd['npeaks'][sti:edi] == 1 # use only single-peak patches mask = mask & ((h5fd['deviations'][sti:edi] >= 0) & (h5fd['deviations'][sti:edi] < 1)) self.peak_fidx= h5fd['peak_fidx'][sti:edi][mask] self.peak_row = h5fd['peak_row'][sti:edi][mask] self.peak_col = h5fd['peak_col'][sti:edi][mask] self.fidx_base = self.peak_fidx.min() # only loaded frames that will be used with h5py.File(ffile, "r") as h5fd: self.frames = h5fd['frames'][self.peak_fidx.min():self.peak_fidx.max()+1] self.len = self.peak_fidx.shape[0] def __getitem__(self, idx): _frame = self.frames[self.peak_fidx[idx] - self.fidx_base] if self.rnd_shift > 0: row_shift = np.random.randint(-self.rnd_shift, self.rnd_shift+1) col_shift = np.random.randint(-self.rnd_shift, self.rnd_shift+1) else: row_shift, col_shift = 0, 0 prow_rnd = int(self.peak_row[idx]) + row_shift pcol_rnd = int(self.peak_col[idx]) + col_shift row_base = max(0, prow_rnd-self.psz//2) col_base = max(0, pcol_rnd-self.psz//2 ) crop_img = _frame[row_base:(prow_rnd + self.psz//2 + self.psz%2), \ col_base:(pcol_rnd + self.psz//2 + self.psz%2)] # if((crop_img > 0).sum() == 1): continue # ignore single non-zero peak if crop_img.size != self.psz ** 2: c_pad_l = (self.psz - crop_img.shape[1]) // 2 c_pad_r = self.psz - c_pad_l - crop_img.shape[1] r_pad_t = (self.psz - crop_img.shape[0]) // 2 r_pad_b = self.psz - r_pad_t - crop_img.shape[0] logging.warn(f"sample {idx} touched edge when crop the patch: {crop_img.shape}") crop_img = np.pad(crop_img, ((r_pad_t, r_pad_b), (c_pad_l, c_pad_r)), mode='constant') else: c_pad_l, r_pad_t = 0 ,0 _center = np.array([self.peak_row[idx] - row_base + r_pad_t, self.peak_col[idx] - col_base + c_pad_l]) crop_img = clean_patch(crop_img, _center) if crop_img.max() != crop_img.min(): _min, _max = crop_img.min().astype(np.float32), crop_img.max().astype(np.float32) feature = (crop_img - _min) / (_max - _min) else: logging.warn("sample %d has unique intensity sum of %d" % (idx, crop_img.sum())) feature = crop_img px = (self.peak_col[idx] - col_base + c_pad_l) / self.psz py = (self.peak_row[idx] - row_base + r_pad_t) / self.psz return feature[np.newaxis], np.array([px, py]).astype(np.float32) def __len__(self): return self.len class PatchWiseDataset(Dataset): def __init__(self, pfile, ffile, psz=15, rnd_shift=0, use='train', train_frac=1): self.psz = psz self.rnd_shift = rnd_shift with h5py.File(pfile, "r") as h5fd: if use == 'train': sti, edi = 0, int(train_frac * h5fd['peak_fidx'].shape[0]) elif use == 'validation': sti, edi = int(train_frac * h5fd['peak_fidx'].shape[0]), None else: logging.error(f"unsupported use: {use}. This class is written for building either training or validation set") mask = h5fd['npeaks'][sti:edi] == 1 # use only single-peak patches mask = mask & ((h5fd['deviations'][sti:edi] >= 0) & (h5fd['deviations'][sti:edi] < 1)) self.peak_fidx= h5fd['peak_fidx'][sti:edi][mask] self.peak_row = h5fd['peak_row'][sti:edi][mask] self.peak_col = h5fd['peak_col'][sti:edi][mask] self.fidx_base = self.peak_fidx.min() # only loaded frames that will be used with h5py.File(ffile, 'r') as h5fd: if use == 'train': sti, edi = 0, int(train_frac * h5fd['frames'].shape[0]) elif use == 'validation': sti, edi = int(train_frac * h5fd['frames'].shape[0]), None else: logging.error(f"unsupported use: {use}. This class is written for building either training or validation set") self.crop_img = h5fd['frames'][sti:edi] self.len = self.peak_fidx.shape[0] def __getitem__(self, idx): crop_img = self.crop_img[idx] row_shift, col_shift = 0, 0 c_pad_l, r_pad_t = 0 ,0 prow_rnd = int(self.peak_row[idx]) + row_shift pcol_rnd = int(self.peak_col[idx]) + col_shift row_base = max(0, prow_rnd-self.psz//2) col_base = max(0, pcol_rnd-self.psz//2) if crop_img.max() != crop_img.min(): _min, _max = crop_img.min().astype(np.float32), crop_img.max().astype(np.float32) feature = (crop_img - _min) / (_max - _min) else: #logging.warn("sample %d has unique intensity sum of %d" % (idx, crop_img.sum())) feature = crop_img px = (self.peak_col[idx] - col_base + c_pad_l) / self.psz py = (self.peak_row[idx] - row_base + r_pad_t) / self.psz return feature[np.newaxis], np.array([px, py]).astype(np.float32) def __len__(self): return self.len