JasonTPhillipsJr commited on
Commit
0cea6d5
·
verified ·
1 Parent(s): 857dba3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -75,7 +75,7 @@ def process_entity(batch, model, device):
75
  spaBERT_embedding = outputs.hidden_states[-1].to(device)
76
 
77
  # Extract the [CLS] token embedding (first token)
78
- spaBERT_embedding = embeddings[:, 0, :].detach() # [batch_size, hidden_size]
79
 
80
  #pivot_token_len = batch['pivot_token_len'].item()
81
  #pivot_embeddings = embeddings[:, :pivot_token_len, :]
@@ -87,11 +87,11 @@ spaBERT_embeddings = []
87
  for i, batch in enumerate(data_loader):
88
  if i >= 2: # Stop after processing 3 batches
89
  break
90
- embeddings, input_ids = process_entity(batch, spaBERT_model, device)
91
- all_embeddings.append(embeddings)
92
 
93
- st.write("SpaBERT Embedding shape:", all_embeddings[0].shape)
94
- st.write("SpaBERT Embedding:", all_embeddings[0])
95
 
96
 
97
 
 
75
  spaBERT_embedding = outputs.hidden_states[-1].to(device)
76
 
77
  # Extract the [CLS] token embedding (first token)
78
+ spaBERT_embedding = spaBERT_embedding[:, 0, :].detach() # [batch_size, hidden_size]
79
 
80
  #pivot_token_len = batch['pivot_token_len'].item()
81
  #pivot_embeddings = embeddings[:, :pivot_token_len, :]
 
87
  for i, batch in enumerate(data_loader):
88
  if i >= 2: # Stop after processing 3 batches
89
  break
90
+ spaBERT_embedding, input_ids = process_entity(batch, spaBERT_model, device)
91
+ spaBERT_embeddings.append(spaBERT_embedding)
92
 
93
+ st.write("SpaBERT Embedding shape:", spaBERT_embedding[0].shape)
94
+ st.write("SpaBERT Embedding:", spaBERT_embedding[0])
95
 
96
 
97