rrivera1849 commited on
Commit
8191fb5
·
1 Parent(s): 47450e4

Upload LUAR

Browse files
Files changed (1) hide show
  1. model.py +9 -5
model.py CHANGED
@@ -16,11 +16,15 @@ class SelfAttention(nn.Module):
16
  super(SelfAttention, self).__init__()
17
 
18
  def forward(self, k, q, v):
19
- d_k = q.size(-1)
20
- scores = torch.matmul(k, q.transpose(-2, -1)) / math.sqrt(d_k)
21
- p_attn = F.softmax(scores, dim=-1)
 
 
 
 
22
 
23
- return torch.matmul(p_attn, v)
24
 
25
  class LUAR(PreTrainedModel):
26
  """Defines the LUAR model.
@@ -85,4 +89,4 @@ class LUAR(PreTrainedModel):
85
  """
86
  output = self.get_episode_embeddings(input_ids, attention_mask, output_attentions)
87
 
88
- return output
 
16
  super(SelfAttention, self).__init__()
17
 
18
  def forward(self, k, q, v):
19
+ if hasattr(F, "scaled_dot_product_attention") and torch.cuda.is_available():
20
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True):
21
+ return F.scaled_dot_product_attention(k, q, v)
22
+ else:
23
+ d_k = q.size(-1)
24
+ scores = torch.matmul(k, q.transpose(-2, -1)) / math.sqrt(d_k)
25
+ p_attn = F.softmax(scores, dim=-1)
26
 
27
+ return torch.matmul(p_attn, v)
28
 
29
  class LUAR(PreTrainedModel):
30
  """Defines the LUAR model.
 
89
  """
90
  output = self.get_episode_embeddings(input_ids, attention_mask, output_attentions)
91
 
92
+ return output