training: fix type mismatch when training
#6
by
Jack477
- opened
- modeling_deepseek.py +1 -0
modeling_deepseek.py
CHANGED
@@ -577,6 +577,7 @@ class DeepseekV2MoE(nn.Module):
|
|
577 |
for i, expert in enumerate(self.experts):
|
578 |
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
|
579 |
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
|
|
580 |
y = y.view(*orig_shape)
|
581 |
y = AddAuxiliaryLoss.apply(y, aux_loss)
|
582 |
else:
|
|
|
577 |
for i, expert in enumerate(self.experts):
|
578 |
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
|
579 |
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
580 |
+
y = y.type(hidden_states.dtype)
|
581 |
y = y.view(*orig_shape)
|
582 |
y = AddAuxiliaryLoss.apply(y, aux_loss)
|
583 |
else:
|