Update app.py
Browse files
app.py
CHANGED
@@ -102,11 +102,12 @@ def fine_tune_classification_model(train_loader):
|
|
102 |
|
103 |
# Update the classifier layer to match the number of labels
|
104 |
if hasattr(model, 'classifier'):
|
105 |
-
model.classifier
|
|
|
|
|
|
|
106 |
else:
|
107 |
-
|
108 |
-
model.classifier = torch.nn.Linear(model.config.num_labels, 3) # Update according to available layers
|
109 |
-
|
110 |
|
111 |
model.train()
|
112 |
|
|
|
102 |
|
103 |
# Update the classifier layer to match the number of labels
|
104 |
if hasattr(model, 'classifier'):
|
105 |
+
if isinstance(model.classifier, torch.nn.Sequential):
|
106 |
+
model.classifier[-1] = torch.nn.Linear(model.classifier[-1].in_features, 3) # Assuming 3 output classes
|
107 |
+
else:
|
108 |
+
model.classifier = torch.nn.Linear(model.classifier.in_features, 3) # In case it's a Linear layer directly
|
109 |
else:
|
110 |
+
print("Classifier layer not found")
|
|
|
|
|
111 |
|
112 |
model.train()
|
113 |
|