wuzhiying2023
commited on
Commit
•
c6f8592
1
Parent(s):
c21a0af
fix NormHead eval
Browse files- modeling_baichuan.py +1 -0
modeling_baichuan.py
CHANGED
@@ -511,6 +511,7 @@ class NormHead(nn.Module):
|
|
511 |
def forward(self, hidden_states):
|
512 |
if self.training:
|
513 |
norm_weight = nn.functional.normalize(self.weight)
|
|
|
514 |
elif self.first_flag:
|
515 |
self.first_flag = False
|
516 |
self.weight.data = nn.functional.normalize(self.weight)
|
|
|
511 |
def forward(self, hidden_states):
|
512 |
if self.training:
|
513 |
norm_weight = nn.functional.normalize(self.weight)
|
514 |
+
self.first_flag = True
|
515 |
elif self.first_flag:
|
516 |
self.first_flag = False
|
517 |
self.weight.data = nn.functional.normalize(self.weight)
|