yixin1121 commited on
Commit
85af14c
1 Parent(s): 544e91d

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. Infer.py +1 -1
Infer.py CHANGED
@@ -29,7 +29,7 @@ def get_mask(lengths, max_length):
29
 
30
  class Infer:
31
  def __init__(self, device):
32
- pretrained_ckpt = torch.load("ckpts/model.pth")
33
  args = pretrained_ckpt['args']
34
  args.n_ans = 2
35
  args.max_tokens = 256
 
29
 
30
  class Infer:
31
  def __init__(self, device):
32
+ pretrained_ckpt = torch.load("ckpts/model.pth", map_location="cpu")
33
  args = pretrained_ckpt['args']
34
  args.n_ans = 2
35
  args.max_tokens = 256