rrivera1849 commited on
Commit
a2d0e7b
·
1 Parent(s): 376029b

Upload LUAR

Browse files
Files changed (1) hide show
  1. model.py +10 -5
model.py CHANGED
@@ -44,12 +44,13 @@ class LUAR(PreTrainedModel):
44
  def mean_pooling(self, token_embeddings, attention_mask):
45
  """Mean Pooling as described in the SBERT paper.
46
  """
47
- input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).float()
 
48
  sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum')
49
  sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9)
50
  return sum_embeddings / sum_mask
51
 
52
- def get_episode_embeddings(self, input_ids, attention_mask):
53
  """Computes the Author Embedding.
54
  """
55
  B, E, _ = attention_mask.shape
@@ -61,7 +62,8 @@ class LUAR(PreTrainedModel):
61
  input_ids=input_ids,
62
  attention_mask=attention_mask,
63
  return_dict=True,
64
- output_hidden_states=True
 
65
  )
66
 
67
  # at this point, we're embedding individual "comments"
@@ -74,11 +76,14 @@ class LUAR(PreTrainedModel):
74
 
75
  episode_embeddings = self.linear(episode_embeddings)
76
 
 
 
 
77
  return episode_embeddings
78
 
79
- def forward(self, input_ids, attention_mask):
80
  """Calculates a fixed-length feature vector for a batch of episode samples.
81
  """
82
- output = self.get_episode_embeddings(input_ids, attention_mask)
83
 
84
  return output
 
44
  def mean_pooling(self, token_embeddings, attention_mask):
45
  """Mean Pooling as described in the SBERT paper.
46
  """
47
+ # input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).float()
48
+ input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).type(token_embeddings.type())
49
  sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum')
50
  sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9)
51
  return sum_embeddings / sum_mask
52
 
53
+ def get_episode_embeddings(self, input_ids, attention_mask, output_attentions=False):
54
  """Computes the Author Embedding.
55
  """
56
  B, E, _ = attention_mask.shape
 
62
  input_ids=input_ids,
63
  attention_mask=attention_mask,
64
  return_dict=True,
65
+ output_hidden_states=True,
66
+ output_attentions=output_attentions,
67
  )
68
 
69
  # at this point, we're embedding individual "comments"
 
76
 
77
  episode_embeddings = self.linear(episode_embeddings)
78
 
79
+ if output_attentions:
80
+ return episode_embeddings, outputs["attentions"]
81
+
82
  return episode_embeddings
83
 
84
+ def forward(self, input_ids, attention_mask, output_attentions=False):
85
  """Calculates a fixed-length feature vector for a batch of episode samples.
86
  """
87
+ output = self.get_episode_embeddings(input_ids, attention_mask, output_attentions)
88
 
89
  return output