wzk1015 favor123 commited on
Commit
a4c0f7d
1 Parent(s): 82e212b

Update modeling_internlm2_ve.py (#6)

Browse files

- Update modeling_internlm2_ve.py (40e5dc42cdbd98aafc20738caadd02a89eb63e3e)


Co-authored-by: Gen Luo <favor123@users.noreply.huggingface.co>

Files changed (1) hide show
  1. modeling_internlm2_ve.py +11 -13
modeling_internlm2_ve.py CHANGED
@@ -689,20 +689,18 @@ class InternLM2DecoderLayer(nn.Module):
689
  hidden_states = self.ffn_norm(hidden_states)
690
 
691
  if past_key_value is None:
692
- """
693
- *************************
694
- maybe faster
695
- ***************************
696
- """
697
- ##############################################################################################################
698
- # dim=hidden_states.shape[-1]
699
- # visual_token_mask=visual_token_mask.repeat(1,1,dim)
700
- # if visual_token_mask.any():
701
- # hidden_states[visual_token_mask] = self.feed_forward_ve(hidden_states[visual_token_mask].reshape(-1,dim)).reshape(-1)
702
- # if (~visual_token_mask).any():
703
- # hidden_states[~visual_token_mask] = self.feed_forward(hidden_states[~visual_token_mask].reshape(-1,dim)).reshape(-1)
704
  ##############################################################################################################
705
- hidden_states = self.feed_forward(hidden_states)*(1.-visual_token_mask)+ self.feed_forward_ve(hidden_states)*visual_token_mask
706
  else:
707
  hidden_states = self.feed_forward(hidden_states)
708
 
 
689
  hidden_states = self.ffn_norm(hidden_states)
690
 
691
  if past_key_value is None:
692
+ ##########################################--modified by luogen--##############################################
693
+ if self.training:
694
+ hidden_states = self.feed_forward(hidden_states)*(1.-visual_token_mask)+ self.feed_forward_ve(hidden_states)*visual_token_mask
695
+ else:
696
+ dim=hidden_states.shape[-1]
697
+ visual_token_mask=visual_token_mask.repeat(1,1,dim).bool()
698
+ non_visual_token_mask=~visual_token_mask
699
+ if visual_token_mask.any():
700
+ hidden_states[visual_token_mask] = self.feed_forward_ve(hidden_states[visual_token_mask].reshape(-1,dim)).reshape(-1)
701
+ if (non_visual_token_mask).any():
702
+ hidden_states[non_visual_token_mask] = self.feed_forward(hidden_states[non_visual_token_mask].reshape(-1,dim)).reshape(-1)
 
703
  ##############################################################################################################
 
704
  else:
705
  hidden_states = self.feed_forward(hidden_states)
706