albumentations
Browse files- app.py +4 -2
- requirements.txt +1 -0
app.py
CHANGED
@@ -5,7 +5,9 @@ import torch
|
|
5 |
import torchvision
|
6 |
import timm
|
7 |
|
8 |
-
|
|
|
|
|
9 |
state_dict = checkpoint["state_dict"]
|
10 |
model_weights = state_dict
|
11 |
for key in list(model_weights):
|
@@ -13,7 +15,7 @@ for key in list(model_weights):
|
|
13 |
|
14 |
|
15 |
def get_model():
|
16 |
-
model = timm.create_model('convnext_base.fb_in22k_ft_in1k', pretrained=
|
17 |
|
18 |
return model
|
19 |
|
|
|
5 |
import torchvision
|
6 |
import timm
|
7 |
|
8 |
+
print(timm.__version__)
|
9 |
+
|
10 |
+
checkpoint = torch.load('v5-epoch=19-val_loss=0.1464-val_accuracy=0.9514.ckpt', map_location=torch.device('cpu'))
|
11 |
state_dict = checkpoint["state_dict"]
|
12 |
model_weights = state_dict
|
13 |
for key in list(model_weights):
|
|
|
15 |
|
16 |
|
17 |
def get_model():
|
18 |
+
model = timm.create_model('convnext_base.fb_in22k_ft_in1k', pretrained=False, num_classes=2)
|
19 |
|
20 |
return model
|
21 |
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
torch
|
2 |
torchvision
|
3 |
timm
|
|
|
|
1 |
torch
|
2 |
torchvision
|
3 |
timm
|
4 |
+
albumentations
|