Matt commited on
Commit
3900116
1 Parent(s): a68ca4b

Preparations for transition to in-library checkpoint

Browse files
Files changed (4) hide show
  1. README.md +2 -4
  2. config.json +9 -1
  3. configuration_RW.py +92 -20
  4. modelling_RW.py → modeling_RW.py +400 -244
README.md CHANGED
@@ -16,7 +16,7 @@ license: apache-2.0
16
 
17
  *Paper coming soon 😊.*
18
 
19
-
20
 
21
  # Call for Proposals : Falcon 40B - World's Top Ranked AI Model Empowers Exceptional Use Cases with Training Compute Power in Call for Proposals
22
 
@@ -57,7 +57,6 @@ pipeline = transformers.pipeline(
57
  model=model,
58
  tokenizer=tokenizer,
59
  torch_dtype=torch.bfloat16,
60
- trust_remote_code=True,
61
  device_map="auto",
62
  )
63
  sequences = pipeline(
@@ -128,7 +127,6 @@ pipeline = transformers.pipeline(
128
  model=model,
129
  tokenizer=tokenizer,
130
  torch_dtype=torch.bfloat16,
131
- trust_remote_code=True,
132
  device_map="auto",
133
  )
134
  sequences = pipeline(
@@ -269,4 +267,4 @@ To learn more about the pretraining dataset, see the 📓 [RefinedWeb paper](htt
269
  Falcon-40B is made available under the Apache 2.0 license.
270
 
271
  ## Contact
272
- falconllm@tii.ae
 
16
 
17
  *Paper coming soon 😊.*
18
 
19
+ ⚠️ Falcon is now available as a core model in the `transformers` library! To use the in-library version, please install the latest version of `transformers` with `pip install git+https://github.com/ huggingface/transformers.git`, then simply remove the `trust_remote_code=True` argument from `from_pretrained()`.
20
 
21
  # Call for Proposals : Falcon 40B - World's Top Ranked AI Model Empowers Exceptional Use Cases with Training Compute Power in Call for Proposals
22
 
 
57
  model=model,
58
  tokenizer=tokenizer,
59
  torch_dtype=torch.bfloat16,
 
60
  device_map="auto",
61
  )
62
  sequences = pipeline(
 
127
  model=model,
128
  tokenizer=tokenizer,
129
  torch_dtype=torch.bfloat16,
 
130
  device_map="auto",
131
  )
132
  sequences = pipeline(
 
267
  Falcon-40B is made available under the Apache 2.0 license.
268
 
269
  ## Contact
270
+ falconllm@tii.ae
config.json CHANGED
@@ -5,6 +5,14 @@
5
  "FalconForCausalLM"
6
  ],
7
  "attention_dropout": 0.0,
 
 
 
 
 
 
 
 
8
  "bias": false,
9
  "bos_token_id": 11,
10
  "eos_token_id": 11,
@@ -22,4 +30,4 @@
22
  "transformers_version": "4.27.4",
23
  "use_cache": true,
24
  "vocab_size": 65024
25
- }
 
5
  "FalconForCausalLM"
6
  ],
7
  "attention_dropout": 0.0,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_RW.RWConfig",
10
+ "AutoModel": "modeling_RW.RWModel",
11
+ "AutoModelForSequenceClassification": "modeling_RW.RWForSequenceClassification",
12
+ "AutoModelForTokenClassification": "modeling_RW.RWForTokenClassification",
13
+ "AutoModelForQuestionAnswering": "modeling_RW.RWForQuestionAnswering",
14
+ "AutoModelForCausalLM": "modeling_RW.RWForCausalLM"
15
+ },
16
  "bias": false,
17
  "bos_token_id": 11,
18
  "eos_token_id": 11,
 
30
  "transformers_version": "4.27.4",
31
  "use_cache": true,
32
  "vocab_size": 65024
33
+ }
configuration_RW.py CHANGED
@@ -1,5 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
@@ -12,63 +12,135 @@
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
- """ Bloom configuration"""
16
  from transformers.configuration_utils import PretrainedConfig
17
  from transformers.utils import logging
18
 
19
 
20
  logger = logging.get_logger(__name__)
21
 
 
 
 
 
 
22
 
23
  class RWConfig(PretrainedConfig):
24
- model_type = "RefinedWeb"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  keys_to_ignore_at_inference = ["past_key_values"]
26
- attribute_map = {
27
- "num_hidden_layers": "n_layer",
28
- "num_attention_heads": "n_head",
29
- }
30
 
31
  def __init__(
32
  self,
33
- vocab_size=250880,
34
- hidden_size=64,
35
- n_layer=2,
36
- n_head=8,
37
  layer_norm_epsilon=1e-5,
38
  initializer_range=0.02,
39
  use_cache=True,
40
- bos_token_id=1,
41
- eos_token_id=2,
42
- apply_residual_connection_post_layernorm=False,
43
  hidden_dropout=0.0,
44
  attention_dropout=0.0,
45
- n_head_kv=None,
46
  alibi=False,
 
 
 
 
 
 
47
  **kwargs,
48
  ):
49
  self.vocab_size = vocab_size
50
  # Backward compatibility with n_embed kwarg
51
  n_embed = kwargs.pop("n_embed", None)
52
  self.hidden_size = hidden_size if n_embed is None else n_embed
53
- self.n_layer = n_layer
54
- self.n_head = n_head
55
  self.layer_norm_epsilon = layer_norm_epsilon
56
  self.initializer_range = initializer_range
57
  self.use_cache = use_cache
58
- self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
59
  self.hidden_dropout = hidden_dropout
60
  self.attention_dropout = attention_dropout
61
 
62
  self.bos_token_id = bos_token_id
63
  self.eos_token_id = eos_token_id
64
- self.n_head_kv = n_head if n_head_kv is None else n_head_kv
65
  self.alibi = alibi
 
 
 
 
66
 
67
  super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
68
 
69
  @property
70
  def head_dim(self):
71
- return self.hidden_size // self.n_head
72
 
73
  @property
74
  def rotary(self):
 
1
  # coding=utf-8
2
+ # Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
+ """ Falcon configuration"""
16
  from transformers.configuration_utils import PretrainedConfig
17
  from transformers.utils import logging
18
 
19
 
20
  logger = logging.get_logger(__name__)
21
 
22
+ FALCON_PRETRAINED_CONFIG_ARCHIVE_MAP = {
23
+ "tiiuae/falcon-40b": "https://huggingface.co/tiiuae/falcon-40b/resolve/main/config.json",
24
+ "tiiuae/falcon-7b": "https://huggingface.co/tiiuae/falcon-7b/resolve/main/config.json",
25
+ }
26
+
27
 
28
  class RWConfig(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`FalconModel`]. It is used to instantiate a Falcon
