Update modeling_phi.py
Browse files- modeling_phi.py +2 -2
modeling_phi.py
CHANGED
@@ -296,8 +296,8 @@ class MoE(nn.Module):
|
|
296 |
config: PretrainedConfig,
|
297 |
):
|
298 |
super().__init__()
|
299 |
-
self.mlp = nn.ModuleList([MLP(config) for i in range(config.
|
300 |
-
self.gate = nn.Linear(config.n_embd, config.
|
301 |
self.num_experts_per_tok = config.num_experts_per_tok
|
302 |
|
303 |
def forward(self, x):
|
|
|
296 |
config: PretrainedConfig,
|
297 |
):
|
298 |
super().__init__()
|
299 |
+
self.mlp = nn.ModuleList([MLP(config) for i in range(config.num_local_experts)])
|
300 |
+
self.gate = nn.Linear(config.n_embd, config.num_local_experts, bias=False)
|
301 |
self.num_experts_per_tok = config.num_experts_per_tok
|
302 |
|
303 |
def forward(self, x):
|