Update modeling_llama.py
Browse files- modeling_llama.py +18 -10
modeling_llama.py
CHANGED
@@ -166,14 +166,17 @@ def rotate_half(x):
|
|
166 |
return torch.cat((-x2, x1), dim=-1)
|
167 |
|
168 |
|
169 |
-
def apply_rotary_pos_emb(q, k, cos, sin,
|
170 |
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
171 |
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
172 |
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
|
|
|
|
|
|
177 |
return q_embed, k_embed
|
178 |
|
179 |
|
@@ -282,28 +285,33 @@ class LlamaAttention(nn.Module):
|
|
282 |
|
283 |
if past_key_value is not None:
|
284 |
kv_seq_len += past_key_value[0].shape[-2]
|
285 |
-
|
|
|
|
|
286 |
|
287 |
if pack_cos_sin is not None:
|
288 |
cos, sin = pack_cos_sin.to(query_states.device)
|
289 |
else:
|
290 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
291 |
key_position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=position_ids.device).unsqueeze(0).view(-1, kv_seq_len)
|
292 |
-
query_states, key_states = apply_rotary_pos_emb(query_states,
|
293 |
|
294 |
if past_key_value is not None:
|
295 |
# reuse k, v, self_attention
|
|
|
296 |
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
297 |
|
298 |
-
past_key_value = (
|
299 |
|
300 |
-
|
301 |
|
302 |
if self.log_scale:
|
303 |
log_n = torch.log(torch.tensor(kv_seq_len*1.0)).to(query_states.device, dtype=query_states.dtype) / \
|
304 |
torch.log(torch.tensor(self.config.max_position_embeddings)).to(query_states.device, dtype=query_states.dtype)
|
305 |
query_states = query_states * log_n
|
306 |
-
|
|
|
|
|
307 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
308 |
|
309 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
|
166 |
return torch.cat((-x2, x1), dim=-1)
|
167 |
|
168 |
|
169 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, key_position_ids):
|
170 |
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
171 |
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
172 |
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
173 |
+
cos_q = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
174 |
+
sin_q = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
175 |
+
|
176 |
+
cos_k = cos[key_position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
177 |
+
sin_k = sin[key_position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
178 |
+
q_embed = (q * cos_q) + (rotate_half(q) * sin_q)
|
179 |
+
k_embed = (k * cos_k) + (rotate_half(k) * sin_k)
|
180 |
return q_embed, k_embed
|
181 |
|
182 |
|
|
|
285 |
|
286 |
if past_key_value is not None:
|
287 |
kv_seq_len += past_key_value[0].shape[-2]
|
288 |
+
cache_key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
289 |
+
else:
|
290 |
+
cache_key_states = key_states
|
291 |
|
292 |
if pack_cos_sin is not None:
|
293 |
cos, sin = pack_cos_sin.to(query_states.device)
|
294 |
else:
|
295 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
296 |
key_position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=position_ids.device).unsqueeze(0).view(-1, kv_seq_len)
|
297 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, cache_key_states, cos, sin, position_ids, key_position_ids)
|
298 |
|
299 |
if past_key_value is not None:
|
300 |
# reuse k, v, self_attention
|
301 |
+
# key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
302 |
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
303 |
|
304 |
+
past_key_value = (cache_key_states, value_states) if use_cache else None
|
305 |
|
306 |
+
use_flashattn = self.config.use_flashattn and is_flash_attn_available()
|
307 |
|
308 |
if self.log_scale:
|
309 |
log_n = torch.log(torch.tensor(kv_seq_len*1.0)).to(query_states.device, dtype=query_states.dtype) / \
|
310 |
torch.log(torch.tensor(self.config.max_position_embeddings)).to(query_states.device, dtype=query_states.dtype)
|
311 |
query_states = query_states * log_n
|
312 |
+
|
313 |
+
|
314 |
+
if query_states.shape[-2] == 1 or query_states.shape[-2] != key_states.shape[-2] and not use_flashattn:
|
315 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
316 |
|
317 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|