Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -75,7 +75,7 @@ def process_entity(batch, model, device):
|
|
75 |
spaBERT_embedding = outputs.hidden_states[-1].to(device)
|
76 |
|
77 |
# Extract the [CLS] token embedding (first token)
|
78 |
-
spaBERT_embedding =
|
79 |
|
80 |
#pivot_token_len = batch['pivot_token_len'].item()
|
81 |
#pivot_embeddings = embeddings[:, :pivot_token_len, :]
|
@@ -87,11 +87,11 @@ spaBERT_embeddings = []
|
|
87 |
for i, batch in enumerate(data_loader):
|
88 |
if i >= 2: # Stop after processing 3 batches
|
89 |
break
|
90 |
-
|
91 |
-
|
92 |
|
93 |
-
st.write("SpaBERT Embedding shape:",
|
94 |
-
st.write("SpaBERT Embedding:",
|
95 |
|
96 |
|
97 |
|
|
|
75 |
spaBERT_embedding = outputs.hidden_states[-1].to(device)
|
76 |
|
77 |
# Extract the [CLS] token embedding (first token)
|
78 |
+
spaBERT_embedding = spaBERT_embedding[:, 0, :].detach() # [batch_size, hidden_size]
|
79 |
|
80 |
#pivot_token_len = batch['pivot_token_len'].item()
|
81 |
#pivot_embeddings = embeddings[:, :pivot_token_len, :]
|
|
|
87 |
for i, batch in enumerate(data_loader):
|
88 |
if i >= 2: # Stop after processing 3 batches
|
89 |
break
|
90 |
+
spaBERT_embedding, input_ids = process_entity(batch, spaBERT_model, device)
|
91 |
+
spaBERT_embeddings.append(spaBERT_embedding)
|
92 |
|
93 |
+
st.write("SpaBERT Embedding shape:", spaBERT_embedding[0].shape)
|
94 |
+
st.write("SpaBERT Embedding:", spaBERT_embedding[0])
|
95 |
|
96 |
|
97 |
|