Guanzheng commited on
Commit
08636e2
1 Parent(s): b0d0cf7

Update modeling_llama.py

Browse files
Files changed (1) hide show
  1. modeling_llama.py +18 -10
modeling_llama.py CHANGED
@@ -166,14 +166,17 @@ def rotate_half(x):
166
  return torch.cat((-x2, x1), dim=-1)
167
 
168
 
169
- def apply_rotary_pos_emb(q, k, cos, sin, q_len, position_ids):
170
  # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
171
  cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
172
  sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
173
- cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
174
- sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
175
- q_embed = (q * cos[:, :, -q_len:, :]) + (rotate_half(q) * sin[:, :, -q_len:, :])
176
- k_embed = (k * cos) + (rotate_half(k) * sin)
 
 
 
177
  return q_embed, k_embed
178
 
179
 
@@ -282,28 +285,33 @@ class LlamaAttention(nn.Module):
282
 
283
  if past_key_value is not None:
284
  kv_seq_len += past_key_value[0].shape[-2]
285
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
 
 
286
 
287
  if pack_cos_sin is not None:
288
  cos, sin = pack_cos_sin.to(query_states.device)
289
  else:
290
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
291
  key_position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=position_ids.device).unsqueeze(0).view(-1, kv_seq_len)
292
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, q_len, key_position_ids)
293
 
294
  if past_key_value is not None:
295
  # reuse k, v, self_attention
 
296
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
297
 
298
- past_key_value = (key_states, value_states) if use_cache else None
299
 
300
- use_flashatn = self.config.use_flashattn and is_flash_attn_available()
301
 
302
  if self.log_scale:
303
  log_n = torch.log(torch.tensor(kv_seq_len*1.0)).to(query_states.device, dtype=query_states.dtype) / \
304
  torch.log(torch.tensor(self.config.max_position_embeddings)).to(query_states.device, dtype=query_states.dtype)
305
  query_states = query_states * log_n
306
- if query_states.shape[-2] == 1 or query_states.shape[-2] != key_states.shape[-2] or use_flashatn:
 
 
307
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
308
 
309
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
 
166
  return torch.cat((-x2, x1), dim=-1)
167
 
168
 
169
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, key_position_ids):
170
  # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
171
  cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
172
  sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
173
+ cos_q = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
174
+ sin_q = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
175
+
176
+ cos_k = cos[key_position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
177
+ sin_k = sin[key_position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
178
+ q_embed = (q * cos_q) + (rotate_half(q) * sin_q)
179
+ k_embed = (k * cos_k) + (rotate_half(k) * sin_k)
180
  return q_embed, k_embed
181
 
182
 
 
285
 
286
  if past_key_value is not None:
287
  kv_seq_len += past_key_value[0].shape[-2]
288
+ cache_key_states = torch.cat([past_key_value[0], key_states], dim=2)
289
+ else:
290
+ cache_key_states = key_states
291
 
292
  if pack_cos_sin is not None:
293
  cos, sin = pack_cos_sin.to(query_states.device)
294
  else:
295
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
296
  key_position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=position_ids.device).unsqueeze(0).view(-1, kv_seq_len)
297
+ query_states, key_states = apply_rotary_pos_emb(query_states, cache_key_states, cos, sin, position_ids, key_position_ids)
298
 
299
  if past_key_value is not None:
300
  # reuse k, v, self_attention
301
+ # key_states = torch.cat([past_key_value[0], key_states], dim=2)
302
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
303
 
304
+ past_key_value = (cache_key_states, value_states) if use_cache else None
305
 
306
+ use_flashattn = self.config.use_flashattn and is_flash_attn_available()
307
 
308
  if self.log_scale:
309
  log_n = torch.log(torch.tensor(kv_seq_len*1.0)).to(query_states.device, dtype=query_states.dtype) / \
310
  torch.log(torch.tensor(self.config.max_position_embeddings)).to(query_states.device, dtype=query_states.dtype)
311
  query_states = query_states * log_n
312
+
313
+
314
+ if query_states.shape[-2] == 1 or query_states.shape[-2] != key_states.shape[-2] and not use_flashattn:
315
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
316
 
317
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):