Fine tuning fails

#10
by maorcatmyheritage - opened

Fine tuning fails with the following error: AttributeError: 'NoneType' object has no attribute 'get_seq_length'

I followed the original guide of Idefics2 and just replaced Idefics2ForConditionalGeneration with Idefics3ForConditionalGeneration and model name HuggingFaceM4/idefics2-8b with HuggingFaceM4/Idefics3-8B-Llama3
I installed the transformers from the following PR as mentioned in the model's card: pip install git+https://github.com/huggingface/transformers.git@refs/pull/32473/head

Installed packages:

torch                                    2.3.1
accelerate                               0.33.0
flash-attn                               2.6.3

Here is the full trace:

AttributeError                            Traceback (most recent call last)
Cell In[6], line 1
----> 1 trainer.train()

File /opt/conda/envs/llama/lib/python3.11/site-packages/transformers/trainer.py:1964, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1962         hf_hub_utils.enable_progress_bars()
   1963 else:
-> 1964     return inner_training_loop(
   1965         args=args,
   1966         resume_from_checkpoint=resume_from_checkpoint,
   1967         trial=trial,
   1968         ignore_keys_for_eval=ignore_keys_for_eval,
   1969     )

File /opt/conda/envs/llama/lib/python3.11/site-packages/transformers/trainer.py:2305, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2302     self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
   2304 with self.accelerator.accumulate(model):
-> 2305     tr_loss_step = self.training_step(model, inputs)
   2307 if (
   2308     args.logging_nan_inf_filter
   2309     and not is_torch_xla_available()
   2310     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
   2311 ):
   2312     # if loss is nan or inf simply add the average of previous logged losses
   2313     tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File /opt/conda/envs/llama/lib/python3.11/site-packages/transformers/trainer.py:3361, in Trainer.training_step(self, model, inputs)
   3358     return loss_mb.reduce_mean().detach().to(self.args.device)
   3360 with self.compute_loss_context_manager():
-> 3361     loss = self.compute_loss(model, inputs)
   3363 del inputs
   3364 if (
   3365     self.args.torch_empty_cache_steps is not None
   3366     and self.state.global_step % self.args.torch_empty_cache_steps == 0
   3367 ):

File /opt/conda/envs/llama/lib/python3.11/site-packages/transformers/trainer.py:3408, in Trainer.compute_loss(self, model, inputs, return_outputs)
   3406 else:
   3407     labels = None
-> 3408 outputs = model(**inputs)
   3409 # Save past state if it exists
   3410 # TODO: this needs to be fixed and made cleaner later.
   3411 if self.args.past_index >= 0:

File /opt/conda/envs/llama/lib/python3.11/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File /opt/conda/envs/llama/lib/python3.11/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File /opt/conda/envs/llama/lib/python3.11/site-packages/accelerate/utils/operations.py:819, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
    818 def forward(*args, **kwargs):
--> 819     return model_forward(*args, **kwargs)

File /opt/conda/envs/llama/lib/python3.11/site-packages/accelerate/utils/operations.py:807, in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
    806 def __call__(self, *args, **kwargs):
--> 807     return convert_to_fp32(self.model_forward(*args, **kwargs))

File /opt/conda/envs/llama/lib/python3.11/site-packages/torch/amp/autocast_mode.py:16, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
     13 @functools.wraps(func)
     14 def decorate_autocast(*args, **kwargs):
     15     with autocast_instance:
---> 16         return func(*args, **kwargs)

File /opt/conda/envs/llama/lib/python3.11/site-packages/accelerate/utils/operations.py:819, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
    818 def forward(*args, **kwargs):
--> 819     return model_forward(*args, **kwargs)

File /opt/conda/envs/llama/lib/python3.11/site-packages/accelerate/utils/operations.py:807, in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
    806 def __call__(self, *args, **kwargs):
--> 807     return convert_to_fp32(self.model_forward(*args, **kwargs))

File /opt/conda/envs/llama/lib/python3.11/site-packages/torch/amp/autocast_mode.py:16, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
     13 @functools.wraps(func)
     14 def decorate_autocast(*args, **kwargs):
     15     with autocast_instance:
