File size: 6,088 Bytes
2d8d1b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a895ef9
 
 
 
4c81121
2d8d1b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
import torch.nn as nn
import numpy as np
from transformers import PreTrainedModel, PretrainedConfig, AutoModelForCausalLM, AutoConfig
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

torch.autograd.set_detect_anomaly(True)

class JudgeXLConfig(PretrainedConfig):
    model_type = "judge-xl"

    def __init__(self, vocab_size=50276, hidden_size=768, max_len=256, n_layer=12, n_head=12,

                 ff_expansion_factor=4, rnn_units=768, num_labels=5, dropout=0.1, **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.max_len = max_len
        self.n_layer = n_layer
        self.n_head = n_head
        self.ff_expansion_factor = ff_expansion_factor
        self.rnn_units = rnn_units
        self.num_labels = num_labels
        self.dropout = dropout
        self.is_decoder = True

class CustomEmbedding(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super(CustomEmbedding, self).__init__()
        print(f"vocab_size: {vocab_size}, hidden_size: {hidden_size}")  # Debugging print
        assert isinstance(vocab_size, int) and isinstance(hidden_size, int), \
            f"Expected integers, but got vocab_size={type(vocab_size)} and hidden_size={type(hidden_size)}"
        self.embedding = nn.Embedding(vocab_size, hidden_size)

    def forward(self, inputs):
        return self.embedding(inputs)

class PositionalEncoding(nn.Module):
    def __init__(self, n_embd, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.n_embd = n_embd
        self.max_len = max_len
        pe = torch.zeros(max_len, n_embd)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, n_embd, 2).float() * -(np.log(10000.0) / n_embd))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]

class TransformerXLBlock(nn.Module):
    def __init__(self, config):
        super(TransformerXLBlock, self).__init__()
        self.attn = nn.MultiheadAttention(config.hidden_size, config.n_head, dropout=config.dropout)
        self.ff = FeedForward(config)
        self.ln1 = nn.LayerNorm(config.hidden_size)
        self.ln2 = nn.LayerNorm(config.hidden_size)

    def forward(self, x, mask=None):
        attn_out, _ = self.attn(x, x, x, attn_mask=mask)
        out1 = self.ln1(x + attn_out)
        ff_out = self.ff(out1)
        return self.ln2(out1 + ff_out)

class FeedForward(nn.Module):
    def __init__(self, config):
        super(FeedForward, self).__init__()
        self.dense1 = nn.Linear(config.hidden_size, config.hidden_size * config.ff_expansion_factor)
        self.dense2 = nn.Linear(config.hidden_size * config.ff_expansion_factor, config.hidden_size)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = torch.nn.functional.gelu(self.dense1(x))
        x = self.dropout(x)
        return self.dense2(x)

class JudgeXL(PreTrainedModel):
    config_class = JudgeXLConfig

    def __init__(self, config):
        super().__init__(config)
        self.token_embedding = CustomEmbedding(config.vocab_size, config.hidden_size)
        self.pos_encoding = PositionalEncoding(config.hidden_size, config.max_len)
        self.transformer_blocks = nn.ModuleList([TransformerXLBlock(config) for _ in range(config.n_layer)])
        self.ln_f = nn.LayerNorm(config.hidden_size)
        self.rnn = nn.LSTM(config.hidden_size, config.rnn_units, num_layers=2, dropout=config.dropout, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(config.rnn_units * 2, config.vocab_size)
        self.lm_head = nn.Linear(config.rnn_units, config.vocab_size)
        self.post_init()

    def forward(self, x, mask=None):
        x = self.token_embedding(x)
        x = self.pos_encoding(x)
        for block in self.transformer_blocks:
            x = block(x, mask=mask)
        x = self.ln_f(x)
        x, _ = self.rnn(x)
        x = self.fc(x)
        x = self.lm_head(x)
        return x
    def init_weights(self):
        """

        Initialize weights for your custom layers using PreTrainedModel's default weight initialization method.

        """
        # Hugging Face’s PreTrainedModel has a standard method for initializing weights
        super().init_weights()
    def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
        if past is None:
            return {"input_ids": input_ids}
        else:
            return {"input_ids": input_ids[:, -1:], "past_key_values": past}

    def _reorder_cache(self, past, beam_idx):
        return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
    def generate(self, prompt, max_len=100):
        self.eval()
        input_ids = self.tokenizer(prompt, return_tensors='pt').input_ids
        generated = input_ids
        with torch.no_grad():
            for _ in range(max_len):
                outputs = self.forward(generated)
                next_token_logits = outputs[:, :]  # Adjusted indexing
                next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
                generated = torch.cat((generated, next_token_id), dim=1)
                if next_token_id.item() == self.tokenizer.sep_token_id:
                    break
        generated_text = self.tokenizer.decode(generated[0], skip_special_tokens=True)
        return generated_text

config = JudgeXLConfig()    
model = JudgeXL(config)

# Register JudgeXLConfig with AutoConfig
JudgeXLConfig.register_for_auto_class(AutoConfig)

# Register JudgeXL with AutoModelForCausalLM
JudgeXL.register_for_auto_class(AutoModelForCausalLM)

model.push_to_hub("Wonder-Griffin/judge-xl-model")