JasonTPhillipsJr commited on
Commit
46e0dd0
·
verified ·
1 Parent(s): ebf50a4

Upload 76 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. models/spabert/README.md +69 -0
  3. models/spabert/__init__.py +0 -0
  4. models/spabert/datasets/__init__.py +0 -0
  5. models/spabert/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  6. models/spabert/datasets/__pycache__/dataset_loader.cpython-310.pyc +0 -0
  7. models/spabert/datasets/__pycache__/dataset_loader_ver2.cpython-310.pyc +0 -0
  8. models/spabert/datasets/__pycache__/osm_sample_loader.cpython-310.pyc +0 -0
  9. models/spabert/datasets/__pycache__/usgs_os_sample_loader.cpython-310.pyc +0 -0
  10. models/spabert/datasets/__pycache__/wikidata_sample_loader.cpython-310.pyc +0 -0
  11. models/spabert/datasets/const.py +162 -0
  12. models/spabert/datasets/dataset_loader.py +162 -0
  13. models/spabert/datasets/dataset_loader_ver2.py +164 -0
  14. models/spabert/datasets/osm_sample_loader.py +246 -0
  15. models/spabert/datasets/usgs_os_sample_loader.py +71 -0
  16. models/spabert/datasets/wikidata_sample_loader.py +127 -0
  17. models/spabert/experiments/__init__.py +0 -0
  18. models/spabert/experiments/__pycache__/__init__.cpython-310.pyc +0 -0
  19. models/spabert/experiments/entity_matching/__init__.py +0 -0
  20. models/spabert/experiments/entity_matching/__pycache__/__init__.cpython-310.pyc +0 -0
  21. models/spabert/experiments/entity_matching/data_processing/__init__.py +0 -0
  22. models/spabert/experiments/entity_matching/data_processing/__pycache__/__init__.cpython-310.pyc +0 -0
  23. models/spabert/experiments/entity_matching/data_processing/__pycache__/__init__.cpython-311.pyc +0 -0
  24. models/spabert/experiments/entity_matching/data_processing/__pycache__/request_wrapper.cpython-310.pyc +0 -0
  25. models/spabert/experiments/entity_matching/data_processing/__pycache__/request_wrapper.cpython-311.pyc +0 -0
  26. models/spabert/experiments/entity_matching/data_processing/get_namelist.py +95 -0
  27. models/spabert/experiments/entity_matching/data_processing/request_wrapper.py +186 -0
  28. models/spabert/experiments/entity_matching/data_processing/run_linking_query.py +143 -0
  29. models/spabert/experiments/entity_matching/data_processing/run_map_neighbor_query.py +123 -0
  30. models/spabert/experiments/entity_matching/data_processing/run_query_sample.py +22 -0
  31. models/spabert/experiments/entity_matching/data_processing/run_wikidata_neighbor_query.py +31 -0
  32. models/spabert/experiments/entity_matching/data_processing/samples.sparql +22 -0
  33. models/spabert/experiments/entity_matching/data_processing/select_ambi.py +18 -0
  34. models/spabert/experiments/entity_matching/data_processing/wikidata_sample30k/wikidata_30k.json +0 -0
  35. models/spabert/experiments/entity_matching/src/evaluation-mrr.py +260 -0
  36. models/spabert/experiments/entity_matching/src/linking_ablation.py +228 -0
  37. models/spabert/experiments/entity_matching/src/unsupervised_wiki_location_allcand.py +329 -0
  38. models/spabert/experiments/semantic_typing/__init__.py +0 -0
  39. models/spabert/experiments/semantic_typing/data_processing/merge_osm_json.py +97 -0
  40. models/spabert/experiments/semantic_typing/src/__init__.py +0 -0
  41. models/spabert/experiments/semantic_typing/src/run_baseline_test.py +82 -0
  42. models/spabert/experiments/semantic_typing/src/test_cls_ablation_spatialbert.py +209 -0
  43. models/spabert/experiments/semantic_typing/src/test_cls_baseline.py +189 -0
  44. models/spabert/experiments/semantic_typing/src/test_cls_spatialbert.py +214 -0
  45. models/spabert/experiments/semantic_typing/src/train_cls_baseline.py +227 -0
  46. models/spabert/experiments/semantic_typing/src/train_cls_spatialbert.py +276 -0
  47. models/spabert/models/__init__.py +0 -0
  48. models/spabert/models/__pycache__/__init__.cpython-310.pyc +0 -0
  49. models/spabert/models/__pycache__/spatial_bert_model.cpython-310.pyc +0 -0
  50. models/spabert/models/baseline_typing_model.py +106 -0
.gitattributes CHANGED
@@ -37,3 +37,7 @@ models/en_core_web_sm/en_core_web_sm-3.7.1/ner/model filter=lfs diff=lfs merge=l
37
  models/en_core_web_sm/en_core_web_sm-3.7.1/tok2vec/model filter=lfs diff=lfs merge=lfs -text
38
  models/en_core_web_sm/ner/model filter=lfs diff=lfs merge=lfs -text
39
  models/en_core_web_sm/tok2vec/model filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
37
  models/en_core_web_sm/en_core_web_sm-3.7.1/tok2vec/model filter=lfs diff=lfs merge=lfs -text
38
  models/en_core_web_sm/ner/model filter=lfs diff=lfs merge=lfs -text
39
  models/en_core_web_sm/tok2vec/model filter=lfs diff=lfs merge=lfs -text
