Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
import json | |
import pdb | |
#from sklearn.metrics.pairwise import cosine_similarity | |
import torch | |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
def write_to_csv(out_dir, map_name, match_list): | |
out_path = os.path.join(out_dir, map_name + '.json') | |
with open(out_path, 'w') as f: | |
for match_dict in match_list: | |
json.dump(match_dict, f) | |
f.write('\n') | |
def load_spatial_bert_pretrained_weights(model, weight_path): | |
# load pretrained weights from SpatialBertLM to SpatialBertModel | |
#pre_trained_model=torch.load(os.path.join(model_save_dir, weight_file_name)) | |
pre_trained_model=torch.load(weight_path) | |
cnt_layers = 0 | |
cur_model_kvpair=model.state_dict() | |
for key,value in cur_model_kvpair.items(): | |
if 'bert.'+key in pre_trained_model: | |
cur_model_kvpair[key]=pre_trained_model['bert.'+key] | |
#print("weights loaded for", key) | |
cnt_layers += 1 | |
else: | |
print("No weight for", key) | |
print(cnt_layers, 'layers loaded') | |
model.load_state_dict(cur_model_kvpair) | |
return model | |
def get_spatialbert_embedding(entity, model, use_distance = True, agg = 'mean'): | |
pseudo_sentence = entity['pseudo_sentence'][None,:].to(device) | |
attention_mask = entity['attention_mask'][None,:].to(device) | |
sent_position_ids = entity['sent_position_ids'][None,:].to(device) | |
pivot_token_len = entity['pivot_token_len'] | |
if 'norm_lng_list' in entity and use_distance: | |
position_list_x = entity['norm_lng_list'][None,:].to(device) | |
position_list_y = entity['norm_lat_list'][None,:].to(device) | |
else: | |
position_list_x = [] | |
position_list_y = [] | |
outputs = model(input_ids = pseudo_sentence, attention_mask = attention_mask, sent_position_ids = sent_position_ids, | |
position_list_x = position_list_x, position_list_y = position_list_y) | |
embeddings = outputs.last_hidden_state | |
pivot_embed = embeddings[0][1:1+pivot_token_len] | |
if agg == 'mean': | |
pivot_embed = torch.mean(pivot_embed, axis = 0).detach().cpu().numpy() # (768, ) | |
elif agg == 'sum': | |
pivot_embed = torch.sum(pivot_embed, axis = 0).detach().cpu().numpy() # (768, ) | |
else: | |
raise NotImplementedError | |
return pivot_embed | |
def get_bert_embedding(entity, model, agg = 'mean'): | |
pseudo_sentence = entity['pseudo_sentence'].unsqueeze(0).to(device) | |
attention_mask = entity['attention_mask'].unsqueeze(0).to(device) | |
pivot_token_len = entity['pivot_token_len'] | |
outputs = model(input_ids = pseudo_sentence, attention_mask = attention_mask) | |
embeddings = outputs.last_hidden_state | |
pivot_embed = embeddings[0][1:1+pivot_token_len] | |
if agg == 'mean': | |
pivot_embed = torch.mean(pivot_embed, axis = 0).detach().cpu().numpy() # (768, ) | |
elif agg == 'sum': | |
pivot_embed = torch.sum(pivot_embed, axis = 0).detach().cpu().numpy() # (768, ) | |
else: | |
raise NotImplementedError | |
return pivot_embed |