File size: 5,928 Bytes
f291f4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""
Edge dataset from temporal complex
"""
from abc import ABC, abstractmethod
from torch.utils.data import Dataset
from PIL import Image
import tensorflow as tf
import numpy as np

class DataHandlerAbstractClass(Dataset, ABC):
    def __init__(self, edge_to, edge_from, feature_vector) -> None:
        super().__init__()
        self.edge_to = edge_to
        self.edge_from = edge_from
        self.data = feature_vector

    @abstractmethod
    def __getitem__(self, item):
        pass

    @abstractmethod
    def __len__(self):
        pass

class DataHandler(Dataset):
    def __init__(self, edge_to, edge_from, feature_vector, attention, transform=None):
        self.edge_to = edge_to
        self.edge_from = edge_from
        self.data = feature_vector
        self.attention = attention
        self.transform = transform

    def __getitem__(self, item):

        edge_to_idx = self.edge_to[item]
        edge_from_idx = self.edge_from[item]
        edge_to = self.data[edge_to_idx]
        edge_from = self.data[edge_from_idx]
        a_to = self.attention[edge_to_idx]
        a_from = self.attention[edge_from_idx]
        if self.transform is not None:
            # TODO correct or not?
            edge_to = Image.fromarray(edge_to)
            edge_to = self.transform(edge_to)
            edge_from = Image.fromarray(edge_from)
            edge_from = self.transform(edge_from)
        return edge_to, edge_from, a_to, a_from

    def __len__(self):
        # return the number of all edges
        return len(self.edge_to)


class HybridDataHandler(Dataset):
    def __init__(self, edge_to, edge_from, feature_vector, attention, embedded, coefficient, transform=None):
        self.edge_to = edge_to
        self.edge_from = edge_from
        self.data = feature_vector
        self.attention = attention
        self.embedded = embedded # replay of positions generated by previous visuaization
        self.coefficient = coefficient # whether samples have generated positions
        self.transform = transform

    def __getitem__(self, item):

        edge_to_idx = self.edge_to[item]
        edge_from_idx = self.edge_from[item]
        edge_to = self.data[edge_to_idx]
        edge_from = self.data[edge_from_idx]
        a_to = self.attention[edge_to_idx]
        a_from = self.attention[edge_from_idx]

        embedded_to = self.embedded[edge_to_idx]
        coeffi_to = self.coefficient[edge_to_idx]
        if self.transform is not None:
            # TODO correct or not?
            edge_to = Image.fromarray(edge_to)
            edge_to = self.transform(edge_to)
            edge_from = Image.fromarray(edge_from)
            edge_from = self.transform(edge_from)
        return edge_to, edge_from, a_to, a_from, embedded_to, coeffi_to

    def __len__(self):
        # return the number of all edges
        return len(self.edge_to)

class DVIDataHandler(Dataset):
    def __init__(self, edge_to, edge_from, feature_vector, attention, transform=None):
        self.edge_to = edge_to
        self.edge_from = edge_from
        self.data = feature_vector
        self.attention = attention
        self.transform = transform
  

    def __getitem__(self, item):
        edge_to_idx = self.edge_to[item]
        edge_from_idx = self.edge_from[item]
        edge_to = self.data[edge_to_idx]
        edge_from = self.data[edge_from_idx]
        a_to = self.attention[edge_to_idx]
        a_from = self.attention[edge_from_idx]
        if self.transform is not None:
            # TODO correct or not?
            edge_to = Image.fromarray(edge_to)
            edge_to = self.transform(edge_to)
            edge_from = Image.fromarray(edge_from)
            edge_from = self.transform(edge_from)
        return edge_to, edge_from, a_to, a_from

    def __len__(self):
        # return the number of all edges
        return len(self.edge_to)

# tf.dataset
def construct_edge_dataset(
    edges_to_exp, edges_from_exp, weight, data, alpha, n_rate, batch_size
):

    def gather_index(index):
        return data[index]

    def gather_alpha(index):
        return alpha[index]

    gather_indices_in_python = True if data.nbytes * 1e-9 > 0.5 else False

    def gather_X(edge_to, edge_from, weight):
        if gather_indices_in_python:
            # if True:
            edge_to_batch = tf.py_function(gather_index, [edge_to], [tf.float32])[0]
            edge_from_batch = tf.py_function(gather_index, [edge_from], [tf.float32])[0]
            alpha_to = tf.py_function(gather_alpha, [edge_to], [tf.float32])[0]
            alpha_from = tf.py_function(gather_alpha, [edge_from], [tf.float32])[0]
        else:
            edge_to_batch = tf.gather(data, edge_to)
            edge_from_batch = tf.gather(data, edge_from)
            alpha_to = tf.gather(alpha, edge_to)
            alpha_from = tf.gather(alpha, edge_from)
            
        to_n_rate = tf.gather(n_rates, edge_to)
        outputs = {"umap": 0}
        outputs["reconstruction"] = edge_to_batch

        return (edge_to_batch, edge_from_batch, alpha_to, alpha_from, to_n_rate, weight), outputs

    # shuffle edges
    shuffle_mask = np.random.permutation(range(len(edges_to_exp)))
    edges_to_exp = edges_to_exp[shuffle_mask].astype(np.int64)
    edges_from_exp = edges_from_exp[shuffle_mask].astype(np.int64)
    weight = weight[shuffle_mask].astype(np.float64)
    weight = np.expand_dims(weight, axis=1)
    n_rates = np.expand_dims(n_rate, axis=1)

    # create edge iterator
    edge_dataset = tf.data.Dataset.from_tensor_slices(
        (edges_to_exp, edges_from_exp, weight)
    )
    edge_dataset = edge_dataset.repeat()
    edge_dataset = edge_dataset.shuffle(10000)
    edge_dataset = edge_dataset.batch(batch_size, drop_remainder=True)
    edge_dataset = edge_dataset.map(
        gather_X, num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
    edge_dataset = edge_dataset.prefetch(10)
    return edge_dataset