---> 16         return func(*args, **kwargs)

File /opt/conda/envs/llama/lib/python3.11/site-packages/accelerate/hooks.py:169, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    167         output = module._old_forward(*args, **kwargs)
    168 else:
--> 169     output = module._old_forward(*args, **kwargs)
    170 return module._hf_hook.post_forward(module, output)

File /opt/conda/envs/llama/lib/python3.11/site-packages/transformers/models/idefics3/modeling_idefics3.py:1179, in Idefics3ForConditionalGeneration.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, pixel_attention_mask, image_hidden_states, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   1176 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1178 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1179 outputs = self.model(
   1180     input_ids=input_ids,
   1181     attention_mask=attention_mask,
   1182     position_ids=position_ids,
   1183     past_key_values=past_key_values,
   1184     inputs_embeds=inputs_embeds,
   1185     pixel_values=pixel_values,
   1186     pixel_attention_mask=pixel_attention_mask,
   1187     image_hidden_states=image_hidden_states,
   1188     use_cache=use_cache,
   1189     output_attentions=output_attentions,
   1190     output_hidden_states=output_hidden_states,
   1191     return_dict=return_dict,
   1192 )
   1194 hidden_states = outputs[0]
   1195 logits = self.lm_head(hidden_states)

File /opt/conda/envs/llama/lib/python3.11/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File /opt/conda/envs/llama/lib/python3.11/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File /opt/conda/envs/llama/lib/python3.11/site-packages/accelerate/hooks.py:169, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    167         output = module._old_forward(*args, **kwargs)
    168 else:
--> 169     output = module._old_forward(*args, **kwargs)
    170 return module._hf_hook.post_forward(module, output)

File /opt/conda/envs/llama/lib/python3.11/site-packages/transformers/models/idefics3/modeling_idefics3.py:953, in Idefics3Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, pixel_attention_mask, image_hidden_states, use_cache, output_attentions, output_hidden_states, return_dict)
    951 past_seen_tokens = 0
    952 if use_cache:
--> 953     past_seen_tokens = past_key_values.get_seq_length()
    955 if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
    956     raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.")

AttributeError: 'NoneType' object has no attribute 'get_seq_length'

What am I doing wrong?

Fyi there is a new tutorial that should work https://x.com/mervenoyann/status/1821605881815147004
Maybe you can try this one?

Thanks @HugoLaurencon . I'll give it a try.

It looks like the PR got broken at some point. Reverting to commit e1b7c0a05ab65e4ddb62a407fe12f8ec13a916f0 solved this specific issue, as reported here: https://github.com/merveenoyan/smol-vision/issues/6

Now I'm facing OOM when running on AWS A10G machine (24GB GPU memory). Fine-tuning used to work fine on this machine with Idefics2.

Here is the full trace:

OutOfMemoryError                          Traceback (most recent call last)
Cell In[9], line 1
----> 1 trainer.train()

File /opt/conda/envs/idefics3/lib/python3.12/site-packages/transformers/trainer.py:1948, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1946         hf_hub_utils.enable_progress_bars()
   1947 else:
-> 1948     return inner_training_loop(
   1949         args=args,
   1950         resume_from_checkpoint=resume_from_checkpoint,
   1951         trial=trial,
   1952         ignore_keys_for_eval=ignore_keys_for_eval,
   1953     )

File /opt/conda/envs/idefics3/lib/python3.12/site-packages/transformers/trainer.py:2289, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2286     self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
   2288 with self.accelerator.accumulate(model):
-> 2289     tr_loss_step = self.training_step(model, inputs)
   2291 if (
   2292     args.logging_nan_inf_filter
   2293     and not is_torch_xla_available()
   2294     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
   2295 ):
   2296     # if loss is nan or inf simply add the average of previous logged losses
   2297     tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)

File /opt/conda/envs/idefics3/lib/python3.12/site-packages/transformers/trainer.py:3328, in Trainer.training_step(self, model, inputs)
   3325     return loss_mb.reduce_mean().detach().to(self.args.device)
   3327 with self.compute_loss_context_manager():
