SalazarPevelll
be
f291f4a
"""
Edge dataset from temporal complex
"""
from abc import ABC, abstractmethod
from torch.utils.data import Dataset
from PIL import Image
import tensorflow as tf
import numpy as np
class DataHandlerAbstractClass(Dataset, ABC):
def __init__(self, edge_to, edge_from, feature_vector) -> None:
super().__init__()
self.edge_to = edge_to
self.edge_from = edge_from
self.data = feature_vector
@abstractmethod
def __getitem__(self, item):
pass
@abstractmethod
def __len__(self):
pass
class DataHandler(Dataset):
def __init__(self, edge_to, edge_from, feature_vector, attention, transform=None):
self.edge_to = edge_to
self.edge_from = edge_from
self.data = feature_vector
self.attention = attention
self.transform = transform
def __getitem__(self, item):
edge_to_idx = self.edge_to[item]
edge_from_idx = self.edge_from[item]
edge_to = self.data[edge_to_idx]
edge_from = self.data[edge_from_idx]
a_to = self.attention[edge_to_idx]
a_from = self.attention[edge_from_idx]
if self.transform is not None:
# TODO correct or not?
edge_to = Image.fromarray(edge_to)
edge_to = self.transform(edge_to)
edge_from = Image.fromarray(edge_from)
edge_from = self.transform(edge_from)
return edge_to, edge_from, a_to, a_from
def __len__(self):
# return the number of all edges
return len(self.edge_to)
class HybridDataHandler(Dataset):
def __init__(self, edge_to, edge_from, feature_vector, attention, embedded, coefficient, transform=None):
self.edge_to = edge_to
self.edge_from = edge_from
self.data = feature_vector
self.attention = attention
self.embedded = embedded # replay of positions generated by previous visuaization
self.coefficient = coefficient # whether samples have generated positions
self.transform = transform
def __getitem__(self, item):
edge_to_idx = self.edge_to[item]
edge_from_idx = self.edge_from[item]
edge_to = self.data[edge_to_idx]
edge_from = self.data[edge_from_idx]
a_to = self.attention[edge_to_idx]
a_from = self.attention[edge_from_idx]
embedded_to = self.embedded[edge_to_idx]
coeffi_to = self.coefficient[edge_to_idx]
if self.transform is not None:
# TODO correct or not?
edge_to = Image.fromarray(edge_to)
edge_to = self.transform(edge_to)
edge_from = Image.fromarray(edge_from)
edge_from = self.transform(edge_from)
return edge_to, edge_from, a_to, a_from, embedded_to, coeffi_to
def __len__(self):
# return the number of all edges
return len(self.edge_to)
class DVIDataHandler(Dataset):
def __init__(self, edge_to, edge_from, feature_vector, attention, transform=None):
self.edge_to = edge_to
self.edge_from = edge_from
self.data = feature_vector
self.attention = attention
self.transform = transform
def __getitem__(self, item):
edge_to_idx = self.edge_to[item]
edge_from_idx = self.edge_from[item]
edge_to = self.data[edge_to_idx]
edge_from = self.data[edge_from_idx]
a_to = self.attention[edge_to_idx]
a_from = self.attention[edge_from_idx]
if self.transform is not None:
# TODO correct or not?
edge_to = Image.fromarray(edge_to)
edge_to = self.transform(edge_to)
edge_from = Image.fromarray(edge_from)
edge_from = self.transform(edge_from)
return edge_to, edge_from, a_to, a_from
def __len__(self):
# return the number of all edges
return len(self.edge_to)
# tf.dataset
def construct_edge_dataset(
edges_to_exp, edges_from_exp, weight, data, alpha, n_rate, batch_size
):
def gather_index(index):
return data[index]
def gather_alpha(index):
return alpha[index]
gather_indices_in_python = True if data.nbytes * 1e-9 > 0.5 else False
def gather_X(edge_to, edge_from, weight):
if gather_indices_in_python:
# if True:
edge_to_batch = tf.py_function(gather_index, [edge_to], [tf.float32])[0]
edge_from_batch = tf.py_function(gather_index, [edge_from], [tf.float32])[0]
alpha_to = tf.py_function(gather_alpha, [edge_to], [tf.float32])[0]
alpha_from = tf.py_function(gather_alpha, [edge_from], [tf.float32])[0]
else:
edge_to_batch = tf.gather(data, edge_to)
edge_from_batch = tf.gather(data, edge_from)
alpha_to = tf.gather(alpha, edge_to)
alpha_from = tf.gather(alpha, edge_from)
to_n_rate = tf.gather(n_rates, edge_to)
outputs = {"umap": 0}
outputs["reconstruction"] = edge_to_batch
return (edge_to_batch, edge_from_batch, alpha_to, alpha_from, to_n_rate, weight), outputs
# shuffle edges
shuffle_mask = np.random.permutation(range(len(edges_to_exp)))
edges_to_exp = edges_to_exp[shuffle_mask].astype(np.int64)
edges_from_exp = edges_from_exp[shuffle_mask].astype(np.int64)
weight = weight[shuffle_mask].astype(np.float64)
weight = np.expand_dims(weight, axis=1)
n_rates = np.expand_dims(n_rate, axis=1)
# create edge iterator
edge_dataset = tf.data.Dataset.from_tensor_slices(
(edges_to_exp, edges_from_exp, weight)
)
edge_dataset = edge_dataset.repeat()
edge_dataset = edge_dataset.shuffle(10000)
edge_dataset = edge_dataset.batch(batch_size, drop_remainder=True)
edge_dataset = edge_dataset.map(
gather_X, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
edge_dataset = edge_dataset.prefetch(10)
return edge_dataset