unpairedelectron07 commited on
Commit
70c420e
·
verified ·
1 Parent(s): eb6b37d

Upload 3 files

Browse files
audiocraft/environment.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Provides cluster and tools configuration across clusters (slurm, dora, utilities).
9
+ """
10
+
11
+ import logging
12
+ import os
13
+ from pathlib import Path
14
+ import re
15
+ import typing as tp
16
+
17
+ import omegaconf
18
+
19
+ from .utils.cluster import _guess_cluster_type
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class AudioCraftEnvironment:
26
+ """Environment configuration for teams and clusters.
27
+
28
+ AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment
29
+ or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment
30
+ provides pointers to a reference folder resolved automatically across clusters that is shared across team members,
31
+ allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically
32
+ map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters.
33
+
34
+ The cluster type is identified automatically and base configuration file is read from config/teams.yaml.
35
+ Use the following environment variables to specify the cluster, team or configuration:
36
+
37
+ AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type
38
+ cannot be inferred automatically.
39
+ AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration.
40
+ If not set, configuration is read from config/teams.yaml.
41
+ AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team.
42
+ Cluster configuration are shared across teams to match compute allocation,
43
+ specify your cluster configuration in the configuration file under a key mapping
44
+ your team name.
45
+ """
46
+ _instance = None
47
+ DEFAULT_TEAM = "default"
48
+
49
+ def __init__(self) -> None:
50
+ """Loads configuration."""
51
+ self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM)
52
+ cluster_type = _guess_cluster_type()
53
+ cluster = os.getenv(
54
+ "AUDIOCRAFT_CLUSTER", cluster_type.value
55
+ )
56
+ logger.info("Detecting cluster type %s", cluster_type)
57
+
58
+ self.cluster: str = cluster
59
+
60
+ config_path = os.getenv(
61
+ "AUDIOCRAFT_CONFIG",
62
+ Path(__file__)
63
+ .parent.parent.joinpath("config/teams", self.team)
64
+ .with_suffix(".yaml"),
65
+ )
66
+ self.config = omegaconf.OmegaConf.load(config_path)
67
+ self._dataset_mappers = []
68
+ cluster_config = self._get_cluster_config()
69
+ if "dataset_mappers" in cluster_config:
70
+ for pattern, repl in cluster_config["dataset_mappers"].items():
71
+ regex = re.compile(pattern)
72
+ self._dataset_mappers.append((regex, repl))
73
+
74
+ def _get_cluster_config(self) -> omegaconf.DictConfig:
75
+ assert isinstance(self.config, omegaconf.DictConfig)
76
+ return self.config[self.cluster]
77
+
78
+ @classmethod
79
+ def instance(cls):
80
+ if cls._instance is None:
81
+ cls._instance = cls()
82
+ return cls._instance
83
+
84
+ @classmethod
85
+ def reset(cls):
86
+ """Clears the environment and forces a reload on next invocation."""
87
+ cls._instance = None
88
+
89
+ @classmethod
90
+ def get_team(cls) -> str:
91
+ """Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var.
92
+ If not defined, defaults to "labs".
93
+ """
94
+ return cls.instance().team
95
+
96
+ @classmethod
97
+ def get_cluster(cls) -> str:
98
+ """Gets the detected cluster.
99
+ This value can be overridden by the AUDIOCRAFT_CLUSTER env var.
100
+ """
101
+ return cls.instance().cluster
102
+
103
+ @classmethod
104
+ def get_dora_dir(cls) -> Path:
105
+ """Gets the path to the dora directory for the current team and cluster.
106
+ Value is overridden by the AUDIOCRAFT_DORA_DIR env var.
107
+ """
108
+ cluster_config = cls.instance()._get_cluster_config()
109
+ dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"])
110
+ logger.warning(f"Dora directory: {dora_dir}")
111
+ return Path(dora_dir)
112
+
113
+ @classmethod
114
+ def get_reference_dir(cls) -> Path:
115
+ """Gets the path to the reference directory for the current team and cluster.
116
+ Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var.
117
+ """
118
+ cluster_config = cls.instance()._get_cluster_config()
119
+ return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"]))
120
+
121
+ @classmethod
122
+ def get_slurm_exclude(cls) -> tp.Optional[str]:
123
+ """Get the list of nodes to exclude for that cluster."""
124
+ cluster_config = cls.instance()._get_cluster_config()
125
+ return cluster_config.get("slurm_exclude")
126
+
127
+ @classmethod
128
+ def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str:
129
+ """Gets the requested partitions for the current team and cluster as a comma-separated string.
130
+
131
+ Args:
132
+ partition_types (list[str], optional): partition types to retrieve. Values must be
133
+ from ['global', 'team']. If not provided, the global partition is returned.
134
+ """
135
+ if not partition_types:
136
+ partition_types = ["global"]
137
+
138
+ cluster_config = cls.instance()._get_cluster_config()
139
+ partitions = [
140
+ cluster_config["partitions"][partition_type]
141
+ for partition_type in partition_types
142
+ ]
143
+ return ",".join(partitions)
144
+
145
+ @classmethod
146
+ def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path:
147
+ """Converts reference placeholder in path with configured reference dir to resolve paths.
148
+
149
+ Args:
150
+ path (str or Path): Path to resolve.
151
+ Returns:
152
+ Path: Resolved path.
153
+ """
154
+ path = str(path)
155
+
156
+ if path.startswith("//reference"):
157
+ reference_dir = cls.get_reference_dir()
158
+ logger.warn(f"Reference directory: {reference_dir}")
159
+ assert (
160
+ reference_dir.exists() and reference_dir.is_dir()
161
+ ), f"Reference directory does not exist: {reference_dir}."
162
+ path = re.sub("^//reference", str(reference_dir), path)
163
+
164
+ return Path(path)
165
+
166
+ @classmethod
167
+ def apply_dataset_mappers(cls, path: str) -> str:
168
+ """Applies dataset mapping regex rules as defined in the configuration.
169
+ If no rules are defined, the path is returned as-is.
170
+ """
171
+ instance = cls.instance()
172
+
173
+ for pattern, repl in instance._dataset_mappers:
174
+ path = pattern.sub(repl, path)
175
+
176
+ return path
audiocraft/py.typed ADDED
File without changes
audiocraft/train.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Entry point for dora to launch solvers for running training loops.
9
+ See more info on how to use dora: https://github.com/facebookresearch/dora
10
+ """
11
+
12
+ import logging
13
+ import multiprocessing
14
+ import os
15
+ from pathlib import Path
16
+ import sys
17
+ import typing as tp
18
+
19
+ from dora import git_save, hydra_main, XP
20
+ import flashy
21
+ import hydra
22
+ import omegaconf
23
+
24
+ from .environment import AudioCraftEnvironment
25
+ from .utils.cluster import get_slurm_parameters
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ def resolve_config_dset_paths(cfg):
31
+ """Enable Dora to load manifest from git clone repository."""
32
+ # manifest files for the different splits
33
+ for key, value in cfg.datasource.items():
34
+ if isinstance(value, str):
35
+ cfg.datasource[key] = git_save.to_absolute_path(value)
36
+
37
+
38
+ def get_solver(cfg):
39
+ from . import solvers
40
+ # Convert batch size to batch size for each GPU
41
+ assert cfg.dataset.batch_size % flashy.distrib.world_size() == 0
42
+ cfg.dataset.batch_size //= flashy.distrib.world_size()
43
+ for split in ['train', 'valid', 'evaluate', 'generate']:
44
+ if hasattr(cfg.dataset, split) and hasattr(cfg.dataset[split], 'batch_size'):
45
+ assert cfg.dataset[split].batch_size % flashy.distrib.world_size() == 0
46
+ cfg.dataset[split].batch_size //= flashy.distrib.world_size()
47
+ resolve_config_dset_paths(cfg)
48
+ solver = solvers.get_solver(cfg)
49
+ return solver
50
+
51
+
52
+ def get_solver_from_xp(xp: XP, override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
53
+ restore: bool = True, load_best: bool = True,
54
+ ignore_state_keys: tp.List[str] = [], disable_fsdp: bool = True):
55
+ """Given a XP, return the Solver object.
56
+
57
+ Args:
58
+ xp (XP): Dora experiment for which to retrieve the solver.
59
+ override_cfg (dict or None): If not None, should be a dict used to
60
+ override some values in the config of `xp`. This will not impact
61
+ the XP signature or folder. The format is different
62
+ than the one used in Dora grids, nested keys should actually be nested dicts,
63
+ not flattened, e.g. `{'optim': {'batch_size': 32}}`.
64
+ restore (bool): If `True` (the default), restore state from the last checkpoint.
65
+ load_best (bool): If `True` (the default), load the best state from the checkpoint.
66
+ ignore_state_keys (list[str]): List of sources to ignore when loading the state, e.g. `optimizer`.
67
+ disable_fsdp (bool): if True, disables FSDP entirely. This will
68
+ also automatically skip loading the EMA. For solver specific
69
+ state sources, like the optimizer, you might want to
70
+ use along `ignore_state_keys=['optimizer']`. Must be used with `load_best=True`.
71
+ """
72
+ logger.info(f"Loading solver from XP {xp.sig}. "
73
+ f"Overrides used: {xp.argv}")
74
+ cfg = xp.cfg
75
+ if override_cfg is not None:
76
+ cfg = omegaconf.OmegaConf.merge(cfg, omegaconf.DictConfig(override_cfg))
77
+ if disable_fsdp and cfg.fsdp.use:
78
+ cfg.fsdp.use = False
79
+ assert load_best is True
80
+ # ignoring some keys that were FSDP sharded like model, ema, and best_state.
81
+ # fsdp_best_state will be used in that case. When using a specific solver,
82
+ # one is responsible for adding the relevant keys, e.g. 'optimizer'.
83
+ # We could make something to automatically register those inside the solver, but that
84
+ # seem overkill at this point.
85
+ ignore_state_keys = ignore_state_keys + ['model', 'ema', 'best_state']
86
+
87
+ try:
88
+ with xp.enter():
89
+ solver = get_solver(cfg)
90
+ if restore:
91
+ solver.restore(load_best=load_best, ignore_state_keys=ignore_state_keys)
92
+ return solver
93
+ finally:
94
+ hydra.core.global_hydra.GlobalHydra.instance().clear()
95
+
96
+
97
+ def get_solver_from_sig(sig: str, *args, **kwargs):
98
+ """Return Solver object from Dora signature, i.e. to play with it from a notebook.
99
+ See `get_solver_from_xp` for more information.
100
+ """
101
+ xp = main.get_xp_from_sig(sig)
102
+ return get_solver_from_xp(xp, *args, **kwargs)
103
+
104
+
105
+ def init_seed_and_system(cfg):
106
+ import numpy as np
107
+ import torch
108
+ import random
109
+ from audiocraft.modules.transformer import set_efficient_attention_backend
110
+
111
+ multiprocessing.set_start_method(cfg.mp_start_method)
112
+ logger.debug('Setting mp start method to %s', cfg.mp_start_method)
113
+ random.seed(cfg.seed)
114
+ np.random.seed(cfg.seed)
115
+ # torch also initialize cuda seed if available
116
+ torch.manual_seed(cfg.seed)
117
+ torch.set_num_threads(cfg.num_threads)
118
+ os.environ['MKL_NUM_THREADS'] = str(cfg.num_threads)
119
+ os.environ['OMP_NUM_THREADS'] = str(cfg.num_threads)
120
+ logger.debug('Setting num threads to %d', cfg.num_threads)
121
+ set_efficient_attention_backend(cfg.efficient_attention_backend)
122
+ logger.debug('Setting efficient attention backend to %s', cfg.efficient_attention_backend)
123
+ if 'SLURM_JOB_ID' in os.environ:
124
+ tmpdir = Path('/scratch/slurm_tmpdir/' + os.environ['SLURM_JOB_ID'])
125
+ if tmpdir.exists():
126
+ logger.info("Changing tmpdir to %s", tmpdir)
127
+ os.environ['TMPDIR'] = str(tmpdir)
128
+
129
+
130
+ @hydra_main(config_path='../config', config_name='config', version_base='1.1')
131
+ def main(cfg):
132
+ init_seed_and_system(cfg)
133
+
134
+ # Setup logging both to XP specific folder, and to stderr.
135
+ log_name = '%s.log.{rank}' % cfg.execute_only if cfg.execute_only else 'solver.log.{rank}'
136
+ flashy.setup_logging(level=str(cfg.logging.level).upper(), log_name=log_name)
137
+ # Initialize distributed training, no need to specify anything when using Dora.
138
+ flashy.distrib.init()
139
+ solver = get_solver(cfg)
140
+ if cfg.show:
141
+ solver.show()
142
+ return
143
+
144
+ if cfg.execute_only:
145
+ assert cfg.execute_inplace or cfg.continue_from is not None, \
146
+ "Please explicitly specify the checkpoint to continue from with continue_from=<sig_or_path> " + \
147
+ "when running with execute_only or set execute_inplace to True."
148
+ solver.restore(replay_metrics=False) # load checkpoint
149
+ solver.run_one_stage(cfg.execute_only)
150
+ return
151
+
152
+ return solver.run()
153
+
154
+
155
+ main.dora.dir = AudioCraftEnvironment.get_dora_dir()
156
+ main._base_cfg.slurm = get_slurm_parameters(main._base_cfg.slurm)
157
+
158
+ if main.dora.shared is not None and not os.access(main.dora.shared, os.R_OK):
159
+ print("No read permission on dora.shared folder, ignoring it.", file=sys.stderr)
160
+ main.dora.shared = None
161
+
162
+ if __name__ == '__main__':
163
+ main()