Update modeling_t2.py
Browse files- modeling_t2.py +1 -1
modeling_t2.py
CHANGED
@@ -235,7 +235,7 @@ class TransformerAttention(nn.Module):
|
|
235 |
k = torch.cat((past_key, k), dim=-2)
|
236 |
v = torch.cat((past_value, v), dim=-2)
|
237 |
|
238 |
-
cos, sin = self.rotary_emb(v,
|
239 |
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
|
240 |
|
241 |
if use_cache is True:
|
|
|
235 |
k = torch.cat((past_key, k), dim=-2)
|
236 |
v = torch.cat((past_value, v), dim=-2)
|
237 |
|
238 |
+
cos, sin = self.rotary_emb(v, position_ids)
|
239 |
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
|
240 |
|
241 |
if use_cache is True:
|