windy2612 commited on
Commit
243445e
1 Parent(s): 97d34bc

Update Predict.py

Browse files
Files changed (1) hide show
  1. Predict.py +1 -1
Predict.py CHANGED
@@ -14,7 +14,7 @@ swap_space = {v : k for k, v in answer_space.items()}
14
 
15
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
  model = BaseModel().to(device)
17
- model.load_state_dict(checkpoint['model_state_dict'])
18
 
19
  def generate_caption(image, question):
20
  if isinstance(image, np.ndarray):
 
14
 
15
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
  model = BaseModel().to(device)
17
+ model.load_state_dict(checkpoint['model_state_dict'], map_locaption = device)
18
 
19
  def generate_caption(image, question):
20
  if isinstance(image, np.ndarray):