amaye15 commited on
Commit
cdfccc4
·
verified ·
1 Parent(s): c386892

Update modeling_aimv2.py

Browse files
Files changed (1) hide show
  1. modeling_aimv2.py +3 -18
modeling_aimv2.py CHANGED
@@ -222,7 +222,7 @@ class AIMv2Model(AIMv2PretrainedModel):
222
  hidden_states=hidden_states,
223
  )
224
 
225
-
226
  class AIMv2ForImageClassification(AIMv2PretrainedModel):
227
  def __init__(self, config: AIMv2Config):
228
  super().__init__(config)
@@ -310,9 +310,9 @@ class AIMv2ForImageClassification(AIMv2PretrainedModel):
310
  hidden_states=outputs.hidden_states,
311
  # attentions=outputs.attentions,
312
  )
 
313
 
314
 
315
- '''
316
  class AIMv2ForImageClassification(AIMv2PretrainedModel):
317
  def __init__(self, config: AIMv2Config):
318
  super().__init__(config)
@@ -338,15 +338,10 @@ class AIMv2ForImageClassification(AIMv2PretrainedModel):
338
  output_hidden_states: Optional[bool] = None,
339
  return_dict: Optional[bool] = None,
340
  ) -> Union[tuple, ImageClassifierOutput]:
341
- print("Forward pass initiated")
342
- print(f"Input pixel_values shape: {pixel_values.shape if pixel_values is not None else 'None'}")
343
- print(f"Head mask provided: {head_mask is not None}")
344
- print(f"Labels provided: {labels is not None}")
345
 
346
  return_dict = (
347
  return_dict if return_dict is not None else self.config.use_return_dict
348
  )
349
- print(f"Using return_dict: {return_dict}")
350
 
351
  # Call base model
352
  outputs = self.aimv2(
@@ -356,33 +351,23 @@ class AIMv2ForImageClassification(AIMv2PretrainedModel):
356
  return_dict=return_dict,
357
  )
358
  sequence_output = outputs[0]
359
- print(f"Sequence output shape: {sequence_output.shape}")
360
-
361
  # Classifier head
362
  logits = self.classifier(sequence_output[:, 0, :])
363
- print(f"Logits shape: {logits.shape}")
364
- print(f"Logits shape: {logits}")
365
 
366
  loss = None
367
  if labels is not None:
368
  labels = labels.to(logits.device)
369
- print(f"Labels shape: {labels.shape}")
370
- print(f"Labels shape: {labels}")
371
-
372
  # Always use cross-entropy loss
373
  loss_fct = CrossEntropyLoss()
374
  loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
375
- print(f"Loss computed: {loss.item()}")
376
 
377
  if not return_dict:
378
  output = (logits,) + outputs[1:]
379
- print("Returning as tuple")
380
  return ((loss,) + output) if loss is not None else output
381
 
382
- print("Returning as ImageClassifierOutput")
383
  return ImageClassifierOutput(
384
  loss=loss,
385
  logits=logits,
386
  hidden_states=outputs.hidden_states,
387
  )
388
- '''
 
222
  hidden_states=hidden_states,
223
  )
224
 
225
+ '''
226
  class AIMv2ForImageClassification(AIMv2PretrainedModel):
227
  def __init__(self, config: AIMv2Config):
228
  super().__init__(config)
 
310
  hidden_states=outputs.hidden_states,
311
  # attentions=outputs.attentions,
312
  )
313
+ '''
314
 
315
 
 
316
  class AIMv2ForImageClassification(AIMv2PretrainedModel):
317
  def __init__(self, config: AIMv2Config):
318
  super().__init__(config)
 
338
  output_hidden_states: Optional[bool] = None,
339
  return_dict: Optional[bool] = None,
340
  ) -> Union[tuple, ImageClassifierOutput]:
 
 
 
 
341
 
342
  return_dict = (
343
  return_dict if return_dict is not None else self.config.use_return_dict
344
  )
 
345
 
346
  # Call base model
347
  outputs = self.aimv2(
 
351
  return_dict=return_dict,
352
  )
353
  sequence_output = outputs[0]
 
 
354
  # Classifier head
355
  logits = self.classifier(sequence_output[:, 0, :])
 
 
356
 
357
  loss = None
358
  if labels is not None:
359
  labels = labels.to(logits.device)
 
 
 
360
  # Always use cross-entropy loss
361
  loss_fct = CrossEntropyLoss()
362
  loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
 
363
 
364
  if not return_dict:
365
  output = (logits,) + outputs[1:]
 
366
  return ((loss,) + output) if loss is not None else output
367
 
 
368
  return ImageClassifierOutput(
369
  loss=loss,
370
  logits=logits,
371
  hidden_states=outputs.hidden_states,
372
  )
373
+