Zilin Zhu commited on
Commit
c042756
·
1 Parent(s): 980d254
Files changed (1) hide show
  1. modeling_gpt2_summarize.py +1 -1
modeling_gpt2_summarize.py CHANGED
@@ -327,7 +327,7 @@ class GPT2Attention(nn.Module):
327
 
328
  if layer_past is not None:
329
  past_key, past_value = layer_past
330
- key = torch.cat((past_key, key), dim=-2)
331
  value = torch.cat((past_value, value), dim=-2)
332
 
333
  if use_cache is True:
 
327
 
328
  if layer_past is not None:
329
  past_key, past_value = layer_past
330
+ key = torch.cat((past_key, key), dim=-1)
331
  value = torch.cat((past_value, value), dim=-2)
332
 
333
  if use_cache is True: