|
from typing import Optional |
|
import torch |
|
from torch import nn |
|
from transformers import PreTrainedModel, GenerationMixin, AutoConfig, AutoModel, AutoModelForCausalLM |
|
from transformers.modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput |
|
|
|
from configuration_lstm import LstmConfig |
|
|
|
class MLP(nn.Module): |
|
def __init__(self, config: LstmConfig): |
|
super().__init__() |
|
self.config = config |
|
self.hidden_size = config.hidden_size |
|
self.intermediate_size = config.intermediate_size |
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
|
self.act_fn = nn.SiLU() |
|
|
|
def forward(self, x): |
|
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
return down_proj |
|
|
|
class LstmLayer(nn.Module): |
|
|
|
def __init__(self, config: LstmConfig): |
|
super().__init__() |
|
|
|
self.lstm = nn.LSTM(config.hidden_size, config.hidden_size, num_layers=1, batch_first=True, bias=False) |
|
self.mlp = MLP(config) |
|
self.input_ln = nn.RMSNorm((config.hidden_size,), eps=1e-6) |
|
self.post_ln = nn.RMSNorm((config.hidden_size,), eps=1e-6) |
|
|
|
def forward(self, hidden_states): |
|
lstm_part = self.input_ln(hidden_states) |
|
lstm_part, _ = self.lstm(lstm_part) |
|
hidden_states = hidden_states + lstm_part |
|
|
|
mlp_part = self.post_ln(hidden_states) |
|
mlp_part = self.mlp(mlp_part) |
|
return hidden_states + mlp_part |
|
|
|
class LstmPreTrainedModel(PreTrainedModel): |
|
config_class = LstmConfig |
|
base_model_prefix = "model" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["LstmLayer"] |
|
|
|
def _init_weights(self, module): |
|
std = self.config.initializer_range |
|
gain = self.config.initializer_gain |
|
if isinstance(module, nn.Linear): |
|
|
|
nn.init.kaiming_uniform_(module.weight.data) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
nn.init.normal_(module.weight.data, std=std) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, nn.RMSNorm): |
|
module.weight.data.fill_(0.4) |
|
elif isinstance(module, nn.LSTM): |
|
for name, param in module.named_parameters(): |
|
if "weight" in name: |
|
nn.init.xavier_uniform_(param, gain=gain) |
|
elif "bias" in name: |
|
with torch.no_grad(): |
|
param.zero_() |
|
|
|
class LstmModel(LstmPreTrainedModel): |
|
|
|
def __init__(self, config: LstmConfig): |
|
super().__init__(config) |
|
self.padding_idx = config.pad_token_id |
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
|
self.layers = nn.ModuleList( |
|
[LstmLayer(config) for layer_idx in range(config.num_hidden_layers)] |
|
) |
|
self.norm = nn.RMSNorm((config.hidden_size,), eps=1e-6) |
|
self.gradient_checkpointing = False |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
inputs_embeds: Optional[torch.LongTensor] = None, |
|
**kwargs, |
|
) -> BaseModelOutputWithNoAttention: |
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
|
if inputs_embeds is None: |
|
hidden_states = self.embed_tokens(input_ids) |
|
|
|
for block in self.layers: |
|
if self.gradient_checkpointing and self.training: |
|
hidden_states = self._gradient_checkpointing_func( |
|
block.__call__, |
|
hidden_states, |
|
) |
|
else: |
|
hidden_states = block(hidden_states) |
|
|
|
last_hidden_state = self.norm(hidden_states) |
|
return BaseModelOutputWithNoAttention( |
|
last_hidden_state=last_hidden_state |
|
) |
|
|
|
class LstmForCausalLM(LstmPreTrainedModel, GenerationMixin): |
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
def __init__(self, config: LstmConfig): |
|
super().__init__(config) |
|
self.model = LstmModel(config) |
|
self.vocab_size = config.vocab_size |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.model.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.embed_tokens = value |
|
|
|
def get_output_embeddings(self): |
|
return self.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.lm_head = new_embeddings |
|
|
|
def set_decoder(self, decoder): |
|
self.model = decoder |
|
|
|
def get_decoder(self): |
|
return self.model |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
num_logits_to_keep: int = 0, |
|
**kwargs, |
|
): |
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
|
hidden_states = self.model(input_ids, inputs_embeds).last_hidden_state |
|
|
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) |
|
|
|
return CausalLMOutput( |
|
loss=loss, |
|
logits=logits, |
|
) |
|
|
|
AutoConfig.register("lstm", LstmConfig) |
|
AutoModel.register(LstmConfig, LstmModel) |
|
AutoModelForCausalLM.register(LstmConfig, LstmForCausalLM) |
|
LstmConfig.register_for_auto_class() |
|
LstmModel.register_for_auto_class("AutoModel") |
|
LstmForCausalLM.register_for_auto_class("AutoModelForCausalLM") |