File size: 6,508 Bytes
8a6f09e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from torch.utils.data import Dataset
import numpy as np
import h5py, torch, random, logging
from skimage.feature import peak_local_max
from skimage import measure

def clean_patch(p, center):
    w, h = p.shape
    cc = measure.label(p > 0)
    if cc.max() == 1:
        return p

    # logging.warn(f"{cc.max()} peaks located in a patch")
    lmin = np.inf
    cc_lmin = None
    for _c in range(1, cc.max()+1):
        lmax = peak_local_max(p * (cc==_c), min_distance=1)
        if lmax.shape[0] == 0:continue # single pixel component
        lc = lmax.mean(axis=0)
        dist = ((lc - center)**2).sum()
        if dist < lmin:
            cc_lmin = _c
            lmin = dist
    return p * (cc == cc_lmin)

class BraggNNDataset(Dataset):
    def __init__(self, pfile, ffile, psz=15, rnd_shift=0, use='train', train_frac=0.8):
        self.psz = psz 
        self.rnd_shift = rnd_shift

        with h5py.File(pfile, "r") as h5fd: 
            if use == 'train':
                sti, edi = 0, int(train_frac * h5fd['peak_fidx'].shape[0])
            elif use == 'validation':
                sti, edi = int(train_frac * h5fd['peak_fidx'].shape[0]), None
            else:
                logging.error(f"unsupported use: {use}. This class is written for building either training or validation set")

            mask = h5fd['npeaks'][sti:edi] == 1 # use only single-peak patches
            mask = mask & ((h5fd['deviations'][sti:edi] >= 0) & (h5fd['deviations'][sti:edi] < 1))

            self.peak_fidx= h5fd['peak_fidx'][sti:edi][mask]
            self.peak_row = h5fd['peak_row'][sti:edi][mask]
            self.peak_col = h5fd['peak_col'][sti:edi][mask]

        self.fidx_base = self.peak_fidx.min()
        # only loaded frames that will be used
        with h5py.File(ffile, "r") as h5fd: 
            self.frames = h5fd['frames'][self.peak_fidx.min():self.peak_fidx.max()+1]

        self.len = self.peak_fidx.shape[0]

    def __getitem__(self, idx):
        _frame = self.frames[self.peak_fidx[idx] - self.fidx_base]
        if self.rnd_shift > 0:
            row_shift = np.random.randint(-self.rnd_shift, self.rnd_shift+1)
            col_shift = np.random.randint(-self.rnd_shift, self.rnd_shift+1)
        else:
            row_shift, col_shift = 0, 0
        prow_rnd  = int(self.peak_row[idx]) + row_shift
        pcol_rnd  = int(self.peak_col[idx]) + col_shift

        row_base = max(0, prow_rnd-self.psz//2)
        col_base = max(0, pcol_rnd-self.psz//2 )

        crop_img = _frame[row_base:(prow_rnd + self.psz//2 + self.psz%2), \
                            col_base:(pcol_rnd + self.psz//2  + self.psz%2)]
        # if((crop_img > 0).sum() == 1): continue # ignore single non-zero peak
        if crop_img.size != self.psz ** 2:
            c_pad_l = (self.psz - crop_img.shape[1]) // 2
            c_pad_r = self.psz - c_pad_l - crop_img.shape[1]

            r_pad_t = (self.psz - crop_img.shape[0]) // 2
            r_pad_b = self.psz - r_pad_t - crop_img.shape[0]

            logging.warn(f"sample {idx} touched edge when crop the patch: {crop_img.shape}")
            crop_img = np.pad(crop_img, ((r_pad_t, r_pad_b), (c_pad_l, c_pad_r)), mode='constant')
        else:
            c_pad_l, r_pad_t = 0 ,0

        _center = np.array([self.peak_row[idx] - row_base + r_pad_t, self.peak_col[idx] - col_base + c_pad_l])
        crop_img = clean_patch(crop_img, _center)
        if crop_img.max() != crop_img.min():
            _min, _max = crop_img.min().astype(np.float32), crop_img.max().astype(np.float32)
            feature = (crop_img - _min) / (_max - _min)
        else:
            logging.warn("sample %d has unique intensity sum of %d" % (idx, crop_img.sum()))
            feature = crop_img

        px = (self.peak_col[idx] - col_base + c_pad_l) / self.psz
        py = (self.peak_row[idx] - row_base + r_pad_t) / self.psz

        return feature[np.newaxis], np.array([px, py]).astype(np.float32)

    def __len__(self):
        return self.len


class PatchWiseDataset(Dataset):
    def __init__(self, pfile, ffile, psz=15, rnd_shift=0, use='train', train_frac=1):
        self.psz = psz 
        self.rnd_shift = rnd_shift
        with h5py.File(pfile, "r") as h5fd: 
            if use == 'train':
                sti, edi = 0, int(train_frac * h5fd['peak_fidx'].shape[0])
            elif use == 'validation':
                sti, edi = int(train_frac * h5fd['peak_fidx'].shape[0]), None
            else:
                logging.error(f"unsupported use: {use}. This class is written for building either training or validation set")

            mask = h5fd['npeaks'][sti:edi] == 1 # use only single-peak patches
            mask = mask & ((h5fd['deviations'][sti:edi] >= 0) & (h5fd['deviations'][sti:edi] < 1))

            self.peak_fidx= h5fd['peak_fidx'][sti:edi][mask]
            self.peak_row = h5fd['peak_row'][sti:edi][mask]
            self.peak_col = h5fd['peak_col'][sti:edi][mask]

        self.fidx_base = self.peak_fidx.min()
        # only loaded frames that will be used
        with h5py.File(ffile, 'r') as h5fd: 
            if use == 'train':
                sti, edi = 0, int(train_frac * h5fd['frames'].shape[0])
            elif use == 'validation':
                sti, edi = int(train_frac * h5fd['frames'].shape[0]), None
            else:
                logging.error(f"unsupported use: {use}. This class is written for building either training or validation set")

            self.crop_img = h5fd['frames'][sti:edi]
            self.len = self.peak_fidx.shape[0]
    
    def __getitem__(self, idx):
        crop_img = self.crop_img[idx] 
        
        row_shift, col_shift = 0, 0
        c_pad_l, r_pad_t = 0 ,0
        prow_rnd  = int(self.peak_row[idx]) + row_shift
        pcol_rnd  = int(self.peak_col[idx]) + col_shift

        row_base = max(0, prow_rnd-self.psz//2)
        col_base = max(0, pcol_rnd-self.psz//2)
        
        if crop_img.max() != crop_img.min():
            _min, _max = crop_img.min().astype(np.float32), crop_img.max().astype(np.float32)
            feature = (crop_img - _min) / (_max - _min)
        else:
            #logging.warn("sample %d has unique intensity sum of %d" % (idx, crop_img.sum()))
            feature = crop_img

        px = (self.peak_col[idx] - col_base + c_pad_l) / self.psz
        py = (self.peak_row[idx] - row_base + r_pad_t) / self.psz

        return feature[np.newaxis], np.array([px, py]).astype(np.float32)

    def __len__(self):
        return self.len