31
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
32
+ defaults will yield a similar configuration to that of the
33
+ [tiiuae/falcon-7b](https://huggingface.co/tiiuae/falcon-7b) architecture.
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 65024):
41
+ Vocabulary size of the Falcon model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`FalconModel`]
43
+ hidden_size (`int`, *optional*, defaults to 4544):
44
+ Dimension of the hidden representations.
45
+ num_hidden_layers (`int`, *optional*, defaults to 32):
46
+ Number of hidden layers in the Transformer decoder.
47
+ num_attention_heads (`int`, *optional*, defaults to 71):
48
+ Number of attention heads for each attention layer in the Transformer encoder.
49
+ initializer_range (`float`, *optional*, defaults to 0.02):
50
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
51
+ use_cache (`bool`, *optional*, defaults to `True`):
52
+ Whether the model should return the last key/values attentions (not used by all models). Only relevant if
53
+ `config.is_decoder=True`.
54
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
55
+ The epsilon used by the layer normalization layers.
56
+ hidden_dropout (`float`, *optional*, defaults to 0.0):
57
+ The dropout probability for MLP layers.
58
+ attention_dropout (`float`, *optional*, defaults to 0.0):
59
+ The dropout probability for attention layers.
60
+ num_kv_heads (`int`, *optional*):
61
+ Number of key-value heads to use per attention layer. If unset, defaults to the same value as
62
+ `num_attention_heads`.
63
+ alibi (`bool`, *optional*, defaults to `False`):
64
+ Whether to use ALiBi positional biases during self-attention.
65
+ new_decoder_architecture (`bool`, *optional*, defaults to `False`):
66
+ Whether to use the new (Falcon-40B) decoder architecture. If `True`, the `multi_query` and `parallel_attn`
67
+ arguments are ignored, as the new decoder always uses parallel attention.
68
+ multi_query (`bool`, *optional*, defaults to `True`):
69
+ Whether to use multi-query attention in the decoder. Ignored when `new_decoder_architecture` is `True`.
70
+ parallel_attn (`bool`, *optional*, defaults to `True`):
71
+ Whether to compute attention in parallel with the feedforward layer. If False, they are consecutive
72
+ instead, as in the original Transformer architecture. Ignored when `new_decoder_architecture` is `True`.
73
+ bias (`bool`, *optional*, defaults to `False`):
74
+ Whether to use bias on Linear layers.
75
+ bos_token_id (`int`, *optional*, defaults to 11):
76
+ The id of the "beginning-of-sequence" token.
77
+ eos_token_id (`int`, *optional*, defaults to 11):
78
+ The id of the "end-of-sequence" token.
79
+
80
+ Example:
81
+
82
+ ```python
83
+ >>> from transformers import FalconModel, RWConfig
84
+
85
+ >>> # Initializing a small (2-layer) Falcon configuration
86
+ >>> configuration = RWConfig(num_hidden_layers=2)
87
+
88
+ >>> # Initializing a model from the small configuration
89
+ >>> model = FalconModel(configuration)
90
+
91
+ >>> # Accessing the model configuration
92
+ >>> configuration = model.config
93
+ ```"""
94
+ model_type = "falcon"
95
  keys_to_ignore_at_inference = ["past_key_values"]
 
 
 
 
96
 
97
  def __init__(
98
  self,
99
+ vocab_size=65024,
100
+ hidden_size=4544,
101
+ num_hidden_layers=32,
102
+ num_attention_heads=71,
103
  layer_norm_epsilon=1e-5,
104
  initializer_range=0.02,
105
  use_cache=True,
 
 
 
106
  hidden_dropout=0.0,
107
  attention_dropout=0.0,
108
+ num_kv_heads=None,
109
  alibi=False,
110
+ new_decoder_architecture=False,
111
+ multi_query=True,
112
+ parallel_attn=True,
113
+ bias=False,
114
+ bos_token_id=11,
115
+ eos_token_id=11,
116
  **kwargs,
117
  ):
118
  self.vocab_size = vocab_size
119
  # Backward compatibility with n_embed kwarg
120
  n_embed = kwargs.pop("n_embed", None)
121
  self.hidden_size = hidden_size if n_embed is None else n_embed
122
+ self.num_hidden_layers = num_hidden_layers
123
+ self.num_attention_heads = num_attention_heads
124
  self.layer_norm_epsilon = layer_norm_epsilon
125
  self.initializer_range = initializer_range
126
  self.use_cache = use_cache
 
127
  self.hidden_dropout = hidden_dropout
128
  self.attention_dropout = attention_dropout
129
 
130
  self.bos_token_id = bos_token_id
131
  self.eos_token_id = eos_token_id
132
+ self.num_kv_heads = num_attention_heads if num_kv_heads is None else num_kv_heads
133
  self.alibi = alibi
134
+ self.new_decoder_architecture = new_decoder_architecture
135
+ self.multi_query = multi_query # Ignored when new_decoder_architecture is True
136
+ self.parallel_attn = parallel_attn
137
+ self.bias = bias
138
 
139
  super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
140
 
141
  @property
142
  def head_dim(self):
143
+ return self.hidden_size // self.num_attention_heads
144
 
145
  @property
146
  def rotary(self):
modelling_RW.py → modeling_RW.py RENAMED
@@ -1,9 +1,20 @@
1
- # port of models described in RW
2
- # We use the bloom model as a starting point for these model.
3
- # Please refer to the bloom models for usage instructions.
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  import math
6
- import warnings
7
  from typing import Optional, Tuple, Union
8
 
9
  import torch
@@ -20,59 +31,60 @@ from transformers.modeling_outputs import (
20
  TokenClassifierOutput,
21
  )
22
  from transformers.modeling_utils import PreTrainedModel
23
- from transformers.utils import logging
24
  from .configuration_RW import RWConfig
25
 
 
26
  logger = logging.get_logger(__name__)
27
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
29
  # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
30
- class Linear(nn.Linear):
31
  def forward(self, input: torch.Tensor) -> torch.Tensor:
32
- ret = input @ self.weight.T
33
  if self.bias is None:
34
- return ret
35
- else:
36
- return ret + self.bias
37
-
38
 
39
- from einops import rearrange
40
 
41
  # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
42
  def rotate_half(x):
43
  x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
44
- return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0
45
 
46
 
47
- class RotaryEmbedding(torch.nn.Module):
48
  """Implementation of RotaryEmbedding from GPT-NeoX.
49
- This implementation is design to operate on queries and keys that are compatible with
50
- [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format).
51
  """
52
 
53
- def __init__(
54
- self,
55
- head_dim: int,
56
- base=10000,
57
- ):
58
  super().__init__()
59
  inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
60
  self.register_buffer("inv_freq", inv_freq, persistent=False)
61
  self.head_dim = head_dim
62
- self.seq_len_cached = None
63
- self.batch_size_cached = None
64
  self.cos_cached: torch.Tensor | None = None
65
  self.sin_cached: torch.Tensor | None = None
66
 
67
- def cos_sin(
68
- self,
69
- seq_len: int,
70
- device="cuda",
71
- dtype=torch.bfloat16,
72
- ) -> torch.Tensor:
73
- if seq_len != self.seq_len_cached:
74
- self.seq_len_cached = seq_len
75
- t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
76
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
77
  emb = torch.cat((freqs, freqs), dim=-1).to(device)
78
 
@@ -85,36 +97,46 @@ class RotaryEmbedding(torch.nn.Module):
85
  self.cos_cached = self.cos_cached.type(dtype)
86
  self.sin_cached = self.sin_cached.type(dtype)
87
 
88
- return self.cos_cached, self.sin_cached
 
 
 
89
 
90
- def forward(self, q, k):
91
- batch, seq_len, head_dim = q.shape
92
- cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
93
- return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
94
 
95
 
96
  def _make_causal_mask(
97
  input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
98
  ) -> torch.BoolTensor:
 
 
 
 
 
99
  batch_size, target_length = input_ids_shape
100
- mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
101
- # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
102
- seq_ids = torch.arange(target_length, device=device)
103
- mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
104
-
105
- if past_key_values_length > 0:
106
- mask[:, :past_key_values_length] = False
107
 
 
 
 
 
 
 
108
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
109
  return expanded_mask
110
 
111
 
112
- def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
113
- batch_size, src_length = mask.shape
114
- tgt_length = tgt_length if tgt_length is not None else src_length
 
 
 
115
 
116
  expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
117
- return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
118
 
119
 
120
  def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
@@ -145,18 +167,32 @@ def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torc
145
  return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
146
 
147
 
 
148
  def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  out = F.dropout(x, p=prob, training=training)
150
  out = residual + out
151
  return out
152
 
153
 
154
- class Attention(nn.Module):
155
  def __init__(self, config: RWConfig):
156
  super().__init__()
157
 
158
  self.hidden_size = config.hidden_size
159
- self.num_heads = config.n_head
160
  self.head_dim = self.hidden_size // self.num_heads
161
  self.split_size = self.hidden_size
162
  self.hidden_dropout = config.hidden_dropout
@@ -167,59 +203,62 @@ class Attention(nn.Module):
167
  f" {self.num_heads})."
168
  )
169
 
170
- self.maybe_rotary = RotaryEmbedding(config.head_dim) if config.rotary else lambda q, k: (q, k)
171
 
172
  # Layer-wise attention scaling
173
  self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
174
  self.beta = self.inv_norm_factor
175
-
176
- self.query_key_value = Linear(
177
- self.hidden_size,
178
- (config.n_head_kv * 2 + config.n_head) * self.head_dim,
179
- bias=config.bias,
180
- )
181
- self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
 
 
 
182
  self.attention_dropout = nn.Dropout(config.attention_dropout)
183
- self.num_kv = config.n_head_kv
184
 
185
  def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
186
  """
187
- Split the last dimension into (num_heads, head_dim), results share same memory
188
- storage as `fused_qkv`
189
 
190
  Args:
