|
import torch |
|
import torch.nn as nn |
|
import pickle |
|
import os |
|
import torch.nn.functional as F |
|
|
|
from mamba_config import MambaConfig |
|
from mlp import MLP |
|
|
|
def sinkhorn(cost, tol=0.0001): |
|
"Sinkhorn based MoE routing function" |
|
cost = torch.exp(2.0 * cost) |
|
d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype) |
|
|
|
d1 = 1 / (cost.size(1) * torch.sum(cost, 0)) |
|
|
|
eps = 0.00000001 |
|
error = 1e9 |
|
d1_old = d1 |
|
while error > tol: |
|
d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps) |
|
d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps) |
|
error = torch.mean(torch.abs(d1_old - d1)) |
|
d1_old = d1 |
|
return d1 * cost * d0.unsqueeze(1) |
|
|
|
|
|
class SwitchMLP(nn.Module): |
|
""" |
|
Top-1 Mixture of Experts Layer. Routes input to one of N MLP "experts" |
|
Curently supports Sinkhorn based expert routing. |
|
""" |
|
|
|
def __init__(self, config: MambaConfig, layer_idx=None): |
|
super().__init__() |
|
|
|
self.layer = layer_idx |
|
self.config: MambaConfig = config |
|
if config.mamba_moe_layers: |
|
self.num_moe_experts = int(config.mamba_moe_layers[layer_idx-1][-1]) |
|
else: |
|
self.num_moe_experts = self.config.num_moe_experts |
|
self.router = torch.nn.Linear(self.config.hidden_size, self.num_moe_experts) |
|
self.add_bias = config.add_bias_linear |
|
self.routing = config.routing_mode |
|
self.route_algo = sinkhorn |
|
self.router_activation = torch.sigmoid |
|
|
|
self.num_local_experts = self.num_moe_experts |
|
self.local_expert_indices = [i for i in range(self.num_local_experts)] |
|
|
|
self.local_experts = torch.nn.ModuleList() |
|
for _ in range(self.num_local_experts): |
|
expert = MLP(self.config, is_expert=True, layer_idx=layer_idx) |
|
self.local_experts.append(expert) |
|
|
|
def gather_indices(self, local_indices): |
|
return local_indices |
|
|
|
def forward(self, hidden_states, inference_params=None): |
|
|
|
hidden_shape = hidden_states.shape |
|
route = self.router(hidden_states) |
|
route = route.view(-1, self.num_moe_experts) |
|
|
|
if self.routing == 'sinkhorn': |
|
route = self.router_activation(route) |
|
max_prob, max_ind = torch.max(route, dim=1) |
|
else: |
|
route = torch.softmax(route, dim=1) |
|
max_prob, max_ind = torch.max(route, dim=1) |
|
|
|
max_prob = torch.unsqueeze(max_prob, 1) |
|
hidden_states = hidden_states.view(-1, hidden_shape[-1]) |
|
|
|
global_hidden_states = hidden_states |
|
global_indices = max_ind |
|
output_total = torch.zeros_like(global_hidden_states) |
|
|
|
|
|
for expert_num, expert in enumerate(self.local_experts): |
|
local_expert_index = self.local_expert_indices[expert_num] |
|
local_indices = (global_indices == local_expert_index).nonzero() |
|
hidden = global_hidden_states[local_indices, :] |
|
output = expert(hidden) |
|
output_total[local_indices, :] = output |
|
|
|
output_total = output_total * max_prob |
|
output_total = output_total.view(hidden_shape) |
|
|
|
return output_total |
|
|