Spaces:
Sleeping
Sleeping
JasonTPhillipsJr
commited on
Upload 76 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -0
- models/spabert/README.md +69 -0
- models/spabert/__init__.py +0 -0
- models/spabert/datasets/__init__.py +0 -0
- models/spabert/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
- models/spabert/datasets/__pycache__/dataset_loader.cpython-310.pyc +0 -0
- models/spabert/datasets/__pycache__/dataset_loader_ver2.cpython-310.pyc +0 -0
- models/spabert/datasets/__pycache__/osm_sample_loader.cpython-310.pyc +0 -0
- models/spabert/datasets/__pycache__/usgs_os_sample_loader.cpython-310.pyc +0 -0
- models/spabert/datasets/__pycache__/wikidata_sample_loader.cpython-310.pyc +0 -0
- models/spabert/datasets/const.py +162 -0
- models/spabert/datasets/dataset_loader.py +162 -0
- models/spabert/datasets/dataset_loader_ver2.py +164 -0
- models/spabert/datasets/osm_sample_loader.py +246 -0
- models/spabert/datasets/usgs_os_sample_loader.py +71 -0
- models/spabert/datasets/wikidata_sample_loader.py +127 -0
- models/spabert/experiments/__init__.py +0 -0
- models/spabert/experiments/__pycache__/__init__.cpython-310.pyc +0 -0
- models/spabert/experiments/entity_matching/__init__.py +0 -0
- models/spabert/experiments/entity_matching/__pycache__/__init__.cpython-310.pyc +0 -0
- models/spabert/experiments/entity_matching/data_processing/__init__.py +0 -0
- models/spabert/experiments/entity_matching/data_processing/__pycache__/__init__.cpython-310.pyc +0 -0
- models/spabert/experiments/entity_matching/data_processing/__pycache__/__init__.cpython-311.pyc +0 -0
- models/spabert/experiments/entity_matching/data_processing/__pycache__/request_wrapper.cpython-310.pyc +0 -0
- models/spabert/experiments/entity_matching/data_processing/__pycache__/request_wrapper.cpython-311.pyc +0 -0
- models/spabert/experiments/entity_matching/data_processing/get_namelist.py +95 -0
- models/spabert/experiments/entity_matching/data_processing/request_wrapper.py +186 -0
- models/spabert/experiments/entity_matching/data_processing/run_linking_query.py +143 -0
- models/spabert/experiments/entity_matching/data_processing/run_map_neighbor_query.py +123 -0
- models/spabert/experiments/entity_matching/data_processing/run_query_sample.py +22 -0
- models/spabert/experiments/entity_matching/data_processing/run_wikidata_neighbor_query.py +31 -0
- models/spabert/experiments/entity_matching/data_processing/samples.sparql +22 -0
- models/spabert/experiments/entity_matching/data_processing/select_ambi.py +18 -0
- models/spabert/experiments/entity_matching/data_processing/wikidata_sample30k/wikidata_30k.json +0 -0
- models/spabert/experiments/entity_matching/src/evaluation-mrr.py +260 -0
- models/spabert/experiments/entity_matching/src/linking_ablation.py +228 -0
- models/spabert/experiments/entity_matching/src/unsupervised_wiki_location_allcand.py +329 -0
- models/spabert/experiments/semantic_typing/__init__.py +0 -0
- models/spabert/experiments/semantic_typing/data_processing/merge_osm_json.py +97 -0
- models/spabert/experiments/semantic_typing/src/__init__.py +0 -0
- models/spabert/experiments/semantic_typing/src/run_baseline_test.py +82 -0
- models/spabert/experiments/semantic_typing/src/test_cls_ablation_spatialbert.py +209 -0
- models/spabert/experiments/semantic_typing/src/test_cls_baseline.py +189 -0
- models/spabert/experiments/semantic_typing/src/test_cls_spatialbert.py +214 -0
- models/spabert/experiments/semantic_typing/src/train_cls_baseline.py +227 -0
- models/spabert/experiments/semantic_typing/src/train_cls_spatialbert.py +276 -0
- models/spabert/models/__init__.py +0 -0
- models/spabert/models/__pycache__/__init__.cpython-310.pyc +0 -0
- models/spabert/models/__pycache__/spatial_bert_model.cpython-310.pyc +0 -0
- models/spabert/models/baseline_typing_model.py +106 -0
.gitattributes
CHANGED
@@ -37,3 +37,7 @@ models/en_core_web_sm/en_core_web_sm-3.7.1/ner/model filter=lfs diff=lfs merge=l
|
|
37 |
models/en_core_web_sm/en_core_web_sm-3.7.1/tok2vec/model filter=lfs diff=lfs merge=lfs -text
|
38 |
models/en_core_web_sm/ner/model filter=lfs diff=lfs merge=lfs -text
|
39 |
models/en_core_web_sm/tok2vec/model filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
37 |
models/en_core_web_sm/en_core_web_sm-3.7.1/tok2vec/model filter=lfs diff=lfs merge=lfs -text
|
38 |
models/en_core_web_sm/ner/model filter=lfs diff=lfs merge=lfs -text
|
39 |
models/en_core_web_sm/tok2vec/model filter=lfs diff=lfs merge=lfs -text
|
40 |
+
models/spabert/notebooks/tutorial_datasets/output.csv.json filter=lfs diff=lfs merge=lfs -text
|
41 |
+
models/spabert/notebooks/tutorial_datasets/spabert_osm_mn.json filter=lfs diff=lfs merge=lfs -text
|
42 |
+
models/spabert/notebooks/tutorial_datasets/spabert_whg_wikidata.json filter=lfs diff=lfs merge=lfs -text
|
43 |
+
models/spabert/notebooks/tutorial_datasets/spabert_wikidata_sampled.json filter=lfs diff=lfs merge=lfs -text
|
models/spabert/README.md
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SpaBERT: A Pretrained Language Model from Geographic Data for Geo-Entity Representation
|
2 |
+
|
3 |
+
This repo contains code for [SpaBERT: A Pretrained Language Model from Geographic Data for Geo-Entity Representation](https://arxiv.org/abs/2210.12213) which was published in EMNLP 2022. SpaBERT provides a general-purpose geo-entity representation based on neighboring entities in geospatial data. SpaBERT extends BERT to capture linearized spatial context, while incorporating a spatial coordinate embedding mechanism to preserve spatial relations of entities in the 2-dimensional space. SpaBERT is pretrained with masked language modeling and masked entity prediction tasks to learn spatial dependencies.
|
4 |
+
|
5 |
+
* Slides: [emnlp22-spabert.pdf](https://drive.google.com/file/d/1V1URsRfpw13dbkb_zgBXeNqZJ0AF2744/view?usp=share_link)
|
6 |
+
|
7 |
+
|
8 |
+
## Pretraining
|
9 |
+
Pretrained model weights can be downloaded from the Google Drive for [SpaBERT-base](https://drive.google.com/file/d/1l44FY3DtDxzM_YVh3RR6PJwKnl80IYWB/view?usp=sharing) and [SpaBERT-large](https://drive.google.com/file/d/1LeZayTR92R5bu9gH_cGCwef7nnMX35cR/view?usp=share_link).
|
10 |
+
|
11 |
+
Weights can also obtained from training from scratch using the following sample code. Data for pretraining can be downloaded [here](https://drive.google.com/drive/folders/1eaeVvUCcJVcNwnyTCk-1N1IKfukihk4j?usp=share_link).
|
12 |
+
|
13 |
+
* Code to pretrain SpaBERT-base model:
|
14 |
+
|
15 |
+
```python3 train_mlm.py --lr=5e-5 --sep_between_neighbors --bert_option='bert-base'```
|
16 |
+
|
17 |
+
* Code to pretrain SpaBERT-large model:
|
18 |
+
|
19 |
+
```python3 train_mlm.py --lr=1e-6 --sep_between_neighbors --bert_option='bert-large```
|
20 |
+
|
21 |
+
## Downstream Tasks
|
22 |
+
### Supervised Geo-entity typing
|
23 |
+
The goal is to predict a geo-entity’s semantic type (e.g., transportation and healthcare) given the target geo-entity name and spatial context (i.e. surrounding neighbors name and location).
|
24 |
+
|
25 |
+
Models trained on OSM in London and California region can be downloaded from Google Drive for [SpaBERT-base](https://drive.google.com/file/d/1XFcA3sxC4wTlt7VjvMp1zNrWY5rjafzE/view?usp=share_link) and [SpaBERT-large](https://drive.google.com/file/d/12_FDVeSYkl_HQ61JmuMU6cRjQdKNpgR_/view?usp=share_link)
|
26 |
+
|
27 |
+
Data used for training and testing can be downloaded [here](https://drive.google.com/drive/folders/1uyvGdiJdu-Cym4dOKhQLIkKpfgHvfo01?usp=share_link)
|
28 |
+
|
29 |
+
* Sample code for training SpaBERT-base typing model
|
30 |
+
|
31 |
+
```
|
32 |
+
python3 train_cls_spatialbert.py --lr=5e-5 --sep_between_neighbors --bert_option='bert-base' --with_type --mlm_checkpoint_path='mlm_mem_keeppos_ep0_iter06000_0.2936.pth'
|
33 |
+
```
|
34 |
+
|
35 |
+
* Sample code for training SpaBERT-large typing model
|
36 |
+
|
37 |
+
```
|
38 |
+
python3 train_cls_spatialbert.py --lr=1e-6 --sep_between_neighbors --bert_option='bert-large' --with_type --mlm_checkpoint_path='mlm_mem_keeppos_ep1_iter02000_0.4400.pth' --epochs=20
|
39 |
+
```
|
40 |
+
|
41 |
+
### Unsupervised Geo-entity Linking
|
42 |
+
|
43 |
+
Geo-entity linking is to link geo-entities from a geographic information system (GIS) oriented dataset to a knowledge base (KB). This task unsupervised thus does not require any further training. Pretrained models can be directly used for this task.
|
44 |
+
|
45 |
+
|
46 |
+
Linking with SpaBERT-base
|
47 |
+
```
|
48 |
+
python3 unsupervised_wiki_location_allcand.py --model_name='spatial_bert-base' --sep_between_neighbors \
|
49 |
+
--spatial_bert_weight_dir='weights/' --spatial_bert_weight_name='mlm_mem_keeppos_ep0_iter06000_0.2936.pth'
|
50 |
+
|
51 |
+
```
|
52 |
+
|
53 |
+
Linking with SpaBERT-large
|
54 |
+
```
|
55 |
+
python3 unsupervised_wiki_location_allcand.py --model_name='spatial_bert-large' --sep_between_neighbors \
|
56 |
+
--spatial_bert_weight_dir='weights/' --spatial_bert_weight_name='mlm_mem_keeppos_ep1_iter02000_0.4400.pth'
|
57 |
+
```
|
58 |
+
|
59 |
+
Data used for linking from USGS historical maps to WikiData KB is provided [here](https://drive.google.com/drive/folders/1qKJnj71qxnca_TaygK-Y3EIySnMyFpFn?usp=share_link)
|
60 |
+
|
61 |
+
## Acknowledgement
|
62 |
+
```
|
63 |
+
@article{li2022spabert,
|
64 |
+
title={SpaBERT: A Pretrained Language Model from Geographic Data for Geo-Entity Representation},
|
65 |
+
author={Zekun Li, Jina Kim, Yao-Yi Chiang and Muhao Chen},
|
66 |
+
journal={EMNLP},
|
67 |
+
year={2022}
|
68 |
+
}
|
69 |
+
```
|
models/spabert/__init__.py
ADDED
File without changes
|
models/spabert/datasets/__init__.py
ADDED
File without changes
|
models/spabert/datasets/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (152 Bytes). View file
|
|
models/spabert/datasets/__pycache__/dataset_loader.cpython-310.pyc
ADDED
Binary file (3.75 kB). View file
|
|
models/spabert/datasets/__pycache__/dataset_loader_ver2.cpython-310.pyc
ADDED
Binary file (3 kB). View file
|
|
models/spabert/datasets/__pycache__/osm_sample_loader.cpython-310.pyc
ADDED
Binary file (5.79 kB). View file
|
|
models/spabert/datasets/__pycache__/usgs_os_sample_loader.cpython-310.pyc
ADDED
Binary file (2.26 kB). View file
|
|
models/spabert/datasets/__pycache__/wikidata_sample_loader.cpython-310.pyc
ADDED
Binary file (3.36 kB). View file
|
|
models/spabert/datasets/const.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def revert_dict(coarse_to_fine_dict):
|
2 |
+
fine_to_coarse_dict = dict()
|
3 |
+
for key, value in coarse_to_fine_dict.items():
|
4 |
+
for v in value:
|
5 |
+
fine_to_coarse_dict[v] = key
|
6 |
+
return fine_to_coarse_dict
|
7 |
+
|
8 |
+
CLASS_9_LIST = ['education', 'entertainment_arts_culture', 'facilities', 'financial', 'healthcare', 'public_service', 'sustenance', 'transportation', 'waste_management']
|
9 |
+
|
10 |
+
CLASS_118_LIST=['animal_boarding', 'animal_breeding', 'animal_shelter', 'arts_centre', 'atm', 'baby_hatch', 'baking_oven', 'bank', 'bar', 'bbq', 'bench', 'bicycle_parking', 'bicycle_rental', 'bicycle_repair_station', 'biergarten', 'boat_rental', 'boat_sharing', 'brothel', 'bureau_de_change', 'bus_station', 'cafe', 'car_rental', 'car_sharing', 'car_wash', 'casino', 'charging_station', 'childcare', 'cinema', 'clinic', 'clock', 'college', 'community_centre', 'compressed_air', 'conference_centre', 'courthouse', 'crematorium', 'dentist', 'dive_centre', 'doctors', 'dog_toilet', 'dressing_room', 'drinking_water', 'driving_school', 'events_venue', 'fast_food', 'ferry_terminal', 'fire_station', 'food_court', 'fountain', 'fuel', 'funeral_hall', 'gambling', 'give_box', 'grave_yard', 'grit_bin', 'hospital', 'hunting_stand', 'ice_cream', 'internet_cafe', 'kindergarten', 'kitchen', 'kneipp_water_cure', 'language_school', 'library', 'lounger', 'love_hotel', 'marketplace', 'monastery', 'motorcycle_parking', 'music_school', 'nightclub', 'nursing_home', 'parcel_locker', 'parking', 'parking_entrance', 'parking_space', 'pharmacy', 'photo_booth', 'place_of_mourning', 'place_of_worship', 'planetarium', 'police', 'post_box', 'post_depot', 'post_office', 'prison', 'pub', 'public_bath', 'public_bookcase', 'ranger_station', 'recycling', 'refugee_site', 'restaurant', 'sanitary_dump_station', 'school', 'shelter', 'shower', 'social_centre', 'social_facility', 'stripclub', 'studio', 'swingerclub', 'taxi', 'telephone', 'theatre', 'toilets', 'townhall', 'toy_library', 'training', 'university', 'vehicle_inspection', 'vending_machine', 'veterinary', 'waste_basket', 'waste_disposal', 'waste_transfer_station', 'water_point', 'watering_place']
|
11 |
+
|
12 |
+
CLASS_95_LIST = ['arts_centre', 'atm', 'baby_hatch', 'bank', 'bar', 'bbq', 'bench', 'bicycle_parking', 'bicycle_rental', 'bicycle_repair_station', 'biergarten', 'boat_rental', 'boat_sharing', 'brothel', 'bureau_de_change', 'bus_station', 'cafe', 'car_rental', 'car_sharing', 'car_wash', 'casino', 'charging_station', 'cinema', 'clinic', 'college', 'community_centre', 'compressed_air', 'conference_centre', 'courthouse', 'dentist', 'doctors', 'dog_toilet', 'dressing_room', 'drinking_water', 'driving_school', 'events_venue', 'fast_food', 'ferry_terminal', 'fire_station', 'food_court', 'fountain', 'fuel', 'gambling', 'give_box', 'grit_bin', 'hospital', 'ice_cream', 'kindergarten', 'language_school', 'library', 'love_hotel', 'motorcycle_parking', 'music_school', 'nightclub', 'nursing_home', 'parcel_locker', 'parking', 'parking_entrance', 'parking_space', 'pharmacy', 'planetarium', 'police', 'post_box', 'post_depot', 'post_office', 'prison', 'pub', 'public_bookcase', 'ranger_station', 'recycling', 'restaurant', 'sanitary_dump_station', 'school', 'shelter', 'shower', 'social_centre', 'social_facility', 'stripclub', 'studio', 'swingerclub', 'taxi', 'telephone', 'theatre', 'toilets', 'townhall', 'toy_library', 'training', 'university', 'vehicle_inspection', 'veterinary', 'waste_basket', 'waste_disposal', 'waste_transfer_station', 'water_point', 'watering_place']
|
13 |
+
|
14 |
+
CLASS_74_LIST = ['arts_centre', 'atm', 'bank', 'bar', 'bench', 'bicycle_parking', 'bicycle_rental', 'bicycle_repair_station', 'boat_rental', 'bureau_de_change', 'bus_station', 'cafe', 'car_rental', 'car_sharing', 'car_wash', 'charging_station', 'cinema', 'clinic', 'college', 'community_centre', 'conference_centre', 'courthouse', 'dentist', 'doctors', 'drinking_water', 'driving_school', 'events_venue', 'fast_food', 'ferry_terminal', 'fire_station', 'food_court', 'fountain', 'fuel', 'gambling', 'hospital', 'kindergarten', 'language_school', 'library', 'motorcycle_parking', 'music_school', 'nightclub', 'nursing_home', 'parcel_locker', 'parking', 'pharmacy', 'police', 'post_box', 'post_depot', 'post_office', 'pub', 'public_bookcase', 'recycling', 'restaurant', 'sanitary_dump_station', 'school', 'shelter', 'social_centre', 'social_facility', 'stripclub', 'studio', 'swingerclub', 'taxi', 'telephone', 'theatre', 'toilets', 'townhall', 'university', 'vehicle_inspection', 'veterinary', 'waste_basket', 'waste_disposal', 'waste_transfer_station', 'water_point', 'watering_place']
|
15 |
+
|
16 |
+
FEWSHOT_CLASS_55_LIST = ['arts_centre', 'atm', 'bank', 'bar', 'bench', 'bicycle_parking', 'bicycle_rental', 'boat_rental', 'bureau_de_change', 'bus_station', 'cafe', 'car_rental', 'car_sharing', 'car_wash', 'charging_station', 'cinema', 'clinic', 'college', 'community_centre', 'courthouse', 'dentist', 'doctors', 'drinking_water', 'driving_school', 'events_venue', 'fast_food', 'ferry_terminal', 'fire_station', 'fountain', 'fuel', 'hospital', 'kindergarten', 'library', 'music_school', 'nightclub', 'parking', 'pharmacy', 'police', 'post_box', 'post_office', 'pub', 'public_bookcase', 'recycling', 'restaurant', 'school', 'shelter', 'social_centre', 'social_facility', 'studio', 'theatre', 'toilets', 'townhall', 'university', 'vehicle_inspection', 'veterinary']
|
17 |
+
|
18 |
+
|
19 |
+
DICT_9to74 = {'sustenance':['bar','cafe','fast_food','food_court','pub','restaurant'],
|
20 |
+
'education':['college','driving_school','kindergarten','language_school','library','music_school','school','university'],
|
21 |
+
'transportation':['bicycle_parking','bicycle_repair_station','bicycle_rental','boat_rental',
|
22 |
+
'bus_station','car_rental','car_sharing','car_wash','vehicle_inspection','charging_station','ferry_terminal',
|
23 |
+
'fuel','motorcycle_parking','parking','taxi'],
|
24 |
+
'financial':['atm','bank','bureau_de_change'],
|
25 |
+
'healthcare':['clinic','dentist','doctors','hospital','nursing_home','pharmacy','social_facility','veterinary'],
|
26 |
+
'entertainment_arts_culture':['arts_centre','cinema','community_centre',
|
27 |
+
'conference_centre','events_venue','fountain','gambling',
|
28 |
+
'nightclub','public_bookcase','social_centre','stripclub','studio','swingerclub','theatre'],
|
29 |
+
'public_service':['courthouse','fire_station','police','post_box',
|
30 |
+
'post_depot','post_office','townhall'],
|
31 |
+
'facilities':['bench','drinking_water','parcel_locker','shelter',
|
32 |
+
'telephone','toilets','water_point','watering_place'],
|
33 |
+
'waste_management':['sanitary_dump_station','recycling','waste_basket','waste_disposal','waste_transfer_station',]
|
34 |
+
}
|
35 |
+
|
36 |
+
DICT_74to9 = revert_dict(DICT_9to74)
|
37 |
+
|
38 |
+
DICT_9to95 = {
|
39 |
+
'education':{'college','driving_school','kindergarten','language_school','library','toy_library','training','music_school','school','university'},
|
40 |
+
'entertainment_arts_culture':{'arts_centre','brothel','casino','cinema','community_centre','conference_centre','events_venue','fountain','gambling','love_hotel','nightclub','planetarium','public_bookcase','social_centre','stripclub','studio','swingerclub','theatre'},
|
41 |
+
'facilities':{'bbq','bench','dog_toilet','dressing_room','drinking_water','give_box','parcel_locker','shelter','shower','telephone','toilets','water_point','watering_place'},
|
42 |
+
'financial':{'atm','bank','bureau_de_change'},
|
43 |
+
'healthcare':{'baby_hatch','clinic','dentist','doctors','hospital','nursing_home','pharmacy','social_facility','veterinary'},
|
44 |
+
'public_service':{'courthouse','fire_station','police','post_box','post_depot','post_office','prison','ranger_station','townhall'},
|
45 |
+
'sustenance':{'bar','biergarten','cafe','fast_food','food_court','ice_cream','pub','restaurant',},
|
46 |
+
'transportation':{'bicycle_parking','bicycle_repair_station','bicycle_rental','boat_rental','boat_sharing','bus_station','car_rental','car_sharing','car_wash','compressed_air','vehicle_inspection','charging_station','ferry_terminal','fuel','grit_bin','motorcycle_parking','parking','parking_entrance','parking_space','taxi'},
|
47 |
+
'waste_management':{'sanitary_dump_station','recycling','waste_basket','waste_disposal','waste_transfer_station'}}
|
48 |
+
|
49 |
+
DICT_95to9 = revert_dict(DICT_9to95)
|
50 |
+
|
51 |
+
# DICT_95to9 = {
|
52 |
+
# 'college':'education',
|
53 |
+
# 'driving_school':'education',
|
54 |
+
# 'kindergarten':'education',
|
55 |
+
# 'language_school':'education',
|
56 |
+
# 'library':'education',
|
57 |
+
# 'toy_library':'education',
|
58 |
+
# 'training':'education',
|
59 |
+
# 'music_school':'education',
|
60 |
+
# 'school':'education',
|
61 |
+
# 'university':'education',
|
62 |
+
|
63 |
+
# 'arts_centre':'entertainment_arts_culture',
|
64 |
+
# 'brothel':'entertainment_arts_culture',
|
65 |
+
# 'casino':'entertainment_arts_culture',
|
66 |
+
# 'cinema':'entertainment_arts_culture',
|
67 |
+
# 'community_centre':'entertainment_arts_culture',
|
68 |
+
# 'conference_centre':'entertainment_arts_culture',
|
69 |
+
# 'events_venue':'entertainment_arts_culture',
|
70 |
+
# 'fountain':'entertainment_arts_culture',
|
71 |
+
# 'gambling':'entertainment_arts_culture',
|
72 |
+
# 'love_hotel':'entertainment_arts_culture',
|
73 |
+
# 'nightclub':'entertainment_arts_culture',
|
74 |
+
# 'planetarium':'entertainment_arts_culture',
|
75 |
+
# 'public_bookcase':'entertainment_arts_culture',
|
76 |
+
# 'social_centre':'entertainment_arts_culture',
|
77 |
+
# 'stripclub':'entertainment_arts_culture',
|
78 |
+
# 'studio':'entertainment_arts_culture',
|
79 |
+
# 'swingerclub':'entertainment_arts_culture',
|
80 |
+
# 'theatre':'entertainment_arts_culture',
|
81 |
+
|
82 |
+
# 'bbq': 'facilities',
|
83 |
+
# 'bench': 'facilities',
|
84 |
+
# 'dog_toilet': 'facilities',
|
85 |
+
# 'dressing_room': 'facilities',
|
86 |
+
# 'drinking_water': 'facilities',
|
87 |
+
# 'give_box': 'facilities',
|
88 |
+
# 'parcel_locker': 'facilities',
|
89 |
+
# 'shelter': 'facilities',
|
90 |
+
# 'shower': 'facilities',
|
91 |
+
# 'telephone': 'facilities',
|
92 |
+
# 'toilets': 'facilities',
|
93 |
+
# 'water_point': 'facilities',
|
94 |
+
# 'watering_place': 'facilities',
|
95 |
+
|
96 |
+
# 'atm': 'financial',
|
97 |
+
# 'bank': 'financial',
|
98 |
+
# 'bureau_de_change': 'financial',
|
99 |
+
|
100 |
+
# 'baby_hatch':'healthcare',
|
101 |
+
# 'clinic':'healthcare',
|
102 |
+
# 'dentist':'healthcare',
|
103 |
+
# 'doctors':'healthcare',
|
104 |
+
# 'hospital':'healthcare',
|
105 |
+
# 'nursing_home':'healthcare',
|
106 |
+
# 'pharmacy':'healthcare',
|
107 |
+
# 'social_facility':'healthcare',
|
108 |
+
# 'veterinary':'healthcare',
|
109 |
+
|
110 |
+
# 'courthouse': 'public_service',
|
111 |
+
# 'fire_station': 'public_service',
|
112 |
+
# 'police': 'public_service',
|
113 |
+
# 'post_box': 'public_service',
|
114 |
+
# 'post_depot': 'public_service',
|
115 |
+
# 'post_office': 'public_service',
|
116 |
+
# 'prison': 'public_service',
|
117 |
+
# 'ranger_station': 'public_service',
|
118 |
+
# 'townhall': 'public_service',
|
119 |
+
|
120 |
+
# 'bar': 'sustenance',
|
121 |
+
# 'biergarten': 'sustenance',
|
122 |
+
# 'cafe': 'sustenance',
|
123 |
+
# 'fast_food': 'sustenance',
|
124 |
+
# 'food_court': 'sustenance',
|
125 |
+
# 'ice_cream': 'sustenance',
|
126 |
+
# 'pub': 'sustenance',
|
127 |
+
# 'restaurant': 'sustenance',
|
128 |
+
|
129 |
+
# 'bicycle_parking': 'transportation',
|
130 |
+
# 'bicycle_repair_station': 'transportation',
|
131 |
+
# 'bicycle_rental': 'transportation',
|
132 |
+
# 'boat_rental': 'transportation',
|
133 |
+
# 'boat_sharing': 'transportation',
|
134 |
+
# 'bus_station': 'transportation',
|
135 |
+
# 'car_rental': 'transportation',
|
136 |
+
# 'car_sharing': 'transportation',
|
137 |
+
# 'car_wash': 'transportation',
|
138 |
+
# 'compressed_air': 'transportation',
|
139 |
+
# 'vehicle_inspection': 'transportation',
|
140 |
+
# 'charging_station': 'transportation',
|
141 |
+
# 'ferry_terminal': 'transportation',
|
142 |
+
# 'fuel': 'transportation',
|
143 |
+
# 'grit_bin': 'transportation',
|
144 |
+
# 'motorcycle_parking': 'transportation',
|
145 |
+
# 'parking': 'transportation',
|
146 |
+
# 'parking_entrance': 'transportation',
|
147 |
+
# 'parking_space': 'transportation',
|
148 |
+
# 'taxi': 'transportation',
|
149 |
+
|
150 |
+
# 'sanitary_dump_station': 'waste_management',
|
151 |
+
# 'recycling': 'waste_management',
|
152 |
+
# 'waste_basket': 'waste_management',
|
153 |
+
# 'waste_disposal': 'waste_management',
|
154 |
+
# 'waste_transfer_station': 'waste_management',
|
155 |
+
|
156 |
+
# }
|
157 |
+
|
158 |
+
# CLASS_9_LIST = ['sustenance', 'education', 'transportation', 'financial', 'healthcare', 'entertainment_arts_culture', 'public_service', 'facilities', 'waste_management']
|
159 |
+
|
160 |
+
# FINE_LIST = ['bar','biergarten','cafe','fast_food','food_court','ice_cream','pub','restaurant','college','driving_school','kindergarten','language_school','library','toy_library','training','music_school','school','university','bicycle_parking','bicycle_repair_station','bicycle_rental','boat_rental','boat_sharing','bus_station','car_rental','car_sharing','car_wash','compressed_air','vehicle_inspection','charging_station','ferry_terminal','fuel','grit_bin','motorcycle_parking','parking','parking_entrance','parking_space','taxi','atm','bank','bureau_de_change','baby_hatch','clinic','dentist','doctors','hospital','nursing_home','pharmacy','social_facility','veterinary','arts_centre','brothel','casino','cinema','community_centre','conference_centre','events_venue','fountain','gambling','love_hotel','nightclub','planetarium','public_bookcase','social_centre','stripclub','studio','swingerclub','theatre','courthouse','fire_station','police','post_box','post_depot','post_office','prison','ranger_station','townhall','bbq','bench','dog_toilet','dressing_room','drinking_water','give_box','parcel_locker','shelter','shower','telephone','toilets','water_point','watering_place','sanitary_dump_station','recycling','waste_basket','waste_disposal','waste_transfer_station']
|
161 |
+
|
162 |
+
# FINE_LIST = ['bar','biergarten','cafe','fast_food','food_court','ice_cream','pub','restaurant','college','driving_school','kindergarten','language_school','library','toy_library','training','music_school','school','university','bicycle_parking','bicycle_repair_station','bicycle_rental','boat_rental','boat_sharing','bus_station','car_rental','car_sharing','car_wash','compressed_air','vehicle_inspection','charging_station','ferry_terminal','fuel','grit_bin','motorcycle_parking','parking','parking_entrance','parking_space','taxi','atm','bank','bureau_de_change','baby_hatch','clinic','dentist','doctors','hospital','nursing_home','pharmacy','social_facility','veterinary','arts_centre','brothel','casino','cinema','community_centre','conference_centre','events_venue','fountain','gambling','love_hotel','nightclub','planetarium','public_bookcase','social_centre','stripclub','studio','swingerclub','theatre','courthouse','fire_station','police','post_box','post_depot','post_office','prison','ranger_station','townhall','bbq','bench','dog_toilet','dressing_room','drinking_water','give_box','parcel_locker','shelter','shower','telephone','toilets','water_point','watering_place','sanitary_dump_station','recycling','waste_basket','waste_disposal','waste_transfer_station','animal_boarding','animal_breeding','animal_shelter','baking_oven','childcare','clock','crematorium','dive_centre','funeral_hall','grave_yard','hunting_stand','internet_cafe','kitchen','kneipp_water_cure','lounger','marketplace','monastery','photo_booth','place_of_mourning','place_of_worship','public_bath','refugee_site','vending_machine']
|
models/spabert/datasets/dataset_loader.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
import pdb
|
5 |
+
|
6 |
+
class SpatialDataset(Dataset):
|
7 |
+
def __init__(self, tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors = False ):
|
8 |
+
self.tokenizer = tokenizer
|
9 |
+
self.max_token_len = max_token_len
|
10 |
+
self.distance_norm_factor = distance_norm_factor
|
11 |
+
self.sep_between_neighbors = sep_between_neighbors
|
12 |
+
|
13 |
+
|
14 |
+
def parse_spatial_context(self, pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill, pivot_dist_fill = 0):
|
15 |
+
|
16 |
+
sep_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.sep_token)
|
17 |
+
cls_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.cls_token)
|
18 |
+
mask_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
19 |
+
pad_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token)
|
20 |
+
max_token_len = self.max_token_len
|
21 |
+
|
22 |
+
|
23 |
+
# process pivot
|
24 |
+
pivot_name_tokens = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(pivot_name))
|
25 |
+
pivot_token_len = len(pivot_name_tokens)
|
26 |
+
|
27 |
+
pivot_lng = pivot_pos[0]
|
28 |
+
pivot_lat = pivot_pos[1]
|
29 |
+
|
30 |
+
# prepare entity mask
|
31 |
+
entity_mask_arr = []
|
32 |
+
rand_entity = np.random.uniform(size = len(neighbor_name_list) + 1) # random number for masking entities including neighbors and pivot
|
33 |
+
# True for mask, False for unmask
|
34 |
+
|
35 |
+
# check if pivot entity needs to be masked out, 15% prob. to be masked out
|
36 |
+
if rand_entity[0] < 0.15:
|
37 |
+
entity_mask_arr.extend([True] * pivot_token_len)
|
38 |
+
else:
|
39 |
+
entity_mask_arr.extend([False] * pivot_token_len)
|
40 |
+
|
41 |
+
# process neighbors
|
42 |
+
neighbor_token_list = []
|
43 |
+
neighbor_lng_list = []
|
44 |
+
neighbor_lat_list = []
|
45 |
+
|
46 |
+
# add separator between pivot and neighbor tokens
|
47 |
+
# a trick to avoid adding separator token after the class name (for class name encoding of margin-ranking loss)
|
48 |
+
if self.sep_between_neighbors and pivot_dist_fill==0:
|
49 |
+
neighbor_lng_list.append(spatial_dist_fill)
|
50 |
+
neighbor_lat_list.append(spatial_dist_fill)
|
51 |
+
neighbor_token_list.append(sep_token_id)
|
52 |
+
|
53 |
+
for neighbor_name, neighbor_geometry, rnd in zip(neighbor_name_list, neighbor_geometry_list, rand_entity[1:]):
|
54 |
+
|
55 |
+
if not neighbor_name[0].isalpha():
|
56 |
+
# only consider neighbors starting with letters
|
57 |
+
continue
|
58 |
+
|
59 |
+
neighbor_token = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(neighbor_name))
|
60 |
+
neighbor_token_len = len(neighbor_token)
|
61 |
+
|
62 |
+
# compute the relative distance from neighbor to pivot,
|
63 |
+
# normalize the relative distance by distance_norm_factor
|
64 |
+
# apply the calculated distance for all the subtokens of the neighbor
|
65 |
+
# neighbor_lng_list.extend([(neighbor_geometry[0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len)
|
66 |
+
# neighbor_lat_list.extend([(neighbor_geometry[1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len)
|
67 |
+
|
68 |
+
if 'coordinates' in neighbor_geometry: # to handle different json dict structures
|
69 |
+
neighbor_lng_list.extend([(neighbor_geometry['coordinates'][0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len)
|
70 |
+
neighbor_lat_list.extend([(neighbor_geometry['coordinates'][1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len)
|
71 |
+
neighbor_token_list.extend(neighbor_token)
|
72 |
+
else:
|
73 |
+
neighbor_lng_list.extend([(neighbor_geometry[0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len)
|
74 |
+
neighbor_lat_list.extend([(neighbor_geometry[1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len)
|
75 |
+
neighbor_token_list.extend(neighbor_token)
|
76 |
+
|
77 |
+
if self.sep_between_neighbors:
|
78 |
+
neighbor_lng_list.append(spatial_dist_fill)
|
79 |
+
neighbor_lat_list.append(spatial_dist_fill)
|
80 |
+
neighbor_token_list.append(sep_token_id)
|
81 |
+
|
82 |
+
entity_mask_arr.extend([False])
|
83 |
+
|
84 |
+
|
85 |
+
if rnd < 0.15:
|
86 |
+
#True: mask out, False: Keey original token
|
87 |
+
entity_mask_arr.extend([True] * neighbor_token_len)
|
88 |
+
else:
|
89 |
+
entity_mask_arr.extend([False] * neighbor_token_len)
|
90 |
+
|
91 |
+
|
92 |
+
pseudo_sentence = pivot_name_tokens + neighbor_token_list
|
93 |
+
dist_lng_list = [pivot_dist_fill] * pivot_token_len + neighbor_lng_list
|
94 |
+
dist_lat_list = [pivot_dist_fill] * pivot_token_len + neighbor_lat_list
|
95 |
+
|
96 |
+
|
97 |
+
#including cls and sep
|
98 |
+
sent_len = len(pseudo_sentence)
|
99 |
+
|
100 |
+
max_token_len_middle = max_token_len -2 # 2 for CLS and SEP token
|
101 |
+
|
102 |
+
# padding and truncation
|
103 |
+
if sent_len > max_token_len_middle :
|
104 |
+
pseudo_sentence = [cls_token_id] + pseudo_sentence[:max_token_len_middle] + [sep_token_id]
|
105 |
+
dist_lat_list = [spatial_dist_fill] + dist_lat_list[:max_token_len_middle]+ [spatial_dist_fill]
|
106 |
+
dist_lng_list = [spatial_dist_fill] + dist_lng_list[:max_token_len_middle]+ [spatial_dist_fill]
|
107 |
+
attention_mask = [False] + [1] * max_token_len_middle + [False] # make sure SEP and CLS are not attented to
|
108 |
+
else:
|
109 |
+
pad_len = max_token_len_middle - sent_len
|
110 |
+
assert pad_len >= 0
|
111 |
+
|
112 |
+
pseudo_sentence = [cls_token_id] + pseudo_sentence + [sep_token_id] + [pad_token_id] * pad_len
|
113 |
+
dist_lat_list = [spatial_dist_fill] + dist_lat_list + [spatial_dist_fill] + [spatial_dist_fill] * pad_len
|
114 |
+
dist_lng_list = [spatial_dist_fill] + dist_lng_list + [spatial_dist_fill] + [spatial_dist_fill] * pad_len
|
115 |
+
attention_mask = [False] + [1] * sent_len + [0] * pad_len + [False]
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
norm_lng_list = np.array(dist_lng_list) # / 0.0001
|
121 |
+
norm_lat_list = np.array(dist_lat_list) # / 0.0001
|
122 |
+
|
123 |
+
|
124 |
+
# mask entity in the pseudo sentence
|
125 |
+
entity_mask_indices = np.where(entity_mask_arr)[0]
|
126 |
+
masked_entity_input = [mask_token_id if i in entity_mask_indices else pseudo_sentence[i] for i in range(0, max_token_len)]
|
127 |
+
|
128 |
+
|
129 |
+
# mask token in the pseudo sentence
|
130 |
+
rand_token = np.random.uniform(size = len(pseudo_sentence))
|
131 |
+
# do not mask out cls and sep token. True: masked tokens False: Keey original token
|
132 |
+
token_mask_arr = (rand_token <0.15) & (np.array(pseudo_sentence) != cls_token_id) & (np.array(pseudo_sentence) != sep_token_id) & (np.array(pseudo_sentence) != pad_token_id)
|
133 |
+
token_mask_indices = np.where(token_mask_arr)[0]
|
134 |
+
|
135 |
+
masked_token_input = [mask_token_id if i in token_mask_indices else pseudo_sentence[i] for i in range(0, max_token_len)]
|
136 |
+
|
137 |
+
|
138 |
+
# yield masked_token with 50% prob, masked_entity with 50% prob
|
139 |
+
if np.random.rand() > 0.5:
|
140 |
+
masked_input = torch.tensor(masked_entity_input)
|
141 |
+
else:
|
142 |
+
masked_input = torch.tensor(masked_token_input)
|
143 |
+
|
144 |
+
train_data = {}
|
145 |
+
train_data['pivot_name'] = pivot_name
|
146 |
+
train_data['pivot_token_len'] = pivot_token_len
|
147 |
+
train_data['masked_input'] = masked_input
|
148 |
+
train_data['sent_position_ids'] = torch.tensor(np.arange(0, len(pseudo_sentence)))
|
149 |
+
train_data['attention_mask'] = torch.tensor(attention_mask)
|
150 |
+
train_data['norm_lng_list'] = torch.tensor(norm_lng_list).to(torch.float32)
|
151 |
+
train_data['norm_lat_list'] = torch.tensor(norm_lat_list).to(torch.float32)
|
152 |
+
train_data['pseudo_sentence'] = torch.tensor(pseudo_sentence)
|
153 |
+
|
154 |
+
return train_data
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
def __len__(self):
|
159 |
+
return NotImplementedError
|
160 |
+
|
161 |
+
def __getitem__(self, index):
|
162 |
+
raise NotImplementedError
|
models/spabert/datasets/dataset_loader_ver2.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
import pdb
|
5 |
+
|
6 |
+
class SpatialDataset(Dataset):
|
7 |
+
def __init__(self, tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors = False ):
|
8 |
+
self.tokenizer = tokenizer
|
9 |
+
self.max_token_len = max_token_len
|
10 |
+
self.distance_norm_factor = distance_norm_factor
|
11 |
+
self.sep_between_neighbors = sep_between_neighbors
|
12 |
+
|
13 |
+
|
14 |
+
def parse_spatial_context(self, pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill, pivot_dist_fill = 0):
|
15 |
+
|
16 |
+
sep_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.sep_token)
|
17 |
+
cls_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.cls_token)
|
18 |
+
#mask_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
19 |
+
pad_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token)
|
20 |
+
max_token_len = self.max_token_len
|
21 |
+
|
22 |
+
|
23 |
+
#print("Module reloaded and changes are reflected")
|
24 |
+
# process pivot
|
25 |
+
pivot_name_tokens = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(pivot_name))
|
26 |
+
pivot_token_len = len(pivot_name_tokens)
|
27 |
+
|
28 |
+
pivot_lng = pivot_pos[0]
|
29 |
+
pivot_lat = pivot_pos[1]
|
30 |
+
|
31 |
+
# prepare entity mask
|
32 |
+
entity_mask_arr = []
|
33 |
+
rand_entity = np.random.uniform(size = len(neighbor_name_list) + 1) # random number for masking entities including neighbors and pivot
|
34 |
+
# True for mask, False for unmask
|
35 |
+
|
36 |
+
# check if pivot entity needs to be masked out, 15% prob. to be masked out
|
37 |
+
#if rand_entity[0] < 0.15:
|
38 |
+
# entity_mask_arr.extend([True] * pivot_token_len)
|
39 |
+
#else:
|
40 |
+
entity_mask_arr.extend([False] * pivot_token_len)
|
41 |
+
|
42 |
+
# process neighbors
|
43 |
+
neighbor_token_list = []
|
44 |
+
neighbor_lng_list = []
|
45 |
+
neighbor_lat_list = []
|
46 |
+
|
47 |
+
# add separator between pivot and neighbor tokens
|
48 |
+
# a trick to avoid adding separator token after the class name (for class name encoding of margin-ranking loss)
|
49 |
+
if self.sep_between_neighbors and pivot_dist_fill==0:
|
50 |
+
neighbor_lng_list.append(spatial_dist_fill)
|
51 |
+
neighbor_lat_list.append(spatial_dist_fill)
|
52 |
+
neighbor_token_list.append(sep_token_id)
|
53 |
+
|
54 |
+
for neighbor_name, neighbor_geometry, rnd in zip(neighbor_name_list, neighbor_geometry_list, rand_entity[1:]):
|
55 |
+
|
56 |
+
if not neighbor_name[0].isalpha():
|
57 |
+
# only consider neighbors starting with letters
|
58 |
+
continue
|
59 |
+
|
60 |
+
neighbor_token = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(neighbor_name))
|
61 |
+
neighbor_token_len = len(neighbor_token)
|
62 |
+
|
63 |
+
# compute the relative distance from neighbor to pivot,
|
64 |
+
# normalize the relative distance by distance_norm_factor
|
65 |
+
# apply the calculated distance for all the subtokens of the neighbor
|
66 |
+
# neighbor_lng_list.extend([(neighbor_geometry[0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len)
|
67 |
+
# neighbor_lat_list.extend([(neighbor_geometry[1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len)
|
68 |
+
|
69 |
+
if 'coordinates' in neighbor_geometry: # to handle different json dict structures
|
70 |
+
neighbor_lng_list.extend([(neighbor_geometry['coordinates'][0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len)
|
71 |
+
neighbor_lat_list.extend([(neighbor_geometry['coordinates'][1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len)
|
72 |
+
neighbor_token_list.extend(neighbor_token)
|
73 |
+
else:
|
74 |
+
neighbor_lng_list.extend([(neighbor_geometry[0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len)
|
75 |
+
neighbor_lat_list.extend([(neighbor_geometry[1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len)
|
76 |
+
neighbor_token_list.extend(neighbor_token)
|
77 |
+
|
78 |
+
if self.sep_between_neighbors:
|
79 |
+
neighbor_lng_list.append(spatial_dist_fill)
|
80 |
+
neighbor_lat_list.append(spatial_dist_fill)
|
81 |
+
neighbor_token_list.append(sep_token_id)
|
82 |
+
|
83 |
+
entity_mask_arr.extend([False])
|
84 |
+
|
85 |
+
|
86 |
+
#if rnd < 0.15:
|
87 |
+
# #True: mask out, False: Keey original token
|
88 |
+
# entity_mask_arr.extend([True] * neighbor_token_len)
|
89 |
+
#else:
|
90 |
+
entity_mask_arr.extend([False] * neighbor_token_len)
|
91 |
+
|
92 |
+
|
93 |
+
pseudo_sentence = pivot_name_tokens + neighbor_token_list
|
94 |
+
dist_lng_list = [pivot_dist_fill] * pivot_token_len + neighbor_lng_list
|
95 |
+
dist_lat_list = [pivot_dist_fill] * pivot_token_len + neighbor_lat_list
|
96 |
+
|
97 |
+
|
98 |
+
#including cls and sep
|
99 |
+
sent_len = len(pseudo_sentence)
|
100 |
+
|
101 |
+
max_token_len_middle = max_token_len -2 # 2 for CLS and SEP token
|
102 |
+
|
103 |
+
# padding and truncation
|
104 |
+
if sent_len > max_token_len_middle :
|
105 |
+
pseudo_sentence = [cls_token_id] + pseudo_sentence[:max_token_len_middle] + [sep_token_id]
|
106 |
+
dist_lat_list = [spatial_dist_fill] + dist_lat_list[:max_token_len_middle]+ [spatial_dist_fill]
|
107 |
+
dist_lng_list = [spatial_dist_fill] + dist_lng_list[:max_token_len_middle]+ [spatial_dist_fill]
|
108 |
+
attention_mask = [False] + [1] * max_token_len_middle + [False] # make sure SEP and CLS are not attented to
|
109 |
+
else:
|
110 |
+
pad_len = max_token_len_middle - sent_len
|
111 |
+
assert pad_len >= 0
|
112 |
+
|
113 |
+
pseudo_sentence = [cls_token_id] + pseudo_sentence + [sep_token_id] + [pad_token_id] * pad_len
|
114 |
+
dist_lat_list = [spatial_dist_fill] + dist_lat_list + [spatial_dist_fill] + [spatial_dist_fill] * pad_len
|
115 |
+
dist_lng_list = [spatial_dist_fill] + dist_lng_list + [spatial_dist_fill] + [spatial_dist_fill] * pad_len
|
116 |
+
attention_mask = [False] + [1] * sent_len + [0] * pad_len + [False]
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
norm_lng_list = np.array(dist_lng_list) # / 0.0001
|
122 |
+
norm_lat_list = np.array(dist_lat_list) # / 0.0001
|
123 |
+
|
124 |
+
|
125 |
+
## mask entity in the pseudo sentence
|
126 |
+
#entity_mask_indices = np.where(entity_mask_arr)[0]
|
127 |
+
#masked_entity_input = [mask_token_id if i in entity_mask_indices else pseudo_sentence[i] for i in range(0, max_token_len)]
|
128 |
+
#
|
129 |
+
#
|
130 |
+
## mask token in the pseudo sentence
|
131 |
+
#rand_token = np.random.uniform(size = len(pseudo_sentence))
|
132 |
+
## do not mask out cls and sep token. True: masked tokens False: Keey original token
|
133 |
+
#token_mask_arr = (rand_token <0.15) & (np.array(pseudo_sentence) != cls_token_id) & (np.array(pseudo_sentence) != sep_token_id) & (np.array(pseudo_sentence) != pad_token_id)
|
134 |
+
#token_mask_indices = np.where(token_mask_arr)[0]
|
135 |
+
#
|
136 |
+
#masked_token_input = [mask_token_id if i in token_mask_indices else pseudo_sentence[i] for i in range(0, max_token_len)]
|
137 |
+
#
|
138 |
+
#
|
139 |
+
## yield masked_token with 50% prob, masked_entity with 50% prob
|
140 |
+
#if np.random.rand() > 0.5:
|
141 |
+
# masked_input = torch.tensor(masked_entity_input)
|
142 |
+
#else:
|
143 |
+
# masked_input = torch.tensor(masked_token_input)
|
144 |
+
masked_input = torch.tensor(pseudo_sentence)
|
145 |
+
|
146 |
+
train_data = {}
|
147 |
+
train_data['pivot_name'] = pivot_name
|
148 |
+
train_data['pivot_token_len'] = pivot_token_len
|
149 |
+
train_data['masked_input'] = masked_input
|
150 |
+
train_data['sent_position_ids'] = torch.tensor(np.arange(0, len(pseudo_sentence)))
|
151 |
+
train_data['attention_mask'] = torch.tensor(attention_mask)
|
152 |
+
train_data['norm_lng_list'] = torch.tensor(norm_lng_list).to(torch.float32)
|
153 |
+
train_data['norm_lat_list'] = torch.tensor(norm_lat_list).to(torch.float32)
|
154 |
+
train_data['pseudo_sentence'] = torch.tensor(pseudo_sentence)
|
155 |
+
|
156 |
+
return train_data
|
157 |
+
|
158 |
+
|
159 |
+
|
160 |
+
def __len__(self):
|
161 |
+
return NotImplementedError
|
162 |
+
|
163 |
+
def __getitem__(self, index):
|
164 |
+
raise NotImplementedError
|
models/spabert/datasets/osm_sample_loader.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import numpy as np
|
4 |
+
import json
|
5 |
+
import math
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from transformers import RobertaTokenizer, BertTokenizer
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
#sys.path.append('/home/zekun/spatial_bert/spatial_bert/datasets')
|
11 |
+
sys.path.append('/content/drive/MyDrive/spaBERT/spabert/datasets')
|
12 |
+
from dataset_loader_ver2 import SpatialDataset
|
13 |
+
|
14 |
+
import pdb
|
15 |
+
|
16 |
+
class PbfMapDataset(SpatialDataset):
|
17 |
+
def __init__(self, data_file_path, tokenizer=None, max_token_len = 512, distance_norm_factor = 0.0001, spatial_dist_fill=10,
|
18 |
+
with_type = True, sep_between_neighbors = False, label_encoder = None, mode = None, num_neighbor_limit = None, random_remove_neighbor = 0.,type_key_str='class'):
|
19 |
+
|
20 |
+
if tokenizer is None:
|
21 |
+
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
22 |
+
else:
|
23 |
+
self.tokenizer = tokenizer
|
24 |
+
|
25 |
+
self.max_token_len = max_token_len
|
26 |
+
self.spatial_dist_fill = spatial_dist_fill # should be normalized distance fill, larger than all normalized neighbor distance
|
27 |
+
self.with_type = with_type
|
28 |
+
self.sep_between_neighbors = sep_between_neighbors
|
29 |
+
self.label_encoder = label_encoder
|
30 |
+
self.num_neighbor_limit = num_neighbor_limit
|
31 |
+
self.read_file(data_file_path, mode)
|
32 |
+
self.random_remove_neighbor = random_remove_neighbor
|
33 |
+
self.type_key_str = type_key_str # key name of the class type in the input data dictionary
|
34 |
+
|
35 |
+
super(PbfMapDataset, self).__init__(self.tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors )
|
36 |
+
|
37 |
+
|
38 |
+
def read_file(self, data_file_path, mode):
|
39 |
+
|
40 |
+
with open(data_file_path, 'r') as f:
|
41 |
+
data = f.readlines()
|
42 |
+
|
43 |
+
if mode == 'train':
|
44 |
+
data = data[0:int(len(data) * 0.8)]
|
45 |
+
elif mode == 'test':
|
46 |
+
data = data[int(len(data) * 0.8):]
|
47 |
+
elif mode is None: # use the full dataset (for mlm)
|
48 |
+
pass
|
49 |
+
else:
|
50 |
+
raise NotImplementedError
|
51 |
+
|
52 |
+
self.len_data = len(data) # updated data length
|
53 |
+
self.data = data
|
54 |
+
|
55 |
+
def load_data(self, index):
|
56 |
+
|
57 |
+
spatial_dist_fill = self.spatial_dist_fill
|
58 |
+
line = self.data[index] # take one line from the input data according to the index
|
59 |
+
|
60 |
+
line_data_dict = json.loads(line)
|
61 |
+
|
62 |
+
# process pivot
|
63 |
+
pivot_name = line_data_dict['info']['name']
|
64 |
+
pivot_pos = line_data_dict['info']['geometry']['coordinates']
|
65 |
+
|
66 |
+
|
67 |
+
neighbor_info = line_data_dict['neighbor_info']
|
68 |
+
neighbor_name_list = neighbor_info['name_list']
|
69 |
+
neighbor_geometry_list = neighbor_info['geometry_list']
|
70 |
+
|
71 |
+
if self.random_remove_neighbor != 0:
|
72 |
+
num_neighbors = len(neighbor_name_list)
|
73 |
+
rand_neighbor = np.random.uniform(size = num_neighbors)
|
74 |
+
|
75 |
+
neighbor_keep_arr = (rand_neighbor >= self.random_remove_neighbor) # select the neighbors to be removed
|
76 |
+
neighbor_keep_arr = np.where(neighbor_keep_arr)[0]
|
77 |
+
|
78 |
+
new_neighbor_name_list, new_neighbor_geometry_list = [],[]
|
79 |
+
for i in range(0, num_neighbors):
|
80 |
+
if i in neighbor_keep_arr:
|
81 |
+
new_neighbor_name_list.append(neighbor_name_list[i])
|
82 |
+
new_neighbor_geometry_list.append(neighbor_geometry_list[i])
|
83 |
+
|
84 |
+
neighbor_name_list = new_neighbor_name_list
|
85 |
+
neighbor_geometry_list = new_neighbor_geometry_list
|
86 |
+
|
87 |
+
if self.num_neighbor_limit is not None:
|
88 |
+
neighbor_name_list = neighbor_name_list[0:self.num_neighbor_limit]
|
89 |
+
neighbor_geometry_list = neighbor_geometry_list[0:self.num_neighbor_limit]
|
90 |
+
|
91 |
+
|
92 |
+
train_data = self.parse_spatial_context(pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill )
|
93 |
+
|
94 |
+
if self.with_type:
|
95 |
+
pivot_type = line_data_dict['info'][self.type_key_str]
|
96 |
+
train_data['pivot_type'] = torch.tensor(self.label_encoder.transform([pivot_type])[0]) # scalar, label_id
|
97 |
+
|
98 |
+
if 'ogc_fid' in line_data_dict['info']:
|
99 |
+
train_data['ogc_fid'] = line_data_dict['info']['ogc_fid']
|
100 |
+
|
101 |
+
return train_data
|
102 |
+
|
103 |
+
def __len__(self):
|
104 |
+
return self.len_data
|
105 |
+
|
106 |
+
def __getitem__(self, index):
|
107 |
+
return self.load_data(index)
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
class PbfMapDatasetMarginRanking(SpatialDataset):
|
112 |
+
def __init__(self, data_file_path, type_list = None, tokenizer=None, max_token_len = 512, distance_norm_factor = 0.0001, spatial_dist_fill=10,
|
113 |
+
sep_between_neighbors = False, mode = None, num_neighbor_limit = None, random_remove_neighbor = 0., type_key_str='class'):
|
114 |
+
|
115 |
+
if tokenizer is None:
|
116 |
+
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
117 |
+
else:
|
118 |
+
self.tokenizer = tokenizer
|
119 |
+
|
120 |
+
self.type_list = type_list
|
121 |
+
self.type_key_str = type_key_str # key name of the class type in the input data dictionary
|
122 |
+
self.max_token_len = max_token_len
|
123 |
+
self.spatial_dist_fill = spatial_dist_fill # should be normalized distance fill, larger than all normalized neighbor distance
|
124 |
+
self.sep_between_neighbors = sep_between_neighbors
|
125 |
+
# self.label_encoder = label_encoder
|
126 |
+
self.num_neighbor_limit = num_neighbor_limit
|
127 |
+
self.read_file(data_file_path, mode)
|
128 |
+
self.random_remove_neighbor = random_remove_neighbor
|
129 |
+
self.mode = mode
|
130 |
+
|
131 |
+
|
132 |
+
super(PbfMapDatasetMarginRanking, self).__init__(self.tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors )
|
133 |
+
|
134 |
+
|
135 |
+
def read_file(self, data_file_path, mode):
|
136 |
+
|
137 |
+
with open(data_file_path, 'r') as f:
|
138 |
+
data = f.readlines()
|
139 |
+
|
140 |
+
if mode == 'train':
|
141 |
+
data = data[0:int(len(data) * 0.8)]
|
142 |
+
elif mode == 'test':
|
143 |
+
data = data[int(len(data) * 0.8):]
|
144 |
+
self.all_types_data = self.prepare_all_types_data()
|
145 |
+
elif mode is None: # use the full dataset (for mlm)
|
146 |
+
pass
|
147 |
+
else:
|
148 |
+
raise NotImplementedError
|
149 |
+
|
150 |
+
self.len_data = len(data) # updated data length
|
151 |
+
self.data = data
|
152 |
+
|
153 |
+
def prepare_all_types_data(self):
|
154 |
+
type_list = self.type_list
|
155 |
+
spatial_dist_fill = self.spatial_dist_fill
|
156 |
+
type_data_dict = dict()
|
157 |
+
for type_name in type_list:
|
158 |
+
type_pos = [None, None] # use filler values
|
159 |
+
type_data = self.parse_spatial_context(type_name, type_pos, pivot_dist_fill = 0.,
|
160 |
+
neighbor_name_list = [], neighbor_geometry_list=[], spatial_dist_fill= spatial_dist_fill)
|
161 |
+
type_data_dict[type_name] = type_data
|
162 |
+
|
163 |
+
return type_data_dict
|
164 |
+
|
165 |
+
def load_data(self, index):
|
166 |
+
|
167 |
+
spatial_dist_fill = self.spatial_dist_fill
|
168 |
+
line = self.data[index] # take one line from the input data according to the index
|
169 |
+
|
170 |
+
line_data_dict = json.loads(line)
|
171 |
+
|
172 |
+
# process pivot
|
173 |
+
pivot_name = line_data_dict['info']['name']
|
174 |
+
pivot_pos = line_data_dict['info']['geometry']['coordinates']
|
175 |
+
|
176 |
+
|
177 |
+
neighbor_info = line_data_dict['neighbor_info']
|
178 |
+
neighbor_name_list = neighbor_info['name_list']
|
179 |
+
neighbor_geometry_list = neighbor_info['geometry_list']
|
180 |
+
|
181 |
+
if self.random_remove_neighbor != 0:
|
182 |
+
num_neighbors = len(neighbor_name_list)
|
183 |
+
rand_neighbor = np.random.uniform(size = num_neighbors)
|
184 |
+
|
185 |
+
neighbor_keep_arr = (rand_neighbor >= self.random_remove_neighbor) # select the neighbors to be removed
|
186 |
+
neighbor_keep_arr = np.where(neighbor_keep_arr)[0]
|
187 |
+
|
188 |
+
new_neighbor_name_list, new_neighbor_geometry_list = [],[]
|
189 |
+
for i in range(0, num_neighbors):
|
190 |
+
if i in neighbor_keep_arr:
|
191 |
+
new_neighbor_name_list.append(neighbor_name_list[i])
|
192 |
+
new_neighbor_geometry_list.append(neighbor_geometry_list[i])
|
193 |
+
|
194 |
+
neighbor_name_list = new_neighbor_name_list
|
195 |
+
neighbor_geometry_list = new_neighbor_geometry_list
|
196 |
+
|
197 |
+
if self.num_neighbor_limit is not None:
|
198 |
+
neighbor_name_list = neighbor_name_list[0:self.num_neighbor_limit]
|
199 |
+
neighbor_geometry_list = neighbor_geometry_list[0:self.num_neighbor_limit]
|
200 |
+
|
201 |
+
|
202 |
+
train_data = self.parse_spatial_context(pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill )
|
203 |
+
|
204 |
+
if 'ogc_fid' in line_data_dict['info']:
|
205 |
+
train_data['ogc_fid'] = line_data_dict['info']['ogc_fid']
|
206 |
+
|
207 |
+
# train_data['pivot_type'] = torch.tensor(self.label_encoder.transform([pivot_type])[0]) # scalar, label_id
|
208 |
+
|
209 |
+
pivot_type = line_data_dict['info'][self.type_key_str]
|
210 |
+
train_data['pivot_type'] = pivot_type
|
211 |
+
|
212 |
+
if self.mode == 'train':
|
213 |
+
# postive class
|
214 |
+
postive_name = pivot_type # class type string as input to tokenizer
|
215 |
+
positive_pos = [None, None] # use filler values
|
216 |
+
postive_type_data = self.parse_spatial_context(postive_name, positive_pos, pivot_dist_fill = 0.,
|
217 |
+
neighbor_name_list = [], neighbor_geometry_list=[], spatial_dist_fill= spatial_dist_fill)
|
218 |
+
train_data['positive_type_data'] = postive_type_data
|
219 |
+
|
220 |
+
|
221 |
+
# negative class
|
222 |
+
other_type_list = self.type_list.copy()
|
223 |
+
other_type_list.remove(pivot_type)
|
224 |
+
other_type = np.random.choice(other_type_list)
|
225 |
+
negative_name = other_type
|
226 |
+
negative_pos = [None, None] # use filler values
|
227 |
+
negative_type_data = self.parse_spatial_context(negative_name, negative_pos, pivot_dist_fill = 0.,
|
228 |
+
neighbor_name_list = [], neighbor_geometry_list=[], spatial_dist_fill= spatial_dist_fill)
|
229 |
+
train_data['negative_type_data'] = negative_type_data
|
230 |
+
|
231 |
+
elif self.mode == 'test':
|
232 |
+
# return data for all class types in type_list
|
233 |
+
train_data['all_types_data'] = self.all_types_data
|
234 |
+
|
235 |
+
else:
|
236 |
+
raise NotImplementedError
|
237 |
+
|
238 |
+
return train_data
|
239 |
+
|
240 |
+
def __len__(self):
|
241 |
+
return self.len_data
|
242 |
+
|
243 |
+
def __getitem__(self, index):
|
244 |
+
return self.load_data(index)
|
245 |
+
|
246 |
+
|
models/spabert/datasets/usgs_os_sample_loader.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import numpy as np
|
4 |
+
import json
|
5 |
+
import math
|
6 |
+
from sklearn.preprocessing import LabelBinarizer, LabelEncoder
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from transformers import RobertaTokenizer, BertTokenizer
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
sys.path.append('/home/zekun/spatial_bert/spatial_bert/datasets')
|
12 |
+
from dataset_loader import SpatialDataset
|
13 |
+
|
14 |
+
import pdb
|
15 |
+
|
16 |
+
|
17 |
+
class USGS_MapDataset(SpatialDataset):
|
18 |
+
def __init__(self, data_file_path, tokenizer=None, max_token_len = 512, distance_norm_factor = 1, spatial_dist_fill=100, sep_between_neighbors = False):
|
19 |
+
|
20 |
+
if tokenizer is None:
|
21 |
+
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
22 |
+
else:
|
23 |
+
self.tokenizer = tokenizer
|
24 |
+
|
25 |
+
self.max_token_len = max_token_len
|
26 |
+
self.spatial_dist_fill = spatial_dist_fill # should be normalized distance fill, larger than all normalized neighbor distance
|
27 |
+
self.sep_between_neighbors = sep_between_neighbors
|
28 |
+
self.read_file(data_file_path)
|
29 |
+
|
30 |
+
|
31 |
+
super(USGS_MapDataset, self).__init__(self.tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors )
|
32 |
+
|
33 |
+
def read_file(self, data_file_path):
|
34 |
+
|
35 |
+
with open(data_file_path, 'r') as f:
|
36 |
+
data = f.readlines()
|
37 |
+
|
38 |
+
len_data = len(data)
|
39 |
+
self.len_data = len_data
|
40 |
+
self.data = data
|
41 |
+
|
42 |
+
|
43 |
+
def load_data(self, index):
|
44 |
+
|
45 |
+
spatial_dist_fill = self.spatial_dist_fill
|
46 |
+
line = self.data[index] # take one line from the input data according to the index
|
47 |
+
|
48 |
+
line_data_dict = json.loads(line)
|
49 |
+
|
50 |
+
# process pivot
|
51 |
+
pivot_name = line_data_dict['info']['name']
|
52 |
+
pivot_pos = line_data_dict['info']['geometry']
|
53 |
+
|
54 |
+
neighbor_info = line_data_dict['neighbor_info']
|
55 |
+
neighbor_name_list = neighbor_info['name_list']
|
56 |
+
neighbor_geometry_list = neighbor_info['geometry_list']
|
57 |
+
|
58 |
+
parsed_data = self.parse_spatial_context(pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill )
|
59 |
+
|
60 |
+
return parsed_data
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
def __len__(self):
|
65 |
+
return self.len_data
|
66 |
+
|
67 |
+
def __getitem__(self, index):
|
68 |
+
return self.load_data(index)
|
69 |
+
|
70 |
+
|
71 |
+
|
models/spabert/datasets/wikidata_sample_loader.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import numpy as np
|
3 |
+
import json
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from transformers import RobertaTokenizer, BertTokenizer
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
sys.path.append('/home/zekun/spatial_bert/spatial_bert/datasets')
|
10 |
+
from dataset_loader import SpatialDataset
|
11 |
+
|
12 |
+
import pdb
|
13 |
+
|
14 |
+
'''Prepare candiate list given randomly sampled data and append to data_list'''
|
15 |
+
class Wikidata_Random_Dataset(SpatialDataset):
|
16 |
+
def __init__(self, data_file_path, tokenizer=None, max_token_len = 512, distance_norm_factor = 0.0001, spatial_dist_fill=100, sep_between_neighbors = False):
|
17 |
+
if tokenizer is None:
|
18 |
+
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
19 |
+
else:
|
20 |
+
self.tokenizer = tokenizer
|
21 |
+
|
22 |
+
self.max_token_len = max_token_len
|
23 |
+
self.spatial_dist_fill = spatial_dist_fill # should be normalized distance fill, larger than all normalized neighbor distance
|
24 |
+
self.sep_between_neighbors = sep_between_neighbors
|
25 |
+
self.read_file(data_file_path)
|
26 |
+
|
27 |
+
|
28 |
+
super(Wikidata_Random_Dataset, self).__init__(self.tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors )
|
29 |
+
|
30 |
+
def read_file(self, data_file_path):
|
31 |
+
|
32 |
+
with open(data_file_path, 'r') as f:
|
33 |
+
data = f.readlines()
|
34 |
+
|
35 |
+
len_data = len(data)
|
36 |
+
self.len_data = len_data
|
37 |
+
self.data = data
|
38 |
+
|
39 |
+
def load_data(self, index):
|
40 |
+
|
41 |
+
spatial_dist_fill = self.spatial_dist_fill
|
42 |
+
line = self.data[index] # take one line from the input data according to the index
|
43 |
+
|
44 |
+
line_data_dict = json.loads(line)
|
45 |
+
|
46 |
+
# process pivot
|
47 |
+
pivot_name = line_data_dict['info']['name']
|
48 |
+
pivot_pos = line_data_dict['info']['geometry']['coordinates']
|
49 |
+
pivot_uri = line_data_dict['info']['uri']
|
50 |
+
|
51 |
+
neighbor_info = line_data_dict['neighbor_info']
|
52 |
+
neighbor_name_list = neighbor_info['name_list']
|
53 |
+
neighbor_geometry_list = neighbor_info['geometry_list']
|
54 |
+
|
55 |
+
parsed_data = self.parse_spatial_context(pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill )
|
56 |
+
parsed_data['uri'] = pivot_uri
|
57 |
+
parsed_data['description'] = None # placeholder
|
58 |
+
|
59 |
+
return parsed_data
|
60 |
+
|
61 |
+
def __len__(self):
|
62 |
+
return self.len_data
|
63 |
+
|
64 |
+
def __getitem__(self, index):
|
65 |
+
return self.load_data(index)
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
'''Prepare candiate list for each phrase and append to data_list'''
|
70 |
+
|
71 |
+
class Wikidata_Geocoord_Dataset(SpatialDataset):
|
72 |
+
|
73 |
+
#DEFAULT_CONFIG_CLS = SpatialBertConfig
|
74 |
+
def __init__(self, data_file_path, tokenizer=None, max_token_len = 512, distance_norm_factor = 0.0001, spatial_dist_fill=100, sep_between_neighbors = False):
|
75 |
+
if tokenizer is None:
|
76 |
+
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
77 |
+
else:
|
78 |
+
self.tokenizer = tokenizer
|
79 |
+
|
80 |
+
self.max_token_len = max_token_len
|
81 |
+
self.spatial_dist_fill = spatial_dist_fill # should be normalized distance fill, larger than all normalized neighbor distance
|
82 |
+
self.sep_between_neighbors = sep_between_neighbors
|
83 |
+
self.read_file(data_file_path)
|
84 |
+
|
85 |
+
|
86 |
+
super(Wikidata_Geocoord_Dataset, self).__init__(self.tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors )
|
87 |
+
|
88 |
+
def read_file(self, data_file_path):
|
89 |
+
|
90 |
+
with open(data_file_path, 'r') as f:
|
91 |
+
data = f.readlines()
|
92 |
+
|
93 |
+
len_data = len(data)
|
94 |
+
self.len_data = len_data
|
95 |
+
self.data = data
|
96 |
+
|
97 |
+
def load_data(self, index):
|
98 |
+
|
99 |
+
spatial_dist_fill = self.spatial_dist_fill
|
100 |
+
line = self.data[index] # take one line from the input data according to the index
|
101 |
+
|
102 |
+
line_data = json.loads(line)
|
103 |
+
parsed_data_list = []
|
104 |
+
|
105 |
+
for line_data_dict in line_data:
|
106 |
+
# process pivot
|
107 |
+
pivot_name = line_data_dict['info']['name']
|
108 |
+
pivot_pos = line_data_dict['info']['geometry']['coordinates']
|
109 |
+
pivot_uri = line_data_dict['info']['uri']
|
110 |
+
|
111 |
+
neighbor_info = line_data_dict['neighbor_info']
|
112 |
+
neighbor_name_list = neighbor_info['name_list']
|
113 |
+
neighbor_geometry_list = neighbor_info['geometry_list']
|
114 |
+
|
115 |
+
parsed_data = self.parse_spatial_context(pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill )
|
116 |
+
parsed_data['uri'] = pivot_uri
|
117 |
+
parsed_data['description'] = None # placeholder
|
118 |
+
parsed_data_list.append(parsed_data)
|
119 |
+
|
120 |
+
return parsed_data_list
|
121 |
+
|
122 |
+
|
123 |
+
def __len__(self):
|
124 |
+
return self.len_data
|
125 |
+
|
126 |
+
def __getitem__(self, index):
|
127 |
+
return self.load_data(index)
|
models/spabert/experiments/__init__.py
ADDED
File without changes
|
models/spabert/experiments/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (155 Bytes). View file
|
|
models/spabert/experiments/entity_matching/__init__.py
ADDED
File without changes
|
models/spabert/experiments/entity_matching/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (171 Bytes). View file
|
|
models/spabert/experiments/entity_matching/data_processing/__init__.py
ADDED
File without changes
|
models/spabert/experiments/entity_matching/data_processing/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (187 Bytes). View file
|
|
models/spabert/experiments/entity_matching/data_processing/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (204 Bytes). View file
|
|
models/spabert/experiments/entity_matching/data_processing/__pycache__/request_wrapper.cpython-310.pyc
ADDED
Binary file (5.97 kB). View file
|
|
models/spabert/experiments/entity_matching/data_processing/__pycache__/request_wrapper.cpython-311.pyc
ADDED
Binary file (7.83 kB). View file
|
|
models/spabert/experiments/entity_matching/data_processing/get_namelist.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
def get_name_list_osm(ref_paths):
|
5 |
+
name_list = []
|
6 |
+
|
7 |
+
for json_path in ref_paths:
|
8 |
+
with open(json_path, 'r') as f:
|
9 |
+
data = f.readlines()
|
10 |
+
for line in data:
|
11 |
+
record = json.loads(line)
|
12 |
+
name = record['name']
|
13 |
+
name_list.append(name)
|
14 |
+
|
15 |
+
namelist = sorted(namelist)
|
16 |
+
return name_list
|
17 |
+
|
18 |
+
# deprecated
|
19 |
+
def get_name_list_usgs_od(ref_paths):
|
20 |
+
name_list = []
|
21 |
+
|
22 |
+
for json_path in ref_paths:
|
23 |
+
with open(json_path, 'r') as f:
|
24 |
+
annot_dict = json.load(f)
|
25 |
+
for key, place in annot_dict.items():
|
26 |
+
place_name = ''
|
27 |
+
for idx in range(1, len(place)+1):
|
28 |
+
try:
|
29 |
+
place_name += place[str(idx)]['text_label']
|
30 |
+
place_name += ' ' # separate words with spaces
|
31 |
+
|
32 |
+
except Exception as e:
|
33 |
+
print(place)
|
34 |
+
place_name = place_name[:-1] # remove last space
|
35 |
+
|
36 |
+
name_list.append(place_name)
|
37 |
+
|
38 |
+
namelist = sorted(namelist)
|
39 |
+
return name_list
|
40 |
+
|
41 |
+
def get_name_list_usgs_od_per_map(ref_paths):
|
42 |
+
all_name_list_dict = dict()
|
43 |
+
|
44 |
+
for json_path in ref_paths:
|
45 |
+
map_name = os.path.basename(json_path).split('.json')[0]
|
46 |
+
|
47 |
+
with open(json_path, 'r') as f:
|
48 |
+
annot_dict = json.load(f)
|
49 |
+
|
50 |
+
map_name_list = []
|
51 |
+
for key, place in annot_dict.items():
|
52 |
+
place_name = ''
|
53 |
+
for idx in range(1, len(place)+1):
|
54 |
+
try:
|
55 |
+
place_name += place[str(idx)]['text_label']
|
56 |
+
place_name += ' ' # separate words with spaces
|
57 |
+
|
58 |
+
except Exception as e:
|
59 |
+
print(place)
|
60 |
+
place_name = place_name[:-1] # remove last space
|
61 |
+
|
62 |
+
map_name_list.append(place_name)
|
63 |
+
all_name_list_dict[map_name] = sorted(map_name_list)
|
64 |
+
|
65 |
+
return all_name_list_dict
|
66 |
+
|
67 |
+
|
68 |
+
def get_name_list_gb1900(ref_path):
|
69 |
+
name_list = []
|
70 |
+
|
71 |
+
with open(ref_path, 'r',encoding='utf-16') as f:
|
72 |
+
data = f.readlines()
|
73 |
+
|
74 |
+
|
75 |
+
for line in data[1:]: # skip the header
|
76 |
+
try:
|
77 |
+
line = line.split(',')
|
78 |
+
text = line[1]
|
79 |
+
lat = float(line[-3])
|
80 |
+
lng = float(line[-2])
|
81 |
+
semantic_type = line[-1]
|
82 |
+
|
83 |
+
name_list.append(text)
|
84 |
+
except:
|
85 |
+
print(line)
|
86 |
+
|
87 |
+
namelist = sorted(namelist)
|
88 |
+
|
89 |
+
return name_list
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == '__main__':
|
93 |
+
#name_list = get_name_list_usgs_od(['labGISReport-master/output/USGS-15-CA-brawley-e1957-s1957-p1961.json',
|
94 |
+
#'labGISReport-master/output/USGS-15-CA-capesanmartin-e1921-s1917.json'])
|
95 |
+
name_list = get_name_list_gb1900('data/GB1900_gazetteer_abridged_july_2018/gb1900_abridged.csv')
|
models/spabert/experiments/entity_matching/data_processing/request_wrapper.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import pdb
|
3 |
+
import time
|
4 |
+
|
5 |
+
# for linkedgeodata: #http://linkedgeodata.org/sparql
|
6 |
+
|
7 |
+
class RequestWrapper:
|
8 |
+
def __init__(self, baseuri = "https://query.wikidata.org/sparql"):
|
9 |
+
|
10 |
+
self.baseuri = baseuri
|
11 |
+
|
12 |
+
def response_handler(self, response, query):
|
13 |
+
if response.status_code == requests.codes.ok:
|
14 |
+
ret_json = response.json()['results']['bindings']
|
15 |
+
elif response.status_code == 500:
|
16 |
+
ret_json = []
|
17 |
+
#print(q_id)
|
18 |
+
print('Internal Error happened. Set ret_json to be empty list')
|
19 |
+
|
20 |
+
elif response.status_code == 429:
|
21 |
+
|
22 |
+
print(response.status_code)
|
23 |
+
print(response.text)
|
24 |
+
retry_seconds = int(response.text.split('Too Many Requests - Please retry in ')[1].split(' seconds')[0])
|
25 |
+
print('rerun in %d seconds' %retry_seconds)
|
26 |
+
time.sleep(retry_seconds + 1)
|
27 |
+
|
28 |
+
response = requests.get(self.baseuri, params = {'format':'json', 'query':query})
|
29 |
+
ret_json = response.json()['results']['bindings']
|
30 |
+
#print(ret_json)
|
31 |
+
print('resumed and succeeded')
|
32 |
+
|
33 |
+
else:
|
34 |
+
print(response.status_code, response.text)
|
35 |
+
exit(-1)
|
36 |
+
|
37 |
+
return ret_json
|
38 |
+
|
39 |
+
'''Search for wikidata entities given the name string'''
|
40 |
+
def wikidata_query (self, name_str):
|
41 |
+
|
42 |
+
query = """
|
43 |
+
PREFIX wd: <http://www.wikidata.org/entity/>
|
44 |
+
PREFIX wds: <http://www.wikidata.org/entity/statement/>
|
45 |
+
PREFIX wdv: <http://www.wikidata.org/value/>
|
46 |
+
PREFIX wdt: <http://www.wikidata.org/prop/direct/>
|
47 |
+
PREFIX wikibase: <http://wikiba.se/ontology#>
|
48 |
+
PREFIX p: <http://www.wikidata.org/prop/>
|
49 |
+
PREFIX ps: <http://www.wikidata.org/prop/statement/>
|
50 |
+
PREFIX pq: <http://www.wikidata.org/prop/qualifier/>
|
51 |
+
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
|
52 |
+
PREFIX bd: <http://www.bigdata.com/rdf#>
|
53 |
+
|
54 |
+
SELECT ?item ?coordinates ?itemDescription WHERE {
|
55 |
+
?item rdfs:label \"%s\"@en;
|
56 |
+
wdt:P625 ?coordinates .
|
57 |
+
SERVICE wikibase:label { bd:serviceParam wikibase:language "en" }
|
58 |
+
}
|
59 |
+
"""%(name_str)
|
60 |
+
|
61 |
+
response = requests.get(self.baseuri, params = {'format':'json', 'query':query})
|
62 |
+
|
63 |
+
|
64 |
+
ret_json = self.response_handler(response, query)
|
65 |
+
|
66 |
+
return ret_json
|
67 |
+
|
68 |
+
|
69 |
+
'''Search for wikidata entities given the name string'''
|
70 |
+
def wikidata_query_withinstate (self, name_str, state_id = 'Q99'):
|
71 |
+
|
72 |
+
|
73 |
+
query = """
|
74 |
+
PREFIX wd: <http://www.wikidata.org/entity/>
|
75 |
+
PREFIX wds: <http://www.wikidata.org/entity/statement/>
|
76 |
+
PREFIX wdv: <http://www.wikidata.org/value/>
|
77 |
+
PREFIX wdt: <http://www.wikidata.org/prop/direct/>
|
78 |
+
PREFIX wikibase: <http://wikiba.se/ontology#>
|
79 |
+
PREFIX p: <http://www.wikidata.org/prop/>
|
80 |
+
PREFIX ps: <http://www.wikidata.org/prop/statement/>
|
81 |
+
PREFIX pq: <http://www.wikidata.org/prop/qualifier/>
|
82 |
+
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
|
83 |
+
PREFIX bd: <http://www.bigdata.com/rdf#>
|
84 |
+
|
85 |
+
SELECT ?item ?coordinates ?itemDescription WHERE {
|
86 |
+
?item rdfs:label \"%s\"@en;
|
87 |
+
wdt:P625 ?coordinates ;
|
88 |
+
wdt:P131+ wd:%s;
|
89 |
+
SERVICE wikibase:label { bd:serviceParam wikibase:language "en" }
|
90 |
+
}
|
91 |
+
"""%(name_str, state_id)
|
92 |
+
|
93 |
+
#print(query)
|
94 |
+
|
95 |
+
response = requests.get(self.baseuri, params = {'format':'json', 'query':query})
|
96 |
+
|
97 |
+
ret_json = self.response_handler(response, query)
|
98 |
+
|
99 |
+
return ret_json
|
100 |
+
|
101 |
+
|
102 |
+
'''Search for nearby wikidata entities given the entity id'''
|
103 |
+
def wikidata_nearby_query (self, q_id):
|
104 |
+
|
105 |
+
query = """
|
106 |
+
PREFIX wd: <http://www.wikidata.org/entity/>
|
107 |
+
PREFIX wds: <http://www.wikidata.org/entity/statement/>
|
108 |
+
PREFIX wdv: <http://www.wikidata.org/value/>
|
109 |
+
PREFIX wdt: <http://www.wikidata.org/prop/direct/>
|
110 |
+
PREFIX wikibase: <http://wikiba.se/ontology#>
|
111 |
+
PREFIX p: <http://www.wikidata.org/prop/>
|
112 |
+
PREFIX ps: <http://www.wikidata.org/prop/statement/>
|
113 |
+
PREFIX pq: <http://www.wikidata.org/prop/qualifier/>
|
114 |
+
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
|
115 |
+
PREFIX bd: <http://www.bigdata.com/rdf#>
|
116 |
+
|
117 |
+
SELECT ?place ?placeLabel ?location ?instanceLabel ?placeDescription
|
118 |
+
WHERE
|
119 |
+
{
|
120 |
+
wd:%s wdt:P625 ?loc .
|
121 |
+
SERVICE wikibase:around {
|
122 |
+
?place wdt:P625 ?location .
|
123 |
+
bd:serviceParam wikibase:center ?loc .
|
124 |
+
bd:serviceParam wikibase:radius "5" .
|
125 |
+
}
|
126 |
+
OPTIONAL { ?place wdt:P31 ?instance }
|
127 |
+
SERVICE wikibase:label { bd:serviceParam wikibase:language "en" }
|
128 |
+
BIND(geof:distance(?loc, ?location) as ?dist)
|
129 |
+
} ORDER BY ?dist
|
130 |
+
LIMIT 200
|
131 |
+
"""%(q_id)
|
132 |
+
# initially 2km
|
133 |
+
|
134 |
+
#pdb.set_trace()
|
135 |
+
|
136 |
+
response = requests.get(self.baseuri, params = {'format':'json', 'query':query})
|
137 |
+
|
138 |
+
|
139 |
+
ret_json = self.response_handler(response, query)
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
return ret_json
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
def linkedgeodata_query (self, name_str):
|
149 |
+
|
150 |
+
query = """
|
151 |
+
|
152 |
+
Prefix lgdo: <http://linkedgeodata.org/ontology/>
|
153 |
+
Prefix geom: <http://geovocab.org/geometry#>
|
154 |
+
Prefix ogc: <http://www.opengis.net/ont/geosparql#>
|
155 |
+
Prefix owl: <http://www.w3.org/2002/07/owl#>
|
156 |
+
Prefix wgs84_pos: <http://www.w3.org/2003/01/geo/wgs84_pos#>
|
157 |
+
Prefix owl: <http://www.w3.org/2002/07/owl#>
|
158 |
+
Prefix gn: <http://www.geonames.org/ontology#>
|
159 |
+
|
160 |
+
Select ?s, ?lat, ?long {
|
161 |
+
{?s rdfs:label \"%s\";
|
162 |
+
wgs84_pos:lat ?lat ;
|
163 |
+
wgs84_pos:long ?long;
|
164 |
+
}
|
165 |
+
}
|
166 |
+
"""%(name_str)
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
response = requests.get(self.baseuri, params = {'format':'json', 'query':query})
|
171 |
+
|
172 |
+
|
173 |
+
ret_json = self.response_handler(response, query)
|
174 |
+
|
175 |
+
|
176 |
+
return ret_json
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
if __name__ == '__main__':
|
181 |
+
request_wrapper_wikidata = RequestWrapper(baseuri = 'https://query.wikidata.org/sparql')
|
182 |
+
#print(request_wrapper_wikidata.wikidata_nearby_query('Q370771'))
|
183 |
+
#print(request_wrapper_wikidata.wikidata_query_withinstate('San Bernardino'))
|
184 |
+
|
185 |
+
# not working now
|
186 |
+
print(request_wrapper_wikidata.linkedgeodata_query('San Bernardino'))
|
models/spabert/experiments/entity_matching/data_processing/run_linking_query.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#from query_wrapper import QueryWrapper
|
2 |
+
from request_wrapper import RequestWrapper
|
3 |
+
from get_namelist import *
|
4 |
+
import glob
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
|
9 |
+
DATASET_OPTIONS = ['OSM', 'OS', 'USGS', 'GB1900']
|
10 |
+
KB_OPTIONS = ['wikidata', 'linkedgeodata']
|
11 |
+
|
12 |
+
DATASET = 'USGS'
|
13 |
+
KB = 'wikidata'
|
14 |
+
OVERWRITE = True
|
15 |
+
WITHIN_CA = True
|
16 |
+
|
17 |
+
assert DATASET in DATASET_OPTIONS
|
18 |
+
assert KB in KB_OPTIONS
|
19 |
+
|
20 |
+
|
21 |
+
def process_one_namelist(sparql_wrapper, namelist, out_path):
|
22 |
+
|
23 |
+
if OVERWRITE:
|
24 |
+
# flush the file if it's been written
|
25 |
+
with open(out_path, 'w') as f:
|
26 |
+
f.write('')
|
27 |
+
|
28 |
+
|
29 |
+
for name in namelist:
|
30 |
+
name = name.replace('"', '')
|
31 |
+
name = name.strip("'")
|
32 |
+
if len(name) == 0:
|
33 |
+
continue
|
34 |
+
print(name)
|
35 |
+
mydict = dict()
|
36 |
+
|
37 |
+
if KB == 'wikidata':
|
38 |
+
if WITHIN_CA:
|
39 |
+
mydict[name] = sparql_wrapper.wikidata_query_withinstate(name)
|
40 |
+
else:
|
41 |
+
mydict[name] = sparql_wrapper.wikidata_query(name)
|
42 |
+
|
43 |
+
elif KB == 'linkedgeodata':
|
44 |
+
mydict[name] = sparql_wrapper.linkedgeodata_query(name)
|
45 |
+
else:
|
46 |
+
raise NotImplementedError
|
47 |
+
|
48 |
+
line = json.dumps(mydict)
|
49 |
+
|
50 |
+
with open(out_path, 'a') as f:
|
51 |
+
f.write(line)
|
52 |
+
f.write('\n')
|
53 |
+
time.sleep(1)
|
54 |
+
|
55 |
+
|
56 |
+
def process_namelist_dict(sparql_wrapper, namelist_dict, out_dir):
|
57 |
+
i = 0
|
58 |
+
for map_name, namelist in namelist_dict.items():
|
59 |
+
# if i <=5:
|
60 |
+
# i += 1
|
61 |
+
# continue
|
62 |
+
|
63 |
+
print('processing %s' %map_name)
|
64 |
+
|
65 |
+
if WITHIN_CA:
|
66 |
+
out_path = os.path.join(out_dir, KB + '_' + map_name + '.json')
|
67 |
+
else:
|
68 |
+
out_path = os.path.join(out_dir, KB + '_ca_' + map_name + '.json')
|
69 |
+
|
70 |
+
process_one_namelist(sparql_wrapper, namelist, out_path)
|
71 |
+
i+=1
|
72 |
+
|
73 |
+
|
74 |
+
if KB == 'linkedgeodata':
|
75 |
+
sparql_wrapper = RequestWrapper(baseuri = 'http://linkedgeodata.org/sparql')
|
76 |
+
elif KB == 'wikidata':
|
77 |
+
sparql_wrapper = RequestWrapper(baseuri = 'https://query.wikidata.org/sparql')
|
78 |
+
else:
|
79 |
+
raise NotImplementedError
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
if DATASET == 'OSM':
|
84 |
+
osm_dir = '../surface_form/data_sample_london/data_osm/'
|
85 |
+
osm_paths = glob.glob(os.path.join(osm_dir, 'embedding*.json'))
|
86 |
+
|
87 |
+
out_path = 'outputs/'+KB+'_linking.json'
|
88 |
+
namelist = get_name_list_osm(osm_paths)
|
89 |
+
|
90 |
+
print('# files',len(file_paths))
|
91 |
+
|
92 |
+
process_one_namelist(sparql_wrapper, namelist, out_path)
|
93 |
+
|
94 |
+
|
95 |
+
elif DATASET == 'OS':
|
96 |
+
histmap_dir = 'data/labGISReport-master/output/'
|
97 |
+
file_paths = glob.glob(os.path.join(histmap_dir, '10*.json'))
|
98 |
+
|
99 |
+
out_path = 'outputs/'+KB+'_os_linking_descript.json'
|
100 |
+
namelist = get_name_list_usgs_od(file_paths)
|
101 |
+
|
102 |
+
print('# files',len(file_paths))
|
103 |
+
|
104 |
+
|
105 |
+
process_one_namelist(sparql_wrapper, namelist, out_path)
|
106 |
+
|
107 |
+
elif DATASET == 'USGS':
|
108 |
+
histmap_dir = 'data/labGISReport-master/output/'
|
109 |
+
file_paths = glob.glob(os.path.join(histmap_dir, 'USGS*.json'))
|
110 |
+
|
111 |
+
if WITHIN_CA:
|
112 |
+
out_dir = 'outputs/' + KB +'_ca'
|
113 |
+
else:
|
114 |
+
out_dir = 'outputs/' + KB
|
115 |
+
namelist_dict = get_name_list_usgs_od_per_map(file_paths)
|
116 |
+
|
117 |
+
if not os.path.isdir(out_dir):
|
118 |
+
os.makedirs(out_dir)
|
119 |
+
|
120 |
+
print('# files',len(file_paths))
|
121 |
+
|
122 |
+
process_namelist_dict(sparql_wrapper, namelist_dict, out_dir)
|
123 |
+
|
124 |
+
elif DATASET == 'GB1900':
|
125 |
+
|
126 |
+
file_path = 'data/GB1900_gazetteer_abridged_july_2018/gb1900_abridged.csv'
|
127 |
+
out_path = 'outputs/'+KB+'_gb1900_linking_descript.json'
|
128 |
+
namelist = get_name_list_gb1900(file_path)
|
129 |
+
|
130 |
+
|
131 |
+
process_one_namelist(sparql_wrapper, namelist, out_path)
|
132 |
+
|
133 |
+
else:
|
134 |
+
raise NotImplementedError
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
#namelist = namelist[730:] #for GB1900
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
print('done')
|
models/spabert/experiments/entity_matching/data_processing/run_map_neighbor_query.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from query_wrapper import QueryWrapper
|
2 |
+
from request_wrapper import RequestWrapper
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
import time
|
7 |
+
import random
|
8 |
+
|
9 |
+
DATASET_OPTIONS = ['OSM', 'OS', 'USGS', 'GB1900']
|
10 |
+
KB_OPTIONS = ['wikidata', 'linkedgeodata']
|
11 |
+
|
12 |
+
dataset = 'USGS'
|
13 |
+
kb = 'wikidata'
|
14 |
+
overwrite = False
|
15 |
+
|
16 |
+
assert dataset in DATASET_OPTIONS
|
17 |
+
assert kb in KB_OPTIONS
|
18 |
+
|
19 |
+
if dataset == 'OSM':
|
20 |
+
raise NotImplementedError
|
21 |
+
|
22 |
+
elif dataset == 'OS':
|
23 |
+
raise NotImplementedError
|
24 |
+
|
25 |
+
elif dataset == 'USGS':
|
26 |
+
|
27 |
+
candidate_file_paths = glob.glob('outputs/alignment_dir/wikidata_USGS*.json')
|
28 |
+
candidate_file_paths = sorted(candidate_file_paths)
|
29 |
+
|
30 |
+
out_dir = 'outputs/wikidata_neighbors/'
|
31 |
+
|
32 |
+
if not os.path.isdir(out_dir):
|
33 |
+
os.makedirs(out_dir)
|
34 |
+
|
35 |
+
elif dataset == 'GB1900':
|
36 |
+
|
37 |
+
raise NotImplementedError
|
38 |
+
|
39 |
+
else:
|
40 |
+
raise NotImplementedError
|
41 |
+
|
42 |
+
|
43 |
+
if kb == 'linkedgeodata':
|
44 |
+
sparql_wrapper = QueryWrapper(baseuri = 'http://linkedgeodata.org/sparql')
|
45 |
+
elif kb == 'wikidata':
|
46 |
+
#sparql_wrapper = QueryWrapper(baseuri = 'https://query.wikidata.org/sparql')
|
47 |
+
sparql_wrapper = RequestWrapper(baseuri = 'https://query.wikidata.org/sparql')
|
48 |
+
else:
|
49 |
+
raise NotImplementedError
|
50 |
+
|
51 |
+
start_map = 6 # 6
|
52 |
+
start_line = 4 # 4
|
53 |
+
|
54 |
+
for candiate_file_path in candidate_file_paths[start_map:]:
|
55 |
+
map_name = os.path.basename(candiate_file_path).split('wikidata_')[1]
|
56 |
+
out_path = os.path.join(out_dir, 'wikidata_' + map_name)
|
57 |
+
|
58 |
+
with open(candiate_file_path, 'r') as f:
|
59 |
+
cand_data = f.readlines()
|
60 |
+
|
61 |
+
with open(out_path, 'a') as out_f:
|
62 |
+
for line in cand_data[start_line:]:
|
63 |
+
line_dict = json.loads(line)
|
64 |
+
ret_line_dict = dict()
|
65 |
+
for key, value in line_dict.items(): # actually just one pair
|
66 |
+
|
67 |
+
print(key)
|
68 |
+
|
69 |
+
place_name = key
|
70 |
+
for cand_entity in value:
|
71 |
+
time.sleep(2)
|
72 |
+
q_id = cand_entity['item']['value'].split('/')[-1]
|
73 |
+
response = sparql_wrapper.wikidata_nearby_query(str(q_id))
|
74 |
+
|
75 |
+
if place_name in ret_line_dict:
|
76 |
+
ret_line_dict[place_name].append(response)
|
77 |
+
else:
|
78 |
+
ret_line_dict[place_name] = [response]
|
79 |
+
|
80 |
+
#time.sleep(random.random()*6)
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
out_f.write(json.dumps(ret_line_dict))
|
85 |
+
out_f.write('\n')
|
86 |
+
|
87 |
+
print('finished with ',candiate_file_path)
|
88 |
+
break
|
89 |
+
|
90 |
+
print('done')
|
91 |
+
|
92 |
+
'''
|
93 |
+
{"Martin": [{"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q27001"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(18.921388888 49.063611111)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "city in northern Slovakia"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q281028"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-101.734166666 43.175)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "city in and county seat of Bennett County, South Dakota, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q761390"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(1.3394 51.179)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "hamlet in Kent"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q2177502"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-100.115 47.826666666)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "city in North Dakota"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q2454021"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-85.64168 42.53698)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "village in Michigan, USA"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q2481111"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-93.21833 32.09917)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "village in Red River Parish, Louisiana, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q2635473"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-82.75944 37.56778)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "city in Kentucky, USA"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q2679547"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-1.9041 50.9759)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "village in Hampshire, England, United Kingdom"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q2780056"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-88.851666666 36.341944444)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "city in Tennessee, USA"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q3261150"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-83.18556 34.48639)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "town in Franklin and Stephens Counties, Georgia, USA"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q6002227"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-101.709 41.2581)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "census-designated place in Nebraska, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q6774807"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-82.1906 29.2936)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "unincorporated community in Marion County, Florida, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q6774809"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-83.336666666 41.5575)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "unincorporated community in Ohio, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q6774810"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(116.037 -32.071)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "suburb of Perth, Western Australia"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q9029707"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-81.476388888 33.069166666)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "unincorporated community in South Carolina, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q11770660"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-0.325391 53.1245)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "village and civil parish in Lincolnshire, UK"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q14692833"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-93.5444 47.4894)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "unincorporated community in Itasca County, Minnesota"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q14714180"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-79.0889 39.2242)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "unincorporated community in Grant County, West Virginia"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q18496647"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(11.95941 46.34585)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "human settlement in Italy"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q20949553"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(22.851944444 41.954444444)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "mountain in Republic of Macedonia"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q24065096"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "<http://www.wikidata.org/entity/Q111> Point(290.75 -21.34)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "crater on Mars"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q26300074"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-0.97879914 51.74729077)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "Thame, South Oxfordshire, Oxfordshire, OX9"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q27988822"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-87.67361111 38.12444444)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "human settlement in Vanderburgh County, Indiana, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q27995389"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-121.317 47.28)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "human settlement in Washington, United States of America"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q28345614"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-76.8 48.5)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "human settlement in Senneterre, Quebec, Canada"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q30626037"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-79.90972222 39.80638889)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "human settlement in United States of America"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q61038281"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-91.13 49.25)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "Meteorological Service of Canada's station for Martin (MSC ID: 6035000), Ontario, Canada"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q63526691"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(152.7011111 -29.881666666)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "Parish of Fitzroy County, New South Wales, Australia"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q63526695"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(148.1011111 -33.231666666)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "Parish of Ashburnham County, New South Wales, Australia"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q96149222"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(14.725573779 48.76768332)"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q96158116"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(15.638358708 49.930136157)"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q103777024"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-3.125028822 58.694748091)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "Shipwreck off the Scottish Coast, imported from Canmore Nov 2020"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q107077206"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(11.688165 46.582982)"}}]}
|
94 |
+
|
95 |
+
|
96 |
+
if overwrite:
|
97 |
+
# flush the file if it's been written
|
98 |
+
with open(out_path, 'w') as f:
|
99 |
+
f.write('')
|
100 |
+
|
101 |
+
for name in namelist:
|
102 |
+
name = name.replace('"', '')
|
103 |
+
name = name.strip("'")
|
104 |
+
if len(name) == 0:
|
105 |
+
continue
|
106 |
+
print(name)
|
107 |
+
mydict = dict()
|
108 |
+
mydict[name] = sparql_wrapper.wikidata_query(name)
|
109 |
+
line = json.dumps(mydict)
|
110 |
+
#print(line)
|
111 |
+
with open(out_path, 'a') as f:
|
112 |
+
f.write(line)
|
113 |
+
f.write('\n')
|
114 |
+
time.sleep(1)
|
115 |
+
|
116 |
+
print('done')
|
117 |
+
|
118 |
+
|
119 |
+
{"info": {"name": "10TH ST", "geometry": [4193.118085062303, -831.274950414831]},
|
120 |
+
"neighbor_info":
|
121 |
+
{"name_list": ["BM 107", "PALM AVE", "WT", "Hidalgo Sch", "PO", "MAIN ST", "BRYANT CANAL", "BM 123", "Oakley Sch", "BRAWLEY", "Witter Sch", "BM 104", "Pistol Range", "Reid Sch", "STANLEY", "MUNICIPAL AIRPORT", "WESTERN AVE", "CANAL", "Riverview Cem", "BEST CANAL"],
|
122 |
+
"geometry_list": [[4180.493095652702, -836.0635465095995], [4240.450935702045, -855.345637906981], [4136.084840542623, -917.7895986922882], [4150.386997979736, -948.7258091165079], [4056.955267048625, -847.1018277439381], [4008.112642182582, -849.089249977583], [4124.177447575567, -1004.0706369942257], [4145.382175508665, -626.1608201557082], [4398.137868976953, -764.1087236140554], [4221.1546492913285, -1062.5745271963772], [4015.203890157584, -985.0178210457995], [3989.2345421184878, -948.9340389243871], [4385.585449075614, -660.4590917125413], [3936.505159635338, -803.6822663422273], [3960.1233867112846, -686.7988766730389], [4409.714306709143, -600.6633389979504], [3871.2873706574037, -832.0785684368772], [4304.899727301024, -524.472390102557], [3955.640201659347, -578.5544271698675], [4075.8524354668034, -1183.5837385075774]]}}
|
123 |
+
'''
|
models/spabert/experiments/entity_matching/data_processing/run_query_sample.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from query_wrapper import QueryWrapper
|
2 |
+
from get_namelist import *
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
import time
|
7 |
+
|
8 |
+
|
9 |
+
#sparql_wrapper_linkedgeo = QueryWrapper(baseuri = 'http://linkedgeodata.org/sparql')
|
10 |
+
|
11 |
+
#print(sparql_wrapper_linkedgeo.linkedgeodata_query('Los Angeles'))
|
12 |
+
|
13 |
+
|
14 |
+
sparql_wrapper_wikidata = QueryWrapper(baseuri = 'https://query.wikidata.org/sparql')
|
15 |
+
|
16 |
+
#print(sparql_wrapper_wikidata.wikidata_query('Los Angeles'))
|
17 |
+
|
18 |
+
#time.sleep(3)
|
19 |
+
|
20 |
+
#print(sparql_wrapper_wikidata.wikidata_nearby_query('Q370771'))
|
21 |
+
print(sparql_wrapper_wikidata.wikidata_nearby_query('Q97625145'))
|
22 |
+
|
models/spabert/experiments/entity_matching/data_processing/run_wikidata_neighbor_query.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import json
|
3 |
+
from request_wrapper import RequestWrapper
|
4 |
+
import time
|
5 |
+
import pdb
|
6 |
+
|
7 |
+
start_idx = 17335
|
8 |
+
wikidata_sample30k_path = 'wikidata_sample30k/wikidata_30k.json'
|
9 |
+
out_path = 'wikidata_sample30k/wikidata_30k_neighbor.json'
|
10 |
+
|
11 |
+
#with open(out_path, 'w') as out_f:
|
12 |
+
# pass
|
13 |
+
|
14 |
+
sparql_wrapper = RequestWrapper(baseuri = 'https://query.wikidata.org/sparql')
|
15 |
+
|
16 |
+
df= pd.read_json(wikidata_sample30k_path)
|
17 |
+
df = df[start_idx:]
|
18 |
+
|
19 |
+
print('length of df:', len(df))
|
20 |
+
|
21 |
+
for index, record in df.iterrows():
|
22 |
+
print(index)
|
23 |
+
uri = record.results['item']['value']
|
24 |
+
q_id = uri.split('/')[-1]
|
25 |
+
response = sparql_wrapper.wikidata_nearby_query(str(q_id))
|
26 |
+
time.sleep(1)
|
27 |
+
with open(out_path, 'a') as out_f:
|
28 |
+
out_f.write(json.dumps(response))
|
29 |
+
out_f.write('\n')
|
30 |
+
|
31 |
+
|
models/spabert/experiments/entity_matching/data_processing/samples.sparql
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Entities near xxx within 1km
|
3 |
+
SELECT ?place ?placeLabel ?location ?instanceLabel ?placeDescription
|
4 |
+
WHERE
|
5 |
+
{
|
6 |
+
wd:Q9188 wdt:P625 ?loc .
|
7 |
+
SERVICE wikibase:around {
|
8 |
+
?place wdt:P625 ?location .
|
9 |
+
bd:serviceParam wikibase:center ?loc .
|
10 |
+
bd:serviceParam wikibase:radius "1" .
|
11 |
+
}
|
12 |
+
OPTIONAL { ?place wdt:P31 ?instance }
|
13 |
+
SERVICE wikibase:label { bd:serviceParam wikibase:language "en" }
|
14 |
+
BIND(geof:distance(?loc, ?location) as ?dist)
|
15 |
+
} ORDER BY ?dist
|
16 |
+
'''
|
17 |
+
|
18 |
+
|
19 |
+
SELECT distinct ?item ?itemLabel WHERE {
|
20 |
+
?item wdt:P625 ?geo .
|
21 |
+
SERVICE wikibase:label { bd:serviceParam wikibase:language "[AUTO_LANGUAGE],en". }
|
22 |
+
}
|
models/spabert/experiments/entity_matching/data_processing/select_ambi.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
json_file = 'entity_linking/outputs/wikidata_usgs_linking_descript.json'
|
4 |
+
|
5 |
+
with open(json_file, 'r') as f:
|
6 |
+
data = f.readlines()
|
7 |
+
|
8 |
+
num_ambi = 0
|
9 |
+
for line in data:
|
10 |
+
line_dict = json.loads(line)
|
11 |
+
for key,value in line_dict.items():
|
12 |
+
len_value = len(value)
|
13 |
+
if len_value < 2:
|
14 |
+
continue
|
15 |
+
else:
|
16 |
+
num_ambi += 1
|
17 |
+
print(key)
|
18 |
+
print(num_ambi)
|
models/spabert/experiments/entity_matching/data_processing/wikidata_sample30k/wikidata_30k.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/spabert/experiments/entity_matching/src/evaluation-mrr.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
|
5 |
+
import sys
|
6 |
+
import os
|
7 |
+
import glob
|
8 |
+
import json
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
import pdb
|
12 |
+
|
13 |
+
|
14 |
+
prediction_dir = sys.argv[1]
|
15 |
+
|
16 |
+
print(prediction_dir)
|
17 |
+
|
18 |
+
gt_dir = '../data_processing/outputs/alignment_gt_dir/'
|
19 |
+
prediction_path_list = sorted(os.listdir(prediction_dir))
|
20 |
+
|
21 |
+
DISPLAY = False
|
22 |
+
DETAIL = False
|
23 |
+
|
24 |
+
if DISPLAY:
|
25 |
+
from IPython.display import display
|
26 |
+
|
27 |
+
def recall_at_k_all_map(all_rank_list, k = 1):
|
28 |
+
|
29 |
+
rank_list = [item for sublist in all_rank_list for item in sublist]
|
30 |
+
total_query = len(rank_list)
|
31 |
+
prec = np.sum(np.array(rank_list)<=k)
|
32 |
+
prec = 1.0 * prec / total_query
|
33 |
+
|
34 |
+
return prec
|
35 |
+
|
36 |
+
def recall_at_k_permap(all_rank_list, k = 1):
|
37 |
+
|
38 |
+
prec_list = []
|
39 |
+
for rank_list in all_rank_list:
|
40 |
+
total_query = len(rank_list)
|
41 |
+
prec = np.sum(np.array(rank_list)<=k)
|
42 |
+
prec = 1.0 * prec / total_query
|
43 |
+
prec_list.append(prec)
|
44 |
+
|
45 |
+
return prec_list
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
def reciprocal_rank(all_rank_list):
|
51 |
+
|
52 |
+
recip_list = [1./rank for rank in all_rank_list]
|
53 |
+
mean_recip = np.mean(recip_list)
|
54 |
+
|
55 |
+
return mean_recip, recip_list
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
count_hist_list = []
|
61 |
+
|
62 |
+
all_rank_list = []
|
63 |
+
|
64 |
+
all_recip_list = []
|
65 |
+
|
66 |
+
permap_recip_list = []
|
67 |
+
|
68 |
+
for map_path in prediction_path_list:
|
69 |
+
|
70 |
+
pred_path = os.path.join(prediction_dir, map_path)
|
71 |
+
gt_path = os.path.join(gt_dir, map_path.split('.json')[0] + '.csv')
|
72 |
+
|
73 |
+
if DETAIL:
|
74 |
+
print(pred_path)
|
75 |
+
|
76 |
+
|
77 |
+
with open(gt_path, 'r') as f:
|
78 |
+
gt_data = f.readlines()
|
79 |
+
|
80 |
+
gt_dict = dict()
|
81 |
+
for line in gt_data:
|
82 |
+
line = line.split(',')
|
83 |
+
pivot_name = line[0]
|
84 |
+
gt_uri = line[1]
|
85 |
+
gt_dict[pivot_name] = gt_uri
|
86 |
+
|
87 |
+
rank_list = []
|
88 |
+
pivot_name_list = []
|
89 |
+
with open(pred_path, 'r') as f:
|
90 |
+
pred_data = f.readlines()
|
91 |
+
for line in pred_data:
|
92 |
+
pred_dict = json.loads(line)
|
93 |
+
#print(pred_dict.keys())
|
94 |
+
pivot_name = pred_dict['pivot_name']
|
95 |
+
sorted_match_uri = pred_dict['sorted_match_uri']
|
96 |
+
#sorted_match_des = pred_dict['sorted_match_des']
|
97 |
+
sorted_sim_matrix = pred_dict['sorted_sim_matrix']
|
98 |
+
|
99 |
+
|
100 |
+
total = len(sorted_match_uri)
|
101 |
+
if total == 1:
|
102 |
+
continue
|
103 |
+
|
104 |
+
if pivot_name in gt_dict:
|
105 |
+
|
106 |
+
gt_uri = gt_dict[pivot_name]
|
107 |
+
|
108 |
+
try:
|
109 |
+
assert gt_uri in sorted_match_uri
|
110 |
+
except Exception as e:
|
111 |
+
#print(e)
|
112 |
+
continue
|
113 |
+
|
114 |
+
pivot_name_list.append(pivot_name)
|
115 |
+
count_hist_list.append(total)
|
116 |
+
rank = sorted_match_uri.index(gt_uri) +1
|
117 |
+
|
118 |
+
rank_list.append(rank)
|
119 |
+
#print(rank,'/',total)
|
120 |
+
|
121 |
+
all_rank_list.append(rank_list)
|
122 |
+
|
123 |
+
mean_recip, recip_list = reciprocal_rank(rank_list)
|
124 |
+
|
125 |
+
all_recip_list.extend(recip_list)
|
126 |
+
permap_recip_list.append(recip_list)
|
127 |
+
|
128 |
+
d = {'pivot': pivot_name_list + ['AVG'], 'rank':rank_list + [' '] ,'recip rank': recip_list + [str(mean_recip)]}
|
129 |
+
if DETAIL:
|
130 |
+
print(pivot_name_list, rank_list, recip_list)
|
131 |
+
|
132 |
+
if DISPLAY:
|
133 |
+
df = pd.DataFrame(data=d)
|
134 |
+
|
135 |
+
display(df)
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
print('all mrr, micro', np.mean(all_recip_list))
|
140 |
+
|
141 |
+
|
142 |
+
if DETAIL:
|
143 |
+
|
144 |
+
len(rank_list)
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
print(recall_at_k_all_map(all_rank_list, k = 1))
|
149 |
+
print(recall_at_k_all_map(all_rank_list, k = 2))
|
150 |
+
print(recall_at_k_all_map(all_rank_list, k = 5))
|
151 |
+
print(recall_at_k_all_map(all_rank_list, k = 10))
|
152 |
+
|
153 |
+
|
154 |
+
print(prediction_path_list)
|
155 |
+
|
156 |
+
|
157 |
+
prec_list_1 = recall_at_k_permap(all_rank_list, k = 1)
|
158 |
+
prec_list_2 = recall_at_k_permap(all_rank_list, k = 2)
|
159 |
+
prec_list_5 = recall_at_k_permap(all_rank_list, k = 5)
|
160 |
+
prec_list_10 = recall_at_k_permap(all_rank_list, k = 10)
|
161 |
+
|
162 |
+
if DETAIL:
|
163 |
+
|
164 |
+
print(np.mean(prec_list_1))
|
165 |
+
print(prec_list_1)
|
166 |
+
print('\n')
|
167 |
+
|
168 |
+
print(np.mean(prec_list_2))
|
169 |
+
print(prec_list_2)
|
170 |
+
print('\n')
|
171 |
+
|
172 |
+
print(np.mean(prec_list_5))
|
173 |
+
print(prec_list_5)
|
174 |
+
print('\n')
|
175 |
+
|
176 |
+
print(np.mean(prec_list_10))
|
177 |
+
print(prec_list_10)
|
178 |
+
print('\n')
|
179 |
+
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
+
|
184 |
+
|
185 |
+
|
186 |
+
|
187 |
+
import pandas as pd
|
188 |
+
|
189 |
+
|
190 |
+
|
191 |
+
map_name_list = [name.split('.json')[0].split('USGS-')[1] for name in prediction_path_list]
|
192 |
+
d = {'map_name': map_name_list,'recall@1': prec_list_1, 'recall@2': prec_list_2, 'recall@5': prec_list_5, 'recall@10': prec_list_10 }
|
193 |
+
df = pd.DataFrame(data=d)
|
194 |
+
|
195 |
+
|
196 |
+
if DETAIL:
|
197 |
+
print(df)
|
198 |
+
|
199 |
+
|
200 |
+
|
201 |
+
|
202 |
+
|
203 |
+
category = ['15-CA','30-CA','60-CA']
|
204 |
+
col_1 = [np.mean(prec_list_1[0:4]), np.mean(prec_list_1[4:9]), np.mean(prec_list_1[9:])]
|
205 |
+
col_2 = [np.mean(prec_list_2[0:4]), np.mean(prec_list_2[4:9]), np.mean(prec_list_2[9:])]
|
206 |
+
col_3 = [np.mean(prec_list_5[0:4]), np.mean(prec_list_5[4:9]), np.mean(prec_list_5[9:])]
|
207 |
+
col_4 = [np.mean(prec_list_10[0:4]), np.mean(prec_list_10[4:9]), np.mean(prec_list_10[9:])]
|
208 |
+
|
209 |
+
|
210 |
+
|
211 |
+
mrr_15 = permap_recip_list[0] + permap_recip_list[1] + permap_recip_list[2] + permap_recip_list[3]
|
212 |
+
mrr_30 = permap_recip_list[4] + permap_recip_list[5] + permap_recip_list[6] + permap_recip_list[7] + permap_recip_list[8]
|
213 |
+
mrr_60 = permap_recip_list[9] + permap_recip_list[10] + permap_recip_list[11] + permap_recip_list[12] + permap_recip_list[13]
|
214 |
+
|
215 |
+
|
216 |
+
|
217 |
+
column_5 = [np.mean(mrr_15), np.mean(mrr_30), np.mean(mrr_60)]
|
218 |
+
|
219 |
+
|
220 |
+
d = {'map set': category, 'mrr': column_5, 'prec@1': col_1, 'prec@2': col_2, 'prec@5': col_3, 'prec@10': col_4 }
|
221 |
+
df = pd.DataFrame(data=d)
|
222 |
+
|
223 |
+
print(df)
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
|
228 |
+
print('all mrr, micro', np.mean(all_recip_list))
|
229 |
+
|
230 |
+
print('\n')
|
231 |
+
|
232 |
+
|
233 |
+
|
234 |
+
print(recall_at_k_all_map(all_rank_list, k = 1))
|
235 |
+
print(recall_at_k_all_map(all_rank_list, k = 2))
|
236 |
+
print(recall_at_k_all_map(all_rank_list, k = 5))
|
237 |
+
print(recall_at_k_all_map(all_rank_list, k = 10))
|
238 |
+
|
239 |
+
|
240 |
+
|
241 |
+
|
242 |
+
if DISPLAY:
|
243 |
+
|
244 |
+
import seaborn
|
245 |
+
|
246 |
+
p = seaborn.histplot(data = count_hist_list, color = 'blue', alpha=0.2)
|
247 |
+
p.set_xlabel("Number of Candiates")
|
248 |
+
p.set_title("Candidate Distribution in USGS")
|
249 |
+
|
250 |
+
|
251 |
+
|
252 |
+
|
253 |
+
len(count_hist_list)
|
254 |
+
|
255 |
+
|
256 |
+
|
257 |
+
|
258 |
+
|
259 |
+
|
260 |
+
|
models/spabert/experiments/entity_matching/src/linking_ablation.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
import sys
|
5 |
+
import os
|
6 |
+
import numpy as np
|
7 |
+
import pdb
|
8 |
+
import json
|
9 |
+
import scipy.spatial as sp
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch.utils.data import DataLoader
|
15 |
+
|
16 |
+
from transformers import AdamW
|
17 |
+
from transformers import BertTokenizer
|
18 |
+
from tqdm import tqdm # for our progress bar
|
19 |
+
|
20 |
+
sys.path.append('../../../')
|
21 |
+
from datasets.usgs_os_sample_loader import USGS_MapDataset
|
22 |
+
from datasets.wikidata_sample_loader import Wikidata_Geocoord_Dataset, Wikidata_Random_Dataset
|
23 |
+
from models.spatial_bert_model import SpatialBertModel
|
24 |
+
from models.spatial_bert_model import SpatialBertConfig
|
25 |
+
from utils.find_closest import find_ref_closest_match, sort_ref_closest_match
|
26 |
+
from utils.common_utils import load_spatial_bert_pretrained_weights, get_spatialbert_embedding, get_bert_embedding, write_to_csv
|
27 |
+
from utils.baseline_utils import get_baseline_model
|
28 |
+
|
29 |
+
from transformers import BertModel
|
30 |
+
|
31 |
+
sys.path.append('/home/zekun/spatial_bert/spatial_bert/datasets')
|
32 |
+
from dataset_loader import SpatialDataset
|
33 |
+
from osm_sample_loader import PbfMapDataset
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
MODEL_OPTIONS = ['spatial_bert-base','spatial_bert-large', 'bert-base','bert-large','roberta-base','roberta-large',
|
38 |
+
'spanbert-base','spanbert-large','luke-base','luke-large',
|
39 |
+
'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large']
|
40 |
+
|
41 |
+
|
42 |
+
CANDSET_MODES = ['all_map'] # candidate set is constructed based on all maps or one map
|
43 |
+
|
44 |
+
def recall_at_k(rank_list, k = 1):
|
45 |
+
|
46 |
+
total_query = len(rank_list)
|
47 |
+
recall = np.sum(np.array(rank_list)<=k)
|
48 |
+
recall = 1.0 * recall / total_query
|
49 |
+
|
50 |
+
return recall
|
51 |
+
|
52 |
+
def reciprocal_rank(all_rank_list):
|
53 |
+
|
54 |
+
recip_list = [1./rank for rank in all_rank_list]
|
55 |
+
mean_recip = np.mean(recip_list)
|
56 |
+
|
57 |
+
return mean_recip, recip_list
|
58 |
+
|
59 |
+
def link_to_itself(source_embedding_ogc_list, target_embedding_ogc_list):
|
60 |
+
|
61 |
+
source_emb_list = [source_dict['emb'] for source_dict in source_embedding_ogc_list]
|
62 |
+
source_ogc_list = [source_dict['ogc_fid'] for source_dict in source_embedding_ogc_list]
|
63 |
+
|
64 |
+
target_emb_list = [target_dict['emb'] for target_dict in target_embedding_ogc_list]
|
65 |
+
target_ogc_list = [target_dict['ogc_fid'] for target_dict in target_embedding_ogc_list]
|
66 |
+
|
67 |
+
rank_list = []
|
68 |
+
for source_emb, source_ogc in zip(source_emb_list, source_ogc_list):
|
69 |
+
sim_matrix = 1 - sp.distance.cdist(np.array(target_emb_list), np.array([source_emb]), 'cosine')
|
70 |
+
closest_match_ogc = sort_ref_closest_match(sim_matrix, target_ogc_list)
|
71 |
+
|
72 |
+
closest_match_ogc = [a[0] for a in closest_match_ogc]
|
73 |
+
rank = closest_match_ogc.index(source_ogc) +1
|
74 |
+
rank_list.append(rank)
|
75 |
+
|
76 |
+
|
77 |
+
mean_recip, recip_list = reciprocal_rank(rank_list)
|
78 |
+
r1 = recall_at_k(rank_list, k = 1)
|
79 |
+
r5 = recall_at_k(rank_list, k = 5)
|
80 |
+
r10 = recall_at_k(rank_list, k = 10)
|
81 |
+
|
82 |
+
return mean_recip , r1, r5, r10
|
83 |
+
|
84 |
+
def get_embedding_and_ogc(dataset, model_name, model):
|
85 |
+
dict_list = []
|
86 |
+
|
87 |
+
for source in dataset:
|
88 |
+
if model_name == 'spatial_bert-base' or model_name == 'spatial_bert-large':
|
89 |
+
source_emb = get_spatialbert_embedding(source, model)
|
90 |
+
else:
|
91 |
+
source_emb = get_bert_embedding(source, model)
|
92 |
+
|
93 |
+
source_dict = {}
|
94 |
+
source_dict['emb'] = source_emb
|
95 |
+
source_dict['ogc_fid'] = source['ogc_fid']
|
96 |
+
#wikidata_dict['wikidata_des_list'] = [wikidata_cand['description']]
|
97 |
+
|
98 |
+
dict_list.append(source_dict)
|
99 |
+
|
100 |
+
return dict_list
|
101 |
+
|
102 |
+
|
103 |
+
def entity_linking_func(args):
|
104 |
+
|
105 |
+
model_name = args.model_name
|
106 |
+
candset_mode = args.candset_mode
|
107 |
+
|
108 |
+
distance_norm_factor = args.distance_norm_factor
|
109 |
+
spatial_dist_fill= args.spatial_dist_fill
|
110 |
+
sep_between_neighbors = args.sep_between_neighbors
|
111 |
+
|
112 |
+
spatial_bert_weight_dir = args.spatial_bert_weight_dir
|
113 |
+
spatial_bert_weight_name = args.spatial_bert_weight_name
|
114 |
+
|
115 |
+
if_no_spatial_distance = args.no_spatial_distance
|
116 |
+
random_remove_neighbor = args.random_remove_neighbor
|
117 |
+
|
118 |
+
|
119 |
+
assert model_name in MODEL_OPTIONS
|
120 |
+
assert candset_mode in CANDSET_MODES
|
121 |
+
|
122 |
+
|
123 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
if model_name == 'spatial_bert-base':
|
128 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
129 |
+
|
130 |
+
config = SpatialBertConfig()
|
131 |
+
model = SpatialBertModel(config)
|
132 |
+
|
133 |
+
model.to(device)
|
134 |
+
model.eval()
|
135 |
+
|
136 |
+
# load pretrained weights
|
137 |
+
weight_path = os.path.join(spatial_bert_weight_dir, spatial_bert_weight_name)
|
138 |
+
model = load_spatial_bert_pretrained_weights(model, weight_path)
|
139 |
+
|
140 |
+
elif model_name == 'spatial_bert-large':
|
141 |
+
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
|
142 |
+
|
143 |
+
config = SpatialBertConfig(hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24)
|
144 |
+
model = SpatialBertModel(config)
|
145 |
+
|
146 |
+
model.to(device)
|
147 |
+
model.eval()
|
148 |
+
|
149 |
+
# load pretrained weights
|
150 |
+
weight_path = os.path.join(spatial_bert_weight_dir, spatial_bert_weight_name)
|
151 |
+
model = load_spatial_bert_pretrained_weights(model, weight_path)
|
152 |
+
|
153 |
+
else:
|
154 |
+
model, tokenizer = get_baseline_model(model_name)
|
155 |
+
model.to(device)
|
156 |
+
model.eval()
|
157 |
+
|
158 |
+
source_file_path = '../data/osm-point-minnesota-full.json'
|
159 |
+
source_dataset = PbfMapDataset(data_file_path = source_file_path,
|
160 |
+
tokenizer = tokenizer,
|
161 |
+
max_token_len = 512,
|
162 |
+
distance_norm_factor = distance_norm_factor,
|
163 |
+
spatial_dist_fill = spatial_dist_fill,
|
164 |
+
with_type = False,
|
165 |
+
sep_between_neighbors = sep_between_neighbors,
|
166 |
+
mode = None,
|
167 |
+
random_remove_neighbor = random_remove_neighbor,
|
168 |
+
)
|
169 |
+
|
170 |
+
target_dataset = PbfMapDataset(data_file_path = source_file_path,
|
171 |
+
tokenizer = tokenizer,
|
172 |
+
max_token_len = 512,
|
173 |
+
distance_norm_factor = distance_norm_factor,
|
174 |
+
spatial_dist_fill = spatial_dist_fill,
|
175 |
+
with_type = False,
|
176 |
+
sep_between_neighbors = sep_between_neighbors,
|
177 |
+
mode = None,
|
178 |
+
random_remove_neighbor = 0., # keep all
|
179 |
+
)
|
180 |
+
|
181 |
+
# process candidates for each phrase
|
182 |
+
|
183 |
+
|
184 |
+
source_embedding_ogc_list = get_embedding_and_ogc(source_dataset, model_name, model)
|
185 |
+
target_embedding_ogc_list = get_embedding_and_ogc(target_dataset, model_name, model)
|
186 |
+
|
187 |
+
|
188 |
+
mean_recip , r1, r5, r10 = link_to_itself(source_embedding_ogc_list, target_embedding_ogc_list)
|
189 |
+
print('\n')
|
190 |
+
print(random_remove_neighbor, mean_recip , r1, r5, r10)
|
191 |
+
|
192 |
+
|
193 |
+
|
194 |
+
def main():
|
195 |
+
parser = argparse.ArgumentParser()
|
196 |
+
parser.add_argument('--model_name', type=str, default='spatial_bert-base')
|
197 |
+
parser.add_argument('--candset_mode', type=str, default='all_map')
|
198 |
+
|
199 |
+
parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
|
200 |
+
parser.add_argument('--spatial_dist_fill', type=float, default = 20)
|
201 |
+
|
202 |
+
parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
|
203 |
+
parser.add_argument('--no_spatial_distance', default=False, action='store_true')
|
204 |
+
|
205 |
+
parser.add_argument('--spatial_bert_weight_dir', type = str, default = None)
|
206 |
+
parser.add_argument('--spatial_bert_weight_name', type = str, default = None)
|
207 |
+
|
208 |
+
parser.add_argument('--random_remove_neighbor', type = float, default = 0.)
|
209 |
+
|
210 |
+
|
211 |
+
args = parser.parse_args()
|
212 |
+
# print('\n')
|
213 |
+
# print(args)
|
214 |
+
# print('\n')
|
215 |
+
|
216 |
+
entity_linking_func(args)
|
217 |
+
|
218 |
+
# CUDA_VISIBLE_DEVICES='1' python3 linking_ablation.py --sep_between_neighbors --model_name='spatial_bert-base' --spatial_bert_weight_dir='/data/zekun/spatial_bert_weights/typing_lr5e-05_sep_bert-base_nofreeze_london_california_bsize12/ep0_iter06000_0.2936/' --spatial_bert_weight_name='keeppos_ep0_iter02000_0.4879.pth' --random_remove_neighbor=0.1
|
219 |
+
|
220 |
+
|
221 |
+
# CUDA_VISIBLE_DEVICES='1' python3 linking_ablation.py --sep_between_neighbors --model_name='spatial_bert-large' --spatial_bert_weight_dir='/data/zekun/spatial_bert_weights/typing_lr1e-06_sep_bert-large_nofreeze_london_california_bsize12/ep2_iter02000_0.3921/' --spatial_bert_weight_name='keeppos_ep8_iter03568_0.2661_val0.2284.pth' --random_remove_neighbor=0.1
|
222 |
+
|
223 |
+
|
224 |
+
if __name__ == '__main__':
|
225 |
+
|
226 |
+
main()
|
227 |
+
|
228 |
+
|
models/spabert/experiments/entity_matching/src/unsupervised_wiki_location_allcand.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
import sys
|
5 |
+
import os
|
6 |
+
import numpy as np
|
7 |
+
import pdb
|
8 |
+
import json
|
9 |
+
import scipy.spatial as sp
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch.utils.data import DataLoader
|
15 |
+
#from transformers.models.bert.modeling_bert import BertForMaskedLM
|
16 |
+
|
17 |
+
from transformers import AdamW
|
18 |
+
from transformers import BertTokenizer
|
19 |
+
from tqdm import tqdm # for our progress bar
|
20 |
+
|
21 |
+
sys.path.append('../../../')
|
22 |
+
from datasets.usgs_os_sample_loader import USGS_MapDataset
|
23 |
+
from datasets.wikidata_sample_loader import Wikidata_Geocoord_Dataset, Wikidata_Random_Dataset
|
24 |
+
from models.spatial_bert_model import SpatialBertModel
|
25 |
+
from models.spatial_bert_model import SpatialBertConfig
|
26 |
+
#from models.spatial_bert_model import SpatialBertForMaskedLM
|
27 |
+
from utils.find_closest import find_ref_closest_match, sort_ref_closest_match
|
28 |
+
from utils.common_utils import load_spatial_bert_pretrained_weights, get_spatialbert_embedding, get_bert_embedding, write_to_csv
|
29 |
+
from utils.baseline_utils import get_baseline_model
|
30 |
+
|
31 |
+
from transformers import BertModel
|
32 |
+
|
33 |
+
|
34 |
+
MODEL_OPTIONS = ['spatial_bert-base','spatial_bert-large', 'bert-base','bert-large','roberta-base','roberta-large',
|
35 |
+
'spanbert-base','spanbert-large','luke-base','luke-large',
|
36 |
+
'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large']
|
37 |
+
|
38 |
+
MAP_TYPES = ['usgs']
|
39 |
+
CANDSET_MODES = ['all_map'] # candidate set is constructed based on all maps or one map
|
40 |
+
|
41 |
+
|
42 |
+
def disambiguify(model, model_name, usgs_dataset, wikidata_dict_list, candset_mode = 'all_map', if_use_distance = True, select_indices = None):
|
43 |
+
|
44 |
+
if select_indices is None:
|
45 |
+
select_indices = range(0, len(usgs_dataset))
|
46 |
+
|
47 |
+
|
48 |
+
assert(candset_mode in ['all_map','per_map'])
|
49 |
+
|
50 |
+
wikidata_emb_list = [wikidata_dict['wikidata_emb_list'] for wikidata_dict in wikidata_dict_list]
|
51 |
+
wikidata_uri_list = [wikidata_dict['wikidata_uri_list'] for wikidata_dict in wikidata_dict_list]
|
52 |
+
#wikidata_des_list = [wikidata_dict['wikidata_des_list'] for wikidata_dict in wikidata_dict_list]
|
53 |
+
|
54 |
+
if candset_mode == 'all_map':
|
55 |
+
wikidata_emb_list = [item for sublist in wikidata_emb_list for item in sublist] # flatten
|
56 |
+
wikidata_uri_list = [item for sublist in wikidata_uri_list for item in sublist] # flatten
|
57 |
+
#wikidata_des_list = [item for sublist in wikidata_des_list for item in sublist] # flatten
|
58 |
+
|
59 |
+
|
60 |
+
ret_list = []
|
61 |
+
for i in select_indices:
|
62 |
+
|
63 |
+
if candset_mode == 'per_map':
|
64 |
+
usgs_entity = usgs_dataset[i]
|
65 |
+
wikidata_emb_list = wikidata_emb_list[i]
|
66 |
+
wikidata_uri_list = wikidata_uri_list[i]
|
67 |
+
#wikidata_des_list = wikidata_des_list[i]
|
68 |
+
|
69 |
+
elif candset_mode == 'all_map':
|
70 |
+
usgs_entity = usgs_dataset[i]
|
71 |
+
else:
|
72 |
+
raise NotImplementedError
|
73 |
+
|
74 |
+
if model_name == 'spatial_bert-base' or model_name == 'spatial_bert-large':
|
75 |
+
usgs_emb = get_spatialbert_embedding(usgs_entity, model, use_distance = if_use_distance)
|
76 |
+
else:
|
77 |
+
usgs_emb = get_bert_embedding(usgs_entity, model)
|
78 |
+
|
79 |
+
|
80 |
+
sim_matrix = 1 - sp.distance.cdist(np.array(wikidata_emb_list), np.array([usgs_emb]), 'cosine')
|
81 |
+
|
82 |
+
closest_match_uri = sort_ref_closest_match(sim_matrix, wikidata_uri_list)
|
83 |
+
#closest_match_des = sort_ref_closest_match(sim_matrix, wikidata_des_list)
|
84 |
+
|
85 |
+
|
86 |
+
sorted_sim_matrix = np.sort(sim_matrix, axis = 0)[::-1] # descending order
|
87 |
+
|
88 |
+
ret_dict = dict()
|
89 |
+
ret_dict['pivot_name'] = usgs_entity['pivot_name']
|
90 |
+
ret_dict['sorted_match_uri'] = [a[0] for a in closest_match_uri]
|
91 |
+
#ret_dict['sorted_match_des'] = [a[0] for a in closest_match_des]
|
92 |
+
ret_dict['sorted_sim_matrix'] = [a[0] for a in sorted_sim_matrix]
|
93 |
+
|
94 |
+
ret_list.append(ret_dict)
|
95 |
+
|
96 |
+
return ret_list
|
97 |
+
|
98 |
+
def entity_linking_func(args):
|
99 |
+
|
100 |
+
model_name = args.model_name
|
101 |
+
map_type = args.map_type
|
102 |
+
candset_mode = args.candset_mode
|
103 |
+
|
104 |
+
usgs_distance_norm_factor = args.usgs_distance_norm_factor
|
105 |
+
spatial_dist_fill= args.spatial_dist_fill
|
106 |
+
sep_between_neighbors = args.sep_between_neighbors
|
107 |
+
|
108 |
+
spatial_bert_weight_dir = args.spatial_bert_weight_dir
|
109 |
+
spatial_bert_weight_name = args.spatial_bert_weight_name
|
110 |
+
|
111 |
+
if_no_spatial_distance = args.no_spatial_distance
|
112 |
+
|
113 |
+
|
114 |
+
assert model_name in MODEL_OPTIONS
|
115 |
+
assert map_type in MAP_TYPES
|
116 |
+
assert candset_mode in CANDSET_MODES
|
117 |
+
|
118 |
+
|
119 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
120 |
+
|
121 |
+
if args.out_dir is None:
|
122 |
+
|
123 |
+
if model_name == 'spatial_bert-base' or model_name == 'spatial_bert-large':
|
124 |
+
|
125 |
+
if sep_between_neighbors:
|
126 |
+
spatialbert_output_dir_str = 'dnorm' + str(usgs_distance_norm_factor ) + '_distfill' + str(spatial_dist_fill) + '_sep'
|
127 |
+
else:
|
128 |
+
spatialbert_output_dir_str = 'dnorm' + str(usgs_distance_norm_factor ) + '_distfill' + str(spatial_dist_fill) + '_nosep'
|
129 |
+
|
130 |
+
|
131 |
+
checkpoint_ep = spatial_bert_weight_name.split('_')[3]
|
132 |
+
checkpoint_iter = spatial_bert_weight_name.split('_')[4]
|
133 |
+
loss_val = spatial_bert_weight_name.split('_')[5][:-4]
|
134 |
+
|
135 |
+
if if_no_spatial_distance:
|
136 |
+
linking_prediction_dir = 'linking_prediction_dir/abalation_no_distance/'
|
137 |
+
else:
|
138 |
+
linking_prediction_dir = 'linking_prediction_dir'
|
139 |
+
|
140 |
+
if model_name == 'spatial_bert-base':
|
141 |
+
out_dir = os.path.join('/data2/zekun/', linking_prediction_dir, spatialbert_output_dir_str) + '/' + map_type + '-' + model_name + '-' + checkpoint_ep + '-' + checkpoint_iter + '-' + loss_val
|
142 |
+
elif model_name == 'spatial_bert-large':
|
143 |
+
|
144 |
+
freeze_str = spatial_bert_weight_dir.split('/')[-2].split('_')[1] # either 'freeze' or 'nofreeze'
|
145 |
+
out_dir = os.path.join('/data2/zekun/', linking_prediction_dir, spatialbert_output_dir_str) + '/' + map_type + '-' + model_name + '-' + checkpoint_ep + '-' + checkpoint_iter + '-' + loss_val + '-' + freeze_str
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
else:
|
150 |
+
out_dir = '/data2/zekun/baseline_linking_prediction_dir/' + map_type + '-' + model_name
|
151 |
+
|
152 |
+
else:
|
153 |
+
out_dir = args.out_dir
|
154 |
+
|
155 |
+
print('out_dir', out_dir)
|
156 |
+
|
157 |
+
if not os.path.isdir(out_dir):
|
158 |
+
os.makedirs(out_dir)
|
159 |
+
|
160 |
+
|
161 |
+
if model_name == 'spatial_bert-base':
|
162 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
163 |
+
|
164 |
+
config = SpatialBertConfig()
|
165 |
+
model = SpatialBertModel(config)
|
166 |
+
|
167 |
+
model.to(device)
|
168 |
+
model.eval()
|
169 |
+
|
170 |
+
# load pretrained weights
|
171 |
+
weight_path = os.path.join(spatial_bert_weight_dir, spatial_bert_weight_name)
|
172 |
+
model = load_spatial_bert_pretrained_weights(model, weight_path)
|
173 |
+
|
174 |
+
elif model_name == 'spatial_bert-large':
|
175 |
+
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
|
176 |
+
|
177 |
+
config = SpatialBertConfig(hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24)
|
178 |
+
model = SpatialBertModel(config)
|
179 |
+
|
180 |
+
model.to(device)
|
181 |
+
model.eval()
|
182 |
+
|
183 |
+
# load pretrained weights
|
184 |
+
weight_path = os.path.join(spatial_bert_weight_dir, spatial_bert_weight_name)
|
185 |
+
model = load_spatial_bert_pretrained_weights(model, weight_path)
|
186 |
+
|
187 |
+
else:
|
188 |
+
model, tokenizer = get_baseline_model(model_name)
|
189 |
+
model.to(device)
|
190 |
+
model.eval()
|
191 |
+
|
192 |
+
|
193 |
+
if map_type == 'usgs':
|
194 |
+
map_name_list = ['USGS-15-CA-brawley-e1957-s1957-p1961',
|
195 |
+
'USGS-15-CA-paloalto-e1899-s1895-rp1911',
|
196 |
+
'USGS-15-CA-capesanmartin-e1921-s1917',
|
197 |
+
'USGS-15-CA-sanfrancisco-e1899-s1892-rp1911',
|
198 |
+
'USGS-30-CA-dardanelles-e1898-s1891-rp1912',
|
199 |
+
'USGS-30-CA-holtville-e1907-s1905-rp1946',
|
200 |
+
'USGS-30-CA-indiospecial-e1904-s1901-rp1910',
|
201 |
+
'USGS-30-CA-lompoc-e1943-s1903-ap1941-rv1941',
|
202 |
+
'USGS-30-CA-sanpedro-e1943-rv1944',
|
203 |
+
'USGS-60-CA-alturas-e1892-rp1904',
|
204 |
+
'USGS-60-CA-amboy-e1942',
|
205 |
+
'USGS-60-CA-amboy-e1943-rv1943',
|
206 |
+
'USGS-60-CA-modoclavabed-e1886-s1884',
|
207 |
+
'USGS-60-CA-saltonsea-e1943-ap1940-rv1942']
|
208 |
+
|
209 |
+
print('processing wikidata...')
|
210 |
+
|
211 |
+
wikidata_dict_list = []
|
212 |
+
|
213 |
+
wikidata_random30k = Wikidata_Random_Dataset(
|
214 |
+
data_file_path = '../data_processing/wikidata_sample30k/wikidata_30k_neighbor_reformat.json',
|
215 |
+
#neighbor_file_path = '../data_processing/wikidata_sample30k/wikidata_30k_neighbor.json',
|
216 |
+
tokenizer = tokenizer,
|
217 |
+
max_token_len = 512,
|
218 |
+
distance_norm_factor = 0.0001,
|
219 |
+
spatial_dist_fill=100,
|
220 |
+
sep_between_neighbors = sep_between_neighbors,
|
221 |
+
)
|
222 |
+
|
223 |
+
# process candidates for each phrase
|
224 |
+
for wikidata_cand in wikidata_random30k:
|
225 |
+
if model_name == 'spatial_bert-base' or model_name == 'spatial_bert-large':
|
226 |
+
wikidata_emb = get_spatialbert_embedding(wikidata_cand, model)
|
227 |
+
else:
|
228 |
+
wikidata_emb = get_bert_embedding(wikidata_cand, model)
|
229 |
+
|
230 |
+
wikidata_dict = {}
|
231 |
+
wikidata_dict['wikidata_emb_list'] = [wikidata_emb]
|
232 |
+
wikidata_dict['wikidata_uri_list'] = [wikidata_cand['uri']]
|
233 |
+
|
234 |
+
wikidata_dict_list.append(wikidata_dict)
|
235 |
+
|
236 |
+
|
237 |
+
|
238 |
+
for map_name in map_name_list:
|
239 |
+
|
240 |
+
print(map_name)
|
241 |
+
|
242 |
+
wikidata_dict_per_map = {}
|
243 |
+
wikidata_dict_per_map['wikidata_emb_list'] = []
|
244 |
+
wikidata_dict_per_map['wikidata_uri_list'] = []
|
245 |
+
|
246 |
+
wikidata_dataset_permap = Wikidata_Geocoord_Dataset(
|
247 |
+
data_file_path = '../data_processing/outputs/wikidata_reformat/wikidata_' + map_name + '.json',
|
248 |
+
tokenizer = tokenizer,
|
249 |
+
max_token_len = 512,
|
250 |
+
distance_norm_factor = 0.0001,
|
251 |
+
spatial_dist_fill=100,
|
252 |
+
sep_between_neighbors = sep_between_neighbors)
|
253 |
+
|
254 |
+
|
255 |
+
|
256 |
+
for i in range(0, len(wikidata_dataset_permap)):
|
257 |
+
# get all candiates for phrases within the map
|
258 |
+
wikidata_candidates = wikidata_dataset_permap[i] # dataset for each map, list of [cand for each phrase]
|
259 |
+
|
260 |
+
|
261 |
+
# process candidates for each phrase
|
262 |
+
for wikidata_cand in wikidata_candidates:
|
263 |
+
if model_name == 'spatial_bert-base' or model_name == 'spatial_bert-large':
|
264 |
+
wikidata_emb = get_spatialbert_embedding(wikidata_cand, model)
|
265 |
+
else:
|
266 |
+
wikidata_emb = get_bert_embedding(wikidata_cand, model)
|
267 |
+
|
268 |
+
wikidata_dict_per_map['wikidata_emb_list'].append(wikidata_emb)
|
269 |
+
wikidata_dict_per_map['wikidata_uri_list'].append(wikidata_cand['uri'])
|
270 |
+
|
271 |
+
|
272 |
+
wikidata_dict_list.append(wikidata_dict_per_map)
|
273 |
+
|
274 |
+
|
275 |
+
|
276 |
+
for map_name in map_name_list:
|
277 |
+
|
278 |
+
print(map_name)
|
279 |
+
|
280 |
+
|
281 |
+
usgs_dataset = USGS_MapDataset(
|
282 |
+
data_file_path = '../data_processing/outputs/alignment_dir/map_' + map_name + '.json',
|
283 |
+
tokenizer = tokenizer,
|
284 |
+
distance_norm_factor = usgs_distance_norm_factor,
|
285 |
+
spatial_dist_fill = spatial_dist_fill,
|
286 |
+
sep_between_neighbors = sep_between_neighbors)
|
287 |
+
|
288 |
+
|
289 |
+
ret_list = disambiguify(model, model_name, usgs_dataset, wikidata_dict_list, candset_mode= candset_mode, if_use_distance = not if_no_spatial_distance, select_indices = None)
|
290 |
+
|
291 |
+
write_to_csv(out_dir, map_name, ret_list)
|
292 |
+
|
293 |
+
print('Done')
|
294 |
+
|
295 |
+
|
296 |
+
def main():
|
297 |
+
parser = argparse.ArgumentParser()
|
298 |
+
parser.add_argument('--model_name', type=str, default='spatial_bert-base')
|
299 |
+
parser.add_argument('--out_dir', type=str, default=None)
|
300 |
+
parser.add_argument('--map_type', type=str, default='usgs')
|
301 |
+
parser.add_argument('--candset_mode', type=str, default='all_map')
|
302 |
+
|
303 |
+
parser.add_argument('--usgs_distance_norm_factor', type=float, default = 1)
|
304 |
+
parser.add_argument('--spatial_dist_fill', type=float, default = 100)
|
305 |
+
|
306 |
+
parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
|
307 |
+
parser.add_argument('--no_spatial_distance', default=False, action='store_true')
|
308 |
+
|
309 |
+
parser.add_argument('--spatial_bert_weight_dir', type = str, default = None)
|
310 |
+
parser.add_argument('--spatial_bert_weight_name', type = str, default = None)
|
311 |
+
|
312 |
+
args = parser.parse_args()
|
313 |
+
print('\n')
|
314 |
+
print(args)
|
315 |
+
print('\n')
|
316 |
+
|
317 |
+
# out_dir not None, and out_dir does not exist, then create out_dir
|
318 |
+
if args.out_dir is not None and not os.path.isdir(args.out_dir):
|
319 |
+
os.makedirs(args.out_dir)
|
320 |
+
|
321 |
+
entity_linking_func(args)
|
322 |
+
|
323 |
+
|
324 |
+
|
325 |
+
if __name__ == '__main__':
|
326 |
+
|
327 |
+
main()
|
328 |
+
|
329 |
+
|
models/spabert/experiments/semantic_typing/__init__.py
ADDED
File without changes
|
models/spabert/experiments/semantic_typing/data_processing/merge_osm_json.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import math
|
4 |
+
import glob
|
5 |
+
import re
|
6 |
+
import pdb
|
7 |
+
|
8 |
+
'''
|
9 |
+
NO LONGER NEEDED
|
10 |
+
|
11 |
+
Process the california, london, and minnesota OSM data and prepare pseudo-sentence, spatial context
|
12 |
+
|
13 |
+
Load the raw output files genrated by sql
|
14 |
+
Unify the json by changing the structure of dictionary
|
15 |
+
Save the output into two files, one for training+ validation, and the other one for testing
|
16 |
+
'''
|
17 |
+
|
18 |
+
region_list = ['california','london','minnesota']
|
19 |
+
|
20 |
+
input_json_dir = '../data/sql_output/sub_files/'
|
21 |
+
output_json_dir = '../data/sql_output/'
|
22 |
+
|
23 |
+
for region_name in region_list:
|
24 |
+
file_list = glob.glob(os.path.join(input_json_dir, 'spatialbert-osm-point-' + region_name + '*.json'))
|
25 |
+
file_list = sorted(file_list)
|
26 |
+
print('found %d files for region %s' % (len(file_list), region_name))
|
27 |
+
|
28 |
+
|
29 |
+
num_test_files = int(math.ceil(len(file_list) * 0.2))
|
30 |
+
num_train_val_files = len(file_list) - num_test_files
|
31 |
+
|
32 |
+
print('%d files for train-val' % num_train_val_files)
|
33 |
+
print('%d files for test-tes' % num_test_files)
|
34 |
+
|
35 |
+
train_val_output_path = os.path.join(output_json_dir + 'osm-point-' + region_name + '_train_val.json')
|
36 |
+
test_output_path = os.path.join(output_json_dir + 'osm-point-' + region_name + '_test.json')
|
37 |
+
|
38 |
+
# refresh the file
|
39 |
+
with open(train_val_output_path, 'w') as f:
|
40 |
+
pass
|
41 |
+
with open(test_output_path, 'w') as f:
|
42 |
+
pass
|
43 |
+
|
44 |
+
for idx in range(len(file_list)):
|
45 |
+
|
46 |
+
if idx < num_train_val_files:
|
47 |
+
output_path = train_val_output_path
|
48 |
+
else:
|
49 |
+
output_path = test_output_path
|
50 |
+
|
51 |
+
file_path = file_list[idx]
|
52 |
+
|
53 |
+
print(file_path)
|
54 |
+
|
55 |
+
with open(file_path, 'r') as f:
|
56 |
+
data = f.readlines()
|
57 |
+
|
58 |
+
line = data[0]
|
59 |
+
|
60 |
+
line = re.sub(r'\n', '', line)
|
61 |
+
line = re.sub(r'\\n', '', line)
|
62 |
+
line = re.sub(r'\\+', '', line)
|
63 |
+
line = re.sub(r'\+', '', line)
|
64 |
+
|
65 |
+
line_dict_list = json.loads(line)
|
66 |
+
|
67 |
+
|
68 |
+
for line_dict in line_dict_list:
|
69 |
+
|
70 |
+
line_dict = line_dict['json_build_object']
|
71 |
+
|
72 |
+
if not line_dict['name'][0].isalpha(): # discard record if the first char is not enghlish etter
|
73 |
+
continue
|
74 |
+
|
75 |
+
neighbor_name_list = line_dict['neighbor_info'][0]['name_list']
|
76 |
+
neighbor_geom_list = line_dict['neighbor_info'][0]['geometry_list']
|
77 |
+
|
78 |
+
assert(len(neighbor_geom_list) == len(neighbor_geom_list))
|
79 |
+
|
80 |
+
temp_dict = \
|
81 |
+
{'info':{'name':line_dict['name'],
|
82 |
+
'geometry':{'coordinates':line_dict['geometry']},
|
83 |
+
'class':line_dict['class']
|
84 |
+
},
|
85 |
+
'neighbor_info':{'name_list': neighbor_name_list,
|
86 |
+
'geometry_list': neighbor_geom_list
|
87 |
+
}
|
88 |
+
}
|
89 |
+
|
90 |
+
with open(output_path, 'a') as f:
|
91 |
+
json.dump(temp_dict, f)
|
92 |
+
f.write('\n')
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
models/spabert/experiments/semantic_typing/src/__init__.py
ADDED
File without changes
|
models/spabert/experiments/semantic_typing/src/run_baseline_test.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pdb
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import time
|
6 |
+
|
7 |
+
MODEL_OPTIONS = ['bert-base','bert-large','roberta-base','roberta-large',
|
8 |
+
'spanbert-base','spanbert-large','luke-base','luke-large',
|
9 |
+
'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large']
|
10 |
+
|
11 |
+
def execute_command(command, if_print_command):
|
12 |
+
t1 = time.time()
|
13 |
+
|
14 |
+
if if_print_command:
|
15 |
+
print(command)
|
16 |
+
os.system(command)
|
17 |
+
|
18 |
+
t2 = time.time()
|
19 |
+
time_usage = t2 - t1
|
20 |
+
return time_usage
|
21 |
+
|
22 |
+
def run_test(args):
|
23 |
+
weight_dir = args.weight_dir
|
24 |
+
backbone_option = args.backbone_option
|
25 |
+
gpu_id = str(args.gpu_id)
|
26 |
+
if_print_command = args.print_command
|
27 |
+
sep_between_neighbors = args.sep_between_neighbors
|
28 |
+
|
29 |
+
assert backbone_option in MODEL_OPTIONS
|
30 |
+
|
31 |
+
if sep_between_neighbors:
|
32 |
+
sep_str = '_sep'
|
33 |
+
else:
|
34 |
+
sep_str = ''
|
35 |
+
|
36 |
+
if 'large' in backbone_option:
|
37 |
+
checkpoint_dir = os.path.join(weight_dir, 'typing_lr1e-06_%s_nofreeze%s_london_california_bsize12'% (backbone_option, sep_str))
|
38 |
+
else:
|
39 |
+
checkpoint_dir = os.path.join(weight_dir, 'typing_lr5e-05_%s_nofreeze%s_london_california_bsize12'% (backbone_option, sep_str))
|
40 |
+
weight_files = os.listdir(checkpoint_dir)
|
41 |
+
|
42 |
+
val_loss_list = [weight_file.split('_')[-1] for weight_file in weight_files]
|
43 |
+
min_loss_weight = weight_files[np.argmin(val_loss_list)]
|
44 |
+
|
45 |
+
checkpoint_path = os.path.join(checkpoint_dir, min_loss_weight)
|
46 |
+
|
47 |
+
if sep_between_neighbors:
|
48 |
+
command = 'CUDA_VISIBLE_DEVICES=%s python3 test_cls_baseline.py --sep_between_neighbors --backbone_option=%s --batch_size=8 --with_type --checkpoint_path=%s ' % (gpu_id, backbone_option, checkpoint_path)
|
49 |
+
else:
|
50 |
+
command = 'CUDA_VISIBLE_DEVICES=%s python3 test_cls_baseline.py --backbone_option=%s --batch_size=8 --with_type --checkpoint_path=%s ' % (gpu_id, backbone_option, checkpoint_path)
|
51 |
+
|
52 |
+
|
53 |
+
execute_command(command, if_print_command)
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
def main():
|
59 |
+
parser = argparse.ArgumentParser()
|
60 |
+
|
61 |
+
parser.add_argument('--weight_dir', type=str, default='/data2/zekun/spatial_bert_baseline_weights/')
|
62 |
+
parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
|
63 |
+
parser.add_argument('--backbone_option', type=str, default=None)
|
64 |
+
parser.add_argument('--gpu_id', type=int, default=0) # output prefix
|
65 |
+
|
66 |
+
parser.add_argument('--print_command', default=False, action='store_true')
|
67 |
+
|
68 |
+
|
69 |
+
args = parser.parse_args()
|
70 |
+
print('\n')
|
71 |
+
print(args)
|
72 |
+
print('\n')
|
73 |
+
|
74 |
+
run_test(args)
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == '__main__':
|
78 |
+
|
79 |
+
main()
|
80 |
+
|
81 |
+
|
82 |
+
|
models/spabert/experiments/semantic_typing/src/test_cls_ablation_spatialbert.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from transformers import RobertaTokenizer, BertTokenizer
|
4 |
+
from tqdm import tqdm # for our progress bar
|
5 |
+
from transformers import AdamW
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
sys.path.append('../../../')
|
12 |
+
from models.spatial_bert_model import SpatialBertModel
|
13 |
+
from models.spatial_bert_model import SpatialBertConfig
|
14 |
+
from models.spatial_bert_model import SpatialBertForMaskedLM, SpatialBertForSemanticTyping
|
15 |
+
from datasets.osm_sample_loader import PbfMapDataset
|
16 |
+
from datasets.const import *
|
17 |
+
from transformers.models.bert.modeling_bert import BertForMaskedLM
|
18 |
+
|
19 |
+
from sklearn.metrics import label_ranking_average_precision_score
|
20 |
+
from sklearn.metrics import precision_recall_fscore_support
|
21 |
+
import numpy as np
|
22 |
+
import argparse
|
23 |
+
from sklearn.preprocessing import LabelEncoder
|
24 |
+
import pdb
|
25 |
+
|
26 |
+
|
27 |
+
DEBUG = False
|
28 |
+
torch.backends.cudnn.deterministic = True
|
29 |
+
torch.backends.cudnn.benchmark = False
|
30 |
+
torch.manual_seed(42)
|
31 |
+
torch.cuda.manual_seed_all(42)
|
32 |
+
|
33 |
+
|
34 |
+
def testing(args):
|
35 |
+
|
36 |
+
max_token_len = args.max_token_len
|
37 |
+
batch_size = args.batch_size
|
38 |
+
num_workers = args.num_workers
|
39 |
+
distance_norm_factor = args.distance_norm_factor
|
40 |
+
spatial_dist_fill=args.spatial_dist_fill
|
41 |
+
with_type = args.with_type
|
42 |
+
sep_between_neighbors = args.sep_between_neighbors
|
43 |
+
checkpoint_path = args.checkpoint_path
|
44 |
+
if_no_spatial_distance = args.no_spatial_distance
|
45 |
+
|
46 |
+
bert_option = args.bert_option
|
47 |
+
num_neighbor_limit = args.num_neighbor_limit
|
48 |
+
|
49 |
+
|
50 |
+
london_file_path = '../data/sql_output/osm-point-london-typing.json'
|
51 |
+
california_file_path = '../data/sql_output/osm-point-california-typing.json'
|
52 |
+
|
53 |
+
|
54 |
+
if bert_option == 'bert-base':
|
55 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
56 |
+
config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, num_semantic_types=len(CLASS_9_LIST))
|
57 |
+
elif bert_option == 'bert-large':
|
58 |
+
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
|
59 |
+
config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24,num_semantic_types=len(CLASS_9_LIST))
|
60 |
+
else:
|
61 |
+
raise NotImplementedError
|
62 |
+
|
63 |
+
|
64 |
+
model = SpatialBertForSemanticTyping(config)
|
65 |
+
|
66 |
+
|
67 |
+
label_encoder = LabelEncoder()
|
68 |
+
label_encoder.fit(CLASS_9_LIST)
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
london_dataset = PbfMapDataset(data_file_path = london_file_path,
|
73 |
+
tokenizer = tokenizer,
|
74 |
+
max_token_len = max_token_len,
|
75 |
+
distance_norm_factor = distance_norm_factor,
|
76 |
+
spatial_dist_fill = spatial_dist_fill,
|
77 |
+
with_type = with_type,
|
78 |
+
sep_between_neighbors = sep_between_neighbors,
|
79 |
+
label_encoder = label_encoder,
|
80 |
+
num_neighbor_limit = num_neighbor_limit,
|
81 |
+
mode = 'test',)
|
82 |
+
|
83 |
+
california_dataset = PbfMapDataset(data_file_path = california_file_path,
|
84 |
+
tokenizer = tokenizer,
|
85 |
+
max_token_len = max_token_len,
|
86 |
+
distance_norm_factor = distance_norm_factor,
|
87 |
+
spatial_dist_fill = spatial_dist_fill,
|
88 |
+
with_type = with_type,
|
89 |
+
sep_between_neighbors = sep_between_neighbors,
|
90 |
+
label_encoder = label_encoder,
|
91 |
+
num_neighbor_limit = num_neighbor_limit,
|
92 |
+
mode = 'test')
|
93 |
+
|
94 |
+
test_dataset = torch.utils.data.ConcatDataset([london_dataset, california_dataset])
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
test_loader = DataLoader(test_dataset, batch_size= batch_size, num_workers=num_workers,
|
99 |
+
shuffle=False, pin_memory=True, drop_last=False)
|
100 |
+
|
101 |
+
|
102 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
103 |
+
model.to(device)
|
104 |
+
|
105 |
+
|
106 |
+
model.load_state_dict(torch.load(checkpoint_path)) # #
|
107 |
+
|
108 |
+
model.eval()
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
print('start testing...')
|
113 |
+
|
114 |
+
|
115 |
+
# setup loop with TQDM and dataloader
|
116 |
+
loop = tqdm(test_loader, leave=True)
|
117 |
+
|
118 |
+
|
119 |
+
mrr_total = 0.
|
120 |
+
prec_total = 0.
|
121 |
+
sample_cnt = 0
|
122 |
+
|
123 |
+
gt_list = []
|
124 |
+
pred_list = []
|
125 |
+
|
126 |
+
for batch in loop:
|
127 |
+
# initialize calculated gradients (from prev step)
|
128 |
+
|
129 |
+
# pull all tensor batches required for training
|
130 |
+
input_ids = batch['pseudo_sentence'].to(device)
|
131 |
+
attention_mask = batch['attention_mask'].to(device)
|
132 |
+
position_list_x = batch['norm_lng_list'].to(device)
|
133 |
+
position_list_y = batch['norm_lat_list'].to(device)
|
134 |
+
sent_position_ids = batch['sent_position_ids'].to(device)
|
135 |
+
|
136 |
+
#labels = batch['pseudo_sentence'].to(device)
|
137 |
+
labels = batch['pivot_type'].to(device)
|
138 |
+
pivot_lens = batch['pivot_token_len'].to(device)
|
139 |
+
|
140 |
+
outputs = model(input_ids, attention_mask = attention_mask, sent_position_ids = sent_position_ids,
|
141 |
+
position_list_x = position_list_x, position_list_y = position_list_y, labels = labels, pivot_len_list = pivot_lens)
|
142 |
+
|
143 |
+
|
144 |
+
onehot_labels = F.one_hot(labels, num_classes=9)
|
145 |
+
|
146 |
+
gt_list.extend(onehot_labels.cpu().detach().numpy())
|
147 |
+
pred_list.extend(outputs.logits.cpu().detach().numpy())
|
148 |
+
|
149 |
+
#pdb.set_trace()
|
150 |
+
mrr = label_ranking_average_precision_score(onehot_labels.cpu().detach().numpy(), outputs.logits.cpu().detach().numpy())
|
151 |
+
mrr_total += mrr * input_ids.shape[0]
|
152 |
+
sample_cnt += input_ids.shape[0]
|
153 |
+
|
154 |
+
precisions, recalls, fscores, supports = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average=None)
|
155 |
+
|
156 |
+
precision, recall, f1, _ = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average='micro')
|
157 |
+
|
158 |
+
# print('precisions:\n', ["{:.3f}".format(prec) for prec in precisions])
|
159 |
+
# print('recalls:\n', ["{:.3f}".format(rec) for rec in recalls])
|
160 |
+
# print('fscores:\n', ["{:.3f}".format(f1) for f1 in fscores])
|
161 |
+
# print('supports:\n', supports)
|
162 |
+
print('micro P, micro R, micro F1', "{:.3f}".format(precision), "{:.3f}".format(recall), "{:.3f}".format(f1))
|
163 |
+
|
164 |
+
|
165 |
+
|
166 |
+
def main():
|
167 |
+
|
168 |
+
parser = argparse.ArgumentParser()
|
169 |
+
|
170 |
+
parser.add_argument('--max_token_len', type=int, default=300)
|
171 |
+
parser.add_argument('--batch_size', type=int, default=12)
|
172 |
+
parser.add_argument('--num_workers', type=int, default=5)
|
173 |
+
|
174 |
+
parser.add_argument('--num_neighbor_limit', type=int, default = None)
|
175 |
+
|
176 |
+
parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
|
177 |
+
parser.add_argument('--spatial_dist_fill', type=float, default = 20)
|
178 |
+
|
179 |
+
|
180 |
+
parser.add_argument('--with_type', default=False, action='store_true')
|
181 |
+
parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
|
182 |
+
parser.add_argument('--no_spatial_distance', default=False, action='store_true')
|
183 |
+
|
184 |
+
parser.add_argument('--bert_option', type=str, default='bert-base')
|
185 |
+
parser.add_argument('--prediction_save_dir', type=str, default=None)
|
186 |
+
|
187 |
+
parser.add_argument('--checkpoint_path', type=str, default=None)
|
188 |
+
|
189 |
+
|
190 |
+
|
191 |
+
args = parser.parse_args()
|
192 |
+
print('\n')
|
193 |
+
print(args)
|
194 |
+
print('\n')
|
195 |
+
|
196 |
+
|
197 |
+
# out_dir not None, and out_dir does not exist, then create out_dir
|
198 |
+
if args.prediction_save_dir is not None and not os.path.isdir(args.prediction_save_dir):
|
199 |
+
os.makedirs(args.prediction_save_dir)
|
200 |
+
|
201 |
+
testing(args)
|
202 |
+
|
203 |
+
|
204 |
+
|
205 |
+
if __name__ == '__main__':
|
206 |
+
|
207 |
+
main()
|
208 |
+
|
209 |
+
|
models/spabert/experiments/semantic_typing/src/test_cls_baseline.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from tqdm import tqdm # for our progress bar
|
4 |
+
import numpy as np
|
5 |
+
import argparse
|
6 |
+
from sklearn.preprocessing import LabelEncoder
|
7 |
+
import pdb
|
8 |
+
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
from transformers import AdamW
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
sys.path.append('../../../')
|
16 |
+
from datasets.osm_sample_loader import PbfMapDataset
|
17 |
+
from datasets.const import *
|
18 |
+
from utils.baseline_utils import get_baseline_model
|
19 |
+
from models.baseline_typing_model import BaselineForSemanticTyping
|
20 |
+
|
21 |
+
from sklearn.metrics import label_ranking_average_precision_score
|
22 |
+
from sklearn.metrics import precision_recall_fscore_support
|
23 |
+
|
24 |
+
torch.backends.cudnn.deterministic = True
|
25 |
+
torch.backends.cudnn.benchmark = False
|
26 |
+
torch.manual_seed(42)
|
27 |
+
torch.cuda.manual_seed_all(42)
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
MODEL_OPTIONS = ['bert-base','bert-large','roberta-base','roberta-large',
|
32 |
+
'spanbert-base','spanbert-large','luke-base','luke-large',
|
33 |
+
'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large']
|
34 |
+
|
35 |
+
def testing(args):
|
36 |
+
|
37 |
+
num_workers = args.num_workers
|
38 |
+
batch_size = args.batch_size
|
39 |
+
max_token_len = args.max_token_len
|
40 |
+
|
41 |
+
distance_norm_factor = args.distance_norm_factor
|
42 |
+
spatial_dist_fill=args.spatial_dist_fill
|
43 |
+
with_type = args.with_type
|
44 |
+
sep_between_neighbors = args.sep_between_neighbors
|
45 |
+
freeze_backbone = args.freeze_backbone
|
46 |
+
|
47 |
+
|
48 |
+
backbone_option = args.backbone_option
|
49 |
+
|
50 |
+
checkpoint_path = args.checkpoint_path
|
51 |
+
|
52 |
+
assert(backbone_option in MODEL_OPTIONS)
|
53 |
+
|
54 |
+
|
55 |
+
london_file_path = '../data/sql_output/osm-point-london-typing.json'
|
56 |
+
california_file_path = '../data/sql_output/osm-point-california-typing.json'
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
label_encoder = LabelEncoder()
|
61 |
+
label_encoder.fit(CLASS_9_LIST)
|
62 |
+
|
63 |
+
|
64 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
65 |
+
|
66 |
+
backbone_model, tokenizer = get_baseline_model(backbone_option)
|
67 |
+
model = BaselineForSemanticTyping(backbone_model, backbone_model.config.hidden_size, len(CLASS_9_LIST))
|
68 |
+
|
69 |
+
model.load_state_dict(torch.load(checkpoint_path) ) #, strict = False # load sentence position embedding weights as well
|
70 |
+
|
71 |
+
model.to(device)
|
72 |
+
model.train()
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
london_dataset = PbfMapDataset(data_file_path = london_file_path,
|
77 |
+
tokenizer = tokenizer,
|
78 |
+
max_token_len = max_token_len,
|
79 |
+
distance_norm_factor = distance_norm_factor,
|
80 |
+
spatial_dist_fill = spatial_dist_fill,
|
81 |
+
with_type = with_type,
|
82 |
+
sep_between_neighbors = sep_between_neighbors,
|
83 |
+
label_encoder = label_encoder,
|
84 |
+
mode = 'test')
|
85 |
+
|
86 |
+
|
87 |
+
california_dataset = PbfMapDataset(data_file_path = california_file_path,
|
88 |
+
tokenizer = tokenizer,
|
89 |
+
max_token_len = max_token_len,
|
90 |
+
distance_norm_factor = distance_norm_factor,
|
91 |
+
spatial_dist_fill = spatial_dist_fill,
|
92 |
+
with_type = with_type,
|
93 |
+
sep_between_neighbors = sep_between_neighbors,
|
94 |
+
label_encoder = label_encoder,
|
95 |
+
mode = 'test')
|
96 |
+
|
97 |
+
test_dataset = torch.utils.data.ConcatDataset([london_dataset, california_dataset])
|
98 |
+
|
99 |
+
|
100 |
+
test_loader = DataLoader(test_dataset, batch_size= batch_size, num_workers=num_workers,
|
101 |
+
shuffle=False, pin_memory=True, drop_last=False)
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
|
107 |
+
print('start testing...')
|
108 |
+
|
109 |
+
# setup loop with TQDM and dataloader
|
110 |
+
loop = tqdm(test_loader, leave=True)
|
111 |
+
|
112 |
+
mrr_total = 0.
|
113 |
+
prec_total = 0.
|
114 |
+
sample_cnt = 0
|
115 |
+
|
116 |
+
gt_list = []
|
117 |
+
pred_list = []
|
118 |
+
|
119 |
+
for batch in loop:
|
120 |
+
# initialize calculated gradients (from prev step)
|
121 |
+
|
122 |
+
# pull all tensor batches required for training
|
123 |
+
input_ids = batch['pseudo_sentence'].to(device)
|
124 |
+
attention_mask = batch['attention_mask'].to(device)
|
125 |
+
position_ids = batch['sent_position_ids'].to(device)
|
126 |
+
|
127 |
+
#labels = batch['pseudo_sentence'].to(device)
|
128 |
+
labels = batch['pivot_type'].to(device)
|
129 |
+
pivot_lens = batch['pivot_token_len'].to(device)
|
130 |
+
|
131 |
+
outputs = model(input_ids, attention_mask = attention_mask, position_ids = position_ids,
|
132 |
+
labels = labels, pivot_len_list = pivot_lens)
|
133 |
+
|
134 |
+
|
135 |
+
onehot_labels = F.one_hot(labels, num_classes=9)
|
136 |
+
|
137 |
+
gt_list.extend(onehot_labels.cpu().detach().numpy())
|
138 |
+
pred_list.extend(outputs.logits.cpu().detach().numpy())
|
139 |
+
|
140 |
+
mrr = label_ranking_average_precision_score(onehot_labels.cpu().detach().numpy(), outputs.logits.cpu().detach().numpy())
|
141 |
+
mrr_total += mrr * input_ids.shape[0]
|
142 |
+
sample_cnt += input_ids.shape[0]
|
143 |
+
|
144 |
+
precisions, recalls, fscores, supports = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average=None)
|
145 |
+
precision, recall, f1, _ = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average='micro')
|
146 |
+
print('precisions:\n', ["{:.3f}".format(prec) for prec in precisions])
|
147 |
+
print('recalls:\n', ["{:.3f}".format(rec) for rec in recalls])
|
148 |
+
print('fscores:\n', ["{:.3f}".format(f1) for f1 in fscores])
|
149 |
+
print('supports:\n', supports)
|
150 |
+
print('micro P, micro R, micro F1', "{:.3f}".format(precision), "{:.3f}".format(recall), "{:.3f}".format(f1))
|
151 |
+
|
152 |
+
#pdb.set_trace()
|
153 |
+
#print(mrr_total/sample_cnt)
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
def main():
|
158 |
+
|
159 |
+
parser = argparse.ArgumentParser()
|
160 |
+
parser.add_argument('--num_workers', type=int, default=5)
|
161 |
+
parser.add_argument('--batch_size', type=int, default=12)
|
162 |
+
parser.add_argument('--max_token_len', type=int, default=300)
|
163 |
+
|
164 |
+
|
165 |
+
parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
|
166 |
+
parser.add_argument('--spatial_dist_fill', type=float, default = 20)
|
167 |
+
|
168 |
+
parser.add_argument('--with_type', default=False, action='store_true')
|
169 |
+
parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
|
170 |
+
parser.add_argument('--freeze_backbone', default=False, action='store_true')
|
171 |
+
|
172 |
+
parser.add_argument('--backbone_option', type=str, default='bert-base')
|
173 |
+
parser.add_argument('--checkpoint_path', type=str, default=None)
|
174 |
+
|
175 |
+
|
176 |
+
args = parser.parse_args()
|
177 |
+
print('\n')
|
178 |
+
print(args)
|
179 |
+
print('\n')
|
180 |
+
|
181 |
+
|
182 |
+
testing(args)
|
183 |
+
|
184 |
+
|
185 |
+
if __name__ == '__main__':
|
186 |
+
|
187 |
+
main()
|
188 |
+
|
189 |
+
|
models/spabert/experiments/semantic_typing/src/test_cls_spatialbert.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from transformers import RobertaTokenizer, BertTokenizer
|
4 |
+
from tqdm import tqdm # for our progress bar
|
5 |
+
from transformers import AdamW
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
sys.path.append('../../../')
|
12 |
+
from models.spatial_bert_model import SpatialBertModel
|
13 |
+
from models.spatial_bert_model import SpatialBertConfig
|
14 |
+
from models.spatial_bert_model import SpatialBertForMaskedLM, SpatialBertForSemanticTyping
|
15 |
+
from datasets.osm_sample_loader import PbfMapDataset
|
16 |
+
from datasets.const import *
|
17 |
+
from transformers.models.bert.modeling_bert import BertForMaskedLM
|
18 |
+
|
19 |
+
from sklearn.metrics import label_ranking_average_precision_score
|
20 |
+
from sklearn.metrics import precision_recall_fscore_support
|
21 |
+
import numpy as np
|
22 |
+
import argparse
|
23 |
+
from sklearn.preprocessing import LabelEncoder
|
24 |
+
import pdb
|
25 |
+
|
26 |
+
|
27 |
+
DEBUG = False
|
28 |
+
|
29 |
+
torch.backends.cudnn.deterministic = True
|
30 |
+
torch.backends.cudnn.benchmark = False
|
31 |
+
torch.manual_seed(42)
|
32 |
+
torch.cuda.manual_seed_all(42)
|
33 |
+
|
34 |
+
|
35 |
+
def testing(args):
|
36 |
+
|
37 |
+
max_token_len = args.max_token_len
|
38 |
+
batch_size = args.batch_size
|
39 |
+
num_workers = args.num_workers
|
40 |
+
distance_norm_factor = args.distance_norm_factor
|
41 |
+
spatial_dist_fill=args.spatial_dist_fill
|
42 |
+
with_type = args.with_type
|
43 |
+
sep_between_neighbors = args.sep_between_neighbors
|
44 |
+
checkpoint_path = args.checkpoint_path
|
45 |
+
if_no_spatial_distance = args.no_spatial_distance
|
46 |
+
|
47 |
+
bert_option = args.bert_option
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
if args.num_classes == 9:
|
52 |
+
london_file_path = '../../semantic_typing/data/sql_output/osm-point-london-typing.json'
|
53 |
+
california_file_path = '../../semantic_typing/data/sql_output/osm-point-california-typing.json'
|
54 |
+
TYPE_LIST = CLASS_9_LIST
|
55 |
+
type_key_str = 'class'
|
56 |
+
elif args.num_classes == 74:
|
57 |
+
london_file_path = '../../semantic_typing/data/sql_output/osm-point-london-typing-ranking.json'
|
58 |
+
california_file_path = '../../semantic_typing/data/sql_output/osm-point-california-typing-ranking.json'
|
59 |
+
TYPE_LIST = CLASS_74_LIST
|
60 |
+
type_key_str = 'fine_class'
|
61 |
+
else:
|
62 |
+
raise NotImplementedError
|
63 |
+
|
64 |
+
if bert_option == 'bert-base':
|
65 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
66 |
+
config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, num_semantic_types=len(TYPE_LIST))
|
67 |
+
elif bert_option == 'bert-large':
|
68 |
+
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
|
69 |
+
config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24,num_semantic_types=len(TYPE_LIST))
|
70 |
+
else:
|
71 |
+
raise NotImplementedError
|
72 |
+
|
73 |
+
|
74 |
+
model = SpatialBertForSemanticTyping(config)
|
75 |
+
|
76 |
+
|
77 |
+
label_encoder = LabelEncoder()
|
78 |
+
label_encoder.fit(TYPE_LIST)
|
79 |
+
|
80 |
+
london_dataset = PbfMapDataset(data_file_path = london_file_path,
|
81 |
+
tokenizer = tokenizer,
|
82 |
+
max_token_len = max_token_len,
|
83 |
+
distance_norm_factor = distance_norm_factor,
|
84 |
+
spatial_dist_fill = spatial_dist_fill,
|
85 |
+
with_type = with_type,
|
86 |
+
type_key_str = type_key_str,
|
87 |
+
sep_between_neighbors = sep_between_neighbors,
|
88 |
+
label_encoder = label_encoder,
|
89 |
+
mode = 'test')
|
90 |
+
|
91 |
+
california_dataset = PbfMapDataset(data_file_path = california_file_path,
|
92 |
+
tokenizer = tokenizer,
|
93 |
+
max_token_len = max_token_len,
|
94 |
+
distance_norm_factor = distance_norm_factor,
|
95 |
+
spatial_dist_fill = spatial_dist_fill,
|
96 |
+
with_type = with_type,
|
97 |
+
type_key_str = type_key_str,
|
98 |
+
sep_between_neighbors = sep_between_neighbors,
|
99 |
+
label_encoder = label_encoder,
|
100 |
+
mode = 'test')
|
101 |
+
|
102 |
+
test_dataset = torch.utils.data.ConcatDataset([london_dataset, california_dataset])
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
test_loader = DataLoader(test_dataset, batch_size= batch_size, num_workers=num_workers,
|
107 |
+
shuffle=False, pin_memory=True, drop_last=False)
|
108 |
+
|
109 |
+
|
110 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
111 |
+
model.to(device)
|
112 |
+
|
113 |
+
model.load_state_dict(torch.load(checkpoint_path))
|
114 |
+
|
115 |
+
model.eval()
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
print('start testing...')
|
120 |
+
|
121 |
+
|
122 |
+
# setup loop with TQDM and dataloader
|
123 |
+
loop = tqdm(test_loader, leave=True)
|
124 |
+
|
125 |
+
|
126 |
+
mrr_total = 0.
|
127 |
+
prec_total = 0.
|
128 |
+
sample_cnt = 0
|
129 |
+
|
130 |
+
gt_list = []
|
131 |
+
pred_list = []
|
132 |
+
|
133 |
+
for batch in loop:
|
134 |
+
# initialize calculated gradients (from prev step)
|
135 |
+
|
136 |
+
# pull all tensor batches required for training
|
137 |
+
input_ids = batch['pseudo_sentence'].to(device)
|
138 |
+
attention_mask = batch['attention_mask'].to(device)
|
139 |
+
position_list_x = batch['norm_lng_list'].to(device)
|
140 |
+
position_list_y = batch['norm_lat_list'].to(device)
|
141 |
+
sent_position_ids = batch['sent_position_ids'].to(device)
|
142 |
+
|
143 |
+
#labels = batch['pseudo_sentence'].to(device)
|
144 |
+
labels = batch['pivot_type'].to(device)
|
145 |
+
pivot_lens = batch['pivot_token_len'].to(device)
|
146 |
+
|
147 |
+
outputs = model(input_ids, attention_mask = attention_mask, sent_position_ids = sent_position_ids,
|
148 |
+
position_list_x = position_list_x, position_list_y = position_list_y, labels = labels, pivot_len_list = pivot_lens)
|
149 |
+
|
150 |
+
|
151 |
+
onehot_labels = F.one_hot(labels, num_classes=len(TYPE_LIST))
|
152 |
+
|
153 |
+
gt_list.extend(onehot_labels.cpu().detach().numpy())
|
154 |
+
pred_list.extend(outputs.logits.cpu().detach().numpy())
|
155 |
+
|
156 |
+
mrr = label_ranking_average_precision_score(onehot_labels.cpu().detach().numpy(), outputs.logits.cpu().detach().numpy())
|
157 |
+
mrr_total += mrr * input_ids.shape[0]
|
158 |
+
sample_cnt += input_ids.shape[0]
|
159 |
+
|
160 |
+
precisions, recalls, fscores, supports = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average=None)
|
161 |
+
# print('precisions:\n', precisions)
|
162 |
+
# print('recalls:\n', recalls)
|
163 |
+
# print('fscores:\n', fscores)
|
164 |
+
# print('supports:\n', supports)
|
165 |
+
precision, recall, f1, _ = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average='micro')
|
166 |
+
print('precisions:\n', ["{:.3f}".format(prec) for prec in precisions])
|
167 |
+
print('recalls:\n', ["{:.3f}".format(rec) for rec in recalls])
|
168 |
+
print('fscores:\n', ["{:.3f}".format(f1) for f1 in fscores])
|
169 |
+
print('supports:\n', supports)
|
170 |
+
print('micro P, micro R, micro F1', "{:.3f}".format(precision), "{:.3f}".format(recall), "{:.3f}".format(f1))
|
171 |
+
|
172 |
+
|
173 |
+
|
174 |
+
|
175 |
+
def main():
|
176 |
+
|
177 |
+
parser = argparse.ArgumentParser()
|
178 |
+
|
179 |
+
parser.add_argument('--max_token_len', type=int, default=512)
|
180 |
+
parser.add_argument('--batch_size', type=int, default=12)
|
181 |
+
parser.add_argument('--num_workers', type=int, default=5)
|
182 |
+
|
183 |
+
parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
|
184 |
+
parser.add_argument('--spatial_dist_fill', type=float, default = 100)
|
185 |
+
parser.add_argument('--num_classes', type=int, default = 9)
|
186 |
+
|
187 |
+
parser.add_argument('--with_type', default=False, action='store_true')
|
188 |
+
parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
|
189 |
+
parser.add_argument('--no_spatial_distance', default=False, action='store_true')
|
190 |
+
|
191 |
+
parser.add_argument('--bert_option', type=str, default='bert-base')
|
192 |
+
parser.add_argument('--prediction_save_dir', type=str, default=None)
|
193 |
+
|
194 |
+
parser.add_argument('--checkpoint_path', type=str, default=None)
|
195 |
+
|
196 |
+
|
197 |
+
args = parser.parse_args()
|
198 |
+
print('\n')
|
199 |
+
print(args)
|
200 |
+
print('\n')
|
201 |
+
|
202 |
+
|
203 |
+
# out_dir not None, and out_dir does not exist, then create out_dir
|
204 |
+
if args.prediction_save_dir is not None and not os.path.isdir(args.prediction_save_dir):
|
205 |
+
os.makedirs(args.prediction_save_dir)
|
206 |
+
|
207 |
+
testing(args)
|
208 |
+
|
209 |
+
|
210 |
+
if __name__ == '__main__':
|
211 |
+
|
212 |
+
main()
|
213 |
+
|
214 |
+
|
models/spabert/experiments/semantic_typing/src/train_cls_baseline.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from tqdm import tqdm # for our progress bar
|
4 |
+
import numpy as np
|
5 |
+
import argparse
|
6 |
+
from sklearn.preprocessing import LabelEncoder
|
7 |
+
import pdb
|
8 |
+
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
from transformers import AdamW
|
13 |
+
|
14 |
+
sys.path.append('../../../')
|
15 |
+
from datasets.osm_sample_loader import PbfMapDataset
|
16 |
+
from datasets.const import *
|
17 |
+
from utils.baseline_utils import get_baseline_model
|
18 |
+
from models.baseline_typing_model import BaselineForSemanticTyping
|
19 |
+
|
20 |
+
|
21 |
+
MODEL_OPTIONS = ['bert-base','bert-large','roberta-base','roberta-large',
|
22 |
+
'spanbert-base','spanbert-large','luke-base','luke-large',
|
23 |
+
'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large']
|
24 |
+
|
25 |
+
def training(args):
|
26 |
+
|
27 |
+
num_workers = args.num_workers
|
28 |
+
batch_size = args.batch_size
|
29 |
+
epochs = args.epochs
|
30 |
+
lr = args.lr #1e-7 # 5e-5
|
31 |
+
save_interval = args.save_interval
|
32 |
+
max_token_len = args.max_token_len
|
33 |
+
distance_norm_factor = args.distance_norm_factor
|
34 |
+
spatial_dist_fill=args.spatial_dist_fill
|
35 |
+
with_type = args.with_type
|
36 |
+
sep_between_neighbors = args.sep_between_neighbors
|
37 |
+
freeze_backbone = args.freeze_backbone
|
38 |
+
|
39 |
+
|
40 |
+
backbone_option = args.backbone_option
|
41 |
+
|
42 |
+
assert(backbone_option in MODEL_OPTIONS)
|
43 |
+
|
44 |
+
|
45 |
+
london_file_path = '../data/sql_output/osm-point-london-typing.json'
|
46 |
+
california_file_path = '../data/sql_output/osm-point-california-typing.json'
|
47 |
+
|
48 |
+
if args.model_save_dir is None:
|
49 |
+
freeze_pathstr = '_freeze' if freeze_backbone else '_nofreeze'
|
50 |
+
sep_pathstr = '_sep' if sep_between_neighbors else '_nosep'
|
51 |
+
model_save_dir = '/data2/zekun/spatial_bert_baseline_weights/typing_lr' + str("{:.0e}".format(lr)) +'_'+backbone_option+ freeze_pathstr + sep_pathstr + '_london_california_bsize' + str(batch_size)
|
52 |
+
|
53 |
+
if not os.path.isdir(model_save_dir):
|
54 |
+
os.makedirs(model_save_dir)
|
55 |
+
else:
|
56 |
+
model_save_dir = args.model_save_dir
|
57 |
+
|
58 |
+
print('model_save_dir', model_save_dir)
|
59 |
+
print('\n')
|
60 |
+
|
61 |
+
|
62 |
+
label_encoder = LabelEncoder()
|
63 |
+
label_encoder.fit(CLASS_9_LIST)
|
64 |
+
|
65 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
66 |
+
|
67 |
+
backbone_model, tokenizer = get_baseline_model(backbone_option)
|
68 |
+
model = BaselineForSemanticTyping(backbone_model, backbone_model.config.hidden_size, len(CLASS_9_LIST))
|
69 |
+
|
70 |
+
model.to(device)
|
71 |
+
model.train()
|
72 |
+
|
73 |
+
london_train_val_dataset = PbfMapDataset(data_file_path = london_file_path,
|
74 |
+
tokenizer = tokenizer,
|
75 |
+
max_token_len = max_token_len,
|
76 |
+
distance_norm_factor = distance_norm_factor,
|
77 |
+
spatial_dist_fill = spatial_dist_fill,
|
78 |
+
with_type = with_type,
|
79 |
+
sep_between_neighbors = sep_between_neighbors,
|
80 |
+
label_encoder = label_encoder,
|
81 |
+
mode = 'train')
|
82 |
+
|
83 |
+
percent_80 = int(len(london_train_val_dataset) * 0.8)
|
84 |
+
london_train_dataset, london_val_dataset = torch.utils.data.random_split(london_train_val_dataset, [percent_80, len(london_train_val_dataset) - percent_80])
|
85 |
+
|
86 |
+
california_train_val_dataset = PbfMapDataset(data_file_path = california_file_path,
|
87 |
+
tokenizer = tokenizer,
|
88 |
+
max_token_len = max_token_len,
|
89 |
+
distance_norm_factor = distance_norm_factor,
|
90 |
+
spatial_dist_fill = spatial_dist_fill,
|
91 |
+
with_type = with_type,
|
92 |
+
sep_between_neighbors = sep_between_neighbors,
|
93 |
+
label_encoder = label_encoder,
|
94 |
+
mode = 'train')
|
95 |
+
percent_80 = int(len(california_train_val_dataset) * 0.8)
|
96 |
+
california_train_dataset, california_val_dataset = torch.utils.data.random_split(california_train_val_dataset, [percent_80, len(california_train_val_dataset) - percent_80])
|
97 |
+
|
98 |
+
train_dataset = torch.utils.data.ConcatDataset([london_train_dataset, california_train_dataset])
|
99 |
+
val_dataset = torch.utils.data.ConcatDataset([london_val_dataset, california_val_dataset])
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
train_loader = DataLoader(train_dataset, batch_size= batch_size, num_workers=num_workers,
|
104 |
+
shuffle=True, pin_memory=True, drop_last=True)
|
105 |
+
val_loader = DataLoader(val_dataset, batch_size= batch_size, num_workers=num_workers,
|
106 |
+
shuffle=False, pin_memory=True, drop_last=False)
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
# initialize optimizer
|
115 |
+
optim = AdamW(model.parameters(), lr = lr)
|
116 |
+
|
117 |
+
print('start training...')
|
118 |
+
|
119 |
+
for epoch in range(epochs):
|
120 |
+
# setup loop with TQDM and dataloader
|
121 |
+
loop = tqdm(train_loader, leave=True)
|
122 |
+
iter = 0
|
123 |
+
for batch in loop:
|
124 |
+
# initialize calculated gradients (from prev step)
|
125 |
+
optim.zero_grad()
|
126 |
+
# pull all tensor batches required for training
|
127 |
+
input_ids = batch['pseudo_sentence'].to(device)
|
128 |
+
attention_mask = batch['attention_mask'].to(device)
|
129 |
+
position_ids = batch['sent_position_ids'].to(device)
|
130 |
+
|
131 |
+
#labels = batch['pseudo_sentence'].to(device)
|
132 |
+
labels = batch['pivot_type'].to(device)
|
133 |
+
pivot_lens = batch['pivot_token_len'].to(device)
|
134 |
+
|
135 |
+
outputs = model(input_ids, attention_mask = attention_mask, position_ids = position_ids,
|
136 |
+
labels = labels, pivot_len_list = pivot_lens)
|
137 |
+
|
138 |
+
|
139 |
+
loss = outputs.loss
|
140 |
+
loss.backward()
|
141 |
+
optim.step()
|
142 |
+
|
143 |
+
loop.set_description(f'Epoch {epoch}')
|
144 |
+
loop.set_postfix({'loss':loss.item()})
|
145 |
+
|
146 |
+
|
147 |
+
iter += 1
|
148 |
+
|
149 |
+
if iter % save_interval == 0 or iter == loop.total:
|
150 |
+
loss_valid = validating(val_loader, model, device)
|
151 |
+
print('validation loss', loss_valid)
|
152 |
+
|
153 |
+
save_path = os.path.join(model_save_dir, 'ep'+str(epoch) + '_iter'+ str(iter).zfill(5) \
|
154 |
+
+ '_' +str("{:.4f}".format(loss.item())) + '_val' + str("{:.4f}".format(loss_valid)) +'.pth' )
|
155 |
+
|
156 |
+
torch.save(model.state_dict(), save_path)
|
157 |
+
print('saving model checkpoint to', save_path)
|
158 |
+
|
159 |
+
|
160 |
+
|
161 |
+
def validating(val_loader, model, device):
|
162 |
+
|
163 |
+
with torch.no_grad():
|
164 |
+
|
165 |
+
loss_valid = 0
|
166 |
+
loop = tqdm(val_loader, leave=True)
|
167 |
+
|
168 |
+
for batch in loop:
|
169 |
+
input_ids = batch['pseudo_sentence'].to(device)
|
170 |
+
attention_mask = batch['attention_mask'].to(device)
|
171 |
+
position_ids = batch['sent_position_ids'].to(device)
|
172 |
+
|
173 |
+
labels = batch['pivot_type'].to(device)
|
174 |
+
pivot_lens = batch['pivot_token_len'].to(device)
|
175 |
+
|
176 |
+
outputs = model(input_ids, attention_mask = attention_mask, position_ids = position_ids,
|
177 |
+
labels = labels, pivot_len_list = pivot_lens)
|
178 |
+
|
179 |
+
loss_valid += outputs.loss
|
180 |
+
|
181 |
+
loss_valid /= len(val_loader)
|
182 |
+
|
183 |
+
return loss_valid
|
184 |
+
|
185 |
+
|
186 |
+
def main():
|
187 |
+
|
188 |
+
parser = argparse.ArgumentParser()
|
189 |
+
parser.add_argument('--num_workers', type=int, default=5)
|
190 |
+
parser.add_argument('--batch_size', type=int, default=12)
|
191 |
+
parser.add_argument('--epochs', type=int, default=10)
|
192 |
+
parser.add_argument('--save_interval', type=int, default=2000)
|
193 |
+
parser.add_argument('--max_token_len', type=int, default=300)
|
194 |
+
|
195 |
+
|
196 |
+
parser.add_argument('--lr', type=float, default = 5e-5)
|
197 |
+
parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
|
198 |
+
parser.add_argument('--spatial_dist_fill', type=float, default = 20)
|
199 |
+
|
200 |
+
parser.add_argument('--with_type', default=False, action='store_true')
|
201 |
+
parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
|
202 |
+
parser.add_argument('--freeze_backbone', default=False, action='store_true')
|
203 |
+
|
204 |
+
parser.add_argument('--backbone_option', type=str, default='bert-base')
|
205 |
+
parser.add_argument('--model_save_dir', type=str, default=None)
|
206 |
+
|
207 |
+
|
208 |
+
|
209 |
+
args = parser.parse_args()
|
210 |
+
print('\n')
|
211 |
+
print(args)
|
212 |
+
print('\n')
|
213 |
+
|
214 |
+
|
215 |
+
# out_dir not None, and out_dir does not exist, then create out_dir
|
216 |
+
if args.model_save_dir is not None and not os.path.isdir(args.model_save_dir):
|
217 |
+
os.makedirs(args.model_save_dir)
|
218 |
+
|
219 |
+
training(args)
|
220 |
+
|
221 |
+
|
222 |
+
|
223 |
+
if __name__ == '__main__':
|
224 |
+
|
225 |
+
main()
|
226 |
+
|
227 |
+
|
models/spabert/experiments/semantic_typing/src/train_cls_spatialbert.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from transformers import RobertaTokenizer, BertTokenizer
|
4 |
+
from tqdm import tqdm # for our progress bar
|
5 |
+
from transformers import AdamW
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
|
10 |
+
sys.path.append('../../../')
|
11 |
+
from models.spatial_bert_model import SpatialBertModel
|
12 |
+
from models.spatial_bert_model import SpatialBertConfig
|
13 |
+
from models.spatial_bert_model import SpatialBertForMaskedLM, SpatialBertForSemanticTyping
|
14 |
+
from datasets.osm_sample_loader import PbfMapDataset
|
15 |
+
from datasets.const import *
|
16 |
+
|
17 |
+
from transformers.models.bert.modeling_bert import BertForMaskedLM
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import argparse
|
21 |
+
from sklearn.preprocessing import LabelEncoder
|
22 |
+
import pdb
|
23 |
+
|
24 |
+
|
25 |
+
DEBUG = False
|
26 |
+
|
27 |
+
|
28 |
+
def training(args):
|
29 |
+
|
30 |
+
num_workers = args.num_workers
|
31 |
+
batch_size = args.batch_size
|
32 |
+
epochs = args.epochs
|
33 |
+
lr = args.lr #1e-7 # 5e-5
|
34 |
+
save_interval = args.save_interval
|
35 |
+
max_token_len = args.max_token_len
|
36 |
+
distance_norm_factor = args.distance_norm_factor
|
37 |
+
spatial_dist_fill=args.spatial_dist_fill
|
38 |
+
with_type = args.with_type
|
39 |
+
sep_between_neighbors = args.sep_between_neighbors
|
40 |
+
freeze_backbone = args.freeze_backbone
|
41 |
+
mlm_checkpoint_path = args.mlm_checkpoint_path
|
42 |
+
|
43 |
+
if_no_spatial_distance = args.no_spatial_distance
|
44 |
+
|
45 |
+
|
46 |
+
bert_option = args.bert_option
|
47 |
+
|
48 |
+
assert bert_option in ['bert-base','bert-large']
|
49 |
+
|
50 |
+
if args.num_classes == 9:
|
51 |
+
london_file_path = '../../semantic_typing/data/sql_output/osm-point-london-typing.json'
|
52 |
+
california_file_path = '../../semantic_typing/data/sql_output/osm-point-california-typing.json'
|
53 |
+
TYPE_LIST = CLASS_9_LIST
|
54 |
+
type_key_str = 'class'
|
55 |
+
elif args.num_classes == 74:
|
56 |
+
london_file_path = '../../semantic_typing/data/sql_output/osm-point-london-typing-ranking.json'
|
57 |
+
california_file_path = '../../semantic_typing/data/sql_output/osm-point-california-typing-ranking.json'
|
58 |
+
TYPE_LIST = CLASS_74_LIST
|
59 |
+
type_key_str = 'fine_class'
|
60 |
+
else:
|
61 |
+
raise NotImplementedError
|
62 |
+
|
63 |
+
|
64 |
+
if args.model_save_dir is None:
|
65 |
+
checkpoint_basename = os.path.basename(mlm_checkpoint_path)
|
66 |
+
checkpoint_prefix = checkpoint_basename.replace("mlm_mem_keeppos_","").strip('.pth')
|
67 |
+
|
68 |
+
sep_pathstr = '_sep' if sep_between_neighbors else '_nosep'
|
69 |
+
freeze_pathstr = '_freeze' if freeze_backbone else '_nofreeze'
|
70 |
+
if if_no_spatial_distance:
|
71 |
+
model_save_dir = '/data2/zekun/spatial_bert_weights_ablation/'
|
72 |
+
else:
|
73 |
+
model_save_dir = '/data2/zekun/spatial_bert_weights/'
|
74 |
+
model_save_dir = os.path.join(model_save_dir, 'typing_lr' + str("{:.0e}".format(lr)) + sep_pathstr +'_'+bert_option+ freeze_pathstr + '_london_california_bsize' + str(batch_size) )
|
75 |
+
model_save_dir = os.path.join(model_save_dir, checkpoint_prefix)
|
76 |
+
|
77 |
+
if not os.path.isdir(model_save_dir):
|
78 |
+
os.makedirs(model_save_dir)
|
79 |
+
else:
|
80 |
+
model_save_dir = args.model_save_dir
|
81 |
+
|
82 |
+
|
83 |
+
print('model_save_dir', model_save_dir)
|
84 |
+
print('\n')
|
85 |
+
|
86 |
+
if bert_option == 'bert-base':
|
87 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
88 |
+
config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, num_semantic_types=len(TYPE_LIST))
|
89 |
+
elif bert_option == 'bert-large':
|
90 |
+
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
|
91 |
+
config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24, num_semantic_types=len(TYPE_LIST))
|
92 |
+
else:
|
93 |
+
raise NotImplementedError
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
label_encoder = LabelEncoder()
|
98 |
+
label_encoder.fit(TYPE_LIST)
|
99 |
+
|
100 |
+
|
101 |
+
london_train_val_dataset = PbfMapDataset(data_file_path = london_file_path,
|
102 |
+
tokenizer = tokenizer,
|
103 |
+
max_token_len = max_token_len,
|
104 |
+
distance_norm_factor = distance_norm_factor,
|
105 |
+
spatial_dist_fill = spatial_dist_fill,
|
106 |
+
with_type = with_type,
|
107 |
+
type_key_str = type_key_str,
|
108 |
+
sep_between_neighbors = sep_between_neighbors,
|
109 |
+
label_encoder = label_encoder,
|
110 |
+
mode = 'train')
|
111 |
+
|
112 |
+
percent_80 = int(len(london_train_val_dataset) * 0.8)
|
113 |
+
london_train_dataset, london_val_dataset = torch.utils.data.random_split(london_train_val_dataset, [percent_80, len(london_train_val_dataset) - percent_80])
|
114 |
+
|
115 |
+
california_train_val_dataset = PbfMapDataset(data_file_path = california_file_path,
|
116 |
+
tokenizer = tokenizer,
|
117 |
+
max_token_len = max_token_len,
|
118 |
+
distance_norm_factor = distance_norm_factor,
|
119 |
+
spatial_dist_fill = spatial_dist_fill,
|
120 |
+
with_type = with_type,
|
121 |
+
type_key_str = type_key_str,
|
122 |
+
sep_between_neighbors = sep_between_neighbors,
|
123 |
+
label_encoder = label_encoder,
|
124 |
+
mode = 'train')
|
125 |
+
percent_80 = int(len(california_train_val_dataset) * 0.8)
|
126 |
+
california_train_dataset, california_val_dataset = torch.utils.data.random_split(california_train_val_dataset, [percent_80, len(california_train_val_dataset) - percent_80])
|
127 |
+
|
128 |
+
train_dataset = torch.utils.data.ConcatDataset([london_train_dataset, california_train_dataset])
|
129 |
+
val_dataset = torch.utils.data.ConcatDataset([london_val_dataset, california_val_dataset])
|
130 |
+
|
131 |
+
|
132 |
+
if DEBUG:
|
133 |
+
train_loader = DataLoader(train_dataset, batch_size= batch_size, num_workers=num_workers,
|
134 |
+
shuffle=False, pin_memory=True, drop_last=True)
|
135 |
+
val_loader = DataLoader(val_dataset, batch_size= batch_size, num_workers=num_workers,
|
136 |
+
shuffle=False, pin_memory=True, drop_last=False)
|
137 |
+
else:
|
138 |
+
train_loader = DataLoader(train_dataset, batch_size= batch_size, num_workers=num_workers,
|
139 |
+
shuffle=True, pin_memory=True, drop_last=True)
|
140 |
+
val_loader = DataLoader(val_dataset, batch_size= batch_size, num_workers=num_workers,
|
141 |
+
shuffle=False, pin_memory=True, drop_last=False)
|
142 |
+
|
143 |
+
|
144 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
145 |
+
|
146 |
+
model = SpatialBertForSemanticTyping(config)
|
147 |
+
model.to(device)
|
148 |
+
|
149 |
+
|
150 |
+
model.load_state_dict(torch.load(mlm_checkpoint_path), strict = False)
|
151 |
+
|
152 |
+
model.train()
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
# initialize optimizer
|
157 |
+
optim = AdamW(model.parameters(), lr = lr)
|
158 |
+
|
159 |
+
print('start training...')
|
160 |
+
|
161 |
+
for epoch in range(epochs):
|
162 |
+
# setup loop with TQDM and dataloader
|
163 |
+
loop = tqdm(train_loader, leave=True)
|
164 |
+
iter = 0
|
165 |
+
for batch in loop:
|
166 |
+
# initialize calculated gradients (from prev step)
|
167 |
+
optim.zero_grad()
|
168 |
+
# pull all tensor batches required for training
|
169 |
+
input_ids = batch['pseudo_sentence'].to(device)
|
170 |
+
attention_mask = batch['attention_mask'].to(device)
|
171 |
+
position_list_x = batch['norm_lng_list'].to(device)
|
172 |
+
position_list_y = batch['norm_lat_list'].to(device)
|
173 |
+
sent_position_ids = batch['sent_position_ids'].to(device)
|
174 |
+
|
175 |
+
#labels = batch['pseudo_sentence'].to(device)
|
176 |
+
labels = batch['pivot_type'].to(device)
|
177 |
+
pivot_lens = batch['pivot_token_len'].to(device)
|
178 |
+
|
179 |
+
outputs = model(input_ids, attention_mask = attention_mask, sent_position_ids = sent_position_ids,
|
180 |
+
position_list_x = position_list_x, position_list_y = position_list_y, labels = labels, pivot_len_list = pivot_lens)
|
181 |
+
|
182 |
+
|
183 |
+
loss = outputs.loss
|
184 |
+
loss.backward()
|
185 |
+
optim.step()
|
186 |
+
|
187 |
+
loop.set_description(f'Epoch {epoch}')
|
188 |
+
loop.set_postfix({'loss':loss.item()})
|
189 |
+
|
190 |
+
if DEBUG:
|
191 |
+
print('ep'+str(epoch)+'_' + '_iter'+ str(iter).zfill(5), loss.item() )
|
192 |
+
|
193 |
+
iter += 1
|
194 |
+
|
195 |
+
if iter % save_interval == 0 or iter == loop.total:
|
196 |
+
loss_valid = validating(val_loader, model, device)
|
197 |
+
|
198 |
+
save_path = os.path.join(model_save_dir, 'keeppos_ep'+str(epoch) + '_iter'+ str(iter).zfill(5) \
|
199 |
+
+ '_' +str("{:.4f}".format(loss.item())) + '_val' + str("{:.4f}".format(loss_valid)) +'.pth' )
|
200 |
+
|
201 |
+
torch.save(model.state_dict(), save_path)
|
202 |
+
print('validation loss', loss_valid)
|
203 |
+
print('saving model checkpoint to', save_path)
|
204 |
+
|
205 |
+
def validating(val_loader, model, device):
|
206 |
+
|
207 |
+
with torch.no_grad():
|
208 |
+
|
209 |
+
loss_valid = 0
|
210 |
+
loop = tqdm(val_loader, leave=True)
|
211 |
+
|
212 |
+
for batch in loop:
|
213 |
+
input_ids = batch['pseudo_sentence'].to(device)
|
214 |
+
attention_mask = batch['attention_mask'].to(device)
|
215 |
+
position_list_x = batch['norm_lng_list'].to(device)
|
216 |
+
position_list_y = batch['norm_lat_list'].to(device)
|
217 |
+
sent_position_ids = batch['sent_position_ids'].to(device)
|
218 |
+
|
219 |
+
|
220 |
+
labels = batch['pivot_type'].to(device)
|
221 |
+
pivot_lens = batch['pivot_token_len'].to(device)
|
222 |
+
|
223 |
+
outputs = model(input_ids, attention_mask = attention_mask, sent_position_ids = sent_position_ids,
|
224 |
+
position_list_x = position_list_x, position_list_y = position_list_y, labels = labels, pivot_len_list = pivot_lens)
|
225 |
+
|
226 |
+
loss_valid += outputs.loss
|
227 |
+
|
228 |
+
loss_valid /= len(val_loader)
|
229 |
+
|
230 |
+
return loss_valid
|
231 |
+
|
232 |
+
|
233 |
+
def main():
|
234 |
+
|
235 |
+
parser = argparse.ArgumentParser()
|
236 |
+
parser.add_argument('--num_workers', type=int, default=5)
|
237 |
+
parser.add_argument('--batch_size', type=int, default=12)
|
238 |
+
parser.add_argument('--epochs', type=int, default=10)
|
239 |
+
parser.add_argument('--save_interval', type=int, default=2000)
|
240 |
+
parser.add_argument('--max_token_len', type=int, default=512)
|
241 |
+
|
242 |
+
|
243 |
+
parser.add_argument('--lr', type=float, default = 5e-5)
|
244 |
+
parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
|
245 |
+
parser.add_argument('--spatial_dist_fill', type=float, default = 100)
|
246 |
+
parser.add_argument('--num_classes', type=int, default = 9)
|
247 |
+
|
248 |
+
parser.add_argument('--with_type', default=False, action='store_true')
|
249 |
+
parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
|
250 |
+
parser.add_argument('--freeze_backbone', default=False, action='store_true')
|
251 |
+
parser.add_argument('--no_spatial_distance', default=False, action='store_true')
|
252 |
+
|
253 |
+
parser.add_argument('--bert_option', type=str, default='bert-base')
|
254 |
+
parser.add_argument('--model_save_dir', type=str, default=None)
|
255 |
+
|
256 |
+
parser.add_argument('--mlm_checkpoint_path', type=str, default=None)
|
257 |
+
|
258 |
+
|
259 |
+
args = parser.parse_args()
|
260 |
+
print('\n')
|
261 |
+
print(args)
|
262 |
+
print('\n')
|
263 |
+
|
264 |
+
|
265 |
+
# out_dir not None, and out_dir does not exist, then create out_dir
|
266 |
+
if args.model_save_dir is not None and not os.path.isdir(args.model_save_dir):
|
267 |
+
os.makedirs(args.model_save_dir)
|
268 |
+
|
269 |
+
training(args)
|
270 |
+
|
271 |
+
|
272 |
+
if __name__ == '__main__':
|
273 |
+
|
274 |
+
main()
|
275 |
+
|
276 |
+
|
models/spabert/models/__init__.py
ADDED
File without changes
|
models/spabert/models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (150 Bytes). View file
|
|
models/spabert/models/__pycache__/spatial_bert_model.cpython-310.pyc
ADDED
Binary file (20.8 kB). View file
|
|
models/spabert/models/baseline_typing_model.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.utils.checkpoint
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
6 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
7 |
+
|
8 |
+
|
9 |
+
class PivotEntityPooler(nn.Module):
|
10 |
+
def __init__(self):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
|
14 |
+
def forward(self, hidden_states, pivot_len_list):
|
15 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
16 |
+
# to the tokens of pivot entity
|
17 |
+
|
18 |
+
bsize = hidden_states.shape[0]
|
19 |
+
|
20 |
+
tensor_list = []
|
21 |
+
for i in torch.arange(0, bsize):
|
22 |
+
pivot_token_full = hidden_states[i, 1:pivot_len_list[i]+1]
|
23 |
+
pivot_token_tensor = torch.mean(torch.unsqueeze(pivot_token_full, 0), dim = 1)
|
24 |
+
tensor_list.append(pivot_token_tensor)
|
25 |
+
|
26 |
+
|
27 |
+
batch_pivot_tensor = torch.cat(tensor_list, dim = 0)
|
28 |
+
|
29 |
+
return batch_pivot_tensor
|
30 |
+
|
31 |
+
|
32 |
+
class BaselineTypingHead(nn.Module):
|
33 |
+
def __init__(self, hidden_size, num_semantic_types):
|
34 |
+
super().__init__()
|
35 |
+
self.dense = nn.Linear(hidden_size, hidden_size)
|
36 |
+
self.activation = nn.Tanh()
|
37 |
+
|
38 |
+
self.seq_relationship = nn.Linear(hidden_size, num_semantic_types)
|
39 |
+
|
40 |
+
def forward(self, pivot_pooled_output):
|
41 |
+
|
42 |
+
pivot_pooled_output = self.dense(pivot_pooled_output)
|
43 |
+
pivot_pooled_output = self.activation(pivot_pooled_output)
|
44 |
+
|
45 |
+
seq_relationship_score = self.seq_relationship(pivot_pooled_output)
|
46 |
+
return seq_relationship_score
|
47 |
+
|
48 |
+
|
49 |
+
class BaselineForSemanticTyping(nn.Module):
|
50 |
+
|
51 |
+
def __init__(self, backbone, hidden_size, num_semantic_types):
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
self.backbone = backbone
|
55 |
+
self.pivot_pooler = PivotEntityPooler()
|
56 |
+
self.num_semantic_types = num_semantic_types
|
57 |
+
|
58 |
+
self.cls = BaselineTypingHead(hidden_size, num_semantic_types)
|
59 |
+
|
60 |
+
|
61 |
+
def forward(
|
62 |
+
self,
|
63 |
+
input_ids=None,
|
64 |
+
position_ids = None,
|
65 |
+
pivot_len_list = None,
|
66 |
+
attention_mask=None,
|
67 |
+
head_mask=None,
|
68 |
+
labels=None,
|
69 |
+
output_attentions=None,
|
70 |
+
output_hidden_states=None,
|
71 |
+
return_dict = True
|
72 |
+
):
|
73 |
+
|
74 |
+
outputs = self.backbone(
|
75 |
+
input_ids,
|
76 |
+
attention_mask=attention_mask,
|
77 |
+
position_ids = position_ids,
|
78 |
+
head_mask=head_mask,
|
79 |
+
output_attentions=output_attentions,
|
80 |
+
output_hidden_states=output_hidden_states,
|
81 |
+
)
|
82 |
+
|
83 |
+
|
84 |
+
sequence_output = outputs[0]
|
85 |
+
pooled_output = self.pivot_pooler(sequence_output, pivot_len_list)
|
86 |
+
|
87 |
+
|
88 |
+
type_prediction_score = self.cls(pooled_output)
|
89 |
+
|
90 |
+
typing_loss = None
|
91 |
+
if labels is not None:
|
92 |
+
loss_fct = CrossEntropyLoss()
|
93 |
+
typing_loss = loss_fct(type_prediction_score.view(-1, self.num_semantic_types), labels.view(-1))
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
if not return_dict:
|
98 |
+
output = (type_prediction_score,) + outputs[2:]
|
99 |
+
return ((typing_loss,) + output) if typing_loss is not None else output
|
100 |
+
|
101 |
+
return SequenceClassifierOutput(
|
102 |
+
loss=typing_loss,
|
103 |
+
logits=type_prediction_score,
|
104 |
+
hidden_states=outputs.hidden_states,
|
105 |
+
attentions=outputs.attentions,
|
106 |
+
)
|