191
  fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
192
 
193
  Returns:
194
- query: [batch_size, seq_length, num_heads, head_dim]
195
- key: [batch_size, seq_length, num_heads, head_dim]
196
  value: [batch_size, seq_length, num_heads, head_dim]
197
  """
198
- batch, seq_len, _ = fused_qkv.shape
199
- qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv + 2, 64)
200
- q = qkv[:, :, :, :-2]
201
- k = qkv[:, :, :, [-2]]
202
- v = qkv[:, :, :, [-1]]
203
- k = torch.broadcast_to(k, q.shape)
204
- v = torch.broadcast_to(v, q.shape)
205
-
206
- q, k, v = [
207
- rearrange(
208
- x,
209
- "batch seq_len group num_heads head_dim ->\
210
- batch seq_len (group num_heads) head_dim",
211
- head_dim=self.head_dim,
212
- )
213
- for x in [q, k, v]
214
- ]
215
- return q, k, v
 
216
 
 
217
  def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
218
  """
219
  Merge heads together over the last dimenstion
220
 
221
  Args:
222
- x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
223
 
224
  Returns:
225
  torch.tensor: [batch_size, seq_length, num_heads * head_dim]
@@ -242,7 +281,7 @@ class Attention(nn.Module):
242
  def forward(
243
  self,
244
  hidden_states: torch.Tensor,
245
- alibi: torch.Tensor,
246
  attention_mask: torch.Tensor,
247
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
248
  head_mask: Optional[torch.Tensor] = None,
@@ -250,106 +289,120 @@ class Attention(nn.Module):
250
  output_attentions: bool = False,
251
  ):
252
  fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
253
-
254
  # 3 x [batch_size, seq_length, num_heads, head_dim]
255
  (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
256
 
257
- batch_size, q_length, _, _ = query_layer.shape
258
 
259
- query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
260
  key_layer = key_layer.transpose(1, 2).reshape(
261
- batch_size * self.num_heads,
262
- q_length,
263
  self.head_dim,
264
  )
265
- value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
266
 
267
- query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
 
268
 
269
  if layer_past is not None:
270
  past_key, past_value = layer_past
271
  # concatenate along seq_length dimension:
272
- # - key: [batch_size * self.num_heads, head_dim, kv_length]
273
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
274
  key_layer = torch.cat((past_key, key_layer), dim=1)
275
  value_layer = torch.cat((past_value, value_layer), dim=1)
276
 
277
  _, kv_length, _ = key_layer.shape
278
-
279
- if use_cache is True:
280
  present = (key_layer, value_layer)
281
  else:
282
  present = None
283
 
 
 
 
 
 
 
284
  if alibi is None:
285
- query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
286
- key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
287
- value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
 
 
288
 
289
- attn_output = F.scaled_dot_product_attention(
290
- query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
291
- )
 
 
 
 
 
 
292
 
293
- x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
294
- x = x.permute(0, 2, 1, 3)
295
- attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim)
296
 
297
  output_tensor = self.dense(attn_output)
298
 
299
- outputs = (output_tensor, present)
300
- assert not output_attentions # not supported.
301
- return outputs
 
 
302
  else:
303
- attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
304
- matmul_result = query_layer @ key_layer.transpose(-1, -2)
305
 
306
  # change view to [batch_size, num_heads, q_length, kv_length]
307
- attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
308
 
309
  # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
310
  input_dtype = attention_scores.dtype
311
  # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
312
  if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
313
  attention_scores = attention_scores.to(torch.float32)
314
- # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
315
- attention_probs = F.softmax(
316
- (attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor
317
- + attention_mask_float,
318
- dim=-1,
319
- dtype=hidden_states.dtype,
320
- )
321
  # [batch_size, num_heads, q_length, kv_length]
322
  attention_probs = self.attention_dropout(attention_probs)
323
 
324
  if head_mask is not None:
325
  attention_probs = attention_probs * head_mask
326
 
327
- # change view [batch_size x num_heads, q_length, kv_length]
328
- attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
329
 
330
  # matmul: [batch_size * num_heads, q_length, head_dim]
331
- context_layer = attention_probs_reshaped @ value_layer
332
 
333
  # change view [batch_size, num_heads, q_length, head_dim]
334
  context_layer = self._merge_heads(context_layer)
335
 
336
  output_tensor = self.dense(context_layer)
337
 
338
- outputs = (output_tensor, present)
339
  if output_attentions:
340
- outputs += (attention_probs,)
341
-
342
- return outputs
343
 
344
 
345
- class MLP(nn.Module):
346
  def __init__(self, config: RWConfig):
347
  super().__init__()
348
  hidden_size = config.hidden_size
349
 
350
- self.dense_h_to_4h = Linear(hidden_size, 4 * hidden_size, bias=config.bias)
351
  self.act = nn.GELU()
352
- self.dense_4h_to_h = Linear(4 * hidden_size, hidden_size, bias=config.bias)
353
  self.hidden_dropout = config.hidden_dropout
354
 
355
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -358,43 +411,47 @@ class MLP(nn.Module):
358
  return x
359
 
360
 
361
- class DecoderLayer(nn.Module):
362
  def __init__(self, config: RWConfig):
363
  super().__init__()
364
  hidden_size = config.hidden_size
365
-
366
- self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
367
- self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
368
-
369
- self.num_heads = config.n_head
370
- self.self_attention = Attention(config)
371
-
372
- self.mlp = MLP(config)
373
-
374
- self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
375
  self.hidden_dropout = config.hidden_dropout
376
-
377
  self.config = config
378
 
 
 
 
 
 
 
 
 
 
 
379
  def forward(
380
  self,
381
  hidden_states: torch.Tensor,
382
- alibi: torch.Tensor,
383
  attention_mask: torch.Tensor,
384
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
385
  head_mask: Optional[torch.Tensor] = None,
386
  use_cache: bool = False,
387
  output_attentions: bool = False,
388
  ):
389
-
390
- ln_attn = self.ln_attn(hidden_states)
391
- ln_mlp = self.ln_mlp(hidden_states)
392
-
393
  residual = hidden_states
394
 
 
 
 
 
 
 
395
  # Self attention.
396
  attn_outputs = self.self_attention(
397
- ln_attn,
398
  layer_past=layer_past,
399
  attention_mask=attention_mask,
400
  alibi=alibi,
@@ -405,14 +462,24 @@ class DecoderLayer(nn.Module):
405
 
406
  attention_output = attn_outputs[0]
407
 
 
 
 
 
 
 
 
 
 
408
  outputs = attn_outputs[1:]
409
 
410
  # MLP.
411
- mlp_output = self.mlp(ln_mlp)
412
 
413
- output = dropout_add(
414
- mlp_output + attention_output, residual, self.config.hidden_dropout, training=self.training
415
- )
 
416
 
417
  if use_cache:
418
  outputs = (output,) + outputs
@@ -422,8 +489,77 @@ class DecoderLayer(nn.Module):
422
  return outputs # hidden_states, present, attentions
423
 
424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  class RWPreTrainedModel(PreTrainedModel):
426
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
427
  """
428
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
429
  models.
@@ -432,14 +568,14 @@ class RWPreTrainedModel(PreTrainedModel):
432
  config_class = RWConfig
433
  base_model_prefix = "transformer"
434
  supports_gradient_checkpointing = True
435
- _no_split_modules = ["DecoderLayer"]
436
 
437
  def __init__(self, *inputs, **kwargs):
438
  super().__init__(*inputs, **kwargs)
439
 
440
  def _init_weights(self, module: nn.Module):
441
  """Initialize the weights."""
442
- if isinstance(module, nn.Linear) or isinstance(module, Linear):
443
  # Slightly different from the TF version which uses truncated_normal for initialization
444
  # cf https://github.com/pytorch/pytorch/pull/5617
445
  module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
@@ -453,26 +589,28 @@ class RWPreTrainedModel(PreTrainedModel):
453
  module.bias.data.zero_()
454
  module.weight.data.fill_(1.0)
455
 
 
456
  def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
457
  if isinstance(module, RWModel):
458
  module.gradient_checkpointing = value
459
 
460
  @staticmethod
461
- def _convert_to_standard_cache(
462
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
463
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
464
  """
465
  Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
466
  num_heads, ...]))
