|
import logging |
|
from dataclasses import asdict, dataclass |
|
from pathlib import Path |
|
|
|
from omegaconf import OmegaConf |
|
from rich.console import Console |
|
from rich.panel import Panel |
|
from rich.table import Table |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
console = Console() |
|
|
|
|
|
def _make_stft_cfg(hop_length, win_length=None): |
|
if win_length is None: |
|
win_length = 4 * hop_length |
|
n_fft = 2 ** (win_length - 1).bit_length() |
|
return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length) |
|
|
|
|
|
def _build_rich_table(rows, columns, title=None): |
|
table = Table(title=title, header_style=None) |
|
for column in columns: |
|
table.add_column(column.capitalize(), justify="left") |
|
for row in rows: |
|
table.add_row(*map(str, row)) |
|
return Panel(table, expand=False) |
|
|
|
|
|
def _rich_print_dict(d, title="Config", key="Key", value="Value"): |
|
console.print(_build_rich_table(d.items(), [key, value], title)) |
|
|
|
|
|
@dataclass(frozen=True) |
|
class HParams: |
|
|
|
fg_dir: Path = Path("data/fg") |
|
bg_dir: Path = Path("data/bg") |
|
rir_dir: Path = Path("data/rir") |
|
load_fg_only: bool = False |
|
praat_augment_prob: float = 0 |
|
|
|
|
|
wav_rate: int = 44_100 |
|
n_fft: int = 2048 |
|
win_size: int = 2048 |
|
hop_size: int = 420 |
|
num_mels: int = 128 |
|
stft_magnitude_min: float = 1e-4 |
|
preemphasis: float = 0.97 |
|
mix_alpha_range: tuple[float, float] = (0.2, 0.8) |
|
|
|
|
|
nj: int = 64 |
|
training_seconds: float = 1.0 |
|
batch_size_per_gpu: int = 16 |
|
min_lr: float = 1e-5 |
|
max_lr: float = 1e-4 |
|
warmup_steps: int = 1000 |
|
max_steps: int = 1_000_000 |
|
gradient_clipping: float = 1.0 |
|
|
|
@property |
|
def deepspeed_config(self): |
|
return { |
|
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu, |
|
"optimizer": { |
|
"type": "Adam", |
|
"params": {"lr": float(self.min_lr)}, |
|
}, |
|
"scheduler": { |
|
"type": "WarmupDecayLR", |
|
"params": { |
|
"warmup_min_lr": float(self.min_lr), |
|
"warmup_max_lr": float(self.max_lr), |
|
"warmup_num_steps": self.warmup_steps, |
|
"total_num_steps": self.max_steps, |
|
"warmup_type": "linear", |
|
}, |
|
}, |
|
"gradient_clipping": self.gradient_clipping, |
|
} |
|
|
|
@property |
|
def stft_cfgs(self): |
|
assert self.wav_rate == 44_100, f"wav_rate must be 44_100, got {self.wav_rate}" |
|
return [_make_stft_cfg(h) for h in (100, 256, 512)] |
|
|
|
@classmethod |
|
def from_yaml(cls, path: Path) -> "HParams": |
|
logger.info(f"Reading hparams from {path}") |
|
|
|
return cls(**dict(OmegaConf.merge(cls(), OmegaConf.load(path)))) |
|
|
|
def save_if_not_exists(self, run_dir: Path): |
|
path = run_dir / "hparams.yaml" |
|
if path.exists(): |
|
logger.info(f"{path} already exists, not saving") |
|
return |
|
path.parent.mkdir(parents=True, exist_ok=True) |
|
OmegaConf.save(asdict(self), str(path)) |
|
|
|
@classmethod |
|
def load(cls, run_dir, yaml: Path | None = None): |
|
hps = [] |
|
|
|
if (run_dir / "hparams.yaml").exists(): |
|
hps.append(cls.from_yaml(run_dir / "hparams.yaml")) |
|
|
|
if yaml is not None: |
|
hps.append(cls.from_yaml(yaml)) |
|
|
|
if len(hps) == 0: |
|
hps.append(cls()) |
|
|
|
for hp in hps[1:]: |
|
if hp != hps[0]: |
|
errors = {} |
|
for k, v in asdict(hp).items(): |
|
if getattr(hps[0], k) != v: |
|
errors[k] = f"{getattr(hps[0], k)} != {v}" |
|
raise ValueError(f"Found inconsistent hparams: {errors}, consider deleting {run_dir}") |
|
|
|
return hps[0] |
|
|
|
def print(self): |
|
_rich_print_dict(asdict(self), title="HParams") |
|
|