wuzhiying2023 commited on
Commit
c6f8592
1 Parent(s): c21a0af

fix NormHead eval

Browse files
Files changed (1) hide show
  1. modeling_baichuan.py +1 -0
modeling_baichuan.py CHANGED
@@ -511,6 +511,7 @@ class NormHead(nn.Module):
511
  def forward(self, hidden_states):
512
  if self.training:
513
  norm_weight = nn.functional.normalize(self.weight)
 
514
  elif self.first_flag:
515
  self.first_flag = False
516
  self.weight.data = nn.functional.normalize(self.weight)
 
511
  def forward(self, hidden_states):
512
  if self.training:
513
  norm_weight = nn.functional.normalize(self.weight)
514
+ self.first_flag = True
515
  elif self.first_flag:
516
  self.first_flag = False
517
  self.weight.data = nn.functional.normalize(self.weight)