467
  """
468
- batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
 
 
 
469
  num_heads = batch_size_times_num_heads // batch_size
470
- # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
471
- # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
472
  return tuple(
473
  (
474
- layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
475
- layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
476
  )
477
  for layer_past in past_key_value
478
  )
@@ -481,32 +619,35 @@ class RWPreTrainedModel(PreTrainedModel):
481
  def _convert_to_rw_cache(
482
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
483
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
484
- batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
485
  batch_size_times_num_heads = batch_size * num_heads
486
- # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
487
- # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
488
  return tuple(
489
  (
490
- layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
491
- layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
492
  )
493
  for layer_past in past_key_value
494
  )
495
 
496
 
 
 
 
 
497
  class RWModel(RWPreTrainedModel):
498
  def __init__(self, config: RWConfig):
499
  super().__init__(config)
500
 
501
  self.embed_dim = config.hidden_size
502
- self.num_heads = config.n_head
503
- self.alibi = config.alibi
504
 
505
  # Embedding + LN Embedding
506
  self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
507
 
508
  # Transformer blocks
509
- self.h = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
510
 
511
  # Final Layer Norm
512
  self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -519,22 +660,31 @@ class RWModel(RWPreTrainedModel):
519
  def get_input_embeddings(self):
520
  return self.word_embeddings
521
 
 
522
  def _prepare_attn_mask(
523
- self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
524
  ) -> torch.BoolTensor:
525
- # create causal mask
526
- # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
 
 
 
 
 
 
 
 
527
  combined_attention_mask = None
528
  device = attention_mask.device
529
- _, src_length = input_shape
530
 
531
- if src_length > 1:
532
  combined_attention_mask = _make_causal_mask(
533
  input_shape, device=device, past_key_values_length=past_key_values_length
534
  )
535
 
536
- # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
537
- expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
538
  combined_attention_mask = (
539
  expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
540
  )
@@ -544,6 +694,12 @@ class RWModel(RWPreTrainedModel):
544
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
545
  self.word_embeddings = new_embeddings
546
 
 
 
 
 
 
 
547
  def forward(
548
  self,
549
  input_ids: Optional[torch.LongTensor] = None,
@@ -555,18 +711,7 @@ class RWModel(RWPreTrainedModel):
555
  output_attentions: Optional[bool] = None,
556
  output_hidden_states: Optional[bool] = None,
557
  return_dict: Optional[bool] = None,
558
- **deprecated_arguments,
559
  ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
560
- if deprecated_arguments.pop("position_ids", False) is not False:
561
- # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
562
- warnings.warn(
563
- "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
564
- " passing `position_ids`.",
565
- FutureWarning,
566
- )
567
- if len(deprecated_arguments) > 0:
568
- raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
569
-
570
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
571
  output_hidden_states = (
572
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -585,12 +730,14 @@ class RWModel(RWPreTrainedModel):
585
 
586
  if past_key_values is None:
587
  past_key_values = tuple([None] * len(self.h))
 
 
588
 
589
  # Prepare head mask if needed
590
  # 1.0 in head_mask indicate we keep the head
591
  # attention_probs has shape batch_size x num_heads x N x N
592
  # head_mask has shape n_layer x batch x num_heads x N x N
593
- head_mask = self.get_head_mask(head_mask, self.config.n_layer)
594
 
595
  if inputs_embeds is None:
596
  inputs_embeds = self.word_embeddings(input_ids)
@@ -602,17 +749,15 @@ class RWModel(RWPreTrainedModel):
602
  all_hidden_states = () if output_hidden_states else None
603
 
604
  # Compute alibi tensor: check build_alibi_tensor documentation
605
- seq_length_with_past = seq_length
606
  past_key_values_length = 0
607
  if past_key_values[0] is not None:
608
- past_key_values_length = past_key_values[0][0].shape[2]
609
- seq_length_with_past = seq_length_with_past + past_key_values_length
610
  if attention_mask is None:
611
- attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
612
  else:
613
  attention_mask = attention_mask.to(hidden_states.device)
614
 
615
- if self.alibi:
616
  alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
617
  else:
618
  alibi = None
@@ -624,12 +769,10 @@ class RWModel(RWPreTrainedModel):
624
  )
625
 
626
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
627
-
628
  if output_hidden_states:
629
  all_hidden_states = all_hidden_states + (hidden_states,)
630
 
631
  if self.gradient_checkpointing and self.training:
632
-
633
  if use_cache:
634
  logger.warning(
635
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
@@ -674,6 +817,9 @@ class RWModel(RWPreTrainedModel):
674
  if output_hidden_states:
675
  all_hidden_states = all_hidden_states + (hidden_states,)
676
 
 
 
 
677
  if not return_dict:
678
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
679
 
@@ -685,8 +831,12 @@ class RWModel(RWPreTrainedModel):
685
  )
686
 
687
 
 
 
 
 
688
  class RWForCausalLM(RWPreTrainedModel):
689
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
690
 
691
  def __init__(self, config: RWConfig):
692
  super().__init__(config)
@@ -705,25 +855,26 @@ class RWForCausalLM(RWPreTrainedModel):
705
  def prepare_inputs_for_generation(
706
  self,
707
  input_ids: torch.LongTensor,
708
- past: Optional[torch.Tensor] = None,
709
  attention_mask: Optional[torch.Tensor] = None,
710
  **kwargs,
711
  ) -> dict:
712
- # only last token for input_ids if past is not None
713
- if past:
714
- input_ids = input_ids[:, -1].unsqueeze(-1)
715
-
716
- # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
717
- if past[0][0].shape[0] == input_ids.shape[0]:
718
- past = self._convert_to_rw_cache(past)
719
 
720
  return {
721
  "input_ids": input_ids,
722
- "past_key_values": past,
723
  "use_cache": kwargs.get("use_cache"),
724
  "attention_mask": attention_mask,
725
  }
726
 
 
 
 
 
 
 
727
  def forward(
728
  self,
729
  input_ids: Optional[torch.LongTensor] = None,
@@ -736,7 +887,6 @@ class RWForCausalLM(RWPreTrainedModel):
736
  output_attentions: Optional[bool] = None,
737
  output_hidden_states: Optional[bool] = None,
738
  return_dict: Optional[bool] = None,
739
- **deprecated_arguments,
740
  ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
741
  r"""
