import torch import torch.nn as nn import torch.nn.functional as F from transformers.modeling_utils import PreTrainedModel # Define your custom language model class class OBILanguageModel(PreTrainedModel): def __init__(self, config): super(OBILanguageModel,self).__init__(config) self.token_embedding_table = nn.Embedding(config.vocab_size, config.hidden_size) # Use length of SentencePiece vocab self.position_embedding_table = nn.Embedding(config.block_size, config.hidden_size) self.transformer = nn.Transformer( d_model=config.hidden_size, nhead=config.num_attention_heads, num_encoder_layers=config.num_hidden_layers, num_decoder_layers=config.num_hidden_layers, dim_feedforward=4 * config.hidden_size, dropout=config.hidden_dropout_prob, activation='gelu' ) self.ln1 = nn.LayerNorm(config.hidden_size) self.ln2 = nn.LayerNorm(config.hidden_size) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) # Use length of SentencePiece vocab def forward(self, idx, targets=None): tok_emb = self.token_embedding_table(idx) pos_emb = self.position_embedding_table(torch.arange(idx.size(1), device='cpu')) x = tok_emb + pos_emb x = self.transformer(x, x) x = self.ln1(x) x = self.ln2(x) logits = self.lm_head(x) if targets is None: loss = None else: loss = F.cross_entropy(logits.view(-1, self.config.vocab_size), targets.view(-1)) return logits, loss def generate(self, idx, max_new_tokens): for _ in range(max_new_tokens): idx_cond = idx[:, -self.config.block_size:] logits, loss = self(idx_cond) logits = logits[:, -1, :] probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) idx = torch.cat((idx, idx_next), dim=1) return idx