Update Predict.py
Browse files- Predict.py +1 -1
Predict.py
CHANGED
@@ -14,7 +14,7 @@ swap_space = {v : k for k, v in answer_space.items()}
|
|
14 |
|
15 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
16 |
model = BaseModel().to(device)
|
17 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
18 |
|
19 |
def generate_caption(image, question):
|
20 |
if isinstance(image, np.ndarray):
|
|
|
14 |
|
15 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
16 |
model = BaseModel().to(device)
|
17 |
+
model.load_state_dict(checkpoint['model_state_dict'], map_locaption = device)
|
18 |
|
19 |
def generate_caption(image, question):
|
20 |
if isinstance(image, np.ndarray):
|