742
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -744,15 +894,6 @@ class RWForCausalLM(RWPreTrainedModel):
744
  `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
745
  are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
746
  """
747
- if deprecated_arguments.pop("position_ids", False) is not False:
748
- # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
749
- warnings.warn(
750
- "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
751
- " passing `position_ids`.",
752
- FutureWarning,
753
- )
754
- if len(deprecated_arguments) > 0:
755
- raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
756
 
757
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
758
 
@@ -805,7 +946,6 @@ class RWForCausalLM(RWPreTrainedModel):
805
 
806
  Output shares the same memory storage as `past`.
807
  """
808
- standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
809
 
810
  # Get a copy of `beam_idx` on all the devices where we need those indices.
811
  device_to_beam_idx = {
@@ -816,14 +956,27 @@ class RWForCausalLM(RWPreTrainedModel):
816
  layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
817
  layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
818
  )
819
- for layer_past in standardized_past
820
  )
821
- return self._convert_to_rw_cache(reordered_past)
822
 
823
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
824
  class RWForSequenceClassification(RWPreTrainedModel):
825
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
826
-
827
  def __init__(self, config: RWConfig):
828
  super().__init__(config)
829
  self.num_labels = config.num_labels
@@ -833,6 +986,12 @@ class RWForSequenceClassification(RWPreTrainedModel):
833
  # Initialize weights and apply final processing
834
  self.post_init()
835
 
 
 
 
 
 
 
836
  def forward(
837
  self,
838
  input_ids: Optional[torch.LongTensor] = None,
@@ -845,7 +1004,6 @@ class RWForSequenceClassification(RWPreTrainedModel):
845
  output_attentions: Optional[bool] = None,
846
  output_hidden_states: Optional[bool] = None,
847
  return_dict: Optional[bool] = None,
848
- **deprecated_arguments,
849
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
850
  r"""
851
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -853,15 +1011,6 @@ class RWForSequenceClassification(RWPreTrainedModel):
853
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
854
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
855
  """
856
- if deprecated_arguments.pop("position_ids", False) is not False:
857
- # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
858
- warnings.warn(
859
- "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
860
- " passing `position_ids`.",
861
- FutureWarning,
862
- )
863
- if len(deprecated_arguments) > 0:
864
- raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
865
 
866
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
867
 
@@ -936,17 +1085,22 @@ class RWForSequenceClassification(RWPreTrainedModel):
936
  )
937
 
938
 
 
 
 
 
 
 
 
939
  class RWForTokenClassification(RWPreTrainedModel):
940
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
941
-
942
  def __init__(self, config: RWConfig):
943
  super().__init__(config)
944
  self.num_labels = config.num_labels
945
 
946
  self.transformer = RWModel(config)
947
- if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
948
  classifier_dropout = config.classifier_dropout
949
- elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
950
  classifier_dropout = config.hidden_dropout
951
  else:
952
  classifier_dropout = 0.1
@@ -956,6 +1110,12 @@ class RWForTokenClassification(RWPreTrainedModel):
956
  # Initialize weights and apply final processing
957
  self.post_init()
958
 
 
 
 
 
 
 
959
  def forward(
960
  self,
961
  input_ids: Optional[torch.LongTensor] = None,
@@ -968,7 +1128,6 @@ class RWForTokenClassification(RWPreTrainedModel):
968
  output_attentions: Optional[bool] = None,
969
  output_hidden_states: Optional[bool] = None,
970
  return_dict: Optional[bool] = None,
971
- **deprecated_arguments,
972
  ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
973
  r"""
974
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -976,15 +1135,6 @@ class RWForTokenClassification(RWPreTrainedModel):
976
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
977
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
978
  """
979
- if deprecated_arguments.pop("position_ids", False) is not False:
980
- # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
981
- warnings.warn(
982
- "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
983
- " passing `position_ids`.",
984
- FutureWarning,
985
- )
986
- if len(deprecated_arguments) > 0:
987
- raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
988
 
989
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
990
 
@@ -1008,7 +1158,9 @@ class RWForTokenClassification(RWPreTrainedModel):
1008
  if labels is not None:
1009
  batch_size, seq_length = labels.shape
1010
  loss_fct = CrossEntropyLoss()
1011
- loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length))
 
 
1012
 
1013
  if not return_dict:
1014
  output = (logits,) + transformer_outputs[2:]
@@ -1022,9 +1174,14 @@ class RWForTokenClassification(RWPreTrainedModel):
1022
  )
1023
 
1024
 
 
 
 
 
 
 
 
1025
  class RWForQuestionAnswering(RWPreTrainedModel):
1026
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
1027
-
1028
  def __init__(self, config):
1029
  super().__init__(config)
1030
  self.transformer = RWModel(config)
@@ -1033,11 +1190,11 @@ class RWForQuestionAnswering(RWPreTrainedModel):
1033
  # Initialize weights and apply final processing
1034
  self.post_init()
1035
 
 
1036
  def forward(
1037
  self,
1038
  input_ids: Optional[torch.LongTensor] = None,
1039
  attention_mask: Optional[torch.FloatTensor] = None,
1040
- position_ids: Optional[torch.LongTensor] = None,
1041
  head_mask: Optional[torch.FloatTensor] = None,
1042
  inputs_embeds: Optional[torch.FloatTensor] = None,
1043
  start_positions: Optional[torch.LongTensor] = None,
@@ -1061,7 +1218,6 @@ class RWForQuestionAnswering(RWPreTrainedModel):
1061
  outputs = self.transformer(
1062
  input_ids,
1063
  attention_mask=attention_mask,
1064
- position_ids=position_ids,
1065
  head_mask=head_mask,
1066
  inputs_embeds=inputs_embeds,
1067
  output_attentions=output_attentions,
 
1
+ # coding=utf-8
2
+ # Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Falcon model."""
16
 
17
  import math
 
18
  from typing import Optional, Tuple, Union
19
 
20
  import torch
 
31
  TokenClassifierOutput,
32
  )
33
  from transformers.modeling_utils import PreTrainedModel
34
+ from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
35
  from .configuration_RW import RWConfig
36
 
37
+
38
  logger = logging.get_logger(__name__)
39
 
40
+ FALCON_PRETRAINED_MODEL_ARCHIVE_LIST = [
41
+ "tiiuae/falcon-40b",
42
+ "tiiuae/falcon-40b-instruct",
43
+ "tiiuae/falcon-7b",
44
+ "tiiuae/falcon-7b-instruct",
45
+ "tiiuae/falcon-rw-7b",
46
+ "tiiuae/falcon-rw-1b",
47
+ ]
48
+ _CHECKPOINT_FOR_DOC = "Rocketknight1/falcon-rw-1b"
49
+ _CONFIG_FOR_DOC = "RWConfig"
50
+
51
+
52
  # NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
53
  # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
54
+ class FalconLinear(nn.Linear):
55
  def forward(self, input: torch.Tensor) -> torch.Tensor:
56
+ hidden_states = input @ self.weight.T
57
  if self.bias is None:
58
+ return hidden_states
59
+ return hidden_states + self.bias
 
 
60
 
 
61
 
62
  # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
63
  def rotate_half(x):
64
  x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
65
+ return torch.cat((-x2, x1), dim=-1)
66
 
67
 
68
+ class FalconRotaryEmbedding(nn.Module):
69
  """Implementation of RotaryEmbedding from GPT-NeoX.
70
+ This implementation is designed to operate on queries and keys that are compatible with `[batch_size,
71
+ n_heads_per_partition, seq_len, head_dim]` (e.g. MinGPTAttention format).
72
  """
73
 
74
+ def __init__(self, head_dim: int, base=10000):
 
 
 
 
75
  super().__init__()
76
  inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
77
  self.register_buffer("inv_freq", inv_freq, persistent=False)
78
  self.head_dim = head_dim
79
+ self.seq_len_cached = -1
 
80
  self.cos_cached: torch.Tensor | None = None
81
  self.sin_cached: torch.Tensor | None = None
82
 
83
+ def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
84
+ total_length = seq_len + past_key_values_length
85
+ if total_length > self.seq_len_cached:
86
+ self.seq_len_cached = total_length
87
+ t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
 
 
 
 
88
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
89
  emb = torch.cat((freqs, freqs), dim=-1).to(device)
90
 
 
97
  self.cos_cached = self.cos_cached.type(dtype)
98
  self.sin_cached = self.sin_cached.type(dtype)
99
 
100
+ return (
101
+ self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
102
+ self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length],
103
+ )
104
 
105
+ def forward(self, query, key, past_key_values_length=0):
106
+ batch, seq_len, head_dim = query.shape
107
+ cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
108
+ return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
109
 
110
 
111
  def _make_causal_mask(
112
  input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
113
  ) -> torch.BoolTensor:
114
+ """
115
+ Make causal mask used for self-attention. This mask does not take the existing attention mask into account - it
116
+ just blocks tokens from attending forwards in the sequence. The output shape will be `[batch_size, 1,
117
+ target_length, target_length+past_key_values_length]`.
118
+ """
119
  batch_size, target_length = input_ids_shape
 
 
 
 
 
 
 
120
 
121
+ mask = torch.triu(torch.ones((target_length, target_length), dtype=torch.bool, device=device), diagonal=1)
122
+ # If past_key_values_length is 0 this is an empty tensor and the concatenation is a no-op.
123
+ # This code style is an unfortunate consequence of getting your TF engineer to port models; doing it this
124
+ # way avoids a data-dependent conditional, which will help me when I have to port this to XLA later.
125
+ past_mask = torch.zeros((target_length, past_key_values_length), dtype=torch.bool, device=device)
126
+ mask = torch.cat([past_mask, mask], dim=-1)
127
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
128
  return expanded_mask
129
 
130
 
131
+ def _expand_mask(mask: torch.Tensor, past_key_values_length: int) -> torch.BoolTensor:
132
+ """
133
+ Expands attention_mask from `[batch_size, seq_length]` to `[batch_size, 1, seq_length, seq_length + past_length]`.
134
+ """
135
+ batch_size, total_length = mask.shape
136
+ seq_length = total_length - past_key_values_length if past_key_values_length is not None else total_length
137
 
138
  expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
139
+ return expanded_mask.expand(batch_size, 1, seq_length, total_length)
140
 
141
 
142
  def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
 
167
  return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
168
 
169
 
170
+ # Copied from transformers.models.bloom.modeling_bloom.dropout_add
171
  def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
172
+ """
173
+ Dropout add function
174
+
175
+ Args:
176
+ x (`torch.tensor`, *required*):
177
+ input tensor
178
+ residual (`torch.tensor`, *required*):
179
+ residual tensor
180
+ prob (`float`, *required*):
181
+ dropout probability
182
+ training (`bool`, *required*):
183
+ training mode
184
+ """
185
  out = F.dropout(x, p=prob, training=training)
186
  out = residual + out
187
  return out
188
 
189
 
190
+ class FalconAttention(nn.Module):
191
  def __init__(self, config: RWConfig):
192
  super().__init__()
193
 
194
  self.hidden_size = config.hidden_size
195
+ self.num_heads = config.num_attention_heads
196
  self.head_dim = self.hidden_size // self.num_heads
197
  self.split_size = self.hidden_size
198
  self.hidden_dropout = config.hidden_dropout
 
203
  f" {self.num_heads})."
204
  )
205
 
206
+ self.maybe_rotary = FalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k)
207
 
208
  # Layer-wise attention scaling
209
  self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
210
  self.beta = self.inv_norm_factor
211
+ if config.new_decoder_architecture:
212
+ qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
213
+ elif config.multi_query:
214
+ qkv_out_dim = self.hidden_size + 2 * self.head_dim
215
+ else:
216
+ qkv_out_dim = 3 * self.hidden_size
217
+ self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)
218
+ self.new_decoder_architecture = config.new_decoder_architecture
219
+ self.multi_query = config.multi_query
220
+ self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)
221
  self.attention_dropout = nn.Dropout(config.attention_dropout)
222
+ self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
223
 
224
  def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
225
  """
226
+ Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv`
 
