eluzhnica commited on
Commit
8eef914
1 Parent(s): 539960c

Add gradient checkpointing

Browse files
Files changed (1) hide show
  1. modeling_mpt.py +41 -15
modeling_mpt.py CHANGED
@@ -18,7 +18,7 @@ from .configuration_mpt import MPTConfig
18
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
19
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
20
  from .meta_init_context import init_empty_weights
21
- from .param_init_fns import generic_param_init_fn_, MODEL_INIT_REGISTRY
22
  try:
23
  from .flash_attn_triton import flash_attn_func
24
  except:
@@ -30,11 +30,18 @@ class MPTPreTrainedModel(PreTrainedModel):
30
  base_model_prefix = 'model'
31
  _no_split_modules = ['MPTBlock']
32
 
 
 
 
 
 
 
33
  class MPTModel(MPTPreTrainedModel):
34
 
35
  def __init__(self, config: MPTConfig):
36
  config._validate_config()
37
  super().__init__(config)
 
38
  self.attn_impl = config.attn_config['attn_impl']
39
  self.prefix_lm = config.attn_config['prefix_lm']
40
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
@@ -80,7 +87,7 @@ class MPTModel(MPTPreTrainedModel):
80
  def get_input_embeddings(self):
81
  return self.wte
82
 
83
- def set_input_embeddings(self, value: nn.Embedding):
84
  self.wte = value
85
 
86
  @torch.no_grad()
@@ -140,7 +147,7 @@ class MPTModel(MPTPreTrainedModel):
140
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
141
  return attn_bias
142
 
143
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None):
144
  return_dict = return_dict if return_dict is not None else self.config.return_dict
145
  use_cache = use_cache if use_cache is not None else self.config.use_cache
146
  if attention_mask is not None:
@@ -156,13 +163,15 @@ class MPTModel(MPTPreTrainedModel):
156
  raise NotImplementedError('MPT does not support training with left padding.')
157
  if self.prefix_lm and prefix_mask is None:
158
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
159
- if inputs_embeds is not None:
160
- raise NotImplementedError('inputs_embeds is not implemented for MPT.')
161
  if self.training:
162
  if self.attn_uses_sequence_id and sequence_id is None:
163
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
164
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
165
  warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
 
 
 
 
166
  S = input_ids.size(1)
167
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
168
  tok_emb = self.wte(input_ids)
@@ -199,7 +208,27 @@ class MPTModel(MPTPreTrainedModel):
199
  assert all_hidden_states is not None
200
  all_hidden_states = all_hidden_states + (x,)
201
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
202
- (x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  if past_key_values is not None:
204
  past_key_values[b_idx] = past_key_value
205
  if output_attentions:
@@ -227,8 +256,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
227
  super().__init__(config)
228
  if not config.tie_word_embeddings:
229
  raise ValueError('MPTForCausalLM only supports tied word embeddings')
230
- print(f'Instantiating an MPTForCausalLM model from {__file__}')
231
- self.transformer: MPTModel = MPTModel(config)
232
  for child in self.transformer.children():
233
  if isinstance(child, torch.nn.ModuleList):
234
  continue
@@ -262,11 +290,9 @@ class MPTForCausalLM(MPTPreTrainedModel):
262
  def get_decoder(self):
263
  return self.transformer
264
 
265
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor]=None):
266
  return_dict = return_dict if return_dict is not None else self.config.return_dict
267
  use_cache = use_cache if use_cache is not None else self.config.use_cache
268
- if inputs_embeds is not None:
269
- raise NotImplementedError('inputs_embeds has to be None (for hf/peft support).')
270
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
271
  logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
272
  if self.logit_scale is not None:
@@ -275,9 +301,9 @@ class MPTForCausalLM(MPTPreTrainedModel):
275
  logits *= self.logit_scale
276
  loss = None
277
  if labels is not None:
278
- _labels = torch.roll(labels, shifts=-1)
279
- _labels[:, -1] = -100
280
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), _labels.to(logits.device).view(-1))
281
  return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
