File size: 6,088 Bytes
2d8d1b2 a895ef9 4c81121 2d8d1b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import torch
import torch.nn as nn
import numpy as np
from transformers import PreTrainedModel, PretrainedConfig, AutoModelForCausalLM, AutoConfig
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
torch.autograd.set_detect_anomaly(True)
class JudgeXLConfig(PretrainedConfig):
model_type = "judge-xl"
def __init__(self, vocab_size=50276, hidden_size=768, max_len=256, n_layer=12, n_head=12,
ff_expansion_factor=4, rnn_units=768, num_labels=5, dropout=0.1, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.max_len = max_len
self.n_layer = n_layer
self.n_head = n_head
self.ff_expansion_factor = ff_expansion_factor
self.rnn_units = rnn_units
self.num_labels = num_labels
self.dropout = dropout
self.is_decoder = True
class CustomEmbedding(nn.Module):
def __init__(self, vocab_size, hidden_size):
super(CustomEmbedding, self).__init__()
print(f"vocab_size: {vocab_size}, hidden_size: {hidden_size}") # Debugging print
assert isinstance(vocab_size, int) and isinstance(hidden_size, int), \
f"Expected integers, but got vocab_size={type(vocab_size)} and hidden_size={type(hidden_size)}"
self.embedding = nn.Embedding(vocab_size, hidden_size)
def forward(self, inputs):
return self.embedding(inputs)
class PositionalEncoding(nn.Module):
def __init__(self, n_embd, max_len=5000):
super(PositionalEncoding, self).__init__()
self.n_embd = n_embd
self.max_len = max_len
pe = torch.zeros(max_len, n_embd)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, n_embd, 2).float() * -(np.log(10000.0) / n_embd))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:x.size(0), :]
class TransformerXLBlock(nn.Module):
def __init__(self, config):
super(TransformerXLBlock, self).__init__()
self.attn = nn.MultiheadAttention(config.hidden_size, config.n_head, dropout=config.dropout)
self.ff = FeedForward(config)
self.ln1 = nn.LayerNorm(config.hidden_size)
self.ln2 = nn.LayerNorm(config.hidden_size)
def forward(self, x, mask=None):
attn_out, _ = self.attn(x, x, x, attn_mask=mask)
out1 = self.ln1(x + attn_out)
ff_out = self.ff(out1)
return self.ln2(out1 + ff_out)
class FeedForward(nn.Module):
def __init__(self, config):
super(FeedForward, self).__init__()
self.dense1 = nn.Linear(config.hidden_size, config.hidden_size * config.ff_expansion_factor)
self.dense2 = nn.Linear(config.hidden_size * config.ff_expansion_factor, config.hidden_size)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
x = torch.nn.functional.gelu(self.dense1(x))
x = self.dropout(x)
return self.dense2(x)
class JudgeXL(PreTrainedModel):
config_class = JudgeXLConfig
def __init__(self, config):
super().__init__(config)
self.token_embedding = CustomEmbedding(config.vocab_size, config.hidden_size)
self.pos_encoding = PositionalEncoding(config.hidden_size, config.max_len)
self.transformer_blocks = nn.ModuleList([TransformerXLBlock(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.hidden_size)
self.rnn = nn.LSTM(config.hidden_size, config.rnn_units, num_layers=2, dropout=config.dropout, bidirectional=True, batch_first=True)
self.fc = nn.Linear(config.rnn_units * 2, config.vocab_size)
self.lm_head = nn.Linear(config.rnn_units, config.vocab_size)
self.post_init()
def forward(self, x, mask=None):
x = self.token_embedding(x)
x = self.pos_encoding(x)
for block in self.transformer_blocks:
x = block(x, mask=mask)
x = self.ln_f(x)
x, _ = self.rnn(x)
x = self.fc(x)
x = self.lm_head(x)
return x
def init_weights(self):
"""
Initialize weights for your custom layers using PreTrainedModel's default weight initialization method.
"""
# Hugging Face’s PreTrainedModel has a standard method for initializing weights
super().init_weights()
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
if past is None:
return {"input_ids": input_ids}
else:
return {"input_ids": input_ids[:, -1:], "past_key_values": past}
def _reorder_cache(self, past, beam_idx):
return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
def generate(self, prompt, max_len=100):
self.eval()
input_ids = self.tokenizer(prompt, return_tensors='pt').input_ids
generated = input_ids
with torch.no_grad():
for _ in range(max_len):
outputs = self.forward(generated)
next_token_logits = outputs[:, :] # Adjusted indexing
next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
generated = torch.cat((generated, next_token_id), dim=1)
if next_token_id.item() == self.tokenizer.sep_token_id:
break
generated_text = self.tokenizer.decode(generated[0], skip_special_tokens=True)
return generated_text
config = JudgeXLConfig()
model = JudgeXL(config)
# Register JudgeXLConfig with AutoConfig
JudgeXLConfig.register_for_auto_class(AutoConfig)
# Register JudgeXL with AutoModelForCausalLM
JudgeXL.register_for_auto_class(AutoModelForCausalLM)
model.push_to_hub("Wonder-Griffin/judge-xl-model")
|