PatrickHaller commited on
Commit
71d07d0
·
verified ·
1 Parent(s): 3e64418

Upload modeling_xlstm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_xlstm.py +296 -0
modeling_xlstm.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Sequence, Tuple, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
6
+ from transformers import PreTrainedModel
7
+ from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
8
+ from xlstm.components.init import small_init_init_
9
+ from xlstm.utils import WeightDecayOptimGroupMixin
10
+ from xlstm.xlstm_block_stack import xLSTMBlockStack as _xLSTMBlockStack
11
+
12
+ from .configuration_xlstm import xLSTMConfig
13
+
14
+
15
+ class xLSTMPreTrainedModel(PreTrainedModel):
16
+ """Base class for all models."""
17
+
18
+ config_class = xLSTMConfig
19
+
20
+
21
+ class xLSTMBlockStack(_xLSTMBlockStack):
22
+ """Small wrapper to expose hidden states"""
23
+
24
+ def forward(
25
+ self, x: torch.Tensor, **kwargs
26
+ ) -> Tuple[torch.Tensor, Sequence[torch.Tensor]]:
27
+ hidden_states = ()
28
+ for block in self.blocks:
29
+ x = block(x, **kwargs)
30
+ hidden_states += (x,)
31
+
32
+ x = self.post_blocks_norm(x)
33
+
34
+ return x, hidden_states
35
+
36
+
37
+ class xLSTMModel(xLSTMPreTrainedModel):
38
+ def __init__(self, config: xLSTMConfig):
39
+ super().__init__(config)
40
+ self.config = config
41
+
42
+ self.token_embedding = nn.Embedding(
43
+ num_embeddings=config.vocab_size, embedding_dim=config.embedding_dim
44
+ )
45
+ _config = config.to_xlstm_config()
46
+
47
+ self.emb_dropout = (
48
+ nn.Dropout(_config.dropout)
49
+ if _config.add_embedding_dropout
50
+ else nn.Identity()
51
+ )
52
+
53
+ self.xlstm_block_stack = xLSTMBlockStack(config=_config)
54
+
55
+
56
+ def forward(
57
+ self,
58
+ input_ids: torch.LongTensor,
59
+ output_hidden_states: Optional[bool] = None,
60
+ return_dict=Optional[bool],
61
+ ) -> Union[Tuple, BaseModelOutput]:
62
+ token_embedding = self.token_embedding(input_ids)
63
+ x = self.emb_dropout(token_embedding)
64
+ x, hidden_states = self.xlstm_block_stack(x)
65
+
66
+ if output_hidden_states:
67
+ hidden_states = (token_embedding,) + hidden_states
68
+
69
+ if not return_dict:
70
+ return x, hidden_states
71
+
72
+ return BaseModelOutput(
73
+ last_hidden_state=x,
74
+ hidden_states=hidden_states if output_hidden_states else None,
75
+ )
76
+
77
+
78
+ class xLSTMForCausalLM(xLSTMPreTrainedModel, WeightDecayOptimGroupMixin):
79
+ _tied_weights_keys = ["lm_head.weight"]
80
+
81
+ def __init__(self, config: xLSTMConfig, **kwargs):
82
+ super().__init__(config)
83
+ self.config = config
84
+ self.vocab_size = config.vocab_size
85
+
86
+ self.model = xLSTMModel(config)
87
+
88
+ self.lm_head = nn.Linear(
89
+ in_features=config.embedding_dim,
90
+ out_features=config.vocab_size,
91
+ bias=False,
92
+ )
93
+
94
+ self.post_init()
95
+ # TODO: Add option for up-projection
96
+
97
+ def get_input_embeddings(self):
98
+ return self.model.token_embedding
99
+
100
+ def set_input_embeddings(self, value: nn.Module):
101
+ self.model.token_embedding = value
102
+
103
+ def get_output_embeddings(self):
104
+ return self.lm_head
105
+
106
+ def set_output_embeddings(self, value):
107
+ self.lm_head = value
108
+
109
+ def reset_parameters(self):
110
+ self.model.xlstm_block_stack.reset_parameters()
111
+
112
+ small_init_init_(
113
+ self.get_input_embeddings().weight, dim=self.config.embedding_dim
114
+ )
115
+
116
+ if not self.config.tie_word_embeddings:
117
+ small_init_init_(
118
+ self.get_output_embeddings().weight, dim=self.config.embedding_dim
119
+ )
120
+
121
+ def forward(
122
+ self,
123
+ input_ids: torch.Tensor,
124
+ labels: Optional[torch.LongTensor] = None,
125
+ output_hidden_states: Optional[bool] = None,
126
+ return_dict: Optional[bool] = None,
127
+ ):
128
+ output = self.model(
129
+ input_ids,
130
+ output_hidden_states=output_hidden_states,
131
+ )
132
+
133
+ hidden_state = output[0]
134
+
135
+ logits = self.lm_head(hidden_state)
136
+ logits = logits.float()
137
+
138
+ loss = None
139
+
140
+ if labels is not None:
141
+ shift_logits = logits[..., :-1, :].contiguous()
142
+ shift_labels = labels[..., 1:].contiguous()
143
+
144
+ loss_fct = nn.CrossEntropyLoss()
145
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
146
+ shift_labels = shift_labels.view(-1)
147
+
148
+ shift_labels = shift_labels.to(shift_logits.device)
149
+ loss = loss_fct(shift_logits, shift_labels)
150
+
151
+ if not return_dict:
152
+ output = (logits,) + output[1:]
153
+ return ((loss,) + output) if loss is not None else output
154
+
155
+ return CausalLMOutputWithPast(
156
+ loss=loss,
157
+ logits=logits,
158
+ hidden_states=output.hidden_states,
159
+ )
160
+
161
+ def step(
162
+ self,
163
+ idx: torch.Tensor,
164
+ state: dict[str, dict[str, tuple[torch.Tensor, ...]]] = None,
165
+ **kwargs,
166
+ ) -> tuple[torch.Tensor, dict[str, dict[str, tuple[torch.Tensor, ...]]]]:
167
+ x = self.token_embedding(idx)
168
+ x = self.emb_dropout(x)
169
+ x, state = self.xlstm_block_stack.step(x, state=state, **kwargs)
170
+ logits = self.lm_head(x)
171
+ return logits, state
172
+
173
+ def _create_weight_decay_optim_groups(
174
+ self, **kwargs
175
+ ) -> tuple[Sequence[nn.Parameter], Sequence[nn.Parameter]]:
176
+ weight_decay, no_weight_decay = super()._create_weight_decay_optim_groups(
177
+ **kwargs
178
+ )
179
+ # remove token embedding and add it to the correct group, accrording to the config
180
+ weight_decay = list(weight_decay)
181
+ removed = 0
182
+ for idx in range(len(weight_decay)):
183
+ if weight_decay[idx - removed] is self.get_input_embeddings().weight:
184
+ weight_decay.pop(idx - removed)
185
+ removed += 1
186
+ weight_decay = tuple(weight_decay)
187
+
188
+ # TODO: Fix this
189
+ # if self.config.weight_decay_on_embedding:
190
+ if True:
191
+ weight_decay += (self.get_input_embeddings().weight,)
192
+ else:
193
+ no_weight_decay += (self.get_input_embeddings().weight,)
194
+
195
+ return weight_decay, no_weight_decay
196
+
197
+ def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
198
+ new_embeddings = nn.Embedding(
199
+ new_num_tokens, self.token_embedding.embedding_dim
200
+ )
201
+ self.token_embedding = new_embeddings.to(self.device)
202
+ return new_embeddings
203
+
204
+ def tie_weights(self):
205
+ self.get_output_embeddings().weight = self.get_input_embeddings().weight
206
+
207
+ def prepare_inputs_for_generation(
208
+ self,
209
+ input_ids,
210
+ **kwargs,
211
+ ):
212
+ model_inputs = {
213
+ "input_ids": input_ids.to(self.device),
214
+ }
215
+ return model_inputs
216
+
217
+
218
+ class xLSTMForSequenceClassification(xLSTMPreTrainedModel):
219
+
220
+ def __init__(self, config: xLSTMConfig, **kwargs):
221
+ super().__init__(config)
222
+ self.num_labels = config.num_labels
223
+ self.config = config
224
+ self.model = xLSTMModel(config)
225
+ self.classifier = nn.Linear(config.embedding_dim, config.num_labels, bias=False)
226
+
227
+ self.init_weights()
228
+
229
+ def forward(
230
+ self,
231
+ input_ids: torch.Tensor,
232
+ labels: Optional[torch.LongTensor] = None,
233
+ output_hidden_states: Optional[bool] = None,
234
+ return_dict: Optional[bool] = None,
235
+ ):
236
+ output = self.model(
237
+ input_ids,
238
+ output_hidden_states=output_hidden_states,
239
+ )
240
+
241
+ hidden_state = output[0]
242
+
243
+ logits = self.classifier(hidden_state)
244
+ batch_size = input_ids.shape[0]
245
+
246
+ if self.config.pad_token_id is None and batch_size != 1:
247
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
248
+ if self.config.pad_token_id is None:
249
+ sequence_lengths = -1
250
+ else:
251
+ if input_ids is not None:
252
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
253
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
254
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
255
+ sequence_lengths = sequence_lengths.to(logits.device)
256
+ else:
257
+ sequence_lengths = -1
258
+
259
+
260
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
261
+
262
+ loss = None
263
+
264
+ if labels is not None:
265
+ labels = labels.to(logits.device)
266
+ if self.config.problem_type is None:
267
+ if self.num_labels == 1:
268
+ self.config.problem_type = "regression"
269
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
270
+ self.config.problem_type = "single_label_classification"
271
+ else:
272
+ self.config.problem_type = "multi_label_classification"
273
+
274
+ if self.config.problem_type == "regression":
275
+ loss_fct = MSELoss()
276
+ if self.num_labels == 1:
277
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
278
+ else:
279
+ loss = loss_fct(pooled_logits, labels)
280
+ elif self.config.problem_type == "single_label_classification":
281
+ loss_fct = CrossEntropyLoss()
282
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
283
+ elif self.config.problem_type == "multi_label_classification":
284
+ loss_fct = BCEWithLogitsLoss()
285
+ loss = loss_fct(pooled_logits, labels)
286
+
287
+ if not return_dict:
288
+ output = (pooled_logits,) + output[1:]
289
+ return ((loss,) + output) if loss is not None else output
290
+
291
+
292
+ return SequenceClassifierOutputWithPast(
293
+ loss=loss,
294
+ logits=pooled_logits,
295
+ hidden_states=output.hidden_states,
296
+ )