227
 
228
  Args:
229
  fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
230
 
231
  Returns:
232
+ query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
 
233
  value: [batch_size, seq_length, num_heads, head_dim]
234
  """
235
+ if self.new_decoder_architecture:
236
+ batch, seq_len, _ = fused_qkv.shape
237
+ qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv_heads + 2, self.head_dim)
238
+ query = qkv[:, :, :, :-2]
239
+ key = qkv[:, :, :, [-2]]
240
+ value = qkv[:, :, :, [-1]]
241
+ key = torch.broadcast_to(key, query.shape)
242
+ value = torch.broadcast_to(value, query.shape)
243
+
244
+ query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
245
+ return query, key, value
246
+ elif not self.multi_query:
247
+ batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
248
+ fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
249
+ return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
250
+ else:
251
+ batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
252
+ fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
253
+ return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
254
 
255
+ # Copied from transformers.models.bloom.modeling_bloom.BloomAttention._merge_heads
256
  def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
257
  """
258
  Merge heads together over the last dimenstion
259
 
260
  Args:
261
+ x (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
262
 
263
  Returns:
264
  torch.tensor: [batch_size, seq_length, num_heads * head_dim]
 
281
  def forward(
282
  self,
283
  hidden_states: torch.Tensor,
284
+ alibi: Optional[torch.Tensor],
285
  attention_mask: torch.Tensor,
286
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
287
  head_mask: Optional[torch.Tensor] = None,
 
289
  output_attentions: bool = False,
290
  ):
291
  fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
292
+ num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
293
  # 3 x [batch_size, seq_length, num_heads, head_dim]
294
  (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
295
 
296
+ batch_size, query_length, _, _ = query_layer.shape
297
 
298
+ query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
299
  key_layer = key_layer.transpose(1, 2).reshape(
300
+ batch_size * num_kv_heads,
301
+ query_length,
302
  self.head_dim,
303
  )
304
+ value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
305
 
306
+ past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
307
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
308
 
309
  if layer_past is not None:
310
  past_key, past_value = layer_past
311
  # concatenate along seq_length dimension:
312
+ # - key: [batch_size * self.num_heads, kv_length, head_dim]
313
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
314
  key_layer = torch.cat((past_key, key_layer), dim=1)
315
  value_layer = torch.cat((past_value, value_layer), dim=1)
316
 
317
  _, kv_length, _ = key_layer.shape
318
+ if use_cache:
 
319
  present = (key_layer, value_layer)
320
  else:
321
  present = None
322
 
323
+ attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
324
+
325
+ query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
326
+ key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
327
+ value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
328
+
329
  if alibi is None:
330
+ if output_attentions:
331
+ # F.scaled_dot_product_attention doesn't return the attention weights, so we have
332
+ # to do it by hand if we want them
333
+ attention_scores = query_layer_ @ key_layer_.transpose(-1, -2)
334
+ attention_scores /= math.sqrt(self.head_dim)
335
 
336
+ attention_scores = F.softmax(
337
+ attention_scores + attention_mask_float, dim=-1, dtype=hidden_states.dtype
338
+ )
339
+ attn_output = attention_scores @ value_layer_
340
+ else:
341
+ attn_output = F.scaled_dot_product_attention(
342
+ query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False
343
+ )
344
+ attention_scores = None
345
 
346
+ attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
347
+ attn_output = attn_output.permute(0, 2, 1, 3)
348
+ attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
349
 
350
  output_tensor = self.dense(attn_output)
351
 
352
+ if output_attentions:
353
+ return output_tensor, present, attention_scores
354
+ else:
355
+ return output_tensor, present
356
+
357
  else:
358
+ matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)
 
359
 
360
  # change view to [batch_size, num_heads, q_length, kv_length]
361
+ attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
362
 
363
  # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
364
  input_dtype = attention_scores.dtype
365
  # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
366
  if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
367
  attention_scores = attention_scores.to(torch.float32)
368
+ # Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
369
+ # adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically
370
+ # equivalent and more performant, but there might be a numerical difference. If you're reading this
371
+ # and you'd like to experiment and maybe file a PR, feel free!
372
+ attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
373
+ attention_logits *= self.inv_norm_factor
374
+ attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype)
375
  # [batch_size, num_heads, q_length, kv_length]
376
  attention_probs = self.attention_dropout(attention_probs)
377
 
378
  if head_mask is not None:
379
  attention_probs = attention_probs * head_mask
380
 
381
+ # change view [batch_size, num_heads, q_length, kv_length]
382
+ attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
383
 
384
  # matmul: [batch_size * num_heads, q_length, head_dim]
385
+ context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
386
 
387
  # change view [batch_size, num_heads, q_length, head_dim]
388
  context_layer = self._merge_heads(context_layer)
389
 
390
  output_tensor = self.dense(context_layer)
391
 
 
392
  if output_attentions:
393
+ return output_tensor, present, attention_probs
394
+ else:
395
+ return output_tensor, present
396
 
397
 
398
+ class FalconMLP(nn.Module):
399
  def __init__(self, config: RWConfig):
400
  super().__init__()
401
  hidden_size = config.hidden_size
402
 
403
+ self.dense_h_to_4h = FalconLinear(hidden_size, 4 * hidden_size, bias=config.bias)
404
  self.act = nn.GELU()
405
+ self.dense_4h_to_h = FalconLinear(4 * hidden_size, hidden_size, bias=config.bias)
406
  self.hidden_dropout = config.hidden_dropout
407
 
408
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
411
  return x
412
 
413
 
414
+ class FalconDecoderLayer(nn.Module):
415
  def __init__(self, config: RWConfig):
416
  super().__init__()
417
  hidden_size = config.hidden_size
418
+ self.num_heads = config.num_attention_heads
419
+ self.self_attention = FalconAttention(config)
420
+ self.mlp = FalconMLP(config)
 
 
 
 
 
 
 
421
  self.hidden_dropout = config.hidden_dropout
 
422
  self.config = config
423
 
424
+ if config.new_decoder_architecture:
425
+ # The layer norm before self-attention
426
+ self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
427
+ # The layer norm before the MLP
428
+ self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
429
+ else:
430
+ self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
431
+ if not config.parallel_attn:
432
+ self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
433
+
434
  def forward(
435
  self,
436
  hidden_states: torch.Tensor,
437
+ alibi: Optional[torch.Tensor],
438
  attention_mask: torch.Tensor,
439
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
440
  head_mask: Optional[torch.Tensor] = None,
441
  use_cache: bool = False,
442
  output_attentions: bool = False,
443
  ):
 
 
 
 
444
  residual = hidden_states
445
 
