rrivera1849
commited on
Commit
·
a2d0e7b
1
Parent(s):
376029b
Upload LUAR
Browse files
model.py
CHANGED
@@ -44,12 +44,13 @@ class LUAR(PreTrainedModel):
|
|
44 |
def mean_pooling(self, token_embeddings, attention_mask):
|
45 |
"""Mean Pooling as described in the SBERT paper.
|
46 |
"""
|
47 |
-
input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).float()
|
|
|
48 |
sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum')
|
49 |
sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9)
|
50 |
return sum_embeddings / sum_mask
|
51 |
|
52 |
-
def get_episode_embeddings(self, input_ids, attention_mask):
|
53 |
"""Computes the Author Embedding.
|
54 |
"""
|
55 |
B, E, _ = attention_mask.shape
|
@@ -61,7 +62,8 @@ class LUAR(PreTrainedModel):
|
|
61 |
input_ids=input_ids,
|
62 |
attention_mask=attention_mask,
|
63 |
return_dict=True,
|
64 |
-
output_hidden_states=True
|
|
|
65 |
)
|
66 |
|
67 |
# at this point, we're embedding individual "comments"
|
@@ -74,11 +76,14 @@ class LUAR(PreTrainedModel):
|
|
74 |
|
75 |
episode_embeddings = self.linear(episode_embeddings)
|
76 |
|
|
|
|
|
|
|
77 |
return episode_embeddings
|
78 |
|
79 |
-
def forward(self, input_ids, attention_mask):
|
80 |
"""Calculates a fixed-length feature vector for a batch of episode samples.
|
81 |
"""
|
82 |
-
output = self.get_episode_embeddings(input_ids, attention_mask)
|
83 |
|
84 |
return output
|
|
|
44 |
def mean_pooling(self, token_embeddings, attention_mask):
|
45 |
"""Mean Pooling as described in the SBERT paper.
|
46 |
"""
|
47 |
+
# input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).float()
|
48 |
+
input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).type(token_embeddings.type())
|
49 |
sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum')
|
50 |
sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9)
|
51 |
return sum_embeddings / sum_mask
|
52 |
|
53 |
+
def get_episode_embeddings(self, input_ids, attention_mask, output_attentions=False):
|
54 |
"""Computes the Author Embedding.
|
55 |
"""
|
56 |
B, E, _ = attention_mask.shape
|
|
|
62 |
input_ids=input_ids,
|
63 |
attention_mask=attention_mask,
|
64 |
return_dict=True,
|
65 |
+
output_hidden_states=True,
|
66 |
+
output_attentions=output_attentions,
|
67 |
)
|
68 |
|
69 |
# at this point, we're embedding individual "comments"
|
|
|
76 |
|
77 |
episode_embeddings = self.linear(episode_embeddings)
|
78 |
|
79 |
+
if output_attentions:
|
80 |
+
return episode_embeddings, outputs["attentions"]
|
81 |
+
|
82 |
return episode_embeddings
|
83 |
|
84 |
+
def forward(self, input_ids, attention_mask, output_attentions=False):
|
85 |
"""Calculates a fixed-length feature vector for a batch of episode samples.
|
86 |
"""
|
87 |
+
output = self.get_episode_embeddings(input_ids, attention_mask, output_attentions)
|
88 |
|
89 |
return output
|