40
+ models/spabert/notebooks/tutorial_datasets/output.csv.json filter=lfs diff=lfs merge=lfs -text
41
+ models/spabert/notebooks/tutorial_datasets/spabert_osm_mn.json filter=lfs diff=lfs merge=lfs -text
42
+ models/spabert/notebooks/tutorial_datasets/spabert_whg_wikidata.json filter=lfs diff=lfs merge=lfs -text
43
+ models/spabert/notebooks/tutorial_datasets/spabert_wikidata_sampled.json filter=lfs diff=lfs merge=lfs -text
models/spabert/README.md ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SpaBERT: A Pretrained Language Model from Geographic Data for Geo-Entity Representation
2
+
3
+ This repo contains code for [SpaBERT: A Pretrained Language Model from Geographic Data for Geo-Entity Representation](https://arxiv.org/abs/2210.12213) which was published in EMNLP 2022. SpaBERT provides a general-purpose geo-entity representation based on neighboring entities in geospatial data. SpaBERT extends BERT to capture linearized spatial context, while incorporating a spatial coordinate embedding mechanism to preserve spatial relations of entities in the 2-dimensional space. SpaBERT is pretrained with masked language modeling and masked entity prediction tasks to learn spatial dependencies.
4
+
5
+ * Slides: [emnlp22-spabert.pdf](https://drive.google.com/file/d/1V1URsRfpw13dbkb_zgBXeNqZJ0AF2744/view?usp=share_link)
6
+
7
+
8
+ ## Pretraining
9
+ Pretrained model weights can be downloaded from the Google Drive for [SpaBERT-base](https://drive.google.com/file/d/1l44FY3DtDxzM_YVh3RR6PJwKnl80IYWB/view?usp=sharing) and [SpaBERT-large](https://drive.google.com/file/d/1LeZayTR92R5bu9gH_cGCwef7nnMX35cR/view?usp=share_link).
10
+
11
+ Weights can also obtained from training from scratch using the following sample code. Data for pretraining can be downloaded [here](https://drive.google.com/drive/folders/1eaeVvUCcJVcNwnyTCk-1N1IKfukihk4j?usp=share_link).
12
+
13
+ * Code to pretrain SpaBERT-base model:
14
+
15
+ ```python3 train_mlm.py --lr=5e-5 --sep_between_neighbors --bert_option='bert-base'```
16
+
17
+ * Code to pretrain SpaBERT-large model:
18
+
19
+ ```python3 train_mlm.py --lr=1e-6 --sep_between_neighbors --bert_option='bert-large```
20
+
21
+ ## Downstream Tasks
22
+ ### Supervised Geo-entity typing
23
+ The goal is to predict a geo-entity’s semantic type (e.g., transportation and healthcare) given the target geo-entity name and spatial context (i.e. surrounding neighbors name and location).
24
+
25
+ Models trained on OSM in London and California region can be downloaded from Google Drive for [SpaBERT-base](https://drive.google.com/file/d/1XFcA3sxC4wTlt7VjvMp1zNrWY5rjafzE/view?usp=share_link) and [SpaBERT-large](https://drive.google.com/file/d/12_FDVeSYkl_HQ61JmuMU6cRjQdKNpgR_/view?usp=share_link)
26
+
27
+ Data used for training and testing can be downloaded [here](https://drive.google.com/drive/folders/1uyvGdiJdu-Cym4dOKhQLIkKpfgHvfo01?usp=share_link)
28
+
29
+ * Sample code for training SpaBERT-base typing model
30
+
31
+ ```
32
+ python3 train_cls_spatialbert.py --lr=5e-5 --sep_between_neighbors --bert_option='bert-base' --with_type --mlm_checkpoint_path='mlm_mem_keeppos_ep0_iter06000_0.2936.pth'
33
+ ```
34
+
35
+ * Sample code for training SpaBERT-large typing model
36
+
37
+ ```
38
+ python3 train_cls_spatialbert.py --lr=1e-6 --sep_between_neighbors --bert_option='bert-large' --with_type --mlm_checkpoint_path='mlm_mem_keeppos_ep1_iter02000_0.4400.pth' --epochs=20
39
+ ```
40
+
41
+ ### Unsupervised Geo-entity Linking
42
+
43
+ Geo-entity linking is to link geo-entities from a geographic information system (GIS) oriented dataset to a knowledge base (KB). This task unsupervised thus does not require any further training. Pretrained models can be directly used for this task.
44
+
45
+
46
+ Linking with SpaBERT-base
47
+ ```
48
+ python3 unsupervised_wiki_location_allcand.py --model_name='spatial_bert-base' --sep_between_neighbors \
49
+ --spatial_bert_weight_dir='weights/' --spatial_bert_weight_name='mlm_mem_keeppos_ep0_iter06000_0.2936.pth'
50
+
51
+ ```
52
+
53
+ Linking with SpaBERT-large
54
+ ```
55
+ python3 unsupervised_wiki_location_allcand.py --model_name='spatial_bert-large' --sep_between_neighbors \
56
+ --spatial_bert_weight_dir='weights/' --spatial_bert_weight_name='mlm_mem_keeppos_ep1_iter02000_0.4400.pth'
57
+ ```
58
+
59
+ Data used for linking from USGS historical maps to WikiData KB is provided [here](https://drive.google.com/drive/folders/1qKJnj71qxnca_TaygK-Y3EIySnMyFpFn?usp=share_link)
60
+
61
+ ## Acknowledgement
62
+ ```
63
+ @article{li2022spabert,
64
+ title={SpaBERT: A Pretrained Language Model from Geographic Data for Geo-Entity Representation},
65
+ author={Zekun Li, Jina Kim, Yao-Yi Chiang and Muhao Chen},
66
+ journal={EMNLP},
67
+ year={2022}
68
+ }
69
+ ```
models/spabert/__init__.py ADDED
File without changes
models/spabert/datasets/__init__.py ADDED
File without changes
models/spabert/datasets/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (152 Bytes). View file
 
models/spabert/datasets/__pycache__/dataset_loader.cpython-310.pyc ADDED
Binary file (3.75 kB). View file
 
models/spabert/datasets/__pycache__/dataset_loader_ver2.cpython-310.pyc ADDED
Binary file (3 kB). View file
 
models/spabert/datasets/__pycache__/osm_sample_loader.cpython-310.pyc ADDED
Binary file (5.79 kB). View file
 
models/spabert/datasets/__pycache__/usgs_os_sample_loader.cpython-310.pyc ADDED
Binary file (2.26 kB). View file
 
models/spabert/datasets/__pycache__/wikidata_sample_loader.cpython-310.pyc ADDED
Binary file (3.36 kB). View file
 
models/spabert/datasets/const.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def revert_dict(coarse_to_fine_dict):
2
+ fine_to_coarse_dict = dict()
3
+ for key, value in coarse_to_fine_dict.items():
4
+ for v in value:
5
+ fine_to_coarse_dict[v] = key
6
+ return fine_to_coarse_dict
7
+
8
+ CLASS_9_LIST = ['education', 'entertainment_arts_culture', 'facilities', 'financial', 'healthcare', 'public_service', 'sustenance', 'transportation', 'waste_management']
9
+
10
+ CLASS_118_LIST=['animal_boarding', 'animal_breeding', 'animal_shelter', 'arts_centre', 'atm', 'baby_hatch', 'baking_oven', 'bank', 'bar', 'bbq', 'bench', 'bicycle_parking', 'bicycle_rental', 'bicycle_repair_station', 'biergarten', 'boat_rental', 'boat_sharing', 'brothel', 'bureau_de_change', 'bus_station', 'cafe', 'car_rental', 'car_sharing', 'car_wash', 'casino', 'charging_station', 'childcare', 'cinema', 'clinic', 'clock', 'college', 'community_centre', 'compressed_air', 'conference_centre', 'courthouse', 'crematorium', 'dentist', 'dive_centre', 'doctors', 'dog_toilet', 'dressing_room', 'drinking_water', 'driving_school', 'events_venue', 'fast_food', 'ferry_terminal', 'fire_station', 'food_court', 'fountain', 'fuel', 'funeral_hall', 'gambling', 'give_box', 'grave_yard', 'grit_bin', 'hospital', 'hunting_stand', 'ice_cream', 'internet_cafe', 'kindergarten', 'kitchen', 'kneipp_water_cure', 'language_school', 'library', 'lounger', 'love_hotel', 'marketplace', 'monastery', 'motorcycle_parking', 'music_school', 'nightclub', 'nursing_home', 'parcel_locker', 'parking', 'parking_entrance', 'parking_space', 'pharmacy', 'photo_booth', 'place_of_mourning', 'place_of_worship', 'planetarium', 'police', 'post_box', 'post_depot', 'post_office', 'prison', 'pub', 'public_bath', 'public_bookcase', 'ranger_station', 'recycling', 'refugee_site', 'restaurant', 'sanitary_dump_station', 'school', 'shelter', 'shower', 'social_centre', 'social_facility', 'stripclub', 'studio', 'swingerclub', 'taxi', 'telephone', 'theatre', 'toilets', 'townhall', 'toy_library', 'training', 'university', 'vehicle_inspection', 'vending_machine', 'veterinary', 'waste_basket', 'waste_disposal', 'waste_transfer_station', 'water_point', 'watering_place']
11
+
12
+ CLASS_95_LIST = ['arts_centre', 'atm', 'baby_hatch', 'bank', 'bar', 'bbq', 'bench', 'bicycle_parking', 'bicycle_rental', 'bicycle_repair_station', 'biergarten', 'boat_rental', 'boat_sharing', 'brothel', 'bureau_de_change', 'bus_station', 'cafe', 'car_rental', 'car_sharing', 'car_wash', 'casino', 'charging_station', 'cinema', 'clinic', 'college', 'community_centre', 'compressed_air', 'conference_centre', 'courthouse', 'dentist', 'doctors', 'dog_toilet', 'dressing_room', 'drinking_water', 'driving_school', 'events_venue', 'fast_food', 'ferry_terminal', 'fire_station', 'food_court', 'fountain', 'fuel', 'gambling', 'give_box', 'grit_bin', 'hospital', 'ice_cream', 'kindergarten', 'language_school', 'library', 'love_hotel', 'motorcycle_parking', 'music_school', 'nightclub', 'nursing_home', 'parcel_locker', 'parking', 'parking_entrance', 'parking_space', 'pharmacy', 'planetarium', 'police', 'post_box', 'post_depot', 'post_office', 'prison', 'pub', 'public_bookcase', 'ranger_station', 'recycling', 'restaurant', 'sanitary_dump_station', 'school', 'shelter', 'shower', 'social_centre', 'social_facility', 'stripclub', 'studio', 'swingerclub', 'taxi', 'telephone', 'theatre', 'toilets', 'townhall', 'toy_library', 'training', 'university', 'vehicle_inspection', 'veterinary', 'waste_basket', 'waste_disposal', 'waste_transfer_station', 'water_point', 'watering_place']
13
+
14
+ CLASS_74_LIST = ['arts_centre', 'atm', 'bank', 'bar', 'bench', 'bicycle_parking', 'bicycle_rental', 'bicycle_repair_station', 'boat_rental', 'bureau_de_change', 'bus_station', 'cafe', 'car_rental', 'car_sharing', 'car_wash', 'charging_station', 'cinema', 'clinic', 'college', 'community_centre', 'conference_centre', 'courthouse', 'dentist', 'doctors', 'drinking_water', 'driving_school', 'events_venue', 'fast_food', 'ferry_terminal', 'fire_station', 'food_court', 'fountain', 'fuel', 'gambling', 'hospital', 'kindergarten', 'language_school', 'library', 'motorcycle_parking', 'music_school', 'nightclub', 'nursing_home', 'parcel_locker', 'parking', 'pharmacy', 'police', 'post_box', 'post_depot', 'post_office', 'pub', 'public_bookcase', 'recycling', 'restaurant', 'sanitary_dump_station', 'school', 'shelter', 'social_centre', 'social_facility', 'stripclub', 'studio', 'swingerclub', 'taxi', 'telephone', 'theatre', 'toilets', 'townhall', 'university', 'vehicle_inspection', 'veterinary', 'waste_basket', 'waste_disposal', 'waste_transfer_station', 'water_point', 'watering_place']
15
+
16
+ FEWSHOT_CLASS_55_LIST = ['arts_centre', 'atm', 'bank', 'bar', 'bench', 'bicycle_parking', 'bicycle_rental', 'boat_rental', 'bureau_de_change', 'bus_station', 'cafe', 'car_rental', 'car_sharing', 'car_wash', 'charging_station', 'cinema', 'clinic', 'college', 'community_centre', 'courthouse', 'dentist', 'doctors', 'drinking_water', 'driving_school', 'events_venue', 'fast_food', 'ferry_terminal', 'fire_station', 'fountain', 'fuel', 'hospital', 'kindergarten', 'library', 'music_school', 'nightclub', 'parking', 'pharmacy', 'police', 'post_box', 'post_office', 'pub', 'public_bookcase', 'recycling', 'restaurant', 'school', 'shelter', 'social_centre', 'social_facility', 'studio', 'theatre', 'toilets', 'townhall', 'university', 'vehicle_inspection', 'veterinary']
17
+
18
+
19
+ DICT_9to74 = {'sustenance':['bar','cafe','fast_food','food_court','pub','restaurant'],
20
+ 'education':['college','driving_school','kindergarten','language_school','library','music_school','school','university'],
21
+ 'transportation':['bicycle_parking','bicycle_repair_station','bicycle_rental','boat_rental',
22
+ 'bus_station','car_rental','car_sharing','car_wash','vehicle_inspection','charging_station','ferry_terminal',
23
+ 'fuel','motorcycle_parking','parking','taxi'],
24
+ 'financial':['atm','bank','bureau_de_change'],
25
+ 'healthcare':['clinic','dentist','doctors','hospital','nursing_home','pharmacy','social_facility','veterinary'],
26
+ 'entertainment_arts_culture':['arts_centre','cinema','community_centre',
27
+ 'conference_centre','events_venue','fountain','gambling',
28
+ 'nightclub','public_bookcase','social_centre','stripclub','studio','swingerclub','theatre'],
29
+ 'public_service':['courthouse','fire_station','police','post_box',
30
+ 'post_depot','post_office','townhall'],
31
+ 'facilities':['bench','drinking_water','parcel_locker','shelter',
32
+ 'telephone','toilets','water_point','watering_place'],
33
+ 'waste_management':['sanitary_dump_station','recycling','waste_basket','waste_disposal','waste_transfer_station',]
34
+ }
35
+
36
+ DICT_74to9 = revert_dict(DICT_9to74)
37
+
38
+ DICT_9to95 = {
39
+ 'education':{'college','driving_school','kindergarten','language_school','library','toy_library','training','music_school','school','university'},
40
+ 'entertainment_arts_culture':{'arts_centre','brothel','casino','cinema','community_centre','conference_centre','events_venue','fountain','gambling','love_hotel','nightclub','planetarium','public_bookcase','social_centre','stripclub','studio','swingerclub','theatre'},
41
+ 'facilities':{'bbq','bench','dog_toilet','dressing_room','drinking_water','give_box','parcel_locker','shelter','shower','telephone','toilets','water_point','watering_place'},
42
+ 'financial':{'atm','bank','bureau_de_change'},
43
+ 'healthcare':{'baby_hatch','clinic','dentist','doctors','hospital','nursing_home','pharmacy','social_facility','veterinary'},
44
+ 'public_service':{'courthouse','fire_station','police','post_box','post_depot','post_office','prison','ranger_station','townhall'},
45
+ 'sustenance':{'bar','biergarten','cafe','fast_food','food_court','ice_cream','pub','restaurant',},
46
+ 'transportation':{'bicycle_parking','bicycle_repair_station','bicycle_rental','boat_rental','boat_sharing','bus_station','car_rental','car_sharing','car_wash','compressed_air','vehicle_inspection','charging_station','ferry_terminal','fuel','grit_bin','motorcycle_parking','parking','parking_entrance','parking_space','taxi'},
47
+ 'waste_management':{'sanitary_dump_station','recycling','waste_basket','waste_disposal','waste_transfer_station'}}
48
+
49
+ DICT_95to9 = revert_dict(DICT_9to95)
50
+
51
+ # DICT_95to9 = {
52
+ # 'college':'education',
53
+ # 'driving_school':'education',
54
+ # 'kindergarten':'education',
55
+ # 'language_school':'education',
56
+ # 'library':'education',
57
+ # 'toy_library':'education',
58
+ # 'training':'education',
59
+ # 'music_school':'education',
60
+ # 'school':'education',
61
+ # 'university':'education',
62
+
63
+ # 'arts_centre':'entertainment_arts_culture',
64
+ # 'brothel':'entertainment_arts_culture',
65
+ # 'casino':'entertainment_arts_culture',
66
+ # 'cinema':'entertainment_arts_culture',
67
+ # 'community_centre':'entertainment_arts_culture',
68
+ # 'conference_centre':'entertainment_arts_culture',
69
+ # 'events_venue':'entertainment_arts_culture',
70
+ # 'fountain':'entertainment_arts_culture',
71
+ # 'gambling':'entertainment_arts_culture',
72
+ # 'love_hotel':'entertainment_arts_culture',
73
+ # 'nightclub':'entertainment_arts_culture',
74
+ # 'planetarium':'entertainment_arts_culture',
75
+ # 'public_bookcase':'entertainment_arts_culture',
76
+ # 'social_centre':'entertainment_arts_culture',
77
+ # 'stripclub':'entertainment_arts_culture',
78
+ # 'studio':'entertainment_arts_culture',
79
+ # 'swingerclub':'entertainment_arts_culture',
80
+ # 'theatre':'entertainment_arts_culture',
81
+
82
+ # 'bbq': 'facilities',
83
+ # 'bench': 'facilities',
84
+ # 'dog_toilet': 'facilities',
85
+ # 'dressing_room': 'facilities',
86
+ # 'drinking_water': 'facilities',
87
+ # 'give_box': 'facilities',
88
+ # 'parcel_locker': 'facilities',
89
+ # 'shelter': 'facilities',
90
+ # 'shower': 'facilities',
91
+ # 'telephone': 'facilities',
92
+ # 'toilets': 'facilities',
93
+ # 'water_point': 'facilities',
94
+ # 'watering_place': 'facilities',
95
+
96
+ # 'atm': 'financial',
97
+ # 'bank': 'financial',
98
+ # 'bureau_de_change': 'financial',
99
+
100
+ # 'baby_hatch':'healthcare',
101
+ # 'clinic':'healthcare',
102
+ # 'dentist':'healthcare',
103
+ # 'doctors':'healthcare',
104
+ # 'hospital':'healthcare',
105
+ # 'nursing_home':'healthcare',
106
+ # 'pharmacy':'healthcare',
107
+ # 'social_facility':'healthcare',
108
+ # 'veterinary':'healthcare',
109
+
110
+ # 'courthouse': 'public_service',
111
+ # 'fire_station': 'public_service',
112
+ # 'police': 'public_service',
113
+ # 'post_box': 'public_service',
114
+ # 'post_depot': 'public_service',
115
+ # 'post_office': 'public_service',
116
+ # 'prison': 'public_service',
117
+ # 'ranger_station': 'public_service',
118
+ # 'townhall': 'public_service',
119
+
120
+ # 'bar': 'sustenance',
121
+ # 'biergarten': 'sustenance',
122
+ # 'cafe': 'sustenance',
123
+ # 'fast_food': 'sustenance',
124
+ # 'food_court': 'sustenance',
125
+ # 'ice_cream': 'sustenance',
126
+ # 'pub': 'sustenance',
127
+ # 'restaurant': 'sustenance',
128
+
129
+ # 'bicycle_parking': 'transportation',
130
+ # 'bicycle_repair_station': 'transportation',
131
+ # 'bicycle_rental': 'transportation',
132
+ # 'boat_rental': 'transportation',
133
+ # 'boat_sharing': 'transportation',
134
+ # 'bus_station': 'transportation',
135
+ # 'car_rental': 'transportation',
136
+ # 'car_sharing': 'transportation',
137
+ # 'car_wash': 'transportation',
138
+ # 'compressed_air': 'transportation',
139
+ # 'vehicle_inspection': 'transportation',
140
+ # 'charging_station': 'transportation',
141
+ # 'ferry_terminal': 'transportation',
142
+ # 'fuel': 'transportation',
143
+ # 'grit_bin': 'transportation',
144
+ # 'motorcycle_parking': 'transportation',
145
+ # 'parking': 'transportation',
146
+ # 'parking_entrance': 'transportation',
147
+ # 'parking_space': 'transportation',
148
+ # 'taxi': 'transportation',
149
+
150
+ # 'sanitary_dump_station': 'waste_management',
151
+ # 'recycling': 'waste_management',
152
+ # 'waste_basket': 'waste_management',
153
+ # 'waste_disposal': 'waste_management',
154
+ # 'waste_transfer_station': 'waste_management',
155
+
156
+ # }
157
+
158
+ # CLASS_9_LIST = ['sustenance', 'education', 'transportation', 'financial', 'healthcare', 'entertainment_arts_culture', 'public_service', 'facilities', 'waste_management']
159
+
160
+ # FINE_LIST = ['bar','biergarten','cafe','fast_food','food_court','ice_cream','pub','restaurant','college','driving_school','kindergarten','language_school','library','toy_library','training','music_school','school','university','bicycle_parking','bicycle_repair_station','bicycle_rental','boat_rental','boat_sharing','bus_station','car_rental','car_sharing','car_wash','compressed_air','vehicle_inspection','charging_station','ferry_terminal','fuel','grit_bin','motorcycle_parking','parking','parking_entrance','parking_space','taxi','atm','bank','bureau_de_change','baby_hatch','clinic','dentist','doctors','hospital','nursing_home','pharmacy','social_facility','veterinary','arts_centre','brothel','casino','cinema','community_centre','conference_centre','events_venue','fountain','gambling','love_hotel','nightclub','planetarium','public_bookcase','social_centre','stripclub','studio','swingerclub','theatre','courthouse','fire_station','police','post_box','post_depot','post_office','prison','ranger_station','townhall','bbq','bench','dog_toilet','dressing_room','drinking_water','give_box','parcel_locker','shelter','shower','telephone','toilets','water_point','watering_place','sanitary_dump_station','recycling','waste_basket','waste_disposal','waste_transfer_station']
161
+
162
+ # FINE_LIST = ['bar','biergarten','cafe','fast_food','food_court','ice_cream','pub','restaurant','college','driving_school','kindergarten','language_school','library','toy_library','training','music_school','school','university','bicycle_parking','bicycle_repair_station','bicycle_rental','boat_rental','boat_sharing','bus_station','car_rental','car_sharing','car_wash','compressed_air','vehicle_inspection','charging_station','ferry_terminal','fuel','grit_bin','motorcycle_parking','parking','parking_entrance','parking_space','taxi','atm','bank','bureau_de_change','baby_hatch','clinic','dentist','doctors','hospital','nursing_home','pharmacy','social_facility','veterinary','arts_centre','brothel','casino','cinema','community_centre','conference_centre','events_venue','fountain','gambling','love_hotel','nightclub','planetarium','public_bookcase','social_centre','stripclub','studio','swingerclub','theatre','courthouse','fire_station','police','post_box','post_depot','post_office','prison','ranger_station','townhall','bbq','bench','dog_toilet','dressing_room','drinking_water','give_box','parcel_locker','shelter','shower','telephone','toilets','water_point','watering_place','sanitary_dump_station','recycling','waste_basket','waste_disposal','waste_transfer_station','animal_boarding','animal_breeding','animal_shelter','baking_oven','childcare','clock','crematorium','dive_centre','funeral_hall','grave_yard','hunting_stand','internet_cafe','kitchen','kneipp_water_cure','lounger','marketplace','monastery','photo_booth','place_of_mourning','place_of_worship','public_bath','refugee_site','vending_machine']
models/spabert/datasets/dataset_loader.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+ import pdb
5
+
6
+ class SpatialDataset(Dataset):
7
+ def __init__(self, tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors = False ):
8
+ self.tokenizer = tokenizer
9
+ self.max_token_len = max_token_len
10
+ self.distance_norm_factor = distance_norm_factor
11
+ self.sep_between_neighbors = sep_between_neighbors
12
+
13
+
14
+ def parse_spatial_context(self, pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill, pivot_dist_fill = 0):
15
+
16
+ sep_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.sep_token)
17
+ cls_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.cls_token)
18
+ mask_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
19
+ pad_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token)
20
+ max_token_len = self.max_token_len
21
+
22
+
23
+ # process pivot
24
+ pivot_name_tokens = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(pivot_name))
25
+ pivot_token_len = len(pivot_name_tokens)
26
+
27
+ pivot_lng = pivot_pos[0]
28
+ pivot_lat = pivot_pos[1]
29
+
30
+ # prepare entity mask
31
+ entity_mask_arr = []
32
+ rand_entity = np.random.uniform(size = len(neighbor_name_list) + 1) # random number for masking entities including neighbors and pivot
33
+ # True for mask, False for unmask
34
+
35
+ # check if pivot entity needs to be masked out, 15% prob. to be masked out
36
+ if rand_entity[0] < 0.15:
37
+ entity_mask_arr.extend([True] * pivot_token_len)
38
+ else:
39
+ entity_mask_arr.extend([False] * pivot_token_len)
40
+
41
+ # process neighbors
42
+ neighbor_token_list = []
43
+ neighbor_lng_list = []
44
+ neighbor_lat_list = []
45
+
46
+ # add separator between pivot and neighbor tokens
47
+ # a trick to avoid adding separator token after the class name (for class name encoding of margin-ranking loss)
48
+ if self.sep_between_neighbors and pivot_dist_fill==0:
49
+ neighbor_lng_list.append(spatial_dist_fill)
50
+ neighbor_lat_list.append(spatial_dist_fill)
51
+ neighbor_token_list.append(sep_token_id)
52
+
53
+ for neighbor_name, neighbor_geometry, rnd in zip(neighbor_name_list, neighbor_geometry_list, rand_entity[1:]):
54
+
55
+ if not neighbor_name[0].isalpha():
56
+ # only consider neighbors starting with letters
57
+ continue
58
+
59
+ neighbor_token = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(neighbor_name))
60
+ neighbor_token_len = len(neighbor_token)
61
+
62
+ # compute the relative distance from neighbor to pivot,
63
+ # normalize the relative distance by distance_norm_factor
64
+ # apply the calculated distance for all the subtokens of the neighbor
65
+ # neighbor_lng_list.extend([(neighbor_geometry[0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len)
66
+ # neighbor_lat_list.extend([(neighbor_geometry[1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len)
67
+
68
+ if 'coordinates' in neighbor_geometry: # to handle different json dict structures
69
+ neighbor_lng_list.extend([(neighbor_geometry['coordinates'][0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len)
70
+ neighbor_lat_list.extend([(neighbor_geometry['coordinates'][1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len)
71
+ neighbor_token_list.extend(neighbor_token)
72
+ else:
73
+ neighbor_lng_list.extend([(neighbor_geometry[0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len)
74
+ neighbor_lat_list.extend([(neighbor_geometry[1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len)
75
+ neighbor_token_list.extend(neighbor_token)
76
+
77
+ if self.sep_between_neighbors:
78
+ neighbor_lng_list.append(spatial_dist_fill)
79
+ neighbor_lat_list.append(spatial_dist_fill)
80
+ neighbor_token_list.append(sep_token_id)
81
+
82
+ entity_mask_arr.extend([False])
83
+
84
+
85
+ if rnd < 0.15:
86
+ #True: mask out, False: Keey original token
87
+ entity_mask_arr.extend([True] * neighbor_token_len)
88
+ else:
89
+ entity_mask_arr.extend([False] * neighbor_token_len)
90
+
91
+
92
+ pseudo_sentence = pivot_name_tokens + neighbor_token_list
93
+ dist_lng_list = [pivot_dist_fill] * pivot_token_len + neighbor_lng_list
94
+ dist_lat_list = [pivot_dist_fill] * pivot_token_len + neighbor_lat_list
95
+
96
+
97
+ #including cls and sep
98
+ sent_len = len(pseudo_sentence)
99
+
100
+ max_token_len_middle = max_token_len -2 # 2 for CLS and SEP token
101
+
102
+ # padding and truncation
103
+ if sent_len > max_token_len_middle :
104
+ pseudo_sentence = [cls_token_id] + pseudo_sentence[:max_token_len_middle] + [sep_token_id]
105
+ dist_lat_list = [spatial_dist_fill] + dist_lat_list[:max_token_len_middle]+ [spatial_dist_fill]
106
+ dist_lng_list = [spatial_dist_fill] + dist_lng_list[:max_token_len_middle]+ [spatial_dist_fill]
107
+ attention_mask = [False] + [1] * max_token_len_middle + [False] # make sure SEP and CLS are not attented to
108
+ else:
109
+ pad_len = max_token_len_middle - sent_len
110
+ assert pad_len >= 0
111
+
112
+ pseudo_sentence = [cls_token_id] + pseudo_sentence + [sep_token_id] + [pad_token_id] * pad_len
113
+ dist_lat_list = [spatial_dist_fill] + dist_lat_list + [spatial_dist_fill] + [spatial_dist_fill] * pad_len
114
+ dist_lng_list = [spatial_dist_fill] + dist_lng_list + [spatial_dist_fill] + [spatial_dist_fill] * pad_len
115
+ attention_mask = [False] + [1] * sent_len + [0] * pad_len + [False]
116
+
117
+
118
+
119
+
120
+ norm_lng_list = np.array(dist_lng_list) # / 0.0001
121
+ norm_lat_list = np.array(dist_lat_list) # / 0.0001
122
+
123
+
124
+ # mask entity in the pseudo sentence
125
+ entity_mask_indices = np.where(entity_mask_arr)[0]
126
+ masked_entity_input = [mask_token_id if i in entity_mask_indices else pseudo_sentence[i] for i in range(0, max_token_len)]
127
+
128
+
129
+ # mask token in the pseudo sentence
130
+ rand_token = np.random.uniform(size = len(pseudo_sentence))
131
+ # do not mask out cls and sep token. True: masked tokens False: Keey original token
132
+ token_mask_arr = (rand_token <0.15) & (np.array(pseudo_sentence) != cls_token_id) & (np.array(pseudo_sentence) != sep_token_id) & (np.array(pseudo_sentence) != pad_token_id)
133
+ token_mask_indices = np.where(token_mask_arr)[0]
134
+
135
+ masked_token_input = [mask_token_id if i in token_mask_indices else pseudo_sentence[i] for i in range(0, max_token_len)]
136
+
137
+
138
+ # yield masked_token with 50% prob, masked_entity with 50% prob
139
+ if np.random.rand() > 0.5:
140
+ masked_input = torch.tensor(masked_entity_input)
141
+ else:
142
+ masked_input = torch.tensor(masked_token_input)
143
+
144
+ train_data = {}
145
+ train_data['pivot_name'] = pivot_name
146
+ train_data['pivot_token_len'] = pivot_token_len
147
+ train_data['masked_input'] = masked_input
148
+ train_data['sent_position_ids'] = torch.tensor(np.arange(0, len(pseudo_sentence)))
149
+ train_data['attention_mask'] = torch.tensor(attention_mask)
150
+ train_data['norm_lng_list'] = torch.tensor(norm_lng_list).to(torch.float32)
151
+ train_data['norm_lat_list'] = torch.tensor(norm_lat_list).to(torch.float32)
152
+ train_data['pseudo_sentence'] = torch.tensor(pseudo_sentence)
153
+
154
+ return train_data
155
+
156
+
157
+
158
+ def __len__(self):
159
+ return NotImplementedError
160
+
161
+ def __getitem__(self, index):
162
+ raise NotImplementedError
models/spabert/datasets/dataset_loader_ver2.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+ import pdb
5
+
6
+ class SpatialDataset(Dataset):
7
+ def __init__(self, tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors = False ):
8
+ self.tokenizer = tokenizer
9
+ self.max_token_len = max_token_len
10
+ self.distance_norm_factor = distance_norm_factor
11
+ self.sep_between_neighbors = sep_between_neighbors
12
+
13
+
14
+ def parse_spatial_context(self, pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill, pivot_dist_fill = 0):
15
+
16
+ sep_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.sep_token)
17
+ cls_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.cls_token)
18
+ #mask_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
19
+ pad_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token)
20
+ max_token_len = self.max_token_len
21
+
22
+
23
+ #print("Module reloaded and changes are reflected")
24
+ # process pivot
25
+ pivot_name_tokens = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(pivot_name))
26
+ pivot_token_len = len(pivot_name_tokens)
27
+
28
+ pivot_lng = pivot_pos[0]
29
+ pivot_lat = pivot_pos[1]
30
+
31
+ # prepare entity mask
32
+ entity_mask_arr = []
33
+ rand_entity = np.random.uniform(size = len(neighbor_name_list) + 1) # random number for masking entities including neighbors and pivot
34
+ # True for mask, False for unmask
35
+
36
+ # check if pivot entity needs to be masked out, 15% prob. to be masked out
37
+ #if rand_entity[0] < 0.15:
38
+ # entity_mask_arr.extend([True] * pivot_token_len)
39
+ #else:
40
+ entity_mask_arr.extend([False] * pivot_token_len)
41
+
42
+ # process neighbors
43
+ neighbor_token_list = []
44
+ neighbor_lng_list = []
45
+ neighbor_lat_list = []
46
+
47
+ # add separator between pivot and neighbor tokens
48
+ # a trick to avoid adding separator token after the class name (for class name encoding of margin-ranking loss)
49
+ if self.sep_between_neighbors and pivot_dist_fill==0:
50
+ neighbor_lng_list.append(spatial_dist_fill)
51
+ neighbor_lat_list.append(spatial_dist_fill)
52
+ neighbor_token_list.append(sep_token_id)
53
+
54
+ for neighbor_name, neighbor_geometry, rnd in zip(neighbor_name_list, neighbor_geometry_list, rand_entity[1:]):
55
+
56
+ if not neighbor_name[0].isalpha():
57
+ # only consider neighbors starting with letters
58
+ continue
59
+
60
+ neighbor_token = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(neighbor_name))
61
+ neighbor_token_len = len(neighbor_token)
62
+
63
+ # compute the relative distance from neighbor to pivot,
64
+ # normalize the relative distance by distance_norm_factor
65
+ # apply the calculated distance for all the subtokens of the neighbor
66
+ # neighbor_lng_list.extend([(neighbor_geometry[0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len)
67
+ # neighbor_lat_list.extend([(neighbor_geometry[1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len)
68
+
69
+ if 'coordinates' in neighbor_geometry: # to handle different json dict structures
70
+ neighbor_lng_list.extend([(neighbor_geometry['coordinates'][0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len)
71
+ neighbor_lat_list.extend([(neighbor_geometry['coordinates'][1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len)
72
+ neighbor_token_list.extend(neighbor_token)
73
+ else:
74
+ neighbor_lng_list.extend([(neighbor_geometry[0]- pivot_lng)/self.distance_norm_factor] * neighbor_token_len)
75
+ neighbor_lat_list.extend([(neighbor_geometry[1]- pivot_lat)/self.distance_norm_factor] * neighbor_token_len)
76
+ neighbor_token_list.extend(neighbor_token)
77
+
78
+ if self.sep_between_neighbors:
79
+ neighbor_lng_list.append(spatial_dist_fill)
80
+ neighbor_lat_list.append(spatial_dist_fill)
81
+ neighbor_token_list.append(sep_token_id)
82
+
83
+ entity_mask_arr.extend([False])
84
+
85
+
86
+ #if rnd < 0.15:
87
+ # #True: mask out, False: Keey original token
88
+ # entity_mask_arr.extend([True] * neighbor_token_len)
89
+ #else:
90
+ entity_mask_arr.extend([False] * neighbor_token_len)
91
+
92
+
93
+ pseudo_sentence = pivot_name_tokens + neighbor_token_list
94
+ dist_lng_list = [pivot_dist_fill] * pivot_token_len + neighbor_lng_list
95
+ dist_lat_list = [pivot_dist_fill] * pivot_token_len + neighbor_lat_list
96
+
97
+
98
+ #including cls and sep
99
+ sent_len = len(pseudo_sentence)
100
+
101
+ max_token_len_middle = max_token_len -2 # 2 for CLS and SEP token
102
+
103
+ # padding and truncation
104
+ if sent_len > max_token_len_middle :
105
+ pseudo_sentence = [cls_token_id] + pseudo_sentence[:max_token_len_middle] + [sep_token_id]
106
+ dist_lat_list = [spatial_dist_fill] + dist_lat_list[:max_token_len_middle]+ [spatial_dist_fill]
107
+ dist_lng_list = [spatial_dist_fill] + dist_lng_list[:max_token_len_middle]+ [spatial_dist_fill]
108
+ attention_mask = [False] + [1] * max_token_len_middle + [False] # make sure SEP and CLS are not attented to
109
+ else:
110
+ pad_len = max_token_len_middle - sent_len
111
+ assert pad_len >= 0
112
+
113
+ pseudo_sentence = [cls_token_id] + pseudo_sentence + [sep_token_id] + [pad_token_id] * pad_len
114
+ dist_lat_list = [spatial_dist_fill] + dist_lat_list + [spatial_dist_fill] + [spatial_dist_fill] * pad_len
115
+ dist_lng_list = [spatial_dist_fill] + dist_lng_list + [spatial_dist_fill] + [spatial_dist_fill] * pad_len
116
+ attention_mask = [False] + [1] * sent_len + [0] * pad_len + [False]
117
+
118
+
119
+
120
+
121
+ norm_lng_list = np.array(dist_lng_list) # / 0.0001
122
+ norm_lat_list = np.array(dist_lat_list) # / 0.0001
123
+
124
+
125
+ ## mask entity in the pseudo sentence
126
+ #entity_mask_indices = np.where(entity_mask_arr)[0]
127
+ #masked_entity_input = [mask_token_id if i in entity_mask_indices else pseudo_sentence[i] for i in range(0, max_token_len)]
128
+ #
129
+ #
130
+ ## mask token in the pseudo sentence
131
+ #rand_token = np.random.uniform(size = len(pseudo_sentence))
132
+ ## do not mask out cls and sep token. True: masked tokens False: Keey original token
133
+ #token_mask_arr = (rand_token <0.15) & (np.array(pseudo_sentence) != cls_token_id) & (np.array(pseudo_sentence) != sep_token_id) & (np.array(pseudo_sentence) != pad_token_id)
134
+ #token_mask_indices = np.where(token_mask_arr)[0]
135
+ #
136
+ #masked_token_input = [mask_token_id if i in token_mask_indices else pseudo_sentence[i] for i in range(0, max_token_len)]
137
+ #
138
+ #
139
+ ## yield masked_token with 50% prob, masked_entity with 50% prob
140
+ #if np.random.rand() > 0.5:
141
+ # masked_input = torch.tensor(masked_entity_input)
142
+ #else:
143
+ # masked_input = torch.tensor(masked_token_input)
144
+ masked_input = torch.tensor(pseudo_sentence)
145
+
146
+ train_data = {}
147
+ train_data['pivot_name'] = pivot_name
148
+ train_data['pivot_token_len'] = pivot_token_len
149
+ train_data['masked_input'] = masked_input
150
+ train_data['sent_position_ids'] = torch.tensor(np.arange(0, len(pseudo_sentence)))
151
+ train_data['attention_mask'] = torch.tensor(attention_mask)
152
+ train_data['norm_lng_list'] = torch.tensor(norm_lng_list).to(torch.float32)
153
+ train_data['norm_lat_list'] = torch.tensor(norm_lat_list).to(torch.float32)
154
+ train_data['pseudo_sentence'] = torch.tensor(pseudo_sentence)
155
+
156
+ return train_data
157
+
158
+
159
+
160
+ def __len__(self):
161
+ return NotImplementedError
162
+
163
+ def __getitem__(self, index):
164
+ raise NotImplementedError
models/spabert/datasets/osm_sample_loader.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import numpy as np
4
+ import json
5
+ import math
6
+
7
+ import torch
8
+ from transformers import RobertaTokenizer, BertTokenizer
9
+ from torch.utils.data import Dataset
10
+ #sys.path.append('/home/zekun/spatial_bert/spatial_bert/datasets')
11
+ sys.path.append('/content/drive/MyDrive/spaBERT/spabert/datasets')
12
+ from dataset_loader_ver2 import SpatialDataset
13
+
14
+ import pdb
15
+
16
+ class PbfMapDataset(SpatialDataset):
17
+ def __init__(self, data_file_path, tokenizer=None, max_token_len = 512, distance_norm_factor = 0.0001, spatial_dist_fill=10,
18
+ with_type = True, sep_between_neighbors = False, label_encoder = None, mode = None, num_neighbor_limit = None, random_remove_neighbor = 0.,type_key_str='class'):
19
+
20
+ if tokenizer is None:
21
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
22
+ else:
23
+ self.tokenizer = tokenizer
24
+
25
+ self.max_token_len = max_token_len
26
+ self.spatial_dist_fill = spatial_dist_fill # should be normalized distance fill, larger than all normalized neighbor distance
27
+ self.with_type = with_type
28
+ self.sep_between_neighbors = sep_between_neighbors
29
+ self.label_encoder = label_encoder
30
+ self.num_neighbor_limit = num_neighbor_limit
31
+ self.read_file(data_file_path, mode)
32
+ self.random_remove_neighbor = random_remove_neighbor
33
+ self.type_key_str = type_key_str # key name of the class type in the input data dictionary
34
+
35
+ super(PbfMapDataset, self).__init__(self.tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors )
36
+
37
+
38
+ def read_file(self, data_file_path, mode):
39
+
40
+ with open(data_file_path, 'r') as f:
41
+ data = f.readlines()
42
+
43
+ if mode == 'train':
44
+ data = data[0:int(len(data) * 0.8)]
45
+ elif mode == 'test':
46
+ data = data[int(len(data) * 0.8):]
47
+ elif mode is None: # use the full dataset (for mlm)
48
+ pass
49
+ else:
50
+ raise NotImplementedError
51
+
52
+ self.len_data = len(data) # updated data length
53
+ self.data = data
54
+
55
+ def load_data(self, index):
56
+
57
+ spatial_dist_fill = self.spatial_dist_fill
58
+ line = self.data[index] # take one line from the input data according to the index
59
+
60
+ line_data_dict = json.loads(line)
61
+
62
+ # process pivot
63
+ pivot_name = line_data_dict['info']['name']
64
+ pivot_pos = line_data_dict['info']['geometry']['coordinates']
65
+
66
+
67
+ neighbor_info = line_data_dict['neighbor_info']
68
+ neighbor_name_list = neighbor_info['name_list']
69
+ neighbor_geometry_list = neighbor_info['geometry_list']
70
+
71
+ if self.random_remove_neighbor != 0:
72
+ num_neighbors = len(neighbor_name_list)
73
+ rand_neighbor = np.random.uniform(size = num_neighbors)
74
+
75
+ neighbor_keep_arr = (rand_neighbor >= self.random_remove_neighbor) # select the neighbors to be removed
76
+ neighbor_keep_arr = np.where(neighbor_keep_arr)[0]
77
+
78
+ new_neighbor_name_list, new_neighbor_geometry_list = [],[]
79
+ for i in range(0, num_neighbors):
80
+ if i in neighbor_keep_arr:
81
+ new_neighbor_name_list.append(neighbor_name_list[i])
82
+ new_neighbor_geometry_list.append(neighbor_geometry_list[i])
83
+
84
+ neighbor_name_list = new_neighbor_name_list
85
+ neighbor_geometry_list = new_neighbor_geometry_list
86
+
87
+ if self.num_neighbor_limit is not None:
88
+ neighbor_name_list = neighbor_name_list[0:self.num_neighbor_limit]
89
+ neighbor_geometry_list = neighbor_geometry_list[0:self.num_neighbor_limit]
90
+
91
+
92
+ train_data = self.parse_spatial_context(pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill )
93
+
94
+ if self.with_type:
95
+ pivot_type = line_data_dict['info'][self.type_key_str]
96
+ train_data['pivot_type'] = torch.tensor(self.label_encoder.transform([pivot_type])[0]) # scalar, label_id
97
+
98
+ if 'ogc_fid' in line_data_dict['info']:
99
+ train_data['ogc_fid'] = line_data_dict['info']['ogc_fid']
100
+
101
+ return train_data
102
+
103
+ def __len__(self):
104
+ return self.len_data
105
+
106
+ def __getitem__(self, index):
107
+ return self.load_data(index)
108
+
109
+
110
+
111
+ class PbfMapDatasetMarginRanking(SpatialDataset):
112
+ def __init__(self, data_file_path, type_list = None, tokenizer=None, max_token_len = 512, distance_norm_factor = 0.0001, spatial_dist_fill=10,
113
+ sep_between_neighbors = False, mode = None, num_neighbor_limit = None, random_remove_neighbor = 0., type_key_str='class'):
114
+
115
+ if tokenizer is None:
116
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
117
+ else:
118
+ self.tokenizer = tokenizer
119
+
120
+ self.type_list = type_list
121
+ self.type_key_str = type_key_str # key name of the class type in the input data dictionary
122
+ self.max_token_len = max_token_len
123
+ self.spatial_dist_fill = spatial_dist_fill # should be normalized distance fill, larger than all normalized neighbor distance
124
+ self.sep_between_neighbors = sep_between_neighbors
125
+ # self.label_encoder = label_encoder
126
+ self.num_neighbor_limit = num_neighbor_limit
127
+ self.read_file(data_file_path, mode)
128
+ self.random_remove_neighbor = random_remove_neighbor
129
+ self.mode = mode
130
+
131
+
132
+ super(PbfMapDatasetMarginRanking, self).__init__(self.tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors )
133
+
134
+
135
+ def read_file(self, data_file_path, mode):
136
+
137
+ with open(data_file_path, 'r') as f:
138
+ data = f.readlines()
139
+
140
+ if mode == 'train':
141
+ data = data[0:int(len(data) * 0.8)]
142
+ elif mode == 'test':
143
+ data = data[int(len(data) * 0.8):]
144
+ self.all_types_data = self.prepare_all_types_data()
145
+ elif mode is None: # use the full dataset (for mlm)
146
+ pass
147
+ else:
148
+ raise NotImplementedError
149
+
150
+ self.len_data = len(data) # updated data length
151
+ self.data = data
152
+
153
+ def prepare_all_types_data(self):
154
+ type_list = self.type_list
155
+ spatial_dist_fill = self.spatial_dist_fill
156
+ type_data_dict = dict()
157
+ for type_name in type_list:
158
+ type_pos = [None, None] # use filler values
159
+ type_data = self.parse_spatial_context(type_name, type_pos, pivot_dist_fill = 0.,
160
+ neighbor_name_list = [], neighbor_geometry_list=[], spatial_dist_fill= spatial_dist_fill)
161
+ type_data_dict[type_name] = type_data
162
+
163
+ return type_data_dict
164
+
165
+ def load_data(self, index):
166
+
167
+ spatial_dist_fill = self.spatial_dist_fill
168
+ line = self.data[index] # take one line from the input data according to the index
169
+
170
+ line_data_dict = json.loads(line)
171
+
172
+ # process pivot
173
+ pivot_name = line_data_dict['info']['name']
174
+ pivot_pos = line_data_dict['info']['geometry']['coordinates']
175
+
176
+
177
+ neighbor_info = line_data_dict['neighbor_info']
178
+ neighbor_name_list = neighbor_info['name_list']
179
+ neighbor_geometry_list = neighbor_info['geometry_list']
180
+
181
+ if self.random_remove_neighbor != 0:
182
+ num_neighbors = len(neighbor_name_list)
183
+ rand_neighbor = np.random.uniform(size = num_neighbors)
184
+
185
+ neighbor_keep_arr = (rand_neighbor >= self.random_remove_neighbor) # select the neighbors to be removed
186
+ neighbor_keep_arr = np.where(neighbor_keep_arr)[0]
187
+
188
+ new_neighbor_name_list, new_neighbor_geometry_list = [],[]
189
+ for i in range(0, num_neighbors):
190
+ if i in neighbor_keep_arr:
191
+ new_neighbor_name_list.append(neighbor_name_list[i])
192
+ new_neighbor_geometry_list.append(neighbor_geometry_list[i])
193
+
194
+ neighbor_name_list = new_neighbor_name_list
195
+ neighbor_geometry_list = new_neighbor_geometry_list
196
+
197
+ if self.num_neighbor_limit is not None:
198
+ neighbor_name_list = neighbor_name_list[0:self.num_neighbor_limit]
199
+ neighbor_geometry_list = neighbor_geometry_list[0:self.num_neighbor_limit]
200
+
201
+
202
+ train_data = self.parse_spatial_context(pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill )
203
+
204
+ if 'ogc_fid' in line_data_dict['info']:
205
+ train_data['ogc_fid'] = line_data_dict['info']['ogc_fid']
206
+
207
+ # train_data['pivot_type'] = torch.tensor(self.label_encoder.transform([pivot_type])[0]) # scalar, label_id
208
+
209
+ pivot_type = line_data_dict['info'][self.type_key_str]
210
+ train_data['pivot_type'] = pivot_type
211
+
212
+ if self.mode == 'train':
213
+ # postive class
214
+ postive_name = pivot_type # class type string as input to tokenizer
215
+ positive_pos = [None, None] # use filler values
216
+ postive_type_data = self.parse_spatial_context(postive_name, positive_pos, pivot_dist_fill = 0.,
217
+ neighbor_name_list = [], neighbor_geometry_list=[], spatial_dist_fill= spatial_dist_fill)
218
+ train_data['positive_type_data'] = postive_type_data
219
+
220
+
221
+ # negative class
222
+ other_type_list = self.type_list.copy()
223
+ other_type_list.remove(pivot_type)
224
+ other_type = np.random.choice(other_type_list)
225
+ negative_name = other_type
226
+ negative_pos = [None, None] # use filler values
227
+ negative_type_data = self.parse_spatial_context(negative_name, negative_pos, pivot_dist_fill = 0.,
228
+ neighbor_name_list = [], neighbor_geometry_list=[], spatial_dist_fill= spatial_dist_fill)
229
+ train_data['negative_type_data'] = negative_type_data
230
+
231
+ elif self.mode == 'test':
232
+ # return data for all class types in type_list
233
+ train_data['all_types_data'] = self.all_types_data
234
+
235
+ else:
236
+ raise NotImplementedError
237
+
238
+ return train_data
239
+
240
+ def __len__(self):
241
+ return self.len_data
242
+
243
+ def __getitem__(self, index):
244
+ return self.load_data(index)
245
+
246
+
models/spabert/datasets/usgs_os_sample_loader.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import numpy as np
4
+ import json
5
+ import math
6
+ from sklearn.preprocessing import LabelBinarizer, LabelEncoder
7
+
8
+ import torch
9
+ from transformers import RobertaTokenizer, BertTokenizer
10
+ from torch.utils.data import Dataset
11
+ sys.path.append('/home/zekun/spatial_bert/spatial_bert/datasets')
12
+ from dataset_loader import SpatialDataset
13
+
14
+ import pdb
15
+
16
+
17
+ class USGS_MapDataset(SpatialDataset):
18
+ def __init__(self, data_file_path, tokenizer=None, max_token_len = 512, distance_norm_factor = 1, spatial_dist_fill=100, sep_between_neighbors = False):
19
+
20
+ if tokenizer is None:
21
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
22
+ else:
23
+ self.tokenizer = tokenizer
24
+
25
+ self.max_token_len = max_token_len
26
+ self.spatial_dist_fill = spatial_dist_fill # should be normalized distance fill, larger than all normalized neighbor distance
27
+ self.sep_between_neighbors = sep_between_neighbors
28
+ self.read_file(data_file_path)
29
+
30
+
31
+ super(USGS_MapDataset, self).__init__(self.tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors )
32
+
33
+ def read_file(self, data_file_path):
34
+
35
+ with open(data_file_path, 'r') as f:
36
+ data = f.readlines()
37
+
38
+ len_data = len(data)
39
+ self.len_data = len_data
40
+ self.data = data
41
+
42
+
43
+ def load_data(self, index):
44
+
45
+ spatial_dist_fill = self.spatial_dist_fill
46
+ line = self.data[index] # take one line from the input data according to the index
47
+
48
+ line_data_dict = json.loads(line)
49
+
50
+ # process pivot
51
+ pivot_name = line_data_dict['info']['name']
52
+ pivot_pos = line_data_dict['info']['geometry']
53
+
54
+ neighbor_info = line_data_dict['neighbor_info']
55
+ neighbor_name_list = neighbor_info['name_list']
56
+ neighbor_geometry_list = neighbor_info['geometry_list']
57
+
58
+ parsed_data = self.parse_spatial_context(pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill )
59
+
60
+ return parsed_data
61
+
62
+
63
+
64
+ def __len__(self):
65
+ return self.len_data
66
+
67
+ def __getitem__(self, index):
68
+ return self.load_data(index)
69
+
70
+
71
+
models/spabert/datasets/wikidata_sample_loader.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+ import json
4
+ import math
5
+
6
+ import torch
7
+ from transformers import RobertaTokenizer, BertTokenizer
8
+ from torch.utils.data import Dataset
9
+ sys.path.append('/home/zekun/spatial_bert/spatial_bert/datasets')
10
+ from dataset_loader import SpatialDataset
11
+
12
+ import pdb
13
+
14
+ '''Prepare candiate list given randomly sampled data and append to data_list'''
15
+ class Wikidata_Random_Dataset(SpatialDataset):
16
+ def __init__(self, data_file_path, tokenizer=None, max_token_len = 512, distance_norm_factor = 0.0001, spatial_dist_fill=100, sep_between_neighbors = False):
17
+ if tokenizer is None:
18
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
19
+ else:
20
+ self.tokenizer = tokenizer
21
+
22
+ self.max_token_len = max_token_len
23
+ self.spatial_dist_fill = spatial_dist_fill # should be normalized distance fill, larger than all normalized neighbor distance
24
+ self.sep_between_neighbors = sep_between_neighbors
25
+ self.read_file(data_file_path)
26
+
27
+
28
+ super(Wikidata_Random_Dataset, self).__init__(self.tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors )
29
+
30
+ def read_file(self, data_file_path):
31
+
32
+ with open(data_file_path, 'r') as f:
33
+ data = f.readlines()
34
+
35
+ len_data = len(data)
36
+ self.len_data = len_data
37
+ self.data = data
38
+
39
+ def load_data(self, index):
40
+
41
+ spatial_dist_fill = self.spatial_dist_fill
42
+ line = self.data[index] # take one line from the input data according to the index
43
+
44
+ line_data_dict = json.loads(line)
45
+
46
+ # process pivot
47
+ pivot_name = line_data_dict['info']['name']
48
+ pivot_pos = line_data_dict['info']['geometry']['coordinates']
49
+ pivot_uri = line_data_dict['info']['uri']
50
+
51
+ neighbor_info = line_data_dict['neighbor_info']
52
+ neighbor_name_list = neighbor_info['name_list']
53
+ neighbor_geometry_list = neighbor_info['geometry_list']
54
+
55
+ parsed_data = self.parse_spatial_context(pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill )
56
+ parsed_data['uri'] = pivot_uri
57
+ parsed_data['description'] = None # placeholder
58
+
59
+ return parsed_data
60
+
61
+ def __len__(self):
62
+ return self.len_data
63
+
64
+ def __getitem__(self, index):
65
+ return self.load_data(index)
66
+
67
+
68
+
69
+ '''Prepare candiate list for each phrase and append to data_list'''
70
+
71
+ class Wikidata_Geocoord_Dataset(SpatialDataset):
72
+
73
+ #DEFAULT_CONFIG_CLS = SpatialBertConfig
74
+ def __init__(self, data_file_path, tokenizer=None, max_token_len = 512, distance_norm_factor = 0.0001, spatial_dist_fill=100, sep_between_neighbors = False):
75
+ if tokenizer is None:
76
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
77
+ else:
78
+ self.tokenizer = tokenizer
79
+
80
+ self.max_token_len = max_token_len
81
+ self.spatial_dist_fill = spatial_dist_fill # should be normalized distance fill, larger than all normalized neighbor distance
82
+ self.sep_between_neighbors = sep_between_neighbors
83
+ self.read_file(data_file_path)
84
+
85
+
86
+ super(Wikidata_Geocoord_Dataset, self).__init__(self.tokenizer , max_token_len , distance_norm_factor, sep_between_neighbors )
87
+
88
+ def read_file(self, data_file_path):
89
+
90
+ with open(data_file_path, 'r') as f:
91
+ data = f.readlines()
92
+
93
+ len_data = len(data)
94
+ self.len_data = len_data
95
+ self.data = data
96
+
97
+ def load_data(self, index):
98
+
99
+ spatial_dist_fill = self.spatial_dist_fill
100
+ line = self.data[index] # take one line from the input data according to the index
101
+
102
+ line_data = json.loads(line)
103
+ parsed_data_list = []
104
+
105
+ for line_data_dict in line_data:
106
+ # process pivot
107
+ pivot_name = line_data_dict['info']['name']
108
+ pivot_pos = line_data_dict['info']['geometry']['coordinates']
109
+ pivot_uri = line_data_dict['info']['uri']
110
+
111
+ neighbor_info = line_data_dict['neighbor_info']
112
+ neighbor_name_list = neighbor_info['name_list']
113
+ neighbor_geometry_list = neighbor_info['geometry_list']
114
+
115
+ parsed_data = self.parse_spatial_context(pivot_name, pivot_pos, neighbor_name_list, neighbor_geometry_list, spatial_dist_fill )
116
+ parsed_data['uri'] = pivot_uri
117
+ parsed_data['description'] = None # placeholder
118
+ parsed_data_list.append(parsed_data)
119
+
120
+ return parsed_data_list
121
+
122
+
123
+ def __len__(self):
124
+ return self.len_data
125
+
126
+ def __getitem__(self, index):
127
+ return self.load_data(index)
models/spabert/experiments/__init__.py ADDED
File without changes
models/spabert/experiments/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (155 Bytes). View file
 
models/spabert/experiments/entity_matching/__init__.py ADDED
File without changes
models/spabert/experiments/entity_matching/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (171 Bytes). View file
 
models/spabert/experiments/entity_matching/data_processing/__init__.py ADDED
File without changes
models/spabert/experiments/entity_matching/data_processing/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (187 Bytes). View file
 
models/spabert/experiments/entity_matching/data_processing/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (204 Bytes). View file
 
models/spabert/experiments/entity_matching/data_processing/__pycache__/request_wrapper.cpython-310.pyc ADDED
Binary file (5.97 kB). View file
 
models/spabert/experiments/entity_matching/data_processing/__pycache__/request_wrapper.cpython-311.pyc ADDED
Binary file (7.83 kB). View file
 
models/spabert/experiments/entity_matching/data_processing/get_namelist.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ def get_name_list_osm(ref_paths):
5
+ name_list = []
6
+
7
+ for json_path in ref_paths:
8
+ with open(json_path, 'r') as f:
9
+ data = f.readlines()
10
+ for line in data:
11
+ record = json.loads(line)
12
+ name = record['name']
13
+ name_list.append(name)
14
+
15
+ namelist = sorted(namelist)
16
+ return name_list
17
+
18
+ # deprecated
19
+ def get_name_list_usgs_od(ref_paths):
20
+ name_list = []
21
+
22
+ for json_path in ref_paths:
23
+ with open(json_path, 'r') as f:
24
+ annot_dict = json.load(f)
25
+ for key, place in annot_dict.items():
26
+ place_name = ''
27
+ for idx in range(1, len(place)+1):
28
+ try:
29
+ place_name += place[str(idx)]['text_label']
30
+ place_name += ' ' # separate words with spaces
31
+
32
+ except Exception as e:
33
+ print(place)
34
+ place_name = place_name[:-1] # remove last space
35
+
36
+ name_list.append(place_name)
37
+
38
+ namelist = sorted(namelist)
39
+ return name_list
40
+
41
+ def get_name_list_usgs_od_per_map(ref_paths):
42
+ all_name_list_dict = dict()
43
+
44
+ for json_path in ref_paths:
45
+ map_name = os.path.basename(json_path).split('.json')[0]
46
+
47
+ with open(json_path, 'r') as f:
48
+ annot_dict = json.load(f)
49
+
50
+ map_name_list = []
51
+ for key, place in annot_dict.items():
52
+ place_name = ''
53
+ for idx in range(1, len(place)+1):
54
+ try:
55
+ place_name += place[str(idx)]['text_label']
56
+ place_name += ' ' # separate words with spaces
57
+
58
+ except Exception as e:
59
+ print(place)
60
+ place_name = place_name[:-1] # remove last space
61
+
62
+ map_name_list.append(place_name)
63
+ all_name_list_dict[map_name] = sorted(map_name_list)
64
+
65
+ return all_name_list_dict
66
+
67
+
68
+ def get_name_list_gb1900(ref_path):
69
+ name_list = []
70
+
71
+ with open(ref_path, 'r',encoding='utf-16') as f:
72
+ data = f.readlines()
73
+
74
+
75
+ for line in data[1:]: # skip the header
76
+ try:
77
+ line = line.split(',')
78
+ text = line[1]
79
+ lat = float(line[-3])
80
+ lng = float(line[-2])
81
+ semantic_type = line[-1]
82
+
83
+ name_list.append(text)
84
+ except:
85
+ print(line)
86
+
87
+ namelist = sorted(namelist)
88
+
89
+ return name_list
90
+
91
+
92
+ if __name__ == '__main__':
93
+ #name_list = get_name_list_usgs_od(['labGISReport-master/output/USGS-15-CA-brawley-e1957-s1957-p1961.json',
94
+ #'labGISReport-master/output/USGS-15-CA-capesanmartin-e1921-s1917.json'])
95
+ name_list = get_name_list_gb1900('data/GB1900_gazetteer_abridged_july_2018/gb1900_abridged.csv')
models/spabert/experiments/entity_matching/data_processing/request_wrapper.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import pdb
3
+ import time
4
+
5
+ # for linkedgeodata: #http://linkedgeodata.org/sparql
6
+
7
+ class RequestWrapper:
8
+ def __init__(self, baseuri = "https://query.wikidata.org/sparql"):
9
+
10
+ self.baseuri = baseuri
11
+
12
+ def response_handler(self, response, query):
13
+ if response.status_code == requests.codes.ok:
14
+ ret_json = response.json()['results']['bindings']
15
+ elif response.status_code == 500:
16
+ ret_json = []
17
+ #print(q_id)
18
+ print('Internal Error happened. Set ret_json to be empty list')
19
+
20
+ elif response.status_code == 429:
21
+
22
+ print(response.status_code)
23
+ print(response.text)
24
+ retry_seconds = int(response.text.split('Too Many Requests - Please retry in ')[1].split(' seconds')[0])
25
+ print('rerun in %d seconds' %retry_seconds)
26
+ time.sleep(retry_seconds + 1)
27
+
28
+ response = requests.get(self.baseuri, params = {'format':'json', 'query':query})
29
+ ret_json = response.json()['results']['bindings']
30
+ #print(ret_json)
31
+ print('resumed and succeeded')
32
+
33
+ else:
34
+ print(response.status_code, response.text)
35
+ exit(-1)
36
+
37
+ return ret_json
38
+
39
+ '''Search for wikidata entities given the name string'''
40
+ def wikidata_query (self, name_str):
41
+
42
+ query = """
43
+ PREFIX wd: <http://www.wikidata.org/entity/>
44
+ PREFIX wds: <http://www.wikidata.org/entity/statement/>
45
+ PREFIX wdv: <http://www.wikidata.org/value/>
46
+ PREFIX wdt: <http://www.wikidata.org/prop/direct/>
47
+ PREFIX wikibase: <http://wikiba.se/ontology#>
48
+ PREFIX p: <http://www.wikidata.org/prop/>
49
+ PREFIX ps: <http://www.wikidata.org/prop/statement/>
50
+ PREFIX pq: <http://www.wikidata.org/prop/qualifier/>
51
+ PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
52
+ PREFIX bd: <http://www.bigdata.com/rdf#>
53
+
54
+ SELECT ?item ?coordinates ?itemDescription WHERE {
55
+ ?item rdfs:label \"%s\"@en;
56
+ wdt:P625 ?coordinates .
57
+ SERVICE wikibase:label { bd:serviceParam wikibase:language "en" }
58
+ }
59
+ """%(name_str)
60
+
61
+ response = requests.get(self.baseuri, params = {'format':'json', 'query':query})
62
+
63
+
64
+ ret_json = self.response_handler(response, query)
65
+
66
+ return ret_json
67
+
68
+
69
+ '''Search for wikidata entities given the name string'''
70
+ def wikidata_query_withinstate (self, name_str, state_id = 'Q99'):
71
+
72
+
73
+ query = """
74
+ PREFIX wd: <http://www.wikidata.org/entity/>
75
+ PREFIX wds: <http://www.wikidata.org/entity/statement/>
76
+ PREFIX wdv: <http://www.wikidata.org/value/>
77
+ PREFIX wdt: <http://www.wikidata.org/prop/direct/>
78
+ PREFIX wikibase: <http://wikiba.se/ontology#>
79
+ PREFIX p: <http://www.wikidata.org/prop/>
80
+ PREFIX ps: <http://www.wikidata.org/prop/statement/>
81
+ PREFIX pq: <http://www.wikidata.org/prop/qualifier/>
82
+ PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
83
+ PREFIX bd: <http://www.bigdata.com/rdf#>
84
+
85
+ SELECT ?item ?coordinates ?itemDescription WHERE {
86
+ ?item rdfs:label \"%s\"@en;
87
+ wdt:P625 ?coordinates ;
88
+ wdt:P131+ wd:%s;
89
+ SERVICE wikibase:label { bd:serviceParam wikibase:language "en" }
90
+ }
91
+ """%(name_str, state_id)
92
+
93
+ #print(query)
94
+
95
+ response = requests.get(self.baseuri, params = {'format':'json', 'query':query})
96
+
97
+ ret_json = self.response_handler(response, query)
98
+
99
+ return ret_json
100
+
101
+
102
+ '''Search for nearby wikidata entities given the entity id'''
103
+ def wikidata_nearby_query (self, q_id):
104
+
105
+ query = """
106
+ PREFIX wd: <http://www.wikidata.org/entity/>
107
+ PREFIX wds: <http://www.wikidata.org/entity/statement/>
108
+ PREFIX wdv: <http://www.wikidata.org/value/>
109
+ PREFIX wdt: <http://www.wikidata.org/prop/direct/>
110
+ PREFIX wikibase: <http://wikiba.se/ontology#>
111
+ PREFIX p: <http://www.wikidata.org/prop/>
112
+ PREFIX ps: <http://www.wikidata.org/prop/statement/>
113
+ PREFIX pq: <http://www.wikidata.org/prop/qualifier/>
114
+ PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
115
+ PREFIX bd: <http://www.bigdata.com/rdf#>
116
+
117
+ SELECT ?place ?placeLabel ?location ?instanceLabel ?placeDescription
118
+ WHERE
119
+ {
120
+ wd:%s wdt:P625 ?loc .
121
+ SERVICE wikibase:around {
122
+ ?place wdt:P625 ?location .
123
+ bd:serviceParam wikibase:center ?loc .
124
+ bd:serviceParam wikibase:radius "5" .
125
+ }
126
+ OPTIONAL { ?place wdt:P31 ?instance }
127
+ SERVICE wikibase:label { bd:serviceParam wikibase:language "en" }
128
+ BIND(geof:distance(?loc, ?location) as ?dist)
129
+ } ORDER BY ?dist
130
+ LIMIT 200
131
+ """%(q_id)
132
+ # initially 2km
133
+
134
+ #pdb.set_trace()
135
+
136
+ response = requests.get(self.baseuri, params = {'format':'json', 'query':query})
137
+
138
+
139
+ ret_json = self.response_handler(response, query)
140
+
141
+
142
+
143
+ return ret_json
144
+
145
+
146
+
147
+
148
+ def linkedgeodata_query (self, name_str):
149
+
150
+ query = """
151
+
152
+ Prefix lgdo: <http://linkedgeodata.org/ontology/>
153
+ Prefix geom: <http://geovocab.org/geometry#>
154
+ Prefix ogc: <http://www.opengis.net/ont/geosparql#>
155
+ Prefix owl: <http://www.w3.org/2002/07/owl#>
156
+ Prefix wgs84_pos: <http://www.w3.org/2003/01/geo/wgs84_pos#>
157
+ Prefix owl: <http://www.w3.org/2002/07/owl#>
158
+ Prefix gn: <http://www.geonames.org/ontology#>
159
+
160
+ Select ?s, ?lat, ?long {
161
+ {?s rdfs:label \"%s\";
162
+ wgs84_pos:lat ?lat ;
163
+ wgs84_pos:long ?long;
164
+ }
165
+ }
166
+ """%(name_str)
167
+
168
+
169
+
170
+ response = requests.get(self.baseuri, params = {'format':'json', 'query':query})
171
+
172
+
173
+ ret_json = self.response_handler(response, query)
174
+
175
+
176
+ return ret_json
177
+
178
+
179
+
180
+ if __name__ == '__main__':
181
+ request_wrapper_wikidata = RequestWrapper(baseuri = 'https://query.wikidata.org/sparql')
182
+ #print(request_wrapper_wikidata.wikidata_nearby_query('Q370771'))
183
+ #print(request_wrapper_wikidata.wikidata_query_withinstate('San Bernardino'))
184
+
185
+ # not working now
186
+ print(request_wrapper_wikidata.linkedgeodata_query('San Bernardino'))
models/spabert/experiments/entity_matching/data_processing/run_linking_query.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #from query_wrapper import QueryWrapper
2
+ from request_wrapper import RequestWrapper
3
+ from get_namelist import *
4
+ import glob
5
+ import os
6
+ import json
7
+ import time
8
+
9
+ DATASET_OPTIONS = ['OSM', 'OS', 'USGS', 'GB1900']
10
+ KB_OPTIONS = ['wikidata', 'linkedgeodata']
11
+
12
+ DATASET = 'USGS'
13
+ KB = 'wikidata'
14
+ OVERWRITE = True
15
+ WITHIN_CA = True
16
+
17
+ assert DATASET in DATASET_OPTIONS
18
+ assert KB in KB_OPTIONS
19
+
20
+
21
+ def process_one_namelist(sparql_wrapper, namelist, out_path):
22
+
23
+ if OVERWRITE:
24
+ # flush the file if it's been written
25
+ with open(out_path, 'w') as f:
26
+ f.write('')
27
+
28
+
29
+ for name in namelist:
30
+ name = name.replace('"', '')
31
+ name = name.strip("'")
32
+ if len(name) == 0:
33
+ continue
34
+ print(name)
35
+ mydict = dict()
36
+
37
+ if KB == 'wikidata':
38
+ if WITHIN_CA:
39
+ mydict[name] = sparql_wrapper.wikidata_query_withinstate(name)
40
+ else:
41
+ mydict[name] = sparql_wrapper.wikidata_query(name)
42
+
43
+ elif KB == 'linkedgeodata':
44
+ mydict[name] = sparql_wrapper.linkedgeodata_query(name)
45
+ else:
46
+ raise NotImplementedError
47
+
48
+ line = json.dumps(mydict)
49
+
50
+ with open(out_path, 'a') as f:
51
+ f.write(line)
52
+ f.write('\n')
53
+ time.sleep(1)
54
+
55
+
56
+ def process_namelist_dict(sparql_wrapper, namelist_dict, out_dir):
57
+ i = 0
58
+ for map_name, namelist in namelist_dict.items():
59
+ # if i <=5:
60
+ # i += 1
61
+ # continue
62
+
63
+ print('processing %s' %map_name)
64
+
65
+ if WITHIN_CA:
66
+ out_path = os.path.join(out_dir, KB + '_' + map_name + '.json')
67
+ else:
68
+ out_path = os.path.join(out_dir, KB + '_ca_' + map_name + '.json')
69
+
70
+ process_one_namelist(sparql_wrapper, namelist, out_path)
71
+ i+=1
72
+
73
+
74
+ if KB == 'linkedgeodata':
75
+ sparql_wrapper = RequestWrapper(baseuri = 'http://linkedgeodata.org/sparql')
76
+ elif KB == 'wikidata':
77
+ sparql_wrapper = RequestWrapper(baseuri = 'https://query.wikidata.org/sparql')
78
+ else:
79
+ raise NotImplementedError
80
+
81
+
82
+
83
+ if DATASET == 'OSM':
84
+ osm_dir = '../surface_form/data_sample_london/data_osm/'
85
+ osm_paths = glob.glob(os.path.join(osm_dir, 'embedding*.json'))
86
+
87
+ out_path = 'outputs/'+KB+'_linking.json'
88
+ namelist = get_name_list_osm(osm_paths)
89
+
90
+ print('# files',len(file_paths))
91
+
92
+ process_one_namelist(sparql_wrapper, namelist, out_path)
93
+
94
+
95
+ elif DATASET == 'OS':
96
+ histmap_dir = 'data/labGISReport-master/output/'
97
+ file_paths = glob.glob(os.path.join(histmap_dir, '10*.json'))
98
+
99
+ out_path = 'outputs/'+KB+'_os_linking_descript.json'
100
+ namelist = get_name_list_usgs_od(file_paths)
101
+
102
+ print('# files',len(file_paths))
103
+
104
+
105
+ process_one_namelist(sparql_wrapper, namelist, out_path)
106
+
107
+ elif DATASET == 'USGS':
108
+ histmap_dir = 'data/labGISReport-master/output/'
109
+ file_paths = glob.glob(os.path.join(histmap_dir, 'USGS*.json'))
110
+
111
+ if WITHIN_CA:
112
+ out_dir = 'outputs/' + KB +'_ca'
113
+ else:
114
+ out_dir = 'outputs/' + KB
115
+ namelist_dict = get_name_list_usgs_od_per_map(file_paths)
116
+
117
+ if not os.path.isdir(out_dir):
118
+ os.makedirs(out_dir)
119
+
120
+ print('# files',len(file_paths))
121
+
122
+ process_namelist_dict(sparql_wrapper, namelist_dict, out_dir)
123
+
124
+ elif DATASET == 'GB1900':
125
+
126
+ file_path = 'data/GB1900_gazetteer_abridged_july_2018/gb1900_abridged.csv'
127
+ out_path = 'outputs/'+KB+'_gb1900_linking_descript.json'
128
+ namelist = get_name_list_gb1900(file_path)
129
+
130
+
131
+ process_one_namelist(sparql_wrapper, namelist, out_path)
132
+
133
+ else:
134
+ raise NotImplementedError
135
+
136
+
137
+
138
+
139
+ #namelist = namelist[730:] #for GB1900
140
+
141
+
142
+
143
+ print('done')
models/spabert/experiments/entity_matching/data_processing/run_map_neighbor_query.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from query_wrapper import QueryWrapper
2
+ from request_wrapper import RequestWrapper
3
+ import glob
4
+ import os
5
+ import json
6
+ import time
7
+ import random
8
+
9
+ DATASET_OPTIONS = ['OSM', 'OS', 'USGS', 'GB1900']
10
+ KB_OPTIONS = ['wikidata', 'linkedgeodata']
11
+
12
+ dataset = 'USGS'
13
+ kb = 'wikidata'
14
+ overwrite = False
15
+
16
+ assert dataset in DATASET_OPTIONS
17
+ assert kb in KB_OPTIONS
18
+
19
+ if dataset == 'OSM':
20
+ raise NotImplementedError
21
+
22
+ elif dataset == 'OS':
23
+ raise NotImplementedError
24
+
25
+ elif dataset == 'USGS':
26
+
27
+ candidate_file_paths = glob.glob('outputs/alignment_dir/wikidata_USGS*.json')
28
+ candidate_file_paths = sorted(candidate_file_paths)
29
+
30
+ out_dir = 'outputs/wikidata_neighbors/'
31
+
32
+ if not os.path.isdir(out_dir):
33
+ os.makedirs(out_dir)
34
+
35
+ elif dataset == 'GB1900':
36
+
37
+ raise NotImplementedError
38
+
39
+ else:
40
+ raise NotImplementedError
41
+
42
+
43
+ if kb == 'linkedgeodata':
44
+ sparql_wrapper = QueryWrapper(baseuri = 'http://linkedgeodata.org/sparql')
45
+ elif kb == 'wikidata':
46
+ #sparql_wrapper = QueryWrapper(baseuri = 'https://query.wikidata.org/sparql')
47
+ sparql_wrapper = RequestWrapper(baseuri = 'https://query.wikidata.org/sparql')
48
+ else:
49
+ raise NotImplementedError
50
+
51
+ start_map = 6 # 6
52
+ start_line = 4 # 4
53
+
54
+ for candiate_file_path in candidate_file_paths[start_map:]:
55
+ map_name = os.path.basename(candiate_file_path).split('wikidata_')[1]
56
+ out_path = os.path.join(out_dir, 'wikidata_' + map_name)
57
+
58
+ with open(candiate_file_path, 'r') as f:
59
+ cand_data = f.readlines()
60
+
61
+ with open(out_path, 'a') as out_f:
62
+ for line in cand_data[start_line:]:
63
+ line_dict = json.loads(line)
64
+ ret_line_dict = dict()
65
+ for key, value in line_dict.items(): # actually just one pair
66
+
67
+ print(key)
68
+
69
+ place_name = key
70
+ for cand_entity in value:
71
+ time.sleep(2)
72
+ q_id = cand_entity['item']['value'].split('/')[-1]
73
+ response = sparql_wrapper.wikidata_nearby_query(str(q_id))
74
+
75
+ if place_name in ret_line_dict:
76
+ ret_line_dict[place_name].append(response)
77
+ else:
78
+ ret_line_dict[place_name] = [response]
79
+
80
+ #time.sleep(random.random()*6)
81
+
82
+
83
+
84
+ out_f.write(json.dumps(ret_line_dict))
85
+ out_f.write('\n')
86
+
87
+ print('finished with ',candiate_file_path)
88
+ break
89
+
90
+ print('done')
91
+
92
+ '''
93
+ {"Martin": [{"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q27001"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(18.921388888 49.063611111)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "city in northern Slovakia"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q281028"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-101.734166666 43.175)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "city in and county seat of Bennett County, South Dakota, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q761390"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(1.3394 51.179)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "hamlet in Kent"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q2177502"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-100.115 47.826666666)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "city in North Dakota"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q2454021"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-85.64168 42.53698)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "village in Michigan, USA"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q2481111"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-93.21833 32.09917)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "village in Red River Parish, Louisiana, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q2635473"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-82.75944 37.56778)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "city in Kentucky, USA"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q2679547"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-1.9041 50.9759)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "village in Hampshire, England, United Kingdom"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q2780056"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-88.851666666 36.341944444)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "city in Tennessee, USA"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q3261150"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-83.18556 34.48639)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "town in Franklin and Stephens Counties, Georgia, USA"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q6002227"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-101.709 41.2581)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "census-designated place in Nebraska, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q6774807"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-82.1906 29.2936)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "unincorporated community in Marion County, Florida, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q6774809"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-83.336666666 41.5575)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "unincorporated community in Ohio, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q6774810"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(116.037 -32.071)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "suburb of Perth, Western Australia"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q9029707"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-81.476388888 33.069166666)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "unincorporated community in South Carolina, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q11770660"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-0.325391 53.1245)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "village and civil parish in Lincolnshire, UK"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q14692833"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-93.5444 47.4894)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "unincorporated community in Itasca County, Minnesota"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q14714180"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-79.0889 39.2242)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "unincorporated community in Grant County, West Virginia"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q18496647"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(11.95941 46.34585)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "human settlement in Italy"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q20949553"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(22.851944444 41.954444444)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "mountain in Republic of Macedonia"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q24065096"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "<http://www.wikidata.org/entity/Q111> Point(290.75 -21.34)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "crater on Mars"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q26300074"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-0.97879914 51.74729077)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "Thame, South Oxfordshire, Oxfordshire, OX9"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q27988822"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-87.67361111 38.12444444)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "human settlement in Vanderburgh County, Indiana, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q27995389"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-121.317 47.28)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "human settlement in Washington, United States of America"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q28345614"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-76.8 48.5)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "human settlement in Senneterre, Quebec, Canada"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q30626037"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-79.90972222 39.80638889)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "human settlement in United States of America"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q61038281"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-91.13 49.25)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "Meteorological Service of Canada's station for Martin (MSC ID: 6035000), Ontario, Canada"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q63526691"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(152.7011111 -29.881666666)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "Parish of Fitzroy County, New South Wales, Australia"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q63526695"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(148.1011111 -33.231666666)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "Parish of Ashburnham County, New South Wales, Australia"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q96149222"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(14.725573779 48.76768332)"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q96158116"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(15.638358708 49.930136157)"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q103777024"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-3.125028822 58.694748091)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "Shipwreck off the Scottish Coast, imported from Canmore Nov 2020"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q107077206"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(11.688165 46.582982)"}}]}
94
+
95
+
96
+ if overwrite:
97
+ # flush the file if it's been written
98
+ with open(out_path, 'w') as f:
99
+ f.write('')
100
+
101
+ for name in namelist:
102
+ name = name.replace('"', '')
103
+ name = name.strip("'")
104
+ if len(name) == 0:
105
+ continue
106
+ print(name)
107
+ mydict = dict()
108
+ mydict[name] = sparql_wrapper.wikidata_query(name)
109
+ line = json.dumps(mydict)
110
+ #print(line)
111
+ with open(out_path, 'a') as f:
112
+ f.write(line)
113
+ f.write('\n')
114
+ time.sleep(1)
115
+
116
+ print('done')
117
+
118
+
119
+ {"info": {"name": "10TH ST", "geometry": [4193.118085062303, -831.274950414831]},
120
+ "neighbor_info":
121
+ {"name_list": ["BM 107", "PALM AVE", "WT", "Hidalgo Sch", "PO", "MAIN ST", "BRYANT CANAL", "BM 123", "Oakley Sch", "BRAWLEY", "Witter Sch", "BM 104", "Pistol Range", "Reid Sch", "STANLEY", "MUNICIPAL AIRPORT", "WESTERN AVE", "CANAL", "Riverview Cem", "BEST CANAL"],
122
+ "geometry_list": [[4180.493095652702, -836.0635465095995], [4240.450935702045, -855.345637906981], [4136.084840542623, -917.7895986922882], [4150.386997979736, -948.7258091165079], [4056.955267048625, -847.1018277439381], [4008.112642182582, -849.089249977583], [4124.177447575567, -1004.0706369942257], [4145.382175508665, -626.1608201557082], [4398.137868976953, -764.1087236140554], [4221.1546492913285, -1062.5745271963772], [4015.203890157584, -985.0178210457995], [3989.2345421184878, -948.9340389243871], [4385.585449075614, -660.4590917125413], [3936.505159635338, -803.6822663422273], [3960.1233867112846, -686.7988766730389], [4409.714306709143, -600.6633389979504], [3871.2873706574037, -832.0785684368772], [4304.899727301024, -524.472390102557], [3955.640201659347, -578.5544271698675], [4075.8524354668034, -1183.5837385075774]]}}
123
+ '''
models/spabert/experiments/entity_matching/data_processing/run_query_sample.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from query_wrapper import QueryWrapper
2
+ from get_namelist import *
3
+ import glob
4
+ import os
5
+ import json
6
+ import time
7
+
8
+
9
+ #sparql_wrapper_linkedgeo = QueryWrapper(baseuri = 'http://linkedgeodata.org/sparql')
10
+
11
+ #print(sparql_wrapper_linkedgeo.linkedgeodata_query('Los Angeles'))
12
+
13
+
14
+ sparql_wrapper_wikidata = QueryWrapper(baseuri = 'https://query.wikidata.org/sparql')
15
+
16
+ #print(sparql_wrapper_wikidata.wikidata_query('Los Angeles'))
17
+
18
+ #time.sleep(3)
19
+
20
+ #print(sparql_wrapper_wikidata.wikidata_nearby_query('Q370771'))
21
+ print(sparql_wrapper_wikidata.wikidata_nearby_query('Q97625145'))
22
+
models/spabert/experiments/entity_matching/data_processing/run_wikidata_neighbor_query.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import json
3
+ from request_wrapper import RequestWrapper
4
+ import time
5
+ import pdb
6
+
7
+ start_idx = 17335
8
+ wikidata_sample30k_path = 'wikidata_sample30k/wikidata_30k.json'
9
+ out_path = 'wikidata_sample30k/wikidata_30k_neighbor.json'
10
+
11
+ #with open(out_path, 'w') as out_f:
12
+ # pass
13
+
14
+ sparql_wrapper = RequestWrapper(baseuri = 'https://query.wikidata.org/sparql')
15
+
16
+ df= pd.read_json(wikidata_sample30k_path)
17
+ df = df[start_idx:]
18
+
19
+ print('length of df:', len(df))
20
+
21
+ for index, record in df.iterrows():
22
+ print(index)
23
+ uri = record.results['item']['value']
24
+ q_id = uri.split('/')[-1]
25
+ response = sparql_wrapper.wikidata_nearby_query(str(q_id))
26
+ time.sleep(1)
27
+ with open(out_path, 'a') as out_f:
28
+ out_f.write(json.dumps(response))
29
+ out_f.write('\n')
30
+
31
+
models/spabert/experiments/entity_matching/data_processing/samples.sparql ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Entities near xxx within 1km
3
+ SELECT ?place ?placeLabel ?location ?instanceLabel ?placeDescription
4
+ WHERE
5
+ {
6
+ wd:Q9188 wdt:P625 ?loc .
7
+ SERVICE wikibase:around {
8
+ ?place wdt:P625 ?location .
9
+ bd:serviceParam wikibase:center ?loc .
10
+ bd:serviceParam wikibase:radius "1" .
11
+ }
12
+ OPTIONAL { ?place wdt:P31 ?instance }
13
+ SERVICE wikibase:label { bd:serviceParam wikibase:language "en" }
14
+ BIND(geof:distance(?loc, ?location) as ?dist)
15
+ } ORDER BY ?dist
16
+ '''
17
+
18
+
19
+ SELECT distinct ?item ?itemLabel WHERE {
20
+ ?item wdt:P625 ?geo .
21
+ SERVICE wikibase:label { bd:serviceParam wikibase:language "[AUTO_LANGUAGE],en". }
22
+ }
models/spabert/experiments/entity_matching/data_processing/select_ambi.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ json_file = 'entity_linking/outputs/wikidata_usgs_linking_descript.json'
4
+
5
+ with open(json_file, 'r') as f:
6
+ data = f.readlines()
7
+
8
+ num_ambi = 0
9
+ for line in data:
10
+ line_dict = json.loads(line)
11
+ for key,value in line_dict.items():
12
+ len_value = len(value)
13
+ if len_value < 2:
14
+ continue
15
+ else:
16
+ num_ambi += 1
17
+ print(key)
18
+ print(num_ambi)
models/spabert/experiments/entity_matching/data_processing/wikidata_sample30k/wikidata_30k.json ADDED
The diff for this file is too large to render. See raw diff
 
models/spabert/experiments/entity_matching/src/evaluation-mrr.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+
5
+ import sys
6
+ import os
7
+ import glob
8
+ import json
9
+ import numpy as np
10
+ import pandas as pd
11
+ import pdb
12
+
13
+
14
+ prediction_dir = sys.argv[1]
15
+
16
+ print(prediction_dir)
17
+
18
+ gt_dir = '../data_processing/outputs/alignment_gt_dir/'
19
+ prediction_path_list = sorted(os.listdir(prediction_dir))
20
+
21
+ DISPLAY = False
22
+ DETAIL = False
23
+
24
+ if DISPLAY:
25
+ from IPython.display import display
26
+
27
+ def recall_at_k_all_map(all_rank_list, k = 1):
28
+
29
+ rank_list = [item for sublist in all_rank_list for item in sublist]
30
+ total_query = len(rank_list)
31
+ prec = np.sum(np.array(rank_list)<=k)
32
+ prec = 1.0 * prec / total_query
33
+
34
+ return prec
35
+
36
+ def recall_at_k_permap(all_rank_list, k = 1):
37
+
38
+ prec_list = []
39
+ for rank_list in all_rank_list:
40
+ total_query = len(rank_list)
41
+ prec = np.sum(np.array(rank_list)<=k)
42
+ prec = 1.0 * prec / total_query
43
+ prec_list.append(prec)
44
+
45
+ return prec_list
46
+
47
+
48
+
49
+
50
+ def reciprocal_rank(all_rank_list):
51
+
52
+ recip_list = [1./rank for rank in all_rank_list]
53
+ mean_recip = np.mean(recip_list)
54
+
55
+ return mean_recip, recip_list
56
+
57
+
58
+
59
+
60
+ count_hist_list = []
61
+
62
+ all_rank_list = []
63
+
64
+ all_recip_list = []
65
+
66
+ permap_recip_list = []
67
+
68
+ for map_path in prediction_path_list:
69
+
70
+ pred_path = os.path.join(prediction_dir, map_path)
71
+ gt_path = os.path.join(gt_dir, map_path.split('.json')[0] + '.csv')
72
+
73
+ if DETAIL:
74
+ print(pred_path)
75
+
76
+
77
+ with open(gt_path, 'r') as f:
78
+ gt_data = f.readlines()
79
+
80
+ gt_dict = dict()
81
+ for line in gt_data:
82
+ line = line.split(',')
83
+ pivot_name = line[0]
84
+ gt_uri = line[1]
85
+ gt_dict[pivot_name] = gt_uri
86
+
87
+ rank_list = []
88
+ pivot_name_list = []
89
+ with open(pred_path, 'r') as f:
90
+ pred_data = f.readlines()
91
+ for line in pred_data:
92
+ pred_dict = json.loads(line)
93
+ #print(pred_dict.keys())
94
+ pivot_name = pred_dict['pivot_name']
95
+ sorted_match_uri = pred_dict['sorted_match_uri']
96
+ #sorted_match_des = pred_dict['sorted_match_des']
97
+ sorted_sim_matrix = pred_dict['sorted_sim_matrix']
98
+
99
+
100
+ total = len(sorted_match_uri)
101
+ if total == 1:
102
+ continue
103
+
104
+ if pivot_name in gt_dict:
105
+
106
+ gt_uri = gt_dict[pivot_name]
107
+
108
+ try:
109
+ assert gt_uri in sorted_match_uri
110
+ except Exception as e:
111
+ #print(e)
112
+ continue
113
+
114
+ pivot_name_list.append(pivot_name)
115
+ count_hist_list.append(total)
116
+ rank = sorted_match_uri.index(gt_uri) +1
117
+
118
+ rank_list.append(rank)
119
+ #print(rank,'/',total)
120
+
121
+ all_rank_list.append(rank_list)
122
+
123
+ mean_recip, recip_list = reciprocal_rank(rank_list)
124
+
125
+ all_recip_list.extend(recip_list)
126
+ permap_recip_list.append(recip_list)
127
+
128
+ d = {'pivot': pivot_name_list + ['AVG'], 'rank':rank_list + [' '] ,'recip rank': recip_list + [str(mean_recip)]}
129
+ if DETAIL:
130
+ print(pivot_name_list, rank_list, recip_list)
131
+
132
+ if DISPLAY:
133
+ df = pd.DataFrame(data=d)
134
+
135
+ display(df)
136
+
137
+
138
+
139
+ print('all mrr, micro', np.mean(all_recip_list))
140
+
141
+
142
+ if DETAIL:
143
+
144
+ len(rank_list)
145
+
146
+
147
+
148
+ print(recall_at_k_all_map(all_rank_list, k = 1))
149
+ print(recall_at_k_all_map(all_rank_list, k = 2))
150
+ print(recall_at_k_all_map(all_rank_list, k = 5))
151
+ print(recall_at_k_all_map(all_rank_list, k = 10))
152
+
153
+
154
+ print(prediction_path_list)
155
+
156
+
157
+ prec_list_1 = recall_at_k_permap(all_rank_list, k = 1)
158
+ prec_list_2 = recall_at_k_permap(all_rank_list, k = 2)
159
+ prec_list_5 = recall_at_k_permap(all_rank_list, k = 5)
160
+ prec_list_10 = recall_at_k_permap(all_rank_list, k = 10)
161
+
162
+ if DETAIL:
163
+
164
+ print(np.mean(prec_list_1))
165
+ print(prec_list_1)
166
+ print('\n')
167
+
168
+ print(np.mean(prec_list_2))
169
+ print(prec_list_2)
170
+ print('\n')
171
+
172
+ print(np.mean(prec_list_5))
173
+ print(prec_list_5)
174
+ print('\n')
175
+
176
+ print(np.mean(prec_list_10))
177
+ print(prec_list_10)
178
+ print('\n')
179
+
180
+
181
+
182
+
183
+
184
+
185
+
186
+
187
+ import pandas as pd
188
+
189
+
190
+
191
+ map_name_list = [name.split('.json')[0].split('USGS-')[1] for name in prediction_path_list]
192
+ d = {'map_name': map_name_list,'recall@1': prec_list_1, 'recall@2': prec_list_2, 'recall@5': prec_list_5, 'recall@10': prec_list_10 }
193
+ df = pd.DataFrame(data=d)
194
+
195
+
196
+ if DETAIL:
197
+ print(df)
198
+
199
+
200
+
201
+
202
+
203
+ category = ['15-CA','30-CA','60-CA']
204
+ col_1 = [np.mean(prec_list_1[0:4]), np.mean(prec_list_1[4:9]), np.mean(prec_list_1[9:])]
205
+ col_2 = [np.mean(prec_list_2[0:4]), np.mean(prec_list_2[4:9]), np.mean(prec_list_2[9:])]
206
+ col_3 = [np.mean(prec_list_5[0:4]), np.mean(prec_list_5[4:9]), np.mean(prec_list_5[9:])]
207
+ col_4 = [np.mean(prec_list_10[0:4]), np.mean(prec_list_10[4:9]), np.mean(prec_list_10[9:])]
208
+
209
+
210
+
211
+ mrr_15 = permap_recip_list[0] + permap_recip_list[1] + permap_recip_list[2] + permap_recip_list[3]
212
+ mrr_30 = permap_recip_list[4] + permap_recip_list[5] + permap_recip_list[6] + permap_recip_list[7] + permap_recip_list[8]
213
+ mrr_60 = permap_recip_list[9] + permap_recip_list[10] + permap_recip_list[11] + permap_recip_list[12] + permap_recip_list[13]
214
+
215
+
216
+
217
+ column_5 = [np.mean(mrr_15), np.mean(mrr_30), np.mean(mrr_60)]
218
+
219
+
220
+ d = {'map set': category, 'mrr': column_5, 'prec@1': col_1, 'prec@2': col_2, 'prec@5': col_3, 'prec@10': col_4 }
221
+ df = pd.DataFrame(data=d)
222
+
223
+ print(df)
224
+
225
+
226
+
227
+
228
+ print('all mrr, micro', np.mean(all_recip_list))
229
+
230
+ print('\n')
231
+
232
+
233
+
234
+ print(recall_at_k_all_map(all_rank_list, k = 1))
235
+ print(recall_at_k_all_map(all_rank_list, k = 2))
236
+ print(recall_at_k_all_map(all_rank_list, k = 5))
237
+ print(recall_at_k_all_map(all_rank_list, k = 10))
238
+
239
+
240
+
241
+
242
+ if DISPLAY:
243
+
244
+ import seaborn
245
+
246
+ p = seaborn.histplot(data = count_hist_list, color = 'blue', alpha=0.2)
247
+ p.set_xlabel("Number of Candiates")
248
+ p.set_title("Candidate Distribution in USGS")
249
+
250
+
251
+
252
+
253
+ len(count_hist_list)
254
+
255
+
256
+
257
+
258
+
259
+
260
+
models/spabert/experiments/entity_matching/src/linking_ablation.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ import sys
5
+ import os
6
+ import numpy as np
7
+ import pdb
8
+ import json
9
+ import scipy.spatial as sp
10
+ import argparse
11
+
12
+
13
+ import torch
14
+ from torch.utils.data import DataLoader
15
+
16
+ from transformers import AdamW
17
+ from transformers import BertTokenizer
18
+ from tqdm import tqdm # for our progress bar
19
+
20
+ sys.path.append('../../../')
21
+ from datasets.usgs_os_sample_loader import USGS_MapDataset
22
+ from datasets.wikidata_sample_loader import Wikidata_Geocoord_Dataset, Wikidata_Random_Dataset
23
+ from models.spatial_bert_model import SpatialBertModel
24
+ from models.spatial_bert_model import SpatialBertConfig
25
+ from utils.find_closest import find_ref_closest_match, sort_ref_closest_match
26
+ from utils.common_utils import load_spatial_bert_pretrained_weights, get_spatialbert_embedding, get_bert_embedding, write_to_csv
27
+ from utils.baseline_utils import get_baseline_model
28
+
29
+ from transformers import BertModel
30
+
31
+ sys.path.append('/home/zekun/spatial_bert/spatial_bert/datasets')
32
+ from dataset_loader import SpatialDataset
33
+ from osm_sample_loader import PbfMapDataset
34
+
35
+
36
+
37
+ MODEL_OPTIONS = ['spatial_bert-base','spatial_bert-large', 'bert-base','bert-large','roberta-base','roberta-large',
38
+ 'spanbert-base','spanbert-large','luke-base','luke-large',
39
+ 'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large']
40
+
41
+
42
+ CANDSET_MODES = ['all_map'] # candidate set is constructed based on all maps or one map
43
+
44
+ def recall_at_k(rank_list, k = 1):
45
+
46
+ total_query = len(rank_list)
47
+ recall = np.sum(np.array(rank_list)<=k)
48
+ recall = 1.0 * recall / total_query
49
+
50
+ return recall
51
+
52
+ def reciprocal_rank(all_rank_list):
53
+
54
+ recip_list = [1./rank for rank in all_rank_list]
55
+ mean_recip = np.mean(recip_list)
56
+
57
+ return mean_recip, recip_list
58
+
59
+ def link_to_itself(source_embedding_ogc_list, target_embedding_ogc_list):
60
+
61
+ source_emb_list = [source_dict['emb'] for source_dict in source_embedding_ogc_list]
62
+ source_ogc_list = [source_dict['ogc_fid'] for source_dict in source_embedding_ogc_list]
63
+
64
+ target_emb_list = [target_dict['emb'] for target_dict in target_embedding_ogc_list]
65
+ target_ogc_list = [target_dict['ogc_fid'] for target_dict in target_embedding_ogc_list]
66
+
67
+ rank_list = []
68
+ for source_emb, source_ogc in zip(source_emb_list, source_ogc_list):
69
+ sim_matrix = 1 - sp.distance.cdist(np.array(target_emb_list), np.array([source_emb]), 'cosine')
70
+ closest_match_ogc = sort_ref_closest_match(sim_matrix, target_ogc_list)
71
+
72
+ closest_match_ogc = [a[0] for a in closest_match_ogc]
73
+ rank = closest_match_ogc.index(source_ogc) +1
74
+ rank_list.append(rank)
75
+
76
+
77
+ mean_recip, recip_list = reciprocal_rank(rank_list)
78
+ r1 = recall_at_k(rank_list, k = 1)
79
+ r5 = recall_at_k(rank_list, k = 5)
80
+ r10 = recall_at_k(rank_list, k = 10)
81
+
82
+ return mean_recip , r1, r5, r10
83
+
84
+ def get_embedding_and_ogc(dataset, model_name, model):
85
+ dict_list = []
86
+
87
+ for source in dataset:
88
+ if model_name == 'spatial_bert-base' or model_name == 'spatial_bert-large':
89
+ source_emb = get_spatialbert_embedding(source, model)
90
+ else:
91
+ source_emb = get_bert_embedding(source, model)
92
+
93
+ source_dict = {}
94
+ source_dict['emb'] = source_emb
95
+ source_dict['ogc_fid'] = source['ogc_fid']
96
+ #wikidata_dict['wikidata_des_list'] = [wikidata_cand['description']]
97
+
98
+ dict_list.append(source_dict)
99
+
100
+ return dict_list
101
+
102
+
103
+ def entity_linking_func(args):
104
+
105
+ model_name = args.model_name
106
+ candset_mode = args.candset_mode
107
+
108
+ distance_norm_factor = args.distance_norm_factor
109
+ spatial_dist_fill= args.spatial_dist_fill
110
+ sep_between_neighbors = args.sep_between_neighbors
111
+
112
+ spatial_bert_weight_dir = args.spatial_bert_weight_dir
113
+ spatial_bert_weight_name = args.spatial_bert_weight_name
114
+
115
+ if_no_spatial_distance = args.no_spatial_distance
116
+ random_remove_neighbor = args.random_remove_neighbor
117
+
118
+
119
+ assert model_name in MODEL_OPTIONS
120
+ assert candset_mode in CANDSET_MODES
121
+
122
+
123
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
124
+
125
+
126
+
127
+ if model_name == 'spatial_bert-base':
128
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
129
+
130
+ config = SpatialBertConfig()
131
+ model = SpatialBertModel(config)
132
+
133
+ model.to(device)
134
+ model.eval()
135
+
136
+ # load pretrained weights
137
+ weight_path = os.path.join(spatial_bert_weight_dir, spatial_bert_weight_name)
138
+ model = load_spatial_bert_pretrained_weights(model, weight_path)
139
+
140
+ elif model_name == 'spatial_bert-large':
141
+ tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
142
+
143
+ config = SpatialBertConfig(hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24)
144
+ model = SpatialBertModel(config)
145
+
146
+ model.to(device)
147
+ model.eval()
148
+
149
+ # load pretrained weights
150
+ weight_path = os.path.join(spatial_bert_weight_dir, spatial_bert_weight_name)
151
+ model = load_spatial_bert_pretrained_weights(model, weight_path)
152
+
153
+ else:
154
+ model, tokenizer = get_baseline_model(model_name)
155
+ model.to(device)
156
+ model.eval()
157
+
158
+ source_file_path = '../data/osm-point-minnesota-full.json'
159
+ source_dataset = PbfMapDataset(data_file_path = source_file_path,
160
+ tokenizer = tokenizer,
161
+ max_token_len = 512,
162
+ distance_norm_factor = distance_norm_factor,
163
+ spatial_dist_fill = spatial_dist_fill,
164
+ with_type = False,
165
+ sep_between_neighbors = sep_between_neighbors,
166
+ mode = None,
167
+ random_remove_neighbor = random_remove_neighbor,
168
+ )
169
+
170
+ target_dataset = PbfMapDataset(data_file_path = source_file_path,
171
+ tokenizer = tokenizer,
172
+ max_token_len = 512,
173
+ distance_norm_factor = distance_norm_factor,
174
+ spatial_dist_fill = spatial_dist_fill,
175
+ with_type = False,
176
+ sep_between_neighbors = sep_between_neighbors,
177
+ mode = None,
178
+ random_remove_neighbor = 0., # keep all
179
+ )
180
+
181
+ # process candidates for each phrase
182
+
183
+
184
+ source_embedding_ogc_list = get_embedding_and_ogc(source_dataset, model_name, model)
185
+ target_embedding_ogc_list = get_embedding_and_ogc(target_dataset, model_name, model)
186
+
187
+
188
+ mean_recip , r1, r5, r10 = link_to_itself(source_embedding_ogc_list, target_embedding_ogc_list)
189
+ print('\n')
190
+ print(random_remove_neighbor, mean_recip , r1, r5, r10)
191
+
192
+
193
+
194
+ def main():
195
+ parser = argparse.ArgumentParser()
196
+ parser.add_argument('--model_name', type=str, default='spatial_bert-base')
197
+ parser.add_argument('--candset_mode', type=str, default='all_map')
198
+
199
+ parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
200
+ parser.add_argument('--spatial_dist_fill', type=float, default = 20)
201
+
202
+ parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
203
+ parser.add_argument('--no_spatial_distance', default=False, action='store_true')
204
+
205
+ parser.add_argument('--spatial_bert_weight_dir', type = str, default = None)
206
+ parser.add_argument('--spatial_bert_weight_name', type = str, default = None)
207
+
208
+ parser.add_argument('--random_remove_neighbor', type = float, default = 0.)
209
+
210
+
211
+ args = parser.parse_args()
212
+ # print('\n')
213
+ # print(args)
214
+ # print('\n')
215
+
216
+ entity_linking_func(args)
217
+
218
+ # CUDA_VISIBLE_DEVICES='1' python3 linking_ablation.py --sep_between_neighbors --model_name='spatial_bert-base' --spatial_bert_weight_dir='/data/zekun/spatial_bert_weights/typing_lr5e-05_sep_bert-base_nofreeze_london_california_bsize12/ep0_iter06000_0.2936/' --spatial_bert_weight_name='keeppos_ep0_iter02000_0.4879.pth' --random_remove_neighbor=0.1
219
+
220
+
221
+ # CUDA_VISIBLE_DEVICES='1' python3 linking_ablation.py --sep_between_neighbors --model_name='spatial_bert-large' --spatial_bert_weight_dir='/data/zekun/spatial_bert_weights/typing_lr1e-06_sep_bert-large_nofreeze_london_california_bsize12/ep2_iter02000_0.3921/' --spatial_bert_weight_name='keeppos_ep8_iter03568_0.2661_val0.2284.pth' --random_remove_neighbor=0.1
222
+
223
+
224
+ if __name__ == '__main__':
225
+
226
+ main()
227
+
228
+
models/spabert/experiments/entity_matching/src/unsupervised_wiki_location_allcand.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ import sys
5
+ import os
6
+ import numpy as np
7
+ import pdb
8
+ import json
9
+ import scipy.spatial as sp
10
+ import argparse
11
+
12
+
13
+ import torch
14
+ from torch.utils.data import DataLoader
15
+ #from transformers.models.bert.modeling_bert import BertForMaskedLM
16
+
17
+ from transformers import AdamW
18
+ from transformers import BertTokenizer
19
+ from tqdm import tqdm # for our progress bar
20
+
21
+ sys.path.append('../../../')
22
+ from datasets.usgs_os_sample_loader import USGS_MapDataset
23
+ from datasets.wikidata_sample_loader import Wikidata_Geocoord_Dataset, Wikidata_Random_Dataset
24
+ from models.spatial_bert_model import SpatialBertModel
25
+ from models.spatial_bert_model import SpatialBertConfig
26
+ #from models.spatial_bert_model import SpatialBertForMaskedLM
27
+ from utils.find_closest import find_ref_closest_match, sort_ref_closest_match
28
+ from utils.common_utils import load_spatial_bert_pretrained_weights, get_spatialbert_embedding, get_bert_embedding, write_to_csv
29
+ from utils.baseline_utils import get_baseline_model
30
+
31
+ from transformers import BertModel
32
+
33
+
34
+ MODEL_OPTIONS = ['spatial_bert-base','spatial_bert-large', 'bert-base','bert-large','roberta-base','roberta-large',
35
+ 'spanbert-base','spanbert-large','luke-base','luke-large',
36
+ 'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large']
37
+
38
+ MAP_TYPES = ['usgs']
39
+ CANDSET_MODES = ['all_map'] # candidate set is constructed based on all maps or one map
40
+
41
+
42
+ def disambiguify(model, model_name, usgs_dataset, wikidata_dict_list, candset_mode = 'all_map', if_use_distance = True, select_indices = None):
43
+
44
+ if select_indices is None:
45
+ select_indices = range(0, len(usgs_dataset))
46
+
47
+
48
+ assert(candset_mode in ['all_map','per_map'])
49
+
50
+ wikidata_emb_list = [wikidata_dict['wikidata_emb_list'] for wikidata_dict in wikidata_dict_list]
51
+ wikidata_uri_list = [wikidata_dict['wikidata_uri_list'] for wikidata_dict in wikidata_dict_list]
52
+ #wikidata_des_list = [wikidata_dict['wikidata_des_list'] for wikidata_dict in wikidata_dict_list]
53
+
54
+ if candset_mode == 'all_map':
55
+ wikidata_emb_list = [item for sublist in wikidata_emb_list for item in sublist] # flatten
56
+ wikidata_uri_list = [item for sublist in wikidata_uri_list for item in sublist] # flatten
57
+ #wikidata_des_list = [item for sublist in wikidata_des_list for item in sublist] # flatten
58
+
59
+
60
+ ret_list = []
61
+ for i in select_indices:
62
+
63
+ if candset_mode == 'per_map':
64
+ usgs_entity = usgs_dataset[i]
65
+ wikidata_emb_list = wikidata_emb_list[i]
66
+ wikidata_uri_list = wikidata_uri_list[i]
67
+ #wikidata_des_list = wikidata_des_list[i]
68
+
69
+ elif candset_mode == 'all_map':
70
+ usgs_entity = usgs_dataset[i]
71
+ else:
72
+ raise NotImplementedError
73
+
74
+ if model_name == 'spatial_bert-base' or model_name == 'spatial_bert-large':
75
+ usgs_emb = get_spatialbert_embedding(usgs_entity, model, use_distance = if_use_distance)
76
+ else:
77
+ usgs_emb = get_bert_embedding(usgs_entity, model)
78
+
79
+
80
+ sim_matrix = 1 - sp.distance.cdist(np.array(wikidata_emb_list), np.array([usgs_emb]), 'cosine')
81
+
82
+ closest_match_uri = sort_ref_closest_match(sim_matrix, wikidata_uri_list)
83
+ #closest_match_des = sort_ref_closest_match(sim_matrix, wikidata_des_list)
84
+
85
+
86
+ sorted_sim_matrix = np.sort(sim_matrix, axis = 0)[::-1] # descending order
87
+
88
+ ret_dict = dict()
89
+ ret_dict['pivot_name'] = usgs_entity['pivot_name']
90
+ ret_dict['sorted_match_uri'] = [a[0] for a in closest_match_uri]
91
+ #ret_dict['sorted_match_des'] = [a[0] for a in closest_match_des]
92
+ ret_dict['sorted_sim_matrix'] = [a[0] for a in sorted_sim_matrix]
93
+
94
+ ret_list.append(ret_dict)
95
+
96
+ return ret_list
97
+
98
+ def entity_linking_func(args):
99
+
100
+ model_name = args.model_name
101
+ map_type = args.map_type
102
+ candset_mode = args.candset_mode
103
+
104
+ usgs_distance_norm_factor = args.usgs_distance_norm_factor
105
+ spatial_dist_fill= args.spatial_dist_fill
106
+ sep_between_neighbors = args.sep_between_neighbors
107
+
108
+ spatial_bert_weight_dir = args.spatial_bert_weight_dir
109
+ spatial_bert_weight_name = args.spatial_bert_weight_name
110
+
111
+ if_no_spatial_distance = args.no_spatial_distance
112
+
113
+
114
+ assert model_name in MODEL_OPTIONS
115
+ assert map_type in MAP_TYPES
116
+ assert candset_mode in CANDSET_MODES
117
+
118
+
119
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
120
+
121
+ if args.out_dir is None:
122
+
123
+ if model_name == 'spatial_bert-base' or model_name == 'spatial_bert-large':
124
+
125
+ if sep_between_neighbors:
126
+ spatialbert_output_dir_str = 'dnorm' + str(usgs_distance_norm_factor ) + '_distfill' + str(spatial_dist_fill) + '_sep'
127
+ else:
128
+ spatialbert_output_dir_str = 'dnorm' + str(usgs_distance_norm_factor ) + '_distfill' + str(spatial_dist_fill) + '_nosep'
129
+
130
+
131
+ checkpoint_ep = spatial_bert_weight_name.split('_')[3]
132
+ checkpoint_iter = spatial_bert_weight_name.split('_')[4]
133
+ loss_val = spatial_bert_weight_name.split('_')[5][:-4]
134
+
135
+ if if_no_spatial_distance:
136
+ linking_prediction_dir = 'linking_prediction_dir/abalation_no_distance/'
137
+ else:
138
+ linking_prediction_dir = 'linking_prediction_dir'
139
+
140
+ if model_name == 'spatial_bert-base':
141
+ out_dir = os.path.join('/data2/zekun/', linking_prediction_dir, spatialbert_output_dir_str) + '/' + map_type + '-' + model_name + '-' + checkpoint_ep + '-' + checkpoint_iter + '-' + loss_val
142
+ elif model_name == 'spatial_bert-large':
143
+
144
+ freeze_str = spatial_bert_weight_dir.split('/')[-2].split('_')[1] # either 'freeze' or 'nofreeze'
145
+ out_dir = os.path.join('/data2/zekun/', linking_prediction_dir, spatialbert_output_dir_str) + '/' + map_type + '-' + model_name + '-' + checkpoint_ep + '-' + checkpoint_iter + '-' + loss_val + '-' + freeze_str
146
+
147
+
148
+
149
+ else:
150
+ out_dir = '/data2/zekun/baseline_linking_prediction_dir/' + map_type + '-' + model_name
151
+
152
+ else:
153
+ out_dir = args.out_dir
154
+
155
+ print('out_dir', out_dir)
156
+
157
+ if not os.path.isdir(out_dir):
158
+ os.makedirs(out_dir)
159
+
160
+
161
+ if model_name == 'spatial_bert-base':
162
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
163
+
164
+ config = SpatialBertConfig()
165
+ model = SpatialBertModel(config)
166
+
167
+ model.to(device)
168
+ model.eval()
169
+
170
+ # load pretrained weights
171
+ weight_path = os.path.join(spatial_bert_weight_dir, spatial_bert_weight_name)
172
+ model = load_spatial_bert_pretrained_weights(model, weight_path)
173
+
174
+ elif model_name == 'spatial_bert-large':
175
+ tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
176
+
177
+ config = SpatialBertConfig(hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24)
178
+ model = SpatialBertModel(config)
179
+
180
+ model.to(device)
181
+ model.eval()
182
+
183
+ # load pretrained weights
184
+ weight_path = os.path.join(spatial_bert_weight_dir, spatial_bert_weight_name)
185
+ model = load_spatial_bert_pretrained_weights(model, weight_path)
186
+
187
+ else:
188
+ model, tokenizer = get_baseline_model(model_name)
189
+ model.to(device)
190
+ model.eval()
191
+
192
+
193
+ if map_type == 'usgs':
194
+ map_name_list = ['USGS-15-CA-brawley-e1957-s1957-p1961',
195
+ 'USGS-15-CA-paloalto-e1899-s1895-rp1911',
196
+ 'USGS-15-CA-capesanmartin-e1921-s1917',
197
+ 'USGS-15-CA-sanfrancisco-e1899-s1892-rp1911',
198
+ 'USGS-30-CA-dardanelles-e1898-s1891-rp1912',
199
+ 'USGS-30-CA-holtville-e1907-s1905-rp1946',
200
+ 'USGS-30-CA-indiospecial-e1904-s1901-rp1910',
201
+ 'USGS-30-CA-lompoc-e1943-s1903-ap1941-rv1941',
202
+ 'USGS-30-CA-sanpedro-e1943-rv1944',
203
+ 'USGS-60-CA-alturas-e1892-rp1904',
204
+ 'USGS-60-CA-amboy-e1942',
205
+ 'USGS-60-CA-amboy-e1943-rv1943',
206
+ 'USGS-60-CA-modoclavabed-e1886-s1884',
207
+ 'USGS-60-CA-saltonsea-e1943-ap1940-rv1942']
208
+
209
+ print('processing wikidata...')
210
+
211
+ wikidata_dict_list = []
212
+
213
+ wikidata_random30k = Wikidata_Random_Dataset(
214
+ data_file_path = '../data_processing/wikidata_sample30k/wikidata_30k_neighbor_reformat.json',
215
+ #neighbor_file_path = '../data_processing/wikidata_sample30k/wikidata_30k_neighbor.json',
216
+ tokenizer = tokenizer,
217
+ max_token_len = 512,
218
+ distance_norm_factor = 0.0001,
219
+ spatial_dist_fill=100,
220
+ sep_between_neighbors = sep_between_neighbors,
221
+ )
222
+
223
+ # process candidates for each phrase
224
+ for wikidata_cand in wikidata_random30k:
225
+ if model_name == 'spatial_bert-base' or model_name == 'spatial_bert-large':
226
+ wikidata_emb = get_spatialbert_embedding(wikidata_cand, model)
227
+ else:
228
+ wikidata_emb = get_bert_embedding(wikidata_cand, model)
229
+
230
+ wikidata_dict = {}
231
+ wikidata_dict['wikidata_emb_list'] = [wikidata_emb]
232
+ wikidata_dict['wikidata_uri_list'] = [wikidata_cand['uri']]
233
+
234
+ wikidata_dict_list.append(wikidata_dict)
235
+
236
+
237
+
238
+ for map_name in map_name_list:
239
+
240
+ print(map_name)
241
+
242
+ wikidata_dict_per_map = {}
243
+ wikidata_dict_per_map['wikidata_emb_list'] = []
244
+ wikidata_dict_per_map['wikidata_uri_list'] = []
245
+
246
+ wikidata_dataset_permap = Wikidata_Geocoord_Dataset(
247
+ data_file_path = '../data_processing/outputs/wikidata_reformat/wikidata_' + map_name + '.json',
248
+ tokenizer = tokenizer,
249
+ max_token_len = 512,
250
+ distance_norm_factor = 0.0001,
251
+ spatial_dist_fill=100,
252
+ sep_between_neighbors = sep_between_neighbors)
253
+
254
+
255
+
256
+ for i in range(0, len(wikidata_dataset_permap)):
257
+ # get all candiates for phrases within the map
258
+ wikidata_candidates = wikidata_dataset_permap[i] # dataset for each map, list of [cand for each phrase]
259
+
260
+
261
+ # process candidates for each phrase
262
+ for wikidata_cand in wikidata_candidates:
263
+ if model_name == 'spatial_bert-base' or model_name == 'spatial_bert-large':
264
+ wikidata_emb = get_spatialbert_embedding(wikidata_cand, model)
265
+ else:
266
+ wikidata_emb = get_bert_embedding(wikidata_cand, model)
267
+
268
+ wikidata_dict_per_map['wikidata_emb_list'].append(wikidata_emb)
269
+ wikidata_dict_per_map['wikidata_uri_list'].append(wikidata_cand['uri'])
270
+
271
+
272
+ wikidata_dict_list.append(wikidata_dict_per_map)
273
+
274
+
275
+
276
+ for map_name in map_name_list:
277
+
278
+ print(map_name)
279
+
280
+
281
+ usgs_dataset = USGS_MapDataset(
282
+ data_file_path = '../data_processing/outputs/alignment_dir/map_' + map_name + '.json',
283
+ tokenizer = tokenizer,
284
+ distance_norm_factor = usgs_distance_norm_factor,
285
+ spatial_dist_fill = spatial_dist_fill,
286
+ sep_between_neighbors = sep_between_neighbors)
287
+
288
+
289
+ ret_list = disambiguify(model, model_name, usgs_dataset, wikidata_dict_list, candset_mode= candset_mode, if_use_distance = not if_no_spatial_distance, select_indices = None)
290
+
291
+ write_to_csv(out_dir, map_name, ret_list)
292
+
293
+ print('Done')
294
+
295
+
296
+ def main():
297
+ parser = argparse.ArgumentParser()
298
+ parser.add_argument('--model_name', type=str, default='spatial_bert-base')
299
+ parser.add_argument('--out_dir', type=str, default=None)
300
+ parser.add_argument('--map_type', type=str, default='usgs')
301
+ parser.add_argument('--candset_mode', type=str, default='all_map')
302
+
303
+ parser.add_argument('--usgs_distance_norm_factor', type=float, default = 1)
304
+ parser.add_argument('--spatial_dist_fill', type=float, default = 100)
305
+
306
+ parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
307
+ parser.add_argument('--no_spatial_distance', default=False, action='store_true')
308
+
309
+ parser.add_argument('--spatial_bert_weight_dir', type = str, default = None)
310
+ parser.add_argument('--spatial_bert_weight_name', type = str, default = None)
311
+
312
+ args = parser.parse_args()
313
+ print('\n')
314
+ print(args)
315
+ print('\n')
316
+
317
+ # out_dir not None, and out_dir does not exist, then create out_dir
318
+ if args.out_dir is not None and not os.path.isdir(args.out_dir):
319
+ os.makedirs(args.out_dir)
320
+
321
+ entity_linking_func(args)
322
+
323
+
324
+
325
+ if __name__ == '__main__':
326
+
327
+ main()
328
+
329
+
models/spabert/experiments/semantic_typing/__init__.py ADDED
File without changes
models/spabert/experiments/semantic_typing/data_processing/merge_osm_json.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import math
4
+ import glob
5
+ import re
6
+ import pdb
7
+
8
+ '''
9
+ NO LONGER NEEDED
10
+
11
+ Process the california, london, and minnesota OSM data and prepare pseudo-sentence, spatial context
12
+
13
+ Load the raw output files genrated by sql
14
+ Unify the json by changing the structure of dictionary
15
+ Save the output into two files, one for training+ validation, and the other one for testing
16
+ '''
17
+
18
+ region_list = ['california','london','minnesota']
19
+
20
+ input_json_dir = '../data/sql_output/sub_files/'
21
+ output_json_dir = '../data/sql_output/'
22
+
23
+ for region_name in region_list:
24
+ file_list = glob.glob(os.path.join(input_json_dir, 'spatialbert-osm-point-' + region_name + '*.json'))
25
+ file_list = sorted(file_list)
26
+ print('found %d files for region %s' % (len(file_list), region_name))
27
+
28
+
29
+ num_test_files = int(math.ceil(len(file_list) * 0.2))
30
+ num_train_val_files = len(file_list) - num_test_files
31
+
32
+ print('%d files for train-val' % num_train_val_files)
33
+ print('%d files for test-tes' % num_test_files)
34
+
35
+ train_val_output_path = os.path.join(output_json_dir + 'osm-point-' + region_name + '_train_val.json')
36
+ test_output_path = os.path.join(output_json_dir + 'osm-point-' + region_name + '_test.json')
37
+
38
+ # refresh the file
39
+ with open(train_val_output_path, 'w') as f:
40
+ pass
41
+ with open(test_output_path, 'w') as f:
42
+ pass
43
+
44
+ for idx in range(len(file_list)):
45
+
46
+ if idx < num_train_val_files:
47
+ output_path = train_val_output_path
48
+ else:
49
+ output_path = test_output_path
50
+
51
+ file_path = file_list[idx]
52
+
53
+ print(file_path)
54
+
55
+ with open(file_path, 'r') as f:
56
+ data = f.readlines()
57
+
58
+ line = data[0]
59
+
60
+ line = re.sub(r'\n', '', line)
61
+ line = re.sub(r'\\n', '', line)
62
+ line = re.sub(r'\\+', '', line)
63
+ line = re.sub(r'\+', '', line)
64
+
65
+ line_dict_list = json.loads(line)
66
+
67
+
68
+ for line_dict in line_dict_list:
69
+
70
+ line_dict = line_dict['json_build_object']
71
+
72
+ if not line_dict['name'][0].isalpha(): # discard record if the first char is not enghlish etter
73
+ continue
74
+
75
+ neighbor_name_list = line_dict['neighbor_info'][0]['name_list']
76
+ neighbor_geom_list = line_dict['neighbor_info'][0]['geometry_list']
77
+
78
+ assert(len(neighbor_geom_list) == len(neighbor_geom_list))
79
+
80
+ temp_dict = \
81
+ {'info':{'name':line_dict['name'],
82
+ 'geometry':{'coordinates':line_dict['geometry']},
83
+ 'class':line_dict['class']
84
+ },
85
+ 'neighbor_info':{'name_list': neighbor_name_list,
86
+ 'geometry_list': neighbor_geom_list
87
+ }
88
+ }
89
+
90
+ with open(output_path, 'a') as f:
91
+ json.dump(temp_dict, f)
92
+ f.write('\n')
93
+
94
+
95
+
96
+
97
+
models/spabert/experiments/semantic_typing/src/__init__.py ADDED
File without changes
models/spabert/experiments/semantic_typing/src/run_baseline_test.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import argparse
4
+ import numpy as np
5
+ import time
6
+
7
+ MODEL_OPTIONS = ['bert-base','bert-large','roberta-base','roberta-large',
8
+ 'spanbert-base','spanbert-large','luke-base','luke-large',
9
+ 'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large']
10
+
11
+ def execute_command(command, if_print_command):
12
+ t1 = time.time()
13
+
14
+ if if_print_command:
15
+ print(command)
16
+ os.system(command)
17
+
18
+ t2 = time.time()
19
+ time_usage = t2 - t1
20
+ return time_usage
21
+
22
+ def run_test(args):
23
+ weight_dir = args.weight_dir
24
+ backbone_option = args.backbone_option
25
+ gpu_id = str(args.gpu_id)
26
+ if_print_command = args.print_command
27
+ sep_between_neighbors = args.sep_between_neighbors
28
+
29
+ assert backbone_option in MODEL_OPTIONS
30
+
31
+ if sep_between_neighbors:
32
+ sep_str = '_sep'
33
+ else:
34
+ sep_str = ''
35
+
36
+ if 'large' in backbone_option:
37
+ checkpoint_dir = os.path.join(weight_dir, 'typing_lr1e-06_%s_nofreeze%s_london_california_bsize12'% (backbone_option, sep_str))
38
+ else:
39
+ checkpoint_dir = os.path.join(weight_dir, 'typing_lr5e-05_%s_nofreeze%s_london_california_bsize12'% (backbone_option, sep_str))
40
+ weight_files = os.listdir(checkpoint_dir)
41
+
42
+ val_loss_list = [weight_file.split('_')[-1] for weight_file in weight_files]
43
+ min_loss_weight = weight_files[np.argmin(val_loss_list)]
44
+
45
+ checkpoint_path = os.path.join(checkpoint_dir, min_loss_weight)
46
+
47
+ if sep_between_neighbors:
48
+ command = 'CUDA_VISIBLE_DEVICES=%s python3 test_cls_baseline.py --sep_between_neighbors --backbone_option=%s --batch_size=8 --with_type --checkpoint_path=%s ' % (gpu_id, backbone_option, checkpoint_path)
49
+ else:
50
+ command = 'CUDA_VISIBLE_DEVICES=%s python3 test_cls_baseline.py --backbone_option=%s --batch_size=8 --with_type --checkpoint_path=%s ' % (gpu_id, backbone_option, checkpoint_path)
51
+
52
+
53
+ execute_command(command, if_print_command)
54
+
55
+
56
+
57
+
58
+ def main():
59
+ parser = argparse.ArgumentParser()
60
+
61
+ parser.add_argument('--weight_dir', type=str, default='/data2/zekun/spatial_bert_baseline_weights/')
62
+ parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
63
+ parser.add_argument('--backbone_option', type=str, default=None)
64
+ parser.add_argument('--gpu_id', type=int, default=0) # output prefix
65
+
66
+ parser.add_argument('--print_command', default=False, action='store_true')
67
+
68
+
69
+ args = parser.parse_args()
70
+ print('\n')
71
+ print(args)
72
+ print('\n')
73
+
74
+ run_test(args)
75
+
76
+
77
+ if __name__ == '__main__':
78
+
79
+ main()
80
+
81
+
82
+
models/spabert/experiments/semantic_typing/src/test_cls_ablation_spatialbert.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from transformers import RobertaTokenizer, BertTokenizer
4
+ from tqdm import tqdm # for our progress bar
5
+ from transformers import AdamW
6
+
7
+ import torch
8
+ from torch.utils.data import DataLoader
9
+ import torch.nn.functional as F
10
+
11
+ sys.path.append('../../../')
12
+ from models.spatial_bert_model import SpatialBertModel
13
+ from models.spatial_bert_model import SpatialBertConfig
14
+ from models.spatial_bert_model import SpatialBertForMaskedLM, SpatialBertForSemanticTyping
15
+ from datasets.osm_sample_loader import PbfMapDataset
16
+ from datasets.const import *
17
+ from transformers.models.bert.modeling_bert import BertForMaskedLM
18
+
19
+ from sklearn.metrics import label_ranking_average_precision_score
20
+ from sklearn.metrics import precision_recall_fscore_support
21
+ import numpy as np
22
+ import argparse
23
+ from sklearn.preprocessing import LabelEncoder
24
+ import pdb
25
+
26
+
27
+ DEBUG = False
28
+ torch.backends.cudnn.deterministic = True
29
+ torch.backends.cudnn.benchmark = False
30
+ torch.manual_seed(42)
31
+ torch.cuda.manual_seed_all(42)
32
+
33
+
34
+ def testing(args):
35
+
36
+ max_token_len = args.max_token_len
37
+ batch_size = args.batch_size
38
+ num_workers = args.num_workers
39
+ distance_norm_factor = args.distance_norm_factor
40
+ spatial_dist_fill=args.spatial_dist_fill
41
+ with_type = args.with_type
42
+ sep_between_neighbors = args.sep_between_neighbors
43
+ checkpoint_path = args.checkpoint_path
44
+ if_no_spatial_distance = args.no_spatial_distance
45
+
46
+ bert_option = args.bert_option
47
+ num_neighbor_limit = args.num_neighbor_limit
48
+
49
+
50
+ london_file_path = '../data/sql_output/osm-point-london-typing.json'
51
+ california_file_path = '../data/sql_output/osm-point-california-typing.json'
52
+
53
+
54
+ if bert_option == 'bert-base':
55
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
56
+ config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, num_semantic_types=len(CLASS_9_LIST))
57
+ elif bert_option == 'bert-large':
58
+ tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
59
+ config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24,num_semantic_types=len(CLASS_9_LIST))
60
+ else:
61
+ raise NotImplementedError
62
+
63
+
64
+ model = SpatialBertForSemanticTyping(config)
65
+
66
+
67
+ label_encoder = LabelEncoder()
68
+ label_encoder.fit(CLASS_9_LIST)
69
+
70
+
71
+
72
+ london_dataset = PbfMapDataset(data_file_path = london_file_path,
73
+ tokenizer = tokenizer,
74
+ max_token_len = max_token_len,
75
+ distance_norm_factor = distance_norm_factor,
76
+ spatial_dist_fill = spatial_dist_fill,
77
+ with_type = with_type,
78
+ sep_between_neighbors = sep_between_neighbors,
79
+ label_encoder = label_encoder,
80
+ num_neighbor_limit = num_neighbor_limit,
81
+ mode = 'test',)
82
+
83
+ california_dataset = PbfMapDataset(data_file_path = california_file_path,
84
+ tokenizer = tokenizer,
85
+ max_token_len = max_token_len,
86
+ distance_norm_factor = distance_norm_factor,
87
+ spatial_dist_fill = spatial_dist_fill,
88
+ with_type = with_type,
89
+ sep_between_neighbors = sep_between_neighbors,
90
+ label_encoder = label_encoder,
91
+ num_neighbor_limit = num_neighbor_limit,
92
+ mode = 'test')
93
+
94
+ test_dataset = torch.utils.data.ConcatDataset([london_dataset, california_dataset])
95
+
96
+
97
+
98
+ test_loader = DataLoader(test_dataset, batch_size= batch_size, num_workers=num_workers,
99
+ shuffle=False, pin_memory=True, drop_last=False)
100
+
101
+
102
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
103
+ model.to(device)
104
+
105
+
106
+ model.load_state_dict(torch.load(checkpoint_path)) # #
107
+
108
+ model.eval()
109
+
110
+
111
+
112
+ print('start testing...')
113
+
114
+
115
+ # setup loop with TQDM and dataloader
116
+ loop = tqdm(test_loader, leave=True)
117
+
118
+
119
+ mrr_total = 0.
120
+ prec_total = 0.
121
+ sample_cnt = 0
122
+
123
+ gt_list = []
124
+ pred_list = []
125
+
126
+ for batch in loop:
127
+ # initialize calculated gradients (from prev step)
128
+
129
+ # pull all tensor batches required for training
130
+ input_ids = batch['pseudo_sentence'].to(device)
131
+ attention_mask = batch['attention_mask'].to(device)
132
+ position_list_x = batch['norm_lng_list'].to(device)
133
+ position_list_y = batch['norm_lat_list'].to(device)
134
+ sent_position_ids = batch['sent_position_ids'].to(device)
135
+
136
+ #labels = batch['pseudo_sentence'].to(device)
137
+ labels = batch['pivot_type'].to(device)
138
+ pivot_lens = batch['pivot_token_len'].to(device)
139
+
140
+ outputs = model(input_ids, attention_mask = attention_mask, sent_position_ids = sent_position_ids,
141
+ position_list_x = position_list_x, position_list_y = position_list_y, labels = labels, pivot_len_list = pivot_lens)
142
+
143
+
144
+ onehot_labels = F.one_hot(labels, num_classes=9)
145
+
146
+ gt_list.extend(onehot_labels.cpu().detach().numpy())
147
+ pred_list.extend(outputs.logits.cpu().detach().numpy())
148
+
149
+ #pdb.set_trace()
150
+ mrr = label_ranking_average_precision_score(onehot_labels.cpu().detach().numpy(), outputs.logits.cpu().detach().numpy())
151
+ mrr_total += mrr * input_ids.shape[0]
152
+ sample_cnt += input_ids.shape[0]
153
+
154
+ precisions, recalls, fscores, supports = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average=None)
155
+
156
+ precision, recall, f1, _ = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average='micro')
157
+
158
+ # print('precisions:\n', ["{:.3f}".format(prec) for prec in precisions])
159
+ # print('recalls:\n', ["{:.3f}".format(rec) for rec in recalls])
160
+ # print('fscores:\n', ["{:.3f}".format(f1) for f1 in fscores])
161
+ # print('supports:\n', supports)
162
+ print('micro P, micro R, micro F1', "{:.3f}".format(precision), "{:.3f}".format(recall), "{:.3f}".format(f1))
163
+
164
+
165
+
166
+ def main():
167
+
168
+ parser = argparse.ArgumentParser()
169
+
170
+ parser.add_argument('--max_token_len', type=int, default=300)
171
+ parser.add_argument('--batch_size', type=int, default=12)
172
+ parser.add_argument('--num_workers', type=int, default=5)
173
+
174
+ parser.add_argument('--num_neighbor_limit', type=int, default = None)
175
+
176
+ parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
177
+ parser.add_argument('--spatial_dist_fill', type=float, default = 20)
178
+
179
+
180
+ parser.add_argument('--with_type', default=False, action='store_true')
181
+ parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
182
+ parser.add_argument('--no_spatial_distance', default=False, action='store_true')
183
+
184
+ parser.add_argument('--bert_option', type=str, default='bert-base')
185
+ parser.add_argument('--prediction_save_dir', type=str, default=None)
186
+
187
+ parser.add_argument('--checkpoint_path', type=str, default=None)
188
+
189
+
190
+
191
+ args = parser.parse_args()
192
+ print('\n')
193
+ print(args)
194
+ print('\n')
195
+
196
+
197
+ # out_dir not None, and out_dir does not exist, then create out_dir
198
+ if args.prediction_save_dir is not None and not os.path.isdir(args.prediction_save_dir):
199
+ os.makedirs(args.prediction_save_dir)
200
+
201
+ testing(args)
202
+
203
+
204
+
205
+ if __name__ == '__main__':
206
+
207
+ main()
208
+
209
+
models/spabert/experiments/semantic_typing/src/test_cls_baseline.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from tqdm import tqdm # for our progress bar
4
+ import numpy as np
5
+ import argparse
6
+ from sklearn.preprocessing import LabelEncoder
7
+ import pdb
8
+
9
+
10
+ import torch
11
+ from torch.utils.data import DataLoader
12
+ from transformers import AdamW
13
+ import torch.nn.functional as F
14
+
15
+ sys.path.append('../../../')
16
+ from datasets.osm_sample_loader import PbfMapDataset
17
+ from datasets.const import *
18
+ from utils.baseline_utils import get_baseline_model
19
+ from models.baseline_typing_model import BaselineForSemanticTyping
20
+
21
+ from sklearn.metrics import label_ranking_average_precision_score
22
+ from sklearn.metrics import precision_recall_fscore_support
23
+
24
+ torch.backends.cudnn.deterministic = True
25
+ torch.backends.cudnn.benchmark = False
26
+ torch.manual_seed(42)
27
+ torch.cuda.manual_seed_all(42)
28
+
29
+
30
+
31
+ MODEL_OPTIONS = ['bert-base','bert-large','roberta-base','roberta-large',
32
+ 'spanbert-base','spanbert-large','luke-base','luke-large',
33
+ 'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large']
34
+
35
+ def testing(args):
36
+
37
+ num_workers = args.num_workers
38
+ batch_size = args.batch_size
39
+ max_token_len = args.max_token_len
40
+
41
+ distance_norm_factor = args.distance_norm_factor
42
+ spatial_dist_fill=args.spatial_dist_fill
43
+ with_type = args.with_type
44
+ sep_between_neighbors = args.sep_between_neighbors
45
+ freeze_backbone = args.freeze_backbone
46
+
47
+
48
+ backbone_option = args.backbone_option
49
+
50
+ checkpoint_path = args.checkpoint_path
51
+
52
+ assert(backbone_option in MODEL_OPTIONS)
53
+
54
+
55
+ london_file_path = '../data/sql_output/osm-point-london-typing.json'
56
+ california_file_path = '../data/sql_output/osm-point-california-typing.json'
57
+
58
+
59
+
60
+ label_encoder = LabelEncoder()
61
+ label_encoder.fit(CLASS_9_LIST)
62
+
63
+
64
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
65
+
66
+ backbone_model, tokenizer = get_baseline_model(backbone_option)
67
+ model = BaselineForSemanticTyping(backbone_model, backbone_model.config.hidden_size, len(CLASS_9_LIST))
68
+
69
+ model.load_state_dict(torch.load(checkpoint_path) ) #, strict = False # load sentence position embedding weights as well
70
+
71
+ model.to(device)
72
+ model.train()
73
+
74
+
75
+
76
+ london_dataset = PbfMapDataset(data_file_path = london_file_path,
77
+ tokenizer = tokenizer,
78
+ max_token_len = max_token_len,
79
+ distance_norm_factor = distance_norm_factor,
80
+ spatial_dist_fill = spatial_dist_fill,
81
+ with_type = with_type,
82
+ sep_between_neighbors = sep_between_neighbors,
83
+ label_encoder = label_encoder,
84
+ mode = 'test')
85
+
86
+
87
+ california_dataset = PbfMapDataset(data_file_path = california_file_path,
88
+ tokenizer = tokenizer,
89
+ max_token_len = max_token_len,
90
+ distance_norm_factor = distance_norm_factor,
91
+ spatial_dist_fill = spatial_dist_fill,
92
+ with_type = with_type,
93
+ sep_between_neighbors = sep_between_neighbors,
94
+ label_encoder = label_encoder,
95
+ mode = 'test')
96
+
97
+ test_dataset = torch.utils.data.ConcatDataset([london_dataset, california_dataset])
98
+
99
+
100
+ test_loader = DataLoader(test_dataset, batch_size= batch_size, num_workers=num_workers,
101
+ shuffle=False, pin_memory=True, drop_last=False)
102
+
103
+
104
+
105
+
106
+
107
+ print('start testing...')
108
+
109
+ # setup loop with TQDM and dataloader
110
+ loop = tqdm(test_loader, leave=True)
111
+
112
+ mrr_total = 0.
113
+ prec_total = 0.
114
+ sample_cnt = 0
115
+
116
+ gt_list = []
117
+ pred_list = []
118
+
119
+ for batch in loop:
120
+ # initialize calculated gradients (from prev step)
121
+
122
+ # pull all tensor batches required for training
123
+ input_ids = batch['pseudo_sentence'].to(device)
124
+ attention_mask = batch['attention_mask'].to(device)
125
+ position_ids = batch['sent_position_ids'].to(device)
126
+
127
+ #labels = batch['pseudo_sentence'].to(device)
128
+ labels = batch['pivot_type'].to(device)
129
+ pivot_lens = batch['pivot_token_len'].to(device)
130
+
131
+ outputs = model(input_ids, attention_mask = attention_mask, position_ids = position_ids,
132
+ labels = labels, pivot_len_list = pivot_lens)
133
+
134
+
135
+ onehot_labels = F.one_hot(labels, num_classes=9)
136
+
137
+ gt_list.extend(onehot_labels.cpu().detach().numpy())
138
+ pred_list.extend(outputs.logits.cpu().detach().numpy())
139
+
140
+ mrr = label_ranking_average_precision_score(onehot_labels.cpu().detach().numpy(), outputs.logits.cpu().detach().numpy())
141
+ mrr_total += mrr * input_ids.shape[0]
142
+ sample_cnt += input_ids.shape[0]
143
+
144
+ precisions, recalls, fscores, supports = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average=None)
145
+ precision, recall, f1, _ = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average='micro')
146
+ print('precisions:\n', ["{:.3f}".format(prec) for prec in precisions])
147
+ print('recalls:\n', ["{:.3f}".format(rec) for rec in recalls])
148
+ print('fscores:\n', ["{:.3f}".format(f1) for f1 in fscores])
149
+ print('supports:\n', supports)
150
+ print('micro P, micro R, micro F1', "{:.3f}".format(precision), "{:.3f}".format(recall), "{:.3f}".format(f1))
151
+
152
+ #pdb.set_trace()
153
+ #print(mrr_total/sample_cnt)
154
+
155
+
156
+
157
+ def main():
158
+
159
+ parser = argparse.ArgumentParser()
160
+ parser.add_argument('--num_workers', type=int, default=5)
161
+ parser.add_argument('--batch_size', type=int, default=12)
162
+ parser.add_argument('--max_token_len', type=int, default=300)
163
+
164
+
165
+ parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
166
+ parser.add_argument('--spatial_dist_fill', type=float, default = 20)
167
+
168
+ parser.add_argument('--with_type', default=False, action='store_true')
169
+ parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
170
+ parser.add_argument('--freeze_backbone', default=False, action='store_true')
171
+
172
+ parser.add_argument('--backbone_option', type=str, default='bert-base')
173
+ parser.add_argument('--checkpoint_path', type=str, default=None)
174
+
175
+
176
+ args = parser.parse_args()
177
+ print('\n')
178
+ print(args)
179
+ print('\n')
180
+
181
+
182
+ testing(args)
183
+
184
+
185
+ if __name__ == '__main__':
186
+
187
+ main()
188
+
189
+
models/spabert/experiments/semantic_typing/src/test_cls_spatialbert.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from transformers import RobertaTokenizer, BertTokenizer
4
+ from tqdm import tqdm # for our progress bar
5
+ from transformers import AdamW
6
+
7
+ import torch
8
+ from torch.utils.data import DataLoader
9
+ import torch.nn.functional as F
10
+
11
+ sys.path.append('../../../')
12
+ from models.spatial_bert_model import SpatialBertModel
13
+ from models.spatial_bert_model import SpatialBertConfig
14
+ from models.spatial_bert_model import SpatialBertForMaskedLM, SpatialBertForSemanticTyping
15
+ from datasets.osm_sample_loader import PbfMapDataset
16
+ from datasets.const import *
17
+ from transformers.models.bert.modeling_bert import BertForMaskedLM
18
+
19
+ from sklearn.metrics import label_ranking_average_precision_score
20
+ from sklearn.metrics import precision_recall_fscore_support
21
+ import numpy as np
22
+ import argparse
23
+ from sklearn.preprocessing import LabelEncoder
24
+ import pdb
25
+
26
+
27
+ DEBUG = False
28
+
29
+ torch.backends.cudnn.deterministic = True
30
+ torch.backends.cudnn.benchmark = False
31
+ torch.manual_seed(42)
32
+ torch.cuda.manual_seed_all(42)
33
+
34
+
35
+ def testing(args):
36
+
37
+ max_token_len = args.max_token_len
38
+ batch_size = args.batch_size
39
+ num_workers = args.num_workers
40
+ distance_norm_factor = args.distance_norm_factor
41
+ spatial_dist_fill=args.spatial_dist_fill
42
+ with_type = args.with_type
43
+ sep_between_neighbors = args.sep_between_neighbors
44
+ checkpoint_path = args.checkpoint_path
45
+ if_no_spatial_distance = args.no_spatial_distance
46
+
47
+ bert_option = args.bert_option
48
+
49
+
50
+
51
+ if args.num_classes == 9:
52
+ london_file_path = '../../semantic_typing/data/sql_output/osm-point-london-typing.json'
53
+ california_file_path = '../../semantic_typing/data/sql_output/osm-point-california-typing.json'
54
+ TYPE_LIST = CLASS_9_LIST
55
+ type_key_str = 'class'
56
+ elif args.num_classes == 74:
57
+ london_file_path = '../../semantic_typing/data/sql_output/osm-point-london-typing-ranking.json'
58
+ california_file_path = '../../semantic_typing/data/sql_output/osm-point-california-typing-ranking.json'
59
+ TYPE_LIST = CLASS_74_LIST
60
+ type_key_str = 'fine_class'
61
+ else:
62
+ raise NotImplementedError
63
+
64
+ if bert_option == 'bert-base':
65
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
66
+ config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, num_semantic_types=len(TYPE_LIST))
67
+ elif bert_option == 'bert-large':
68
+ tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
69
+ config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24,num_semantic_types=len(TYPE_LIST))
70
+ else:
71
+ raise NotImplementedError
72
+
73
+
74
+ model = SpatialBertForSemanticTyping(config)
75
+
76
+
77
+ label_encoder = LabelEncoder()
78
+ label_encoder.fit(TYPE_LIST)
79
+
80
+ london_dataset = PbfMapDataset(data_file_path = london_file_path,
81
+ tokenizer = tokenizer,
82
+ max_token_len = max_token_len,
83
+ distance_norm_factor = distance_norm_factor,
84
+ spatial_dist_fill = spatial_dist_fill,
85
+ with_type = with_type,
86
+ type_key_str = type_key_str,
87
+ sep_between_neighbors = sep_between_neighbors,
88
+ label_encoder = label_encoder,
89
+ mode = 'test')
90
+
91
+ california_dataset = PbfMapDataset(data_file_path = california_file_path,
92
+ tokenizer = tokenizer,
93
+ max_token_len = max_token_len,
94
+ distance_norm_factor = distance_norm_factor,
95
+ spatial_dist_fill = spatial_dist_fill,
96
+ with_type = with_type,
97
+ type_key_str = type_key_str,
98
+ sep_between_neighbors = sep_between_neighbors,
99
+ label_encoder = label_encoder,
100
+ mode = 'test')
101
+
102
+ test_dataset = torch.utils.data.ConcatDataset([london_dataset, california_dataset])
103
+
104
+
105
+
106
+ test_loader = DataLoader(test_dataset, batch_size= batch_size, num_workers=num_workers,
107
+ shuffle=False, pin_memory=True, drop_last=False)
108
+
109
+
110
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
111
+ model.to(device)
112
+
113
+ model.load_state_dict(torch.load(checkpoint_path))
114
+
115
+ model.eval()
116
+
117
+
118
+
119
+ print('start testing...')
120
+
121
+
122
+ # setup loop with TQDM and dataloader
123
+ loop = tqdm(test_loader, leave=True)
124
+
125
+
126
+ mrr_total = 0.
127
+ prec_total = 0.
128
+ sample_cnt = 0
129
+
130
+ gt_list = []
131
+ pred_list = []
132
+
133
+ for batch in loop:
134
+ # initialize calculated gradients (from prev step)
135
+
136
+ # pull all tensor batches required for training
137
+ input_ids = batch['pseudo_sentence'].to(device)
138
+ attention_mask = batch['attention_mask'].to(device)
139
+ position_list_x = batch['norm_lng_list'].to(device)
140
+ position_list_y = batch['norm_lat_list'].to(device)
141
+ sent_position_ids = batch['sent_position_ids'].to(device)
142
+
143
+ #labels = batch['pseudo_sentence'].to(device)
144
+ labels = batch['pivot_type'].to(device)
145
+ pivot_lens = batch['pivot_token_len'].to(device)
146
+
147
+ outputs = model(input_ids, attention_mask = attention_mask, sent_position_ids = sent_position_ids,
148
+ position_list_x = position_list_x, position_list_y = position_list_y, labels = labels, pivot_len_list = pivot_lens)
149
+
150
+
151
+ onehot_labels = F.one_hot(labels, num_classes=len(TYPE_LIST))
152
+
153
+ gt_list.extend(onehot_labels.cpu().detach().numpy())
154
+ pred_list.extend(outputs.logits.cpu().detach().numpy())
155
+
156
+ mrr = label_ranking_average_precision_score(onehot_labels.cpu().detach().numpy(), outputs.logits.cpu().detach().numpy())
157
+ mrr_total += mrr * input_ids.shape[0]
158
+ sample_cnt += input_ids.shape[0]
159
+
160
+ precisions, recalls, fscores, supports = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average=None)
161
+ # print('precisions:\n', precisions)
162
+ # print('recalls:\n', recalls)
163
+ # print('fscores:\n', fscores)
164
+ # print('supports:\n', supports)
165
+ precision, recall, f1, _ = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average='micro')
166
+ print('precisions:\n', ["{:.3f}".format(prec) for prec in precisions])
167
+ print('recalls:\n', ["{:.3f}".format(rec) for rec in recalls])
168
+ print('fscores:\n', ["{:.3f}".format(f1) for f1 in fscores])
169
+ print('supports:\n', supports)
170
+ print('micro P, micro R, micro F1', "{:.3f}".format(precision), "{:.3f}".format(recall), "{:.3f}".format(f1))
171
+
172
+
173
+
174
+
175
+ def main():
176
+
177
+ parser = argparse.ArgumentParser()
178
+
179
+ parser.add_argument('--max_token_len', type=int, default=512)
180
+ parser.add_argument('--batch_size', type=int, default=12)
181
+ parser.add_argument('--num_workers', type=int, default=5)
182
+
183
+ parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
184
+ parser.add_argument('--spatial_dist_fill', type=float, default = 100)
185
+ parser.add_argument('--num_classes', type=int, default = 9)
186
+
187
+ parser.add_argument('--with_type', default=False, action='store_true')
188
+ parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
189
+ parser.add_argument('--no_spatial_distance', default=False, action='store_true')
190
+
191
+ parser.add_argument('--bert_option', type=str, default='bert-base')
192
+ parser.add_argument('--prediction_save_dir', type=str, default=None)
193
+
194
+ parser.add_argument('--checkpoint_path', type=str, default=None)
195
+
196
+
197
+ args = parser.parse_args()
198
+ print('\n')
199
+ print(args)
200
+ print('\n')
201
+
202
+
203
+ # out_dir not None, and out_dir does not exist, then create out_dir
204
+ if args.prediction_save_dir is not None and not os.path.isdir(args.prediction_save_dir):
205
+ os.makedirs(args.prediction_save_dir)
206
+
207
+ testing(args)
208
+
209
+
210
+ if __name__ == '__main__':
211
+
212
+ main()
213
+
214
+
models/spabert/experiments/semantic_typing/src/train_cls_baseline.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from tqdm import tqdm # for our progress bar
4
+ import numpy as np
5
+ import argparse
6
+ from sklearn.preprocessing import LabelEncoder
7
+ import pdb
8
+
9
+
10
+ import torch
11
+ from torch.utils.data import DataLoader
12
+ from transformers import AdamW
13
+
14
+ sys.path.append('../../../')
15
+ from datasets.osm_sample_loader import PbfMapDataset
16
+ from datasets.const import *
17
+ from utils.baseline_utils import get_baseline_model
18
+ from models.baseline_typing_model import BaselineForSemanticTyping
19
+
20
+
21
+ MODEL_OPTIONS = ['bert-base','bert-large','roberta-base','roberta-large',
22
+ 'spanbert-base','spanbert-large','luke-base','luke-large',
23
+ 'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large']
24
+
25
+ def training(args):
26
+
27
+ num_workers = args.num_workers
28
+ batch_size = args.batch_size
29
+ epochs = args.epochs
30
+ lr = args.lr #1e-7 # 5e-5
31
+ save_interval = args.save_interval
32
+ max_token_len = args.max_token_len
33
+ distance_norm_factor = args.distance_norm_factor
34
+ spatial_dist_fill=args.spatial_dist_fill
35
+ with_type = args.with_type
36
+ sep_between_neighbors = args.sep_between_neighbors
37
+ freeze_backbone = args.freeze_backbone
38
+
39
+
40
+ backbone_option = args.backbone_option
41
+
42
+ assert(backbone_option in MODEL_OPTIONS)
43
+
44
+
45
+ london_file_path = '../data/sql_output/osm-point-london-typing.json'
46
+ california_file_path = '../data/sql_output/osm-point-california-typing.json'
47
+
48
+ if args.model_save_dir is None:
49
+ freeze_pathstr = '_freeze' if freeze_backbone else '_nofreeze'
50
+ sep_pathstr = '_sep' if sep_between_neighbors else '_nosep'
51
+ model_save_dir = '/data2/zekun/spatial_bert_baseline_weights/typing_lr' + str("{:.0e}".format(lr)) +'_'+backbone_option+ freeze_pathstr + sep_pathstr + '_london_california_bsize' + str(batch_size)
52
+
53
+ if not os.path.isdir(model_save_dir):
54
+ os.makedirs(model_save_dir)
55
+ else:
56
+ model_save_dir = args.model_save_dir
57
+
58
+ print('model_save_dir', model_save_dir)
59
+ print('\n')
60
+
61
+
62
+ label_encoder = LabelEncoder()
63
+ label_encoder.fit(CLASS_9_LIST)
64
+
65
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
66
+
67
+ backbone_model, tokenizer = get_baseline_model(backbone_option)
68
+ model = BaselineForSemanticTyping(backbone_model, backbone_model.config.hidden_size, len(CLASS_9_LIST))
69
+
70
+ model.to(device)
71
+ model.train()
72
+
73
+ london_train_val_dataset = PbfMapDataset(data_file_path = london_file_path,
74
+ tokenizer = tokenizer,
75
+ max_token_len = max_token_len,
76
+ distance_norm_factor = distance_norm_factor,
77
+ spatial_dist_fill = spatial_dist_fill,
78
+ with_type = with_type,
79
+ sep_between_neighbors = sep_between_neighbors,
80
+ label_encoder = label_encoder,
81
+ mode = 'train')
82
+
83
+ percent_80 = int(len(london_train_val_dataset) * 0.8)
84
+ london_train_dataset, london_val_dataset = torch.utils.data.random_split(london_train_val_dataset, [percent_80, len(london_train_val_dataset) - percent_80])
85
+
86
+ california_train_val_dataset = PbfMapDataset(data_file_path = california_file_path,
87
+ tokenizer = tokenizer,
88
+ max_token_len = max_token_len,
89
+ distance_norm_factor = distance_norm_factor,
90
+ spatial_dist_fill = spatial_dist_fill,
91
+ with_type = with_type,
92
+ sep_between_neighbors = sep_between_neighbors,
93
+ label_encoder = label_encoder,
94
+ mode = 'train')
95
+ percent_80 = int(len(california_train_val_dataset) * 0.8)
96
+ california_train_dataset, california_val_dataset = torch.utils.data.random_split(california_train_val_dataset, [percent_80, len(california_train_val_dataset) - percent_80])
97
+
98
+ train_dataset = torch.utils.data.ConcatDataset([london_train_dataset, california_train_dataset])
99
+ val_dataset = torch.utils.data.ConcatDataset([london_val_dataset, california_val_dataset])
100
+
101
+
102
+
103
+ train_loader = DataLoader(train_dataset, batch_size= batch_size, num_workers=num_workers,
104
+ shuffle=True, pin_memory=True, drop_last=True)
105
+ val_loader = DataLoader(val_dataset, batch_size= batch_size, num_workers=num_workers,
106
+ shuffle=False, pin_memory=True, drop_last=False)
107
+
108
+
109
+
110
+
111
+
112
+
113
+
114
+ # initialize optimizer
115
+ optim = AdamW(model.parameters(), lr = lr)
116
+
117
+ print('start training...')
118
+
119
+ for epoch in range(epochs):
120
+ # setup loop with TQDM and dataloader
121
+ loop = tqdm(train_loader, leave=True)
122
+ iter = 0
123
+ for batch in loop:
124
+ # initialize calculated gradients (from prev step)
125
+ optim.zero_grad()
126
+ # pull all tensor batches required for training
127
+ input_ids = batch['pseudo_sentence'].to(device)
128
+ attention_mask = batch['attention_mask'].to(device)
129
+ position_ids = batch['sent_position_ids'].to(device)
130
+
131
+ #labels = batch['pseudo_sentence'].to(device)
132
+ labels = batch['pivot_type'].to(device)
133
+ pivot_lens = batch['pivot_token_len'].to(device)
134
+
135
+ outputs = model(input_ids, attention_mask = attention_mask, position_ids = position_ids,
136
+ labels = labels, pivot_len_list = pivot_lens)
137
+
138
+
139
+ loss = outputs.loss
140
+ loss.backward()
141
+ optim.step()
142
+
143
+ loop.set_description(f'Epoch {epoch}')
144
+ loop.set_postfix({'loss':loss.item()})
145
+
146
+
147
+ iter += 1
148
+
149
+ if iter % save_interval == 0 or iter == loop.total:
150
+ loss_valid = validating(val_loader, model, device)
151
+ print('validation loss', loss_valid)
152
+
153
+ save_path = os.path.join(model_save_dir, 'ep'+str(epoch) + '_iter'+ str(iter).zfill(5) \
154
+ + '_' +str("{:.4f}".format(loss.item())) + '_val' + str("{:.4f}".format(loss_valid)) +'.pth' )
155
+
156
+ torch.save(model.state_dict(), save_path)
157
+ print('saving model checkpoint to', save_path)
158
+
159
+
160
+
161
+ def validating(val_loader, model, device):
162
+
163
+ with torch.no_grad():
164
+
165
+ loss_valid = 0
166
+ loop = tqdm(val_loader, leave=True)
167
+
168
+ for batch in loop:
169
+ input_ids = batch['pseudo_sentence'].to(device)
170
+ attention_mask = batch['attention_mask'].to(device)
171
+ position_ids = batch['sent_position_ids'].to(device)
172
+
173
+ labels = batch['pivot_type'].to(device)
174
+ pivot_lens = batch['pivot_token_len'].to(device)
175
+
176
+ outputs = model(input_ids, attention_mask = attention_mask, position_ids = position_ids,
177
+ labels = labels, pivot_len_list = pivot_lens)
178
+
179
+ loss_valid += outputs.loss
180
+
181
+ loss_valid /= len(val_loader)
182
+
183
+ return loss_valid
184
+
185
+
186
+ def main():
187
+
188
+ parser = argparse.ArgumentParser()
189
+ parser.add_argument('--num_workers', type=int, default=5)
190
+ parser.add_argument('--batch_size', type=int, default=12)
191
+ parser.add_argument('--epochs', type=int, default=10)
192
+ parser.add_argument('--save_interval', type=int, default=2000)
193
+ parser.add_argument('--max_token_len', type=int, default=300)
194
+
195
+
196
+ parser.add_argument('--lr', type=float, default = 5e-5)
197
+ parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
198
+ parser.add_argument('--spatial_dist_fill', type=float, default = 20)
199
+
200
+ parser.add_argument('--with_type', default=False, action='store_true')
201
+ parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
202
+ parser.add_argument('--freeze_backbone', default=False, action='store_true')
203
+
204
+ parser.add_argument('--backbone_option', type=str, default='bert-base')
205
+ parser.add_argument('--model_save_dir', type=str, default=None)
206
+
207
+
208
+
209
+ args = parser.parse_args()
210
+ print('\n')
211
+ print(args)
212
+ print('\n')
213
+
214
+
215
+ # out_dir not None, and out_dir does not exist, then create out_dir
216
+ if args.model_save_dir is not None and not os.path.isdir(args.model_save_dir):
217
+ os.makedirs(args.model_save_dir)
218
+
219
+ training(args)
220
+
221
+
222
+
223
+ if __name__ == '__main__':
224
+
225
+ main()
226
+
227
+
models/spabert/experiments/semantic_typing/src/train_cls_spatialbert.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from transformers import RobertaTokenizer, BertTokenizer
4
+ from tqdm import tqdm # for our progress bar
5
+ from transformers import AdamW
6
+
7
+ import torch
8
+ from torch.utils.data import DataLoader
9
+
10
+ sys.path.append('../../../')
11
+ from models.spatial_bert_model import SpatialBertModel
12
+ from models.spatial_bert_model import SpatialBertConfig
13
+ from models.spatial_bert_model import SpatialBertForMaskedLM, SpatialBertForSemanticTyping
14
+ from datasets.osm_sample_loader import PbfMapDataset
15
+ from datasets.const import *
16
+
17
+ from transformers.models.bert.modeling_bert import BertForMaskedLM
18
+
19
+ import numpy as np
20
+ import argparse
21
+ from sklearn.preprocessing import LabelEncoder
22
+ import pdb
23
+
24
+
25
+ DEBUG = False
26
+
27
+
28
+ def training(args):
29
+
30
+ num_workers = args.num_workers
31
+ batch_size = args.batch_size
32
+ epochs = args.epochs
33
+ lr = args.lr #1e-7 # 5e-5
34
+ save_interval = args.save_interval
35
+ max_token_len = args.max_token_len
36
+ distance_norm_factor = args.distance_norm_factor
37
+ spatial_dist_fill=args.spatial_dist_fill
38
+ with_type = args.with_type
39
+ sep_between_neighbors = args.sep_between_neighbors
40
+ freeze_backbone = args.freeze_backbone
41
+ mlm_checkpoint_path = args.mlm_checkpoint_path
42
+
43
+ if_no_spatial_distance = args.no_spatial_distance
44
+
45
+
46
+ bert_option = args.bert_option
47
+
48
+ assert bert_option in ['bert-base','bert-large']
49
+
50
+ if args.num_classes == 9:
51
+ london_file_path = '../../semantic_typing/data/sql_output/osm-point-london-typing.json'
52
+ california_file_path = '../../semantic_typing/data/sql_output/osm-point-california-typing.json'
53
+ TYPE_LIST = CLASS_9_LIST
54
+ type_key_str = 'class'
55
+ elif args.num_classes == 74:
56
+ london_file_path = '../../semantic_typing/data/sql_output/osm-point-london-typing-ranking.json'
57
+ california_file_path = '../../semantic_typing/data/sql_output/osm-point-california-typing-ranking.json'
58
+ TYPE_LIST = CLASS_74_LIST
59
+ type_key_str = 'fine_class'
60
+ else:
61
+ raise NotImplementedError
62
+
63
+
64
+ if args.model_save_dir is None:
65
+ checkpoint_basename = os.path.basename(mlm_checkpoint_path)
66
+ checkpoint_prefix = checkpoint_basename.replace("mlm_mem_keeppos_","").strip('.pth')
67
+
68
+ sep_pathstr = '_sep' if sep_between_neighbors else '_nosep'
69
+ freeze_pathstr = '_freeze' if freeze_backbone else '_nofreeze'
70
+ if if_no_spatial_distance:
71
+ model_save_dir = '/data2/zekun/spatial_bert_weights_ablation/'
72
+ else:
73
+ model_save_dir = '/data2/zekun/spatial_bert_weights/'
74
+ model_save_dir = os.path.join(model_save_dir, 'typing_lr' + str("{:.0e}".format(lr)) + sep_pathstr +'_'+bert_option+ freeze_pathstr + '_london_california_bsize' + str(batch_size) )
75
+ model_save_dir = os.path.join(model_save_dir, checkpoint_prefix)
76
+
77
+ if not os.path.isdir(model_save_dir):
78
+ os.makedirs(model_save_dir)
79
+ else:
80
+ model_save_dir = args.model_save_dir
81
+
82
+
83
+ print('model_save_dir', model_save_dir)
84
+ print('\n')
85
+
86
+ if bert_option == 'bert-base':
87
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
88
+ config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, num_semantic_types=len(TYPE_LIST))
89
+ elif bert_option == 'bert-large':
90
+ tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
91
+ config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24, num_semantic_types=len(TYPE_LIST))
92
+ else:
93
+ raise NotImplementedError
94
+
95
+
96
+
97
+ label_encoder = LabelEncoder()
98
+ label_encoder.fit(TYPE_LIST)
99
+
100
+
101
+ london_train_val_dataset = PbfMapDataset(data_file_path = london_file_path,
102
+ tokenizer = tokenizer,
103
+ max_token_len = max_token_len,
104
+ distance_norm_factor = distance_norm_factor,
105
+ spatial_dist_fill = spatial_dist_fill,
106
+ with_type = with_type,
107
+ type_key_str = type_key_str,
108
+ sep_between_neighbors = sep_between_neighbors,
109
+ label_encoder = label_encoder,
110
+ mode = 'train')
111
+
112
+ percent_80 = int(len(london_train_val_dataset) * 0.8)
113
+ london_train_dataset, london_val_dataset = torch.utils.data.random_split(london_train_val_dataset, [percent_80, len(london_train_val_dataset) - percent_80])
114
+
115
+ california_train_val_dataset = PbfMapDataset(data_file_path = california_file_path,
116
+ tokenizer = tokenizer,
117
+ max_token_len = max_token_len,
118
+ distance_norm_factor = distance_norm_factor,
119
+ spatial_dist_fill = spatial_dist_fill,
120
+ with_type = with_type,
121
+ type_key_str = type_key_str,
122
+ sep_between_neighbors = sep_between_neighbors,
123
+ label_encoder = label_encoder,
124
+ mode = 'train')
125
+ percent_80 = int(len(california_train_val_dataset) * 0.8)
126
+ california_train_dataset, california_val_dataset = torch.utils.data.random_split(california_train_val_dataset, [percent_80, len(california_train_val_dataset) - percent_80])
127
+
128
+ train_dataset = torch.utils.data.ConcatDataset([london_train_dataset, california_train_dataset])
129
+ val_dataset = torch.utils.data.ConcatDataset([london_val_dataset, california_val_dataset])
130
+
131
+
132
+ if DEBUG:
133
+ train_loader = DataLoader(train_dataset, batch_size= batch_size, num_workers=num_workers,
134
+ shuffle=False, pin_memory=True, drop_last=True)
135
+ val_loader = DataLoader(val_dataset, batch_size= batch_size, num_workers=num_workers,
136
+ shuffle=False, pin_memory=True, drop_last=False)
137
+ else:
138
+ train_loader = DataLoader(train_dataset, batch_size= batch_size, num_workers=num_workers,
139
+ shuffle=True, pin_memory=True, drop_last=True)
140
+ val_loader = DataLoader(val_dataset, batch_size= batch_size, num_workers=num_workers,
141
+ shuffle=False, pin_memory=True, drop_last=False)
142
+
143
+
144
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
145
+
146
+ model = SpatialBertForSemanticTyping(config)
147
+ model.to(device)
148
+
149
+
150
+ model.load_state_dict(torch.load(mlm_checkpoint_path), strict = False)
151
+
152
+ model.train()
153
+
154
+
155
+
156
+ # initialize optimizer
157
+ optim = AdamW(model.parameters(), lr = lr)
158
+
159
+ print('start training...')
160
+
161
+ for epoch in range(epochs):
162
+ # setup loop with TQDM and dataloader
163
+ loop = tqdm(train_loader, leave=True)
164
+ iter = 0
165
+ for batch in loop:
166
+ # initialize calculated gradients (from prev step)
167
+ optim.zero_grad()
168
+ # pull all tensor batches required for training
169
+ input_ids = batch['pseudo_sentence'].to(device)
170
+ attention_mask = batch['attention_mask'].to(device)
171
+ position_list_x = batch['norm_lng_list'].to(device)
172
+ position_list_y = batch['norm_lat_list'].to(device)
173
+ sent_position_ids = batch['sent_position_ids'].to(device)
174
+
175
+ #labels = batch['pseudo_sentence'].to(device)
176
+ labels = batch['pivot_type'].to(device)
177
+ pivot_lens = batch['pivot_token_len'].to(device)
178
+
179
+ outputs = model(input_ids, attention_mask = attention_mask, sent_position_ids = sent_position_ids,
180
+ position_list_x = position_list_x, position_list_y = position_list_y, labels = labels, pivot_len_list = pivot_lens)
181
+
182
+
183
+ loss = outputs.loss
184
+ loss.backward()
185
+ optim.step()
186
+
187
+ loop.set_description(f'Epoch {epoch}')
188
+ loop.set_postfix({'loss':loss.item()})
189
+
190
+ if DEBUG:
191
+ print('ep'+str(epoch)+'_' + '_iter'+ str(iter).zfill(5), loss.item() )
192
+
193
+ iter += 1
194
+
195
+ if iter % save_interval == 0 or iter == loop.total:
196
+ loss_valid = validating(val_loader, model, device)
197
+
198
+ save_path = os.path.join(model_save_dir, 'keeppos_ep'+str(epoch) + '_iter'+ str(iter).zfill(5) \
199
+ + '_' +str("{:.4f}".format(loss.item())) + '_val' + str("{:.4f}".format(loss_valid)) +'.pth' )
200
+
201
+ torch.save(model.state_dict(), save_path)
202
+ print('validation loss', loss_valid)
203
+ print('saving model checkpoint to', save_path)
204
+
205
+ def validating(val_loader, model, device):
206
+
207
+ with torch.no_grad():
208
+
209
+ loss_valid = 0
210
+ loop = tqdm(val_loader, leave=True)
211
+
212
+ for batch in loop:
213
+ input_ids = batch['pseudo_sentence'].to(device)
214
+ attention_mask = batch['attention_mask'].to(device)
215
+ position_list_x = batch['norm_lng_list'].to(device)
216
+ position_list_y = batch['norm_lat_list'].to(device)
217
+ sent_position_ids = batch['sent_position_ids'].to(device)
218
+
219
+
220
+ labels = batch['pivot_type'].to(device)
221
+ pivot_lens = batch['pivot_token_len'].to(device)
222
+
223
+ outputs = model(input_ids, attention_mask = attention_mask, sent_position_ids = sent_position_ids,
224
+ position_list_x = position_list_x, position_list_y = position_list_y, labels = labels, pivot_len_list = pivot_lens)
225
+
226
+ loss_valid += outputs.loss
227
+
228
+ loss_valid /= len(val_loader)
229
+
230
+ return loss_valid
231
+
232
+
233
+ def main():
234
+
235
+ parser = argparse.ArgumentParser()
236
+ parser.add_argument('--num_workers', type=int, default=5)
237
+ parser.add_argument('--batch_size', type=int, default=12)
238
+ parser.add_argument('--epochs', type=int, default=10)
239
+ parser.add_argument('--save_interval', type=int, default=2000)
240
+ parser.add_argument('--max_token_len', type=int, default=512)
241
+
242
+
243
+ parser.add_argument('--lr', type=float, default = 5e-5)
244
+ parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
245
+ parser.add_argument('--spatial_dist_fill', type=float, default = 100)
246
+ parser.add_argument('--num_classes', type=int, default = 9)
247
+
248
+ parser.add_argument('--with_type', default=False, action='store_true')
249
+ parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
250
+ parser.add_argument('--freeze_backbone', default=False, action='store_true')
251
+ parser.add_argument('--no_spatial_distance', default=False, action='store_true')
252
+
253
+ parser.add_argument('--bert_option', type=str, default='bert-base')
254
+ parser.add_argument('--model_save_dir', type=str, default=None)
255
+
256
+ parser.add_argument('--mlm_checkpoint_path', type=str, default=None)
257
+
258
+
259
+ args = parser.parse_args()
260
+ print('\n')
261
+ print(args)
262
+ print('\n')
263
+
264
+
265
+ # out_dir not None, and out_dir does not exist, then create out_dir
266
+ if args.model_save_dir is not None and not os.path.isdir(args.model_save_dir):
267
+ os.makedirs(args.model_save_dir)
268
+
269
+ training(args)
270
+
271
+
272
+ if __name__ == '__main__':
273
+
274
+ main()
275
+
276
+
models/spabert/models/__init__.py ADDED
File without changes
models/spabert/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (150 Bytes). View file
 
models/spabert/models/__pycache__/spatial_bert_model.cpython-310.pyc ADDED
Binary file (20.8 kB). View file
 
models/spabert/models/baseline_typing_model.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.utils.checkpoint
4
+ from torch import nn
5
+ from torch.nn import CrossEntropyLoss, MSELoss
6
+ from transformers.modeling_outputs import SequenceClassifierOutput
7
+
8
+
9
+ class PivotEntityPooler(nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+
13
+
14
+ def forward(self, hidden_states, pivot_len_list):
15
+ # We "pool" the model by simply taking the hidden state corresponding
16
+ # to the tokens of pivot entity
17
+
18
+ bsize = hidden_states.shape[0]
19
+
20
+ tensor_list = []
21
+ for i in torch.arange(0, bsize):
22
+ pivot_token_full = hidden_states[i, 1:pivot_len_list[i]+1]
23
+ pivot_token_tensor = torch.mean(torch.unsqueeze(pivot_token_full, 0), dim = 1)
24
+ tensor_list.append(pivot_token_tensor)
25
+
26
+
27
+ batch_pivot_tensor = torch.cat(tensor_list, dim = 0)
28
+
29
+ return batch_pivot_tensor
30
+
31
+
32
+ class BaselineTypingHead(nn.Module):
33
+ def __init__(self, hidden_size, num_semantic_types):
34
+ super().__init__()
35
+ self.dense = nn.Linear(hidden_size, hidden_size)
36
+ self.activation = nn.Tanh()
37
+
38
+ self.seq_relationship = nn.Linear(hidden_size, num_semantic_types)
39
+
40
+ def forward(self, pivot_pooled_output):
41
+
42
+ pivot_pooled_output = self.dense(pivot_pooled_output)
43
+ pivot_pooled_output = self.activation(pivot_pooled_output)
44
+
45
+ seq_relationship_score = self.seq_relationship(pivot_pooled_output)
46
+ return seq_relationship_score
47
+
48
+
49
+ class BaselineForSemanticTyping(nn.Module):
50
+
51
+ def __init__(self, backbone, hidden_size, num_semantic_types):
52
+ super().__init__()
53
+
54
+ self.backbone = backbone
55
+ self.pivot_pooler = PivotEntityPooler()
56
+ self.num_semantic_types = num_semantic_types
57
+
58
+ self.cls = BaselineTypingHead(hidden_size, num_semantic_types)
59
+
60
+
61
+ def forward(
62
+ self,
63
+ input_ids=None,
64
+ position_ids = None,
65
+ pivot_len_list = None,
66
+ attention_mask=None,
67
+ head_mask=None,
68
+ labels=None,
69
+ output_attentions=None,
70
+ output_hidden_states=None,
71
+ return_dict = True
72
+ ):
73
+
74
+ outputs = self.backbone(
75
+ input_ids,
76
+ attention_mask=attention_mask,
77
+ position_ids = position_ids,
78
+ head_mask=head_mask,
79
+ output_attentions=output_attentions,
80
+ output_hidden_states=output_hidden_states,
81
+ )
82
+
83
+
84
+ sequence_output = outputs[0]
85
+ pooled_output = self.pivot_pooler(sequence_output, pivot_len_list)
86
+
87
+
88
+ type_prediction_score = self.cls(pooled_output)
89
+
90
+ typing_loss = None
91
+ if labels is not None:
92
+ loss_fct = CrossEntropyLoss()
93
+ typing_loss = loss_fct(type_prediction_score.view(-1, self.num_semantic_types), labels.view(-1))
94
+
95
+
96
+
97
+ if not return_dict:
98
+ output = (type_prediction_score,) + outputs[2:]
99
+ return ((typing_loss,) + output) if typing_loss is not None else output
100
+
101
+ return SequenceClassifierOutput(
102
+ loss=typing_loss,
103
+ logits=type_prediction_score,
104
+ hidden_states=outputs.hidden_states,
105
+ attentions=outputs.attentions,
106
+ )