File size: 3,126 Bytes
46e0dd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import numpy as np 
import json 
import pdb 
#from sklearn.metrics.pairwise import cosine_similarity


import torch

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')


def write_to_csv(out_dir, map_name, match_list):
    out_path = os.path.join(out_dir, map_name + '.json')

    with open(out_path, 'w') as f:
        for match_dict in match_list:
            json.dump(match_dict, f)
            f.write('\n')



def load_spatial_bert_pretrained_weights(model, weight_path):
   
    # load pretrained weights from SpatialBertLM to SpatialBertModel
    #pre_trained_model=torch.load(os.path.join(model_save_dir, weight_file_name))
    pre_trained_model=torch.load(weight_path)

    cnt_layers = 0
    cur_model_kvpair=model.state_dict()
    for key,value in cur_model_kvpair.items():
        if 'bert.'+key in pre_trained_model:
            cur_model_kvpair[key]=pre_trained_model['bert.'+key]     
            #print("weights loaded for", key)
            cnt_layers += 1
        else:
            print("No weight for", key)

    print(cnt_layers, 'layers loaded')

    model.load_state_dict(cur_model_kvpair)
    
    return model



def get_spatialbert_embedding(entity, model, use_distance = True, agg = 'mean'):
    
    pseudo_sentence = entity['pseudo_sentence'][None,:].to(device)
    attention_mask = entity['attention_mask'][None,:].to(device)
    sent_position_ids = entity['sent_position_ids'][None,:].to(device)
    pivot_token_len = entity['pivot_token_len'] 

    
    if 'norm_lng_list' in entity and use_distance:
        position_list_x = entity['norm_lng_list'][None,:].to(device)
        position_list_y = entity['norm_lat_list'][None,:].to(device)
    else:
        position_list_x = []
        position_list_y = []

    outputs = model(input_ids = pseudo_sentence, attention_mask = attention_mask,  sent_position_ids = sent_position_ids,
        position_list_x = position_list_x, position_list_y = position_list_y)


    embeddings = outputs.last_hidden_state

    
    pivot_embed = embeddings[0][1:1+pivot_token_len]
    if agg == 'mean':
        pivot_embed = torch.mean(pivot_embed, axis = 0).detach().cpu().numpy() # (768, )
    elif agg == 'sum':
        pivot_embed = torch.sum(pivot_embed, axis = 0).detach().cpu().numpy() # (768, )
    else:
        raise NotImplementedError
    
    return pivot_embed

def get_bert_embedding(entity, model, agg = 'mean'):
    
    pseudo_sentence = entity['pseudo_sentence'].unsqueeze(0).to(device)
    attention_mask = entity['attention_mask'].unsqueeze(0).to(device)
    pivot_token_len = entity['pivot_token_len'] 
    
    
    outputs = model(input_ids = pseudo_sentence, attention_mask = attention_mask)


    embeddings = outputs.last_hidden_state

    
    pivot_embed = embeddings[0][1:1+pivot_token_len]
    if agg == 'mean':
        pivot_embed = torch.mean(pivot_embed, axis = 0).detach().cpu().numpy() # (768, )
    elif agg == 'sum':
        pivot_embed = torch.sum(pivot_embed, axis = 0).detach().cpu().numpy() # (768, )
    else:
        raise NotImplementedError
    
    
    return pivot_embed