Reapply cekal/mpt-7b-peft-compatible (#4)
Browse files- Reapply cekal/mpt-7b-peft-compatible (e7704ff001c86e435514b8dce903799e53ab68d4)
Co-authored-by: K <kornfield@users.noreply.huggingface.co>
- modeling_mpt.py +71 -13
modeling_mpt.py
CHANGED
@@ -33,13 +33,19 @@ log = logging.getLogger(__name__)
|
|
33 |
class MPTPreTrainedModel(PreTrainedModel):
|
34 |
config_class = MPTConfig
|
35 |
base_model_prefix = 'model'
|
36 |
-
_no_split_modules = [
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
class MPTModel(MPTPreTrainedModel):
|
39 |
|
40 |
def __init__(self, config: MPTConfig):
|
41 |
config._validate_config()
|
42 |
super().__init__(config)
|
|
|
43 |
self.attn_impl = config.attn_config['attn_impl']
|
44 |
self.prefix_lm = config.attn_config['prefix_lm']
|
45 |
self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
|
@@ -146,8 +152,37 @@ class MPTModel(MPTPreTrainedModel):
|
|
146 |
def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None) -> BaseModelOutputWithPast:
|
147 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
148 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
if attention_mask is not None:
|
150 |
attention_mask = attention_mask.bool()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
if prefix_mask is not None:
|
152 |
prefix_mask = prefix_mask.bool()
|
153 |
if not return_dict:
|
@@ -155,8 +190,8 @@ class MPTModel(MPTPreTrainedModel):
|
|
155 |
if output_attentions:
|
156 |
if self.attn_impl != 'torch':
|
157 |
raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
|
158 |
-
if self.training and attention_mask is not None and (attention_mask[:, 0].sum() != attention_mask.shape[0]):
|
159 |
-
|
160 |
if self.prefix_lm and prefix_mask is None:
|
161 |
raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
|
162 |
if inputs_embeds is not None:
|
@@ -166,7 +201,7 @@ class MPTModel(MPTPreTrainedModel):
|
|
166 |
raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
|
167 |
elif self.attn_uses_sequence_id is False and sequence_id is not None:
|
168 |
warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
|
169 |
-
S =
|
170 |
assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
|
171 |
tok_emb = self.wte(input_ids)
|
172 |
if self.learned_pos_emb:
|
@@ -180,7 +215,7 @@ class MPTModel(MPTPreTrainedModel):
|
|
180 |
if S + past_position > self.config.max_seq_len:
|
181 |
raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length ' + f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
|
182 |
pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
|
183 |
-
if attention_mask is not None:
|
184 |
pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
|
185 |
pos_emb = self.wpe(pos)
|
186 |
x = tok_emb + pos_emb
|
@@ -196,6 +231,7 @@ class MPTModel(MPTPreTrainedModel):
|
|
196 |
presents = () if use_cache else None
|
197 |
if use_cache and past_key_values is None:
|
198 |
past_key_values = [() for _ in range(self.config.n_layers)]
|
|
|
199 |
all_hidden_states = () if output_hidden_states else None
|
200 |
all_self_attns = () if output_attentions else None
|
201 |
for (b_idx, block) in enumerate(self.blocks):
|
@@ -203,12 +239,34 @@ class MPTModel(MPTPreTrainedModel):
|
|
203 |
assert all_hidden_states is not None
|
204 |
all_hidden_states = all_hidden_states + (x,)
|
205 |
past_key_value = past_key_values[b_idx] if past_key_values is not None else None
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
x = self.norm_f(x)
|
213 |
if output_hidden_states:
|
214 |
assert all_hidden_states is not None
|
@@ -271,7 +329,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
271 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
272 |
if inputs_embeds is not None:
|
273 |
raise NotImplementedError('inputs_embeds has to be None (for hf/peft support).')
|
274 |
-
outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
|
275 |
logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
|
276 |
if self.logit_scale is not None:
|
277 |
if self.logit_scale == 0:
|
@@ -324,4 +382,4 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
324 |
reordered_past = []
|
325 |
for layer_past in past_key_values:
|
326 |
reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
|
327 |
-
return reordered_past
|
|
|
33 |
class MPTPreTrainedModel(PreTrainedModel):
|
34 |
config_class = MPTConfig
|
35 |
base_model_prefix = 'model'
|
36 |
+
_no_split_modules = ["MPTBlock"]
|
37 |
+
supports_gradient_checkpointing = True
|
38 |
+
|
39 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
40 |
+
if isinstance(module, MPTModel):
|
41 |
+
module.gradient_checkpointing = value
|
42 |
|
43 |
class MPTModel(MPTPreTrainedModel):
|
44 |
|
45 |
def __init__(self, config: MPTConfig):
|
46 |
config._validate_config()
|
47 |
super().__init__(config)
|
48 |
+
self.gradient_checkpointing = False
|
49 |
self.attn_impl = config.attn_config['attn_impl']
|
50 |
self.prefix_lm = config.attn_config['prefix_lm']
|
51 |
self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
|
|
|
152 |
def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None) -> BaseModelOutputWithPast:
|
153 |
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
154 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
155 |
+
if self.gradient_checkpointing and self.training:
|
156 |
+
if use_cache:
|
157 |
+
use_cache = False
|
158 |
+
if input_ids is not None and inputs_embeds is not None:
|
159 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
160 |
+
elif input_ids is not None:
|
161 |
+
batch_size, seq_length = input_ids.shape
|
162 |
+
elif inputs_embeds is not None:
|
163 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
164 |
+
else:
|
165 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
166 |
+
|
167 |
+
seq_length_with_past = seq_length
|
168 |
+
past_key_values_length = 0
|
169 |
+
|
170 |
+
if past_key_values is not None:
|
171 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
172 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
173 |
+
|
174 |
if attention_mask is not None:
|
175 |
attention_mask = attention_mask.bool()
|
176 |
+
else:
|
177 |
+
attention_mask = torch.ones(
|
178 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
179 |
+
)
|
180 |
+
|
181 |
+
if inputs_embeds is None:
|
182 |
+
tok_emb = self.wte(input_ids)
|
183 |
+
else:
|
184 |
+
tok_emb = inputs_embeds
|
185 |
+
|
186 |
if prefix_mask is not None:
|
187 |
prefix_mask = prefix_mask.bool()
|
188 |
if not return_dict:
|
|
|
190 |
if output_attentions:
|
191 |
if self.attn_impl != 'torch':
|
192 |
raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
|
193 |
+
#if self.training and attention_mask is not None and (attention_mask[:, 0].sum() != attention_mask.shape[0]):
|
194 |
+
# raise NotImplementedError('MPT does not support training with left padding.')
|
195 |
if self.prefix_lm and prefix_mask is None:
|
196 |
raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
|
197 |
if inputs_embeds is not None:
|
|
|
201 |
raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
|
202 |
elif self.attn_uses_sequence_id is False and sequence_id is not None:
|
203 |
warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
|
204 |
+
S = seq_length
|
205 |
assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
|
206 |
tok_emb = self.wte(input_ids)
|
207 |
if self.learned_pos_emb:
|
|
|
215 |
if S + past_position > self.config.max_seq_len:
|
216 |
raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length ' + f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
|
217 |
pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
|
218 |
+
if attention_mask is not None and not self.training:
|
219 |
pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
|
220 |
pos_emb = self.wpe(pos)
|
221 |
x = tok_emb + pos_emb
|
|
|
231 |
presents = () if use_cache else None
|
232 |
if use_cache and past_key_values is None:
|
233 |
past_key_values = [() for _ in range(self.config.n_layers)]
|
234 |
+
|
235 |
all_hidden_states = () if output_hidden_states else None
|
236 |
all_self_attns = () if output_attentions else None
|
237 |
for (b_idx, block) in enumerate(self.blocks):
|
|
|
239 |
assert all_hidden_states is not None
|
240 |
all_hidden_states = all_hidden_states + (x,)
|
241 |
past_key_value = past_key_values[b_idx] if past_key_values is not None else None
|
242 |
+
if self.gradient_checkpointing and self.training:
|
243 |
+
|
244 |
+
def create_custom_forward(module):
|
245 |
+
def custom_forward(*inputs):
|
246 |
+
# None for past_key_value
|
247 |
+
return module(*inputs)
|
248 |
+
|
249 |
+
return custom_forward
|
250 |
+
|
251 |
+
(x, past_key_value) = torch.utils.checkpoint.checkpoint(
|
252 |
+
create_custom_forward(block),
|
253 |
+
x,
|
254 |
+
past_key_value,
|
255 |
+
attn_bias,
|
256 |
+
attention_mask,
|
257 |
+
self.is_causal,
|
258 |
+
)
|
259 |
+
if past_key_values is not None:
|
260 |
+
past_key_values[b_idx] = past_key_value
|
261 |
+
else:
|
262 |
+
(x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions))
|
263 |
+
if presents is not None:
|
264 |
+
presents += (present,)
|
265 |
+
if output_attentions:
|
266 |
+
assert all_self_attns is not None
|
267 |
+
all_self_attns = all_self_attns + (attn_weights,)
|
268 |
+
|
269 |
+
|
270 |
x = self.norm_f(x)
|
271 |
if output_hidden_states:
|
272 |
assert all_hidden_states is not None
|
|
|
329 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
330 |
if inputs_embeds is not None:
|
331 |
raise NotImplementedError('inputs_embeds has to be None (for hf/peft support).')
|
332 |
+
outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, inputs_embeds=inputs_embeds)
|
333 |
logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
|
334 |
if self.logit_scale is not None:
|
335 |
if self.logit_scale == 0:
|
|
|
382 |
reordered_past = []
|
383 |
for layer_past in past_key_values:
|
384 |
reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
|
385 |
+
return reordered_past
|