lstm_v1 / modeling_lstm.py
AlexHung29629's picture
Upload LstmForCausalLM
d7a28b9 verified
raw
history blame
6.49 kB
from typing import Optional
import torch
from torch import nn
from transformers import PreTrainedModel, GenerationMixin, AutoConfig, AutoModel, AutoModelForCausalLM
from transformers.modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput
from configuration_lstm import LstmConfig
class MLP(nn.Module):
def __init__(self, config: LstmConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = nn.SiLU()
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class LstmLayer(nn.Module):
def __init__(self, config: LstmConfig):
super().__init__()
self.lstm = nn.LSTM(config.hidden_size, config.hidden_size, num_layers=1, batch_first=True, bias=False)
self.mlp = MLP(config)
self.input_ln = nn.RMSNorm((config.hidden_size,), eps=1e-6)
self.post_ln = nn.RMSNorm((config.hidden_size,), eps=1e-6)
def forward(self, hidden_states):
lstm_part = self.input_ln(hidden_states)
lstm_part, _ = self.lstm(lstm_part)
hidden_states = hidden_states + lstm_part
mlp_part = self.post_ln(hidden_states)
mlp_part = self.mlp(mlp_part)
return hidden_states + mlp_part
class LstmPreTrainedModel(PreTrainedModel):
config_class = LstmConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LstmLayer"]
def _init_weights(self, module):
std = self.config.initializer_range
gain = self.config.initializer_gain
if isinstance(module, nn.Linear):
#nn.init.normal_(module.weight.data, std=std)
nn.init.kaiming_uniform_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight.data, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.RMSNorm):
module.weight.data.fill_(0.4)
elif isinstance(module, nn.LSTM):
for name, param in module.named_parameters():
if "weight" in name:
nn.init.xavier_uniform_(param, gain=gain)
elif "bias" in name:
with torch.no_grad():
param.zero_()
class LstmModel(LstmPreTrainedModel):
def __init__(self, config: LstmConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[LstmLayer(config) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = nn.RMSNorm((config.hidden_size,), eps=1e-6)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
**kwargs,
) -> BaseModelOutputWithNoAttention:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
hidden_states = self.embed_tokens(input_ids)
for block in self.layers:
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
)
else:
hidden_states = block(hidden_states)
last_hidden_state = self.norm(hidden_states)
return BaseModelOutputWithNoAttention(
last_hidden_state=last_hidden_state
)
class LstmForCausalLM(LstmPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: LstmConfig):
super().__init__(config)
self.model = LstmModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**kwargs,
):
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
hidden_states = self.model(input_ids, inputs_embeds).last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
return CausalLMOutput(
loss=loss,
logits=logits,
)
AutoConfig.register("lstm", LstmConfig)
AutoModel.register(LstmConfig, LstmModel)
AutoModelForCausalLM.register(LstmConfig, LstmForCausalLM)
LstmConfig.register_for_auto_class()
LstmModel.register_for_auto_class("AutoModel")
LstmForCausalLM.register_for_auto_class("AutoModelForCausalLM")