Upload folder using huggingface_hub
Browse files
Infer.py
CHANGED
@@ -29,7 +29,7 @@ def get_mask(lengths, max_length):
|
|
29 |
|
30 |
class Infer:
|
31 |
def __init__(self, device):
|
32 |
-
pretrained_ckpt = torch.load("ckpts/model.pth")
|
33 |
args = pretrained_ckpt['args']
|
34 |
args.n_ans = 2
|
35 |
args.max_tokens = 256
|
|
|
29 |
|
30 |
class Infer:
|
31 |
def __init__(self, device):
|
32 |
+
pretrained_ckpt = torch.load("ckpts/model.pth", map_location="cpu")
|
33 |
args = pretrained_ckpt['args']
|
34 |
args.n_ans = 2
|
35 |
args.max_tokens = 256
|