-> 3328     loss = self.compute_loss(model, inputs)
   3330 del inputs
   3331 if (
   3332     self.args.torch_empty_cache_steps is not None
   3333     and self.state.global_step % self.args.torch_empty_cache_steps == 0
   3334 ):

File /opt/conda/envs/idefics3/lib/python3.12/site-packages/transformers/trainer.py:3373, in Trainer.compute_loss(self, model, inputs, return_outputs)
   3371 else:
   3372     labels = None
-> 3373 outputs = model(**inputs)
   3374 # Save past state if it exists
   3375 # TODO: this needs to be fixed and made cleaner later.
   3376 if self.args.past_index >= 0:

File /opt/conda/envs/idefics3/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File /opt/conda/envs/idefics3/lib/python3.12/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File /opt/conda/envs/idefics3/lib/python3.12/site-packages/accelerate/utils/operations.py:819, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
    818 def forward(*args, **kwargs):
--> 819     return model_forward(*args, **kwargs)

File /opt/conda/envs/idefics3/lib/python3.12/site-packages/accelerate/utils/operations.py:807, in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
    806 def __call__(self, *args, **kwargs):
--> 807     return convert_to_fp32(self.model_forward(*args, **kwargs))

File /opt/conda/envs/idefics3/lib/python3.12/site-packages/torch/amp/autocast_mode.py:43, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
     40 @functools.wraps(func)
     41 def decorate_autocast(*args, **kwargs):
     42     with autocast_instance:
---> 43         return func(*args, **kwargs)

File /opt/conda/envs/idefics3/lib/python3.12/site-packages/accelerate/hooks.py:169, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    167         output = module._old_forward(*args, **kwargs)
    168 else:
--> 169     output = module._old_forward(*args, **kwargs)
    170 return module._hf_hook.post_forward(module, output)

File /opt/conda/envs/idefics3/lib/python3.12/site-packages/transformers/models/idefics3/modeling_idefics3.py:1205, in Idefics3ForConditionalGeneration.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, pixel_attention_mask, image_hidden_states, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   1203     # Flatten the tokens
   1204     loss_fct = CrossEntropyLoss()
-> 1205     loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
   1207 if not return_dict:
   1208     output = (logits,) + outputs[1:]

File /opt/conda/envs/idefics3/lib/python3.12/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File /opt/conda/envs/idefics3/lib/python3.12/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File /opt/conda/envs/idefics3/lib/python3.12/site-packages/torch/nn/modules/loss.py:1188, in CrossEntropyLoss.forward(self, input, target)
   1187 def forward(self, input: Tensor, target: Tensor) -> Tensor:
-> 1188     return F.cross_entropy(input, target, weight=self.weight,
   1189                            ignore_index=self.ignore_index, reduction=self.reduction,
   1190                            label_smoothing=self.label_smoothing)

File /opt/conda/envs/idefics3/lib/python3.12/site-packages/torch/nn/functional.py:3104, in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   3102 if size_average is not None or reduce is not None:
   3103     reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3104 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)

OutOfMemoryError: CUDA out of memory. Tried to allocate 446.00 MiB. GPU 0 has a total capacity of 22.03 GiB of which 224.88 MiB is free. Including non-PyTorch memory, this process has 21.81 GiB memory in use. Of the allocated memory 19.53 GiB is allocated by PyTorch, and 609.11 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Solved it by lowering the value of longest_side to 364:

processor = AutoProcessor.from_pretrained(
    model_id, size={"longest_edge": 364}
)

Then, I used 4bit quantization instead of 8bit and managed to increase the size to 728:


processor = AutoProcessor.from_pretrained(
    model_id, size={"longest_edge": 2 * 364}
)

...
bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=False,
        )
...

Nice! I was going to say to reduce the value of longest_side yes.
It really depends on your application, if you don't need to be really good at OCR tasks, then 2*364 can be ok, especially if the model is completely fine-tuned with that.
364 is a bit low, I would suggest trying both approaches to be sure

Actually, the main task is OCR so I'm going to try a stronger machine and then use higher configuration. Thanks @HugoLaurencon for your very helpful insights!

maorcatmyheritage changed discussion status to closed

Sign up or log in to comment