jadechoghari
commited on
Commit
•
fd1a078
1
Parent(s):
eb26dea
Update pipeline.py
Browse files- pipeline.py +3 -6
pipeline.py
CHANGED
@@ -7,7 +7,7 @@ import torch
|
|
7 |
|
8 |
class MOSDiffusionPipeline(DiffusionPipeline):
|
9 |
|
10 |
-
def __init__(self,
|
11 |
"""
|
12 |
Initialize the MOS Diffusion pipeline and download the necessary files/folders.
|
13 |
|
@@ -22,15 +22,12 @@ class MOSDiffusionPipeline(DiffusionPipeline):
|
|
22 |
|
23 |
self.base_folder = base_folder if base_folder else os.getcwd()
|
24 |
self.repo_id = "jadechoghari/qa-mdt"
|
25 |
-
self.config_yaml =
|
26 |
-
self.list_inference = list_inference
|
27 |
self.reload_from_ckpt = reload_from_ckpt
|
28 |
config_yaml_path = os.path.join(self.config_yaml)
|
29 |
self.configs = self.load_yaml(config_yaml_path)
|
30 |
if self.reload_from_ckpt is not None:
|
31 |
self.configs["reload_from_ckpt"] = self.reload_from_ckpt
|
32 |
-
|
33 |
-
self.dataset_key = build_dataset_json_from_list(self.list_inference)
|
34 |
self.exp_name = os.path.basename(self.config_yaml.split(".")[0])
|
35 |
self.exp_group_name = os.path.basename(os.path.dirname(self.config_yaml))
|
36 |
|
@@ -95,5 +92,5 @@ class MOSDiffusionPipeline(DiffusionPipeline):
|
|
95 |
# Example of how to use the pipeline
|
96 |
if __name__ == "__main__":
|
97 |
pipe = MOSDiffusionPipeline()
|
98 |
-
result = pipe("
|
99 |
print(result)
|
|
|
7 |
|
8 |
class MOSDiffusionPipeline(DiffusionPipeline):
|
9 |
|
10 |
+
def __init__(self, reload_from_ckpt=None, base_folder=None):
|
11 |
"""
|
12 |
Initialize the MOS Diffusion pipeline and download the necessary files/folders.
|
13 |
|
|
|
22 |
|
23 |
self.base_folder = base_folder if base_folder else os.getcwd()
|
24 |
self.repo_id = "jadechoghari/qa-mdt"
|
25 |
+
self.config_yaml = "./qa_mdt/audioldm_train/config/mos_as_token/qa_mdt.yaml"
|
|
|
26 |
self.reload_from_ckpt = reload_from_ckpt
|
27 |
config_yaml_path = os.path.join(self.config_yaml)
|
28 |
self.configs = self.load_yaml(config_yaml_path)
|
29 |
if self.reload_from_ckpt is not None:
|
30 |
self.configs["reload_from_ckpt"] = self.reload_from_ckpt
|
|
|
|
|
31 |
self.exp_name = os.path.basename(self.config_yaml.split(".")[0])
|
32 |
self.exp_group_name = os.path.basename(os.path.dirname(self.config_yaml))
|
33 |
|
|
|
92 |
# Example of how to use the pipeline
|
93 |
if __name__ == "__main__":
|
94 |
pipe = MOSDiffusionPipeline()
|
95 |
+
result = pipe("A modern synthesizer creating futuristic soundscapes.")
|
96 |
print(result)
|