446
+ if self.config.new_decoder_architecture:
447
+ attention_layernorm_out = self.ln_attn(hidden_states)
448
+ mlp_layernorm_out = self.ln_mlp(hidden_states)
449
+ else:
450
+ attention_layernorm_out = self.input_layernorm(hidden_states)
451
+
452
  # Self attention.
453
  attn_outputs = self.self_attention(
454
+ attention_layernorm_out,
455
  layer_past=layer_past,
456
  attention_mask=attention_mask,
457
  alibi=alibi,
 
462
 
463
  attention_output = attn_outputs[0]
464
 
465
+ if not self.config.new_decoder_architecture:
466
+ if self.config.parallel_attn:
467
+ mlp_layernorm_out = attention_layernorm_out
468
+ else:
469
+ residual = dropout_add(
470
+ attention_output, residual, self.config.attention_dropout, training=self.training
471
+ )
472
+ mlp_layernorm_out = self.post_attention_layernorm(residual)
473
+
474
  outputs = attn_outputs[1:]
475
 
476
  # MLP.
477
+ mlp_output = self.mlp(mlp_layernorm_out)
478
 
479
+ if self.config.new_decoder_architecture or self.config.parallel_attn:
480
+ mlp_output += attention_output
481
+
482
+ output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
483
 
484
  if use_cache:
485
  outputs = (output,) + outputs
 
489
  return outputs # hidden_states, present, attentions
490
 
491
 
492
+ FALCON_START_DOCSTRING = r"""
493
+
494
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
495
+ library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
496
+
497
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
498
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
499
+ and behavior.
500
+
501
+ Parameters:
502
+ config ([`RWConfig`]): Model configuration class with all the parameters of the model.
503
+ Initializing with a config file does not load the weights associated with the model, only the
504
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
505
+ """
506
+
507
+ FALCON_INPUTS_DOCSTRING = r"""
508
+ Args:
509
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
510
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
511
+ (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
512
+
513
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
514
+ `input_ids`.
515
+
516
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
517
+ [`PreTrainedTokenizer.__call__`] for details.
518
+
519
+ [What are input IDs?](../glossary#input-ids)
520
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.num_hidden_layers`):
521
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
522
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
523
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
524
+
525
+ Each element of `past_key_values` is a tuple (past_key, past_value):
526
+ - past_key: [batch_size * num_heads, head_dim, kv_length]
527
+ - past_value: [batch_size * num_heads, kv_length, head_dim]
528
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
529
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
530
+
531
+ - 1 for tokens that are **not masked**,
532
+ - 0 for tokens that are **masked**.
533
+
534
+ [What are attention masks?](../glossary#attention-mask)
535
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
536
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
537
+
538
+ - 1 indicates the head is **not masked**,
539
+ - 0 indicates the head is **masked**.
540
+
541
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
542
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
543
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
544
+ model's internal embedding lookup matrix.
545
+
546
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
547
+ `past_key_values`).
548
+ use_cache (`bool`, *optional*):
549
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
550
+ `past_key_values`).
551
+ output_attentions (`bool`, *optional*):
552
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
553
+ tensors for more detail.
554
+ output_hidden_states (`bool`, *optional*):
555
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
556
+ more detail.
557
+ return_dict (`bool`, *optional*):
558
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
559
+ """
560
+
561
+
562
  class RWPreTrainedModel(PreTrainedModel):
 
563
  """
564
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
565
  models.
 
568
  config_class = RWConfig
569
  base_model_prefix = "transformer"
570
  supports_gradient_checkpointing = True
571
+ _no_split_modules = ["FalconDecoderLayer"]
572
 
573
  def __init__(self, *inputs, **kwargs):
574
  super().__init__(*inputs, **kwargs)
575
 
576
  def _init_weights(self, module: nn.Module):
577
  """Initialize the weights."""
578
+ if isinstance(module, nn.Linear) or isinstance(module, FalconLinear):
579
  # Slightly different from the TF version which uses truncated_normal for initialization
580
  # cf https://github.com/pytorch/pytorch/pull/5617
581
  module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
 
589
  module.bias.data.zero_()
590
  module.weight.data.fill_(1.0)
591
 
592
+ # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._set_gradient_checkpointing with BloomModel->RWModel
593
  def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
594
  if isinstance(module, RWModel):
595
  module.gradient_checkpointing = value
596
 
597
  @staticmethod
598
+ def _convert_cache_to_standard_format(
599
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
600
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
601
  """
602
  Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
603
  num_heads, ...]))
604
  """
605
+ batch_size_times_num_heads, kv_length, head_dim = past_key_value[0][0].shape
606
+ # [batch_size * self.num_heads, kv_length, head_dim] -> [batch_size, num_heads, kv_length, head_dim]
607
+ # Note that don't want to use self.num_attention_heads because the number of heads may vary depending
608
+ # on whether we use multi_query attention.
609
  num_heads = batch_size_times_num_heads // batch_size
 
 
610
  return tuple(
611
  (
612
+ layer_past[0].view(batch_size, num_heads, kv_length, head_dim),
613
+ layer_past[1].view(batch_size, num_heads, kv_length, head_dim),
614
  )
615
  for layer_past in past_key_value
616
  )
 
619
  def _convert_to_rw_cache(
620
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
621
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
622
+ batch_size, num_heads, kv_length, head_dim = past_key_value[0][0].shape
623
  batch_size_times_num_heads = batch_size * num_heads
624
+ # [batch_size, num_heads, kv_length, head_dim] -> [batch_size * num_heads, kv_length, head_dim]
 
625
  return tuple(
626
  (
627
+ layer_past[0].view(batch_size_times_num_heads, kv_length, head_dim),
628
+ layer_past[1].view(batch_size_times_num_heads, kv_length, head_dim),
629
  )
630
  for layer_past in past_key_value
631
  )
632
 
633
 
634
+ @add_start_docstrings(
635
+ "The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.",
636
+ FALCON_START_DOCSTRING,
637
+ )
638
  class RWModel(RWPreTrainedModel):
639
  def __init__(self, config: RWConfig):
640
  super().__init__(config)
641
 
642
  self.embed_dim = config.hidden_size
643
+ self.num_heads = config.num_attention_heads
644
+ self.use_alibi = config.alibi
645
 
646
  # Embedding + LN Embedding
647
  self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
648
 
649
  # Transformer blocks
650
+ self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)])
651
 
652
  # Final Layer Norm
653
  self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
 
660
  def get_input_embeddings(self):
661
  return self.word_embeddings
662
 
663
+ @staticmethod
664
  def _prepare_attn_mask(
665
+ attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
666
  ) -> torch.BoolTensor:
667
+ # Create a causal mask
668
+ # The attention mask we receive as input should cover the whole extended sequence, including any past
669
+ # cache, so its shape should be [batch_size, seq_length + past_key_values_length]
670
+ # The output shape will be [batch_size, 1, seq_length, seq_length + past_key_values_length]
671
+ if input_shape[1] + past_key_values_length != attention_mask.shape[1]:
672
+ raise ValueError(
673
+ "Attention mask shape should be (batch_size, seq_length + past_key_values_length)"
674
+ f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length"
675
+ f" {past_key_values_length}."
676
+ )
677
  combined_attention_mask = None
678
  device = attention_mask.device
679
+ _, seq_length = input_shape
680
 
681
+ if seq_length > 1:
682
  combined_attention_mask = _make_causal_mask(
683
  input_shape, device=device, past_key_values_length=past_key_values_length
684
  )
685
 
686
+ # [batch_size, seq_length + past_key_values_length] -> [batch_size, 1, seq_length, seq_length + past_key_values_length]
687
+ expanded_attn_mask = _expand_mask(attention_mask, past_key_values_length=past_key_values_length)
688
  combined_attention_mask = (
689
  expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
690
  )
 
694
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
695
  self.word_embeddings = new_embeddings
696
 
697
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
698
+ @add_code_sample_docstrings(
699
+ checkpoint=_CHECKPOINT_FOR_DOC,
700
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
701
+ config_class=_CONFIG_FOR_DOC,
702
+ )
703
  def forward(
704
  self,
705
  input_ids: Optional[torch.LongTensor] = None,
 
711
  output_attentions: Optional[bool] = None,
712
  output_hidden_states: Optional[bool] = None,
713
  return_dict: Optional[bool] = None,
 
714
  ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
 
 
 
 
 
 
 
 
 
 
715
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
716
  output_hidden_states = (
717
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
730
 
731
  if past_key_values is None:
732
  past_key_values = tuple([None] * len(self.h))
733
+ else:
734
+ past_key_values = self._convert_to_rw_cache(past_key_values)
735
 
736
  # Prepare head mask if needed
737
  # 1.0 in head_mask indicate we keep the head
738
  # attention_probs has shape batch_size x num_heads x N x N
739
  # head_mask has shape n_layer x batch x num_heads x N x N
740
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
741
 
742
  if inputs_embeds is None:
743
  inputs_embeds = self.word_embeddings(input_ids)
 
749
  all_hidden_states = () if output_hidden_states else None
750
 
751
  # Compute alibi tensor: check build_alibi_tensor documentation
 
752
  past_key_values_length = 0
753
  if past_key_values[0] is not None:
754
+ past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
 
755
  if attention_mask is None:
756
+ attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
757
  else:
758
  attention_mask = attention_mask.to(hidden_states.device)
759
 
760
+ if self.use_alibi:
761
  alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
762
  else:
763
  alibi = None
 
769
  )
770
 
771
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
 
772
  if output_hidden_states:
773
  all_hidden_states = all_hidden_states + (hidden_states,)
774
 
775
  if self.gradient_checkpointing and self.training:
 
776
  if use_cache:
777
  logger.warning(
778
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
 
817
  if output_hidden_states:
818
  all_hidden_states = all_hidden_states + (hidden_states,)
819
 
820
+ if presents is not None:
821
+ presents = self._convert_cache_to_standard_format(presents, batch_size)
822
+
823
  if not return_dict:
824
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
825
 
 
831
  )
832
 
833
 
834
+ @add_start_docstrings(
835
+ "The Falcon Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).",
836
+ FALCON_START_DOCSTRING,
837
+ )
838
  class RWForCausalLM(RWPreTrainedModel):
839
+ _tied_weights_keys = ["lm_head.weight"]
840
 
841
  def __init__(self, config: RWConfig):
842
  super().__init__(config)
 
855
  def prepare_inputs_for_generation(
856
  self,
857
  input_ids: torch.LongTensor,
858
+ past_key_values: Optional[torch.Tensor] = None,
859
  attention_mask: Optional[torch.Tensor] = None,
860
  **kwargs,
861
  ) -> dict:
862
+ if past_key_values is not None:
863
+ input_ids = input_ids[:, -1:]
 
 
 
 
 
864
 
865
  return {
866
  "input_ids": input_ids,
867
+ "past_key_values": past_key_values,
868
  "use_cache": kwargs.get("use_cache"),
869
  "attention_mask": attention_mask,
870
  }
871
 
872
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
873
+ @add_code_sample_docstrings(
874
+ checkpoint=_CHECKPOINT_FOR_DOC,
875
+ output_type=CausalLMOutputWithCrossAttentions,
876
+ config_class=_CONFIG_FOR_DOC,
877
+ )
878
  def forward(
879
  self,
880
  input_ids: Optional[torch.LongTensor] = None,
 
887
  output_attentions: Optional[bool] = None,
888
  output_hidden_states: Optional[bool] = None,
889
  return_dict: Optional[bool] = None,
 
890
  ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
891
  r"""
