Zilin Zhu commited on
Commit
980d254
·
1 Parent(s): d544be0
config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_function": "quick_gelu",
3
+ "architectures": [
4
+ "GPT2LMHeadCustomModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoModelForCausalLM": "modeling_gpt2_summarize.GPT2LMHeadCustomModel"
8
+ },
9
+ "attn_pdrop": 0,
10
+ "bos_token_id": 50256,
11
+ "embd_pdrop": 0,
12
+ "eos_token_id": 50256,
13
+ "initializer_range": 0.02,
14
+ "layer_norm_epsilon": 1e-05,
15
+ "model_type": "gpt2",
16
+ "n_embd": 2048,
17
+ "n_head": 16,
18
+ "n_inner": null,
19
+ "n_layer": 24,
20
+ "n_positions": 2048,
21
+ "reorder_and_upcast_attn": false,
22
+ "resid_pdrop": 0,
23
+ "scale_attn_by_inverse_layer_idx": false,
24
+ "scale_attn_weights": true,
25
+ "summary_activation": null,
26
+ "summary_first_dropout": 0.1,
27
+ "summary_proj_to_labels": true,
28
+ "summary_type": "cls_index",
29
+ "summary_use_proj": true,
30
+ "tie_word_embeddings": false,
31
+ "torch_dtype": "float32",
32
+ "transformers_version": "4.25.1",
33
+ "use_cache": true,
34
+ "vocab_size": 50304
35
+ }
modeling_gpt2_summarize.py ADDED
@@ -0,0 +1,1074 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ PyTorch OpenAI GPT-2 model in Learning to Summarize with Human Feedback.
18
+ https://openai.com/blog/learning-to-summarize-with-human-feedback/
19
+ https://arxiv.org/abs/2009.01325
20
+ https://github.com/openai/summarize-from-feedback
21
+ """
22
+
23
+ import math
24
+ from typing import Optional, Tuple, Union
25
+
26
+ import torch
27
+ import torch.utils.checkpoint
28
+ from torch import nn
29
+ from torch.cuda.amp import autocast
30
+ from torch.nn import CrossEntropyLoss
31
+
32
+ from transformers.activations import ACT2FN
33
+ from transformers.modeling_outputs import (
34
+ BaseModelOutputWithPastAndCrossAttentions,
35
+ CausalLMOutputWithCrossAttentions,
36
+ )
37
+ from transformers.modeling_utils import PreTrainedModel
38
+ from transformers.utils import (
39
+ add_code_sample_docstrings,
40
+ add_start_docstrings,
41
+ add_start_docstrings_to_model_forward,
42
+ logging,
43
+ )
44
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
45
+ from transformers import GPT2Config
46
+
47
+ import numpy as np
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+ _CHECKPOINT_FOR_DOC = "gpt2"
52
+ _CONFIG_FOR_DOC = "GPT2Config"
53
+ _TOKENIZER_FOR_DOC = "GPT2Tokenizer"
54
+
55
+
56
+ class Conv1D(nn.Module):
57
+ """
58
+ 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
59
+
60
+ Basically works like a linear layer but the weights are transposed.
61
+
62
+ Args:
63
+ nf (`int`): The number of output features.
64
+ nx (`int`): The number of input features.
65
+ """
66
+
67
+ def __init__(self, nf, nx, bias=True):
68
+ super().__init__()
69
+ self.nf = nf
70
+ w = torch.empty(nx, nf)
71
+ nn.init.normal_(w, std=0.02)
72
+ self.weight = nn.Parameter(w)
73
+ if bias:
74
+ self.bias = nn.Parameter(torch.zeros(nf))
75
+ else:
76
+ self.bias = None
77
+
78
+ def forward(self, x):
79
+ size_out = x.size()[:-1] + (self.nf,)
80
+ if self.bias is not None:
81
+ x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
82
+ else:
83
+ x = torch.mm(x.view(-1, x.size(-1)), self.weight)
84
+ x = x.view(size_out)
85
+ return x
86
+
87
+
88
+ def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1, bias: bool = True) -> Conv1D:
89
+ """
90
+ Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights
91
+ are transposed.
92
+
93
+ Used to remove heads.
94
+
95
+ Args:
96
+ layer ([`~pytorch_utils.Conv1D`]): The layer to prune.
97
+ index (`torch.LongTensor`): The indices to keep in the layer.
98
+ dim (`int`, *optional*, defaults to 1): The dimension on which to keep the indices.
99
+
100
+ Returns:
101
+ [`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
102
+ """
103
+ index = index.to(layer.weight.device)
104
+ W = layer.weight.index_select(dim, index).clone().detach()
105
+ if bias:
106
+ if dim == 0:
107
+ b = layer.bias.clone().detach()
108
+ else:
109
+ b = layer.bias[index].clone().detach()
110
+ new_size = list(layer.weight.size())
111
+ new_size[dim] = len(index)
112
+ new_layer = Conv1D(new_size[1], new_size[0], bias).to(layer.weight.device)
113
+ new_layer.weight.requires_grad = False
114
+ new_layer.weight.copy_(W.contiguous())
115
+ new_layer.weight.requires_grad = True
116
+ if bias:
117
+ new_layer.bias.requires_grad = False
118
+ new_layer.bias.copy_(b.contiguous())
119
+ new_layer.bias.requires_grad = True
120
+ return new_layer
121
+
122
+
123
+ class GPT2Attention(nn.Module):
124
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
125
+ super().__init__()
126
+
127
+ max_positions = config.max_position_embeddings
128
+ self.register_buffer(
129
+ "bias",
130
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
131
+ 1, 1, max_positions, max_positions
132
+ ),
133
+ )
134
+ self.register_buffer("masked_bias", torch.tensor(-1e4))
135
+
136
+ self.embed_dim = config.hidden_size
137
+ self.num_heads = config.num_attention_heads
138
+ self.head_dim = self.embed_dim // self.num_heads
139
+ self.split_size = self.embed_dim
140
+ if self.head_dim * self.num_heads != self.embed_dim:
141
+ raise ValueError(
142
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
143
+ f" {self.num_heads})."
144
+ )
145
+
146
+ self.scale_attn_weights = config.scale_attn_weights
147
+ self.is_cross_attention = is_cross_attention
148
+
149
+ # Layer-wise attention scaling, reordering, and upcasting
150
+ self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
151
+ self.layer_idx = layer_idx
152
+ self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
153
+
154
+ if self.is_cross_attention:
155
+ raise NotImplementedError("should not enter this path.")
156
+ self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
157
+ self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
158
+ else:
159
+ #self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
160
+ self.q_proj = Conv1D(self.embed_dim, self.embed_dim)
161
+ self.k_proj = Conv1D(self.embed_dim, self.embed_dim, bias=False)
162
+ self.v_proj = Conv1D(self.embed_dim, self.embed_dim)
163
+ self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
164
+
165
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
166
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
167
+
168
+ self.pruned_heads = set()
169
+
170
+ def prune_heads(self, heads):
171
+ if len(heads) == 0:
172
+ return
173
+ heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
174
+ index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
175
+
176
+ # Prune conv1d layers
177
+ #self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
178
+ self.q_proj = prune_conv1d_layer(self.q_proj, index_attn, dim=1)
179
+ self.k_proj = prune_conv1d_layer(self.q_proj, index_attn, dim=1, bias=False)
180
+ self.v_proj = prune_conv1d_layer(self.q_proj, index_attn, dim=1)
181
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
182
+
183
+ # Update hyper params
184
+ self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
185
+ self.num_heads = self.num_heads - len(heads)
186
+ self.pruned_heads = self.pruned_heads.union(heads)
187
+
188
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
189
+ if self.scale_attn_weights:
190
+ # Pre-divide by fp16_stability_scale to prevent fp16 overflow
191
+ softmax_scale = 1.0 / np.sqrt(np.sqrt(query.size(-1)))
192
+ query = query * softmax_scale
193
+ key = key * softmax_scale
194
+
195
+ attn_weights = torch.matmul(query, key)
196
+
197
+ if not self.is_cross_attention:
198
+ # if only "normal" attention layer implements causal mask
199
+ query_length, key_length = query.size(-2), key.size(-1)
200
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
201
+ mask_value = torch.finfo(attn_weights.dtype).min
202
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
203
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
204
+ mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
205
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
206
+
207
+ if attention_mask is not None:
208
+ # Apply the attention mask
209
+ attn_weights = attn_weights + attention_mask
210
+
211
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
212
+
213
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
214
+ attn_weights = attn_weights.type(value.dtype)
215
+ attn_weights = self.attn_dropout(attn_weights)
216
+
217
+ # Mask heads if we want to
218
+ if head_mask is not None:
219
+ attn_weights = attn_weights * head_mask
220
+
221
+ attn_output = torch.matmul(attn_weights, value)
222
+
223
+ return attn_output, attn_weights
224
+
225
+ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
226
+ # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
227
+ bsz, num_heads, q_seq_len, dk = query.size()
228
+ _, _, k_seq_len, _ = key.size()
229
+
230
+ # Preallocate attn_weights for `baddbmm`
231
+ attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
232
+
233
+ # Compute Scale Factor
234
+ scale_factor = 1.0
235
+ if self.scale_attn_weights:
236
+ scale_factor /= float(value.size(-1)) ** 0.5
237
+
238
+ if self.scale_attn_by_inverse_layer_idx:
239
+ scale_factor /= float(self.layer_idx + 1)
240
+
241
+ # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
242
+ with autocast(enabled=False):
243
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
244
+ attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
245
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
246
+
247
+ if not self.is_cross_attention:
248
+ # if only "normal" attention layer implements causal mask
249
+ query_length, key_length = query.size(-2), key.size(-2)
250
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
251
+ mask_value = torch.finfo(attn_weights.dtype).min
252
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
253
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
254
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
255
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
256
+
257
+ if attention_mask is not None:
258
+ # Apply the attention mask
259
+ attn_weights = attn_weights + attention_mask
260
+
261
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
262
+
263
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
264
+ if attn_weights.dtype != torch.float32:
265
+ raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
266
+ attn_weights = attn_weights.type(value.dtype)
267
+ attn_weights = self.attn_dropout(attn_weights)
268
+
269
+ # Mask heads if we want to
270
+ if head_mask is not None:
271
+ attn_weights = attn_weights * head_mask
272
+
273
+ attn_output = torch.matmul(attn_weights, value)
274
+
275
+ return attn_output, attn_weights
276
+
277
+ def _split_heads(self, tensor, num_heads, attn_head_size, k=False):
278
+ """
279
+ Splits hidden_size dim into attn_head_size and num_heads
280
+ """
281
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
282
+ tensor = tensor.view(new_shape)
283
+ if k:
284
+ return tensor.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
285
+ else:
286
+ return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
287
+
288
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
289
+ """
290
+ Merges attn_head_size dim and num_attn_heads dim into hidden_size
291
+ """
292
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
293
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
294
+ return tensor.view(new_shape)
295
+
296
+ def forward(
297
+ self,
298
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
299
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
300
+ attention_mask: Optional[torch.FloatTensor] = None,
301
+ head_mask: Optional[torch.FloatTensor] = None,
302
+ encoder_hidden_states: Optional[torch.Tensor] = None,
303
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
304
+ use_cache: Optional[bool] = False,
305
+ output_attentions: Optional[bool] = False,
306
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
307
+ if encoder_hidden_states is not None:
308
+ raise NotImplementedError("should not enter this path.")
309
+ if not hasattr(self, "q_attn"):
310
+ raise ValueError(
311
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
312
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
313
+ )
314
+
315
+ query = self.q_attn(hidden_states)
316
+ key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
317
+ attention_mask = encoder_attention_mask
318
+ else:
319
+ #query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
320
+ query = self.q_proj(hidden_states)
321
+ key = self.k_proj(hidden_states)
322
+ value = self.v_proj(hidden_states)
323
+
324
+ query = self._split_heads(query, self.num_heads, self.head_dim)
325
+ key = self._split_heads(key, self.num_heads, self.head_dim, k=True)
326
+ value = self._split_heads(value, self.num_heads, self.head_dim)
327
+
328
+ if layer_past is not None:
329
+ past_key, past_value = layer_past
330
+ key = torch.cat((past_key, key), dim=-2)
331
+ value = torch.cat((past_value, value), dim=-2)
332
+
333
+ if use_cache is True:
334
+ present = (key, value)
335
+ else:
336
+ present = None
337
+
338
+ if self.reorder_and_upcast_attn:
339
+ raise NotImplementedError("should not enter this path.")
340
+ attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
341
+ else:
342
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
343
+
344
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
345
+ attn_output = self.c_proj(attn_output)
346
+ attn_output = self.resid_dropout(attn_output)
347
+
348
+ outputs = (attn_output, present)
349
+ if output_attentions:
350
+ outputs += (attn_weights,)
351
+
352
+ return outputs # a, present, (attentions)
353
+
354
+
355
+ class GPT2MLP(nn.Module):
356
+ def __init__(self, intermediate_size, config):
357
+ super().__init__()
358
+ embed_dim = config.hidden_size
359
+ self.c_fc = Conv1D(intermediate_size, embed_dim)
360
+ self.c_proj = Conv1D(embed_dim, intermediate_size)
361
+ self.act = ACT2FN[config.activation_function]
362
+ self.dropout = nn.Dropout(config.resid_pdrop)
363
+
364
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
365
+ hidden_states = self.c_fc(hidden_states)
366
+ hidden_states = self.act(hidden_states)
367
+ hidden_states = self.c_proj(hidden_states)
368
+ hidden_states = self.dropout(hidden_states)
369
+ return hidden_states
370
+
371
+
372
+ class GPT2Block(nn.Module):
373
+ def __init__(self, config, layer_idx=None):
374
+ super().__init__()
375
+ hidden_size = config.hidden_size
376
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
377
+
378
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
379
+ self.attn = GPT2Attention(config, layer_idx=layer_idx)
380
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
381
+
382
+ if config.add_cross_attention:
383
+ self.crossattention = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx)
384
+ self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
385
+
386
+ self.mlp = GPT2MLP(inner_dim, config)
387
+
388
+ def forward(
389
+ self,
390
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
391
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
392
+ attention_mask: Optional[torch.FloatTensor] = None,
393
+ head_mask: Optional[torch.FloatTensor] = None,
394
+ encoder_hidden_states: Optional[torch.Tensor] = None,
395
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
396
+ use_cache: Optional[bool] = False,
397
+ output_attentions: Optional[bool] = False,
398
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
399
+ inputs = hidden_states
400
+ residual = hidden_states
401
+ hidden_states = self.ln_1(hidden_states)
402
+ attn_outputs = self.attn(
403
+ hidden_states,
404
+ layer_past=layer_past,
405
+ attention_mask=attention_mask,
406
+ head_mask=head_mask,
407
+ use_cache=use_cache,
408
+ output_attentions=output_attentions,
409
+ )
410
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
411
+ outputs = attn_outputs[1:]
412
+ # residual connection
413
+ hidden_states = attn_output + residual
414
+
415
+ if encoder_hidden_states is not None:
416
+ # add one self-attention block for cross-attention
417
+ if not hasattr(self, "crossattention"):
418
+ raise ValueError(
419
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
420
+ "cross-attention layers by setting `config.add_cross_attention=True`"
421
+ )
422
+ residual = hidden_states
423
+ hidden_states = self.ln_cross_attn(hidden_states)
424
+ cross_attn_outputs = self.crossattention(
425
+ hidden_states,
426
+ attention_mask=attention_mask,
427
+ head_mask=head_mask,
428
+ encoder_hidden_states=encoder_hidden_states,
429
+ encoder_attention_mask=encoder_attention_mask,
430
+ output_attentions=output_attentions,
431
+ )
432
+ attn_output = cross_attn_outputs[0]
433
+ # residual connection
434
+ hidden_states = residual + attn_output
435
+ outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
436
+
437
+ residual = hidden_states
438
+ hidden_states = self.ln_2(hidden_states)
439
+ feed_forward_hidden_states = self.mlp(hidden_states)
440
+ # we use unnormalized inputs to all functions for residuals
441
+ # 1.0 here is `res_scale``
442
+ hidden_states = inputs + 1.0 * (attn_output + feed_forward_hidden_states)
443
+
444
+ if use_cache:
445
+ outputs = (hidden_states,) + outputs
446
+ else:
447
+ outputs = (hidden_states,) + outputs[1:]
448
+
449
+ return outputs # hidden_states, present, (attentions, cross_attentions)
450
+
451
+
452
+ class GPT2PreTrainedModel(PreTrainedModel):
453
+ """
454
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
455
+ models.
456
+ """
457
+
458
+ config_class = GPT2Config
459
+ base_model_prefix = "transformer"
460
+ is_parallelizable = True
461
+ supports_gradient_checkpointing = True
462
+ _no_split_modules = ["GPT2Block"]
463
+
464
+ def __init__(self, *inputs, **kwargs):
465
+ super().__init__(*inputs, **kwargs)
466
+
467
+ def _init_weights(self, module):
468
+ """Initialize the weights."""
469
+ if isinstance(module, (nn.Linear, Conv1D)):
470
+ # Slightly different from the TF version which uses truncated_normal for initialization
471
+ # cf https://github.com/pytorch/pytorch/pull/5617
472
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
473
+ if module.bias is not None:
474
+ module.bias.data.zero_()
475
+ elif isinstance(module, nn.Embedding):
476
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
477
+ if module.padding_idx is not None:
478
+ module.weight.data[module.padding_idx].zero_()
479
+ elif isinstance(module, nn.LayerNorm):
480
+ module.bias.data.zero_()
481
+ module.weight.data.fill_(1.0)
482
+
483
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
484
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
485
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
486
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
487
+ #
488
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
489
+ for name, p in module.named_parameters():
490
+ if name == "c_proj.weight":
491
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
492
+ p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
493
+
494
+ def _set_gradient_checkpointing(self, module, value=False):
495
+ if isinstance(module, GPT2Model):
496
+ module.gradient_checkpointing = value
497
+
498
+
499
+ GPT2_START_DOCSTRING = r"""
500
+
501
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
502
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
503
+ etc.)
504
+
505
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
506
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
507
+ and behavior.
508
+
509
+ Parameters:
510
+ config ([`GPT2Config`]): Model configuration class with all the parameters of the model.
511
+ Initializing with a config file does not load the weights associated with the model, only the
512
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
513
+ """
514
+
515
+ GPT2_INPUTS_DOCSTRING = r"""
516
+ Args:
517
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
518
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
519
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
520
+ sequence tokens in the vocabulary.
521
+
522
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
523
+ `input_ids`.
524
+
525
+ Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
526
+ [`PreTrainedTokenizer.__call__`] for details.
527
+
528
+ [What are input IDs?](../glossary#input-ids)
529
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
530
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
531
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
532
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
533
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
534
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
535
+
536
+ - 1 for tokens that are **not masked**,
537
+ - 0 for tokens that are **masked**.
538
+
539
+ If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
540
+ `past_key_values`. In other words, the `attention_mask` always has to have the length:
541
+ `len(past_key_values) + len(input_ids)`
542
+
543
+ [What are attention masks?](../glossary#attention-mask)
544
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
545
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
546
+ 1]`:
547
+
548
+ - 0 corresponds to a *sentence A* token,
549
+ - 1 corresponds to a *sentence B* token.
550
+
551
+ [What are token type IDs?](../glossary#token-type-ids)
552
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
553
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
554
+ config.max_position_embeddings - 1]`.
555
+
556
+ [What are position IDs?](../glossary#position-ids)
557
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
558
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
559
+
560
+ - 1 indicates the head is **not masked**,
561
+ - 0 indicates the head is **masked**.
562
+
563
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
564
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
565
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
566
+ model's internal embedding lookup matrix.
567
+
568
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
569
+ `past_key_values`).
570
+ use_cache (`bool`, *optional*):
571
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
572
+ `past_key_values`).
573
+ output_attentions (`bool`, *optional*):
574
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
575
+ tensors for more detail.
576
+ output_hidden_states (`bool`, *optional*):
577
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
578
+ more detail.
579
+ return_dict (`bool`, *optional*):
580
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
581
+ """
582
+ PARALLELIZE_DOCSTRING = r"""
583
+ This is an experimental feature and is a subject to change at a moment's notice.
584
+
585
+ Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
586
+ it will evenly distribute blocks across all devices.
587
+
588
+ Args:
589
+ device_map (`Dict[int, list]`, optional, defaults to None):
590
+ A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
591
+ automatically mapped to the first device (for esoteric reasons). That means that the first device should
592
+ have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
593
+ following number of attention modules:
594
+
595
+ - gpt2: 12
596
+ - gpt2-medium: 24
597
+ - gpt2-large: 36
598
+ - gpt2-xl: 48
599
+
600
+ Example:
601
+
602
+ ```python
603
+ # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
604
+ model = GPT2LMHeadModel.from_pretrained("gpt2-xl")
605
+ device_map = {
606
+ 0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
607
+ 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
608
+ 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
609
+ 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
610
+ }
611
+ model.parallelize(device_map)
612
+ ```
613
+ """
614
+ DEPARALLELIZE_DOCSTRING = r"""
615
+ Moves the model to cpu from a model parallel state.
616
+
617
+ Example:
618
+
619
+ ```python
620
+ # On a 4 GPU machine with gpt2-large:
621
+ model = GPT2LMHeadModel.from_pretrained("gpt2-large")
622
+ device_map = {
623
+ 0: [0, 1, 2, 3, 4, 5, 6, 7],
624
+ 1: [8, 9, 10, 11, 12, 13, 14, 15],
625
+ 2: [16, 17, 18, 19, 20, 21, 22, 23],
626
+ 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
627
+ }
628
+ model.parallelize(device_map) # Splits the model across several devices
629
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
630
+ ```
631
+ """
632
+
633
+
634
+ @add_start_docstrings(
635
+ "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
636
+ GPT2_START_DOCSTRING,
637
+ )
638
+ class GPT2Model(GPT2PreTrainedModel):
639
+ _keys_to_ignore_on_load_missing = ["attn.masked_bias"]
640
+
641
+ def __init__(self, config):
642
+ super().__init__(config)
643
+
644
+ self.embed_dim = config.hidden_size
645
+
646
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
647
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
648
+
649
+ self.drop = nn.Dropout(config.embd_pdrop)
650
+ self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
651
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
652
+
653
+ # Model parallel
654
+ self.model_parallel = False
655
+ self.device_map = None
656
+ self.gradient_checkpointing = False
657
+
658
+ # Initialize weights and apply final processing
659
+ self.post_init()
660
+
661
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
662
+ def parallelize(self, device_map=None):
663
+ # Check validity of device_map
664
+ self.device_map = (
665
+ get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
666
+ )
667
+ assert_device_map(self.device_map, len(self.h))
668
+ self.model_parallel = True
669
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
670
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
671
+ self.wte = self.wte.to(self.first_device)
672
+ self.wpe = self.wpe.to(self.first_device)
673
+ # Load onto devices
674
+ for k, v in self.device_map.items():
675
+ for block in v:
676
+ cuda_device = "cuda:" + str(k)
677
+ self.h[block] = self.h[block].to(cuda_device)
678
+ # ln_f to last
679
+ self.ln_f = self.ln_f.to(self.last_device)
680
+
681
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
682
+ def deparallelize(self):
683
+ self.model_parallel = False
684
+ self.device_map = None
685
+ self.first_device = "cpu"
686
+ self.last_device = "cpu"
687
+ self.wte = self.wte.to("cpu")
688
+ self.wpe = self.wpe.to("cpu")
689
+ for index in range(len(self.h)):
690
+ self.h[index] = self.h[index].to("cpu")
691
+ self.ln_f = self.ln_f.to("cpu")
692
+ torch.cuda.empty_cache()
693
+
694
+ def get_input_embeddings(self):
695
+ return self.wte
696
+
697
+ def set_input_embeddings(self, new_embeddings):
698
+ self.wte = new_embeddings
699
+
700
+ def _prune_heads(self, heads_to_prune):
701
+ """
702
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
703
+ """
704
+ for layer, heads in heads_to_prune.items():
705
+ self.h[layer].attn.prune_heads(heads)
706
+
707
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
708
+ @add_code_sample_docstrings(
709
+ processor_class=_TOKENIZER_FOR_DOC,
710
+ checkpoint=_CHECKPOINT_FOR_DOC,
711
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
712
+ config_class=_CONFIG_FOR_DOC,
713
+ )
714
+ def forward(
715
+ self,
716
+ input_ids: Optional[torch.LongTensor] = None,
717
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
718
+ attention_mask: Optional[torch.FloatTensor] = None,
719
+ token_type_ids: Optional[torch.LongTensor] = None,
720
+ position_ids: Optional[torch.LongTensor] = None,
721
+ head_mask: Optional[torch.FloatTensor] = None,
722
+ inputs_embeds: Optional[torch.FloatTensor] = None,
723
+ encoder_hidden_states: Optional[torch.Tensor] = None,
724
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
725
+ use_cache: Optional[bool] = None,
726
+ output_attentions: Optional[bool] = None,
727
+ output_hidden_states: Optional[bool] = None,
728
+ return_dict: Optional[bool] = None,
729
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
730
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
731
+ output_hidden_states = (
732
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
733
+ )
734
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
735
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
736
+
737
+ if input_ids is not None and inputs_embeds is not None:
738
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
739
+ elif input_ids is not None:
740
+ input_shape = input_ids.size()
741
+ input_ids = input_ids.view(-1, input_shape[-1])
742
+ batch_size = input_ids.shape[0]
743
+ elif inputs_embeds is not None:
744
+ input_shape = inputs_embeds.size()[:-1]
745
+ batch_size = inputs_embeds.shape[0]
746
+ else:
747
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
748
+
749
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
750
+
751
+ if token_type_ids is not None:
752
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
753
+ if position_ids is not None:
754
+ position_ids = position_ids.view(-1, input_shape[-1])
755
+
756
+ if past_key_values is None:
757
+ past_length = 0
758
+ past_key_values = tuple([None] * len(self.h))
759
+ else:
760
+ past_length = past_key_values[0][0].size(-2)
761
+ if position_ids is None:
762
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
763
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
764
+
765
+ # GPT2Attention mask.
766
+ if attention_mask is not None:
767
+ if batch_size <= 0:
768
+ raise ValueError("batch_size has to be defined and > 0")
769
+ attention_mask = attention_mask.view(batch_size, -1)
770
+ # We create a 3D attention mask from a 2D tensor mask.
771
+ # Sizes are [batch_size, 1, 1, to_seq_length]
772
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
773
+ # this attention mask is more simple than the triangular masking of causal attention
774
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
775
+ attention_mask = attention_mask[:, None, None, :]
776
+
777
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
778
+ # masked positions, this operation will create a tensor which is 0.0 for
779
+ # positions we want to attend and the dtype's smallest value for masked positions.
780
+ # Since we are adding it to the raw scores before the softmax, this is
781
+ # effectively the same as removing these entirely.
782
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
783
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
784
+
785
+ # If a 2D or 3D attention mask is provided for the cross-attention
786
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
787
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
788
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
789
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
790
+ if encoder_attention_mask is None:
791
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
792
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
793
+ else:
794
+ encoder_attention_mask = None
795
+
796
+ # Prepare head mask if needed
797
+ # 1.0 in head_mask indicate we keep the head
798
+ # attention_probs has shape bsz x n_heads x N x N
799
+ # head_mask has shape n_layer x batch x n_heads x N x N
800
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
801
+
802
+ if inputs_embeds is None:
803
+ inputs_embeds = self.wte(input_ids)
804
+ position_embeds = self.wpe(position_ids)
805
+ hidden_states = inputs_embeds + position_embeds
806
+
807
+ if token_type_ids is not None:
808
+ token_type_embeds = self.wte(token_type_ids)
809
+ hidden_states = hidden_states + token_type_embeds
810
+
811
+ hidden_states = self.drop(hidden_states)
812
+
813
+ output_shape = input_shape + (hidden_states.size(-1),)
814
+
815
+ presents = () if use_cache else None
816
+ all_self_attentions = () if output_attentions else None
817
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
818
+ all_hidden_states = () if output_hidden_states else None
819
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
820
+
821
+ # Model parallel
822
+ if self.model_parallel:
823
+ torch.cuda.set_device(hidden_states.device)
824
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
825
+ if layer_past is not None:
826
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
827
+ # Ensure that attention_mask is always on the same device as hidden_states
828
+ if attention_mask is not None:
829
+ attention_mask = attention_mask.to(hidden_states.device)
830
+ if isinstance(head_mask, torch.Tensor):
831
+ head_mask = head_mask.to(hidden_states.device)
832
+ if output_hidden_states:
833
+ all_hidden_states = all_hidden_states + (hidden_states,)
834
+
835
+ if self.gradient_checkpointing and self.training:
836
+
837
+ if use_cache:
838
+ logger.warning(
839
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
840
+ )
841
+ use_cache = False
842
+
843
+ def create_custom_forward(module):
844
+ def custom_forward(*inputs):
845
+ # None for past_key_value
846
+ return module(*inputs, use_cache, output_attentions)
847
+
848
+ return custom_forward
849
+
850
+ outputs = torch.utils.checkpoint.checkpoint(
851
+ create_custom_forward(block),
852
+ hidden_states,
853
+ None,
854
+ attention_mask,
855
+ head_mask[i],
856
+ encoder_hidden_states,
857
+ encoder_attention_mask,
858
+ )
859
+ else:
860
+ outputs = block(
861
+ hidden_states,
862
+ layer_past=layer_past,
863
+ attention_mask=attention_mask,
864
+ head_mask=head_mask[i],
865
+ encoder_hidden_states=encoder_hidden_states,
866
+ encoder_attention_mask=encoder_attention_mask,
867
+ use_cache=use_cache,
868
+ output_attentions=output_attentions,
869
+ )
870
+
871
+ hidden_states = outputs[0]
872
+ if use_cache is True:
873
+ presents = presents + (outputs[1],)
874
+
875
+ if output_attentions:
876
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
877
+ if self.config.add_cross_attention:
878
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
879
+
880
+ # Model Parallel: If it's the last layer for that device, put things on the next device
881
+ if self.model_parallel:
882
+ for k, v in self.device_map.items():
883
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
884
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
885
+
886
+ hidden_states = self.ln_f(hidden_states)
887
+
888
+ hidden_states = hidden_states.view(output_shape)
889
+ # Add last hidden state
890
+ if output_hidden_states:
891
+ all_hidden_states = all_hidden_states + (hidden_states,)
892
+
893
+ if not return_dict:
894
+ return tuple(
895
+ v
896
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
897
+ if v is not None
898
+ )
899
+
900
+ return BaseModelOutputWithPastAndCrossAttentions(
901
+ last_hidden_state=hidden_states,
902
+ past_key_values=presents,
903
+ hidden_states=all_hidden_states,
904
+ attentions=all_self_attentions,
905
+ cross_attentions=all_cross_attentions,
906
+ )
907
+
908
+
909
+ @add_start_docstrings(
910
+ """
911
+ The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
912
+ embeddings).
913
+ """,
914
+ GPT2_START_DOCSTRING,
915
+ )
916
+ class GPT2LMHeadCustomModel(GPT2PreTrainedModel):
917
+ _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
918
+
919
+ def __init__(self, config):
920
+ super().__init__(config)
921
+ self.transformer = GPT2Model(config)
922
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
923
+
924
+ # Model parallel
925
+ self.model_parallel = False
926
+ self.device_map = None
927
+
928
+ # Initialize weights and apply final processing
929
+ self.post_init()
930
+
931
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
932
+ def parallelize(self, device_map=None):
933
+ self.device_map = (
934
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
935
+ if device_map is None
936
+ else device_map
937
+ )
938
+ assert_device_map(self.device_map, len(self.transformer.h))
939
+ self.transformer.parallelize(self.device_map)
940
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
941
+ self.model_parallel = True
942
+
943
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
944
+ def deparallelize(self):
945
+ self.transformer.deparallelize()
946
+ self.transformer = self.transformer.to("cpu")
947
+ self.lm_head = self.lm_head.to("cpu")
948
+ self.model_parallel = False
949
+ torch.cuda.empty_cache()
950
+
951
+ def get_output_embeddings(self):
952
+ return self.lm_head
953
+
954
+ def set_output_embeddings(self, new_embeddings):
955
+ self.lm_head = new_embeddings
956
+
957
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
958
+ token_type_ids = kwargs.get("token_type_ids", None)
959
+ # only last token for inputs_ids if past is defined in kwargs
960
+ if past:
961
+ input_ids = input_ids[:, -1].unsqueeze(-1)
962
+ if token_type_ids is not None:
963
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
964
+
965
+ attention_mask = kwargs.get("attention_mask", None)
966
+ position_ids = kwargs.get("position_ids", None)
967
+
968
+ if attention_mask is not None and position_ids is None:
969
+ # create position_ids on the fly for batch generation
970
+ position_ids = attention_mask.long().cumsum(-1) - 1
971
+ position_ids.masked_fill_(attention_mask == 0, 1)
972
+ if past:
973
+ position_ids = position_ids[:, -1].unsqueeze(-1)
974
+ else:
975
+ position_ids = None
976
+ return {
977
+ "input_ids": input_ids,
978
+ "past_key_values": past,
979
+ "use_cache": kwargs.get("use_cache"),
980
+ "position_ids": position_ids,
981
+ "attention_mask": attention_mask,
982
+ "token_type_ids": token_type_ids,
983
+ }
984
+
985
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
986
+ @add_code_sample_docstrings(
987
+ processor_class=_TOKENIZER_FOR_DOC,
988
+ checkpoint=_CHECKPOINT_FOR_DOC,
989
+ output_type=CausalLMOutputWithCrossAttentions,
990
+ config_class=_CONFIG_FOR_DOC,
991
+ )
992
+ def forward(
993
+ self,
994
+ input_ids: Optional[torch.LongTensor] = None,
995
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
996
+ attention_mask: Optional[torch.FloatTensor] = None,
997
+ token_type_ids: Optional[torch.LongTensor] = None,
998
+ position_ids: Optional[torch.LongTensor] = None,
999
+ head_mask: Optional[torch.FloatTensor] = None,
1000
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1001
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1002
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1003
+ labels: Optional[torch.LongTensor] = None,
1004
+ use_cache: Optional[bool] = None,
1005
+ output_attentions: Optional[bool] = None,
1006
+ output_hidden_states: Optional[bool] = None,
1007
+ return_dict: Optional[bool] = None,
1008
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1009
+ r"""
1010
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1011
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1012
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1013
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1014
+ """
1015
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1016
+
1017
+ transformer_outputs = self.transformer(
1018
+ input_ids,
1019
+ past_key_values=past_key_values,
1020
+ attention_mask=attention_mask,
1021
+ token_type_ids=token_type_ids,
1022
+ position_ids=position_ids,
1023
+ head_mask=head_mask,
1024
+ inputs_embeds=inputs_embeds,
1025
+ encoder_hidden_states=encoder_hidden_states,
1026
+ encoder_attention_mask=encoder_attention_mask,
1027
+ use_cache=use_cache,
1028
+ output_attentions=output_attentions,
1029
+ output_hidden_states=output_hidden_states,
1030
+ return_dict=return_dict,
1031
+ )
1032
+ hidden_states = transformer_outputs[0]
1033
+
1034
+ # Set device for model parallelism
1035
+ if self.model_parallel:
1036
+ torch.cuda.set_device(self.transformer.first_device)
1037
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1038
+
1039
+ lm_logits = self.lm_head(hidden_states)
1040
+
1041
+ loss = None
1042
+ if labels is not None:
1043
+ # Shift so that tokens < n predict n
1044
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1045
+ shift_labels = labels[..., 1:].contiguous()
1046
+ # Flatten the tokens
1047
+ loss_fct = CrossEntropyLoss()
1048
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1049
+
1050
+ if not return_dict:
1051
+ output = (lm_logits,) + transformer_outputs[1:]
1052
+ return ((loss,) + output) if loss is not None else output
1053
+
1054
+ return CausalLMOutputWithCrossAttentions(
1055
+ loss=loss,
1056
+ logits=lm_logits,
1057
+ past_key_values=transformer_outputs.past_key_values,
1058
+ hidden_states=transformer_outputs.hidden_states,
1059
+ attentions=transformer_outputs.attentions,
1060
+ cross_attentions=transformer_outputs.cross_attentions,
1061
+ )
1062
+
1063
+ @staticmethod
1064
+ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
1065
+ """
1066
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1067
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1068
+ beam_idx at every generation step.
1069
+ """
1070
+ return tuple(
1071
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
1072
+ for layer_past in past
1073
+ )
1074
+
pytorch_model-00001-of-00006.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d491a6bded42bb3ade1fcef6167d7c38d93709ab72d7b2519d48cb083680f63a
3
+ size 978625381
pytorch_model-00002-of-00006.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6f2981feff090a7da4e5a724459c3a7d24c4e0d57dad2fd3144533e6c476879
3
+ size 960982501
pytorch_model-00003-of-00006.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe90003bed77d15cae1eeb5c3dd976e46774d0b7623eb6de0cb00cd8508b75d3
3
+ size 994535645
pytorch_model-00004-of-00006.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1fea5e1bf6bd1a6aa79fbe2cf00011b41c396f30dc03e601bac1da6a1db45fdf
3
+ size 994561235
pytorch_model-00005-of-00006.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0531a9cd7c51e05fb4d191a97362051b1f4f3e460ff3dfaaaad1df10f48abb3
3
+ size 956795165
pytorch_model-00006-of-00006.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3043425123753b7f43c288367f2c2c01ab485163e542ca6d355c7fed1abd9d87
3
+ size 890475405
pytorch_model.bin.index.json ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 5775835232
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "pytorch_model-00006-of-00006.bin",
7
+ "transformer.h.0.attn.bias": "pytorch_model-00001-of-00006.bin",
8
+ "transformer.h.0.attn.c_proj.bias": "pytorch_model-00001-of-00006.bin",
9
+ "transformer.h.0.attn.c_proj.weight": "pytorch_model-00001-of-00006.bin",
10
+ "transformer.h.0.attn.k_proj.weight": "pytorch_model-00001-of-00006.bin",
11
+ "transformer.h.0.attn.masked_bias": "pytorch_model-00001-of-00006.bin",
12
+ "transformer.h.0.attn.q_proj.bias": "pytorch_model-00001-of-00006.bin",
13
+ "transformer.h.0.attn.q_proj.weight": "pytorch_model-00001-of-00006.bin",
14
+ "transformer.h.0.attn.v_proj.bias": "pytorch_model-00001-of-00006.bin",
15
+ "transformer.h.0.attn.v_proj.weight": "pytorch_model-00001-of-00006.bin",
16
+ "transformer.h.0.ln_1.bias": "pytorch_model-00001-of-00006.bin",
17
+ "transformer.h.0.ln_1.weight": "pytorch_model-00001-of-00006.bin",
18
+ "transformer.h.0.ln_2.bias": "pytorch_model-00001-of-00006.bin",
19
+ "transformer.h.0.ln_2.weight": "pytorch_model-00001-of-00006.bin",
20
+ "transformer.h.0.mlp.c_fc.bias": "pytorch_model-00001-of-00006.bin",
21
+ "transformer.h.0.mlp.c_fc.weight": "pytorch_model-00001-of-00006.bin",
22
+ "transformer.h.0.mlp.c_proj.bias": "pytorch_model-00001-of-00006.bin",
23
+ "transformer.h.0.mlp.c_proj.weight": "pytorch_model-00001-of-00006.bin",
24
+ "transformer.h.1.attn.bias": "pytorch_model-00001-of-00006.bin",
25
+ "transformer.h.1.attn.c_proj.bias": "pytorch_model-00001-of-00006.bin",
26
+ "transformer.h.1.attn.c_proj.weight": "pytorch_model-00001-of-00006.bin",
27
+ "transformer.h.1.attn.k_proj.weight": "pytorch_model-00001-of-00006.bin",
28
+ "transformer.h.1.attn.masked_bias": "pytorch_model-00001-of-00006.bin",
29
+ "transformer.h.1.attn.q_proj.bias": "pytorch_model-00001-of-00006.bin",
30
+ "transformer.h.1.attn.q_proj.weight": "pytorch_model-00001-of-00006.bin",
31
+ "transformer.h.1.attn.v_proj.bias": "pytorch_model-00001-of-00006.bin",
32
+ "transformer.h.1.attn.v_proj.weight": "pytorch_model-00001-of-00006.bin",
33
+ "transformer.h.1.ln_1.bias": "pytorch_model-00001-of-00006.bin",
34
+ "transformer.h.1.ln_1.weight": "pytorch_model-00001-of-00006.bin",
35
+ "transformer.h.1.ln_2.bias": "pytorch_model-00001-of-00006.bin",
36
+ "transformer.h.1.ln_2.weight": "pytorch_model-00001-of-00006.bin",
37
+ "transformer.h.1.mlp.c_fc.bias": "pytorch_model-00001-of-00006.bin",
38
+ "transformer.h.1.mlp.c_fc.weight": "pytorch_model-00001-of-00006.bin",
39
+ "transformer.h.1.mlp.c_proj.bias": "pytorch_model-00001-of-00006.bin",
40
+ "transformer.h.1.mlp.c_proj.weight": "pytorch_model-00001-of-00006.bin",
41
+ "transformer.h.10.attn.bias": "pytorch_model-00003-of-00006.bin",
42
+ "transformer.h.10.attn.c_proj.bias": "pytorch_model-00003-of-00006.bin",
43
+ "transformer.h.10.attn.c_proj.weight": "pytorch_model-00003-of-00006.bin",
44
+ "transformer.h.10.attn.k_proj.weight": "pytorch_model-00003-of-00006.bin",
45
+ "transformer.h.10.attn.masked_bias": "pytorch_model-00003-of-00006.bin",
46
+ "transformer.h.10.attn.q_proj.bias": "pytorch_model-00003-of-00006.bin",
47
+ "transformer.h.10.attn.q_proj.weight": "pytorch_model-00003-of-00006.bin",
48
+ "transformer.h.10.attn.v_proj.bias": "pytorch_model-00003-of-00006.bin",
49
+ "transformer.h.10.attn.v_proj.weight": "pytorch_model-00003-of-00006.bin",
50
+ "transformer.h.10.ln_1.bias": "pytorch_model-00003-of-00006.bin",
51
+ "transformer.h.10.ln_1.weight": "pytorch_model-00003-of-00006.bin",
52
+ "transformer.h.10.ln_2.bias": "pytorch_model-00003-of-00006.bin",
53
+ "transformer.h.10.ln_2.weight": "pytorch_model-00003-of-00006.bin",
54
+ "transformer.h.10.mlp.c_fc.bias": "pytorch_model-00003-of-00006.bin",
55
+ "transformer.h.10.mlp.c_fc.weight": "pytorch_model-00003-of-00006.bin",
56
+ "transformer.h.10.mlp.c_proj.bias": "pytorch_model-00003-of-00006.bin",
57
+ "transformer.h.10.mlp.c_proj.weight": "pytorch_model-00003-of-00006.bin",
58
+ "transformer.h.11.attn.bias": "pytorch_model-00003-of-00006.bin",
59
+ "transformer.h.11.attn.c_proj.bias": "pytorch_model-00003-of-00006.bin",
60
+ "transformer.h.11.attn.c_proj.weight": "pytorch_model-00003-of-00006.bin",
61
+ "transformer.h.11.attn.k_proj.weight": "pytorch_model-00003-of-00006.bin",
62
+ "transformer.h.11.attn.masked_bias": "pytorch_model-00003-of-00006.bin",
63
+ "transformer.h.11.attn.q_proj.bias": "pytorch_model-00003-of-00006.bin",
64
+ "transformer.h.11.attn.q_proj.weight": "pytorch_model-00003-of-00006.bin",
65
+ "transformer.h.11.attn.v_proj.bias": "pytorch_model-00003-of-00006.bin",
66
+ "transformer.h.11.attn.v_proj.weight": "pytorch_model-00003-of-00006.bin",
67
+ "transformer.h.11.ln_1.bias": "pytorch_model-00003-of-00006.bin",
68
+ "transformer.h.11.ln_1.weight": "pytorch_model-00003-of-00006.bin",
69
+ "transformer.h.11.ln_2.bias": "pytorch_model-00003-of-00006.bin",
70
+ "transformer.h.11.ln_2.weight": "pytorch_model-00003-of-00006.bin",
71
+ "transformer.h.11.mlp.c_fc.bias": "pytorch_model-00003-of-00006.bin",
72
+ "transformer.h.11.mlp.c_fc.weight": "pytorch_model-00003-of-00006.bin",
73
+ "transformer.h.11.mlp.c_proj.bias": "pytorch_model-00003-of-00006.bin",
74
+ "transformer.h.11.mlp.c_proj.weight": "pytorch_model-00003-of-00006.bin",
75
+ "transformer.h.12.attn.bias": "pytorch_model-00003-of-00006.bin",
76
+ "transformer.h.12.attn.c_proj.bias": "pytorch_model-00004-of-00006.bin",
77
+ "transformer.h.12.attn.c_proj.weight": "pytorch_model-00004-of-00006.bin",
78
+ "transformer.h.12.attn.k_proj.weight": "pytorch_model-00003-of-00006.bin",
79
+ "transformer.h.12.attn.masked_bias": "pytorch_model-00003-of-00006.bin",
80
+ "transformer.h.12.attn.q_proj.bias": "pytorch_model-00003-of-00006.bin",
81
+ "transformer.h.12.attn.q_proj.weight": "pytorch_model-00003-of-00006.bin",
82
+ "transformer.h.12.attn.v_proj.bias": "pytorch_model-00004-of-00006.bin",
83
+ "transformer.h.12.attn.v_proj.weight": "pytorch_model-00004-of-00006.bin",
84
+ "transformer.h.12.ln_1.bias": "pytorch_model-00003-of-00006.bin",
85
+ "transformer.h.12.ln_1.weight": "pytorch_model-00003-of-00006.bin",
86
+ "transformer.h.12.ln_2.bias": "pytorch_model-00004-of-00006.bin",
87
+ "transformer.h.12.ln_2.weight": "pytorch_model-00004-of-00006.bin",
88
+ "transformer.h.12.mlp.c_fc.bias": "pytorch_model-00004-of-00006.bin",
89
+ "transformer.h.12.mlp.c_fc.weight": "pytorch_model-00004-of-00006.bin",
90
+ "transformer.h.12.mlp.c_proj.bias": "pytorch_model-00004-of-00006.bin",
91
+ "transformer.h.12.mlp.c_proj.weight": "pytorch_model-00004-of-00006.bin",
92
+ "transformer.h.13.attn.bias": "pytorch_model-00004-of-00006.bin",
93
+ "transformer.h.13.attn.c_proj.bias": "pytorch_model-00004-of-00006.bin",
94
+ "transformer.h.13.attn.c_proj.weight": "pytorch_model-00004-of-00006.bin",
95
+ "transformer.h.13.attn.k_proj.weight": "pytorch_model-00004-of-00006.bin",
96
+ "transformer.h.13.attn.masked_bias": "pytorch_model-00004-of-00006.bin",
97
+ "transformer.h.13.attn.q_proj.bias": "pytorch_model-00004-of-00006.bin",
98
+ "transformer.h.13.attn.q_proj.weight": "pytorch_model-00004-of-00006.bin",
99
+ "transformer.h.13.attn.v_proj.bias": "pytorch_model-00004-of-00006.bin",
100
+ "transformer.h.13.attn.v_proj.weight": "pytorch_model-00004-of-00006.bin",
101
+ "transformer.h.13.ln_1.bias": "pytorch_model-00004-of-00006.bin",
102
+ "transformer.h.13.ln_1.weight": "pytorch_model-00004-of-00006.bin",
103
+ "transformer.h.13.ln_2.bias": "pytorch_model-00004-of-00006.bin",
104
+ "transformer.h.13.ln_2.weight": "pytorch_model-00004-of-00006.bin",
105
+ "transformer.h.13.mlp.c_fc.bias": "pytorch_model-00004-of-00006.bin",
106
+ "transformer.h.13.mlp.c_fc.weight": "pytorch_model-00004-of-00006.bin",
107
+ "transformer.h.13.mlp.c_proj.bias": "pytorch_model-00004-of-00006.bin",
108
+ "transformer.h.13.mlp.c_proj.weight": "pytorch_model-00004-of-00006.bin",
109
+ "transformer.h.14.attn.bias": "pytorch_model-00004-of-00006.bin",
110
+ "transformer.h.14.attn.c_proj.bias": "pytorch_model-00004-of-00006.bin",
111
+ "transformer.h.14.attn.c_proj.weight": "pytorch_model-00004-of-00006.bin",
112
+ "transformer.h.14.attn.k_proj.weight": "pytorch_model-00004-of-00006.bin",
113
+ "transformer.h.14.attn.masked_bias": "pytorch_model-00004-of-00006.bin",
114
+ "transformer.h.14.attn.q_proj.bias": "pytorch_model-00004-of-00006.bin",
115
+ "transformer.h.14.attn.q_proj.weight": "pytorch_model-00004-of-00006.bin",
116
+ "transformer.h.14.attn.v_proj.bias": "pytorch_model-00004-of-00006.bin",
117
+ "transformer.h.14.attn.v_proj.weight": "pytorch_model-00004-of-00006.bin",
118
+ "transformer.h.14.ln_1.bias": "pytorch_model-00004-of-00006.bin",
119
+ "transformer.h.14.ln_1.weight": "pytorch_model-00004-of-00006.bin",
120
+ "transformer.h.14.ln_2.bias": "pytorch_model-00004-of-00006.bin",
121
+ "transformer.h.14.ln_2.weight": "pytorch_model-00004-of-00006.bin",
122
+ "transformer.h.14.mlp.c_fc.bias": "pytorch_model-00004-of-00006.bin",
123
+ "transformer.h.14.mlp.c_fc.weight": "pytorch_model-00004-of-00006.bin",
124
+ "transformer.h.14.mlp.c_proj.bias": "pytorch_model-00004-of-00006.bin",
125
+ "transformer.h.14.mlp.c_proj.weight": "pytorch_model-00004-of-00006.bin",
126
+ "transformer.h.15.attn.bias": "pytorch_model-00004-of-00006.bin",
127
+ "transformer.h.15.attn.c_proj.bias": "pytorch_model-00004-of-00006.bin",
128
+ "transformer.h.15.attn.c_proj.weight": "pytorch_model-00004-of-00006.bin",
129
+ "transformer.h.15.attn.k_proj.weight": "pytorch_model-00004-of-00006.bin",
130
+ "transformer.h.15.attn.masked_bias": "pytorch_model-00004-of-00006.bin",
131
+ "transformer.h.15.attn.q_proj.bias": "pytorch_model-00004-of-00006.bin",
132
+ "transformer.h.15.attn.q_proj.weight": "pytorch_model-00004-of-00006.bin",
133
+ "transformer.h.15.attn.v_proj.bias": "pytorch_model-00004-of-00006.bin",
134
+ "transformer.h.15.attn.v_proj.weight": "pytorch_model-00004-of-00006.bin",
135
+ "transformer.h.15.ln_1.bias": "pytorch_model-00004-of-00006.bin",
136
+ "transformer.h.15.ln_1.weight": "pytorch_model-00004-of-00006.bin",
137
+ "transformer.h.15.ln_2.bias": "pytorch_model-00004-of-00006.bin",
138
+ "transformer.h.15.ln_2.weight": "pytorch_model-00004-of-00006.bin",
139
+ "transformer.h.15.mlp.c_fc.bias": "pytorch_model-00004-of-00006.bin",
140
+ "transformer.h.15.mlp.c_fc.weight": "pytorch_model-00004-of-00006.bin",
141
+ "transformer.h.15.mlp.c_proj.bias": "pytorch_model-00004-of-00006.bin",
142
+ "transformer.h.15.mlp.c_proj.weight": "pytorch_model-00004-of-00006.bin",
143
+ "transformer.h.16.attn.bias": "pytorch_model-00004-of-00006.bin",
144
+ "transformer.h.16.attn.c_proj.bias": "pytorch_model-00004-of-00006.bin",
145
+ "transformer.h.16.attn.c_proj.weight": "pytorch_model-00004-of-00006.bin",
146
+ "transformer.h.16.attn.k_proj.weight": "pytorch_model-00004-of-00006.bin",
147
+ "transformer.h.16.attn.masked_bias": "pytorch_model-00004-of-00006.bin",
148
+ "transformer.h.16.attn.q_proj.bias": "pytorch_model-00004-of-00006.bin",
149
+ "transformer.h.16.attn.q_proj.weight": "pytorch_model-00004-of-00006.bin",
150
+ "transformer.h.16.attn.v_proj.bias": "pytorch_model-00004-of-00006.bin",
151
+ "transformer.h.16.attn.v_proj.weight": "pytorch_model-00004-of-00006.bin",
152
+ "transformer.h.16.ln_1.bias": "pytorch_model-00004-of-00006.bin",
153
+ "transformer.h.16.ln_1.weight": "pytorch_model-00004-of-00006.bin",
154
+ "transformer.h.16.ln_2.bias": "pytorch_model-00004-of-00006.bin",
155
+ "transformer.h.16.ln_2.weight": "pytorch_model-00004-of-00006.bin",
156
+ "transformer.h.16.mlp.c_fc.bias": "pytorch_model-00004-of-00006.bin",
157
+ "transformer.h.16.mlp.c_fc.weight": "pytorch_model-00004-of-00006.bin",
158
+ "transformer.h.16.mlp.c_proj.bias": "pytorch_model-00004-of-00006.bin",
159
+ "transformer.h.16.mlp.c_proj.weight": "pytorch_model-00004-of-00006.bin",
160
+ "transformer.h.17.attn.bias": "pytorch_model-00004-of-00006.bin",
161
+ "transformer.h.17.attn.c_proj.bias": "pytorch_model-00005-of-00006.bin",
162
+ "transformer.h.17.attn.c_proj.weight": "pytorch_model-00005-of-00006.bin",
163
+ "transformer.h.17.attn.k_proj.weight": "pytorch_model-00005-of-00006.bin",
164
+ "transformer.h.17.attn.masked_bias": "pytorch_model-00004-of-00006.bin",
165
+ "transformer.h.17.attn.q_proj.bias": "pytorch_model-00005-of-00006.bin",
166
+ "transformer.h.17.attn.q_proj.weight": "pytorch_model-00005-of-00006.bin",
167
+ "transformer.h.17.attn.v_proj.bias": "pytorch_model-00005-of-00006.bin",
168
+ "transformer.h.17.attn.v_proj.weight": "pytorch_model-00005-of-00006.bin",
169
+ "transformer.h.17.ln_1.bias": "pytorch_model-00004-of-00006.bin",
170
+ "transformer.h.17.ln_1.weight": "pytorch_model-00004-of-00006.bin",
171
+ "transformer.h.17.ln_2.bias": "pytorch_model-00005-of-00006.bin",
172
+ "transformer.h.17.ln_2.weight": "pytorch_model-00005-of-00006.bin",
173
+ "transformer.h.17.mlp.c_fc.bias": "pytorch_model-00005-of-00006.bin",
174
+ "transformer.h.17.mlp.c_fc.weight": "pytorch_model-00005-of-00006.bin",
175
+ "transformer.h.17.mlp.c_proj.bias": "pytorch_model-00005-of-00006.bin",
176
+ "transformer.h.17.mlp.c_proj.weight": "pytorch_model-00005-of-00006.bin",
177
+ "transformer.h.18.attn.bias": "pytorch_model-00005-of-00006.bin",
178
+ "transformer.h.18.attn.c_proj.bias": "pytorch_model-00005-of-00006.bin",
179
+ "transformer.h.18.attn.c_proj.weight": "pytorch_model-00005-of-00006.bin",
180
+ "transformer.h.18.attn.k_proj.weight": "pytorch_model-00005-of-00006.bin",
181
+ "transformer.h.18.attn.masked_bias": "pytorch_model-00005-of-00006.bin",
182
+ "transformer.h.18.attn.q_proj.bias": "pytorch_model-00005-of-00006.bin",
183
+ "transformer.h.18.attn.q_proj.weight": "pytorch_model-00005-of-00006.bin",
184
+ "transformer.h.18.attn.v_proj.bias": "pytorch_model-00005-of-00006.bin",
185
+ "transformer.h.18.attn.v_proj.weight": "pytorch_model-00005-of-00006.bin",
186
+ "transformer.h.18.ln_1.bias": "pytorch_model-00005-of-00006.bin",
187
+ "transformer.h.18.ln_1.weight": "pytorch_model-00005-of-00006.bin",
188
+ "transformer.h.18.ln_2.bias": "pytorch_model-00005-of-00006.bin",
189
+ "transformer.h.18.ln_2.weight": "pytorch_model-00005-of-00006.bin",
190
+ "transformer.h.18.mlp.c_fc.bias": "pytorch_model-00005-of-00006.bin",
191
+ "transformer.h.18.mlp.c_fc.weight": "pytorch_model-00005-of-00006.bin",
192
+ "transformer.h.18.mlp.c_proj.bias": "pytorch_model-00005-of-00006.bin",
193
+ "transformer.h.18.mlp.c_proj.weight": "pytorch_model-00005-of-00006.bin",
194
+ "transformer.h.19.attn.bias": "pytorch_model-00005-of-00006.bin",
195
+ "transformer.h.19.attn.c_proj.bias": "pytorch_model-00005-of-00006.bin",
196
+ "transformer.h.19.attn.c_proj.weight": "pytorch_model-00005-of-00006.bin",
197
+ "transformer.h.19.attn.k_proj.weight": "pytorch_model-00005-of-00006.bin",
198
+ "transformer.h.19.attn.masked_bias": "pytorch_model-00005-of-00006.bin",
199
+ "transformer.h.19.attn.q_proj.bias": "pytorch_model-00005-of-00006.bin",
200
+ "transformer.h.19.attn.q_proj.weight": "pytorch_model-00005-of-00006.bin",
201
+ "transformer.h.19.attn.v_proj.bias": "pytorch_model-00005-of-00006.bin",
202
+ "transformer.h.19.attn.v_proj.weight": "pytorch_model-00005-of-00006.bin",
203
+ "transformer.h.19.ln_1.bias": "pytorch_model-00005-of-00006.bin",
204
+ "transformer.h.19.ln_1.weight": "pytorch_model-00005-of-00006.bin",
205
+ "transformer.h.19.ln_2.bias": "pytorch_model-00005-of-00006.bin",
206
+ "transformer.h.19.ln_2.weight": "pytorch_model-00005-of-00006.bin",
207
+ "transformer.h.19.mlp.c_fc.bias": "pytorch_model-00005-of-00006.bin",
208
+ "transformer.h.19.mlp.c_fc.weight": "pytorch_model-00005-of-00006.bin",
209
+ "transformer.h.19.mlp.c_proj.bias": "pytorch_model-00005-of-00006.bin",
210
+ "transformer.h.19.mlp.c_proj.weight": "pytorch_model-00005-of-00006.bin",
211
+ "transformer.h.2.attn.bias": "pytorch_model-00001-of-00006.bin",
212
+ "transformer.h.2.attn.c_proj.bias": "pytorch_model-00001-of-00006.bin",
213
+ "transformer.h.2.attn.c_proj.weight": "pytorch_model-00001-of-00006.bin",
214
+ "transformer.h.2.attn.k_proj.weight": "pytorch_model-00001-of-00006.bin",
215
+ "transformer.h.2.attn.masked_bias": "pytorch_model-00001-of-00006.bin",
216
+ "transformer.h.2.attn.q_proj.bias": "pytorch_model-00001-of-00006.bin",
217
+ "transformer.h.2.attn.q_proj.weight": "pytorch_model-00001-of-00006.bin",
218
+ "transformer.h.2.attn.v_proj.bias": "pytorch_model-00001-of-00006.bin",
219
+ "transformer.h.2.attn.v_proj.weight": "pytorch_model-00001-of-00006.bin",
220
+ "transformer.h.2.ln_1.bias": "pytorch_model-00001-of-00006.bin",
221
+ "transformer.h.2.ln_1.weight": "pytorch_model-00001-of-00006.bin",
222
+ "transformer.h.2.ln_2.bias": "pytorch_model-00001-of-00006.bin",
223
+ "transformer.h.2.ln_2.weight": "pytorch_model-00001-of-00006.bin",
224
+ "transformer.h.2.mlp.c_fc.bias": "pytorch_model-00001-of-00006.bin",
225
+ "transformer.h.2.mlp.c_fc.weight": "pytorch_model-00001-of-00006.bin",
226
+ "transformer.h.2.mlp.c_proj.bias": "pytorch_model-00002-of-00006.bin",
227
+ "transformer.h.2.mlp.c_proj.weight": "pytorch_model-00002-of-00006.bin",
228
+ "transformer.h.20.attn.bias": "pytorch_model-00005-of-00006.bin",
229
+ "transformer.h.20.attn.c_proj.bias": "pytorch_model-00005-of-00006.bin",
230
+ "transformer.h.20.attn.c_proj.weight": "pytorch_model-00005-of-00006.bin",
231
+ "transformer.h.20.attn.k_proj.weight": "pytorch_model-00005-of-00006.bin",
232
+ "transformer.h.20.attn.masked_bias": "pytorch_model-00005-of-00006.bin",
233
+ "transformer.h.20.attn.q_proj.bias": "pytorch_model-00005-of-00006.bin",
234
+ "transformer.h.20.attn.q_proj.weight": "pytorch_model-00005-of-00006.bin",
235
+ "transformer.h.20.attn.v_proj.bias": "pytorch_model-00005-of-00006.bin",
236
+ "transformer.h.20.attn.v_proj.weight": "pytorch_model-00005-of-00006.bin",
237
+ "transformer.h.20.ln_1.bias": "pytorch_model-00005-of-00006.bin",
238
+ "transformer.h.20.ln_1.weight": "pytorch_model-00005-of-00006.bin",
239
+ "transformer.h.20.ln_2.bias": "pytorch_model-00005-of-00006.bin",
240
+ "transformer.h.20.ln_2.weight": "pytorch_model-00005-of-00006.bin",
241
+ "transformer.h.20.mlp.c_fc.bias": "pytorch_model-00005-of-00006.bin",
242
+ "transformer.h.20.mlp.c_fc.weight": "pytorch_model-00005-of-00006.bin",
243
+ "transformer.h.20.mlp.c_proj.bias": "pytorch_model-00005-of-00006.bin",
244
+ "transformer.h.20.mlp.c_proj.weight": "pytorch_model-00005-of-00006.bin",
245
+ "transformer.h.21.attn.bias": "pytorch_model-00005-of-00006.bin",
246
+ "transformer.h.21.attn.c_proj.bias": "pytorch_model-00005-of-00006.bin",
247
+ "transformer.h.21.attn.c_proj.weight": "pytorch_model-00005-of-00006.bin",
248
+ "transformer.h.21.attn.k_proj.weight": "pytorch_model-00005-of-00006.bin",
249
+ "transformer.h.21.attn.masked_bias": "pytorch_model-00005-of-00006.bin",
250
+ "transformer.h.21.attn.q_proj.bias": "pytorch_model-00005-of-00006.bin",
251
+ "transformer.h.21.attn.q_proj.weight": "pytorch_model-00005-of-00006.bin",
252
+ "transformer.h.21.attn.v_proj.bias": "pytorch_model-00005-of-00006.bin",
253
+ "transformer.h.21.attn.v_proj.weight": "pytorch_model-00005-of-00006.bin",
254
+ "transformer.h.21.ln_1.bias": "pytorch_model-00005-of-00006.bin",
255
+ "transformer.h.21.ln_1.weight": "pytorch_model-00005-of-00006.bin",
256
+ "transformer.h.21.ln_2.bias": "pytorch_model-00005-of-00006.bin",
257
+ "transformer.h.21.ln_2.weight": "pytorch_model-00005-of-00006.bin",
258
+ "transformer.h.21.mlp.c_fc.bias": "pytorch_model-00005-of-00006.bin",
259
+ "transformer.h.21.mlp.c_fc.weight": "pytorch_model-00005-of-00006.bin",
260
+ "transformer.h.21.mlp.c_proj.bias": "pytorch_model-00006-of-00006.bin",
261
+ "transformer.h.21.mlp.c_proj.weight": "pytorch_model-00006-of-00006.bin",
262
+ "transformer.h.22.attn.bias": "pytorch_model-00006-of-00006.bin",
263
+ "transformer.h.22.attn.c_proj.bias": "pytorch_model-00006-of-00006.bin",
264
+ "transformer.h.22.attn.c_proj.weight": "pytorch_model-00006-of-00006.bin",
265
+ "transformer.h.22.attn.k_proj.weight": "pytorch_model-00006-of-00006.bin",
266
+ "transformer.h.22.attn.masked_bias": "pytorch_model-00006-of-00006.bin",
267
+ "transformer.h.22.attn.q_proj.bias": "pytorch_model-00006-of-00006.bin",
268
+ "transformer.h.22.attn.q_proj.weight": "pytorch_model-00006-of-00006.bin",
269
+ "transformer.h.22.attn.v_proj.bias": "pytorch_model-00006-of-00006.bin",
270
+ "transformer.h.22.attn.v_proj.weight": "pytorch_model-00006-of-00006.bin",
271
+ "transformer.h.22.ln_1.bias": "pytorch_model-00006-of-00006.bin",
272
+ "transformer.h.22.ln_1.weight": "pytorch_model-00006-of-00006.bin",
273
+ "transformer.h.22.ln_2.bias": "pytorch_model-00006-of-00006.bin",
274
+ "transformer.h.22.ln_2.weight": "pytorch_model-00006-of-00006.bin",
275
+ "transformer.h.22.mlp.c_fc.bias": "pytorch_model-00006-of-00006.bin",
276
+ "transformer.h.22.mlp.c_fc.weight": "pytorch_model-00006-of-00006.bin",
277
+ "transformer.h.22.mlp.c_proj.bias": "pytorch_model-00006-of-00006.bin",
278
+ "transformer.h.22.mlp.c_proj.weight": "pytorch_model-00006-of-00006.bin",
279
+ "transformer.h.23.attn.bias": "pytorch_model-00006-of-00006.bin",
280
+ "transformer.h.23.attn.c_proj.bias": "pytorch_model-00006-of-00006.bin",
281
+ "transformer.h.23.attn.c_proj.weight": "pytorch_model-00006-of-00006.bin",
282
+ "transformer.h.23.attn.k_proj.weight": "pytorch_model-00006-of-00006.bin",
283
+ "transformer.h.23.attn.masked_bias": "pytorch_model-00006-of-00006.bin",
284
+ "transformer.h.23.attn.q_proj.bias": "pytorch_model-00006-of-00006.bin",
285
+ "transformer.h.23.attn.q_proj.weight": "pytorch_model-00006-of-00006.bin",
286
+ "transformer.h.23.attn.v_proj.bias": "pytorch_model-00006-of-00006.bin",
287
+ "transformer.h.23.attn.v_proj.weight": "pytorch_model-00006-of-00006.bin",
288
+ "transformer.h.23.ln_1.bias": "pytorch_model-00006-of-00006.bin",
289
+ "transformer.h.23.ln_1.weight": "pytorch_model-00006-of-00006.bin",
290
+ "transformer.h.23.ln_2.bias": "pytorch_model-00006-of-00006.bin",
291
+ "transformer.h.23.ln_2.weight": "pytorch_model-00006-of-00006.bin",
292
+ "transformer.h.23.mlp.c_fc.bias": "pytorch_model-00006-of-00006.bin",
293
+ "transformer.h.23.mlp.c_fc.weight": "pytorch_model-00006-of-00006.bin",
294
+ "transformer.h.23.mlp.c_proj.bias": "pytorch_model-00006-of-00006.bin",
295
+ "transformer.h.23.mlp.c_proj.weight": "pytorch_model-00006-of-00006.bin",
296
+ "transformer.h.3.attn.bias": "pytorch_model-00002-of-00006.bin",
297
+ "transformer.h.3.attn.c_proj.bias": "pytorch_model-00002-of-00006.bin",
298
+ "transformer.h.3.attn.c_proj.weight": "pytorch_model-00002-of-00006.bin",
299
+ "transformer.h.3.attn.k_proj.weight": "pytorch_model-00002-of-00006.bin",
300
+ "transformer.h.3.attn.masked_bias": "pytorch_model-00002-of-00006.bin",
301
+ "transformer.h.3.attn.q_proj.bias": "pytorch_model-00002-of-00006.bin",
302
+ "transformer.h.3.attn.q_proj.weight": "pytorch_model-00002-of-00006.bin",
303
+ "transformer.h.3.attn.v_proj.bias": "pytorch_model-00002-of-00006.bin",
304
+ "transformer.h.3.attn.v_proj.weight": "pytorch_model-00002-of-00006.bin",
305
+ "transformer.h.3.ln_1.bias": "pytorch_model-00002-of-00006.bin",
306
+ "transformer.h.3.ln_1.weight": "pytorch_model-00002-of-00006.bin",
307
+ "transformer.h.3.ln_2.bias": "pytorch_model-00002-of-00006.bin",
308
+ "transformer.h.3.ln_2.weight": "pytorch_model-00002-of-00006.bin",
309
+ "transformer.h.3.mlp.c_fc.bias": "pytorch_model-00002-of-00006.bin",
310
+ "transformer.h.3.mlp.c_fc.weight": "pytorch_model-00002-of-00006.bin",
311
+ "transformer.h.3.mlp.c_proj.bias": "pytorch_model-00002-of-00006.bin",
312
+ "transformer.h.3.mlp.c_proj.weight": "pytorch_model-00002-of-00006.bin",
313
+ "transformer.h.4.attn.bias": "pytorch_model-00002-of-00006.bin",
314
+ "transformer.h.4.attn.c_proj.bias": "pytorch_model-00002-of-00006.bin",
315
+ "transformer.h.4.attn.c_proj.weight": "pytorch_model-00002-of-00006.bin",
316
+ "transformer.h.4.attn.k_proj.weight": "pytorch_model-00002-of-00006.bin",
317
+ "transformer.h.4.attn.masked_bias": "pytorch_model-00002-of-00006.bin",
318
+ "transformer.h.4.attn.q_proj.bias": "pytorch_model-00002-of-00006.bin",
319
+ "transformer.h.4.attn.q_proj.weight": "pytorch_model-00002-of-00006.bin",
320
+ "transformer.h.4.attn.v_proj.bias": "pytorch_model-00002-of-00006.bin",
321
+ "transformer.h.4.attn.v_proj.weight": "pytorch_model-00002-of-00006.bin",
322
+ "transformer.h.4.ln_1.bias": "pytorch_model-00002-of-00006.bin",
323
+ "transformer.h.4.ln_1.weight": "pytorch_model-00002-of-00006.bin",
324
+ "transformer.h.4.ln_2.bias": "pytorch_model-00002-of-00006.bin",
325
+ "transformer.h.4.ln_2.weight": "pytorch_model-00002-of-00006.bin",
326
+ "transformer.h.4.mlp.c_fc.bias": "pytorch_model-00002-of-00006.bin",
327
+ "transformer.h.4.mlp.c_fc.weight": "pytorch_model-00002-of-00006.bin",
328
+ "transformer.h.4.mlp.c_proj.bias": "pytorch_model-00002-of-00006.bin",
329
+ "transformer.h.4.mlp.c_proj.weight": "pytorch_model-00002-of-00006.bin",
330
+ "transformer.h.5.attn.bias": "pytorch_model-00002-of-00006.bin",
331
+ "transformer.h.5.attn.c_proj.bias": "pytorch_model-00002-of-00006.bin",
332
+ "transformer.h.5.attn.c_proj.weight": "pytorch_model-00002-of-00006.bin",
333
+ "transformer.h.5.attn.k_proj.weight": "pytorch_model-00002-of-00006.bin",
334
+ "transformer.h.5.attn.masked_bias": "pytorch_model-00002-of-00006.bin",
335
+ "transformer.h.5.attn.q_proj.bias": "pytorch_model-00002-of-00006.bin",
336
+ "transformer.h.5.attn.q_proj.weight": "pytorch_model-00002-of-00006.bin",
337
+ "transformer.h.5.attn.v_proj.bias": "pytorch_model-00002-of-00006.bin",
338
+ "transformer.h.5.attn.v_proj.weight": "pytorch_model-00002-of-00006.bin",
339
+ "transformer.h.5.ln_1.bias": "pytorch_model-00002-of-00006.bin",
340
+ "transformer.h.5.ln_1.weight": "pytorch_model-00002-of-00006.bin",
341
+ "transformer.h.5.ln_2.bias": "pytorch_model-00002-of-00006.bin",
342
+ "transformer.h.5.ln_2.weight": "pytorch_model-00002-of-00006.bin",
343
+ "transformer.h.5.mlp.c_fc.bias": "pytorch_model-00002-of-00006.bin",
344
+ "transformer.h.5.mlp.c_fc.weight": "pytorch_model-00002-of-00006.bin",
345
+ "transformer.h.5.mlp.c_proj.bias": "pytorch_model-00002-of-00006.bin",
346
+ "transformer.h.5.mlp.c_proj.weight": "pytorch_model-00002-of-00006.bin",
347
+ "transformer.h.6.attn.bias": "pytorch_model-00002-of-00006.bin",
348
+ "transformer.h.6.attn.c_proj.bias": "pytorch_model-00002-of-00006.bin",
349
+ "transformer.h.6.attn.c_proj.weight": "pytorch_model-00002-of-00006.bin",
350
+ "transformer.h.6.attn.k_proj.weight": "pytorch_model-00002-of-00006.bin",
351
+ "transformer.h.6.attn.masked_bias": "pytorch_model-00002-of-00006.bin",
352
+ "transformer.h.6.attn.q_proj.bias": "pytorch_model-00002-of-00006.bin",
353
+ "transformer.h.6.attn.q_proj.weight": "pytorch_model-00002-of-00006.bin",
354
+ "transformer.h.6.attn.v_proj.bias": "pytorch_model-00002-of-00006.bin",
355
+ "transformer.h.6.attn.v_proj.weight": "pytorch_model-00002-of-00006.bin",
356
+ "transformer.h.6.ln_1.bias": "pytorch_model-00002-of-00006.bin",
357
+ "transformer.h.6.ln_1.weight": "pytorch_model-00002-of-00006.bin",
358
+ "transformer.h.6.ln_2.bias": "pytorch_model-00002-of-00006.bin",
359
+ "transformer.h.6.ln_2.weight": "pytorch_model-00002-of-00006.bin",
360
+ "transformer.h.6.mlp.c_fc.bias": "pytorch_model-00002-of-00006.bin",
361
+ "transformer.h.6.mlp.c_fc.weight": "pytorch_model-00002-of-00006.bin",
362
+ "transformer.h.6.mlp.c_proj.bias": "pytorch_model-00002-of-00006.bin",
363
+ "transformer.h.6.mlp.c_proj.weight": "pytorch_model-00002-of-00006.bin",
364
+ "transformer.h.7.attn.bias": "pytorch_model-00002-of-00006.bin",
365
+ "transformer.h.7.attn.c_proj.bias": "pytorch_model-00002-of-00006.bin",
366
+ "transformer.h.7.attn.c_proj.weight": "pytorch_model-00002-of-00006.bin",
367
+ "transformer.h.7.attn.k_proj.weight": "pytorch_model-00002-of-00006.bin",
368
+ "transformer.h.7.attn.masked_bias": "pytorch_model-00002-of-00006.bin",
369
+ "transformer.h.7.attn.q_proj.bias": "pytorch_model-00002-of-00006.bin",
370
+ "transformer.h.7.attn.q_proj.weight": "pytorch_model-00002-of-00006.bin",
371
+ "transformer.h.7.attn.v_proj.bias": "pytorch_model-00002-of-00006.bin",
372
+ "transformer.h.7.attn.v_proj.weight": "pytorch_model-00002-of-00006.bin",
373
+ "transformer.h.7.ln_1.bias": "pytorch_model-00002-of-00006.bin",
374
+ "transformer.h.7.ln_1.weight": "pytorch_model-00002-of-00006.bin",
375
+ "transformer.h.7.ln_2.bias": "pytorch_model-00002-of-00006.bin",
376
+ "transformer.h.7.ln_2.weight": "pytorch_model-00002-of-00006.bin",
377
+ "transformer.h.7.mlp.c_fc.bias": "pytorch_model-00003-of-00006.bin",
378
+ "transformer.h.7.mlp.c_fc.weight": "pytorch_model-00003-of-00006.bin",
379
+ "transformer.h.7.mlp.c_proj.bias": "pytorch_model-00003-of-00006.bin",
380
+ "transformer.h.7.mlp.c_proj.weight": "pytorch_model-00003-of-00006.bin",
381
+ "transformer.h.8.attn.bias": "pytorch_model-00003-of-00006.bin",
382
+ "transformer.h.8.attn.c_proj.bias": "pytorch_model-00003-of-00006.bin",
383
+ "transformer.h.8.attn.c_proj.weight": "pytorch_model-00003-of-00006.bin",
384
+ "transformer.h.8.attn.k_proj.weight": "pytorch_model-00003-of-00006.bin",
385
+ "transformer.h.8.attn.masked_bias": "pytorch_model-00003-of-00006.bin",
386
+ "transformer.h.8.attn.q_proj.bias": "pytorch_model-00003-of-00006.bin",
387
+ "transformer.h.8.attn.q_proj.weight": "pytorch_model-00003-of-00006.bin",
388
+ "transformer.h.8.attn.v_proj.bias": "pytorch_model-00003-of-00006.bin",
389
+ "transformer.h.8.attn.v_proj.weight": "pytorch_model-00003-of-00006.bin",
390
+ "transformer.h.8.ln_1.bias": "pytorch_model-00003-of-00006.bin",
391
+ "transformer.h.8.ln_1.weight": "pytorch_model-00003-of-00006.bin",
392
+ "transformer.h.8.ln_2.bias": "pytorch_model-00003-of-00006.bin",
393
+ "transformer.h.8.ln_2.weight": "pytorch_model-00003-of-00006.bin",
394
+ "transformer.h.8.mlp.c_fc.bias": "pytorch_model-00003-of-00006.bin",
395
+ "transformer.h.8.mlp.c_fc.weight": "pytorch_model-00003-of-00006.bin",
396
+ "transformer.h.8.mlp.c_proj.bias": "pytorch_model-00003-of-00006.bin",
397
+ "transformer.h.8.mlp.c_proj.weight": "pytorch_model-00003-of-00006.bin",
398
+ "transformer.h.9.attn.bias": "pytorch_model-00003-of-00006.bin",
399
+ "transformer.h.9.attn.c_proj.bias": "pytorch_model-00003-of-00006.bin",
400
+ "transformer.h.9.attn.c_proj.weight": "pytorch_model-00003-of-00006.bin",
401
+ "transformer.h.9.attn.k_proj.weight": "pytorch_model-00003-of-00006.bin",
402
+ "transformer.h.9.attn.masked_bias": "pytorch_model-00003-of-00006.bin",
403
+ "transformer.h.9.attn.q_proj.bias": "pytorch_model-00003-of-00006.bin",
404
+ "transformer.h.9.attn.q_proj.weight": "pytorch_model-00003-of-00006.bin",
405
+ "transformer.h.9.attn.v_proj.bias": "pytorch_model-00003-of-00006.bin",
406
+ "transformer.h.9.attn.v_proj.weight": "pytorch_model-00003-of-00006.bin",
407
+ "transformer.h.9.ln_1.bias": "pytorch_model-00003-of-00006.bin",
408
+ "transformer.h.9.ln_1.weight": "pytorch_model-00003-of-00006.bin",
409
+ "transformer.h.9.ln_2.bias": "pytorch_model-00003-of-00006.bin",
410
+ "transformer.h.9.ln_2.weight": "pytorch_model-00003-of-00006.bin",
411
+ "transformer.h.9.mlp.c_fc.bias": "pytorch_model-00003-of-00006.bin",
412
+ "transformer.h.9.mlp.c_fc.weight": "pytorch_model-00003-of-00006.bin",
413
+ "transformer.h.9.mlp.c_proj.bias": "pytorch_model-00003-of-00006.bin",
414
+ "transformer.h.9.mlp.c_proj.weight": "pytorch_model-00003-of-00006.bin",
415
+ "transformer.ln_f.bias": "pytorch_model-00006-of-00006.bin",
416
+ "transformer.ln_f.weight": "pytorch_model-00006-of-00006.bin",
417
+ "transformer.wpe.weight": "pytorch_model-00001-of-00006.bin",
418
+ "transformer.wte.weight": "pytorch_model-00001-of-00006.bin"
419
+ }
420
+ }