Wendyellé Abubakrh Alban NYANTUDRE commited on
Commit
88b5dc0
1 Parent(s): c49c7f5

finally deleted .git file from resemble-enhance

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. resemble-enhance/README.md +95 -0
  2. resemble-enhance/config/denoiser.yaml +2 -0
  3. resemble-enhance/config/enhancer_stage1.yaml +4 -0
  4. resemble-enhance/config/enhancer_stage2.yaml +8 -0
  5. resemble-enhance/resemble_enhance/__init__.py +0 -0
  6. resemble-enhance/resemble_enhance/common.py +55 -0
  7. resemble-enhance/resemble_enhance/data/__init__.py +48 -0
  8. resemble-enhance/resemble_enhance/data/dataset.py +171 -0
  9. resemble-enhance/resemble_enhance/data/distorter/__init__.py +1 -0
  10. resemble-enhance/resemble_enhance/data/distorter/base.py +104 -0
  11. resemble-enhance/resemble_enhance/data/distorter/custom.py +85 -0
  12. resemble-enhance/resemble_enhance/data/distorter/distorter.py +32 -0
  13. resemble-enhance/resemble_enhance/data/distorter/sox.py +176 -0
  14. resemble-enhance/resemble_enhance/data/utils.py +43 -0
  15. resemble-enhance/resemble_enhance/denoiser/__init__.py +0 -0
  16. resemble-enhance/resemble_enhance/denoiser/__main__.py +30 -0
  17. resemble-enhance/resemble_enhance/denoiser/denoiser.py +181 -0
  18. resemble-enhance/resemble_enhance/denoiser/hparams.py +9 -0
  19. resemble-enhance/resemble_enhance/denoiser/inference.py +29 -0
  20. resemble-enhance/resemble_enhance/denoiser/train.py +112 -0
  21. resemble-enhance/resemble_enhance/denoiser/unet.py +144 -0
  22. resemble-enhance/resemble_enhance/enhancer/__init__.py +0 -0
  23. resemble-enhance/resemble_enhance/enhancer/__main__.py +129 -0
  24. resemble-enhance/resemble_enhance/enhancer/download.py +30 -0
  25. resemble-enhance/resemble_enhance/enhancer/enhancer.py +195 -0
  26. resemble-enhance/resemble_enhance/enhancer/hparams.py +23 -0
  27. resemble-enhance/resemble_enhance/enhancer/inference.py +41 -0
  28. resemble-enhance/resemble_enhance/enhancer/lcfm/__init__.py +2 -0
  29. resemble-enhance/resemble_enhance/enhancer/lcfm/cfm.py +372 -0
  30. resemble-enhance/resemble_enhance/enhancer/lcfm/irmae.py +123 -0
  31. resemble-enhance/resemble_enhance/enhancer/lcfm/lcfm.py +152 -0
  32. resemble-enhance/resemble_enhance/enhancer/lcfm/wn.py +147 -0
  33. resemble-enhance/resemble_enhance/enhancer/train.py +143 -0
  34. resemble-enhance/resemble_enhance/enhancer/univnet/__init__.py +1 -0
  35. resemble-enhance/resemble_enhance/enhancer/univnet/alias_free_torch/__init__.py +5 -0
  36. resemble-enhance/resemble_enhance/enhancer/univnet/alias_free_torch/filter.py +95 -0
  37. resemble-enhance/resemble_enhance/enhancer/univnet/alias_free_torch/resample.py +49 -0
  38. resemble-enhance/resemble_enhance/enhancer/univnet/amp.py +101 -0
  39. resemble-enhance/resemble_enhance/enhancer/univnet/discriminator.py +210 -0
  40. resemble-enhance/resemble_enhance/enhancer/univnet/lvcnet.py +281 -0
  41. resemble-enhance/resemble_enhance/enhancer/univnet/mrstft.py +128 -0
  42. resemble-enhance/resemble_enhance/enhancer/univnet/univnet.py +94 -0
  43. resemble-enhance/resemble_enhance/hparams.py +128 -0
  44. resemble-enhance/resemble_enhance/inference.py +163 -0
  45. resemble-enhance/resemble_enhance/melspec.py +61 -0
  46. resemble-enhance/resemble_enhance/utils/__init__.py +5 -0
  47. resemble-enhance/resemble_enhance/utils/control.py +26 -0
  48. resemble-enhance/resemble_enhance/utils/distributed.py +96 -0
  49. resemble-enhance/resemble_enhance/utils/engine.py +145 -0
  50. resemble-enhance/resemble_enhance/utils/logging.py +38 -0
