ZihanWang314
commited on
Commit
•
d3eed00
1
Parent(s):
3399ce9
Update modeling_deepseek.py
Browse filesconvert dtype in ESFT so trainable experts of fp32 can be aggregated with frozen experts of bf16
- modeling_deepseek.py +8 -1
modeling_deepseek.py
CHANGED
@@ -388,7 +388,14 @@ class DeepseekV2MLP(nn.Module):
|
|
388 |
self.act_fn = ACT2FN[config.hidden_act]
|
389 |
|
390 |
def forward(self, x):
|
391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
return down_proj
|
393 |
|
394 |
|
|
|
388 |
self.act_fn = ACT2FN[config.hidden_act]
|
389 |
|
390 |
def forward(self, x):
|
391 |
+
# convert dtype in ESFT so trainable experts of fp32 can be aggregated with frozen experts of bf16
|
392 |
+
if x.dtype != self.up_proj.weight.dtype:
|
393 |
+
xdtype = x.dtype
|
394 |
+
x = x.to(self.up_proj.weight.dtype)
|
395 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
396 |
+
down_proj = down_proj.to(xdtype)
|
397 |
+
else:
|
398 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
399 |
return down_proj
|
400 |
|
401 |
|