Update pipeline.py
Browse files- pipeline.py +1 -1
pipeline.py
CHANGED
@@ -35,7 +35,7 @@ class PreTrainedPipeline():
|
|
35 |
def __init__(self, path=""):
|
36 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
37 |
self.model = Generator().to(self.device)
|
38 |
-
self.model.load_state_dict(torch.load("pytorch_model.bin", map_location=device))
|
39 |
|
40 |
def __call__(self, inputs: str):
|
41 |
"""
|
|
|
35 |
def __init__(self, path=""):
|
36 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
37 |
self.model = Generator().to(self.device)
|
38 |
+
self.model.load_state_dict(torch.load("pytorch_model.bin", map_location=self.device))
|
39 |
|
40 |
def __call__(self, inputs: str):
|
41 |
"""
|