282
 
283
  def param_init_fn(self, module):
@@ -320,4 +346,4 @@ class MPTForCausalLM(MPTPreTrainedModel):
320
  reordered_past = []
321
  for layer_past in past_key_values:
322
  reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
323
- return reordered_past
 
18
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
19
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
20
  from .meta_init_context import init_empty_weights
21
+ from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
22
  try:
23
  from .flash_attn_triton import flash_attn_func
24
  except:
 
30
  base_model_prefix = 'model'
31
  _no_split_modules = ['MPTBlock']
32
 
33
+ supports_gradient_checkpointing = True
34
+
35
+ def _set_gradient_checkpointing(self, module, value=False):
36
+ if isinstance(module, MPTModel):
37
+ module.gradient_checkpointing = value
38
+
39
  class MPTModel(MPTPreTrainedModel):
40
 
41
  def __init__(self, config: MPTConfig):
42
  config._validate_config()
43
  super().__init__(config)
44
+ self.gradient_checkpointing = False
45
  self.attn_impl = config.attn_config['attn_impl']
46
  self.prefix_lm = config.attn_config['prefix_lm']
47
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
 
87
  def get_input_embeddings(self):
88
  return self.wte
89
 
90
+ def set_input_embeddings(self, value):
91
  self.wte = value
92
 
93
  @torch.no_grad()
 
147
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
148
  return attn_bias
149
 
150
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
151
  return_dict = return_dict if return_dict is not None else self.config.return_dict
152
  use_cache = use_cache if use_cache is not None else self.config.use_cache
153
  if attention_mask is not None:
 
163
  raise NotImplementedError('MPT does not support training with left padding.')
164
  if self.prefix_lm and prefix_mask is None:
165
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
 
 
166
  if self.training:
167
  if self.attn_uses_sequence_id and sequence_id is None:
168
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
169
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
170
  warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
171
+ if self.gradient_checkpointing and self.training:
172
+ if use_cache:
173
+ use_cache = False
174
+
175
  S = input_ids.size(1)
176
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
177
  tok_emb = self.wte(input_ids)
 
208
  assert all_hidden_states is not None
209
  all_hidden_states = all_hidden_states + (x,)
210
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
211
+ if self.gradient_checkpointing and self.training:
212
+
213
+ def create_custom_forward(module):
214
+ def custom_forward(*inputs):
215
+ # None for past_key_value
216
+ return module(*inputs)
217
+
218
+ return custom_forward
219
+
220
+ (x, attn_weights, past_key_value) = torch.utils.checkpoint.checkpoint(
221
+ create_custom_forward(block),
222
+ x,
223
+ past_key_value,
224
+ attn_bias,
225
+ attention_mask,
226
+ self.is_causal,
227
+ )
228
+ else:
229
+ (x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias,
230
+ attention_mask=attention_mask, is_causal=self.is_causal)
231
+
232
  if past_key_values is not None:
233
  past_key_values[b_idx] = past_key_value
234
  if output_attentions:
 
256
  super().__init__(config)
257
  if not config.tie_word_embeddings:
258
  raise ValueError('MPTForCausalLM only supports tied word embeddings')
259
+ self.transformer = MPTModel(config)
 
260
  for child in self.transformer.children():
261
  if isinstance(child, torch.nn.ModuleList):
262
  continue
 
290
  def get_decoder(self):
291
  return self.transformer
292
 
293
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
294
  return_dict = return_dict if return_dict is not None else self.config.return_dict
295
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
 
296
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
297
  logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
298
  if self.logit_scale is not None:
 
301
  logits *= self.logit_scale
302
  loss = None
303
  if labels is not None:
304
+ labels = torch.roll(labels, shifts=-1)
305
+ labels[:, -1] = -100
306
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
307
  return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
308
 
309
  def param_init_fn(self, module):
 
346
  reordered_past = []
347
  for layer_past in past_key_values:
348
  reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
349
+ return reordered_past