from transformers import ( PretrainedConfig, PreTrainedModel, AutoConfig, AutoModel, AutoModelForCausalLM, ) from torch import nn from hf.llama import CustomAttentionLLaMa class MyLLaMaConfig(PretrainedConfig): model_type = "LLaMa" def __init__( self, embed_dim: int = 1536, n_layers: int = 24, n_heads: int = 24, n_chckpnt_segments: int = 24, **kwargs, ): self.embed_dim = embed_dim self.n_layers = n_layers self.n_heads = n_heads self.n_chckpnt_segments = n_chckpnt_segments super().__init__(**kwargs) class MyLLaMa(PreTrainedModel): config_class = MyLLaMaConfig def __init__(self, config: MyLLaMaConfig): super().__init__(config) self.model = CustomAttentionLLaMa( config.embed_dim, config.n_layers, config.n_heads, dropout=0, n_chckpnt_segments=config.n_chckpnt_segments, ) def forward(self, tensor, labels=None): logits = self.model(tensor)["logits"] if labels is not None: loss = nn.functional.cross_entropy(logits, labels) return {"loss": loss, "logits": logits} return {"logits": logits} AutoConfig.register("LLaMa", MyLLaMaConfig) AutoModel.register(MyLLaMaConfig, MyLLaMa) AutoModelForCausalLM.register(MyLLaMaConfig, MyLLaMa)