""" Edge dataset from temporal complex """ from abc import ABC, abstractmethod from torch.utils.data import Dataset from PIL import Image import tensorflow as tf import numpy as np class DataHandlerAbstractClass(Dataset, ABC): def __init__(self, edge_to, edge_from, feature_vector) -> None: super().__init__() self.edge_to = edge_to self.edge_from = edge_from self.data = feature_vector @abstractmethod def __getitem__(self, item): pass @abstractmethod def __len__(self): pass class DataHandler(Dataset): def __init__(self, edge_to, edge_from, feature_vector, attention, transform=None): self.edge_to = edge_to self.edge_from = edge_from self.data = feature_vector self.attention = attention self.transform = transform def __getitem__(self, item): edge_to_idx = self.edge_to[item] edge_from_idx = self.edge_from[item] edge_to = self.data[edge_to_idx] edge_from = self.data[edge_from_idx] a_to = self.attention[edge_to_idx] a_from = self.attention[edge_from_idx] if self.transform is not None: # TODO correct or not? edge_to = Image.fromarray(edge_to) edge_to = self.transform(edge_to) edge_from = Image.fromarray(edge_from) edge_from = self.transform(edge_from) return edge_to, edge_from, a_to, a_from def __len__(self): # return the number of all edges return len(self.edge_to) class HybridDataHandler(Dataset): def __init__(self, edge_to, edge_from, feature_vector, attention, embedded, coefficient, transform=None): self.edge_to = edge_to self.edge_from = edge_from self.data = feature_vector self.attention = attention self.embedded = embedded # replay of positions generated by previous visuaization self.coefficient = coefficient # whether samples have generated positions self.transform = transform def __getitem__(self, item): edge_to_idx = self.edge_to[item] edge_from_idx = self.edge_from[item] edge_to = self.data[edge_to_idx] edge_from = self.data[edge_from_idx] a_to = self.attention[edge_to_idx] a_from = self.attention[edge_from_idx] embedded_to = self.embedded[edge_to_idx] coeffi_to = self.coefficient[edge_to_idx] if self.transform is not None: # TODO correct or not? edge_to = Image.fromarray(edge_to) edge_to = self.transform(edge_to) edge_from = Image.fromarray(edge_from) edge_from = self.transform(edge_from) return edge_to, edge_from, a_to, a_from, embedded_to, coeffi_to def __len__(self): # return the number of all edges return len(self.edge_to) class DVIDataHandler(Dataset): def __init__(self, edge_to, edge_from, feature_vector, attention, transform=None): self.edge_to = edge_to self.edge_from = edge_from self.data = feature_vector self.attention = attention self.transform = transform def __getitem__(self, item): edge_to_idx = self.edge_to[item] edge_from_idx = self.edge_from[item] edge_to = self.data[edge_to_idx] edge_from = self.data[edge_from_idx] a_to = self.attention[edge_to_idx] a_from = self.attention[edge_from_idx] if self.transform is not None: # TODO correct or not? edge_to = Image.fromarray(edge_to) edge_to = self.transform(edge_to) edge_from = Image.fromarray(edge_from) edge_from = self.transform(edge_from) return edge_to, edge_from, a_to, a_from def __len__(self): # return the number of all edges return len(self.edge_to) # tf.dataset def construct_edge_dataset( edges_to_exp, edges_from_exp, weight, data, alpha, n_rate, batch_size ): def gather_index(index): return data[index] def gather_alpha(index): return alpha[index] gather_indices_in_python = True if data.nbytes * 1e-9 > 0.5 else False def gather_X(edge_to, edge_from, weight): if gather_indices_in_python: # if True: edge_to_batch = tf.py_function(gather_index, [edge_to], [tf.float32])[0] edge_from_batch = tf.py_function(gather_index, [edge_from], [tf.float32])[0] alpha_to = tf.py_function(gather_alpha, [edge_to], [tf.float32])[0] alpha_from = tf.py_function(gather_alpha, [edge_from], [tf.float32])[0] else: edge_to_batch = tf.gather(data, edge_to) edge_from_batch = tf.gather(data, edge_from) alpha_to = tf.gather(alpha, edge_to) alpha_from = tf.gather(alpha, edge_from) to_n_rate = tf.gather(n_rates, edge_to) outputs = {"umap": 0} outputs["reconstruction"] = edge_to_batch return (edge_to_batch, edge_from_batch, alpha_to, alpha_from, to_n_rate, weight), outputs # shuffle edges shuffle_mask = np.random.permutation(range(len(edges_to_exp))) edges_to_exp = edges_to_exp[shuffle_mask].astype(np.int64) edges_from_exp = edges_from_exp[shuffle_mask].astype(np.int64) weight = weight[shuffle_mask].astype(np.float64) weight = np.expand_dims(weight, axis=1) n_rates = np.expand_dims(n_rate, axis=1) # create edge iterator edge_dataset = tf.data.Dataset.from_tensor_slices( (edges_to_exp, edges_from_exp, weight) ) edge_dataset = edge_dataset.repeat() edge_dataset = edge_dataset.shuffle(10000) edge_dataset = edge_dataset.batch(batch_size, drop_remainder=True) edge_dataset = edge_dataset.map( gather_X, num_parallel_calls=tf.data.experimental.AUTOTUNE ) edge_dataset = edge_dataset.prefetch(10) return edge_dataset