PatrickHaller commited on
Commit
5060d5b
1 Parent(s): 8cf7fe2

Upload modeling_xlstm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_xlstm.py +297 -0
modeling_xlstm.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Sequence, Tuple, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
6
+ from transformers import PreTrainedModel
7
+ from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
8
+ from xlstm.components.init import small_init_init_
9
+ from xlstm.utils import WeightDecayOptimGroupMixin
10
+ from xlstm.xlstm_block_stack import xLSTMBlockStack as _xLSTMBlockStack
11
+
12
+ from .configuration_xlstm import xLSTMConfig
13
+
14
+
15
+ class xLSTMPreTrainedModel(PreTrainedModel):
16
+ """Base class for all models."""
17
+
18
+ config_class = xLSTMConfig
19
+
20
+
21
+ class xLSTMBlockStack(_xLSTMBlockStack):
22
+ """Small wrapper to expose hidden states"""
23
+
24
+ def forward(
25
+ self, x: torch.Tensor, **kwargs
26
+ ) -> Tuple[torch.Tensor, Sequence[torch.Tensor]]:
27
+ hidden_states = ()
28
+ for block in self.blocks:
29
+ x = block(x, **kwargs)
30
+ hidden_states += (x,)
31
+
32
+ x = self.post_blocks_norm(x)
33
+
34
+ return x, hidden_states
35
+
36
+
37
+ class xLSTMModel(xLSTMPreTrainedModel):
38
+ def __init__(self, config: xLSTMConfig):
39
+ super().__init__(config)
40
+ self.config = config
41
+
42
+ self.token_embedding = nn.Embedding(
43
+ num_embeddings=config.vocab_size, embedding_dim=config.embedding_dim
44
+ )
45
+ _config = config.to_xlstm_config()
46
+
47
+ self.emb_dropout = (
48
+ nn.Dropout(_config.dropout)
49
+ if _config.add_embedding_dropout
50
+ else nn.Identity()
51
+ )
52
+
53
+ self.xlstm_block_stack = xLSTMBlockStack(config=_config)
54
+
55
+
56
+ def forward(
57
+ self,
58
+ input_ids: torch.LongTensor,
59
+ output_hidden_states: Optional[bool] = None,
60
+ return_dict=Optional[bool],
61
+ ) -> Union[Tuple, BaseModelOutput]:
62
+ token_embedding = self.token_embedding(input_ids)
63
+ x = self.emb_dropout(token_embedding)
64
+ x, hidden_states = self.xlstm_block_stack(x)
65
+
66
+ if output_hidden_states:
67
+ hidden_states = (token_embedding,) + hidden_states
68
+
69
+ if not return_dict:
70
+ return x, hidden_states
71
+
72
+ return BaseModelOutput(
73
+ last_hidden_state=x,
74
+ hidden_states=hidden_states if output_hidden_states else None,
75
+ )
76
+
77
+
78
+ class xLSTMForCausalLM(xLSTMPreTrainedModel, WeightDecayOptimGroupMixin):
79
+ _tied_weights_keys = ["lm_head.weight"]
80
+
81
+ def __init__(self, config: xLSTMConfig, **kwargs):
82
+ super().__init__(config)
83
+ self.config = config
84
+ self.vocab_size = config.vocab_size
85
+
86
+ self.model = xLSTMModel(config)
87
+
88
+ self.lm_head = nn.Linear(
89
+ in_features=config.embedding_dim,
90
+ out_features=config.vocab_size,
91
+ bias=False,
92
+ )
93
+
94
+ self.post_init()
95
+ # TODO: Add option for up-projection
96
+
97
+ def get_input_embeddings(self):
98
+ return self.model.token_embedding
99
+
100
+ def set_input_embeddings(self, value: nn.Module):
101
+ self.model.token_embedding = value
102
+
103
+ def get_output_embeddings(self):
104
+ return self.lm_head
105
+
106
+ def set_output_embeddings(self, value):
107
+ self.lm_head = value
108
+
109
+ def reset_parameters(self):
110
+ self.model.xlstm_block_stack.reset_parameters()
111
+
112
+ small_init_init_(
113
+ self.get_input_embeddings().weight, dim=self.config.embedding_dim
114
+ )
115
+
116
+ if not self.config.tie_word_embeddings:
117
+ small_init_init_(
118
+ self.get_output_embeddings().weight, dim=self.config.embedding_dim
119
+ )
120
+
121
+ def forward(
122
+ self,
123
+ input_ids: torch.Tensor,
124
+ labels: Optional[torch.LongTensor] = None,
125
+ output_hidden_states: Optional[bool] = None,
126
+ return_dict: Optional[bool] = None,
127
+ **kwargs,
128
+ ):
129
+ output = self.model(
130
+ input_ids,
131
+ output_hidden_states=output_hidden_states,
132
+ )
133
+
134
+ hidden_state = output[0]
135
+
136
+ logits = self.lm_head(hidden_state)
137
+ logits = logits.float()
138
+
139
+ loss = None
140
+
141
+ if labels is not None:
142
+ shift_logits = logits[..., :-1, :].contiguous()
143
+ shift_labels = labels[..., 1:].contiguous()
144
+
145
+ loss_fct = nn.CrossEntropyLoss()
146
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
147
+ shift_labels = shift_labels.view(-1)
148
+
149
+ shift_labels = shift_labels.to(shift_logits.device)
150
+ loss = loss_fct(shift_logits, shift_labels)
151
+
152
+ if not return_dict:
153
+ output = (logits,) + output[1:]
154
+ return ((loss,) + output) if loss is not None else output
155
+
156
+ return CausalLMOutputWithPast(
157
+ loss=loss,
158
+ logits=logits,
159
+ hidden_states=output.hidden_states,
160
+ )
161
+
162
+ def step(
163
+ self,
164
+ idx: torch.Tensor,
165
+ state: dict[str, dict[str, tuple[torch.Tensor, ...]]] = None,
166
+ **kwargs,
167
+ ) -> tuple[torch.Tensor, dict[str, dict[str, tuple[torch.Tensor, ...]]]]:
168
+ x = self.token_embedding(idx)
169
+ x = self.emb_dropout(x)
170
+ x, state = self.xlstm_block_stack.step(x, state=state, **kwargs)
171
+ logits = self.lm_head(x)
172
+ return logits, state
173
+
174
+ def _create_weight_decay_optim_groups(
175
+ self, **kwargs
176
+ ) -> tuple[Sequence[nn.Parameter], Sequence[nn.Parameter]]:
177
+ weight_decay, no_weight_decay = super()._create_weight_decay_optim_groups(
178
+ **kwargs
179
+ )
180
+ # remove token embedding and add it to the correct group, accrording to the config
181
+ weight_decay = list(weight_decay)
182
+ removed = 0
183
+ for idx in range(len(weight_decay)):
184
+ if weight_decay[idx - removed] is self.get_input_embeddings().weight:
185
+ weight_decay.pop(idx - removed)
186
+ removed += 1
187
+ weight_decay = tuple(weight_decay)
188
+
189
+ # TODO: Fix this
190
+ # if self.config.weight_decay_on_embedding:
191
+ if True:
192
+ weight_decay += (self.get_input_embeddings().weight,)
193
+ else:
194
+ no_weight_decay += (self.get_input_embeddings().weight,)
195
+
196
+ return weight_decay, no_weight_decay
197
+
198
+ def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
199
+ new_embeddings = nn.Embedding(
200
+ new_num_tokens, self.token_embedding.embedding_dim
201
+ )
202
+ self.token_embedding = new_embeddings.to(self.device)
203
+ return new_embeddings
204
+
205
+ def tie_weights(self):
206
+ self.get_output_embeddings().weight = self.get_input_embeddings().weight
207
+
208
+ def prepare_inputs_for_generation(
209
+ self,
210
+ input_ids,
211
+ **kwargs,
212
+ ):
213
+ model_inputs = {
214
+ "input_ids": input_ids.to(self.device),
215
+ }
216
+ return model_inputs
217
+
218
+
219
+ class xLSTMForSequenceClassification(xLSTMPreTrainedModel):
220
+
221
+ def __init__(self, config: xLSTMConfig, **kwargs):
222
+ super().__init__(config)
223
+ self.num_labels = config.num_labels
224
+ self.config = config
225
+ self.model = xLSTMModel(config)
226
+ self.classifier = nn.Linear(config.embedding_dim, config.num_labels, bias=False)
227
+
228
+ self.init_weights()
229
+
230
+ def forward(
231
+ self,
232
+ input_ids: torch.Tensor,
233
+ labels: Optional[torch.LongTensor] = None,
234
+ output_hidden_states: Optional[bool] = None,
235
+ return_dict: Optional[bool] = None,
236
+ ):
237
+ output = self.model(
238
+ input_ids,
239
+ output_hidden_states=output_hidden_states,
240
+ )
241
+
242
+ hidden_state = output[0]
243
+
244
+ logits = self.classifier(hidden_state)
245
+ batch_size = input_ids.shape[0]
246
+
247
+ if self.config.pad_token_id is None and batch_size != 1:
248
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
249
+ if self.config.pad_token_id is None:
250
+ sequence_lengths = -1
251
+ else:
252
+ if input_ids is not None:
253
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
254
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
255
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
256
+ sequence_lengths = sequence_lengths.to(logits.device)
257
+ else:
258
+ sequence_lengths = -1
259
+
260
+
261
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
262
+
263
+ loss = None
264
+
265
+ if labels is not None:
266
+ labels = labels.to(logits.device)
267
+ if self.config.problem_type is None:
268
+ if self.num_labels == 1:
269
+ self.config.problem_type = "regression"
270
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
271
+ self.config.problem_type = "single_label_classification"
272
+ else:
273
+ self.config.problem_type = "multi_label_classification"
274
+
275
+ if self.config.problem_type == "regression":
276
+ loss_fct = MSELoss()
277
+ if self.num_labels == 1:
278
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
279
+ else:
280
+ loss = loss_fct(pooled_logits, labels)
281
+ elif self.config.problem_type == "single_label_classification":
282
+ loss_fct = CrossEntropyLoss()
283
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
284
+ elif self.config.problem_type == "multi_label_classification":
285
+ loss_fct = BCEWithLogitsLoss()
286
+ loss = loss_fct(pooled_logits, labels)
287
+
288
+ if not return_dict:
289
+ output = (pooled_logits,) + output[1:]
290
+ return ((loss,) + output) if loss is not None else output
291
+
292
+
293
+ return SequenceClassifierOutputWithPast(
294
+ loss=loss,
295
+ logits=pooled_logits,
296
+ hidden_states=output.hidden_states,
297
+ )