Spaces:
Sleeping
Sleeping
File size: 3,995 Bytes
32b2aaa 1df74c6 32b2aaa 1df74c6 32b2aaa 1df74c6 32b2aaa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import logging
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Union
from omegaconf import OmegaConf
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
logger = logging.getLogger(__name__)
console = Console()
def _make_stft_cfg(hop_length, win_length=None):
if win_length is None:
win_length = 4 * hop_length
n_fft = 2 ** (win_length - 1).bit_length()
return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
def _build_rich_table(rows, columns, title=None):
table = Table(title=title, header_style=None)
for column in columns:
table.add_column(column.capitalize(), justify="left")
for row in rows:
table.add_row(*map(str, row))
return Panel(table, expand=False)
def _rich_print_dict(d, title="Config", key="Key", value="Value"):
console.print(_build_rich_table(d.items(), [key, value], title))
@dataclass(frozen=True)
class HParams:
# Dataset
fg_dir: Path = Path("data/fg")
bg_dir: Path = Path("data/bg")
rir_dir: Path = Path("data/rir")
load_fg_only: bool = False
praat_augment_prob: float = 0
# Audio settings
wav_rate: int = 44_100
n_fft: int = 2048
win_size: int = 2048
hop_size: int = 420 # 9.5ms
num_mels: int = 128
stft_magnitude_min: float = 1e-4
preemphasis: float = 0.97
mix_alpha_range: tuple[float, float] = (0.2, 0.8)
# Training
nj: int = 64
training_seconds: float = 1.0
batch_size_per_gpu: int = 16
min_lr: float = 1e-5
max_lr: float = 1e-4
warmup_steps: int = 1000
max_steps: int = 1_000_000
gradient_clipping: float = 1.0
@property
def deepspeed_config(self):
return {
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
"optimizer": {
"type": "Adam",
"params": {"lr": float(self.min_lr)},
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": float(self.min_lr),
"warmup_max_lr": float(self.max_lr),
"warmup_num_steps": self.warmup_steps,
"total_num_steps": self.max_steps,
"warmup_type": "linear",
},
},
"gradient_clipping": self.gradient_clipping,
}
@property
def stft_cfgs(self):
assert self.wav_rate == 44_100, f"wav_rate must be 44_100, got {self.wav_rate}"
return [_make_stft_cfg(h) for h in (100, 256, 512)]
@classmethod
def from_yaml(cls, path: Path) -> "HParams":
logger.info(f"Reading hparams from {path}")
# First merge to fix types (e.g., str -> Path)
return cls(**dict(OmegaConf.merge(cls(), OmegaConf.load(path))))
def save_if_not_exists(self, run_dir: Path):
path = run_dir / "hparams.yaml"
if path.exists():
logger.info(f"{path} already exists, not saving")
return
path.parent.mkdir(parents=True, exist_ok=True)
OmegaConf.save(asdict(self), str(path))
@classmethod
def load(cls, run_dir, yaml: Union[Path, None] = None):
hps = []
if (run_dir / "hparams.yaml").exists():
hps.append(cls.from_yaml(run_dir / "hparams.yaml"))
if yaml is not None:
hps.append(cls.from_yaml(yaml))
if len(hps) == 0:
hps.append(cls())
for hp in hps[1:]:
if hp != hps[0]:
errors = {}
for k, v in asdict(hp).items():
if getattr(hps[0], k) != v:
errors[k] = f"{getattr(hps[0], k)} != {v}"
raise ValueError(
f"Found inconsistent hparams: {errors}, consider deleting {run_dir}"
)
return hps[0]
def print(self):
_rich_print_dict(asdict(self), title="HParams")
|