|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
|
|
class OBILanguageModel(PreTrainedModel): |
|
def __init__(self, config): |
|
super(OBILanguageModel,self).__init__(config) |
|
self.token_embedding_table = nn.Embedding(config.vocab_size, config.hidden_size) |
|
self.position_embedding_table = nn.Embedding(config.block_size, config.hidden_size) |
|
self.transformer = nn.Transformer( |
|
d_model=config.hidden_size, |
|
nhead=config.num_attention_heads, |
|
num_encoder_layers=config.num_hidden_layers, |
|
num_decoder_layers=config.num_hidden_layers, |
|
dim_feedforward=4 * config.hidden_size, |
|
dropout=config.hidden_dropout_prob, |
|
activation='gelu' |
|
) |
|
self.ln1 = nn.LayerNorm(config.hidden_size) |
|
self.ln2 = nn.LayerNorm(config.hidden_size) |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) |
|
|
|
def forward(self, idx, targets=None): |
|
tok_emb = self.token_embedding_table(idx) |
|
pos_emb = self.position_embedding_table(torch.arange(idx.size(1), device='cpu')) |
|
x = tok_emb + pos_emb |
|
x = self.transformer(x, x) |
|
x = self.ln1(x) |
|
x = self.ln2(x) |
|
logits = self.lm_head(x) |
|
|
|
if targets is None: |
|
loss = None |
|
else: |
|
loss = F.cross_entropy(logits.view(-1, self.config.vocab_size), targets.view(-1)) |
|
|
|
return logits, loss |
|
|
|
def generate(self, idx, max_new_tokens): |
|
for _ in range(max_new_tokens): |
|
idx_cond = idx[:, -self.config.block_size:] |
|
logits, loss = self(idx_cond) |
|
logits = logits[:, -1, :] |
|
probs = F.softmax(logits, dim=-1) |
|
idx_next = torch.multinomial(probs, num_samples=1) |
|
idx = torch.cat((idx, idx_next), dim=1) |
|
return idx |
|
|