resemble-enhance/README.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Resemble Enhance
2
+
3
+ [![PyPI](https://img.shields.io/pypi/v/resemble-enhance.svg)](https://pypi.org/project/resemble-enhance/)
4
+ [![Hugging Face Space](https://img.shields.io/badge/Hugging%20Face%20%F0%9F%A4%97-Space-yellow)](https://huggingface.co/spaces/ResembleAI/resemble-enhance)
5
+ [![License](https://img.shields.io/github/license/resemble-ai/Resemble-Enhance.svg)](https://github.com/resemble-ai/resemble-enhance/blob/main/LICENSE)
6
+ [![Webpage](https://img.shields.io/badge/Webpage-Online-brightgreen)](https://www.resemble.ai/enhance/)
7
+
8
+ https://github.com/resemble-ai/resemble-enhance/assets/660224/bc3ec943-e795-4646-b119-cce327c810f1
9
+
10
+ Resemble Enhance is an AI-powered tool that aims to improve the overall quality of speech by performing denoising and enhancement. It consists of two modules: a denoiser, which separates speech from a noisy audio, and an enhancer, which further boosts the perceptual audio quality by restoring audio distortions and extending the audio bandwidth. The two models are trained on high-quality 44.1kHz speech data that guarantees the enhancement of your speech with high quality.
11
+
12
+ ## Usage
13
+
14
+ ### Installation
15
+
16
+ Install the stable version:
17
+
18
+ ```bash
19
+ pip install resemble-enhance --upgrade
20
+ ```
21
+
22
+ Or try the latest pre-release version:
23
+
24
+ ```bash
25
+ pip install resemble-enhance --upgrade --pre
26
+ ```
27
+
28
+ ### Enhance
29
+
30
+ ```
31
+ resemble_enhance in_dir out_dir
32
+ ```
33
+
34
+ ### Denoise only
35
+
36
+ ```
37
+ resemble_enhance in_dir out_dir --denoise_only
38
+ ```
39
+
40
+ ### Web Demo
41
+
42
+ We provide a web demo built with Gradio, you can try it out [here](https://huggingface.co/spaces/ResembleAI/resemble-enhance), or also run it locally:
43
+
44
+ ```
45
+ python app.py
46
+ ```
47
+
48
+ ## Train your own model
49
+
50
+ ### Data Preparation
51
+
52
+ You need to prepare a foreground speech dataset and a background non-speech dataset. In addition, you need to prepare a RIR dataset ([examples](https://github.com/RoyJames/room-impulse-responses)).
53
+
54
+ ```bash
55
+ data
56
+ ├── fg
57
+ │   ├── 00001.wav
58
+ │   └── ...
59
+ ├── bg
60
+ │   ├── 00001.wav
61
+ │   └── ...
62
+ └── rir
63
+    ├── 00001.npy
64
+    └── ...
65
+ ```
66
+
67
+ ### Training
68
+
69
+ #### Denoiser Warmup
70
+
71
+ Though the denoiser is trained jointly with the enhancer, it is recommended for a warmup training first.
72
+
73
+ ```bash
74
+ python -m resemble_enhance.denoiser.train --yaml config/denoiser.yaml runs/denoiser
75
+ ```
76
+
77
+ #### Enhancer
78
+
79
+ Then, you can train the enhancer in two stages. The first stage is to train the autoencoder and vocoder. And the second stage is to train the latent conditional flow matching (CFM) model.
80
+
81
+ ##### Stage 1
82
+
83
+ ```bash
84
+ python -m resemble_enhance.enhancer.train --yaml config/enhancer_stage1.yaml runs/enhancer_stage1
85
+ ```
86
+
87
+ ##### Stage 2
88
+
89
+ ```bash
90
+ python -m resemble_enhance.enhancer.train --yaml config/enhancer_stage2.yaml runs/enhancer_stage2
91
+ ```
92
+
93
+ ## Blog
94
+
95
+ Learn more on our [website](https://www.resemble.ai/enhance/)!
resemble-enhance/config/denoiser.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ batch_size_per_gpu: 32
2
+ training_seconds: 3.0
resemble-enhance/config/enhancer_stage1.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ lcfm_training_mode: ae
2
+ load_fg_only: true
3
+ batch_size_per_gpu: 16
4
+ denoiser_run_dir: runs/denoiser
resemble-enhance/config/enhancer_stage2.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ lcfm_training_mode: cfm
2
+ batch_size_per_gpu: 32
3
+ training_seconds: 3.0
4
+ gan_training_start_step: null
5
+ lcfm_z_scale: 6
6
+ praat_augment_prob: 0.2
7
+ denoiser_run_dir: runs/denoiser
8
+ enhancer_stage1_run_dir: runs/enhancer_stage1
resemble-enhance/resemble_enhance/__init__.py ADDED
File without changes
resemble-enhance/resemble_enhance/common.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ class Normalizer(nn.Module):
10
+ def __init__(self, momentum=0.01, eps=1e-9):
11
+ super().__init__()
12
+ self.momentum = momentum
13
+ self.eps = eps
14
+ self.running_mean_unsafe: Tensor
15
+ self.running_var_unsafe: Tensor
16
+ self.register_buffer("running_mean_unsafe", torch.full([], torch.nan))
17
+ self.register_buffer("running_var_unsafe", torch.full([], torch.nan))
18
+
19
+ @property
20
+ def started(self):
21
+ return not torch.isnan(self.running_mean_unsafe)
22
+
23
+ @property
24
+ def running_mean(self):
25
+ if not self.started:
26
+ return torch.zeros_like(self.running_mean_unsafe)
27
+ return self.running_mean_unsafe
28
+
29
+ @property
30
+ def running_std(self):
31
+ if not self.started:
32
+ return torch.ones_like(self.running_var_unsafe)
33
+ return (self.running_var_unsafe + self.eps).sqrt()
34
+
35
+ @torch.no_grad()
36
+ def _ema(self, a: Tensor, x: Tensor):
37
+ return (1 - self.momentum) * a + self.momentum * x
38
+
39
+ def update_(self, x):
40
+ if not self.started:
41
+ self.running_mean_unsafe = x.mean()
42
+ self.running_var_unsafe = x.var()
43
+ else:
44
+ self.running_mean_unsafe = self._ema(self.running_mean_unsafe, x.mean())
45
+ self.running_var_unsafe = self._ema(self.running_var_unsafe, (x - self.running_mean).pow(2).mean())
46
+
47
+ def forward(self, x: Tensor, update=True):
48
+ if self.training and update:
49
+ self.update_(x)
50
+ self.stats = dict(mean=self.running_mean.item(), std=self.running_std.item())
51
+ x = (x - self.running_mean) / self.running_std
52
+ return x
53
+
54
+ def inverse(self, x: Tensor):
55
+ return x * self.running_std + self.running_mean
resemble-enhance/resemble_enhance/data/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+
4
+ from torch.utils.data import DataLoader
5
+
6
+ from ..hparams import HParams
7
+ from .dataset import Dataset
8
+ from .utils import mix_fg_bg, rglob_audio_files
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def _create_datasets(hp: HParams, mode, val_size=10, seed=123):
14
+ paths = rglob_audio_files(hp.fg_dir)
15
+ logger.info(f"Found {len(paths)} audio files in {hp.fg_dir}")
16
+
17
+ random.Random(seed).shuffle(paths)
18
+ train_paths = paths[:-val_size]
19
+ val_paths = paths[-val_size:]
20
+
21
+ train_ds = Dataset(train_paths, hp, training=True, mode=mode)
22
+ val_ds = Dataset(val_paths, hp, training=False, mode=mode)
23
+
24
+ logger.info(f"Train set: {len(train_ds)} samples - Val set: {len(val_ds)} samples")
25
+
26
+ return train_ds, val_ds
27
+
28
+
29
+ def create_dataloaders(hp: HParams, mode):
30
+ train_ds, val_ds = _create_datasets(hp=hp, mode=mode)
31
+
32
+ train_dl = DataLoader(
33
+ train_ds,
34
+ batch_size=hp.batch_size_per_gpu,
35
+ shuffle=True,
36
+ num_workers=hp.nj,
37
+ drop_last=True,
38
+ collate_fn=train_ds.collate_fn,
39
+ )
40
+ val_dl = DataLoader(
41
+ val_ds,
42
+ batch_size=1,
43
+ shuffle=False,
44
+ num_workers=hp.nj,
45
+ drop_last=False,
46
+ collate_fn=val_ds.collate_fn,
47
+ )
48
+ return train_dl, val_dl
resemble-enhance/resemble_enhance/data/dataset.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchaudio
8
+ import torchaudio.functional as AF
9
+ from torch.nn.utils.rnn import pad_sequence
10
+ from torch.utils.data import Dataset as DatasetBase
11
+
12
+ from ..hparams import HParams
13
+ from .distorter import Distorter
14
+ from .utils import rglob_audio_files
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def _normalize(x):
20
+ return x / (np.abs(x).max() + 1e-7)
21
+
22
+
23
+ def _collate(batch, key, tensor=True, pad=True):
24
+ l = [d[key] for d in batch]
25
+ if l[0] is None:
26
+ return None
27
+ if tensor:
28
+ l = [torch.from_numpy(x) for x in l]
29
+ if pad:
30
+ assert tensor, "Can't pad non-tensor"
31
+ l = pad_sequence(l, batch_first=True)
32
+ return l
33
+
34
+
35
+ def praat_augment(wav, sr):
36
+ try:
37
+ import parselmouth
38
+ except ImportError:
39
+ raise ImportError("Please install parselmouth>=0.5.0 to use Praat augmentation")
40
+ # "praat-parselmouth @ git+https://github.com/YannickJadoul/Parselmouth@0bbcca69705ed73322f3712b19d71bb3694b2540",
41
+ # https://github.com/YannickJadoul/Parselmouth/issues/68
42
+ # note that this function may hang if the praat version is 0.4.3
43
+ assert wav.ndim == 1, f"wav.ndim must be 1 but got {wav.ndim}"
44
+ sound = parselmouth.Sound(wav, sr)
45
+ formant_shift_ratio = random.uniform(1.1, 1.5)
46
+ pitch_range_factor = random.uniform(0.5, 2.0)
47
+ sound = parselmouth.praat.call(sound, "Change gender", 75, 600, formant_shift_ratio, 0, pitch_range_factor, 1.0)
48
+ wav = np.array(sound.values)[0].astype(np.float32)
49
+ return wav
50
+
51
+
52
+ class Dataset(DatasetBase):
53
+ def __init__(
54
+ self,
55
+ fg_paths: list[Path],
56
+ hp: HParams,
57
+ training=True,
58
+ max_retries=100,
59
+ silent_fg_prob=0.01,
60
+ mode=False,
61
+ ):
62
+ super().__init__()
63
+
64
+ assert mode in ("enhancer", "denoiser"), f"Invalid mode: {mode}"
65
+
66
+ self.hp = hp
67
+ self.fg_paths = fg_paths
68
+ self.bg_paths = rglob_audio_files(hp.bg_dir)
69
+
70
+ if len(self.fg_paths) == 0:
71
+ raise ValueError(f"No foreground audio files found in {hp.fg_dir}")
72
+
73
+ if len(self.bg_paths) == 0:
74
+ raise ValueError(f"No background audio files found in {hp.bg_dir}")
75
+
76
+ logger.info(f"Found {len(self.fg_paths)} foreground files and {len(self.bg_paths)} background files")
77
+
78
+ self.training = training
79
+ self.max_retries = max_retries
80
+ self.silent_fg_prob = silent_fg_prob
81
+
82
+ self.mode = mode
83
+ self.distorter = Distorter(hp, training=training, mode=mode)
84
+
85
+ def _load_wav(self, path, length=None, random_crop=True):
86
+ wav, sr = torchaudio.load(path)
87
+
88
+ wav = AF.resample(
89
+ waveform=wav,
90
+ orig_freq=sr,
91
+ new_freq=self.hp.wav_rate,
92
+ lowpass_filter_width=64,
93
+ rolloff=0.9475937167399596,
94
+ resampling_method="sinc_interp_kaiser",
95
+ beta=14.769656459379492,
96
+ )
97
+
98
+ wav = wav.float().numpy()
99
+
100
+ if wav.ndim == 2:
101
+ wav = np.mean(wav, axis=0)
102
+
103
+ if length is None and self.training:
104
+ length = int(self.hp.training_seconds * self.hp.wav_rate)
105
+
106
+ if length is not None:
107
+ if random_crop:
108
+ start = random.randint(0, max(0, len(wav) - length))
109
+ wav = wav[start : start + length]
110
+ else:
111
+ wav = wav[:length]
112
+
113
+ if length is not None and len(wav) < length:
114
+ wav = np.pad(wav, (0, length - len(wav)))
115
+
116
+ wav = _normalize(wav)
117
+
118
+ return wav
119
+
120
+ def _getitem_unsafe(self, index: int):
121
+ fg_path = self.fg_paths[index]
122
+
123
+ if self.training and random.random() < self.silent_fg_prob:
124
+ fg_wav = np.zeros(int(self.hp.training_seconds * self.hp.wav_rate), dtype=np.float32)
125
+ else:
126
+ fg_wav = self._load_wav(fg_path)
127
+ if random.random() < self.hp.praat_augment_prob and self.training:
128
+ fg_wav = praat_augment(fg_wav, self.hp.wav_rate)
129
+
130
+ if self.hp.load_fg_only:
131
+ bg_wav = None
132
+ fg_dwav = None
133
+ bg_dwav = None
134
+ else:
135
+ fg_dwav = _normalize(self.distorter(fg_wav, self.hp.wav_rate)).astype(np.float32)
136
+ if self.training:
137
+ bg_path = random.choice(self.bg_paths)
138
+ else:
139
+ # Deterministic for validation
140
+ bg_path = self.bg_paths[index % len(self.bg_paths)]
141
+ bg_wav = self._load_wav(bg_path, length=len(fg_wav), random_crop=self.training)
142
+ bg_dwav = _normalize(self.distorter(bg_wav, self.hp.wav_rate)).astype(np.float32)
143
+
144
+ return dict(
145
+ fg_wav=fg_wav,
146
+ bg_wav=bg_wav,
147
+ fg_dwav=fg_dwav,
148
+ bg_dwav=bg_dwav,
149
+ )
150
+
151
+ def __getitem__(self, index: int):
152
+ for i in range(self.max_retries):
153
+ try:
154
+ return self._getitem_unsafe(index)
155
+ except Exception as e:
156
+ if i == self.max_retries - 1:
157
+ raise RuntimeError(f"Failed to load {self.fg_paths[index]} after {self.max_retries} retries") from e
158
+ logger.debug(f"Error loading {self.fg_paths[index]}: {e}, skipping")
159
+ index = np.random.randint(0, len(self))
160
+
161
+ def __len__(self):
162
+ return len(self.fg_paths)
163
+
164
+ @staticmethod
165
+ def collate_fn(batch):
166
+ return dict(
167
+ fg_wavs=_collate(batch, "fg_wav"),
168
+ bg_wavs=_collate(batch, "bg_wav"),
169
+ fg_dwavs=_collate(batch, "fg_dwav"),
170
+ bg_dwavs=_collate(batch, "bg_dwav"),
171
+ )
resemble-enhance/resemble_enhance/data/distorter/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .distorter import Distorter
resemble-enhance/resemble_enhance/data/distorter/base.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import os
3
+ import random
4
+ import time
5
+ import warnings
6
+
7
+ import numpy as np
8
+
9
+ _DEBUG = bool(os.environ.get("DEBUG", False))
10
+
11
+
12
+ class Effect:
13
+ def apply(self, wav: np.ndarray, sr: int):
14
+ """
15
+ Args:
16
+ wav: (T)
17
+ sr: sample rate
18
+ Returns:
19
+ wav: (T) with the same sample rate of `sr`
20
+ """
21
+ raise NotImplementedError
22
+
23
+ def __call__(self, wav: np.ndarray, sr: int):
24
+ """
25
+ Args:
26
+ wav: (T)
27
+ sr: sample rate
28
+ Returns:
29
+ wav: (T) with the same sample rate of `sr`
30
+ """
31
+ assert len(wav.shape) == 1, wav.shape
32
+
33
+ if _DEBUG:
34
+ start = time.time()
35
+ else:
36
+ start = None
37
+
38
+ shape = wav.shape
39
+ assert wav.ndim == 1, f"{self}: Expected wav.ndim == 1, got {wav.ndim}."
40
+ wav = self.apply(wav, sr)
41
+ assert shape == wav.shape, f"{self}: {shape} != {wav.shape}."
42
+
43
+ if start is not None:
44
+ end = time.time()
45
+ print(f"{self.__class__.__name__}: {end - start:.3f} sec")
46
+
47
+ return wav
48
+
49
+
50
+ class Chain(Effect):
51
+ def __init__(self, *effects):
52
+ super().__init__()
53
+
54
+ self.effects = effects
55
+
56
+ def apply(self, wav, sr):
57
+ for effect in self.effects:
58
+ wav = effect(wav, sr)
59
+ return wav
60
+
61
+
62
+ class Maybe(Effect):
63
+ def __init__(self, prob, effect):
64
+ super().__init__()
65
+
66
+ self.prob = prob
67
+ self.effect = effect
68
+
69
+ if _DEBUG:
70
+ warnings.warn("DEBUG mode is on. Maybe -> Must.")
71
+ self.prob = 1
72
+
73
+ def apply(self, wav, sr):
74
+ if random.random() > self.prob:
75
+ return wav
76
+ return self.effect(wav, sr)
77
+
78
+
79
+ class Choice(Effect):
80
+ def __init__(self, *effects, **kwargs):
81
+ super().__init__()
82
+ self.effects = effects
83
+ self.kwargs = kwargs
84
+
85
+ def apply(self, wav, sr):
86
+ return np.random.choice(self.effects, **self.kwargs)(wav, sr)
87
+
88
+
89
+ class Permutation(Effect):
90
+ def __init__(self, *effects, n: int | None = None):
91
+ super().__init__()
92
+ self.effects = effects
93
+ self.n = n
94
+
95
+ def apply(self, wav, sr):
96
+ if self.n is None:
97
+ n = np.random.binomial(len(self.effects), 0.5)
98
+ else:
99
+ n = self.n
100
+ if n == 0:
101
+ return wav
102
+ perms = itertools.permutations(self.effects, n)
103
+ effects = random.choice(list(perms))
104
+ return Chain(*effects)(wav, sr)
resemble-enhance/resemble_enhance/data/distorter/custom.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+ from dataclasses import dataclass
4
+ from functools import cached_property
5
+ from pathlib import Path
6
+
7
+ import librosa
8
+ import numpy as np
9
+ from scipy import signal
10
+
11
+ from ..utils import walk_paths
12
+ from .base import Effect
13
+
14
+ _logger = logging.getLogger(__name__)
15
+
16
+
17
+ @dataclass
18
+ class RandomRIR(Effect):
19
+ rir_dir: Path | None
20
+ rir_rate: int = 44_000
21
+ rir_suffix: str = ".npy"
22
+ deterministic: bool = False
23
+
24
+ @cached_property
25
+ def rir_paths(self):
26
+ if self.rir_dir is None:
27
+ return []
28
+ return list(walk_paths(self.rir_dir, self.rir_suffix))
29
+
30
+ def _sample_rir(self):
31
+ if len(self.rir_paths) == 0:
32
+ return None
33
+
34
+ if self.deterministic:
35
+ rir_path = self.rir_paths[0]
36
+ else:
37
+ rir_path = random.choice(self.rir_paths)
38
+
39
+ rir = np.squeeze(np.load(rir_path))
40
+ assert isinstance(rir, np.ndarray)
41
+
42
+ return rir
43
+
44
+ def apply(self, wav, sr):
45
+ # ref: https://github.com/haoheliu/voicefixer_main/blob/b06e07c945ac1d309b8a57ddcd599ca376b98cd9/dataloaders/augmentation/magical_effects.py#L158
46
+
47
+ if len(self.rir_paths) == 0:
48
+ return wav
49
+
50
+ length = len(wav)
51
+
52
+ wav = librosa.resample(wav, orig_sr=sr, target_sr=self.rir_rate, res_type="kaiser_fast")
53
+ rir = self._sample_rir()
54
+
55
+ wav = signal.convolve(wav, rir, mode="same")
56
+
57
+ actlev = np.max(np.abs(wav))
58
+ if actlev > 0.99:
59
+ wav = (wav / actlev) * 0.98
60
+
61
+ wav = librosa.resample(wav, orig_sr=self.rir_rate, target_sr=sr, res_type="kaiser_fast")
62
+
63
+ if abs(length - len(wav)) > 10:
64
+ _logger.warning(f"length mismatch: {length} vs {len(wav)}")
65
+
66
+ if length > len(wav):
67
+ wav = np.pad(wav, (0, length - len(wav)))
68
+ elif length < len(wav):
69
+ wav = wav[:length]
70
+
71
+ return wav
72
+
73
+
74
+ class RandomGaussianNoise(Effect):
75
+ def __init__(self, alpha_range=(0.8, 1)):
76
+ super().__init__()
77
+ self.alpha_range = alpha_range
78
+
79
+ def apply(self, wav, sr):
80
+ noise = np.random.randn(*wav.shape)
81
+ noise_energy = np.sum(noise**2)
82
+ wav_energy = np.sum(wav**2)
83
+ noise = noise * np.sqrt(wav_energy / noise_energy)
84
+ alpha = random.uniform(*self.alpha_range)
85
+ return wav * alpha + noise * (1 - alpha)
resemble-enhance/resemble_enhance/data/distorter/distorter.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ...hparams import HParams
2
+ from .base import Chain, Choice, Permutation
3
+ from .custom import RandomGaussianNoise, RandomRIR
4
+
5
+
6
+ class Distorter(Chain):
7
+ def __init__(self, hp: HParams, training: bool = False, mode: str = "enhancer"):
8
+ # Lazy import
9
+ from .sox import RandomBandpassDistorter, RandomEqualizer, RandomLowpassDistorter, RandomOverdrive, RandomReverb
10
+
11
+ if training:
12
+ permutation = Permutation(
13
+ RandomRIR(hp.rir_dir),
14
+ RandomReverb(),
15
+ RandomGaussianNoise(),
16
+ RandomOverdrive(),
17
+ RandomEqualizer(),
18
+ Choice(
19
+ RandomLowpassDistorter(),
20
+ RandomBandpassDistorter(),
21
+ ),
22
+ )
23
+ if mode == "denoiser":
24
+ super().__init__(permutation)
25
+ else:
26
+ # 80%: distortion, 20%: clean
27
+ super().__init__(Choice(permutation, Chain(), p=[0.8, 0.2]))
28
+ else:
29
+ super().__init__(
30
+ RandomRIR(hp.rir_dir, deterministic=True),
31
+ RandomReverb(deterministic=True),
32
+ )
resemble-enhance/resemble_enhance/data/distorter/sox.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ import warnings
5
+ from functools import partial
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+ try:
11
+ import augment
12
+ except ImportError:
13
+ raise ImportError(
14
+ "augment is not installed, please install it first using:"
15
+ "\npip install git+https://github.com/facebookresearch/WavAugment@54afcdb00ccc852c2f030f239f8532c9562b550e"
16
+ )
17
+
18
+ from .base import Effect
19
+
20
+ _logger = logging.getLogger(__name__)
21
+ _DEBUG = bool(os.environ.get("DEBUG", False))
22
+
23
+
24
+ class AttachableEffect(Effect):
25
+ def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
26
+ raise NotImplementedError
27
+
28
+ def apply(self, wav: np.ndarray, sr: int):
29
+ chain = augment.EffectChain()
30
+ chain = self.attach(chain)
31
+ tensor = torch.from_numpy(wav)[None].float() # (1, T)
32
+ tensor = chain.apply(tensor, src_info={"rate": sr}, target_info={"channels": 1, "rate": sr})
33
+ wav = tensor.numpy()[0] # (T,)
34
+ return wav
35
+
36
+
37
+ class SoxEffect(AttachableEffect):
38
+ def __init__(self, effect_name: str, *args, **kwargs):
39
+ self.effect_name = effect_name
40
+ self.args = args
41
+ self.kwargs = kwargs
42
+
43
+ def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
44
+ _logger.debug(f"Attaching {self.effect_name} with {self.args} and {self.kwargs}")
45
+ if not hasattr(chain, self.effect_name):
46
+ raise ValueError(f"EffectChain has no attribute {self.effect_name}")
47
+ return getattr(chain, self.effect_name)(*self.args, **self.kwargs)
48
+
49
+
50
+ class Maybe(AttachableEffect):
51
+ """
52
+ Attach an effect with a probability.
53
+ """
54
+
55
+ def __init__(self, prob: float, effect: AttachableEffect):
56
+ self.prob = prob
57
+ self.effect = effect
58
+ if _DEBUG:
59
+ warnings.warn("DEBUG mode is on. Maybe -> Must.")
60
+ self.prob = 1
61
+
62
+ def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
63
+ if random.random() > self.prob:
64
+ return chain
65
+ return self.effect.attach(chain)
66
+
67
+
68
+ class Chain(AttachableEffect):
69
+ """
70
+ Attach a chain of effects.
71
+ """
72
+
73
+ def __init__(self, *effects: AttachableEffect):
74
+ self.effects = effects
75
+
76
+ def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
77
+ for effect in self.effects:
78
+ chain = effect.attach(chain)
79
+ return chain
80
+
81
+
82
+ class Choice(AttachableEffect):
83
+ """
84
+ Attach one of the effects randomly.
85
+ """
86
+
87
+ def __init__(self, *effects: AttachableEffect):
88
+ self.effects = effects
89
+
90
+ def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
91
+ return random.choice(self.effects).attach(chain)
92
+
93
+
94
+ class Generator:
95
+ def __call__(self) -> str:
96
+ raise NotImplementedError
97
+
98
+
99
+ class Uniform(Generator):
100
+ def __init__(self, low, high):
101
+ self.low = low
102
+ self.high = high
103
+
104
+ def __call__(self) -> str:
105
+ return str(random.uniform(self.low, self.high))
106
+
107
+
108
+ class Randint(Generator):
109
+ def __init__(self, low, high):
110
+ self.low = low
111
+ self.high = high
112
+
113
+ def __call__(self) -> str:
114
+ return str(random.randint(self.low, self.high))
115
+
116
+
117
+ class Concat(Generator):
118
+ def __init__(self, *parts: Generator | str):
119
+ self.parts = parts
120
+
121
+ def __call__(self):
122
+ return "".join([part if isinstance(part, str) else part() for part in self.parts])
123
+
124
+
125
+ class RandomLowpassDistorter(SoxEffect):
126
+ def __init__(self, low=2000, high=16000):
127
+ super().__init__("sinc", "-n", Randint(50, 200), Concat("-", Uniform(low, high)))
128
+
129
+
130
+ class RandomBandpassDistorter(SoxEffect):
131
+ def __init__(self, low=100, high=1000, min_width=2000, max_width=4000):
132
+ super().__init__("sinc", "-n", Randint(50, 200), partial(self._fn, low, high, min_width, max_width))
133
+
134
+ @staticmethod
135
+ def _fn(low, high, min_width, max_width):
136
+ start = random.randint(low, high)
137
+ stop = start + random.randint(min_width, max_width)
138
+ return f"{start}-{stop}"
139
+
140
+
141
+ class RandomEqualizer(SoxEffect):
142
+ def __init__(self, low=100, high=4000, q_low=1, q_high=5, db_low: int = -30, db_high: int = 30):
143
+ super().__init__(
144
+ "equalizer",
145
+ Uniform(low, high),
146
+ lambda: f"{random.randint(q_low, q_high)}q",
147
+ lambda: random.randint(db_low, db_high),
148
+ )
149
+
150
+
151
+ class RandomOverdrive(SoxEffect):
152
+ def __init__(self, gain_low=5, gain_high=40, colour_low=20, colour_high=80):
153
+ super().__init__("overdrive", Uniform(gain_low, gain_high), Uniform(colour_low, colour_high))
154
+
155
+
156
+ class RandomReverb(Chain):
157
+ def __init__(self, deterministic=False):
158
+ super().__init__(
159
+ SoxEffect(
160
+ "reverb",
161
+ Uniform(50, 50) if deterministic else Uniform(0, 100),
162
+ Uniform(50, 50) if deterministic else Uniform(0, 100),
163
+ Uniform(50, 50) if deterministic else Uniform(0, 100),
164
+ ),
165
+ SoxEffect("channels", 1),
166
+ )
167
+
168
+
169
+ class Flanger(SoxEffect):
170
+ def __init__(self):
171
+ super().__init__("flanger")
172
+
173
+
174
+ class Phaser(SoxEffect):
175
+ def __init__(self):
176
+ super().__init__("phaser")
resemble-enhance/resemble_enhance/data/utils.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Callable
3
+
4
+ from torch import Tensor
5
+
6
+
7
+ def walk_paths(root, suffix):
8
+ for path in Path(root).iterdir():
9
+ if path.is_dir():
10
+ yield from walk_paths(path, suffix)
11
+ elif path.suffix == suffix:
12
+ yield path
13
+
14
+
15
+ def rglob_audio_files(path: Path):
16
+ return list(walk_paths(path, ".wav")) + list(walk_paths(path, ".flac"))
17
+
18
+
19
+ def mix_fg_bg(fg: Tensor, bg: Tensor, alpha: float | Callable[..., float] = 0.5, eps=1e-7):
20
+ """
21
+ Args:
22
+ fg: (b, t)
23
+ bg: (b, t)
24
+ """
25
+ assert bg.shape == fg.shape, f"bg.shape != fg.shape: {bg.shape} != {fg.shape}"
26
+ fg = fg / (fg.abs().max(dim=-1, keepdim=True).values + eps)
27
+ bg = bg / (bg.abs().max(dim=-1, keepdim=True).values + eps)
28
+
29
+ fg_energy = fg.pow(2).sum(dim=-1, keepdim=True)
30
+ bg_energy = bg.pow(2).sum(dim=-1, keepdim=True)
31
+
32
+ fg = fg / (fg_energy + eps).sqrt()
33
+ bg = bg / (bg_energy + eps).sqrt()
34
+
35
+ if callable(alpha):
36
+ alpha = alpha()
37
+
38
+ assert 0 <= alpha <= 1, f"alpha must be between 0 and 1: {alpha}"
39
+
40
+ mx = alpha * fg + (1 - alpha) * bg
41
+ mx = mx / (mx.abs().max(dim=-1, keepdim=True).values + eps)
42
+
43
+ return mx
resemble-enhance/resemble_enhance/denoiser/__init__.py ADDED
File without changes
resemble-enhance/resemble_enhance/denoiser/__main__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ import torchaudio
6
+
7
+ from .inference import denoise
8
+
9
+
10
+ @torch.inference_mode()
11
+ def main():
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument("in_dir", type=Path, help="Path to input audio folder")
14
+ parser.add_argument("out_dir", type=Path, help="Output folder")
15
+ parser.add_argument("--run_dir", type=Path, default="runs/denoiser", help="Path to run folder")
16
+ parser.add_argument("--suffix", type=str, default=".wav", help="File suffix")
17
+ parser.add_argument("--device", type=str, default="cuda", help="Device")
18
+ args = parser.parse_args()
19
+
20
+ for path in args.in_dir.glob(f"**/*{args.suffix}"):
21
+ print(f"Processing {path} ..")
22
+ dwav, sr = torchaudio.load(path)
23
+ hwav, sr = denoise(dwav[0], sr, args.run_dir, args.device)
24
+ out_path = args.out_dir / path.relative_to(args.in_dir)
25
+ out_path.parent.mkdir(parents=True, exist_ok=True)
26
+ torchaudio.save(out_path, hwav[None], sr)
27
+
28
+
29
+ if __name__ == "__main__":
30
+ main()
resemble-enhance/resemble_enhance/denoiser/denoiser.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import Tensor, nn
6
+
7
+ from ..melspec import MelSpectrogram
8
+ from .hparams import HParams
9
+ from .unet import UNet
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def _normalize(x: Tensor) -> Tensor:
15
+ return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7)
16
+
17
+
18
+ class Denoiser(nn.Module):
19
+ @property
20
+ def stft_cfg(self) -> dict:
21
+ hop_size = self.hp.hop_size
22
+ return dict(hop_length=hop_size, n_fft=hop_size * 4, win_length=hop_size * 4)
23
+
24
+ @property
25
+ def n_fft(self):
26
+ return self.stft_cfg["n_fft"]
27
+
28
+ @property
29
+ def eps(self):
30
+ return 1e-7
31
+
32
+ def __init__(self, hp: HParams):
33
+ super().__init__()
34
+ self.hp = hp
35
+ self.net = UNet(input_dim=3, output_dim=3)
36
+ self.mel_fn = MelSpectrogram(hp)
37
+
38
+ self.dummy: Tensor
39
+ self.register_buffer("dummy", torch.zeros(1), persistent=False)
40
+
41
+ def to_mel(self, x: Tensor, drop_last=True):
42
+ """
43
+ Args:
44
+ x: (b t), wavs
45
+ Returns:
46
+ o: (b c t), mels
47
+ """
48
+ if drop_last:
49
+ return self.mel_fn(x)[..., :-1] # (b d t)
50
+ return self.mel_fn(x)
51
+
52
+ def _stft(self, x):
53
+ """
54
+ Args:
55
+ x: (b t)
56
+ Returns:
57
+ mag: (b f t) in [0, inf)
58
+ cos: (b f t) in [-1, 1]
59
+ sin: (b f t) in [-1, 1]
60
+ """
61
+ dtype = x.dtype
62
+ device = x.device
63
+
64
+ if x.is_mps:
65
+ x = x.cpu()
66
+
67
+ window = torch.hann_window(self.stft_cfg["win_length"], device=x.device)
68
+ s = torch.stft(x.float(), **self.stft_cfg, window=window, return_complex=True) # (b f t+1)
69
+
70
+ s = s[..., :-1] # (b f t)
71
+
72
+ mag = s.abs() # (b f t)
73
+
74
+ phi = s.angle() # (b f t)
75
+ cos = phi.cos() # (b f t)
76
+ sin = phi.sin() # (b f t)
77
+
78
+ mag = mag.to(dtype=dtype, device=device)
79
+ cos = cos.to(dtype=dtype, device=device)
80
+ sin = sin.to(dtype=dtype, device=device)
81
+
82
+ return mag, cos, sin
83
+
84
+ def _istft(self, mag: Tensor, cos: Tensor, sin: Tensor):
85
+ """
86
+ Args:
87
+ mag: (b f t) in [0, inf)
88
+ cos: (b f t) in [-1, 1]
89
+ sin: (b f t) in [-1, 1]
90
+ Returns:
91
+ x: (b t)
92
+ """
93
+ device = mag.device
94
+ dtype = mag.dtype
95
+
96
+ if mag.is_mps:
97
+ mag = mag.cpu()
98
+ cos = cos.cpu()
99
+ sin = sin.cpu()
100
+
101
+ real = mag * cos # (b f t)
102
+ imag = mag * sin # (b f t)
103
+
104
+ s = torch.complex(real, imag) # (b f t)
105
+
106
+ if s.isnan().any():
107
+ logger.warning("NaN detected in ISTFT input.")
108
+
109
+ s = F.pad(s, (0, 1), "replicate") # (b f t+1)
110
+
111
+ window = torch.hann_window(self.stft_cfg["win_length"], device=s.device)
112
+ x = torch.istft(s, **self.stft_cfg, window=window, return_complex=False)
113
+
114
+ if x.isnan().any():
115
+ logger.warning("NaN detected in ISTFT output, set to zero.")
116
+ x = torch.where(x.isnan(), torch.zeros_like(x), x)
117
+
118
+ x = x.to(dtype=dtype, device=device)
119
+
120
+ return x
121
+
122
+ def _magphase(self, real, imag):
123
+ mag = (real.pow(2) + imag.pow(2) + self.eps).sqrt()
124
+ cos = real / mag
125
+ sin = imag / mag
126
+ return mag, cos, sin
127
+
128
+ def _predict(self, mag: Tensor, cos: Tensor, sin: Tensor):
129
+ """
130
+ Args:
131
+ mag: (b f t)
132
+ cos: (b f t)
133
+ sin: (b f t)
134
+ Returns:
135
+ mag_mask: (b f t) in [0, 1], magnitude mask
136
+ cos_res: (b f t) in [-1, 1], phase residual
137
+ sin_res: (b f t) in [-1, 1], phase residual
138
+ """
139
+ x = torch.stack([mag, cos, sin], dim=1) # (b 3 f t)
140
+ mag_mask, real, imag = self.net(x).unbind(1) # (b 3 f t)
141
+ mag_mask = mag_mask.sigmoid() # (b f t)
142
+ real = real.tanh() # (b f t)
143
+ imag = imag.tanh() # (b f t)
144
+ _, cos_res, sin_res = self._magphase(real, imag) # (b f t)
145
+ return mag_mask, sin_res, cos_res
146
+
147
+ def _separate(self, mag, cos, sin, mag_mask, cos_res, sin_res):
148
+ """Ref: https://audio-agi.github.io/Separate-Anything-You-Describe/AudioSep_arXiv.pdf"""
149
+ sep_mag = F.relu(mag * mag_mask)
150
+ sep_cos = cos * cos_res - sin * sin_res
151
+ sep_sin = sin * cos_res + cos * sin_res
152
+ return sep_mag, sep_cos, sep_sin
153
+
154
+ def forward(self, x: Tensor, y: Tensor | None = None):
155
+ """
156
+ Args:
157
+ x: (b t), a mixed audio
158
+ y: (b t), a fg audio
159
+ """
160
+ assert x.dim() == 2, f"Expected (b t), got {x.size()}"
161
+ x = x.to(self.dummy)
162
+ x = _normalize(x)
163
+
164
+ if y is not None:
165
+ assert y.dim() == 2, f"Expected (b t), got {y.size()}"
166
+ y = y.to(self.dummy)
167
+ y = _normalize(y)
168
+
169
+ mag, cos, sin = self._stft(x) # (b 2f t)
170
+ mag_mask, sin_res, cos_res = self._predict(mag, cos, sin)
171
+ sep_mag, sep_cos, sep_sin = self._separate(mag, cos, sin, mag_mask, cos_res, sin_res)
172
+
173
+ o = self._istft(sep_mag, sep_cos, sep_sin)
174
+
175
+ npad = x.shape[-1] - o.shape[-1]
176
+ o = F.pad(o, (0, npad))
177
+
178
+ if y is not None:
179
+ self.losses = dict(l1=F.l1_loss(o, y))
180
+
181
+ return o
resemble-enhance/resemble_enhance/denoiser/hparams.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ from ..hparams import HParams as HParamsBase
4
+
5
+
6
+ @dataclass(frozen=True)
7
+ class HParams(HParamsBase):
8
+ batch_size_per_gpu: int = 128
9
+ distort_prob: float = 0.5
resemble-enhance/resemble_enhance/denoiser/inference.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from functools import cache
3
+
4
+ import torch
5
+
6
+ from ..inference import inference
7
+ from .train import Denoiser, HParams
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ @cache
13
+ def load_denoiser(run_dir, device):
14
+ if run_dir is None:
15
+ return Denoiser(HParams())
16
+ hp = HParams.load(run_dir)
17
+ denoiser = Denoiser(hp)
18
+ path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt"
19
+ state_dict = torch.load(path, map_location="cpu")["module"]
20
+ denoiser.load_state_dict(state_dict)
21
+ denoiser.eval()
22
+ denoiser.to(device)
23
+ return denoiser
24
+
25
+
26
+ @torch.inference_mode()
27
+ def denoise(dwav, sr, run_dir, device):
28
+ denoiser = load_denoiser(run_dir, device)
29
+ return inference(model=denoiser, dwav=dwav, sr=sr, device=device)
resemble-enhance/resemble_enhance/denoiser/train.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ from functools import partial
4
+ from pathlib import Path
5
+
6
+ import soundfile
7
+ import torch
8
+ from deepspeed import DeepSpeedConfig
9
+ from torch import Tensor
10
+ from tqdm import tqdm
11
+
12
+ from ..data import create_dataloaders, mix_fg_bg
13
+ from ..utils import Engine, TrainLoop, save_mels, setup_logging, tree_map
14
+ from ..utils.distributed import is_local_leader
15
+ from .denoiser import Denoiser
16
+ from .hparams import HParams
17
+
18
+
19
+ def load_G(run_dir: Path, hp: HParams | None = None, training=True):
20
+ if hp is None:
21
+ hp = HParams.load(run_dir)
22
+ assert isinstance(hp, HParams)
23
+ model = Denoiser(hp)
24
+ engine = Engine(model=model, config_class=DeepSpeedConfig(hp.deepspeed_config), ckpt_dir=run_dir / "ds" / "G")
25
+ if training:
26
+ engine.load_checkpoint()
27
+ else:
28
+ engine.load_checkpoint(load_optimizer_states=False, load_lr_scheduler_states=False)
29
+ return engine
30
+
31
+
32
+ def save_wav(path: Path, wav: Tensor, rate: int):
33
+ wav = wav.detach().cpu().numpy()
34
+ soundfile.write(path, wav, samplerate=rate)
35
+
36
+
37
+ def main():
38
+ parser = argparse.ArgumentParser()
39
+ parser.add_argument("run_dir", type=Path)
40
+ parser.add_argument("--yaml", type=Path, default=None)
41
+ parser.add_argument("--device", type=str, default="cuda")
42
+ args = parser.parse_args()
43
+
44
+ setup_logging(args.run_dir)
45
+ hp = HParams.load(args.run_dir, yaml=args.yaml)
46
+
47
+ if is_local_leader():
48
+ hp.save_if_not_exists(args.run_dir)
49
+ hp.print()
50
+
51
+ train_dl, val_dl = create_dataloaders(hp, mode="denoiser")
52
+
53
+ def feed_G(engine: Engine, batch: dict[str, Tensor]):
54
+ alpha_fn = lambda: random.uniform(*hp.mix_alpha_range)
55
+ if random.random() < hp.distort_prob:
56
+ fg_wavs = batch["fg_dwavs"]
57
+ else:
58
+ fg_wavs = batch["fg_wavs"]
59
+ mx_dwavs = mix_fg_bg(fg_wavs, batch["bg_dwavs"], alpha=alpha_fn)
60
+ pred = engine(mx_dwavs, fg_wavs)
61
+ losses = engine.gather_attribute("losses", prefix="losses")
62
+ return pred, losses
63
+
64
+ @torch.no_grad()
65
+ def eval_fn(engine: Engine, eval_dir, n_saved=10):
66
+ model = engine.module
67
+ model.eval()
68
+
69
+ step = engine.global_step
70
+
71
+ for i, batch in enumerate(tqdm(val_dl), 1):
72
+ batch = tree_map(lambda x: x.to(args.device) if isinstance(x, Tensor) else x, batch)
73
+
74
+ fg_dwavs = batch["fg_dwavs"] # 1 t
75
+ mx_dwavs = mix_fg_bg(fg_dwavs, batch["bg_dwavs"])
76
+ pred_fg_dwavs = model(mx_dwavs) # 1 t
77
+
78
+ mx_mels = model.to_mel(mx_dwavs) # 1 c t
79
+ fg_mels = model.to_mel(fg_dwavs) # 1 c t
80
+ pred_fg_mels = model.to_mel(pred_fg_dwavs) # 1 c t
81
+
82
+ rate = model.hp.wav_rate
83
+ get_path = lambda suffix: eval_dir / f"step_{step:08}_{i:03}{suffix}"
84
+
85
+ save_wav(get_path("_input.wav"), mx_dwavs[0], rate=rate)
86
+ save_wav(get_path("_predict.wav"), pred_fg_dwavs[0], rate=rate)
87
+ save_wav(get_path("_target.wav"), fg_dwavs[0], rate=rate)
88
+
89
+ save_mels(
90
+ get_path(".png"),
91
+ cond_mel=mx_mels[0].cpu().numpy(),
92
+ pred_mel=pred_fg_mels[0].cpu().numpy(),
93
+ targ_mel=fg_mels[0].cpu().numpy(),
94
+ )
95
+
96
+ if i >= n_saved:
97
+ break
98
+
99
+ train_loop = TrainLoop(
100
+ run_dir=args.run_dir,
101
+ train_dl=train_dl,
102
+ load_G=partial(load_G, hp=hp),
103
+ device=args.device,
104
+ feed_G=feed_G,
105
+ eval_fn=eval_fn,
106
+ )
107
+
108
+ train_loop.run(max_steps=hp.max_steps)
109
+
110
+
111
+ if __name__ == "__main__":
112
+ main()
resemble-enhance/resemble_enhance/denoiser/unet.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ from torch import nn
3
+
4
+
5
+ class PreactResBlock(nn.Sequential):
6
+ def __init__(self, dim):
7
+ super().__init__(
8
+ nn.GroupNorm(dim // 16, dim),
9
+ nn.GELU(),
10
+ nn.Conv2d(dim, dim, 3, padding=1),
11
+ nn.GroupNorm(dim // 16, dim),
12
+ nn.GELU(),
13
+ nn.Conv2d(dim, dim, 3, padding=1),
14
+ )
15
+
16
+ def forward(self, x):
17
+ return x + super().forward(x)
18
+
19
+
20
+ class UNetBlock(nn.Module):
21
+ def __init__(self, input_dim, output_dim=None, scale_factor=1.0):
22
+ super().__init__()
23
+ if output_dim is None:
24
+ output_dim = input_dim
25
+ self.pre_conv = nn.Conv2d(input_dim, output_dim, 3, padding=1)
26
+ self.res_block1 = PreactResBlock(output_dim)
27
+ self.res_block2 = PreactResBlock(output_dim)
28
+ self.downsample = self.upsample = nn.Identity()
29
+ if scale_factor > 1:
30
+ self.upsample = nn.Upsample(scale_factor=scale_factor)
31
+ elif scale_factor < 1:
32
+ self.downsample = nn.Upsample(scale_factor=scale_factor)
33
+
34
+ def forward(self, x, h=None):
35
+ """
36
+ Args:
37
+ x: (b c h w), last output
38
+ h: (b c h w), skip output
39
+ Returns:
40
+ o: (b c h w), output
41
+ s: (b c h w), skip output
42
+ """
43
+ x = self.upsample(x)
44
+ if h is not None:
45
+ assert x.shape == h.shape, f"{x.shape} != {h.shape}"
46
+ x = x + h
47
+ x = self.pre_conv(x)
48
+ x = self.res_block1(x)
49
+ x = self.res_block2(x)
50
+ return self.downsample(x), x
51
+
52
+
53
+ class UNet(nn.Module):
54
+ def __init__(self, input_dim, output_dim, hidden_dim=16, num_blocks=4, num_middle_blocks=2):
55
+ super().__init__()
56
+ self.input_dim = input_dim
57
+ self.output_dim = output_dim
58
+ self.input_proj = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
59
+ self.encoder_blocks = nn.ModuleList(
60
+ [
61
+ UNetBlock(input_dim=hidden_dim * 2**i, output_dim=hidden_dim * 2 ** (i + 1), scale_factor=0.5)
62
+ for i in range(num_blocks)
63
+ ]
64
+ )
65
+ self.middle_blocks = nn.ModuleList(
66
+ [UNetBlock(input_dim=hidden_dim * 2**num_blocks) for _ in range(num_middle_blocks)]
67
+ )
68
+ self.decoder_blocks = nn.ModuleList(
69
+ [
70
+ UNetBlock(input_dim=hidden_dim * 2 ** (i + 1), output_dim=hidden_dim * 2**i, scale_factor=2)
71
+ for i in reversed(range(num_blocks))
72
+ ]
73
+ )
74
+ self.head = nn.Sequential(
75
+ nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
76
+ nn.GELU(),
77
+ nn.Conv2d(hidden_dim, output_dim, 1),
78
+ )
79
+
80
+ @property
81
+ def scale_factor(self):
82
+ return 2 ** len(self.encoder_blocks)
83
+
84
+ def pad_to_fit(self, x):
85
+ """
86
+ Args:
87
+ x: (b c h w), input
88
+ Returns:
89
+ x: (b c h' w'), padded input
90
+ """
91
+ hpad = (self.scale_factor - x.shape[2] % self.scale_factor) % self.scale_factor
92
+ wpad = (self.scale_factor - x.shape[3] % self.scale_factor) % self.scale_factor
93
+ return F.pad(x, (0, wpad, 0, hpad))
94
+
95
+ def forward(self, x):
96
+ """
97
+ Args:
98
+ x: (b c h w), input
99
+ Returns:
100
+ o: (b c h w), output
101
+ """
102
+ shape = x.shape
103
+
104
+ x = self.pad_to_fit(x)
105
+ x = self.input_proj(x)
106
+
107
+ s_list = []
108
+ for block in self.encoder_blocks:
109
+ x, s = block(x)
110
+ s_list.append(s)
111
+
112
+ for block in self.middle_blocks:
113
+ x, _ = block(x)
114
+
115
+ for block, s in zip(self.decoder_blocks, reversed(s_list)):
116
+ x, _ = block(x, s)
117
+
118
+ x = self.head(x)
119
+ x = x[..., : shape[2], : shape[3]]
120
+
121
+ return x
122
+
123
+ def test(self, shape=(3, 512, 256)):
124
+ import ptflops
125
+
126
+ macs, params = ptflops.get_model_complexity_info(
127
+ self,
128
+ shape,
129
+ as_strings=True,
130
+ print_per_layer_stat=True,
131
+ verbose=True,
132
+ )
133
+
134
+ print(f"macs: {macs}")
135
+ print(f"params: {params}")
136
+
137
+
138
+ def main():
139
+ model = UNet(3, 3)
140
+ model.test()
141
+
142
+
143
+ if __name__ == "__main__":
144
+ main()
resemble-enhance/resemble_enhance/enhancer/__init__.py ADDED
File without changes
resemble-enhance/resemble_enhance/enhancer/__main__.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ import time
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ import torchaudio
8
+ from tqdm import tqdm
9
+
10
+ from .inference import denoise, enhance
11
+
12
+
13
+ @torch.inference_mode()
14
+ def main():
15
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
16
+ parser.add_argument("in_dir", type=Path, help="Path to input audio folder")
17
+ parser.add_argument("out_dir", type=Path, help="Output folder")
18
+ parser.add_argument(
19
+ "--run_dir",
20
+ type=Path,
21
+ default=None,
22
+ help="Path to the enhancer run folder, if None, use the default model",
23
+ )
24
+ parser.add_argument(
25
+ "--suffix",
26
+ type=str,
27
+ default=".wav",
28
+ help="Audio file suffix",
29
+ )
30
+ parser.add_argument(
31
+ "--device",
32
+ type=str,
33
+ default="cuda",
34
+ help="Device to use for computation, recommended to use CUDA",
35
+ )
36
+ parser.add_argument(
37
+ "--denoise_only",
38
+ action="store_true",
39
+ help="Only apply denoising without enhancement",
40
+ )
41
+ parser.add_argument(
42
+ "--lambd",
43
+ type=float,
44
+ default=1.0,
45
+ help="Denoise strength for enhancement (0.0 to 1.0)",
46
+ )
47
+ parser.add_argument(
48
+ "--tau",
49
+ type=float,
50
+ default=0.5,
51
+ help="CFM prior temperature (0.0 to 1.0)",
52
+ )
53
+ parser.add_argument(
54
+ "--solver",
55
+ type=str,
56
+ default="midpoint",
57
+ choices=["midpoint", "rk4", "euler"],
58
+ help="Numerical solver to use",
59
+ )
60
+ parser.add_argument(
61
+ "--nfe",
62
+ type=int,
63
+ default=64,
64
+ help="Number of function evaluations",
65
+ )
66
+ parser.add_argument(
67
+ "--parallel_mode",
68
+ action="store_true",
69
+ help="Shuffle the audio paths and skip the existing ones, enabling multiple jobs to run in parallel",
70
+ )
71
+
72
+ args = parser.parse_args()
73
+
74
+ device = args.device
75
+
76
+ if device == "cuda" and not torch.cuda.is_available():
77
+ print("CUDA is not available but --device is set to cuda, using CPU instead")
78
+ device = "cpu"
79
+
80
+ start_time = time.perf_counter()
81
+
82
+ run_dir = args.run_dir
83
+
84
+ paths = sorted(args.in_dir.glob(f"**/*{args.suffix}"))
85
+
86
+ if args.parallel_mode:
87
+ random.shuffle(paths)
88
+
89
+ if len(paths) == 0:
90
+ print(f"No {args.suffix} files found in the following path: {args.in_dir}")
91
+ return
92
+
93
+ pbar = tqdm(paths)
94
+
95
+ for path in pbar:
96
+ out_path = args.out_dir / path.relative_to(args.in_dir)
97
+ if args.parallel_mode and out_path.exists():
98
+ continue
99
+ pbar.set_description(f"Processing {out_path}")
100
+ dwav, sr = torchaudio.load(path)
101
+ dwav = dwav.mean(0)
102
+ if args.denoise_only:
103
+ hwav, sr = denoise(
104
+ dwav=dwav,
105
+ sr=sr,
106
+ device=device,
107
+ run_dir=args.run_dir,
108
+ )
109
+ else:
110
+ hwav, sr = enhance(
111
+ dwav=dwav,
112
+ sr=sr,
113
+ device=device,
114
+ nfe=args.nfe,
115
+ solver=args.solver,
116
+ lambd=args.lambd,
117
+ tau=args.tau,
118
+ run_dir=run_dir,
119
+ )
120
+ out_path.parent.mkdir(parents=True, exist_ok=True)
121
+ torchaudio.save(out_path, hwav[None], sr)
122
+
123
+ # Cool emoji effect saying the job is done
124
+ elapsed_time = time.perf_counter() - start_time
125
+ print(f"🌟 Enhancement done! {len(paths)} files processed in {elapsed_time:.2f}s")
126
+
127
+
128
+ if __name__ == "__main__":
129
+ main()
resemble-enhance/resemble_enhance/enhancer/download.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ import torch
5
+
6
+ RUN_NAME = "enhancer_stage2"
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def get_source_url(relpath):
12
+ return f"https://huggingface.co/ResembleAI/resemble-enhance/resolve/main/{RUN_NAME}/{relpath}?download=true"
13
+
14
+
15
+ def get_target_path(relpath: str | Path, run_dir: str | Path | None = None):
16
+ if run_dir is None:
17
+ run_dir = Path(__file__).parent.parent / "model_repo" / RUN_NAME
18
+ return Path(run_dir) / relpath
19
+
20
+
21
+ def download(run_dir: str | Path | None = None):
22
+ relpaths = ["hparams.yaml", "ds/G/latest", "ds/G/default/mp_rank_00_model_states.pt"]
23
+ for relpath in relpaths:
24
+ path = get_target_path(relpath, run_dir=run_dir)
25
+ if path.exists():
26
+ continue
27
+ url = get_source_url(relpath)
28
+ path.parent.mkdir(parents=True, exist_ok=True)
29
+ torch.hub.download_url_to_file(url, str(path))
30
+ return get_target_path("", run_dir=run_dir)
resemble-enhance/resemble_enhance/enhancer/enhancer.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import matplotlib.pyplot as plt
4
+ import pandas as pd
5
+ import torch
6
+ from torch import Tensor, nn
7
+ from torch.distributions import Beta
8
+
9
+ from ..common import Normalizer
10
+ from ..denoiser.inference import load_denoiser
11
+ from ..melspec import MelSpectrogram
12
+ from ..utils.distributed import global_leader_only
13
+ from ..utils.train_loop import TrainLoop
14
+ from .hparams import HParams
15
+ from .lcfm import CFM, IRMAE, LCFM
16
+ from .univnet import UnivNet
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def _maybe(fn):
22
+ def _fn(*args):
23
+ if args[0] is None:
24
+ return None
25
+ return fn(*args)
26
+
27
+ return _fn
28
+
29
+
30
+ def _normalize_wav(x: Tensor):
31
+ return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7)
32
+
33
+
34
+ class Enhancer(nn.Module):
35
+ def __init__(self, hp: HParams):
36
+ super().__init__()
37
+ self.hp = hp
38
+
39
+ n_mels = self.hp.num_mels
40
+ vocoder_input_dim = n_mels + self.hp.vocoder_extra_dim
41
+ latent_dim = self.hp.lcfm_latent_dim
42
+
43
+ self.lcfm = LCFM(
44
+ IRMAE(
45
+ input_dim=n_mels,
46
+ output_dim=vocoder_input_dim,
47
+ latent_dim=latent_dim,
48
+ ),
49
+ CFM(
50
+ cond_dim=n_mels,
51
+ output_dim=self.hp.lcfm_latent_dim,
52
+ solver_nfe=self.hp.cfm_solver_nfe,
53
+ solver_method=self.hp.cfm_solver_method,
54
+ time_mapping_divisor=self.hp.cfm_time_mapping_divisor,
55
+ ),
56
+ z_scale=self.hp.lcfm_z_scale,
57
+ )
58
+
59
+ self.lcfm.set_mode_(self.hp.lcfm_training_mode)
60
+
61
+ self.mel_fn = MelSpectrogram(hp)
62
+ self.vocoder = UnivNet(self.hp, vocoder_input_dim)
63
+ self.denoiser = load_denoiser(self.hp.denoiser_run_dir, "cpu")
64
+ self.normalizer = Normalizer()
65
+
66
+ self._eval_lambd = 0.0
67
+
68
+ self.dummy: Tensor
69
+ self.register_buffer("dummy", torch.zeros(1))
70
+
71
+ if self.hp.enhancer_stage1_run_dir is not None:
72
+ pretrained_path = self.hp.enhancer_stage1_run_dir / "ds/G/default/mp_rank_00_model_states.pt"
73
+ self._load_pretrained(pretrained_path)
74
+
75
+ logger.info(f"{self.__class__.__name__} summary")
76
+ logger.info(f"{self.summarize()}")
77
+
78
+ def _load_pretrained(self, path):
79
+ # Clone is necessary as otherwise it holds a reference to the original model
80
+ cfm_state_dict = {k: v.clone() for k, v in self.lcfm.cfm.state_dict().items()}
81
+ denoiser_state_dict = {k: v.clone() for k, v in self.denoiser.state_dict().items()}
82
+ state_dict = torch.load(path, map_location="cpu")["module"]
83
+ self.load_state_dict(state_dict, strict=False)
84
+ self.lcfm.cfm.load_state_dict(cfm_state_dict) # Reset cfm
85
+ self.denoiser.load_state_dict(denoiser_state_dict) # Reset denoiser
86
+ logger.info(f"Loaded pretrained model from {path}")
87
+
88
+ def summarize(self):
89
+ npa_train = lambda m: sum(p.numel() for p in m.parameters() if p.requires_grad)
90
+ npa = lambda m: sum(p.numel() for p in m.parameters())
91
+ rows = []
92
+ for name, module in self.named_children():
93
+ rows.append(dict(name=name, trainable=npa_train(module), total=npa(module)))
94
+ rows.append(dict(name="total", trainable=npa_train(self), total=npa(self)))
95
+ df = pd.DataFrame(rows)
96
+ return df.to_markdown(index=False)
97
+
98
+ def to_mel(self, x: Tensor, drop_last=True):
99
+ """
100
+ Args:
101
+ x: (b t), wavs
102
+ Returns:
103
+ o: (b c t), mels
104
+ """
105
+ if drop_last:
106
+ return self.mel_fn(x)[..., :-1] # (b d t)
107
+ return self.mel_fn(x)
108
+
109
+ @global_leader_only
110
+ @torch.no_grad()
111
+ def _visualize(self, original_mel, denoised_mel):
112
+ loop = TrainLoop.get_running_loop()
113
+ if loop is None or loop.global_step % 100 != 0:
114
+ return
115
+
116
+ plt.figure(figsize=(6, 6))
117
+ plt.subplot(211)
118
+ plt.title("Original")
119
+ plt.imshow(original_mel[0].cpu().numpy(), origin="lower", interpolation="none")
120
+ plt.subplot(212)
121
+ plt.title("Denoised")
122
+ plt.imshow(denoised_mel[0].cpu().numpy(), origin="lower", interpolation="none")
123
+ plt.tight_layout()
124
+
125
+ path = loop.get_running_loop_viz_path("input", ".png")
126
+ plt.savefig(path, dpi=300)
127
+
128
+ def _may_denoise(self, x: Tensor, y: Tensor | None = None):
129
+ if self.hp.lcfm_training_mode == "cfm":
130
+ return self.denoiser(x, y)
131
+ return x
132
+
133
+ def configurate_(self, nfe, solver, lambd, tau):
134
+ """
135
+ Args:
136
+ nfe: number of function evaluations
137
+ solver: solver method
138
+ lambd: denoiser strength [0, 1]
139
+ tau: prior temperature [0, 1]
140
+ """
141
+ self.lcfm.cfm.solver.configurate_(nfe, solver)
142
+ self.lcfm.eval_tau_(tau)
143
+ self._eval_lambd = lambd
144
+
145
+ def forward(self, x: Tensor, y: Tensor | None = None, z: Tensor | None = None):
146
+ """
147
+ Args:
148
+ x: (b t), mix wavs (fg + bg)
149
+ y: (b t), fg clean wavs
150
+ z: (b t), fg distorted wavs
151
+ Returns:
152
+ o: (b t), reconstructed wavs
153
+ """
154
+ assert x.dim() == 2, f"Expected (b t), got {x.size()}"
155
+ assert y is None or y.dim() == 2, f"Expected (b t), got {y.size()}"
156
+
157
+ if self.hp.lcfm_training_mode == "cfm":
158
+ self.normalizer.eval()
159
+
160
+ x = _normalize_wav(x)
161
+ y = _maybe(_normalize_wav)(y)
162
+ z = _maybe(_normalize_wav)(z)
163
+
164
+ x_mel_original = self.normalizer(self.to_mel(x), update=False) # (b d t)
165
+
166
+ if self.hp.lcfm_training_mode == "cfm":
167
+ if self.training:
168
+ lambd = Beta(0.2, 0.2).sample(x.shape[:1]).to(x.device)
169
+ lambd = lambd[:, None, None]
170
+ x_mel_denoised = self.normalizer(self.to_mel(self._may_denoise(x, z)), update=False)
171
+ x_mel_denoised = x_mel_denoised.detach()
172
+ x_mel_denoised = lambd * x_mel_denoised + (1 - lambd) * x_mel_original
173
+ self._visualize(x_mel_original, x_mel_denoised)
174
+ else:
175
+ lambd = self._eval_lambd
176
+ if lambd == 0:
177
+ x_mel_denoised = x_mel_original
178
+ else:
179
+ x_mel_denoised = self.normalizer(self.to_mel(self._may_denoise(x, z)), update=False)
180
+ x_mel_denoised = x_mel_denoised.detach()
181
+ x_mel_denoised = lambd * x_mel_denoised + (1 - lambd) * x_mel_original
182
+ else:
183
+ x_mel_denoised = x_mel_original
184
+
185
+ y_mel = _maybe(self.to_mel)(y) # (b d t)
186
+ y_mel = _maybe(self.normalizer)(y_mel)
187
+
188
+ lcfm_decoded = self.lcfm(x_mel_denoised, y_mel, ψ0=x_mel_original) # (b d t)
189
+
190
+ if lcfm_decoded is None:
191
+ o = None
192
+ else:
193
+ o = self.vocoder(lcfm_decoded, y)
194
+
195
+ return o
resemble-enhance/resemble_enhance/enhancer/hparams.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+
4
+ from ..hparams import HParams as HParamsBase
5
+
6
+
7
+ @dataclass(frozen=True)
8
+ class HParams(HParamsBase):
9
+ cfm_solver_method: str = "midpoint"
10
+ cfm_solver_nfe: int = 64
11
+ cfm_time_mapping_divisor: int = 4
12
+ univnet_nc: int = 96
13
+
14
+ lcfm_latent_dim: int = 64
15
+ lcfm_training_mode: str = "ae"
16
+ lcfm_z_scale: float = 5
17
+
18
+ vocoder_extra_dim: int = 32
19
+
20
+ gan_training_start_step: int | None = 5_000
21
+ enhancer_stage1_run_dir: Path | None = None
22
+
23
+ denoiser_run_dir: Path | None = None
resemble-enhance/resemble_enhance/enhancer/inference.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from functools import cache
3
+ from pathlib import Path
4
+
5
+ import torch
6
+
7
+ from ..inference import inference
8
+ from .download import download
9
+ from .train import Enhancer, HParams
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ @cache
15
+ def load_enhancer(run_dir: str | Path | None, device):
16
+ run_dir = download(run_dir)
17
+ hp = HParams.load(run_dir)
18
+ enhancer = Enhancer(hp)
19
+ path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt"
20
+ state_dict = torch.load(path, map_location="cpu")["module"]
21
+ enhancer.load_state_dict(state_dict)
22
+ enhancer.eval()
23
+ enhancer.to(device)
24
+ return enhancer
25
+
26
+
27
+ @torch.inference_mode()
28
+ def denoise(dwav, sr, device, run_dir=None):
29
+ enhancer = load_enhancer(run_dir, device)
30
+ return inference(model=enhancer.denoiser, dwav=dwav, sr=sr, device=device)
31
+
32
+
33
+ @torch.inference_mode()
34
+ def enhance(dwav, sr, device, nfe=32, solver="midpoint", lambd=0.5, tau=0.5, run_dir=None):
35
+ assert 0 < nfe <= 128, f"nfe must be in (0, 128], got {nfe}"
36
+ assert solver in ("midpoint", "rk4", "euler"), f"solver must be in ('midpoint', 'rk4', 'euler'), got {solver}"
37
+ assert 0 <= lambd <= 1, f"lambd must be in [0, 1], got {lambd}"
38
+ assert 0 <= tau <= 1, f"tau must be in [0, 1], got {tau}"
39
+ enhancer = load_enhancer(run_dir, device)
40
+ enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau)
41
+ return inference(model=enhancer, dwav=dwav, sr=sr, device=device)
resemble-enhance/resemble_enhance/enhancer/lcfm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .irmae import IRMAE
2
+ from .lcfm import CFM, LCFM
resemble-enhance/resemble_enhance/enhancer/lcfm/cfm.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass
3
+ from functools import partial
4
+ from typing import Protocol
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import scipy
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import Tensor, nn
12
+ from tqdm import trange
13
+
14
+ from .wn import WN
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class VelocityField(Protocol):
20
+ def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor:
21
+ ...
22
+
23
+
24
+ class Solver:
25
+ def __init__(
26
+ self,
27
+ method="midpoint",
28
+ nfe=32,
29
+ viz_name="solver",
30
+ viz_every=100,
31
+ mel_fn=None,
32
+ time_mapping_divisor=4,
33
+ verbose=False,
34
+ ):
35
+ self.configurate_(nfe=nfe, method=method)
36
+
37
+ self.verbose = verbose
38
+ self.viz_every = viz_every
39
+ self.viz_name = viz_name
40
+
41
+ self._camera = None
42
+ self._mel_fn = mel_fn
43
+ self._time_mapping = partial(self.exponential_decay_mapping, n=time_mapping_divisor)
44
+
45
+ def configurate_(self, nfe=None, method=None):
46
+ if nfe is None:
47
+ nfe = self.nfe
48
+
49
+ if method is None:
50
+ method = self.method
51
+
52
+ if nfe == 1 and method in ("midpoint", "rk4"):
53
+ logger.warning(f"1 NFE is not supported for {method}, using euler method instead.")
54
+ method = "euler"
55
+
56
+ self.nfe = nfe
57
+ self.method = method
58
+
59
+ @property
60
+ def time_mapping(self):
61
+ return self._time_mapping
62
+
63
+ @staticmethod
64
+ def exponential_decay_mapping(t, n=4):
65
+ """
66
+ Args:
67
+ n: target step
68
+ """
69
+
70
+ def h(t, a):
71
+ return (a**t - 1) / (a - 1)
72
+
73
+ # Solve h(1/n) = 0.5
74
+ a = float(scipy.optimize.fsolve(lambda a: h(1 / n, a) - 0.5, x0=0))
75
+
76
+ t = h(t, a=a)
77
+
78
+ return t
79
+
80
+ @torch.no_grad()
81
+ def _maybe_camera_snap(self, *, ψt, t):
82
+ camera = self._camera
83
+ if camera is not None:
84
+ if ψt.shape[1] == 1:
85
+ # Waveform, b 1 t, plot every 100 samples
86
+ plt.subplot(211)
87
+ plt.plot(ψt.detach().cpu().numpy()[0, 0, ::100], color="blue")
88
+ if self._mel_fn is not None:
89
+ plt.subplot(212)
90
+ mel = self._mel_fn(ψt.detach().cpu().numpy()[0, 0])
91
+ plt.imshow(mel, origin="lower", interpolation="none")
92
+ elif ψt.shape[1] == 2:
93
+ # Complex
94
+ plt.subplot(121)
95
+ plt.imshow(
96
+ ψt.detach().cpu().numpy()[0, 0],
97
+ origin="lower",
98
+ interpolation="none",
99
+ )
100
+ plt.subplot(122)
101
+ plt.imshow(
102
+ ψt.detach().cpu().numpy()[0, 1],
103
+ origin="lower",
104
+ interpolation="none",
105
+ )
106
+ else:
107
+ # Spectrogram, b c t
108
+ plt.imshow(ψt.detach().cpu().numpy()[0], origin="lower", interpolation="none")
109
+ ax = plt.gca()
110
+ ax.text(0.5, 1.01, f"t={t:.2f}", transform=ax.transAxes, ha="center")
111
+ camera.snap()
112
+
113
+ @staticmethod
114
+ def _euler_step(t, ψt, dt, f: VelocityField):
115
+ return ψt + dt * f(t=t, ψt=ψt, dt=dt)
116
+
117
+ @staticmethod
118
+ def _midpoint_step(t, ψt, dt, f: VelocityField):
119
+ return ψt + dt * f(t=t + dt / 2, ψt=ψt + dt * f(t=t, ψt=ψt, dt=dt) / 2, dt=dt)
120
+
121
+ @staticmethod
122
+ def _rk4_step(t, ψt, dt, f: VelocityField):
123
+ k1 = f(t=t, ψt=ψt, dt=dt)
124
+ k2 = f(t=t + dt / 2, ψt=ψt + dt * k1 / 2, dt=dt)
125
+ k3 = f(t=t + dt / 2, ψt=ψt + dt * k2 / 2, dt=dt)
126
+ k4 = f(t=t + dt, ψt=ψt + dt * k3, dt=dt)
127
+ return ψt + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6
128
+
129
+ @property
130
+ def _step(self):
131
+ if self.method == "euler":
132
+ return self._euler_step
133
+ elif self.method == "midpoint":
134
+ return self._midpoint_step
135
+ elif self.method == "rk4":
136
+ return self._rk4_step
137
+ else:
138
+ raise ValueError(f"Unknown method: {self.method}")
139
+
140
+ def get_running_train_loop(self):
141
+ try:
142
+ # Lazy import
143
+ from ...utils.train_loop import TrainLoop
144
+
145
+ return TrainLoop.get_running_loop()
146
+ except ImportError:
147
+ return None
148
+
149
+ @property
150
+ def visualizing(self):
151
+ loop = self.get_running_train_loop()
152
+ if loop is None:
153
+ return
154
+ out_path = loop.make_current_step_viz_path(self.viz_name, ".gif")
155
+ return loop.global_step % self.viz_every == 0 and not out_path.exists()
156
+
157
+ def _reset_camera(self):
158
+ try:
159
+ from celluloid import Camera
160
+
161
+ self._camera = Camera(plt.figure())
162
+ except:
163
+ pass
164
+
165
+ def _maybe_dump_camera(self):
166
+ camera = self._camera
167
+ loop = self.get_running_train_loop()
168
+ if camera is not None and loop is not None:
169
+ animation = camera.animate()
170
+ out_path = loop.make_current_step_viz_path(self.viz_name, ".gif")
171
+ out_path.parent.mkdir(exist_ok=True, parents=True)
172
+ animation.save(out_path, writer="pillow", fps=4)
173
+ plt.close()
174
+ self._camera = None
175
+
176
+ @property
177
+ def n_steps(self):
178
+ n = self.nfe
179
+ if self.method == "euler":
180
+ pass
181
+ elif self.method == "midpoint":
182
+ n //= 2
183
+ elif self.method == "rk4":
184
+ n //= 4
185
+ else:
186
+ raise ValueError(f"Unknown method: {self.method}")
187
+ return n
188
+
189
+ def solve(self, f: VelocityField, ψ0: Tensor, t0=0.0, t1=1.0):
190
+ ts = self._time_mapping(np.linspace(t0, t1, self.n_steps + 1))
191
+
192
+ if self.visualizing:
193
+ self._reset_camera()
194
+
195
+ if self.verbose:
196
+ steps = trange(self.n_steps, desc="CFM inference")
197
+ else:
198
+ steps = range(self.n_steps)
199
+
200
+ ψt = ψ0
201
+
202
+ for i in steps:
203
+ dt = ts[i + 1] - ts[i]
204
+ t = ts[i]
205
+ self._maybe_camera_snap(ψt=ψt, t=t)
206
+ ψt = self._step(t=t, ψt=ψt, dt=dt, f=f)
207
+
208
+ self._maybe_camera_snap(ψt=ψt, t=ts[-1])
209
+
210
+ ψ1 = ψt
211
+ del ψt
212
+
213
+ self._maybe_dump_camera()
214
+
215
+ return ψ1
216
+
217
+ def __call__(self, f: VelocityField, ψ0: Tensor, t0=0.0, t1=1.0):
218
+ return self.solve(f=f, ψ0=ψ0, t0=t0, t1=t1)
219
+
220
+
221
+ class SinusodialTimeEmbedding(nn.Module):
222
+ def __init__(self, d_embed):
223
+ super().__init__()
224
+ self.d_embed = d_embed
225
+ assert d_embed % 2 == 0
226
+
227
+ def forward(self, t):
228
+ t = t.unsqueeze(-1) # ... 1
229
+ p = torch.linspace(0, 4, self.d_embed // 2).to(t)
230
+ while p.dim() < t.dim():
231
+ p = p.unsqueeze(0) # ... d/2
232
+ sin = torch.sin(t * 10**p)
233
+ cos = torch.cos(t * 10**p)
234
+ return torch.cat([sin, cos], dim=-1)
235
+
236
+
237
+ @dataclass(eq=False)
238
+ class CFM(nn.Module):
239
+ """
240
+ This mixin is for general diffusion models.
241
+
242
+ ψ0 stands for the gaussian noise, and ψ1 is the data point.
243
+
244
+ Here we follow the CFM style:
245
+ The generation process (reverse process) is from t=0 to t=1.
246
+ The forward process is from t=1 to t=0.
247
+ """
248
+
249
+ cond_dim: int
250
+ output_dim: int
251
+ time_emb_dim: int = 128
252
+ viz_name: str = "cfm"
253
+ solver_nfe: int = 32
254
+ solver_method: str = "midpoint"
255
+ time_mapping_divisor: int = 4
256
+
257
+ def __post_init__(self):
258
+ super().__init__()
259
+ self.solver = Solver(
260
+ viz_name=self.viz_name,
261
+ viz_every=1,
262
+ nfe=self.solver_nfe,
263
+ method=self.solver_method,
264
+ time_mapping_divisor=self.time_mapping_divisor,
265
+ )
266
+ self.emb = SinusodialTimeEmbedding(self.time_emb_dim)
267
+ self.net = WN(
268
+ input_dim=self.output_dim,
269
+ output_dim=self.output_dim,
270
+ local_dim=self.cond_dim,
271
+ global_dim=self.time_emb_dim,
272
+ )
273
+
274
+ def _perturb(self, ψ1: Tensor, t: Tensor | None = None):
275
+ """
276
+ Perturb ψ1 to ψt.
277
+ """
278
+ raise NotImplementedError
279
+
280
+ def _sample_ψ0(self, x: Tensor):
281
+ """
282
+ Args:
283
+ x: (b c t), which implies the shape of ψ0
284
+ """
285
+ shape = list(x.shape)
286
+ shape[1] = self.output_dim
287
+ if self.training:
288
+ g = None
289
+ else:
290
+ g = torch.Generator(device=x.device)
291
+ g.manual_seed(0) # deterministic sampling during eval
292
+ ψ0 = torch.randn(shape, device=x.device, dtype=x.dtype, generator=g)
293
+ return ψ0
294
+
295
+ @property
296
+ def sigma(self):
297
+ return 1e-4
298
+
299
+ def _to_ψt(self, *, ψ1: Tensor, ψ0: Tensor, t: Tensor):
300
+ """
301
+ Eq (22)
302
+ """
303
+ while t.dim() < ψ1.dim():
304
+ t = t.unsqueeze(-1)
305
+ μ = t * ψ1 + (1 - t) * ψ0
306
+ return μ + torch.randn_like(μ) * self.sigma
307
+
308
+ def _to_u(self, *, ψ1, ψ0: Tensor):
309
+ """
310
+ Eq (21)
311
+ """
312
+ return ψ1 - ψ0
313
+
314
+ def _to_v(self, *, ψt, x, t: float | Tensor):
315
+ """
316
+ Args:
317
+ ψt: (b c t)
318
+ x: (b c t)
319
+ t: (b)
320
+ Returns:
321
+ v: (b c t)
322
+ """
323
+ if isinstance(t, (float, int)):
324
+ t = torch.full(ψt.shape[:1], t).to(ψt)
325
+ t = t.clamp(0, 1) # [0, 1)
326
+ g = self.emb(t) # (b d)
327
+ v = self.net(ψt, l=x, g=g)
328
+ return v
329
+
330
+ def compute_losses(self, x, y, ψ0) -> dict:
331
+ """
332
+ Args:
333
+ x: (b c t)
334
+ y: (b c t)
335
+ Returns:
336
+ losses: dict
337
+ """
338
+ t = torch.rand(len(x), device=x.device, dtype=x.dtype)
339
+ t = self.solver.time_mapping(t)
340
+
341
+ if ψ0 is None:
342
+ ψ0 = self._sample_ψ0(x)
343
+
344
+ ψt = self._to_ψt(ψ1=y, t=t, ψ0=ψ0)
345
+
346
+ v = self._to_v(ψt=ψt, t=t, x=x)
347
+ u = self._to_u(ψ1=y, ψ0=ψ0)
348
+
349
+ losses = dict(l1=F.l1_loss(v, u))
350
+
351
+ return losses
352
+
353
+ @torch.inference_mode()
354
+ def sample(self, x, ψ0=None, t0=0.0):
355
+ """
356
+ Args:
357
+ x: (b c t)
358
+ Returns:
359
+ y: (b ... t)
360
+ """
361
+ if ψ0 is None:
362
+ ψ0 = self._sample_ψ0(x)
363
+ f = lambda t, ψt, dt: self._to_v(ψt=ψt, t=t, x=x)
364
+ ψ1 = self.solver(f=f, ψ0=ψ0, t0=t0)
365
+ return ψ1
366
+
367
+ def forward(self, x: Tensor, y: Tensor | None = None, ψ0: Tensor | None = None, t0=0.0):
368
+ if y is None:
369
+ y = self.sample(x, ψ0=ψ0, t0=t0)
370
+ else:
371
+ self.losses = self.compute_losses(x, y, ψ0=ψ0)
372
+ return y
resemble-enhance/resemble_enhance/enhancer/lcfm/irmae.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass
3
+
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch import Tensor, nn
7
+ from torch.nn.utils.parametrizations import weight_norm
8
+
9
+ from ...common import Normalizer
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ @dataclass
15
+ class IRMAEOutput:
16
+ latent: Tensor # latent vector
17
+ decoded: Tensor | None # decoder output, include extra dim
18
+
19
+
20
+ class ResBlock(nn.Sequential):
21
+ def __init__(self, channels, dilations=[1, 2, 4, 8]):
22
+ wn = weight_norm
23
+ super().__init__(
24
+ nn.GroupNorm(32, channels),
25
+ nn.GELU(),
26
+ wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[0])),
27
+ nn.GroupNorm(32, channels),
28
+ nn.GELU(),
29
+ wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[1])),
30
+ nn.GroupNorm(32, channels),
31
+ nn.GELU(),
32
+ wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[2])),
33
+ nn.GroupNorm(32, channels),
34
+ nn.GELU(),
35
+ wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[3])),
36
+ )
37
+
38
+ def forward(self, x: Tensor):
39
+ return x + super().forward(x)
40
+
41
+
42
+ class IRMAE(nn.Module):
43
+ def __init__(
44
+ self,
45
+ input_dim,
46
+ output_dim,
47
+ latent_dim,
48
+ hidden_dim=1024,
49
+ num_irms=4,
50
+ ):
51
+ """
52
+ Args:
53
+ input_dim: input dimension
54
+ output_dim: output dimension
55
+ latent_dim: latent dimension
56
+ hidden_dim: hidden layer dimension
57
+ num_irm_matrics: number of implicit rank minimization matrices
58
+ norm: normalization layer
59
+ """
60
+ self.input_dim = input_dim
61
+ super().__init__()
62
+
63
+ self.encoder = nn.Sequential(
64
+ nn.Conv1d(input_dim, hidden_dim, 3, padding="same"),
65
+ *[ResBlock(hidden_dim) for _ in range(4)],
66
+ # Try to obtain compact representation (https://proceedings.neurips.cc/paper/2020/file/a9078e8653368c9c291ae2f8b74012e7-Paper.pdf)
67
+ *[nn.Conv1d(hidden_dim if i == 0 else latent_dim, latent_dim, 1, bias=False) for i in range(num_irms)],
68
+ nn.Tanh(),
69
+ )
70
+
71
+ self.decoder = nn.Sequential(
72
+ nn.Conv1d(latent_dim, hidden_dim, 3, padding="same"),
73
+ *[ResBlock(hidden_dim) for _ in range(4)],
74
+ nn.Conv1d(hidden_dim, output_dim, 1),
75
+ )
76
+
77
+ self.head = nn.Sequential(
78
+ nn.Conv1d(output_dim, hidden_dim, 3, padding="same"),
79
+ nn.GELU(),
80
+ nn.Conv1d(hidden_dim, input_dim, 1),
81
+ )
82
+
83
+ self.estimator = Normalizer()
84
+
85
+ def encode(self, x):
86
+ """
87
+ Args:
88
+ x: (b c t) tensor
89
+ """
90
+ z = self.encoder(x) # (b c t)
91
+ _ = self.estimator(z) # Estimate the glboal mean and std of z
92
+ self.stats = {}
93
+ self.stats["z_mean"] = z.mean().item()
94
+ self.stats["z_std"] = z.std().item()
95
+ self.stats["z_abs_68"] = z.abs().quantile(0.6827).item()
96
+ self.stats["z_abs_95"] = z.abs().quantile(0.9545).item()
97
+ self.stats["z_abs_99"] = z.abs().quantile(0.9973).item()
98
+ return z
99
+
100
+ def decode(self, z):
101
+ """
102
+ Args:
103
+ z: (b c t) tensor
104
+ """
105
+ return self.decoder(z)
106
+
107
+ def forward(self, x, skip_decoding=False):
108
+ """
109
+ Args:
110
+ x: (b c t) tensor
111
+ skip_decoding: if True, skip the decoding step
112
+ """
113
+ z = self.encode(x) # q(z|x)
114
+
115
+ if skip_decoding:
116
+ # This speeds up the training in cfm only mode
117
+ decoded = None
118
+ else:
119
+ decoded = self.decode(z) # p(x|z)
120
+ predicted = self.head(decoded)
121
+ self.losses = dict(mse=F.mse_loss(predicted, x))
122
+
123
+ return IRMAEOutput(latent=z, decoded=decoded)
resemble-enhance/resemble_enhance/enhancer/lcfm/lcfm.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from enum import Enum
3
+
4
+ import matplotlib.pyplot as plt
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch import Tensor, nn
8
+
9
+ from .cfm import CFM
10
+ from .irmae import IRMAE, IRMAEOutput
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def freeze_(module):
16
+ for p in module.parameters():
17
+ p.requires_grad_(False)
18
+
19
+
20
+ class LCFM(nn.Module):
21
+ class Mode(Enum):
22
+ AE = "ae"
23
+ CFM = "cfm"
24
+
25
+ def __init__(self, ae: IRMAE, cfm: CFM, z_scale: float = 1.0):
26
+ super().__init__()
27
+ self.ae = ae
28
+ self.cfm = cfm
29
+ self.z_scale = z_scale
30
+ self._mode = None
31
+ self._eval_tau = 0.5
32
+
33
+ @property
34
+ def mode(self):
35
+ return self._mode
36
+
37
+ def set_mode_(self, mode):
38
+ mode = self.Mode(mode)
39
+ self._mode = mode
40
+
41
+ if mode == mode.AE:
42
+ freeze_(self.cfm)
43
+ logger.info("Freeze cfm")
44
+ elif mode == mode.CFM:
45
+ freeze_(self.ae)
46
+ logger.info("Freeze ae (encoder and decoder)")
47
+ else:
48
+ raise ValueError(f"Unknown training mode: {mode}")
49
+
50
+ def get_running_train_loop(self):
51
+ try:
52
+ # Lazy import
53
+ from ...utils.train_loop import TrainLoop
54
+
55
+ return TrainLoop.get_running_loop()
56
+ except ImportError:
57
+ return None
58
+
59
+ @property
60
+ def global_step(self):
61
+ loop = self.get_running_train_loop()
62
+ if loop is None:
63
+ return None
64
+ return loop.global_step
65
+
66
+ @torch.no_grad()
67
+ def _visualize(self, x, y, y_):
68
+ loop = self.get_running_train_loop()
69
+ if loop is None:
70
+ return
71
+
72
+ plt.subplot(221)
73
+ plt.imshow(y[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
74
+ plt.title("GT")
75
+
76
+ plt.subplot(222)
77
+ y_ = y_[:, : y.shape[1]]
78
+ plt.imshow(y_[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
79
+ plt.title("Posterior")
80
+
81
+ plt.subplot(223)
82
+ z_ = self.cfm(x)
83
+ y__ = self.ae.decode(z_)
84
+ y__ = y__[:, : y.shape[1]]
85
+ plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
86
+ plt.title("C-Prior")
87
+ del y__
88
+
89
+ plt.subplot(224)
90
+ z_ = torch.randn_like(z_)
91
+ y__ = self.ae.decode(z_)
92
+ y__ = y__[:, : y.shape[1]]
93
+ plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
94
+ plt.title("Prior")
95
+ del z_, y__
96
+
97
+ path = loop.make_current_step_viz_path("recon", ".png")
98
+ path.parent.mkdir(exist_ok=True, parents=True)
99
+ plt.tight_layout()
100
+ plt.savefig(path, dpi=500)
101
+ plt.close()
102
+
103
+ def _scale(self, z: Tensor):
104
+ return z * self.z_scale
105
+
106
+ def _unscale(self, z: Tensor):
107
+ return z / self.z_scale
108
+
109
+ def eval_tau_(self, tau):
110
+ self._eval_tau = tau
111
+
112
+ def forward(self, x, y: Tensor | None = None, ψ0: Tensor | None = None):
113
+ """
114
+ Args:
115
+ x: (b d t), condition mel
116
+ y: (b d t), target mel
117
+ ψ0: (b d t), starting mel
118
+ """
119
+ if self.mode == self.Mode.CFM:
120
+ self.ae.eval() # Always set to eval when training cfm
121
+
122
+ if ψ0 is not None:
123
+ ψ0 = self._scale(self.ae.encode(ψ0))
124
+ if self.training:
125
+ tau = torch.rand_like(ψ0[:, :1, :1])
126
+ else:
127
+ tau = self._eval_tau
128
+ ψ0 = tau * torch.randn_like(ψ0) + (1 - tau) * ψ0
129
+
130
+ if y is None:
131
+ if self.mode == self.Mode.AE:
132
+ with torch.no_grad():
133
+ training = self.ae.training
134
+ self.ae.eval()
135
+ z = self.ae.encode(x)
136
+ self.ae.train(training)
137
+ else:
138
+ z = self._unscale(self.cfm(x, ψ0=ψ0))
139
+
140
+ h = self.ae.decode(z)
141
+ else:
142
+ ae_output: IRMAEOutput = self.ae(y, skip_decoding=self.mode == self.Mode.CFM)
143
+
144
+ if self.mode == self.Mode.CFM:
145
+ _ = self.cfm(x, self._scale(ae_output.latent.detach()), ψ0=ψ0)
146
+
147
+ h = ae_output.decoded
148
+
149
+ if h is not None and self.global_step is not None and self.global_step % 100 == 0:
150
+ self._visualize(x[:1], y[:1], h[:1])
151
+
152
+ return h
resemble-enhance/resemble_enhance/enhancer/lcfm/wn.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ @torch.jit.script
11
+ def _fused_tanh_sigmoid(h):
12
+ a, b = h.chunk(2, dim=1)
13
+ h = a.tanh() * b.sigmoid()
14
+ return h
15
+
16
+
17
+ class WNLayer(nn.Module):
18
+ """
19
+ A DiffWave-like WN
20
+ """
21
+
22
+ def __init__(self, hidden_dim, local_dim, global_dim, kernel_size, dilation):
23
+ super().__init__()
24
+
25
+ local_output_dim = hidden_dim * 2
26
+
27
+ if global_dim is not None:
28
+ self.gconv = nn.Conv1d(global_dim, hidden_dim, 1)
29
+
30
+ if local_dim is not None:
31
+ self.lconv = nn.Conv1d(local_dim, local_output_dim, 1)
32
+
33
+ self.dconv = nn.Conv1d(hidden_dim, local_output_dim, kernel_size, dilation=dilation, padding="same")
34
+
35
+ self.out = nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size=1)
36
+
37
+ def forward(self, z, l, g):
38
+ identity = z
39
+
40
+ if g is not None:
41
+ if g.dim() == 2:
42
+ g = g.unsqueeze(-1)
43
+ z = z + self.gconv(g)
44
+
45
+ z = self.dconv(z)
46
+
47
+ if l is not None:
48
+ z = z + self.lconv(l)
49
+
50
+ z = _fused_tanh_sigmoid(z)
51
+
52
+ h = self.out(z)
53
+
54
+ z, s = h.chunk(2, dim=1)
55
+
56
+ o = (z + identity) / math.sqrt(2)
57
+
58
+ return o, s
59
+
60
+
61
+ class WN(nn.Module):
62
+ def __init__(
63
+ self,
64
+ input_dim,
65
+ output_dim,
66
+ local_dim=None,
67
+ global_dim=None,
68
+ n_layers=30,
69
+ kernel_size=3,
70
+ dilation_cycle=5,
71
+ hidden_dim=512,
72
+ ):
73
+ super().__init__()
74
+ assert kernel_size % 2 == 1
75
+ assert hidden_dim % 2 == 0
76
+
77
+ self.input_dim = input_dim
78
+ self.hidden_dim = hidden_dim
79
+ self.local_dim = local_dim
80
+ self.global_dim = global_dim
81
+
82
+ self.start = nn.Conv1d(input_dim, hidden_dim, 1)
83
+ if local_dim is not None:
84
+ self.local_norm = nn.InstanceNorm1d(local_dim)
85
+
86
+ self.layers = nn.ModuleList(
87
+ [
88
+ WNLayer(
89
+ hidden_dim=hidden_dim,
90
+ local_dim=local_dim,
91
+ global_dim=global_dim,
92
+ kernel_size=kernel_size,
93
+ dilation=2 ** (i % dilation_cycle),
94
+ )
95
+ for i in range(n_layers)
96
+ ]
97
+ )
98
+
99
+ self.end = nn.Conv1d(hidden_dim, output_dim, 1)
100
+
101
+ def forward(self, z, l=None, g=None):
102
+ """
103
+ Args:
104
+ z: input (b c t)
105
+ l: local condition (b c t)
106
+ g: global condition (b d)
107
+ """
108
+ z = self.start(z)
109
+
110
+ if l is not None:
111
+ l = self.local_norm(l)
112
+
113
+ # Skips
114
+ s_list = []
115
+
116
+ for layer in self.layers:
117
+ z, s = layer(z, l, g)
118
+ s_list.append(s)
119
+
120
+ s_list = torch.stack(s_list, dim=0).sum(dim=0)
121
+ s_list = s_list / math.sqrt(len(self.layers))
122
+
123
+ o = self.end(s_list)
124
+
125
+ return o
126
+
127
+ def summarize(self, length=100):
128
+ from ptflops import get_model_complexity_info
129
+
130
+ x = torch.randn(1, self.input_dim, length)
131
+
132
+ macs, params = get_model_complexity_info(
133
+ self,
134
+ (self.input_dim, length),
135
+ as_strings=True,
136
+ print_per_layer_stat=True,
137
+ verbose=True,
138
+ )
139
+
140
+ print(f"Input shape: {x.shape}")
141
+ print(f"Computational complexity: {macs}")
142
+ print(f"Number of parameters: {params}")
143
+
144
+
145
+ if __name__ == "__main__":
146
+ model = WN(input_dim=64, output_dim=64)
147
+ model.summarize()
resemble-enhance/resemble_enhance/enhancer/train.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ from functools import partial
4
+ from pathlib import Path
5
+
6
+ import soundfile
7
+ import torch
8
+ from deepspeed import DeepSpeedConfig
9
+ from torch import Tensor
10
+ from tqdm import tqdm
11
+
12
+ from ..data import create_dataloaders, mix_fg_bg
13
+ from ..utils import Engine, TrainLoop, save_mels, setup_logging, tree_map
14
+ from ..utils.distributed import is_local_leader
15
+ from .enhancer import Enhancer
16
+ from .hparams import HParams
17
+ from .univnet.discriminator import Discriminator
18
+
19
+
20
+ def load_G(run_dir: Path, hp: HParams | None = None, training=True):
21
+ if hp is None:
22
+ hp = HParams.load(run_dir)
23
+ assert isinstance(hp, HParams)
24
+ model = Enhancer(hp)
25
+ engine = Engine(model=model, config_class=DeepSpeedConfig(hp.deepspeed_config), ckpt_dir=run_dir / "ds" / "G")
26
+ if training:
27
+ engine.load_checkpoint()
28
+ else:
29
+ engine.load_checkpoint(load_optimizer_states=False, load_lr_scheduler_states=False)
30
+ return engine
31
+
32
+
33
+ def load_D(run_dir: Path, hp: HParams):
34
+ if hp is None:
35
+ hp = HParams.load(run_dir)
36
+ assert isinstance(hp, HParams)
37
+ model = Discriminator(hp)
38
+ engine = Engine(model=model, config_class=DeepSpeedConfig(hp.deepspeed_config), ckpt_dir=run_dir / "ds" / "D")
39
+ engine.load_checkpoint()
40
+ return engine
41
+
42
+
43
+ def save_wav(path: Path, wav: Tensor, rate: int):
44
+ wav = wav.detach().cpu().numpy()
45
+ soundfile.write(path, wav, samplerate=rate)
46
+
47
+
48
+ def main():
49
+ parser = argparse.ArgumentParser()
50
+ parser.add_argument("run_dir", type=Path)
51
+ parser.add_argument("--yaml", type=Path, default=None)
52
+ parser.add_argument("--device", type=str, default="cuda")
53
+ args = parser.parse_args()
54
+
55
+ setup_logging(args.run_dir)
56
+ hp = HParams.load(args.run_dir, yaml=args.yaml)
57
+
58
+ if is_local_leader():
59
+ hp.save_if_not_exists(args.run_dir)
60
+ hp.print()
61
+
62
+ train_dl, val_dl = create_dataloaders(hp, mode="enhancer")
63
+
64
+ def feed_G(engine: Engine, batch: dict[str, Tensor]):
65
+ if hp.lcfm_training_mode == "ae":
66
+ pred = engine(batch["fg_wavs"], batch["fg_wavs"])
67
+ elif hp.lcfm_training_mode == "cfm":
68
+ alpha_fn = lambda: random.uniform(*hp.mix_alpha_range)
69
+ mx_dwavs = mix_fg_bg(batch["fg_dwavs"], batch["bg_dwavs"], alpha=alpha_fn)
70
+ pred = engine(mx_dwavs, batch["fg_wavs"], batch["fg_dwavs"])
71
+ else:
72
+ raise ValueError(f"Unknown training mode: {hp.lcfm_training_mode}")
73
+ losses = engine.gather_attribute("losses")
74
+ return pred, losses
75
+
76
+ def feed_D(engine: Engine, batch: dict | None, fake: Tensor):
77
+ if batch is None:
78
+ losses = engine(fake=fake)
79
+ else:
80
+ losses = engine(fake=fake, real=batch["fg_wavs"])
81
+ return losses
82
+
83
+ @torch.no_grad()
84
+ def eval_fn(engine: Engine, eval_dir, n_saved=10):
85
+ assert isinstance(hp, HParams)
86
+
87
+ model = engine.module
88
+ model.eval()
89
+
90
+ step = engine.global_step
91
+
92
+ for i, batch in enumerate(tqdm(val_dl), 1):
93
+ batch = tree_map(lambda x: x.to(args.device) if isinstance(x, Tensor) else x, batch)
94
+
95
+ fg_wavs = batch["fg_wavs"] # 1 t
96
+
97
+ if hp.lcfm_training_mode == "ae":
98
+ in_dwavs = fg_wavs
99
+ elif hp.lcfm_training_mode == "cfm":
100
+ in_dwavs = mix_fg_bg(fg_wavs, batch["bg_dwavs"])
101
+ else:
102
+ raise ValueError(f"Unknown training mode: {hp.lcfm_training_mode}")
103
+
104
+ pred_fg_wavs = model(in_dwavs) # 1 t
105
+
106
+ in_mels = model.to_mel(in_dwavs) # 1 c t
107
+ fg_mels = model.to_mel(fg_wavs) # 1 c t
108
+ pred_fg_mels = model.to_mel(pred_fg_wavs) # 1 c t
109
+
110
+ rate = model.hp.wav_rate
111
+ get_path = lambda suffix: eval_dir / f"step_{step:08}_{i:03}{suffix}"
112
+
113
+ save_wav(get_path("_input.wav"), in_dwavs[0], rate=rate)
114
+ save_wav(get_path("_predict.wav"), pred_fg_wavs[0], rate=rate)
115
+ save_wav(get_path("_target.wav"), fg_wavs[0], rate=rate)
116
+
117
+ save_mels(
118
+ get_path(".png"),
119
+ cond_mel=in_mels[0].cpu().numpy(),
120
+ pred_mel=pred_fg_mels[0].cpu().numpy(),
121
+ targ_mel=fg_mels[0].cpu().numpy(),
122
+ )
123
+
124
+ if i >= n_saved:
125
+ break
126
+
127
+ train_loop = TrainLoop(
128
+ run_dir=args.run_dir,
129
+ train_dl=train_dl,
130
+ load_G=partial(load_G, hp=hp),
131
+ load_D=partial(load_D, hp=hp),
132
+ device=args.device,
133
+ feed_G=feed_G,
134
+ feed_D=feed_D,
135
+ eval_fn=eval_fn,
136
+ gan_training_start_step=hp.gan_training_start_step,
137
+ )
138
+
139
+ train_loop.run(max_steps=hp.max_steps)
140
+
141
+
142
+ if __name__ == "__main__":
143
+ main()
resemble-enhance/resemble_enhance/enhancer/univnet/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .univnet import UnivNet
resemble-enhance/resemble_enhance/enhancer/univnet/alias_free_torch/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ from .filter import *
5
+ from .resample import *
resemble-enhance/resemble_enhance/enhancer/univnet/alias_free_torch/filter.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import math
8
+
9
+ if 'sinc' in dir(torch):
10
+ sinc = torch.sinc
11
+ else:
12
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
13
+ # https://adefossez.github.io/julius/julius/core.html
14
+ # LICENSE is in incl_licenses directory.
15
+ def sinc(x: torch.Tensor):
16
+ """
17
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19
+ """
20
+ return torch.where(x == 0,
21
+ torch.tensor(1., device=x.device, dtype=x.dtype),
22
+ torch.sin(math.pi * x) / math.pi / x)
23
+
24
+
25
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
26
+ # https://adefossez.github.io/julius/julius/lowpass.html
27
+ # LICENSE is in incl_licenses directory.
28
+ def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
29
+ even = (kernel_size % 2 == 0)
30
+ half_size = kernel_size // 2
31
+
32
+ #For kaiser window
33
+ delta_f = 4 * half_width
34
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
35
+ if A > 50.:
36
+ beta = 0.1102 * (A - 8.7)
37
+ elif A >= 21.:
38
+ beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
39
+ else:
40
+ beta = 0.
41
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
42
+
43
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
44
+ if even:
45
+ time = (torch.arange(-half_size, half_size) + 0.5)
46
+ else:
47
+ time = torch.arange(kernel_size) - half_size
48
+ if cutoff == 0:
49
+ filter_ = torch.zeros_like(time)
50
+ else:
51
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
52
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
53
+ # of the constant component in the input signal.
54
+ filter_ /= filter_.sum()
55
+ filter = filter_.view(1, 1, kernel_size)
56
+
57
+ return filter
58
+
59
+
60
+ class LowPassFilter1d(nn.Module):
61
+ def __init__(self,
62
+ cutoff=0.5,
63
+ half_width=0.6,
64
+ stride: int = 1,
65
+ padding: bool = True,
66
+ padding_mode: str = 'replicate',
67
+ kernel_size: int = 12):
68
+ # kernel_size should be even number for stylegan3 setup,
69
+ # in this implementation, odd number is also possible.
70
+ super().__init__()
71
+ if cutoff < -0.:
72
+ raise ValueError("Minimum cutoff must be larger than zero.")
73
+ if cutoff > 0.5:
74
+ raise ValueError("A cutoff above 0.5 does not make sense.")
75
+ self.kernel_size = kernel_size
76
+ self.even = (kernel_size % 2 == 0)
77
+ self.pad_left = kernel_size // 2 - int(self.even)
78
+ self.pad_right = kernel_size // 2
79
+ self.stride = stride
80
+ self.padding = padding
81
+ self.padding_mode = padding_mode
82
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
83
+ self.register_buffer("filter", filter)
84
+
85
+ #input [B, C, T]
86
+ def forward(self, x):
87
+ _, C, _ = x.shape
88
+
89
+ if self.padding:
90
+ x = F.pad(x, (self.pad_left, self.pad_right),
91
+ mode=self.padding_mode)
92
+ out = F.conv1d(x, self.filter.expand(C, -1, -1),
93
+ stride=self.stride, groups=C)
94
+
95
+ return out
resemble-enhance/resemble_enhance/enhancer/univnet/alias_free_torch/resample.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from .filter import LowPassFilter1d
7
+ from .filter import kaiser_sinc_filter1d
8
+
9
+
10
+ class UpSample1d(nn.Module):
11
+ def __init__(self, ratio=2, kernel_size=None):
12
+ super().__init__()
13
+ self.ratio = ratio
14
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15
+ self.stride = ratio
16
+ self.pad = self.kernel_size // ratio - 1
17
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
18
+ self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
19
+ filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
20
+ half_width=0.6 / ratio,
21
+ kernel_size=self.kernel_size)
22
+ self.register_buffer("filter", filter)
23
+
24
+ # x: [B, C, T]
25
+ def forward(self, x):
26
+ _, C, _ = x.shape
27
+
28
+ x = F.pad(x, (self.pad, self.pad), mode='replicate')
29
+ x = self.ratio * F.conv_transpose1d(
30
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
31
+ x = x[..., self.pad_left:-self.pad_right]
32
+
33
+ return x
34
+
35
+
36
+ class DownSample1d(nn.Module):
37
+ def __init__(self, ratio=2, kernel_size=None):
38
+ super().__init__()
39
+ self.ratio = ratio
40
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
41
+ self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
42
+ half_width=0.6 / ratio,
43
+ stride=ratio,
44
+ kernel_size=self.kernel_size)
45
+
46
+ def forward(self, x):
47
+ xx = self.lowpass(x)
48
+
49
+ return xx
resemble-enhance/resemble_enhance/enhancer/univnet/amp.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Refer from https://github.com/NVIDIA/BigVGAN
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch import nn
8
+ from torch.nn.utils.parametrizations import weight_norm
9
+
10
+ from .alias_free_torch import DownSample1d, UpSample1d
11
+
12
+
13
+ class SnakeBeta(nn.Module):
14
+ """
15
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
16
+ Shape:
17
+ - Input: (B, C, T)
18
+ - Output: (B, C, T), same shape as the input
19
+ Parameters:
20
+ - alpha - trainable parameter that controls frequency
21
+ - beta - trainable parameter that controls magnitude
22
+ References:
23
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
24
+ https://arxiv.org/abs/2006.08195
25
+ Examples:
26
+ >>> a1 = snakebeta(256)
27
+ >>> x = torch.randn(256)
28
+ >>> x = a1(x)
29
+ """
30
+
31
+ def __init__(self, in_features, alpha=1.0, clamp=(1e-2, 50)):
32
+ """
33
+ Initialization.
34
+ INPUT:
35
+ - in_features: shape of the input
36
+ - alpha - trainable parameter that controls frequency
37
+ - beta - trainable parameter that controls magnitude
38
+ alpha is initialized to 1 by default, higher values = higher-frequency.
39
+ beta is initialized to 1 by default, higher values = higher-magnitude.
40
+ alpha will be trained along with the rest of your model.
41
+ """
42
+ super().__init__()
43
+ self.in_features = in_features
44
+ self.log_alpha = nn.Parameter(torch.zeros(in_features) + math.log(alpha))
45
+ self.log_beta = nn.Parameter(torch.zeros(in_features) + math.log(alpha))
46
+ self.clamp = clamp
47
+
48
+ def forward(self, x):
49
+ """
50
+ Forward pass of the function.
51
+ Applies the function to the input elementwise.
52
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
53
+ """
54
+ alpha = self.log_alpha.exp().clamp(*self.clamp)
55
+ alpha = alpha[None, :, None]
56
+
57
+ beta = self.log_beta.exp().clamp(*self.clamp)
58
+ beta = beta[None, :, None]
59
+
60
+ x = x + (1.0 / beta) * (x * alpha).sin().pow(2)
61
+
62
+ return x
63
+
64
+
65
+ class UpActDown(nn.Module):
66
+ def __init__(
67
+ self,
68
+ act,
69
+ up_ratio: int = 2,
70
+ down_ratio: int = 2,
71
+ up_kernel_size: int = 12,
72
+ down_kernel_size: int = 12,
73
+ ):
74
+ super().__init__()
75
+ self.up_ratio = up_ratio
76
+ self.down_ratio = down_ratio
77
+ self.act = act
78
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
79
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
80
+
81
+ def forward(self, x):
82
+ # x: [B,C,T]
83
+ x = self.upsample(x)
84
+ x = self.act(x)
85
+ x = self.downsample(x)
86
+ return x
87
+
88
+
89
+ class AMPBlock(nn.Sequential):
90
+ def __init__(self, channels, *, kernel_size=3, dilations=(1, 3, 5)):
91
+ super().__init__(*(self._make_layer(channels, kernel_size, d) for d in dilations))
92
+
93
+ def _make_layer(self, channels, kernel_size, dilation):
94
+ return nn.Sequential(
95
+ weight_norm(nn.Conv1d(channels, channels, kernel_size, dilation=dilation, padding="same")),
96
+ UpActDown(act=SnakeBeta(channels)),
97
+ weight_norm(nn.Conv1d(channels, channels, kernel_size, padding="same")),
98
+ )
99
+
100
+ def forward(self, x):
101
+ return x + super().forward(x)
resemble-enhance/resemble_enhance/enhancer/univnet/discriminator.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import Tensor, nn
6
+ from torch.nn.utils.parametrizations import weight_norm
7
+
8
+ from ..hparams import HParams
9
+ from .mrstft import get_stft_cfgs
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class PeriodNetwork(nn.Module):
15
+ def __init__(self, period):
16
+ super().__init__()
17
+ self.period = period
18
+ wn = weight_norm
19
+ self.convs = nn.ModuleList(
20
+ [
21
+ wn(nn.Conv2d(1, 64, (5, 1), (3, 1), padding=(2, 0))),
22
+ wn(nn.Conv2d(64, 128, (5, 1), (3, 1), padding=(2, 0))),
23
+ wn(nn.Conv2d(128, 256, (5, 1), (3, 1), padding=(2, 0))),
24
+ wn(nn.Conv2d(256, 512, (5, 1), (3, 1), padding=(2, 0))),
25
+ wn(nn.Conv2d(512, 1024, (5, 1), 1, padding=(2, 0))),
26
+ ]
27
+ )
28
+ self.conv_post = wn(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
29
+
30
+ def forward(self, x):
31
+ """
32
+ Args:
33
+ x: [B, 1, T]
34
+ """
35
+ assert x.dim() == 3, f"(B, 1, T) is expected, but got {x.shape}."
36
+
37
+ # 1d to 2d
38
+ b, c, t = x.shape
39
+ if t % self.period != 0: # pad first
40
+ n_pad = self.period - (t % self.period)
41
+ x = F.pad(x, (0, n_pad), "reflect")
42
+ t = t + n_pad
43
+ x = x.view(b, c, t // self.period, self.period)
44
+
45
+ for l in self.convs:
46
+ x = l(x)
47
+ x = F.leaky_relu(x, 0.2)
48
+ x = self.conv_post(x)
49
+ x = torch.flatten(x, 1, -1)
50
+
51
+ return x
52
+
53
+
54
+ class SpecNetwork(nn.Module):
55
+ def __init__(self, stft_cfg: dict):
56
+ super().__init__()
57
+ wn = weight_norm
58
+ self.stft_cfg = stft_cfg
59
+ self.convs = nn.ModuleList(
60
+ [
61
+ wn(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
62
+ wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
63
+ wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
64
+ wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
65
+ wn(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
66
+ ]
67
+ )
68
+ self.conv_post = wn(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
69
+
70
+ def forward(self, x):
71
+ """
72
+ Args:
73
+ x: [B, 1, T]
74
+ """
75
+ x = self.spectrogram(x)
76
+ x = x.unsqueeze(1)
77
+ for l in self.convs:
78
+ x = l(x)
79
+ x = F.leaky_relu(x, 0.2)
80
+ x = self.conv_post(x)
81
+ x = x.flatten(1, -1)
82
+ return x
83
+
84
+ def spectrogram(self, x):
85
+ """
86
+ Args:
87
+ x: [B, 1, T]
88
+ """
89
+ x = x.squeeze(1)
90
+ dtype = x.dtype
91
+ stft_cfg = dict(self.stft_cfg)
92
+ x = torch.stft(x.float(), center=False, return_complex=False, **stft_cfg)
93
+ mag = x.norm(p=2, dim=-1) # [B, F, TT]
94
+ mag = mag.to(dtype) # [B, F, TT]
95
+ return mag
96
+
97
+
98
+ class MD(nn.ModuleList):
99
+ def __init__(self, l: list):
100
+ super().__init__([self._create_network(x) for x in l])
101
+ self._loss_type = None
102
+
103
+ def loss_type_(self, loss_type):
104
+ self._loss_type = loss_type
105
+
106
+ def _create_network(self, _):
107
+ raise NotImplementedError
108
+
109
+ def _forward_each(self, d, x, y):
110
+ assert self._loss_type is not None, "loss_type is not set."
111
+ loss_type = self._loss_type
112
+
113
+ if loss_type == "hinge":
114
+ if y == 0:
115
+ # d(x) should be small -> -1
116
+ loss = F.relu(1 + d(x)).mean()
117
+ elif y == 1:
118
+ # d(x) should be large -> 1
119
+ loss = F.relu(1 - d(x)).mean()
120
+ else:
121
+ raise ValueError(f"Invalid y: {y}")
122
+ elif loss_type == "wgan":
123
+ if y == 0:
124
+ loss = d(x).mean()
125
+ elif y == 1:
126
+ loss = -d(x).mean()
127
+ else:
128
+ raise ValueError(f"Invalid y: {y}")
129
+ else:
130
+ raise ValueError(f"Invalid loss_type: {loss_type}")
131
+
132
+ return loss
133
+
134
+ def forward(self, x, y) -> Tensor:
135
+ losses = [self._forward_each(d, x, y) for d in self]
136
+ return torch.stack(losses).mean()
137
+
138
+
139
+ class MPD(MD):
140
+ def __init__(self):
141
+ super().__init__([2, 3, 7, 13, 17])
142
+
143
+ def _create_network(self, period):
144
+ return PeriodNetwork(period)
145
+
146
+
147
+ class MRD(MD):
148
+ def __init__(self, stft_cfgs):
149
+ super().__init__(stft_cfgs)
150
+
151
+ def _create_network(self, stft_cfg):
152
+ return SpecNetwork(stft_cfg)
153
+
154
+
155
+ class Discriminator(nn.Module):
156
+ @property
157
+ def wav_rate(self):
158
+ return self.hp.wav_rate
159
+
160
+ def __init__(self, hp: HParams):
161
+ super().__init__()
162
+ self.hp = hp
163
+ self.stft_cfgs = get_stft_cfgs(hp)
164
+ self.mpd = MPD()
165
+ self.mrd = MRD(self.stft_cfgs)
166
+ self.dummy_float: Tensor
167
+ self.register_buffer("dummy_float", torch.zeros(0), persistent=False)
168
+
169
+ def loss_type_(self, loss_type):
170
+ self.mpd.loss_type_(loss_type)
171
+ self.mrd.loss_type_(loss_type)
172
+
173
+ def forward(self, fake, real=None):
174
+ """
175
+ Args:
176
+ fake: [B T]
177
+ real: [B T]
178
+ """
179
+ fake = fake.to(self.dummy_float)
180
+
181
+ if real is None:
182
+ self.loss_type_("wgan")
183
+ else:
184
+ length_difference = (fake.shape[-1] - real.shape[-1]) / real.shape[-1]
185
+ assert length_difference < 0.05, f"length_difference should be smaller than 5%"
186
+
187
+ self.loss_type_("hinge")
188
+ real = real.to(self.dummy_float)
189
+
190
+ fake = fake[..., : real.shape[-1]]
191
+ real = real[..., : fake.shape[-1]]
192
+
193
+ losses = {}
194
+
195
+ assert fake.dim() == 2, f"(B, T) is expected, but got {fake.shape}."
196
+ assert real is None or real.dim() == 2, f"(B, T) is expected, but got {real.shape}."
197
+
198
+ fake = fake.unsqueeze(1)
199
+
200
+ if real is None:
201
+ losses["mpd"] = self.mpd(fake, 1)
202
+ losses["mrd"] = self.mrd(fake, 1)
203
+ else:
204
+ real = real.unsqueeze(1)
205
+ losses["mpd_fake"] = self.mpd(fake, 0)
206
+ losses["mpd_real"] = self.mpd(real, 1)
207
+ losses["mrd_fake"] = self.mrd(fake, 0)
208
+ losses["mrd_real"] = self.mrd(real, 1)
209
+
210
+ return losses
resemble-enhance/resemble_enhance/enhancer/univnet/lvcnet.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ refer from https://github.com/zceng/LVCNet """
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+ from torch.nn.utils.parametrizations import weight_norm
8
+
9
+ from .amp import AMPBlock
10
+
11
+
12
+ class KernelPredictor(torch.nn.Module):
13
+ """Kernel predictor for the location-variable convolutions"""
14
+
15
+ def __init__(
16
+ self,
17
+ cond_channels,
18
+ conv_in_channels,
19
+ conv_out_channels,
20
+ conv_layers,
21
+ conv_kernel_size=3,
22
+ kpnet_hidden_channels=64,
23
+ kpnet_conv_size=3,
24
+ kpnet_dropout=0.0,
25
+ kpnet_nonlinear_activation="LeakyReLU",
26
+ kpnet_nonlinear_activation_params={"negative_slope": 0.1},
27
+ ):
28
+ """
29
+ Args:
30
+ cond_channels (int): number of channel for the conditioning sequence,
31
+ conv_in_channels (int): number of channel for the input sequence,
32
+ conv_out_channels (int): number of channel for the output sequence,
33
+ conv_layers (int): number of layers
34
+ """
35
+ super().__init__()
36
+
37
+ self.conv_in_channels = conv_in_channels
38
+ self.conv_out_channels = conv_out_channels
39
+ self.conv_kernel_size = conv_kernel_size
40
+ self.conv_layers = conv_layers
41
+
42
+ kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
43
+ kpnet_bias_channels = conv_out_channels * conv_layers # l_b
44
+
45
+ self.input_conv = nn.Sequential(
46
+ weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
47
+ getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
48
+ )
49
+
50
+ self.residual_convs = nn.ModuleList()
51
+ padding = (kpnet_conv_size - 1) // 2
52
+ for _ in range(3):
53
+ self.residual_convs.append(
54
+ nn.Sequential(
55
+ nn.Dropout(kpnet_dropout),
56
+ weight_norm(
57
+ nn.Conv1d(
58
+ kpnet_hidden_channels,
59
+ kpnet_hidden_channels,
60
+ kpnet_conv_size,
61
+ padding=padding,
62
+ bias=True,
63
+ )
64
+ ),
65
+ getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
66
+ weight_norm(
67
+ nn.Conv1d(
68
+ kpnet_hidden_channels,
69
+ kpnet_hidden_channels,
70
+ kpnet_conv_size,
71
+ padding=padding,
72
+ bias=True,
73
+ )
74
+ ),
75
+ getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
76
+ )
77
+ )
78
+ self.kernel_conv = weight_norm(
79
+ nn.Conv1d(
80
+ kpnet_hidden_channels,
81
+ kpnet_kernel_channels,
82
+ kpnet_conv_size,
83
+ padding=padding,
84
+ bias=True,
85
+ )
86
+ )
87
+ self.bias_conv = weight_norm(
88
+ nn.Conv1d(
89
+ kpnet_hidden_channels,
90
+ kpnet_bias_channels,
91
+ kpnet_conv_size,
92
+ padding=padding,
93
+ bias=True,
94
+ )
95
+ )
96
+
97
+ def forward(self, c):
98
+ """
99
+ Args:
100
+ c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
101
+ """
102
+ batch, _, cond_length = c.shape
103
+ c = self.input_conv(c)
104
+ for residual_conv in self.residual_convs:
105
+ residual_conv.to(c.device)
106
+ c = c + residual_conv(c)
107
+ k = self.kernel_conv(c)
108
+ b = self.bias_conv(c)
109
+ kernels = k.contiguous().view(
110
+ batch,
111
+ self.conv_layers,
112
+ self.conv_in_channels,
113
+ self.conv_out_channels,
114
+ self.conv_kernel_size,
115
+ cond_length,
116
+ )
117
+ bias = b.contiguous().view(
118
+ batch,
119
+ self.conv_layers,
120
+ self.conv_out_channels,
121
+ cond_length,
122
+ )
123
+
124
+ return kernels, bias
125
+
126
+
127
+ class LVCBlock(torch.nn.Module):
128
+ """the location-variable convolutions"""
129
+
130
+ def __init__(
131
+ self,
132
+ in_channels,
133
+ cond_channels,
134
+ stride,
135
+ dilations=[1, 3, 9, 27],
136
+ lReLU_slope=0.2,
137
+ conv_kernel_size=3,
138
+ cond_hop_length=256,
139
+ kpnet_hidden_channels=64,
140
+ kpnet_conv_size=3,
141
+ kpnet_dropout=0.0,
142
+ add_extra_noise=False,
143
+ downsampling=False,
144
+ ):
145
+ super().__init__()
146
+
147
+ self.add_extra_noise = add_extra_noise
148
+
149
+ self.cond_hop_length = cond_hop_length
150
+ self.conv_layers = len(dilations)
151
+ self.conv_kernel_size = conv_kernel_size
152
+
153
+ self.kernel_predictor = KernelPredictor(
154
+ cond_channels=cond_channels,
155
+ conv_in_channels=in_channels,
156
+ conv_out_channels=2 * in_channels,
157
+ conv_layers=len(dilations),
158
+ conv_kernel_size=conv_kernel_size,
159
+ kpnet_hidden_channels=kpnet_hidden_channels,
160
+ kpnet_conv_size=kpnet_conv_size,
161
+ kpnet_dropout=kpnet_dropout,
162
+ kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope},
163
+ )
164
+
165
+ if downsampling:
166
+ self.convt_pre = nn.Sequential(
167
+ nn.LeakyReLU(lReLU_slope),
168
+ weight_norm(nn.Conv1d(in_channels, in_channels, 2 * stride + 1, padding="same")),
169
+ nn.AvgPool1d(stride, stride),
170
+ )
171
+ else:
172
+ if stride == 1:
173
+ self.convt_pre = nn.Sequential(
174
+ nn.LeakyReLU(lReLU_slope),
175
+ weight_norm(nn.Conv1d(in_channels, in_channels, 1)),
176
+ )
177
+ else:
178
+ self.convt_pre = nn.Sequential(
179
+ nn.LeakyReLU(lReLU_slope),
180
+ weight_norm(
181
+ nn.ConvTranspose1d(
182
+ in_channels,
183
+ in_channels,
184
+ 2 * stride,
185
+ stride=stride,
186
+ padding=stride // 2 + stride % 2,
187
+ output_padding=stride % 2,
188
+ )
189
+ ),
190
+ )
191
+
192
+ self.amp_block = AMPBlock(in_channels)
193
+
194
+ self.conv_blocks = nn.ModuleList()
195
+ for d in dilations:
196
+ self.conv_blocks.append(
197
+ nn.Sequential(
198
+ nn.LeakyReLU(lReLU_slope),
199
+ weight_norm(nn.Conv1d(in_channels, in_channels, conv_kernel_size, dilation=d, padding="same")),
200
+ nn.LeakyReLU(lReLU_slope),
201
+ )
202
+ )
203
+
204
+ def forward(self, x, c):
205
+ """forward propagation of the location-variable convolutions.
206
+ Args:
207
+ x (Tensor): the input sequence (batch, in_channels, in_length)
208
+ c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
209
+
210
+ Returns:
211
+ Tensor: the output sequence (batch, in_channels, in_length)
212
+ """
213
+ _, in_channels, _ = x.shape # (B, c_g, L')
214
+
215
+ x = self.convt_pre(x) # (B, c_g, stride * L')
216
+
217
+ # Add one amp block just after the upsampling
218
+ x = self.amp_block(x) # (B, c_g, stride * L')
219
+
220
+ kernels, bias = self.kernel_predictor(c)
221
+
222
+ if self.add_extra_noise:
223
+ # Add extra noise to part of the feature
224
+ a, b = x.chunk(2, dim=1)
225
+ b = b + torch.randn_like(b) * 0.1
226
+ x = torch.cat([a, b], dim=1)
227
+
228
+ for i, conv in enumerate(self.conv_blocks):
229
+ output = conv(x) # (B, c_g, stride * L')
230
+
231
+ k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
232
+ b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
233
+
234
+ output = self.location_variable_convolution(
235
+ output, k, b, hop_size=self.cond_hop_length
236
+ ) # (B, 2 * c_g, stride * L'): LVC
237
+ x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
238
+ output[:, in_channels:, :]
239
+ ) # (B, c_g, stride * L'): GAU
240
+
241
+ return x
242
+
243
+ def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
244
+ """perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
245
+ Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
246
+ Args:
247
+ x (Tensor): the input sequence (batch, in_channels, in_length).
248
+ kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
249
+ bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
250
+ dilation (int): the dilation of convolution.
251
+ hop_size (int): the hop_size of the conditioning sequence.
252
+ Returns:
253
+ (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
254
+ """
255
+ batch, _, in_length = x.shape
256
+ batch, _, out_channels, kernel_size, kernel_length = kernel.shape
257
+
258
+ assert in_length == (
259
+ kernel_length * hop_size
260
+ ), f"length of (x, kernel) is not matched, {in_length} != {kernel_length} * {hop_size}"
261
+
262
+ padding = dilation * int((kernel_size - 1) / 2)
263
+ x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
264
+ x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
265
+
266
+ if hop_size < dilation:
267
+ x = F.pad(x, (0, dilation), "constant", 0)
268
+ x = x.unfold(
269
+ 3, dilation, dilation
270
+ ) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
271
+ x = x[:, :, :, :, :hop_size]
272
+ x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
273
+ x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
274
+
275
+ o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
276
+ o = o.to(memory_format=torch.channels_last_3d)
277
+ bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
278
+ o = o + bias
279
+ o = o.contiguous().view(batch, out_channels, -1)
280
+
281
+ return o
resemble-enhance/resemble_enhance/enhancer/univnet/mrstft.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2019 Tomoki Hayashi
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ from ..hparams import HParams
12
+
13
+
14
+ def _make_stft_cfg(hop_length, win_length=None):
15
+ if win_length is None:
16
+ win_length = 4 * hop_length
17
+ n_fft = 2 ** (win_length - 1).bit_length()
18
+ return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
19
+
20
+
21
+ def get_stft_cfgs(hp: HParams):
22
+ assert hp.wav_rate == 44100, f"wav_rate must be 44100, got {hp.wav_rate}"
23
+ return [_make_stft_cfg(h) for h in (100, 256, 512)]
24
+
25
+
26
+ def stft(x, n_fft, hop_length, win_length, window):
27
+ dtype = x.dtype
28
+ x = torch.stft(x.float(), n_fft, hop_length, win_length, window, return_complex=True)
29
+ x = x.abs().to(dtype)
30
+ x = x.transpose(2, 1) # (b f t) -> (b t f)
31
+ return x
32
+
33
+
34
+ class SpectralConvergengeLoss(nn.Module):
35
+ def forward(self, x_mag, y_mag):
36
+ """Calculate forward propagation.
37
+ Args:
38
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
39
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
40
+ Returns:
41
+ Tensor: Spectral convergence loss value.
42
+ """
43
+ return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
44
+
45
+
46
+ class LogSTFTMagnitudeLoss(nn.Module):
47
+ def forward(self, x_mag, y_mag):
48
+ """Calculate forward propagation.
49
+ Args:
50
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
51
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
52
+ Returns:
53
+ Tensor: Log STFT magnitude loss value.
54
+ """
55
+ return F.l1_loss(torch.log1p(x_mag), torch.log1p(y_mag))
56
+
57
+
58
+ class STFTLoss(nn.Module):
59
+ def __init__(self, hp, stft_cfg: dict, window="hann_window"):
60
+ super().__init__()
61
+ self.hp = hp
62
+ self.stft_cfg = stft_cfg
63
+ self.spectral_convergenge_loss = SpectralConvergengeLoss()
64
+ self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
65
+ self.register_buffer("window", getattr(torch, window)(stft_cfg["win_length"]), persistent=False)
66
+
67
+ def forward(self, x, y):
68
+ """Calculate forward propagation.
69
+ Args:
70
+ x (Tensor): Predicted signal (B, T).
71
+ y (Tensor): Groundtruth signal (B, T).
72
+ Returns:
73
+ Tensor: Spectral convergence loss value.
74
+ Tensor: Log STFT magnitude loss value.
75
+ """
76
+ stft_cfg = dict(self.stft_cfg)
77
+ x_mag = stft(x, **stft_cfg, window=self.window) # (b t) -> (b t f)
78
+ y_mag = stft(y, **stft_cfg, window=self.window)
79
+ sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
80
+ mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
81
+ return dict(sc=sc_loss, mag=mag_loss)
82
+
83
+
84
+ class MRSTFTLoss(nn.Module):
85
+ def __init__(self, hp: HParams, window="hann_window"):
86
+ """Initialize Multi resolution STFT loss module.
87
+ Args:
88
+ resolutions (list): List of (FFT size, hop size, window length).
89
+ window (str): Window function type.
90
+ """
91
+ super().__init__()
92
+ stft_cfgs = get_stft_cfgs(hp)
93
+ self.stft_losses = nn.ModuleList()
94
+ self.hp = hp
95
+ for c in stft_cfgs:
96
+ self.stft_losses += [STFTLoss(hp, c, window=window)]
97
+
98
+ def forward(self, x, y):
99
+ """Calculate forward propagation.
100
+ Args:
101
+ x (Tensor): Predicted signal (b t).
102
+ y (Tensor): Groundtruth signal (b t).
103
+ Returns:
104
+ Tensor: Multi resolution spectral convergence loss value.
105
+ Tensor: Multi resolution log STFT magnitude loss value.
106
+ """
107
+ assert x.dim() == 2 and y.dim() == 2, f"(b t) is expected, but got {x.shape} and {y.shape}."
108
+
109
+ dtype = x.dtype
110
+
111
+ x = x.float()
112
+ y = y.float()
113
+
114
+ # Align length
115
+ x = x[..., : y.shape[-1]]
116
+ y = y[..., : x.shape[-1]]
117
+
118
+ losses = {}
119
+
120
+ for f in self.stft_losses:
121
+ d = f(x, y)
122
+ for k, v in d.items():
123
+ losses.setdefault(k, []).append(v)
124
+
125
+ for k, v in losses.items():
126
+ losses[k] = torch.stack(v, dim=0).mean().to(dtype)
127
+
128
+ return losses
resemble-enhance/resemble_enhance/enhancer/univnet/univnet.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import Tensor, nn
5
+ from torch.nn.utils.parametrizations import weight_norm
6
+
7
+ from ..hparams import HParams
8
+ from .lvcnet import LVCBlock
9
+ from .mrstft import MRSTFTLoss
10
+
11
+
12
+ class UnivNet(nn.Module):
13
+ @property
14
+ def d_noise(self):
15
+ return 128
16
+
17
+ @property
18
+ def strides(self):
19
+ return [7, 5, 4, 3]
20
+
21
+ @property
22
+ def dilations(self):
23
+ return [1, 3, 9, 27]
24
+
25
+ @property
26
+ def nc(self):
27
+ return self.hp.univnet_nc
28
+
29
+ @property
30
+ def scale_factor(self) -> int:
31
+ return self.hp.hop_size
32
+
33
+ def __init__(self, hp: HParams, d_input):
34
+ super().__init__()
35
+ self.d_input = d_input
36
+
37
+ self.hp = hp
38
+
39
+ self.blocks = nn.ModuleList(
40
+ [
41
+ LVCBlock(
42
+ self.nc,
43
+ d_input,
44
+ stride=stride,
45
+ dilations=self.dilations,
46
+ cond_hop_length=hop_length,
47
+ kpnet_conv_size=3,
48
+ )
49
+ for stride, hop_length in zip(self.strides, np.cumprod(self.strides))
50
+ ]
51
+ )
52
+
53
+ self.conv_pre = weight_norm(nn.Conv1d(self.d_noise, self.nc, 7, padding=3, padding_mode="reflect"))
54
+
55
+ self.conv_post = nn.Sequential(
56
+ nn.LeakyReLU(0.2),
57
+ weight_norm(nn.Conv1d(self.nc, 1, 7, padding=3, padding_mode="reflect")),
58
+ nn.Tanh(),
59
+ )
60
+
61
+ self.mrstft = MRSTFTLoss(hp)
62
+
63
+ @property
64
+ def eps(self):
65
+ return 1e-5
66
+
67
+ def forward(self, x: Tensor, y: Tensor | None = None, npad=10):
68
+ """
69
+ Args:
70
+ x: (b c t), acoustic features
71
+ y: (b t), waveform
72
+ Returns:
73
+ z: (b t), waveform
74
+ """
75
+ assert x.ndim == 3, "x must be 3D tensor"
76
+ assert y is None or y.ndim == 2, "y must be 2D tensor"
77
+ assert x.shape[1] == self.d_input, f"x.shape[1] must be {self.d_input}, but got {x.shape}"
78
+ assert npad >= 0, "npad must be positive or zero"
79
+
80
+ x = F.pad(x, (0, npad), "constant", 0)
81
+ z = torch.randn(x.shape[0], self.d_noise, x.shape[2]).to(x)
82
+ z = self.conv_pre(z) # (b c t)
83
+
84
+ for block in self.blocks:
85
+ z = block(z, x) # (b c t)
86
+
87
+ z = self.conv_post(z) # (b 1 t)
88
+ z = z[..., : -self.scale_factor * npad]
89
+ z = z.squeeze(1) # (b t)
90
+
91
+ if y is not None:
92
+ self.losses = self.mrstft(z, y)
93
+
94
+ return z
resemble-enhance/resemble_enhance/hparams.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import asdict, dataclass
3
+ from pathlib import Path
4
+
5
+ from omegaconf import OmegaConf
6
+ from rich.console import Console
7
+ from rich.panel import Panel
8
+ from rich.table import Table
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ console = Console()
13
+
14
+
15
+ def _make_stft_cfg(hop_length, win_length=None):
16
+ if win_length is None:
17
+ win_length = 4 * hop_length
18
+ n_fft = 2 ** (win_length - 1).bit_length()
19
+ return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
20
+
21
+
22
+ def _build_rich_table(rows, columns, title=None):
23
+ table = Table(title=title, header_style=None)
24
+ for column in columns:
25
+ table.add_column(column.capitalize(), justify="left")
26
+ for row in rows:
27
+ table.add_row(*map(str, row))
28
+ return Panel(table, expand=False)
29
+
30
+
31
+ def _rich_print_dict(d, title="Config", key="Key", value="Value"):
32
+ console.print(_build_rich_table(d.items(), [key, value], title))
33
+
34
+
35
+ @dataclass(frozen=True)
36
+ class HParams:
37
+ # Dataset
38
+ fg_dir: Path = Path("data/fg")
39
+ bg_dir: Path = Path("data/bg")
40
+ rir_dir: Path = Path("data/rir")
41
+ load_fg_only: bool = False
42
+ praat_augment_prob: float = 0
43
+
44
+ # Audio settings
45
+ wav_rate: int = 44_100
46
+ n_fft: int = 2048
47
+ win_size: int = 2048
48
+ hop_size: int = 420 # 9.5ms
49
+ num_mels: int = 128
50
+ stft_magnitude_min: float = 1e-4
51
+ preemphasis: float = 0.97
52
+ mix_alpha_range: tuple[float, float] = (0.2, 0.8)
53
+
54
+ # Training
55
+ nj: int = 64
56
+ training_seconds: float = 1.0
57
+ batch_size_per_gpu: int = 16
58
+ min_lr: float = 1e-5
59
+ max_lr: float = 1e-4
60
+ warmup_steps: int = 1000
61
+ max_steps: int = 1_000_000
62
+ gradient_clipping: float = 1.0
63
+
64
+ @property
65
+ def deepspeed_config(self):
66
+ return {
67
+ "train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
68
+ "optimizer": {
69
+ "type": "Adam",
70
+ "params": {"lr": float(self.min_lr)},
71
+ },
72
+ "scheduler": {
73
+ "type": "WarmupDecayLR",
74
+ "params": {
75
+ "warmup_min_lr": float(self.min_lr),
76
+ "warmup_max_lr": float(self.max_lr),
77
+ "warmup_num_steps": self.warmup_steps,
78
+ "total_num_steps": self.max_steps,
79
+ "warmup_type": "linear",
80
+ },
81
+ },
82
+ "gradient_clipping": self.gradient_clipping,
83
+ }
84
+
85
+ @property
86
+ def stft_cfgs(self):
87
+ assert self.wav_rate == 44_100, f"wav_rate must be 44_100, got {self.wav_rate}"
88
+ return [_make_stft_cfg(h) for h in (100, 256, 512)]
89
+
90
+ @classmethod
91
+ def from_yaml(cls, path: Path) -> "HParams":
92
+ logger.info(f"Reading hparams from {path}")
93
+ # First merge to fix types (e.g., str -> Path)
94
+ return cls(**dict(OmegaConf.merge(cls(), OmegaConf.load(path))))
95
+
96
+ def save_if_not_exists(self, run_dir: Path):
97
+ path = run_dir / "hparams.yaml"
98
+ if path.exists():
99
+ logger.info(f"{path} already exists, not saving")
100
+ return
101
+ path.parent.mkdir(parents=True, exist_ok=True)
102
+ OmegaConf.save(asdict(self), str(path))
103
+
104
+ @classmethod
105
+ def load(cls, run_dir, yaml: Path | None = None):
106
+ hps = []
107
+
108
+ if (run_dir / "hparams.yaml").exists():
109
+ hps.append(cls.from_yaml(run_dir / "hparams.yaml"))
110
+
111
+ if yaml is not None:
112
+ hps.append(cls.from_yaml(yaml))
113
+
114
+ if len(hps) == 0:
115
+ hps.append(cls())
116
+
117
+ for hp in hps[1:]:
118
+ if hp != hps[0]:
119
+ errors = {}
120
+ for k, v in asdict(hp).items():
121
+ if getattr(hps[0], k) != v:
122
+ errors[k] = f"{getattr(hps[0], k)} != {v}"
123
+ raise ValueError(f"Found inconsistent hparams: {errors}, consider deleting {run_dir}")
124
+
125
+ return hps[0]
126
+
127
+ def print(self):
128
+ _rich_print_dict(asdict(self), title="HParams")
resemble-enhance/resemble_enhance/inference.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch.nn.utils.parametrize import remove_parametrizations
7
+ from torchaudio.functional import resample
8
+ from torchaudio.transforms import MelSpectrogram
9
+ from tqdm import trange
10
+
11
+ from .hparams import HParams
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @torch.inference_mode()
17
+ def inference_chunk(model, dwav, sr, device, npad=441):
18
+ assert model.hp.wav_rate == sr, f"Expected {model.hp.wav_rate} Hz, got {sr} Hz"
19
+ del sr
20
+
21
+ length = dwav.shape[-1]
22
+ abs_max = dwav.abs().max().clamp(min=1e-7)
23
+
24
+ assert dwav.dim() == 1, f"Expected 1D waveform, got {dwav.dim()}D"
25
+ dwav = dwav.to(device)
26
+ dwav = dwav / abs_max # Normalize
27
+ dwav = F.pad(dwav, (0, npad))
28
+ hwav = model(dwav[None])[0].cpu() # (T,)
29
+ hwav = hwav[:length] # Trim padding
30
+ hwav = hwav * abs_max # Unnormalize
31
+
32
+ return hwav
33
+
34
+
35
+ def compute_corr(x, y):
36
+ return torch.fft.ifft(torch.fft.fft(x) * torch.fft.fft(y).conj()).abs()
37
+
38
+
39
+ def compute_offset(chunk1, chunk2, sr=44100):
40
+ """
41
+ Args:
42
+ chunk1: (T,)
43
+ chunk2: (T,)
44
+ Returns:
45
+ offset: int, offset in samples such that chunk1 ~= chunk2.roll(-offset)
46
+ """
47
+ hop_length = sr // 200 # 5 ms resolution
48
+ win_length = hop_length * 4
49
+ n_fft = 2 ** (win_length - 1).bit_length()
50
+
51
+ mel_fn = MelSpectrogram(
52
+ sample_rate=sr,
53
+ n_fft=n_fft,
54
+ win_length=win_length,
55
+ hop_length=hop_length,
56
+ n_mels=80,
57
+ f_min=0.0,
58
+ f_max=sr // 2,
59
+ )
60
+
61
+ spec1 = mel_fn(chunk1).log1p()
62
+ spec2 = mel_fn(chunk2).log1p()
63
+
64
+ corr = compute_corr(spec1, spec2) # (F, T)
65
+ corr = corr.mean(dim=0) # (T,)
66
+
67
+ argmax = corr.argmax().item()
68
+
69
+ if argmax > len(corr) // 2:
70
+ argmax -= len(corr)
71
+
72
+ offset = -argmax * hop_length
73
+
74
+ return offset
75
+
76
+
77
+ def merge_chunks(chunks, chunk_length, hop_length, sr=44100, length=None):
78
+ signal_length = (len(chunks) - 1) * hop_length + chunk_length
79
+ overlap_length = chunk_length - hop_length
80
+ signal = torch.zeros(signal_length, device=chunks[0].device)
81
+
82
+ fadein = torch.linspace(0, 1, overlap_length, device=chunks[0].device)
83
+ fadein = torch.cat([fadein, torch.ones(hop_length, device=chunks[0].device)])
84
+ fadeout = torch.linspace(1, 0, overlap_length, device=chunks[0].device)
85
+ fadeout = torch.cat([torch.ones(hop_length, device=chunks[0].device), fadeout])
86
+
87
+ for i, chunk in enumerate(chunks):
88
+ start = i * hop_length
89
+ end = start + chunk_length
90
+
91
+ if len(chunk) < chunk_length:
92
+ chunk = F.pad(chunk, (0, chunk_length - len(chunk)))
93
+
94
+ if i > 0:
95
+ pre_region = chunks[i - 1][-overlap_length:]
96
+ cur_region = chunk[:overlap_length]
97
+ offset = compute_offset(pre_region, cur_region, sr=sr)
98
+ start -= offset
99
+ end -= offset
100
+
101
+ if i == 0:
102
+ chunk = chunk * fadeout
103
+ elif i == len(chunks) - 1:
104
+ chunk = chunk * fadein
105
+ else:
106
+ chunk = chunk * fadein * fadeout
107
+
108
+ signal[start:end] += chunk[: len(signal[start:end])]
109
+
110
+ signal = signal[:length]
111
+
112
+ return signal
113
+
114
+
115
+ def remove_weight_norm_recursively(module):
116
+ for _, module in module.named_modules():
117
+ try:
118
+ remove_parametrizations(module, "weight")
119
+ except Exception:
120
+ pass
121
+
122
+
123
+ def inference(model, dwav, sr, device, chunk_seconds: float = 30.0, overlap_seconds: float = 1.0):
124
+ remove_weight_norm_recursively(model)
125
+
126
+ hp: HParams = model.hp
127
+
128
+ dwav = resample(
129
+ dwav,
130
+ orig_freq=sr,
131
+ new_freq=hp.wav_rate,
132
+ lowpass_filter_width=64,
133
+ rolloff=0.9475937167399596,
134
+ resampling_method="sinc_interp_kaiser",
135
+ beta=14.769656459379492,
136
+ )
137
+
138
+ del sr # Everything is in hp.wav_rate now
139
+
140
+ sr = hp.wav_rate
141
+
142
+ if torch.cuda.is_available():
143
+ torch.cuda.synchronize()
144
+
145
+ start_time = time.perf_counter()
146
+
147
+ chunk_length = int(sr * chunk_seconds)
148
+ overlap_length = int(sr * overlap_seconds)
149
+ hop_length = chunk_length - overlap_length
150
+
151
+ chunks = []
152
+ for start in trange(0, dwav.shape[-1], hop_length):
153
+ chunks.append(inference_chunk(model, dwav[start : start + chunk_length], sr, device))
154
+
155
+ hwav = merge_chunks(chunks, chunk_length, hop_length, sr=sr, length=dwav.shape[-1])
156
+
157
+ if torch.cuda.is_available():
158
+ torch.cuda.synchronize()
159
+
160
+ elapsed_time = time.perf_counter() - start_time
161
+ logger.info(f"Elapsed time: {elapsed_time:.3f} s, {hwav.shape[-1] / elapsed_time / 1000:.3f} kHz")
162
+
163
+ return hwav, sr
resemble-enhance/resemble_enhance/melspec.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+ from torchaudio.transforms import MelSpectrogram as TorchMelSpectrogram
5
+
6
+ from .hparams import HParams
7
+
8
+
9
+ class MelSpectrogram(nn.Module):
10
+ def __init__(self, hp: HParams):
11
+ """
12
+ Torch implementation of Resemble's mel extraction.
13
+ Note that the values are NOT identical to librosa's implementation
14
+ due to floating point precisions.
15
+ """
16
+ super().__init__()
17
+ self.hp = hp
18
+ self.melspec = TorchMelSpectrogram(
19
+ hp.wav_rate,
20
+ n_fft=hp.n_fft,
21
+ win_length=hp.win_size,
22
+ hop_length=hp.hop_size,
23
+ f_min=0,
24
+ f_max=hp.wav_rate // 2,
25
+ n_mels=hp.num_mels,
26
+ power=1,
27
+ normalized=False,
28
+ # NOTE: Folowing librosa's default.
29
+ pad_mode="constant",
30
+ norm="slaney",
31
+ mel_scale="slaney",
32
+ )
33
+ self.register_buffer("stft_magnitude_min", torch.FloatTensor([hp.stft_magnitude_min]))
34
+ self.min_level_db = 20 * np.log10(hp.stft_magnitude_min)
35
+ self.preemphasis = hp.preemphasis
36
+ self.hop_size = hp.hop_size
37
+
38
+ def forward(self, wav, pad=True):
39
+ """
40
+ Args:
41
+ wav: [B, T]
42
+ """
43
+ device = wav.device
44
+ if wav.is_mps:
45
+ wav = wav.cpu()
46
+ self.to(wav.device)
47
+ if self.preemphasis > 0:
48
+ wav = torch.nn.functional.pad(wav, [1, 0], value=0)
49
+ wav = wav[..., 1:] - self.preemphasis * wav[..., :-1]
50
+ mel = self.melspec(wav)
51
+ mel = self._amp_to_db(mel)
52
+ mel_normed = self._normalize(mel)
53
+ assert not pad or mel_normed.shape[-1] == 1 + wav.shape[-1] // self.hop_size # Sanity check
54
+ mel_normed = mel_normed.to(device)
55
+ return mel_normed # (M, T)
56
+
57
+ def _normalize(self, s, headroom_db=15):
58
+ return (s - self.min_level_db) / (-self.min_level_db + headroom_db)
59
+
60
+ def _amp_to_db(self, x):
61
+ return x.clamp_min(self.hp.stft_magnitude_min).log10() * 20
resemble-enhance/resemble_enhance/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .distributed import global_leader_only
2
+ from .engine import Engine, gather_attribute
3
+ from .logging import setup_logging
4
+ from .train_loop import TrainLoop, is_global_leader
5
+ from .utils import save_mels, tree_map
resemble-enhance/resemble_enhance/utils/control.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import selectors
3
+ import sys
4
+ from functools import cache
5
+
6
+ from .distributed import global_leader_only
7
+
8
+ _logger = logging.getLogger(__name__)
9
+
10
+
11
+ @cache
12
+ def _get_stdin_selector():
13
+ selector = selectors.DefaultSelector()
14
+ selector.register(fileobj=sys.stdin, events=selectors.EVENT_READ)
15
+ return selector
16
+
17
+
18
+ @global_leader_only(boardcast_return=True)
19
+ def non_blocking_input():
20
+ s = ""
21
+ selector = _get_stdin_selector()
22
+ events = selector.select(timeout=0)
23
+ for key, _ in events:
24
+ s: str = key.fileobj.readline().strip()
25
+ _logger.info(f'Get stdin "{s}".')
26
+ return s
resemble-enhance/resemble_enhance/utils/distributed.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import socket
3
+ from functools import cache, partial, wraps
4
+ from typing import Callable
5
+
6
+ import deepspeed
7
+ import torch
8
+ from deepspeed.accelerator import get_accelerator
9
+ from torch.distributed import broadcast_object_list
10
+
11
+
12
+ def get_free_port():
13
+ sock = socket.socket()
14
+ sock.bind(("", 0))
15
+ return sock.getsockname()[1]
16
+
17
+
18
+ @cache
19
+ def fix_unset_envs():
20
+ envs = dict(RANK="0", WORLD_SIZE="1", MASTER_ADDR="localhost", MASTER_PORT=str(get_free_port()), LOCAL_RANK="0")
21
+
22
+ for key in envs:
23
+ value = os.getenv(key)
24
+ if value is not None:
25
+ return
26
+
27
+ for key, value in envs.items():
28
+ os.environ[key] = value
29
+
30
+
31
+ @cache
32
+ def init_distributed():
33
+ fix_unset_envs()
34
+ deepspeed.init_distributed(get_accelerator().communication_backend_name())
35
+ torch.cuda.set_device(local_rank())
36
+
37
+
38
+ def local_rank():
39
+ return int(os.getenv("LOCAL_RANK", 0))
40
+
41
+
42
+ def global_rank():
43
+ return int(os.getenv("RANK", 0))
44
+
45
+
46
+ def is_local_leader():
47
+ return local_rank() == 0
48
+
49
+
50
+ def is_global_leader():
51
+ return global_rank() == 0
52
+
53
+
54
+ def leader_only(leader_only_type, fn: Callable | None = None, boardcast_return=False) -> Callable:
55
+ """
56
+ Args:
57
+ fn: The function to decorate
58
+ boardcast_return: Whether to boardcast the return value to all processes
59
+ (may cause deadlock if the function calls another decorated function)
60
+ """
61
+
62
+ def wrapper(fn):
63
+ if hasattr(fn, "__leader_only_type__"):
64
+ raise RuntimeError(f"Function {fn.__name__} has already been decorated with {fn.__leader_only_type__}")
65
+
66
+ fn.__leader_only_type__ = leader_only_type
67
+
68
+ if leader_only_type == "local":
69
+ guard_fn = is_local_leader
70
+ elif leader_only_type == "global":
71
+ guard_fn = is_global_leader
72
+ else:
73
+ raise ValueError(f"Unknown leader_only_type: {leader_only_type}")
74
+
75
+ @wraps(fn)
76
+ def wrapped(*args, **kwargs):
77
+ if boardcast_return:
78
+ init_distributed()
79
+ obj_list = [None]
80
+ if guard_fn():
81
+ ret = fn(*args, **kwargs)
82
+ obj_list[0] = ret
83
+ if boardcast_return:
84
+ broadcast_object_list(obj_list, src=0)
85
+ return obj_list[0]
86
+
87
+ return wrapped
88
+
89
+ if fn is None:
90
+ return wrapper
91
+
92
+ return wrapper(fn)
93
+
94
+
95
+ local_leader_only = partial(leader_only, "local")
96
+ global_leader_only = partial(leader_only, "global")
resemble-enhance/resemble_enhance/utils/engine.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+ from functools import cache, partial
4
+ from typing import Callable, TypeVar
5
+
6
+ import deepspeed
7
+ import pandas as pd
8
+ from deepspeed.accelerator import get_accelerator
9
+ from deepspeed.runtime.engine import DeepSpeedEngine
10
+ from deepspeed.runtime.utils import clip_grad_norm_
11
+ from torch import nn
12
+
13
+ from .distributed import fix_unset_envs
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ T = TypeVar("T")
18
+
19
+
20
+ def flatten_dict(d):
21
+ records = pd.json_normalize(d, sep="/").to_dict(orient="records")
22
+ return records[0] if records else {}
23
+
24
+
25
+ def _get_named_modules(module, attrname, sep="/"):
26
+ for name, module in module.named_modules():
27
+ name = name.replace(".", sep)
28
+ if hasattr(module, attrname):
29
+ yield name, module
30
+
31
+
32
+ def gather_attribute(module, attrname, delete=True, prefix=None):
33
+ ret = {}
34
+ for name, module in _get_named_modules(module, attrname):
35
+ ret[name] = getattr(module, attrname)
36
+ if delete:
37
+ try:
38
+ delattr(module, attrname)
39
+ except Exception as e:
40
+ raise RuntimeError(f"{name} {module} {attrname}") from e
41
+ if prefix:
42
+ ret = {prefix: ret}
43
+ ret = flatten_dict(ret)
44
+ # remove consecutive /
45
+ ret = {re.sub(r"\/+", "/", k): v for k, v in ret.items()}
46
+ return ret
47
+
48
+
49
+ def dispatch_attribute(module, attrname, value, filter_fn: Callable[[nn.Module], bool] | None = None):
50
+ for _, module in _get_named_modules(module, attrname):
51
+ if filter_fn is None or filter_fn(module):
52
+ setattr(module, attrname, value)
53
+
54
+
55
+ @cache
56
+ def update_deepspeed_logger():
57
+ logger = logging.getLogger("DeepSpeed")
58
+ logger.setLevel(logging.WARNING)
59
+
60
+
61
+ @cache
62
+ def init_distributed():
63
+ update_deepspeed_logger()
64
+ fix_unset_envs()
65
+ deepspeed.init_distributed(get_accelerator().communication_backend_name())
66
+
67
+
68
+ def _try_each(*fns, e=None):
69
+ if len(fns) == 0:
70
+ raise RuntimeError("All functions failed")
71
+
72
+ head, *tails = fns
73
+
74
+ try:
75
+ return head()
76
+ except Exception as e:
77
+ logger.warning(f"Tried {head} but failed: {e}, trying next")
78
+ return _try_each(*tails)
79
+
80
+
81
+ class Engine(DeepSpeedEngine):
82
+ def __init__(self, *args, ckpt_dir, **kwargs):
83
+ init_distributed()
84
+ super().__init__(args=None, *args, **kwargs)
85
+ self._ckpt_dir = ckpt_dir
86
+ self._frozen_params = set()
87
+ self._fp32_grad_norm = None
88
+
89
+ @property
90
+ def path(self):
91
+ return self._ckpt_dir
92
+
93
+ def freeze_(self):
94
+ for p in self.module.parameters():
95
+ if p.requires_grad:
96
+ p.requires_grad_(False)
97
+ self._frozen_params.add(p)
98
+
99
+ def unfreeze_(self):
100
+ for p in self._frozen_params:
101
+ p.requires_grad_(True)
102
+ self._frozen_params.clear()
103
+
104
+ @property
105
+ def global_step(self):
106
+ return self.global_steps
107
+
108
+ def gather_attribute(self, *args, **kwargs):
109
+ return gather_attribute(self.module, *args, **kwargs)
110
+
111
+ def dispatch_attribute(self, *args, **kwargs):
112
+ return dispatch_attribute(self.module, *args, **kwargs)
113
+
114
+ def clip_fp32_gradients(self):
115
+ self._fp32_grad_norm = clip_grad_norm_(
116
+ parameters=self.module.parameters(),
117
+ max_norm=self.gradient_clipping(),
118
+ mpu=self.mpu,
119
+ )
120
+
121
+ def get_grad_norm(self):
122
+ grad_norm = self.get_global_grad_norm()
123
+ if grad_norm is None:
124
+ grad_norm = self._fp32_grad_norm
125
+ return grad_norm
126
+
127
+ def save_checkpoint(self, *args, **kwargs):
128
+ if not self._ckpt_dir.exists():
129
+ self._ckpt_dir.mkdir(parents=True, exist_ok=True)
130
+ super().save_checkpoint(save_dir=self._ckpt_dir, *args, **kwargs)
131
+ logger.info(f"Saved checkpoint to {self._ckpt_dir}")
132
+
133
+ def load_checkpoint(self, *args, **kwargs):
134
+ fn = partial(super().load_checkpoint, *args, load_dir=self._ckpt_dir, **kwargs)
135
+ return _try_each(
136
+ lambda: fn(),
137
+ lambda: fn(load_optimizer_states=False),
138
+ lambda: fn(load_lr_scheduler_states=False),
139
+ lambda: fn(load_optimizer_states=False, load_lr_scheduler_states=False),
140
+ lambda: fn(
141
+ load_optimizer_states=False,
142
+ load_lr_scheduler_states=False,
143
+ load_module_strict=False,
144
+ ),
145
+ )
resemble-enhance/resemble_enhance/utils/logging.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ from rich.logging import RichHandler
5
+
6
+ from .distributed import global_leader_only
7
+
8
+
9
+ @global_leader_only
10
+ def setup_logging(run_dir):
11
+ handlers = []
12
+ stdout_handler = RichHandler()
13
+ stdout_handler.setLevel(logging.INFO)
14
+ handlers.append(stdout_handler)
15
+
16
+ if run_dir is not None:
17
+ filename = Path(run_dir) / f"log.txt"
18
+ filename.parent.mkdir(parents=True, exist_ok=True)
19
+ file_handler = logging.FileHandler(filename, mode="a")
20
+ file_handler.setLevel(logging.DEBUG)
21
+ handlers.append(file_handler)
22
+
23
+ # Update all existing loggers
24
+ for name in ["DeepSpeed"]:
25
+ logger = logging.getLogger(name)
26
+ if isinstance(logger, logging.Logger):
27
+ for handler in list(logger.handlers):
28
+ logger.removeHandler(handler)
29
+ for handler in handlers:
30
+ logger.addHandler(handler)
31
+
32
+ # Set the default logger
33
+ logging.basicConfig(
34
+ level=logging.getLevelName("INFO"),
35
+ format="%(message)s",
36
+ datefmt="[%X]",
37
+ handlers=handlers,
38
+ )