nimocodes commited on
Commit
99fdb46
1 Parent(s): 4abab34

Update inference_2.py

Browse files
Files changed (1) hide show
  1. inference_2.py +3 -3
inference_2.py CHANGED
@@ -10,7 +10,7 @@ from models import image
10
 
11
  from onnx2pytorch import ConvertModel
12
 
13
- onnx_model = onnx.load('checkpoints/efficientnet.onnx')
14
  pytorch_model = ConvertModel(onnx_model)
15
 
16
  torch.manual_seed(42)
@@ -65,14 +65,14 @@ def get_args(parser):
65
 
66
  def load_img_modality_model(args):
67
  rgb_encoder = pytorch_model
68
- ckpt = torch.load('checkpoints/model.pth', map_location = torch.device('cpu'))
69
  rgb_encoder.load_state_dict(ckpt['rgb_encoder'], strict = True)
70
  rgb_encoder.eval()
71
  return rgb_encoder
72
 
73
  def load_spec_modality_model(args):
74
  spec_encoder = image.RawNet(args)
75
- ckpt = torch.load('checkpoints/model.pth', map_location = torch.device('cpu'))
76
  spec_encoder.load_state_dict(ckpt['spec_encoder'], strict = True)
77
  spec_encoder.eval()
78
  return spec_encoder
 
10
 
11
  from onnx2pytorch import ConvertModel
12
 
13
+ onnx_model = onnx.load('models/efficientnet.onnx')
14
  pytorch_model = ConvertModel(onnx_model)
15
 
16
  torch.manual_seed(42)
 
65
 
66
  def load_img_modality_model(args):
67
  rgb_encoder = pytorch_model
68
+ ckpt = torch.load('models/model.pth', map_location = torch.device('cpu'))
69
  rgb_encoder.load_state_dict(ckpt['rgb_encoder'], strict = True)
70
  rgb_encoder.eval()
71
  return rgb_encoder
72
 
73
  def load_spec_modality_model(args):
74
  spec_encoder = image.RawNet(args)
75
+ ckpt = torch.load('models/model.pth', map_location = torch.device('cpu'))
76
  spec_encoder.load_state_dict(ckpt['spec_encoder'], strict = True)
77
  spec_encoder.eval()
78
  return spec_encoder