ZihanWang314 commited on
Commit
d3eed00
1 Parent(s): 3399ce9

Update modeling_deepseek.py

Browse files

convert dtype in ESFT so trainable experts of fp32 can be aggregated with frozen experts of bf16

Files changed (1) hide show
  1. modeling_deepseek.py +8 -1
modeling_deepseek.py CHANGED
@@ -388,7 +388,14 @@ class DeepseekV2MLP(nn.Module):
388
  self.act_fn = ACT2FN[config.hidden_act]
389
 
390
  def forward(self, x):
391
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
 
 
 
 
 
 
 
392
  return down_proj
393
 
394
 
 
388
  self.act_fn = ACT2FN[config.hidden_act]
389
 
390
  def forward(self, x):
391
+ # convert dtype in ESFT so trainable experts of fp32 can be aggregated with frozen experts of bf16
392
+ if x.dtype != self.up_proj.weight.dtype:
393
+ xdtype = x.dtype
394
+ x = x.to(self.up_proj.weight.dtype)
395
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
396
+ down_proj = down_proj.to(xdtype)
397
+ else:
398
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
399
  return down_proj
400
 
401