JasonTPhillipsJr commited on
Commit
d914cbe
·
verified ·
1 Parent(s): dc9ff0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -19
app.py CHANGED
@@ -9,40 +9,32 @@ from models.spabert.utils.common_utils import load_spatial_bert_pretrained_weigh
9
 
10
  from PIL import Image
11
 
12
- ##LOAD MODEL SECTION##
 
 
13
  nlp = spacy.load("./models/en_core_web_sm")
14
 
15
  #BERT Initialization Section
16
  bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
17
  bert_model = BertModel.from_pretrained("bert-base-uncased")
 
18
  bert_model.eval()
19
 
20
  #SpaBERT Initialization Section
21
  data_file_path = 'models/spabert/datasets/SPABERT_finetuning_data_combined.json'
22
- pretrained_model = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth'
 
23
 
24
- b_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
25
- b_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
26
 
27
  config = SpatialBertConfig()
28
  config.output_hidden_states = True
29
-
30
  spaBERT_model = SpatialBertForMaskedLM(config)
31
- spaBERT_model.load_state_dict(b_model.state_dict(), strict = False)
32
 
33
- pre_trained_model = torch.load(pretrained_model, map_location=torch.device('cpu'))
34
-
35
- model_keys = spaBERT_model.state_dict()
36
- cnt_layers = 0
37
- for key in model_keys:
38
- if key in pre_trained_model:
39
- model_keys[key] = pre_trained_model[key]
40
- cnt_layers += 1
41
- else:
42
- print("No weight for", key)
43
- print(cnt_layers, 'layers loaded')
44
 
45
- spaBERT_model.load_state_dict(model_keys)
46
  spaBERT_model.to(device)
47
  spaBERT_model.eval()
48
 
@@ -69,7 +61,7 @@ COLOR_MAP = {
69
  'GPE': ('green', 'Geopolitical Entities (e.g., countries, cities)')
70
  }
71
 
72
- # Display the color key with descriptions
73
  st.write("**Color Key:**")
74
  for label, (color, description) in COLOR_MAP.items():
75
  st.markdown(f"- **{label}**: <span style='color:{color}'>{color}</span> - {description}", unsafe_allow_html=True)
 
9
 
10
  from PIL import Image
11
 
12
+ device = torch.device('cpu')
13
+
14
+ #Spacy Initialization Section
15
  nlp = spacy.load("./models/en_core_web_sm")
16
 
17
  #BERT Initialization Section
18
  bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
19
  bert_model = BertModel.from_pretrained("bert-base-uncased")
20
+ bert_model.to(device)
21
  bert_model.eval()
22
 
23
  #SpaBERT Initialization Section
24
  data_file_path = 'models/spabert/datasets/SPABERT_finetuning_data_combined.json'
25
+ pretrained_model_path = 'models/spabert/datasets/fine-spabert-base-uncased-finetuned-osm-mn.pth'
26
+ pre_trained_model = torch.load(pretrained_model_path, map_location=torch.device('cpu'))
27
 
28
+ #b_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
29
+ #b_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
30
 
31
  config = SpatialBertConfig()
32
  config.output_hidden_states = True
 
33
  spaBERT_model = SpatialBertForMaskedLM(config)
 
34
 
35
+ spaBERT_model.load_state_dict(bert_model.state_dict(), strict = False)
36
+ spaBERT_model.load_state_dict(pre_trained_model, strict=False)
 
 
 
 
 
 
 
 
 
37
 
 
38
  spaBERT_model.to(device)
39
  spaBERT_model.eval()
40
 
 
61
  'GPE': ('green', 'Geopolitical Entities (e.g., countries, cities)')
62
  }
63
 
64
+ # Display the color key
65
  st.write("**Color Key:**")
66
  for label, (color, description) in COLOR_MAP.items():
67
  st.markdown(f"- **{label}**: <span style='color:{color}'>{color}</span> - {description}", unsafe_allow_html=True)