import numpy as np import torch from torch.utils.data import Dataset import pdb class SpatialDataset(Dataset): def __init__(self, tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors = False ): self.tokenizer = tokenizer self.max_token_len = max_token_len self.distance_norm_factor = distance_norm_factor self.sep_between_neighbors = sep_between_neighbors def parse_spatial_context(self, pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill, pivot_dist_fill = 0): sep_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.sep_token) cls_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.cls_token) mask_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) pad_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token) max_token_len = self.max_token_len # process pivot pivot_name_tokens = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(pivot_name)) pivot_token_len = len(pivot_name_tokens) pivot_lng = pivot_pos[0] pivot_lat = pivot_pos[1] # prepare entity mask entity_mask_arr = [] rand_entity = np.random.uniform(size = len(neighbor_name_list) + 1) # random number for masking entities including neighbors and pivot # True for mask, False for unmask # check if pivot entity needs to be masked out, 15% prob. to be masked out if rand_entity[0] < 0.15: entity_mask_arr.extend([True] * pivot_token_len) else: entity_mask_arr.extend([False] * pivot_token_len) # process neighbors neighbor_token_list = [] neighbor_lng_list = [] neighbor_lat_list = [] # add separator between pivot and neighbor tokens # a trick to avoid adding separator token after the class name (for class name encoding of margin-ranking loss) if self.sep_between_neighbors and pivot_dist_fill==0: neighbor_lng_list.append(spatial_dist_fill) neighbor_lat_list.append(spatial_dist_fill) neighbor_token_list.append(sep_token_id) for neighbor_name, neighbor_geometry, rnd in zip(neighbor_name_list, neighbor_geometry_list, rand_entity[1:]): if not neighbor_name[0].isalpha(): # only consider neighbors starting with letters continue neighbor_token = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(neighbor_name)) neighbor_token_len = len(neighbor_token) # compute the relative distance from neighbor to pivot, # normalize the relative distance by distance_norm_factor # apply the calculated distance for all the subtokens of the neighbor # neighbor_lng_list.extend([(neighbor_geometry[0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len) # neighbor_lat_list.extend([(neighbor_geometry[1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len) if 'coordinates' in neighbor_geometry: # to handle different json dict structures neighbor_lng_list.extend([(neighbor_geometry['coordinates'][0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len) neighbor_lat_list.extend([(neighbor_geometry['coordinates'][1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len) neighbor_token_list.extend(neighbor_token) else: neighbor_lng_list.extend([(neighbor_geometry[0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len) neighbor_lat_list.extend([(neighbor_geometry[1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len) neighbor_token_list.extend(neighbor_token) if self.sep_between_neighbors: neighbor_lng_list.append(spatial_dist_fill) neighbor_lat_list.append(spatial_dist_fill) neighbor_token_list.append(sep_token_id) entity_mask_arr.extend([False]) if rnd < 0.15: #True: mask out, False: Keey original token entity_mask_arr.extend([True] * neighbor_token_len) else: entity_mask_arr.extend([False] * neighbor_token_len) pseudo_sentence = pivot_name_tokens + neighbor_token_list dist_lng_list = [pivot_dist_fill] * pivot_token_len + neighbor_lng_list dist_lat_list = [pivot_dist_fill] * pivot_token_len + neighbor_lat_list #including cls and sep sent_len = len(pseudo_sentence) max_token_len_middle = max_token_len -2 # 2 for CLS and SEP token # padding and truncation if sent_len > max_token_len_middle : pseudo_sentence = [cls_token_id] + pseudo_sentence[:max_token_len_middle] + [sep_token_id] dist_lat_list = [spatial_dist_fill] + dist_lat_list[:max_token_len_middle]+ [spatial_dist_fill] dist_lng_list = [spatial_dist_fill] + dist_lng_list[:max_token_len_middle]+ [spatial_dist_fill] attention_mask = [False] + [1] * max_token_len_middle + [False] # make sure SEP and CLS are not attented to else: pad_len = max_token_len_middle - sent_len assert pad_len >= 0 pseudo_sentence = [cls_token_id] + pseudo_sentence + [sep_token_id] + [pad_token_id] * pad_len dist_lat_list = [spatial_dist_fill] + dist_lat_list + [spatial_dist_fill] + [spatial_dist_fill] * pad_len dist_lng_list = [spatial_dist_fill] + dist_lng_list + [spatial_dist_fill] + [spatial_dist_fill] * pad_len attention_mask = [False] + [1] * sent_len + [0] * pad_len + [False] norm_lng_list = np.array(dist_lng_list) # / 0.0001 norm_lat_list = np.array(dist_lat_list) # / 0.0001 # mask entity in the pseudo sentence entity_mask_indices = np.where(entity_mask_arr)[0] masked_entity_input = [mask_token_id if i in entity_mask_indices else pseudo_sentence[i] for i in range(0, max_token_len)] # mask token in the pseudo sentence rand_token = np.random.uniform(size = len(pseudo_sentence)) # do not mask out cls and sep token. True: masked tokens False: Keey original token token_mask_arr = (rand_token <0.15) & (np.array(pseudo_sentence) != cls_token_id) & (np.array(pseudo_sentence) != sep_token_id) & (np.array(pseudo_sentence) != pad_token_id) token_mask_indices = np.where(token_mask_arr)[0] masked_token_input = [mask_token_id if i in token_mask_indices else pseudo_sentence[i] for i in range(0, max_token_len)] # yield masked_token with 50% prob, masked_entity with 50% prob if np.random.rand() > 0.5: masked_input = torch.tensor(masked_entity_input) else: masked_input = torch.tensor(masked_token_input) train_data = {} train_data['pivot_name'] = pivot_name train_data['pivot_token_len'] = pivot_token_len train_data['masked_input'] = masked_input train_data['sent_position_ids'] = torch.tensor(np.arange(0, len(pseudo_sentence))) train_data['attention_mask'] = torch.tensor(attention_mask) train_data['norm_lng_list'] = torch.tensor(norm_lng_list).to(torch.float32) train_data['norm_lat_list'] = torch.tensor(norm_lat_list).to(torch.float32) train_data['pseudo_sentence'] = torch.tensor(pseudo_sentence) return train_data def __len__(self): return NotImplementedError def __getitem__(self, index): raise NotImplementedError