Spaces:
Runtime error
Runtime error
Update inference_2.py
Browse files- inference_2.py +3 -3
inference_2.py
CHANGED
@@ -10,7 +10,7 @@ from models import image
|
|
10 |
|
11 |
from onnx2pytorch import ConvertModel
|
12 |
|
13 |
-
onnx_model = onnx.load('
|
14 |
pytorch_model = ConvertModel(onnx_model)
|
15 |
|
16 |
torch.manual_seed(42)
|
@@ -65,14 +65,14 @@ def get_args(parser):
|
|
65 |
|
66 |
def load_img_modality_model(args):
|
67 |
rgb_encoder = pytorch_model
|
68 |
-
ckpt = torch.load('
|
69 |
rgb_encoder.load_state_dict(ckpt['rgb_encoder'], strict = True)
|
70 |
rgb_encoder.eval()
|
71 |
return rgb_encoder
|
72 |
|
73 |
def load_spec_modality_model(args):
|
74 |
spec_encoder = image.RawNet(args)
|
75 |
-
ckpt = torch.load('
|
76 |
spec_encoder.load_state_dict(ckpt['spec_encoder'], strict = True)
|
77 |
spec_encoder.eval()
|
78 |
return spec_encoder
|
|
|
10 |
|
11 |
from onnx2pytorch import ConvertModel
|
12 |
|
13 |
+
onnx_model = onnx.load('models/efficientnet.onnx')
|
14 |
pytorch_model = ConvertModel(onnx_model)
|
15 |
|
16 |
torch.manual_seed(42)
|
|
|
65 |
|
66 |
def load_img_modality_model(args):
|
67 |
rgb_encoder = pytorch_model
|
68 |
+
ckpt = torch.load('models/model.pth', map_location = torch.device('cpu'))
|
69 |
rgb_encoder.load_state_dict(ckpt['rgb_encoder'], strict = True)
|
70 |
rgb_encoder.eval()
|
71 |
return rgb_encoder
|
72 |
|
73 |
def load_spec_modality_model(args):
|
74 |
spec_encoder = image.RawNet(args)
|
75 |
+
ckpt = torch.load('models/model.pth', map_location = torch.device('cpu'))
|
76 |
spec_encoder.load_state_dict(ckpt['spec_encoder'], strict = True)
|
77 |
spec_encoder.eval()
|
78 |
return spec_encoder
|