moshel commited on
Commit
8855229
·
1 Parent(s): 163a85b
Files changed (1) hide show
  1. app.py +5 -0
app.py CHANGED
@@ -5,6 +5,11 @@ import torch
5
  import torchvision
6
 
7
  model = torch.load('v4-epoch=19-val_loss=0.6964-val_accuracy=0.8964.ckpt', map_location=torch.device('cpu'))
 
 
 
 
 
8
 
9
  import requests
10
  from PIL import Image
 
5
  import torchvision
6
 
7
  model = torch.load('v4-epoch=19-val_loss=0.6964-val_accuracy=0.8964.ckpt', map_location=torch.device('cpu'))
8
+ state_dict = checkpoint["state_dict"]
9
+ model_weights = state_dict
10
+ for key in list(model_weights):
11
+ model_weights[key.replace("backbone.", "")] = model_weights.pop(key)
12
+ model.load_state_dict(model_weights).eval()
13
 
14
  import requests
15
  from PIL import Image