SpaGAN / models /spabert /datasets /dataset_loader_ver2.py
JasonTPhillipsJr's picture
Upload 76 files
46e0dd0 verified
raw
history blame
8.15 kB
import numpy as np
import torch
from torch.utils.data import Dataset
import pdb
class SpatialDataset(Dataset):
def __init__(self, tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors = False ):
self.tokenizer = tokenizer
self.max_token_len = max_token_len
self.distance_norm_factor = distance_norm_factor
self.sep_between_neighbors = sep_between_neighbors
def parse_spatial_context(self, pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill, pivot_dist_fill = 0):
sep_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.sep_token)
cls_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.cls_token)
#mask_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
pad_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token)
max_token_len = self.max_token_len
#print("Module reloaded and changes are reflected")
# process pivot
pivot_name_tokens = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(pivot_name))
pivot_token_len = len(pivot_name_tokens)
pivot_lng = pivot_pos[0]
pivot_lat = pivot_pos[1]
# prepare entity mask
entity_mask_arr = []
rand_entity = np.random.uniform(size = len(neighbor_name_list) + 1) # random number for masking entities including neighbors and pivot
# True for mask, False for unmask
# check if pivot entity needs to be masked out, 15% prob. to be masked out
#if rand_entity[0] < 0.15:
# entity_mask_arr.extend([True] * pivot_token_len)
#else:
entity_mask_arr.extend([False] * pivot_token_len)
# process neighbors
neighbor_token_list = []
neighbor_lng_list = []
neighbor_lat_list = []
# add separator between pivot and neighbor tokens
# a trick to avoid adding separator token after the class name (for class name encoding of margin-ranking loss)
if self.sep_between_neighbors and pivot_dist_fill==0:
neighbor_lng_list.append(spatial_dist_fill)
neighbor_lat_list.append(spatial_dist_fill)
neighbor_token_list.append(sep_token_id)
for neighbor_name, neighbor_geometry, rnd in zip(neighbor_name_list, neighbor_geometry_list, rand_entity[1:]):
if not neighbor_name[0].isalpha():
# only consider neighbors starting with letters
continue
neighbor_token = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(neighbor_name))
neighbor_token_len = len(neighbor_token)
# compute the relative distance from neighbor to pivot,
# normalize the relative distance by distance_norm_factor
# apply the calculated distance for all the subtokens of the neighbor
# neighbor_lng_list.extend([(neighbor_geometry[0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len)
# neighbor_lat_list.extend([(neighbor_geometry[1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len)
if 'coordinates' in neighbor_geometry: # to handle different json dict structures
neighbor_lng_list.extend([(neighbor_geometry['coordinates'][0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len)
neighbor_lat_list.extend([(neighbor_geometry['coordinates'][1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len)
neighbor_token_list.extend(neighbor_token)
else:
neighbor_lng_list.extend([(neighbor_geometry[0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len)
neighbor_lat_list.extend([(neighbor_geometry[1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len)
neighbor_token_list.extend(neighbor_token)
if self.sep_between_neighbors:
neighbor_lng_list.append(spatial_dist_fill)
neighbor_lat_list.append(spatial_dist_fill)
neighbor_token_list.append(sep_token_id)
entity_mask_arr.extend([False])
#if rnd < 0.15:
# #True: mask out, False: Keey original token
# entity_mask_arr.extend([True] * neighbor_token_len)
#else:
entity_mask_arr.extend([False] * neighbor_token_len)
pseudo_sentence = pivot_name_tokens + neighbor_token_list
dist_lng_list = [pivot_dist_fill] * pivot_token_len + neighbor_lng_list
dist_lat_list = [pivot_dist_fill] * pivot_token_len + neighbor_lat_list
#including cls and sep
sent_len = len(pseudo_sentence)
max_token_len_middle = max_token_len -2 # 2 for CLS and SEP token
# padding and truncation
if sent_len > max_token_len_middle :
pseudo_sentence = [cls_token_id] + pseudo_sentence[:max_token_len_middle] + [sep_token_id]
dist_lat_list = [spatial_dist_fill] + dist_lat_list[:max_token_len_middle]+ [spatial_dist_fill]
dist_lng_list = [spatial_dist_fill] + dist_lng_list[:max_token_len_middle]+ [spatial_dist_fill]
attention_mask = [False] + [1] * max_token_len_middle + [False] # make sure SEP and CLS are not attented to
else:
pad_len = max_token_len_middle - sent_len
assert pad_len >= 0
pseudo_sentence = [cls_token_id] + pseudo_sentence + [sep_token_id] + [pad_token_id] * pad_len
dist_lat_list = [spatial_dist_fill] + dist_lat_list + [spatial_dist_fill] + [spatial_dist_fill] * pad_len
dist_lng_list = [spatial_dist_fill] + dist_lng_list + [spatial_dist_fill] + [spatial_dist_fill] * pad_len
attention_mask = [False] + [1] * sent_len + [0] * pad_len + [False]
norm_lng_list = np.array(dist_lng_list) # / 0.0001
norm_lat_list = np.array(dist_lat_list) # / 0.0001
## mask entity in the pseudo sentence
#entity_mask_indices = np.where(entity_mask_arr)[0]
#masked_entity_input = [mask_token_id if i in entity_mask_indices else pseudo_sentence[i] for i in range(0, max_token_len)]
#
#
## mask token in the pseudo sentence
#rand_token = np.random.uniform(size = len(pseudo_sentence))
## do not mask out cls and sep token. True: masked tokens False: Keey original token
#token_mask_arr = (rand_token <0.15) & (np.array(pseudo_sentence) != cls_token_id) & (np.array(pseudo_sentence) != sep_token_id) & (np.array(pseudo_sentence) != pad_token_id)
#token_mask_indices = np.where(token_mask_arr)[0]
#
#masked_token_input = [mask_token_id if i in token_mask_indices else pseudo_sentence[i] for i in range(0, max_token_len)]
#
#
## yield masked_token with 50% prob, masked_entity with 50% prob
#if np.random.rand() > 0.5:
# masked_input = torch.tensor(masked_entity_input)
#else:
# masked_input = torch.tensor(masked_token_input)
masked_input = torch.tensor(pseudo_sentence)
train_data = {}
train_data['pivot_name'] = pivot_name
train_data['pivot_token_len'] = pivot_token_len
train_data['masked_input'] = masked_input
train_data['sent_position_ids'] = torch.tensor(np.arange(0, len(pseudo_sentence)))
train_data['attention_mask'] = torch.tensor(attention_mask)
train_data['norm_lng_list'] = torch.tensor(norm_lng_list).to(torch.float32)
train_data['norm_lat_list'] = torch.tensor(norm_lat_list).to(torch.float32)
train_data['pseudo_sentence'] = torch.tensor(pseudo_sentence)
return train_data
def __len__(self):
return NotImplementedError
def __getitem__(self, index):
raise NotImplementedError