File size: 1,623 Bytes
3ef28b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
from dataclasses import dataclass
from typing import Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import bias_gelu_impl
from mamba_config import MambaConfig
class MLP(nn.Module):
def __init__(
self, config: MambaConfig, is_expert: bool = False, layer_idx=None
):
super().__init__()
self.config: MambaConfig = config
self.layer = layer_idx
ffn_hidden_size_1 = self.config.ffn_hidden_size
ffn_hidden_size_2 = self.config.ffn_hidden_size
# If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf
if self.config.gated_linear_unit:
ffn_hidden_size_1 *= 2
self.linear_fc1 = nn.Linear(self.config.hidden_size, ffn_hidden_size_1, bias = self.config.add_bias_linear, device = self.config.device)
self.linear_fc1.is_expert = is_expert
if self.config.gated_linear_unit:
def glu(x):
x = torch.chunk(x, 2, dim=-1)
return self.config.activation_func(x[0]) * x[1]
self.activation_func = glu
else:
self.activation_func = self.config.activation_func
self.linear_fc2 = nn.Linear(ffn_hidden_size_2, self.config.hidden_size, bias = self.config.add_bias_linear, device = self.config.device)
def forward(self, hidden_states, inference_params=None):
intermediate = self.linear_fc1(hidden_states)
intermediate = self.activation_func(intermediate)
output = self.linear_fc2(intermediate)
return output |