amaye15 commited on
Commit
b0a61c5
·
verified ·
1 Parent(s): c8eff26

Update modeling_aimv2.py

Browse files
Files changed (1) hide show
  1. modeling_aimv2.py +20 -2
modeling_aimv2.py CHANGED
@@ -309,6 +309,12 @@ class AIMv2ForImageClassification(AIMv2PretrainedModel):
309
  '''
310
 
311
 
 
 
 
 
 
 
312
  class AIMv2ForImageClassification(AIMv2PretrainedModel):
313
  def __init__(self, config: AIMv2Config):
314
  super().__init__(config)
@@ -334,34 +340,46 @@ class AIMv2ForImageClassification(AIMv2PretrainedModel):
334
  output_hidden_states: Optional[bool] = None,
335
  return_dict: Optional[bool] = None,
336
  ) -> Union[tuple, ImageClassifierOutput]:
337
-
 
 
 
 
338
  return_dict = (
339
  return_dict if return_dict is not None else self.config.use_return_dict
340
  )
 
341
 
 
342
  outputs = self.aimv2(
343
  pixel_values,
344
  mask=head_mask,
345
  output_hidden_states=output_hidden_states,
346
  return_dict=return_dict,
347
  )
348
-
349
  sequence_output = outputs[0]
 
350
 
 
351
  logits = self.classifier(sequence_output[:, 0, :])
 
352
 
353
  loss = None
354
  if labels is not None:
355
  labels = labels.to(logits.device)
 
356
 
357
  # Always use cross-entropy loss
358
  loss_fct = CrossEntropyLoss()
359
  loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
 
360
 
361
  if not return_dict:
362
  output = (logits,) + outputs[1:]
 
363
  return ((loss,) + output) if loss is not None else output
364
 
 
365
  return ImageClassifierOutput(
366
  loss=loss,
367
  logits=logits,
 
309
  '''
310
 
311
 
312
+ import logging
313
+
314
+ # Setup logging
315
+ logging.basicConfig(level=logging.DEBUG)
316
+ logger = logging.getLogger(__name__)
317
+
318
  class AIMv2ForImageClassification(AIMv2PretrainedModel):
319
  def __init__(self, config: AIMv2Config):
320
  super().__init__(config)
 
340
  output_hidden_states: Optional[bool] = None,
341
  return_dict: Optional[bool] = None,
342
  ) -> Union[tuple, ImageClassifierOutput]:
343
+ logger.debug("Forward pass initiated")
344
+ logger.debug(f"Input pixel_values shape: {pixel_values.shape if pixel_values is not None else 'None'}")
345
+ logger.debug(f"Head mask provided: {head_mask is not None}")
346
+ logger.debug(f"Labels provided: {labels is not None}")
347
+
348
  return_dict = (
349
  return_dict if return_dict is not None else self.config.use_return_dict
350
  )
351
+ logger.debug(f"Using return_dict: {return_dict}")
352
 
353
+ # Call base model
354
  outputs = self.aimv2(
355
  pixel_values,
356
  mask=head_mask,
357
  output_hidden_states=output_hidden_states,
358
  return_dict=return_dict,
359
  )
 
360
  sequence_output = outputs[0]
361
+ logger.debug(f"Sequence output shape: {sequence_output.shape}")
362
 
363
+ # Classifier head
364
  logits = self.classifier(sequence_output[:, 0, :])
365
+ logger.debug(f"Logits shape: {logits.shape}")
366
 
367
  loss = None
368
  if labels is not None:
369
  labels = labels.to(logits.device)
370
+ logger.debug(f"Labels shape: {labels.shape}")
371
 
372
  # Always use cross-entropy loss
373
  loss_fct = CrossEntropyLoss()
374
  loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
375
+ logger.debug(f"Loss computed: {loss.item()}")
376
 
377
  if not return_dict:
378
  output = (logits,) + outputs[1:]
379
+ logger.debug("Returning as tuple")
380
  return ((loss,) + output) if loss is not None else output
381
 
382
+ logger.debug("Returning as ImageClassifierOutput")
383
  return ImageClassifierOutput(
384
  loss=loss,
385
  logits=logits,