Abdullah-Nazhat commited on
Commit
6a18f40
·
verified ·
1 Parent(s): 746885f

Update linearizer.py

Browse files
Files changed (1) hide show
  1. linearizer.py +6 -6
linearizer.py CHANGED
@@ -7,7 +7,7 @@ import math
7
 
8
 
9
 
10
- # helper functions
11
 
12
  def default(val, default_val):
13
  return val if val is not None else default_val
@@ -18,7 +18,7 @@ def init_(tensor):
18
  tensor.uniform_(-std, std)
19
  return tensor
20
 
21
- # helper classes
22
 
23
  class Residual(nn.Module):
24
  def __init__(self, fn):
@@ -110,25 +110,25 @@ class LinformerSelfAttention(nn.Module):
110
 
111
  kv_projs = (self.proj_k, self.proj_v if not self.share_kv else self.proj_k)
112
 
113
- # project keys and values along the sequence length dimension to k
114
 
115
  keys, values = map(proj_seq_len, zip((keys, values), kv_projs))
116
 
117
- # merge head into batch for queries and key / values
118
 
119
  queries = queries.reshape(b, n, h, -1).transpose(1, 2)
120
 
121
  merge_key_values = lambda t: t.reshape(b, k, -1, d_h).transpose(1, 2).expand(-1, h, -1, -1)
122
  keys, values = map(merge_key_values, (keys, values))
123
 
124
- # attention
125
 
126
  dots = torch.einsum('bhnd,bhkd->bhnk', queries, keys) * (d_h ** -0.5)
127
  attn = dots.softmax(dim=-1)
128
  attn = self.dropout(attn)
129
  out = torch.einsum('bhnk,bhkd->bhnd', attn, values)
130
 
131
- # split heads
132
  out = out.transpose(1, 2).reshape(b, n, -1)
133
  return self.to_out(out)
134
 
 
7
 
8
 
9
 
10
+
11
 
12
  def default(val, default_val):
13
  return val if val is not None else default_val
 
18
  tensor.uniform_(-std, std)
19
  return tensor
20
 
21
+
22
 
23
  class Residual(nn.Module):
24
  def __init__(self, fn):
 
110
 
111
  kv_projs = (self.proj_k, self.proj_v if not self.share_kv else self.proj_k)
112
 
113
+
114
 
115
  keys, values = map(proj_seq_len, zip((keys, values), kv_projs))
116
 
117
+
118
 
119
  queries = queries.reshape(b, n, h, -1).transpose(1, 2)
120
 
121
  merge_key_values = lambda t: t.reshape(b, k, -1, d_h).transpose(1, 2).expand(-1, h, -1, -1)
122
  keys, values = map(merge_key_values, (keys, values))
123
 
124
+
125
 
126
  dots = torch.einsum('bhnd,bhkd->bhnk', queries, keys) * (d_h ** -0.5)
127
  attn = dots.softmax(dim=-1)
128
  attn = self.dropout(attn)
129
  out = torch.einsum('bhnk,bhkd->bhnd', attn, values)
130
 
131
+
132
  out = out.transpose(1, 2).reshape(b, n, -1)
133
  return self.to_out(out)
134