### copy from LIMoE #from distutils.command.config import config import os import torch import torch.nn as nn from torch.distributions.normal import Normal import torch.nn.functional as F import numpy as np from transformers.activations import ACT2FN from .adapter import Adapter from collections import OrderedDict from copy import deepcopy #-------------------# # MoE class MLP(nn.Module): def __init__(self, input_size:int, output_size:int, hidden_size:int): super(MLP, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.fc2 = nn.Linear(hidden_size, output_size) self.dropout = nn.Dropout(0.1) self.activation = ACT2FN["gelu"] self.log_soft = nn.LogSoftmax(1) self.apply(self.init_weights) def init_weights(self, m: nn.Module, std=1e-3): if isinstance(m, nn.Linear): torch.nn.init.normal_(m.weight, std=std) torch.nn.init.normal_(m.bias, std=std) m.weight.data = torch.clamp(m.weight.data, min=-2 * std, max=2 * std) m.bias.data = torch.clamp(m.bias.data, min=-2 * std, max=2 * std) elif isinstance(m, nn.LayerNorm): m.bias.data.zero_() m.weight.data.fill_(1.0) def forward(self, x): out = self.fc1(x) out = self.activation(out) out = self.dropout(out) out = self.fc2(out) out = self.log_soft(out) return out class SparseDispatcher(object): """Helper for implementing a mixture of experts. The purpose of this class is to create input minibatches for the experts and to combine the results of the experts to form a unified output tensor. There are two functions: dispatch - take an input Tensor and create input Tensors for each expert. combine - take output Tensors from each expert and form a combined output Tensor. Outputs from different experts for the same batch element are summed together, weighted by the provided "gates". The class is initialized with a "gates" Tensor, which specifies which batch elements go to which experts, and the weights to use when combining the outputs. Batch element b is sent to expert e iff gates[b, e] != 0. The inputs and outputs are all two-dimensional [batch, depth]. Caller is responsible for collapsing additional dimensions prior to calling this class and reshaping the output to the original shape. See common_layers.reshape_like(). Example use: gates: a float32 `Tensor` with shape `[batch_size, num_experts]` inputs: a float32 `Tensor` with shape `[batch_size, input_size]` experts: a list of length `num_experts` containing sub-networks. dispatcher = SparseDispatcher(num_experts, gates) expert_inputs = dispatcher.dispatch(inputs) expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)] outputs = dispatcher.combine(expert_outputs) The preceding code sets the output for a particular example b to: output[b] = Sum_i(gates[b, i] * experts[i](inputs[b])) This class takes advantage of sparsity in the gate matrix by including in the `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`. """ def __init__(self, num_experts, gates): """Create a SparseDispatcher.""" self._gates = gates self._num_experts = num_experts # sort experts sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0) # torch.nonzero: 返回非0坐标,按行、列依次排序 # drop indices _, self._expert_index = sorted_experts.split(1, dim=1) # get according batch index for each expert self._batch_index = sorted_experts[index_sorted_experts[:, 1],0] # calculate num samples that each expert gets self._part_sizes = list((gates > 0).sum(0).cpu().numpy()) # expand gates to match with self._batch_index gates_exp = gates[self._batch_index.flatten()] self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index) def dispatch(self, inp): """Create one input Tensor for each expert. The `Tensor` for a expert `i` contains the slices of `inp` corresponding to the batch elements `b` where `gates[b, i] > 0`. Args: inp: a `Tensor` of shape "[batch_size, ]` Returns: a list of `num_experts` `Tensor`s with shapes `[expert_batch_size_i, ]`. """ # assigns samples to experts whose gate is nonzero # expand according to batch index so we can just split by _part_sizes inp_exp = inp[self._batch_index].squeeze(1) return torch.split(inp_exp, self._part_sizes, dim=0) def combine(self, expert_out, multiply_by_gates=True): """Sum together the expert output, weighted by the gates. The slice corresponding to a particular batch element `b` is computed as the sum over all experts `i` of the expert output, weighted by the corresponding gate values. If `multiply_by_gates` is set to False, the gate values are ignored. Args: expert_out: a list of `num_experts` `Tensor`s, each with shape `[expert_batch_size_i, ]`. multiply_by_gates: a boolean Returns: a `Tensor` with shape `[batch_size, ]`. """ # apply exp to expert outputs, so we are not longer in log space #stitched = torch.cat(expert_out, 0).exp() stitched = torch.cat(expert_out, 0) if multiply_by_gates: if len(stitched.shape) == 3: stitched = stitched.mul(self._nonzero_gates.unsqueeze(1)) else: stitched = stitched.mul(self._nonzero_gates) if len(stitched.shape) == 3: zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), expert_out[-1].size(-1), requires_grad=True, device=stitched.device) else: zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), requires_grad=True, device=stitched.device) # combine samples that have been processed by the same k experts combined = zeros.index_add(0, self._batch_index, stitched.float()) # add eps to all zero values in order to avoid nans when going back to log space #combined[combined == 0] = np.finfo(float).eps # back to log space #return combined.log() return combined def expert_to_gates(self): """Gate values corresponding to the examples in the per-expert `Tensor`s. Returns: a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32` and shapes `[expert_batch_size_i]` """ # split nonzero gates for each expert return torch.split(self._nonzero_gates, self._part_sizes, dim=0) class MoE(nn.Module): """Call a Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts. Args: input_size: integer - size of the input output_size: integer - size of the input num_experts: an integer - number of experts hidden_size: an integer - hidden size of the experts noisy_gating: a boolean k: an integer - how many experts to use for each batch element """ def __init__(self, noisy_gating = True, ds_factor = 8.0, num_experts = 4, moe_input_size = 768, top_k = 2, dropout = 0.1, gating = 'linear', routing = None, layer_id = 0 ): super(MoE, self).__init__() self.noisy_gating = noisy_gating self.num_experts = num_experts self.input_size = moe_input_size self.k = top_k self.layer_id = layer_id # instantiate experts #self.experts = nn.ModuleList([MLP(self.input_size, self.output_size, self.hidden_size) for i in range(self.num_experts)]) self.gating = gating self.experts = nn.ModuleList([Adapter(ds_factor, moe_input_size, dropout=dropout) for i in range(self.num_experts)]) self.routing = routing self.infer_expert = None if self.routing != 'random': if gating == 'linear': #self.w_gate = nn.Linear(self.input_size, self.num_experts, bias=False) self.w_gate = nn.Parameter(torch.zeros(self.input_size, num_experts), requires_grad=True) elif gating == 'cosine': self.w_gate = CosineTopKGate(self.input_size, self.num_experts) self.w_noise = nn.Parameter(torch.zeros(self.input_size, self.num_experts), requires_grad=True) self.softplus = nn.Softplus() self.softmax = nn.Softmax(-1) self.register_buffer("mean", torch.tensor([0.0])) self.register_buffer("std", torch.tensor([1.0])) assert(self.k <= self.num_experts) def cv_squared(self, x): """The squared coefficient of variation of a sample. Useful as a loss to encourage a positive distribution to be more uniform. Epsilons added for numerical stability. Returns 0 for an empty Tensor. Args: x: a `Tensor`. Returns: a `Scalar`. """ eps = 1e-10 # if only num_experts = 1 if x.shape[0] == 1: return torch.Tensor([0]) if len(x.shape) == 2: x = x.sum(dim=0) return x.float().var() / (x.float().mean()**2 + eps) def _gates_to_load(self, gates): """Compute the true load per expert, given the gates. The load is the number of examples for which the corresponding gate is >0. Args: gates: a `Tensor` of shape [batch_size, n] Returns: a float32 `Tensor` of shape [n] """ return (gates > 0).sum(0) def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values): """Helper function to NoisyTopKGating. Computes the probability that value is in top k, given different random noise. This gives us a way of backpropagating from a loss that balances the number of times each expert is in the top k experts per example. In the case of no noise, pass in None for noise_stddev, and the result will not be differentiable. Args: clean_values: a `Tensor` of shape [batch, n]. noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus normally distributed noise with standard deviation noise_stddev. noise_stddev: a `Tensor` of shape [batch, n], or None noisy_top_values: a `Tensor` of shape [batch, m]. "values" Output of tf.top_k(noisy_top_values, m). m >= k+1 Returns: a `Tensor` of shape [batch, n]. """ batch = clean_values.size(0) m = noisy_top_values.size(1) top_values_flat = noisy_top_values.flatten() # (bs x m) threshold_positions_if_in = torch.arange(batch) * m + self.k # bs threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in.to(top_values_flat.device)), 1) if len(noisy_values.shape) == 3: threshold_if_in = threshold_if_in.unsqueeze(1) is_in = torch.gt(noisy_values, threshold_if_in) threshold_positions_if_out = threshold_positions_if_in - 1 threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat,0 , threshold_positions_if_out.to(top_values_flat.device)), 1) if len(noisy_values.shape) == 3: threshold_if_out = threshold_if_out.unsqueeze(1) # is each value currently in the top k. normal = Normal(self.mean.to(noise_stddev.device), self.std.to(noise_stddev.device)) prob_if_in = normal.cdf((clean_values - threshold_if_in)/noise_stddev) prob_if_out = normal.cdf((clean_values - threshold_if_out)/noise_stddev) prob = torch.where(is_in, prob_if_in, prob_if_out) return prob def random_k_gating(self, features, train): if train: idx = torch.randint(0, self.num_experts, 1) results = self.experts[idx](features) else: results = [] for i in range(self.num_experts): tmp = self.num_experts[i](features) results.append(tmp) results = torch.stack(results, dim=0).mean(dim=0) return results def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2): """Noisy top-k gating. See paper: https://arxiv.org/abs/1701.06538. Args: x: input Tensor with shape [batch_size, input_size] train: a boolean - we only add noise at training time. noise_epsilon: a float Returns: gates: a Tensor with shape [batch_size, num_experts] load: a Tensor with shape [num_experts] """ #clean_logits = self.w_gate(x) if self.gating == 'linear': clean_logits = x @ self.w_gate elif self.gating == 'cosine': clean_logits = self.w_gate(x) if self.noisy_gating and train: raw_noise_stddev = x @ self.w_noise noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon) * train) noisy_logits = clean_logits + ( torch.randn_like(clean_logits) * noise_stddev) logits = noisy_logits else: logits = clean_logits # logits (bs, n): 表示选择n中每个expert的概率 # 选k个experts,返回相应的下标以及logit top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim= -1) top_k_logits = top_logits[:, :self.k] if len(top_logits.shape) == 2 else top_logits[:, :, :self.k] top_k_indices = top_indices[:, :self.k] if len(top_indices.shape) == 2 else top_indices[:, :, :self.k] top_k_gates = self.softmax(top_k_logits) zeros = torch.zeros_like(logits, requires_grad=True) # 将经过softmax后的weight分配给相应的expert,未选定的expert的weight则为0 gates = zeros.scatter(-1, top_k_indices, top_k_gates) if self.noisy_gating and self.k < self.num_experts and train: load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0) else: load = self._gates_to_load(gates) return gates, load def forward(self, x, frame_features, train=True, loss_coef=1e-2): """Args: x: tensor shape [batch_size, input_size] train: a boolean scalar. loss_coef: a scalar - multiplier on load-balancing losses Returns: y: a tensor with shape [batch_size, output_size]. extra_training_loss: a scalar. This should be added into the overall training loss of the model. The backpropagation of this loss encourages all experts to be approximately equally used across a batch. """ if self.routing == 'random': loss = None load = None if train: gates = torch.zeros(x.shape[0], self.num_experts) random_idx = torch.randint(0, self.num_experts, (x.shape[0],)) gates[torch.arange(x.shape[0]), random_idx] = 1 gates = gates.to(x.device) dispatcher = SparseDispatcher(self.num_experts, gates) expert_inputs = dispatcher.dispatch(frame_features) # 获取每个expert的输入 gates = dispatcher.expert_to_gates() # expert_outputs = [self.experts[i](expert_inputs[i]) for i in range(self.num_experts)] y = dispatcher.combine(expert_outputs) else: if self.infer_expert is None: weights = [self.experts[i].state_dict() for i in range(self.num_experts)] merge_weights = OrderedDict() for idx, it in enumerate(weights): for k,v in it.items(): merge_weights[k] = v / self.num_experts if idx==0 else merge_weights[k] + v / self.num_experts self.infer_expert = deepcopy(self.experts[0]) self.infer_expert.load_state_dict(merge_weights) y = self.infer_expert(frame_features) return y, loss, load else: if len(x.shape) == 1: x = x.unsqueeze(0) gates, load = self.noisy_top_k_gating(x, train) # calculate importance loss importance = gates.sum(dim=0) # calculate loss loss = self.cv_squared(importance) + self.cv_squared(load) loss *= loss_coef dispatcher = SparseDispatcher(self.num_experts, gates) expert_inputs = dispatcher.dispatch(frame_features) # 获取每个expert的输入 gates = dispatcher.expert_to_gates() # 获取 expert_outputs = [self.experts[i](expert_inputs[i]) for i in range(self.num_experts)] y = dispatcher.combine(expert_outputs) return y, loss, load class CosineTopKGate(torch.nn.Module): def __init__(self, model_dim, num_global_experts, proj_dim=256, init_t=0.5): super(CosineTopKGate, self).__init__() self.temperature = torch.nn.Parameter(torch.log(torch.full([1], 1.0 / init_t)), requires_grad=True) self.cosine_projector = torch.nn.Linear(model_dim, proj_dim) self.sim_matrix = torch.nn.Parameter(torch.randn(size=(proj_dim, num_global_experts)), requires_grad=True) self.clamp_max = torch.log(torch.tensor(1. / 0.01)).item() torch.nn.init.normal_(self.sim_matrix, 0, 0.01) def forward(self, x): cosine_projector = self.cosine_projector sim_matrix = self.sim_matrix logits = torch.matmul(F.normalize(cosine_projector(x), dim=1), F.normalize(sim_matrix, dim=0)) logit_scale = torch.clamp(self.temperature, max=self.clamp_max).exp() logits = logits * logit_scale return logits ''' model = MoE() inputs = torch.randn((32, 1, 768)) frame_features = torch.randn((32,10, 768)) model(inputs, frame_features) '''