jadechoghari
commited on
Commit
•
9b35b7c
1
Parent(s):
2a97bc5
Update pipeline.py
Browse files- pipeline.py +2 -3
pipeline.py
CHANGED
@@ -7,7 +7,7 @@ import torch
|
|
7 |
|
8 |
class MOSDiffusionPipeline(DiffusionPipeline):
|
9 |
|
10 |
-
def __init__(self, reload_from_ckpt=
|
11 |
"""
|
12 |
Initialize the MOS Diffusion pipeline and download the necessary files/folders.
|
13 |
|
@@ -26,8 +26,7 @@ class MOSDiffusionPipeline(DiffusionPipeline):
|
|
26 |
self.reload_from_ckpt = reload_from_ckpt
|
27 |
config_yaml_path = os.path.join(self.config_yaml)
|
28 |
self.configs = self.load_yaml(config_yaml_path)
|
29 |
-
|
30 |
-
self.configs["reload_from_ckpt"] = self.reload_from_ckpt
|
31 |
self.exp_name = os.path.basename(self.config_yaml.split(".")[0])
|
32 |
self.exp_group_name = os.path.basename(os.path.dirname(self.config_yaml))
|
33 |
|
|
|
7 |
|
8 |
class MOSDiffusionPipeline(DiffusionPipeline):
|
9 |
|
10 |
+
def __init__(self, reload_from_ckpt="./qa-mdt/checkpoint_389999.ckpt", base_folder=None):
|
11 |
"""
|
12 |
Initialize the MOS Diffusion pipeline and download the necessary files/folders.
|
13 |
|
|
|
26 |
self.reload_from_ckpt = reload_from_ckpt
|
27 |
config_yaml_path = os.path.join(self.config_yaml)
|
28 |
self.configs = self.load_yaml(config_yaml_path)
|
29 |
+
self.configs["reload_from_ckpt"] = self.reload_from_ckpt
|
|
|
30 |
self.exp_name = os.path.basename(self.config_yaml.split(".")[0])
|
31 |
self.exp_group_name = os.path.basename(os.path.dirname(self.config_yaml))
|
32 |
|