jsunn-y commited on
Commit
6d75398
1 Parent(s): dde65c9

added the model file

Browse files
Files changed (1) hide show
  1. model.py +1090 -0
model.py ADDED
@@ -0,0 +1,1090 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modified forward-pass implementation based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py
16
+ import math
17
+ from dataclasses import dataclass
18
+ from typing import Optional, Tuple, Union, Dict
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ import torch.utils.checkpoint
24
+ from torch.nn import CrossEntropyLoss
25
+ from transformers.activations import ACT2FN
26
+ from transformers.cache_utils import Cache, DynamicCache
27
+ from transformers.modeling_outputs import (
28
+ BaseModelOutputWithPast as _BaseModelOutputWithPast,
29
+ )
30
+ from transformers.modeling_outputs import (
31
+ CausalLMOutputWithPast as _CausalLMOutputWithPast,
32
+ )
33
+ from transformers.modeling_utils import PreTrainedModel
34
+ from transformers.utils import logging
35
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
36
+
37
+ from .adapter import ParallelAdapterLayer, ProjectionMLP
38
+ from .config import ProGenConfig, ProGenConditionalConfig
39
+ from ..utils import exists
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+ @dataclass
44
+ class BaseModelOutputWithPast(_BaseModelOutputWithPast):
45
+ inputs: Optional[Union[torch.LongTensor, torch.FloatTensor]] = None
46
+
47
+
48
+ @dataclass
49
+ class CausalLMOutputWithPast(_CausalLMOutputWithPast):
50
+ all_losses: Optional[torch.FloatTensor] = None
51
+ inputs: Optional[Union[torch.LongTensor, torch.FloatTensor]] = None
52
+
53
+ def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
54
+ dim = x.shape[-1]
55
+ if seq_len is None:
56
+ seq_len = x.shape[seq_dim]
57
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
58
+ sinusoid_inp = (
59
+ torch.einsum("i , j -> i j", torch.arange(seq_len), inv_freq).to(x.device).float()
60
+ )
61
+ return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
62
+
63
+
64
+ def rotate_every_two(x):
65
+ x1 = x[:, :, :, ::2]
66
+ x2 = x[:, :, :, 1::2]
67
+ x = torch.stack((-x2, x1), axis=-1)
68
+ return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
69
+
70
+
71
+ def apply_rotary_pos_emb(x, sincos, offset=0):
72
+ sin, cos = map(
73
+ lambda t: t[None, offset : x.shape[1] + offset, None, :].repeat_interleave(2, 3), sincos
74
+ )
75
+ # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
76
+ return (x * cos) + (rotate_every_two(x) * sin)
77
+
78
+
79
+ class ProGenAttention(nn.Module):
80
+ def __init__(self, config):
81
+ super().__init__()
82
+ self.config = config
83
+
84
+ max_positions = config.max_position_embeddings
85
+ self.register_buffer(
86
+ "bias",
87
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
88
+ 1, 1, max_positions, max_positions
89
+ ),
90
+ )
91
+ self.register_buffer("masked_bias", torch.tensor(-1e9))
92
+
93
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
94
+ self.attn_pdrop = config.attn_pdrop
95
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
96
+
97
+ self.embed_dim = config.hidden_size
98
+ self.num_attention_heads = config.num_attention_heads
99
+ self.head_dim = self.embed_dim // self.num_attention_heads
100
+ if self.head_dim * self.num_attention_heads != self.embed_dim:
101
+ raise ValueError(
102
+ f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and `num_attention_heads`: {self.num_attention_heads})."
103
+ )
104
+ self.scale_attn = math.sqrt(self.head_dim)
105
+ self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False)
106
+
107
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
108
+ self.rotary_dim = None
109
+ if config.rotary_dim is not None:
110
+ self.rotary_dim = config.rotary_dim
111
+
112
+ def _split_heads(self, x, n_head, dim_head, mp_num):
113
+ reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head))
114
+ reshaped = reshaped.reshape(x.shape[:-2] + (-1,) + reshaped.shape[-1:])
115
+ return reshaped
116
+
117
+ def _naive_attn(
118
+ self,
119
+ query,
120
+ key,
121
+ value,
122
+ attention_mask=None,
123
+ ):
124
+ # compute causal mask from causal mask buffer
125
+ batch_size, query_length, key_length = query.size(0), query.size(-2), key.size(-2)
126
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
127
+
128
+ attn_weights = torch.matmul(query, key.transpose(-1, -2)) / self.scale_attn
129
+ attn_weights = torch.where(
130
+ causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)
131
+ )
132
+
133
+ if attention_mask is not None:
134
+ # Apply the attention mask
135
+ attn_weights = attn_weights + attention_mask
136
+
137
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
138
+ attn_weights = self.attn_dropout(attn_weights)
139
+ attn_output = torch.matmul(attn_weights, value)
140
+
141
+ expected_size = (batch_size, self.num_attention_heads, query_length, self.head_dim)
142
+ if attn_output.size() != expected_size:
143
+ raise ValueError(
144
+ f"`attn_output` should be of size {expected_size}, but is {attn_output.size()}"
145
+ )
146
+
147
+ attn_output = attn_output.transpose(1, 2).contiguous()
148
+ attn_output = attn_output.reshape(batch_size, query_length, self.embed_dim)
149
+ return attn_output, attn_weights
150
+
151
+ def _sdpa_attn(
152
+ self,
153
+ query,
154
+ key,
155
+ value,
156
+ attention_mask=None,
157
+ ):
158
+ bsz, q_len = query.shape[0], query.shape[2]
159
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
160
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
161
+ if query.device.type == "cuda" and attention_mask is not None:
162
+ query = query.contiguous()
163
+ key = key.contiguous()
164
+ value = value.contiguous()
165
+
166
+ attn_output = F.scaled_dot_product_attention(
167
+ query,
168
+ key,
169
+ value,
170
+ attn_mask=attention_mask,
171
+ dropout_p=self.attn_pdrop if self.training else 0.0,
172
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
173
+ is_causal=q_len > 1,
174
+ scale=1 / self.scale_attn,
175
+ )
176
+
177
+ attn_output = attn_output.transpose(1, 2).contiguous()
178
+ attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
179
+ return attn_output, None
180
+
181
+ def forward(
182
+ self,
183
+ hidden_states,
184
+ attention_mask=None,
185
+ layer_past=None,
186
+ use_cache=False,
187
+ output_attentions=False,
188
+ ):
189
+ qkv = self.qkv_proj(hidden_states)
190
+ # TODO(enijkamp): factor out number of logical TPU-v3/v4 cores or make forward pass agnostic
191
+ # mp_num = 4
192
+ mp_num = 8
193
+ qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1))
194
+
195
+ local_dim = self.head_dim * self.num_attention_heads // mp_num
196
+ query, value, key = torch.split(qkv_split, local_dim, dim=-1)
197
+ query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num)
198
+ key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num)
199
+
200
+ value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num)
201
+ value = value.permute(0, 2, 1, 3)
202
+
203
+ seq_len = key.shape[1]
204
+ offset = 0
205
+
206
+ if layer_past is not None:
207
+ offset = layer_past[0].shape[-2]
208
+ seq_len += offset
209
+
210
+ if self.rotary_dim is not None:
211
+ k_rot = key[:, :, :, : self.rotary_dim]
212
+ k_pass = key[:, :, :, self.rotary_dim :]
213
+
214
+ q_rot = query[:, :, :, : self.rotary_dim]
215
+ q_pass = query[:, :, :, self.rotary_dim :]
216
+
217
+ sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len)
218
+ k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset)
219
+ q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset)
220
+
221
+ key = torch.cat([k_rot, k_pass], dim=-1)
222
+ query = torch.cat([q_rot, q_pass], dim=-1)
223
+ else:
224
+ sincos = fixed_pos_embedding(key, 1, seq_len=seq_len)
225
+ key = apply_rotary_pos_emb(key, sincos, offset=offset)
226
+ query = apply_rotary_pos_emb(query, sincos, offset=offset)
227
+
228
+ key = key.permute(0, 2, 1, 3)
229
+ query = query.permute(0, 2, 1, 3)
230
+
231
+ if layer_past is not None:
232
+ past_key = layer_past[0]
233
+ past_value = layer_past[1]
234
+ key = torch.cat((past_key, key), dim=-2)
235
+ value = torch.cat((past_value, value), dim=-2)
236
+
237
+ if use_cache is True:
238
+ present = (key, value)
239
+ else:
240
+ present = None
241
+
242
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
243
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
244
+ # cast them back in float16 just to be sure everything works as expected.
245
+
246
+ input_dtype = query.dtype
247
+ if torch.is_autocast_enabled():
248
+ target_dtype = torch.get_autocast_gpu_dtype()
249
+ # Handle the case where the model is quantized
250
+ elif hasattr(self.config, "_pre_quantization_dtype"):
251
+ target_dtype = self.config._pre_quantization_dtype
252
+ else:
253
+ target_dtype = self.qkv_proj.weight.dtype #this is giving an issue, but it usually isn't called
254
+
255
+ if input_dtype != target_dtype:
256
+ logger.warning_once(
257
+ f"The input hidden states seems to be silently casted in {input_dtype}. "
258
+ f"This might be because you have upcasted embedding or layer norm layers "
259
+ f"in {input_dtype}. We will cast back the input in {target_dtype}."
260
+ )
261
+ query = query.to(target_dtype)
262
+ key = key.to(target_dtype)
263
+ value = value.to(target_dtype)
264
+
265
+ # compute self-attention: V x Softmax(QK^T)
266
+ if output_attentions:
267
+ attn_output, attn_weights = self._naive_attn(query, key, value, attention_mask)
268
+ else:
269
+ attn_output, attn_weights = self._sdpa_attn(query, key, value, None)
270
+ attn_output = self.out_proj(attn_output)
271
+ attn_output = self.resid_dropout(attn_output)
272
+
273
+ outputs = (attn_output, present)
274
+ if output_attentions:
275
+ outputs += (attn_weights,)
276
+
277
+ return outputs
278
+
279
+
280
+ class ProGenMLP(nn.Module):
281
+ def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim
282
+ super().__init__()
283
+ embed_dim = config.n_embd
284
+
285
+ self.fc_in = nn.Linear(embed_dim, intermediate_size)
286
+ self.fc_out = nn.Linear(intermediate_size, embed_dim)
287
+
288
+ self.act = ACT2FN[config.activation_function]
289
+ self.dropout = nn.Dropout(config.resid_pdrop)
290
+
291
+ def forward(self, hidden_states):
292
+ hidden_states = self.fc_in(hidden_states)
293
+ hidden_states = self.act(hidden_states)
294
+ hidden_states = self.fc_out(hidden_states)
295
+ hidden_states = self.dropout(hidden_states)
296
+ return hidden_states
297
+
298
+
299
+ class ProGenBlock(nn.Module):
300
+ def __init__(self, config):
301
+ super().__init__()
302
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
303
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
304
+ self.attn = ProGenAttention(config)
305
+ self.mlp = ProGenMLP(inner_dim, config)
306
+
307
+ def forward(
308
+ self,
309
+ hidden_states,
310
+ layer_past=None,
311
+ attention_mask=None,
312
+ head_mask=None,
313
+ adapter_layer=None,
314
+ adapter_dropout=None,
315
+ adapter_input=None,
316
+ use_cache=False,
317
+ output_attentions=False,
318
+ ):
319
+ residual = hidden_states
320
+ hidden_states = self.ln_1(hidden_states)
321
+ attn_outputs = self.attn(
322
+ hidden_states,
323
+ layer_past=layer_past,
324
+ attention_mask=attention_mask,
325
+ use_cache=use_cache,
326
+ output_attentions=output_attentions,
327
+ )
328
+ attn_output = attn_outputs[0]
329
+ outputs = attn_outputs[1:]
330
+
331
+ feed_forward_hidden_states = self.mlp(hidden_states)
332
+
333
+ ### addition of adapter layer ###
334
+ if exists(adapter_layer) and exists(adapter_dropout) and exists(
335
+ adapter_input):
336
+
337
+ hidden_states_update = attn_output + feed_forward_hidden_states
338
+ adapter_out = adapter_layer(hidden_states_update, adapter_input)
339
+ adapter_out = adapter_dropout(adapter_out)
340
+ hidden_states_update = hidden_states_update + adapter_out
341
+
342
+ hidden_states = hidden_states_update + residual
343
+ else:
344
+ hidden_states = attn_output + feed_forward_hidden_states + residual
345
+ ### end of addition of adapter layer ###
346
+
347
+ if use_cache:
348
+ outputs = (hidden_states,) + outputs
349
+ else:
350
+ outputs = (hidden_states,) + outputs[1:]
351
+
352
+ return outputs
353
+
354
+
355
+ class ProGenPreTrainedModel(PreTrainedModel):
356
+ """An abstract class to handle weights initialization and a simple interface for downloading
357
+ and loading pretrained models."""
358
+
359
+ config_class = ProGenConfig
360
+ base_model_prefix = "transformer"
361
+ is_parallelizable = True
362
+ _no_split_modules = ["ProGenBlock"]
363
+
364
+ def __init__(self, *inputs, **kwargs):
365
+ super().__init__(*inputs, **kwargs)
366
+
367
+ def _init_weights(self, module):
368
+ """Initialize the weights."""
369
+ if isinstance(module, (nn.Linear,)):
370
+ # Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization
371
+ # cf https://github.com/pytorch/pytorch/pull/5617
372
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
373
+ if module.bias is not None:
374
+ module.bias.data.zero_()
375
+ elif isinstance(module, nn.Embedding):
376
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
377
+ if module.padding_idx is not None:
378
+ module.weight.data[module.padding_idx].zero_()
379
+ elif isinstance(module, nn.LayerNorm):
380
+ module.bias.data.zero_()
381
+ module.weight.data.fill_(1.0)
382
+
383
+ class ModularProGenModel(ProGenPreTrainedModel):
384
+
385
+ def __init__(self, config):
386
+ super().__init__(config)
387
+
388
+ self.embed_dim = config.n_embd
389
+ self.vocab_size = config.vocab_size
390
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
391
+ self.drop = nn.Dropout(config.embd_pdrop)
392
+ self.h = nn.ModuleList(
393
+ [ProGenBlock(config) for _ in range(config.n_layer)])
394
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
395
+ self.rotary_dim = min(config.rotary_dim,
396
+ config.n_ctx // config.num_attention_heads)
397
+ self.init_weights()
398
+
399
+ def get_input_embeddings(self):
400
+ return self.wte
401
+
402
+ def set_input_embeddings(self, new_embeddings):
403
+ self.wte = new_embeddings
404
+
405
+ def forward_prep(
406
+ self,
407
+ input_ids=None,
408
+ past_key_values=None,
409
+ attention_mask=None,
410
+ token_type_ids=None,
411
+ position_ids=None,
412
+ head_mask=None,
413
+ inputs_embeds=None,
414
+ use_cache=None,
415
+ output_attentions=None,
416
+ output_hidden_states=None,
417
+ return_dict=None,
418
+ ):
419
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
420
+ output_hidden_states = (output_hidden_states
421
+ if output_hidden_states is not None else
422
+ self.config.output_hidden_states)
423
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
424
+
425
+ if getattr(self.config, "gradient_checkpointing",
426
+ False) and self.training:
427
+ #print('using gradient checkpointing')
428
+ if use_cache:
429
+ use_cache = False
430
+
431
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
432
+
433
+ if input_ids is not None and inputs_embeds is not None:
434
+ raise ValueError(
435
+ "You cannot specify both input_ids and inputs_embeds at the same time"
436
+ )
437
+ elif input_ids is not None:
438
+ input_shape = input_ids.size()
439
+ input_ids = input_ids.view(-1, input_shape[-1])
440
+ batch_size = input_ids.shape[0]
441
+ elif inputs_embeds is not None:
442
+ input_shape = inputs_embeds.size()[:-1]
443
+ batch_size = inputs_embeds.shape[0]
444
+ else:
445
+ raise ValueError(
446
+ "You have to specify either input_ids or inputs_embeds")
447
+
448
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
449
+
450
+ if token_type_ids is not None:
451
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
452
+
453
+ if position_ids is not None:
454
+ position_ids = position_ids.view(-1, input_shape[-1])
455
+
456
+ if past_key_values is None:
457
+ past_length = 0
458
+ past_key_values = tuple([None] * len(self.h))
459
+ else:
460
+ past_length = past_key_values[0][0].size(-2)
461
+
462
+ if position_ids is None:
463
+ position_ids = torch.arange(past_length,
464
+ input_shape[-1] + past_length,
465
+ dtype=torch.long,
466
+ device=device)
467
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
468
+
469
+ # Attention mask.
470
+ if attention_mask is not None:
471
+ assert batch_size > 0, "batch_size has to be defined and > 0"
472
+ attention_mask = attention_mask.view(batch_size, -1)
473
+ # We create a 3D attention mask from a 2D tensor mask.
474
+ # Sizes are [batch_size, 1, 1, to_seq_length]
475
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
476
+ # this attention mask is more simple than the triangular masking of causal attention
477
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
478
+ attention_mask = attention_mask[:, None, None, :]
479
+
480
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
481
+ # masked positions, this operation will create a tensor which is 0.0 for
482
+ # positions we want to attend and -10000.0 for masked positions.
483
+ # Since we are adding it to the raw scores before the softmax, this is
484
+ # effectively the same as removing these entirely.
485
+ attention_mask = attention_mask.to(
486
+ dtype=self.dtype) # fp16 compatibility
487
+ attention_mask = (1.0 - attention_mask) * -10000.0
488
+
489
+ # Prepare head mask if needed
490
+ # 1.0 in head_mask indicate we keep the head
491
+ # attention_probs has shape bsz x num_attention_heads x N x N
492
+ # head_mask has shape n_layer x batch x num_attention_heads x N x N
493
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
494
+
495
+ return input_ids, attention_mask, head_mask, position_ids, token_type_ids, inputs_embeds, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict
496
+
497
+ def forward_embed(
498
+ self,
499
+ input_ids=None,
500
+ token_type_ids=None,
501
+ inputs_embeds=None,
502
+ ):
503
+ if inputs_embeds is None:
504
+ inputs_embeds = self.wte(input_ids)
505
+
506
+ hidden_states = inputs_embeds
507
+
508
+ if token_type_ids is not None:
509
+ token_type_embeds = self.wte(token_type_ids)
510
+ hidden_states = hidden_states + token_type_embeds
511
+
512
+ hidden_states = self.drop(hidden_states)
513
+
514
+ return hidden_states
515
+
516
+ def forward_layer(
517
+ self,
518
+ hidden_states,
519
+ layer_i,
520
+ layer_past=None,
521
+ attention_mask=None,
522
+ head_mask=None,
523
+ adapter_layer=None,
524
+ adapter_dropout=None,
525
+ adapter_input=None,
526
+ use_cache=None,
527
+ output_attentions=None,
528
+ ):
529
+ if getattr(self.config, "gradient_checkpointing",
530
+ False) and self.training:
531
+ if use_cache:
532
+ logger.warning(
533
+ "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
534
+ "`use_cache=False`...")
535
+ use_cache = False
536
+
537
+ def create_custom_forward(module):
538
+
539
+ def custom_forward(*inputs):
540
+ # None for past_key_value
541
+ return module(*inputs, use_cache, output_attentions)
542
+
543
+ return custom_forward
544
+
545
+ outputs = torch.utils.checkpoint.checkpoint(
546
+ create_custom_forward(self.h[layer_i]),
547
+ hidden_states,
548
+ None,
549
+ attention_mask,
550
+ head_mask[layer_i],
551
+ adapter_layer,
552
+ adapter_dropout,
553
+ adapter_input,
554
+ )
555
+ else:
556
+ outputs = self.h[layer_i](
557
+ hidden_states,
558
+ layer_past=layer_past,
559
+ attention_mask=attention_mask,
560
+ head_mask=head_mask[layer_i],
561
+ adapter_layer=adapter_layer,
562
+ adapter_dropout=adapter_dropout,
563
+ adapter_input=adapter_input,
564
+ use_cache=use_cache,
565
+ output_attentions=output_attentions,
566
+ )
567
+
568
+ hidden_states = outputs[0]
569
+
570
+ if use_cache:
571
+ presents = (outputs[1], )
572
+ else:
573
+ presents = None
574
+
575
+ if output_attentions:
576
+ self_attentions = outputs[2 if use_cache else 1]
577
+ else:
578
+ self_attentions = None
579
+
580
+ return hidden_states, presents, self_attentions
581
+
582
+ def forward_layers(
583
+ self,
584
+ hidden_states,
585
+ past_key_values=None,
586
+ attention_mask=None,
587
+ head_mask=None,
588
+ use_cache=None,
589
+ output_attentions=None,
590
+ output_hidden_states=None,
591
+ ):
592
+ all_presents = () if use_cache else None
593
+ all_self_attentions = () if output_attentions else None
594
+ all_hidden_states = () if output_hidden_states else None
595
+ for i in range(self.config.n_layer):
596
+ if output_hidden_states:
597
+ all_hidden_states = all_hidden_states + (hidden_states, )
598
+
599
+ hidden_states, presents, self_attentions = self.forward_layer(
600
+ hidden_states,
601
+ i,
602
+ layer_past=past_key_values[i]
603
+ if past_key_values is not None else None,
604
+ attention_mask=attention_mask,
605
+ head_mask=head_mask,
606
+ use_cache=use_cache,
607
+ output_attentions=output_attentions,
608
+ )
609
+
610
+ if use_cache is True:
611
+ all_presents = all_presents + presents
612
+ if output_attentions:
613
+ all_self_attentions = all_self_attentions + (self_attentions, )
614
+
615
+ return hidden_states, all_presents, all_self_attentions, all_hidden_states
616
+
617
+ def forward(
618
+ self,
619
+ input_ids=None,
620
+ past_key_values=None,
621
+ attention_mask=None,
622
+ token_type_ids=None,
623
+ position_ids=None,
624
+ head_mask=None,
625
+ inputs_embeds=None,
626
+ use_cache=None,
627
+ output_attentions=None,
628
+ output_hidden_states=None,
629
+ return_dict=None,
630
+ ):
631
+ input_shape = input_ids.size()
632
+ input_ids, attention_mask, head_mask, position_ids, token_type_ids, inputs_embeds, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict = self.forward_prep(
633
+ input_ids=input_ids,
634
+ past_key_values=past_key_values,
635
+ attention_mask=attention_mask,
636
+ token_type_ids=token_type_ids,
637
+ position_ids=position_ids,
638
+ head_mask=head_mask,
639
+ inputs_embeds=inputs_embeds,
640
+ use_cache=use_cache,
641
+ output_attentions=output_attentions,
642
+ output_hidden_states=output_hidden_states,
643
+ return_dict=return_dict,
644
+ )
645
+
646
+ hidden_states = self.forward_embed(
647
+ input_ids=input_ids,
648
+ token_type_ids=token_type_ids,
649
+ inputs_embeds=inputs_embeds,
650
+ )
651
+
652
+ hidden_states, all_presents, all_self_attentions, all_hidden_states = self.forward_layers(
653
+ hidden_states=hidden_states,
654
+ past_key_values=past_key_values,
655
+ attention_mask=attention_mask,
656
+ head_mask=head_mask,
657
+ use_cache=use_cache,
658
+ output_attentions=output_attentions,
659
+ output_hidden_states=output_hidden_states,
660
+ )
661
+
662
+ hidden_states = self(hidden_states)
663
+
664
+ output_shape = input_shape + (hidden_states.size(-1), )
665
+ hidden_states = hidden_states.view(*output_shape)
666
+ # Add last hidden state
667
+ if output_hidden_states:
668
+ all_hidden_states = all_hidden_states + (hidden_states, )
669
+
670
+ if not return_dict:
671
+ return tuple(v for v in [
672
+ hidden_states, all_presents, all_hidden_states,
673
+ all_self_attentions
674
+ ] if v is not None)
675
+
676
+ return BaseModelOutputWithPast(
677
+ last_hidden_state=hidden_states,
678
+ past_key_values=all_presents,
679
+ hidden_states=all_hidden_states,
680
+ attentions=all_self_attentions,
681
+ )
682
+
683
+ class ModularProGenForCausalLM(ProGenPreTrainedModel):
684
+ _keys_to_ignore_on_load_missing = [
685
+ r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"
686
+ ]
687
+
688
+ def __init__(self, config):
689
+ super().__init__(config)
690
+
691
+ self.transformer = ModularProGenModel(config)
692
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
693
+ self.init_weights()
694
+
695
+ def get_output_embeddings(self):
696
+ return None
697
+
698
+ def set_output_embeddings(self, new_embeddings):
699
+ return
700
+
701
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
702
+ token_type_ids = kwargs.get("token_type_ids", None)
703
+ # only last token for inputs_ids if past is defined in kwargs
704
+ if past:
705
+ input_ids = input_ids[:, -1].unsqueeze(-1)
706
+ if token_type_ids is not None:
707
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
708
+
709
+ attention_mask = kwargs.get("attention_mask", None)
710
+ position_ids = kwargs.get("position_ids", None)
711
+
712
+ if attention_mask is not None and position_ids is None:
713
+ # create position_ids on the fly for batch generation
714
+ position_ids = attention_mask.long().cumsum(-1) - 1
715
+ position_ids.masked_fill_(attention_mask == 0, 1)
716
+ if past:
717
+ position_ids = position_ids[:, -1].unsqueeze(-1)
718
+ else:
719
+ position_ids = None
720
+ return {
721
+ "input_ids": input_ids,
722
+ "past_key_values": past,
723
+ "use_cache": kwargs.get("use_cache"),
724
+ "position_ids": position_ids,
725
+ "attention_mask": attention_mask,
726
+ "token_type_ids": token_type_ids,
727
+ }
728
+
729
+ def forward(
730
+ self,
731
+ input_ids=None,
732
+ past_key_values=None,
733
+ attention_mask=None,
734
+ token_type_ids=None,
735
+ position_ids=None,
736
+ head_mask=None,
737
+ inputs_embeds=None,
738
+ labels=None,
739
+ use_cache=None,
740
+ output_attentions=None,
741
+ output_hidden_states=None,
742
+ return_dict=None,
743
+ ):
744
+ r"""
745
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
746
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
747
+ ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
748
+ ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
749
+ """
750
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
751
+
752
+ transformer_outputs = self.transformer(
753
+ input_ids,
754
+ past_key_values=past_key_values,
755
+ attention_mask=attention_mask,
756
+ token_type_ids=token_type_ids,
757
+ position_ids=position_ids,
758
+ head_mask=head_mask,
759
+ inputs_embeds=inputs_embeds,
760
+ use_cache=use_cache,
761
+ output_attentions=output_attentions,
762
+ output_hidden_states=output_hidden_states,
763
+ return_dict=return_dict,
764
+ )
765
+ hidden_states = transformer_outputs[0]
766
+
767
+ # make sure sampling in fp16 works correctly and
768
+ # compute loss in fp32 to match with mesh-tf version
769
+ # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
770
+ lm_logits = self.lm_head(hidden_states).to(torch.float32)
771
+
772
+ loss = None
773
+ if labels is not None:
774
+ # Shift so that tokens < n predict n
775
+ shift_logits = lm_logits[..., :-1, :].contiguous()
776
+ shift_labels = labels[..., 1:].contiguous()
777
+ # Flatten the tokens
778
+ loss_fct = CrossEntropyLoss()
779
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
780
+ shift_labels.view(-1))
781
+
782
+ loss = loss.to(hidden_states.dtype)
783
+
784
+ if not return_dict:
785
+ output = (lm_logits, ) + transformer_outputs[1:]
786
+ return ((loss, ) + output) if loss is not None else output
787
+
788
+ return CausalLMOutputWithPast(
789
+ loss=loss,
790
+ logits=lm_logits,
791
+ past_key_values=transformer_outputs.past_key_values,
792
+ hidden_states=transformer_outputs.hidden_states,
793
+ attentions=transformer_outputs.attentions,
794
+ )
795
+
796
+ @staticmethod
797
+ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]],
798
+ beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
799
+ """
800
+ This function is used to re-order the :obj:`past_key_values` cache if
801
+ :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
802
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
803
+ """
804
+ return tuple(
805
+ tuple(
806
+ past_state.index_select(0, beam_idx.to(past_state.device))
807
+ for past_state in layer_past) for layer_past in past)
808
+
809
+
810
+ class ProgenConditional(ProGenPreTrainedModel): #nn.Module
811
+ def __init__(self, config: ProGenConditionalConfig):
812
+ super().__init__(config)
813
+
814
+ #self.model = ModularProGenForCausalLM.from_pretrained(pretrained_model_name_or_path=config.pretrained_model_dir, config=config)
815
+ self.model = ModularProGenForCausalLM.from_pretrained("jsunn-y/ProCALM", subfolder="progen2-base", config=config, cache_dir=config.pretrained_model_dir)
816
+ self.model.requires_grad_(False) #freeze the pretrained model by default
817
+
818
+ self.config = config
819
+
820
+ self.projection_mlps = torch.nn.ModuleDict() #conditioning encoders
821
+ if config.adapter_shared_projection == True:
822
+ n_projection_mlps = 1 #sharing a projector
823
+ else:
824
+ n_projection_mlps = len(self.model.transformer.h) #having a projector for every layer
825
+
826
+ for key, input_dim in config.encoding_dimensions.items():
827
+ adapter_projection_layers = nn.ModuleList()
828
+ for i in range(n_projection_mlps):
829
+ if config.adapter_projection_nlayers == None:
830
+ projection_mlp = torch.nn.Linear(input_dim, config.adapter_c_s)
831
+ else:
832
+ projection_mlp = ProjectionMLP(input_dim=input_dim, c_s=config.adapter_c_s, num_layers=config.adapter_projection_nlayers)
833
+ adapter_projection_layers.append(projection_mlp)
834
+
835
+ self.projection_mlps[key] = adapter_projection_layers
836
+
837
+ #if using a shared adapter, append an extra MLP to process the summed input
838
+ #not necessary if you have a separate adapter for each layer
839
+ #this one is always nonlinear and uses two layers
840
+ if (config.conditions_shared_adapter == True) and (len(config.encoding_dimensions.values()) >=2):
841
+ adapter_projection_layers = nn.ModuleList()
842
+ for i in range(n_projection_mlps):
843
+ projection_mlp = ProjectionMLP(input_dim=config.adapter_c_s, c_s=config.adapter_c_s, num_layers=2)
844
+ adapter_projection_layers.append(projection_mlp)
845
+
846
+ self.projection_mlps["combination"] = adapter_projection_layers
847
+
848
+ #initialize the adapter layers
849
+ self.adapter_layers = torch.nn.ModuleList()
850
+ if config.conditions_shared_adapter == False:
851
+ keys = config.encoding_dimensions.keys()
852
+ else:
853
+ keys = ["joint"]
854
+ n_parallel = len(keys)
855
+
856
+ for i in range(len(self.model.transformer.h)):
857
+ parallel_adapter_layer = ParallelAdapterLayer(
858
+ n_parallel=n_parallel,
859
+ c_s=config.adapter_c_s,
860
+ c_h=config.n_embd,
861
+ adapter_summation=config.adapter_summation,
862
+ weight_init=config.adapter_weight_init,
863
+ adapter_nlayers=config.adapter_nlayers,
864
+ )
865
+ adapter_dropout = torch.nn.Dropout(config.adapter_dropout)
866
+ self.adapter_layers.append(nn.ModuleList([parallel_adapter_layer, adapter_dropout]))
867
+
868
+ def prepare_inputs_for_generation(self, input_ids, condition_encodings: Dict[str, torch.tensor] = None, past=None, **kwargs):
869
+ """
870
+ Overides the prepare inputs for generation function (HF compatible) to allow for the addition of adapter input.
871
+ """
872
+ token_type_ids = kwargs.get("token_type_ids", None)
873
+ # only last token for inputs_ids if past is defined in kwargs
874
+ past = kwargs.get("past_key_values", past)
875
+ if past:
876
+ input_ids = input_ids[:, -1].unsqueeze(-1)
877
+ if token_type_ids is not None:
878
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
879
+
880
+ attention_mask = kwargs.get("attention_mask", None)
881
+ position_ids = kwargs.get("position_ids", None)
882
+
883
+ if attention_mask is not None and position_ids is None:
884
+ # create position_ids on the fly for batch generation
885
+ position_ids = attention_mask.long().cumsum(-1) - 1
886
+ position_ids.masked_fill_(attention_mask == 0, 1)
887
+ if past:
888
+ position_ids = position_ids[:, -1].unsqueeze(-1)
889
+ else:
890
+ position_ids = None
891
+
892
+ adapter_input = {}
893
+ for key, condition_encoding in condition_encodings.items():
894
+ if condition_encoding is not None:
895
+ single_adapter_input = condition_encoding.repeat(input_ids.shape[0], input_ids.shape[1], 1)
896
+ else:
897
+ single_adapter_input = None
898
+ adapter_input[key] = single_adapter_input
899
+
900
+ return {
901
+ "input_ids": input_ids,
902
+ "past_key_values": past,
903
+ "position_ids": position_ids,
904
+ "attention_mask": attention_mask,
905
+ "token_type_ids": token_type_ids,
906
+ "adapter_input": adapter_input,
907
+ }
908
+
909
+ @staticmethod
910
+ def _reorder_cache(past_key_values, beam_idx):
911
+ if isinstance(past_key_values, Cache):
912
+ return past_key_values.reorder_cache(beam_idx)
913
+
914
+ reordered_past = ()
915
+ for layer_past in past_key_values:
916
+ reordered_past += (
917
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
918
+ )
919
+ return DynamicCache.from_legacy_cache(reordered_past)
920
+
921
+ def forward(
922
+ self,
923
+ input_ids=None,
924
+ past_key_values=None,
925
+ attention_mask=None,
926
+ token_type_ids=None,
927
+ position_ids=None,
928
+ head_mask=None,
929
+ inputs_embeds=None,
930
+ labels=None,
931
+ use_cache=None,
932
+ output_attentions=None,
933
+ output_hidden_states=None,
934
+ return_dict=None,
935
+ adapter_input=None,
936
+ ):
937
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
938
+
939
+ input_shape = input_ids.size()
940
+
941
+ input_ids, attention_mask, head_mask, position_ids, token_type_ids, inputs_embeds, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict = self.model.transformer.forward_prep(
942
+ input_ids=input_ids,
943
+ past_key_values=past_key_values,
944
+ attention_mask=attention_mask,
945
+ token_type_ids=token_type_ids,
946
+ position_ids=position_ids,
947
+ head_mask=head_mask,
948
+ inputs_embeds=inputs_embeds,
949
+ use_cache=use_cache,
950
+ output_attentions=output_attentions,
951
+ output_hidden_states=output_hidden_states,
952
+ return_dict=return_dict,
953
+ )
954
+
955
+ hidden_states = self.model.transformer.forward_embed(
956
+ input_ids=input_ids,
957
+ token_type_ids=token_type_ids,
958
+ inputs_embeds=inputs_embeds,
959
+ )
960
+
961
+ all_presents = () if use_cache else None
962
+ all_self_attentions = () if output_attentions else None
963
+ all_hidden_states = () if output_hidden_states else None
964
+
965
+ #project the condition to the dimension of the adapter
966
+ #if sharing a single projection layer
967
+ #else do nothing until we get into the loop
968
+ if self.config.adapter_shared_projection == True:
969
+ encoded_adapter_input = ()
970
+ #if you're sharing an adapter and doing joint conditioning
971
+ if len(adapter_input.keys()) >= 2 and self.config.conditions_shared_adapter == True:
972
+ summed_adapter_input = torch.zeros(input_shape[0], input_shape[1], self.config.adapter_c_s).to(input_ids.device)
973
+ for key, single_adapter_input in adapter_input.items():
974
+ projected_adapter_input = self.projection_mlps[key][0](single_adapter_input)
975
+ summed_adapter_input += projected_adapter_input
976
+
977
+ #combine the inputs and pass through one
978
+ key = "combination"
979
+ summed_adapter_input = self.projection_mlps[key][0](summed_adapter_input)
980
+ encoded_adapter_input = (summed_adapter_input, )
981
+
982
+ #if you're not sharing an adapter (with or without multiple conditions)
983
+ else:
984
+ for key, value in adapter_input.items():
985
+ summed_adapter_input = self.projection_mlps[key][0](value)
986
+ encoded_adapter_input = encoded_adapter_input + (summed_adapter_input, )
987
+ encoded_adapter_input = torch.stack(encoded_adapter_input, dim=0)
988
+
989
+ for i in range(len(self.model.transformer.h)):
990
+ #if not sharing a projection layer
991
+ if self.config.adapter_shared_projection == False:
992
+ encoded_adapter_input = ()
993
+ #if you're sharing an adapter and doing joint conditioning
994
+ if len(adapter_input.keys()) >= 2 and self.config.conditions_shared_adapter == True:
995
+ summed_adapter_input = torch.zeros(input_shape[0], input_shape[1], self.config.adapter_c_s).to(input_ids.device)
996
+ for key, single_adapter_input in adapter_input.items():
997
+ projected_adapter_input = self.projection_mlps[key][i](single_adapter_input)
998
+ encoded_adapter_input += projected_adapter_input
999
+
1000
+ #combine the inputs and pass through one more mlp
1001
+ key = "combination"
1002
+ summed_adapter_input = self.projection_mlps[key][i](summed_adapter_input)
1003
+ encoded_adapter_input = (summed_adapter_input, )
1004
+
1005
+ #if you're not sharing an adapter (with or without multiple conditions)
1006
+ else:
1007
+ for key, value in adapter_input.items():
1008
+ summed_adapter_input = self.projection_mlps[key][i](value)
1009
+ encoded_adapter_input = encoded_adapter_input + (summed_adapter_input, )
1010
+ encoded_adapter_input = torch.stack(encoded_adapter_input, dim=0)
1011
+
1012
+ if output_hidden_states:
1013
+ all_hidden_states = all_hidden_states + (hidden_states, )
1014
+
1015
+ hidden_states, presents, self_attentions = self.model.transformer.forward_layer(
1016
+ hidden_states=hidden_states,
1017
+ layer_i=i,
1018
+ layer_past=past_key_values[i] if past_key_values[i] is not None else None,
1019
+ attention_mask=attention_mask,
1020
+ head_mask=head_mask,
1021
+ use_cache=use_cache,
1022
+ output_attentions=output_attentions,
1023
+ adapter_layer=self.adapter_layers[i][0],
1024
+ adapter_dropout=self.adapter_layers[i][1],
1025
+ adapter_input=encoded_adapter_input,
1026
+ )
1027
+
1028
+ if use_cache is True:
1029
+ all_presents = all_presents + presents
1030
+ if output_attentions:
1031
+ all_self_attentions = all_self_attentions + (self_attentions, )
1032
+
1033
+ hidden_states = self.model.transformer.ln_f(hidden_states)
1034
+
1035
+ output_shape = input_shape + (hidden_states.size(-1), )
1036
+ hidden_states = hidden_states.view(*output_shape)
1037
+
1038
+ if output_hidden_states:
1039
+ all_hidden_states = all_hidden_states + (hidden_states, )
1040
+
1041
+ if not return_dict:
1042
+ return tuple(v for v in [
1043
+ hidden_states, all_presents, all_hidden_states,
1044
+ all_self_attentions
1045
+ ] if v is not None)
1046
+
1047
+ transformer_outputs = BaseModelOutputWithPast(
1048
+ last_hidden_state=hidden_states,
1049
+ past_key_values=all_presents,
1050
+ hidden_states=all_hidden_states,
1051
+ attentions=all_self_attentions,
1052
+ )
1053
+
1054
+ hidden_states = transformer_outputs[0]
1055
+
1056
+ # make sure sampling in fp16 works correctly and
1057
+ # compute loss in fp32 to match with mesh-tf version
1058
+ # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
1059
+ lm_logits = self.model.lm_head(hidden_states).to(torch.float32)
1060
+
1061
+ loss = None
1062
+ all_losses = None
1063
+ if labels is not None:
1064
+ # Shift so that tokens < n predict n
1065
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1066
+ shift_labels = labels[..., 1:].contiguous()
1067
+
1068
+ #added this so that the loss of each sample is outputted
1069
+ loss_fct = CrossEntropyLoss(ignore_index=0, reduction='none')
1070
+ all_losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
1071
+ shift_labels.view(-1))
1072
+ all_losses = all_losses.to(hidden_states.dtype)
1073
+
1074
+ #still output the mean reduced loss
1075
+ loss_fct = CrossEntropyLoss(ignore_index=0)
1076
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
1077
+ shift_labels.view(-1))
1078
+
1079
+ if not return_dict:
1080
+ output = (lm_logits, ) + transformer_outputs[1:]
1081
+ return ((loss, ) + output) if loss is not None else output
1082
+
1083
+ return CausalLMOutputWithPast(
1084
+ loss=loss,
1085
+ all_losses=all_losses,
1086
+ logits=lm_logits,
1087
+ past_key_values=transformer_outputs.past_key_values,
1088
+ hidden_states=transformer_outputs.hidden_states,
1089
+ attentions=transformer_outputs.attentions,
1090
+ )