|
""" |
|
Patches to support multipack for mixtral |
|
""" |
|
import torch |
|
|
|
|
|
def patch_mixtral_moe_forward_zero3() -> None: |
|
import torch.nn.functional as F |
|
|
|
def mlp_forward(self, hidden_states): |
|
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3( |
|
hidden_states |
|
) |
|
current_hidden_states = self.w2(current_hidden_states) |
|
return current_hidden_states |
|
|
|
|
|
def moe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
batch_size, sequence_length, hidden_dim = hidden_states.shape |
|
hidden_states = hidden_states.view(-1, hidden_dim) |
|
|
|
router_logits = self.gate(hidden_states) |
|
|
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) |
|
topk_weight, topk_idx = torch.topk( |
|
routing_weights, self.top_k, dim=-1, sorted=False |
|
) |
|
topk_weight /= topk_weight.sum(dim=-1, keepdim=True) |
|
|
|
topk_weight = topk_weight.to(hidden_states.dtype) |
|
|
|
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0) |
|
y = torch.empty_like(hidden_states) |
|
flat_topk_idx = topk_idx.view(-1) |
|
for i in range(self.num_experts): |
|
expert = self.experts[i] |
|
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) |
|
y = ( |
|
y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1) |
|
).sum(dim=1) |
|
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim) |
|
return final_hidden_states, router_logits |
|
|
|
from transformers.models.mixtral.modeling_mixtral import ( |
|
MixtralBLockSparseTop2MLP, |
|
MixtralSparseMoeBlock, |
|
) |
|
|
|
MixtralBLockSparseTop2MLP.forward = mlp_forward |
|
MixtralSparseMoeBlock.forward = moe_forward |
|
|