Tanusree88 commited on
Commit
7ab14ef
·
verified ·
1 Parent(s): 6871a7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -102,11 +102,12 @@ def fine_tune_classification_model(train_loader):
102
 
103
  # Update the classifier layer to match the number of labels
104
  if hasattr(model, 'classifier'):
105
- model.classifier = torch.nn.Linear(model.classifier.in_features, 3) # Assuming 3 output classes
 
 
 
106
  else:
107
- # Access the linear layer differently if 'classifier' does not exist
108
- model.classifier = torch.nn.Linear(model.config.num_labels, 3) # Update according to available layers
109
-
110
 
111
  model.train()
112
 
 
102
 
103
  # Update the classifier layer to match the number of labels
104
  if hasattr(model, 'classifier'):
105
+ if isinstance(model.classifier, torch.nn.Sequential):
106
+ model.classifier[-1] = torch.nn.Linear(model.classifier[-1].in_features, 3) # Assuming 3 output classes
107
+ else:
108
+ model.classifier = torch.nn.Linear(model.classifier.in_features, 3) # In case it's a Linear layer directly
109
  else:
110
+ print("Classifier layer not found")
 
 
111
 
112
  model.train()
113