892
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 
894
  `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
895
  are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
896
  """
 
 
 
 
 
 
 
 
 
897
 
898
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
899
 
 
946
 
947
  Output shares the same memory storage as `past`.
948
  """
 
949
 
950
  # Get a copy of `beam_idx` on all the devices where we need those indices.
951
  device_to_beam_idx = {
 
956
  layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
957
  layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
958
  )
959
+ for layer_past in past
960
  )
961
+ return reordered_past
962
 
963
 
964
+ @add_start_docstrings(
965
+ """
966
+ The Falcon Model transformer with a sequence classification head on top (linear layer).
967
+
968
+ [`RWForSequenceClassification`] uses the last token in order to do the classification, as other causal models
969
+ (e.g. GPT-1) do.
970
+
971
+ Since it does classification on the last token, it requires to know the position of the last token. If a
972
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
973
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
974
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
975
+ each row of the batch).
976
+ """,
977
+ FALCON_START_DOCSTRING,
978
+ )
979
  class RWForSequenceClassification(RWPreTrainedModel):
 
 
980
  def __init__(self, config: RWConfig):
981
  super().__init__(config)
982
  self.num_labels = config.num_labels
 
986
  # Initialize weights and apply final processing
987
  self.post_init()
988
 
989
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
990
+ @add_code_sample_docstrings(
991
+ checkpoint=_CHECKPOINT_FOR_DOC,
992
+ output_type=SequenceClassifierOutputWithPast,
993
+ config_class=_CONFIG_FOR_DOC,
994
+ )
995
  def forward(
996
  self,
997
  input_ids: Optional[torch.LongTensor] = None,
 
1004
  output_attentions: Optional[bool] = None,
1005
  output_hidden_states: Optional[bool] = None,
1006
  return_dict: Optional[bool] = None,
 
1007
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
1008
  r"""
1009
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
1011
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1012
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1013
  """
 
 
 
 
 
 
 
 
 
1014
 
1015
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1016
 
 
1085
  )
1086
 
1087
 
1088
+ @add_start_docstrings(
1089
+ """
1090
+ Falcon Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1091
+ Named-Entity-Recognition (NER) tasks.
1092
+ """,
1093
+ FALCON_START_DOCSTRING,
1094
+ )
1095
  class RWForTokenClassification(RWPreTrainedModel):
 
 
1096
  def __init__(self, config: RWConfig):
1097
  super().__init__(config)
1098
  self.num_labels = config.num_labels
1099
 
1100
  self.transformer = RWModel(config)
1101
+ if getattr(config, "classifier_dropout", None) is not None:
1102
  classifier_dropout = config.classifier_dropout
1103
+ elif getattr(config, "hidden_dropout", None) is not None:
1104
  classifier_dropout = config.hidden_dropout
1105
  else:
1106
  classifier_dropout = 0.1
 
1110
  # Initialize weights and apply final processing
1111
  self.post_init()
1112
 
1113
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1114
+ @add_code_sample_docstrings(
1115
+ checkpoint=_CHECKPOINT_FOR_DOC,
1116
+ output_type=TokenClassifierOutput,
1117
+ config_class=_CONFIG_FOR_DOC,
1118
+ )
1119
  def forward(
1120
  self,
1121
  input_ids: Optional[torch.LongTensor] = None,
 
1128
  output_attentions: Optional[bool] = None,
1129
  output_hidden_states: Optional[bool] = None,
1130
  return_dict: Optional[bool] = None,
 
1131
  ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1132
  r"""
1133
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
1135
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1136
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1137
  """
 
 
 
 
 
 
 
 
 
1138
 
1139
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1140
 
 
1158
  if labels is not None:
1159
  batch_size, seq_length = labels.shape
1160
  loss_fct = CrossEntropyLoss()
1161
+ loss = loss_fct(
1162
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1163
+ )
1164
 
1165
  if not return_dict:
1166
  output = (logits,) + transformer_outputs[2:]
 
1174
  )
1175
 
1176
 
1177
+ @add_start_docstrings(
1178
+ """
1179
+ The Falcon Model transformer with a span classification head on top for extractive question-answering tasks like
1180
+ SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1181
+ """,
1182
+ FALCON_START_DOCSTRING,
1183
+ )
1184
  class RWForQuestionAnswering(RWPreTrainedModel):
 
 
1185
  def __init__(self, config):
1186
  super().__init__(config)
1187
  self.transformer = RWModel(config)
 
1190
  # Initialize weights and apply final processing
1191
  self.post_init()
1192
 
1193
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1194
  def forward(
1195
  self,
1196
  input_ids: Optional[torch.LongTensor] = None,
1197
  attention_mask: Optional[torch.FloatTensor] = None,
 
1198
  head_mask: Optional[torch.FloatTensor] = None,
1199
  inputs_embeds: Optional[torch.FloatTensor] = None,
1200
  start_positions: Optional[torch.LongTensor] = None,
 
1218
  outputs = self.transformer(
1219
  input_ids,
1220
  attention_mask=attention_mask,
 
1221
  head_mask=head_mask,
1222
  inputs_embeds=inputs_embeds,
1223
  output_attentions=output_attentions,