Wendyellé Abubakrh Alban NYANTUDRE
commited on
Commit
•
88b5dc0
1
Parent(s):
c49c7f5
finally deleted .git file from resemble-enhance
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- resemble-enhance/README.md +95 -0
- resemble-enhance/config/denoiser.yaml +2 -0
- resemble-enhance/config/enhancer_stage1.yaml +4 -0
- resemble-enhance/config/enhancer_stage2.yaml +8 -0
- resemble-enhance/resemble_enhance/__init__.py +0 -0
- resemble-enhance/resemble_enhance/common.py +55 -0
- resemble-enhance/resemble_enhance/data/__init__.py +48 -0
- resemble-enhance/resemble_enhance/data/dataset.py +171 -0
- resemble-enhance/resemble_enhance/data/distorter/__init__.py +1 -0
- resemble-enhance/resemble_enhance/data/distorter/base.py +104 -0
- resemble-enhance/resemble_enhance/data/distorter/custom.py +85 -0
- resemble-enhance/resemble_enhance/data/distorter/distorter.py +32 -0
- resemble-enhance/resemble_enhance/data/distorter/sox.py +176 -0
- resemble-enhance/resemble_enhance/data/utils.py +43 -0
- resemble-enhance/resemble_enhance/denoiser/__init__.py +0 -0
- resemble-enhance/resemble_enhance/denoiser/__main__.py +30 -0
- resemble-enhance/resemble_enhance/denoiser/denoiser.py +181 -0
- resemble-enhance/resemble_enhance/denoiser/hparams.py +9 -0
- resemble-enhance/resemble_enhance/denoiser/inference.py +29 -0
- resemble-enhance/resemble_enhance/denoiser/train.py +112 -0
- resemble-enhance/resemble_enhance/denoiser/unet.py +144 -0
- resemble-enhance/resemble_enhance/enhancer/__init__.py +0 -0
- resemble-enhance/resemble_enhance/enhancer/__main__.py +129 -0
- resemble-enhance/resemble_enhance/enhancer/download.py +30 -0
- resemble-enhance/resemble_enhance/enhancer/enhancer.py +195 -0
- resemble-enhance/resemble_enhance/enhancer/hparams.py +23 -0
- resemble-enhance/resemble_enhance/enhancer/inference.py +41 -0
- resemble-enhance/resemble_enhance/enhancer/lcfm/__init__.py +2 -0
- resemble-enhance/resemble_enhance/enhancer/lcfm/cfm.py +372 -0
- resemble-enhance/resemble_enhance/enhancer/lcfm/irmae.py +123 -0
- resemble-enhance/resemble_enhance/enhancer/lcfm/lcfm.py +152 -0
- resemble-enhance/resemble_enhance/enhancer/lcfm/wn.py +147 -0
- resemble-enhance/resemble_enhance/enhancer/train.py +143 -0
- resemble-enhance/resemble_enhance/enhancer/univnet/__init__.py +1 -0
- resemble-enhance/resemble_enhance/enhancer/univnet/alias_free_torch/__init__.py +5 -0
- resemble-enhance/resemble_enhance/enhancer/univnet/alias_free_torch/filter.py +95 -0
- resemble-enhance/resemble_enhance/enhancer/univnet/alias_free_torch/resample.py +49 -0
- resemble-enhance/resemble_enhance/enhancer/univnet/amp.py +101 -0
- resemble-enhance/resemble_enhance/enhancer/univnet/discriminator.py +210 -0
- resemble-enhance/resemble_enhance/enhancer/univnet/lvcnet.py +281 -0
- resemble-enhance/resemble_enhance/enhancer/univnet/mrstft.py +128 -0
- resemble-enhance/resemble_enhance/enhancer/univnet/univnet.py +94 -0
- resemble-enhance/resemble_enhance/hparams.py +128 -0
- resemble-enhance/resemble_enhance/inference.py +163 -0
- resemble-enhance/resemble_enhance/melspec.py +61 -0
- resemble-enhance/resemble_enhance/utils/__init__.py +5 -0
- resemble-enhance/resemble_enhance/utils/control.py +26 -0
- resemble-enhance/resemble_enhance/utils/distributed.py +96 -0
- resemble-enhance/resemble_enhance/utils/engine.py +145 -0
- resemble-enhance/resemble_enhance/utils/logging.py +38 -0
resemble-enhance/README.md
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Resemble Enhance
|
2 |
+
|
3 |
+
[![PyPI](https://img.shields.io/pypi/v/resemble-enhance.svg)](https://pypi.org/project/resemble-enhance/)
|
4 |
+
[![Hugging Face Space](https://img.shields.io/badge/Hugging%20Face%20%F0%9F%A4%97-Space-yellow)](https://huggingface.co/spaces/ResembleAI/resemble-enhance)
|
5 |
+
[![License](https://img.shields.io/github/license/resemble-ai/Resemble-Enhance.svg)](https://github.com/resemble-ai/resemble-enhance/blob/main/LICENSE)
|
6 |
+
[![Webpage](https://img.shields.io/badge/Webpage-Online-brightgreen)](https://www.resemble.ai/enhance/)
|
7 |
+
|
8 |
+
https://github.com/resemble-ai/resemble-enhance/assets/660224/bc3ec943-e795-4646-b119-cce327c810f1
|
9 |
+
|
10 |
+
Resemble Enhance is an AI-powered tool that aims to improve the overall quality of speech by performing denoising and enhancement. It consists of two modules: a denoiser, which separates speech from a noisy audio, and an enhancer, which further boosts the perceptual audio quality by restoring audio distortions and extending the audio bandwidth. The two models are trained on high-quality 44.1kHz speech data that guarantees the enhancement of your speech with high quality.
|
11 |
+
|
12 |
+
## Usage
|
13 |
+
|
14 |
+
### Installation
|
15 |
+
|
16 |
+
Install the stable version:
|
17 |
+
|
18 |
+
```bash
|
19 |
+
pip install resemble-enhance --upgrade
|
20 |
+
```
|
21 |
+
|
22 |
+
Or try the latest pre-release version:
|
23 |
+
|
24 |
+
```bash
|
25 |
+
pip install resemble-enhance --upgrade --pre
|
26 |
+
```
|
27 |
+
|
28 |
+
### Enhance
|
29 |
+
|
30 |
+
```
|
31 |
+
resemble_enhance in_dir out_dir
|
32 |
+
```
|
33 |
+
|
34 |
+
### Denoise only
|
35 |
+
|
36 |
+
```
|
37 |
+
resemble_enhance in_dir out_dir --denoise_only
|
38 |
+
```
|
39 |
+
|
40 |
+
### Web Demo
|
41 |
+
|
42 |
+
We provide a web demo built with Gradio, you can try it out [here](https://huggingface.co/spaces/ResembleAI/resemble-enhance), or also run it locally:
|
43 |
+
|
44 |
+
```
|
45 |
+
python app.py
|
46 |
+
```
|
47 |
+
|
48 |
+
## Train your own model
|
49 |
+
|
50 |
+
### Data Preparation
|
51 |
+
|
52 |
+
You need to prepare a foreground speech dataset and a background non-speech dataset. In addition, you need to prepare a RIR dataset ([examples](https://github.com/RoyJames/room-impulse-responses)).
|
53 |
+
|
54 |
+
```bash
|
55 |
+
data
|
56 |
+
├── fg
|
57 |
+
│ ├── 00001.wav
|
58 |
+
│ └── ...
|
59 |
+
├── bg
|
60 |
+
│ ├── 00001.wav
|
61 |
+
│ └── ...
|
62 |
+
└── rir
|
63 |
+
├── 00001.npy
|
64 |
+
└── ...
|
65 |
+
```
|
66 |
+
|
67 |
+
### Training
|
68 |
+
|
69 |
+
#### Denoiser Warmup
|
70 |
+
|
71 |
+
Though the denoiser is trained jointly with the enhancer, it is recommended for a warmup training first.
|
72 |
+
|
73 |
+
```bash
|
74 |
+
python -m resemble_enhance.denoiser.train --yaml config/denoiser.yaml runs/denoiser
|
75 |
+
```
|
76 |
+
|
77 |
+
#### Enhancer
|
78 |
+
|
79 |
+
Then, you can train the enhancer in two stages. The first stage is to train the autoencoder and vocoder. And the second stage is to train the latent conditional flow matching (CFM) model.
|
80 |
+
|
81 |
+
##### Stage 1
|
82 |
+
|
83 |
+
```bash
|
84 |
+
python -m resemble_enhance.enhancer.train --yaml config/enhancer_stage1.yaml runs/enhancer_stage1
|
85 |
+
```
|
86 |
+
|
87 |
+
##### Stage 2
|
88 |
+
|
89 |
+
```bash
|
90 |
+
python -m resemble_enhance.enhancer.train --yaml config/enhancer_stage2.yaml runs/enhancer_stage2
|
91 |
+
```
|
92 |
+
|
93 |
+
## Blog
|
94 |
+
|
95 |
+
Learn more on our [website](https://www.resemble.ai/enhance/)!
|
resemble-enhance/config/denoiser.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
batch_size_per_gpu: 32
|
2 |
+
training_seconds: 3.0
|
resemble-enhance/config/enhancer_stage1.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
lcfm_training_mode: ae
|
2 |
+
load_fg_only: true
|
3 |
+
batch_size_per_gpu: 16
|
4 |
+
denoiser_run_dir: runs/denoiser
|
resemble-enhance/config/enhancer_stage2.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
lcfm_training_mode: cfm
|
2 |
+
batch_size_per_gpu: 32
|
3 |
+
training_seconds: 3.0
|
4 |
+
gan_training_start_step: null
|
5 |
+
lcfm_z_scale: 6
|
6 |
+
praat_augment_prob: 0.2
|
7 |
+
denoiser_run_dir: runs/denoiser
|
8 |
+
enhancer_stage1_run_dir: runs/enhancer_stage1
|
resemble-enhance/resemble_enhance/__init__.py
ADDED
File without changes
|
resemble-enhance/resemble_enhance/common.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import Tensor, nn
|
5 |
+
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
|
9 |
+
class Normalizer(nn.Module):
|
10 |
+
def __init__(self, momentum=0.01, eps=1e-9):
|
11 |
+
super().__init__()
|
12 |
+
self.momentum = momentum
|
13 |
+
self.eps = eps
|
14 |
+
self.running_mean_unsafe: Tensor
|
15 |
+
self.running_var_unsafe: Tensor
|
16 |
+
self.register_buffer("running_mean_unsafe", torch.full([], torch.nan))
|
17 |
+
self.register_buffer("running_var_unsafe", torch.full([], torch.nan))
|
18 |
+
|
19 |
+
@property
|
20 |
+
def started(self):
|
21 |
+
return not torch.isnan(self.running_mean_unsafe)
|
22 |
+
|
23 |
+
@property
|
24 |
+
def running_mean(self):
|
25 |
+
if not self.started:
|
26 |
+
return torch.zeros_like(self.running_mean_unsafe)
|
27 |
+
return self.running_mean_unsafe
|
28 |
+
|
29 |
+
@property
|
30 |
+
def running_std(self):
|
31 |
+
if not self.started:
|
32 |
+
return torch.ones_like(self.running_var_unsafe)
|
33 |
+
return (self.running_var_unsafe + self.eps).sqrt()
|
34 |
+
|
35 |
+
@torch.no_grad()
|
36 |
+
def _ema(self, a: Tensor, x: Tensor):
|
37 |
+
return (1 - self.momentum) * a + self.momentum * x
|
38 |
+
|
39 |
+
def update_(self, x):
|
40 |
+
if not self.started:
|
41 |
+
self.running_mean_unsafe = x.mean()
|
42 |
+
self.running_var_unsafe = x.var()
|
43 |
+
else:
|
44 |
+
self.running_mean_unsafe = self._ema(self.running_mean_unsafe, x.mean())
|
45 |
+
self.running_var_unsafe = self._ema(self.running_var_unsafe, (x - self.running_mean).pow(2).mean())
|
46 |
+
|
47 |
+
def forward(self, x: Tensor, update=True):
|
48 |
+
if self.training and update:
|
49 |
+
self.update_(x)
|
50 |
+
self.stats = dict(mean=self.running_mean.item(), std=self.running_std.item())
|
51 |
+
x = (x - self.running_mean) / self.running_std
|
52 |
+
return x
|
53 |
+
|
54 |
+
def inverse(self, x: Tensor):
|
55 |
+
return x * self.running_std + self.running_mean
|
resemble-enhance/resemble_enhance/data/__init__.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import random
|
3 |
+
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
|
6 |
+
from ..hparams import HParams
|
7 |
+
from .dataset import Dataset
|
8 |
+
from .utils import mix_fg_bg, rglob_audio_files
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
def _create_datasets(hp: HParams, mode, val_size=10, seed=123):
|
14 |
+
paths = rglob_audio_files(hp.fg_dir)
|
15 |
+
logger.info(f"Found {len(paths)} audio files in {hp.fg_dir}")
|
16 |
+
|
17 |
+
random.Random(seed).shuffle(paths)
|
18 |
+
train_paths = paths[:-val_size]
|
19 |
+
val_paths = paths[-val_size:]
|
20 |
+
|
21 |
+
train_ds = Dataset(train_paths, hp, training=True, mode=mode)
|
22 |
+
val_ds = Dataset(val_paths, hp, training=False, mode=mode)
|
23 |
+
|
24 |
+
logger.info(f"Train set: {len(train_ds)} samples - Val set: {len(val_ds)} samples")
|
25 |
+
|
26 |
+
return train_ds, val_ds
|
27 |
+
|
28 |
+
|
29 |
+
def create_dataloaders(hp: HParams, mode):
|
30 |
+
train_ds, val_ds = _create_datasets(hp=hp, mode=mode)
|
31 |
+
|
32 |
+
train_dl = DataLoader(
|
33 |
+
train_ds,
|
34 |
+
batch_size=hp.batch_size_per_gpu,
|
35 |
+
shuffle=True,
|
36 |
+
num_workers=hp.nj,
|
37 |
+
drop_last=True,
|
38 |
+
collate_fn=train_ds.collate_fn,
|
39 |
+
)
|
40 |
+
val_dl = DataLoader(
|
41 |
+
val_ds,
|
42 |
+
batch_size=1,
|
43 |
+
shuffle=False,
|
44 |
+
num_workers=hp.nj,
|
45 |
+
drop_last=False,
|
46 |
+
collate_fn=val_ds.collate_fn,
|
47 |
+
)
|
48 |
+
return train_dl, val_dl
|
resemble-enhance/resemble_enhance/data/dataset.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import random
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torchaudio
|
8 |
+
import torchaudio.functional as AF
|
9 |
+
from torch.nn.utils.rnn import pad_sequence
|
10 |
+
from torch.utils.data import Dataset as DatasetBase
|
11 |
+
|
12 |
+
from ..hparams import HParams
|
13 |
+
from .distorter import Distorter
|
14 |
+
from .utils import rglob_audio_files
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
def _normalize(x):
|
20 |
+
return x / (np.abs(x).max() + 1e-7)
|
21 |
+
|
22 |
+
|
23 |
+
def _collate(batch, key, tensor=True, pad=True):
|
24 |
+
l = [d[key] for d in batch]
|
25 |
+
if l[0] is None:
|
26 |
+
return None
|
27 |
+
if tensor:
|
28 |
+
l = [torch.from_numpy(x) for x in l]
|
29 |
+
if pad:
|
30 |
+
assert tensor, "Can't pad non-tensor"
|
31 |
+
l = pad_sequence(l, batch_first=True)
|
32 |
+
return l
|
33 |
+
|
34 |
+
|
35 |
+
def praat_augment(wav, sr):
|
36 |
+
try:
|
37 |
+
import parselmouth
|
38 |
+
except ImportError:
|
39 |
+
raise ImportError("Please install parselmouth>=0.5.0 to use Praat augmentation")
|
40 |
+
# "praat-parselmouth @ git+https://github.com/YannickJadoul/Parselmouth@0bbcca69705ed73322f3712b19d71bb3694b2540",
|
41 |
+
# https://github.com/YannickJadoul/Parselmouth/issues/68
|
42 |
+
# note that this function may hang if the praat version is 0.4.3
|
43 |
+
assert wav.ndim == 1, f"wav.ndim must be 1 but got {wav.ndim}"
|
44 |
+
sound = parselmouth.Sound(wav, sr)
|
45 |
+
formant_shift_ratio = random.uniform(1.1, 1.5)
|
46 |
+
pitch_range_factor = random.uniform(0.5, 2.0)
|
47 |
+
sound = parselmouth.praat.call(sound, "Change gender", 75, 600, formant_shift_ratio, 0, pitch_range_factor, 1.0)
|
48 |
+
wav = np.array(sound.values)[0].astype(np.float32)
|
49 |
+
return wav
|
50 |
+
|
51 |
+
|
52 |
+
class Dataset(DatasetBase):
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
fg_paths: list[Path],
|
56 |
+
hp: HParams,
|
57 |
+
training=True,
|
58 |
+
max_retries=100,
|
59 |
+
silent_fg_prob=0.01,
|
60 |
+
mode=False,
|
61 |
+
):
|
62 |
+
super().__init__()
|
63 |
+
|
64 |
+
assert mode in ("enhancer", "denoiser"), f"Invalid mode: {mode}"
|
65 |
+
|
66 |
+
self.hp = hp
|
67 |
+
self.fg_paths = fg_paths
|
68 |
+
self.bg_paths = rglob_audio_files(hp.bg_dir)
|
69 |
+
|
70 |
+
if len(self.fg_paths) == 0:
|
71 |
+
raise ValueError(f"No foreground audio files found in {hp.fg_dir}")
|
72 |
+
|
73 |
+
if len(self.bg_paths) == 0:
|
74 |
+
raise ValueError(f"No background audio files found in {hp.bg_dir}")
|
75 |
+
|
76 |
+
logger.info(f"Found {len(self.fg_paths)} foreground files and {len(self.bg_paths)} background files")
|
77 |
+
|
78 |
+
self.training = training
|
79 |
+
self.max_retries = max_retries
|
80 |
+
self.silent_fg_prob = silent_fg_prob
|
81 |
+
|
82 |
+
self.mode = mode
|
83 |
+
self.distorter = Distorter(hp, training=training, mode=mode)
|
84 |
+
|
85 |
+
def _load_wav(self, path, length=None, random_crop=True):
|
86 |
+
wav, sr = torchaudio.load(path)
|
87 |
+
|
88 |
+
wav = AF.resample(
|
89 |
+
waveform=wav,
|
90 |
+
orig_freq=sr,
|
91 |
+
new_freq=self.hp.wav_rate,
|
92 |
+
lowpass_filter_width=64,
|
93 |
+
rolloff=0.9475937167399596,
|
94 |
+
resampling_method="sinc_interp_kaiser",
|
95 |
+
beta=14.769656459379492,
|
96 |
+
)
|
97 |
+
|
98 |
+
wav = wav.float().numpy()
|
99 |
+
|
100 |
+
if wav.ndim == 2:
|
101 |
+
wav = np.mean(wav, axis=0)
|
102 |
+
|
103 |
+
if length is None and self.training:
|
104 |
+
length = int(self.hp.training_seconds * self.hp.wav_rate)
|
105 |
+
|
106 |
+
if length is not None:
|
107 |
+
if random_crop:
|
108 |
+
start = random.randint(0, max(0, len(wav) - length))
|
109 |
+
wav = wav[start : start + length]
|
110 |
+
else:
|
111 |
+
wav = wav[:length]
|
112 |
+
|
113 |
+
if length is not None and len(wav) < length:
|
114 |
+
wav = np.pad(wav, (0, length - len(wav)))
|
115 |
+
|
116 |
+
wav = _normalize(wav)
|
117 |
+
|
118 |
+
return wav
|
119 |
+
|
120 |
+
def _getitem_unsafe(self, index: int):
|
121 |
+
fg_path = self.fg_paths[index]
|
122 |
+
|
123 |
+
if self.training and random.random() < self.silent_fg_prob:
|
124 |
+
fg_wav = np.zeros(int(self.hp.training_seconds * self.hp.wav_rate), dtype=np.float32)
|
125 |
+
else:
|
126 |
+
fg_wav = self._load_wav(fg_path)
|
127 |
+
if random.random() < self.hp.praat_augment_prob and self.training:
|
128 |
+
fg_wav = praat_augment(fg_wav, self.hp.wav_rate)
|
129 |
+
|
130 |
+
if self.hp.load_fg_only:
|
131 |
+
bg_wav = None
|
132 |
+
fg_dwav = None
|
133 |
+
bg_dwav = None
|
134 |
+
else:
|
135 |
+
fg_dwav = _normalize(self.distorter(fg_wav, self.hp.wav_rate)).astype(np.float32)
|
136 |
+
if self.training:
|
137 |
+
bg_path = random.choice(self.bg_paths)
|
138 |
+
else:
|
139 |
+
# Deterministic for validation
|
140 |
+
bg_path = self.bg_paths[index % len(self.bg_paths)]
|
141 |
+
bg_wav = self._load_wav(bg_path, length=len(fg_wav), random_crop=self.training)
|
142 |
+
bg_dwav = _normalize(self.distorter(bg_wav, self.hp.wav_rate)).astype(np.float32)
|
143 |
+
|
144 |
+
return dict(
|
145 |
+
fg_wav=fg_wav,
|
146 |
+
bg_wav=bg_wav,
|
147 |
+
fg_dwav=fg_dwav,
|
148 |
+
bg_dwav=bg_dwav,
|
149 |
+
)
|
150 |
+
|
151 |
+
def __getitem__(self, index: int):
|
152 |
+
for i in range(self.max_retries):
|
153 |
+
try:
|
154 |
+
return self._getitem_unsafe(index)
|
155 |
+
except Exception as e:
|
156 |
+
if i == self.max_retries - 1:
|
157 |
+
raise RuntimeError(f"Failed to load {self.fg_paths[index]} after {self.max_retries} retries") from e
|
158 |
+
logger.debug(f"Error loading {self.fg_paths[index]}: {e}, skipping")
|
159 |
+
index = np.random.randint(0, len(self))
|
160 |
+
|
161 |
+
def __len__(self):
|
162 |
+
return len(self.fg_paths)
|
163 |
+
|
164 |
+
@staticmethod
|
165 |
+
def collate_fn(batch):
|
166 |
+
return dict(
|
167 |
+
fg_wavs=_collate(batch, "fg_wav"),
|
168 |
+
bg_wavs=_collate(batch, "bg_wav"),
|
169 |
+
fg_dwavs=_collate(batch, "fg_dwav"),
|
170 |
+
bg_dwavs=_collate(batch, "bg_dwav"),
|
171 |
+
)
|
resemble-enhance/resemble_enhance/data/distorter/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .distorter import Distorter
|
resemble-enhance/resemble_enhance/data/distorter/base.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import time
|
5 |
+
import warnings
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
_DEBUG = bool(os.environ.get("DEBUG", False))
|
10 |
+
|
11 |
+
|
12 |
+
class Effect:
|
13 |
+
def apply(self, wav: np.ndarray, sr: int):
|
14 |
+
"""
|
15 |
+
Args:
|
16 |
+
wav: (T)
|
17 |
+
sr: sample rate
|
18 |
+
Returns:
|
19 |
+
wav: (T) with the same sample rate of `sr`
|
20 |
+
"""
|
21 |
+
raise NotImplementedError
|
22 |
+
|
23 |
+
def __call__(self, wav: np.ndarray, sr: int):
|
24 |
+
"""
|
25 |
+
Args:
|
26 |
+
wav: (T)
|
27 |
+
sr: sample rate
|
28 |
+
Returns:
|
29 |
+
wav: (T) with the same sample rate of `sr`
|
30 |
+
"""
|
31 |
+
assert len(wav.shape) == 1, wav.shape
|
32 |
+
|
33 |
+
if _DEBUG:
|
34 |
+
start = time.time()
|
35 |
+
else:
|
36 |
+
start = None
|
37 |
+
|
38 |
+
shape = wav.shape
|
39 |
+
assert wav.ndim == 1, f"{self}: Expected wav.ndim == 1, got {wav.ndim}."
|
40 |
+
wav = self.apply(wav, sr)
|
41 |
+
assert shape == wav.shape, f"{self}: {shape} != {wav.shape}."
|
42 |
+
|
43 |
+
if start is not None:
|
44 |
+
end = time.time()
|
45 |
+
print(f"{self.__class__.__name__}: {end - start:.3f} sec")
|
46 |
+
|
47 |
+
return wav
|
48 |
+
|
49 |
+
|
50 |
+
class Chain(Effect):
|
51 |
+
def __init__(self, *effects):
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
self.effects = effects
|
55 |
+
|
56 |
+
def apply(self, wav, sr):
|
57 |
+
for effect in self.effects:
|
58 |
+
wav = effect(wav, sr)
|
59 |
+
return wav
|
60 |
+
|
61 |
+
|
62 |
+
class Maybe(Effect):
|
63 |
+
def __init__(self, prob, effect):
|
64 |
+
super().__init__()
|
65 |
+
|
66 |
+
self.prob = prob
|
67 |
+
self.effect = effect
|
68 |
+
|
69 |
+
if _DEBUG:
|
70 |
+
warnings.warn("DEBUG mode is on. Maybe -> Must.")
|
71 |
+
self.prob = 1
|
72 |
+
|
73 |
+
def apply(self, wav, sr):
|
74 |
+
if random.random() > self.prob:
|
75 |
+
return wav
|
76 |
+
return self.effect(wav, sr)
|
77 |
+
|
78 |
+
|
79 |
+
class Choice(Effect):
|
80 |
+
def __init__(self, *effects, **kwargs):
|
81 |
+
super().__init__()
|
82 |
+
self.effects = effects
|
83 |
+
self.kwargs = kwargs
|
84 |
+
|
85 |
+
def apply(self, wav, sr):
|
86 |
+
return np.random.choice(self.effects, **self.kwargs)(wav, sr)
|
87 |
+
|
88 |
+
|
89 |
+
class Permutation(Effect):
|
90 |
+
def __init__(self, *effects, n: int | None = None):
|
91 |
+
super().__init__()
|
92 |
+
self.effects = effects
|
93 |
+
self.n = n
|
94 |
+
|
95 |
+
def apply(self, wav, sr):
|
96 |
+
if self.n is None:
|
97 |
+
n = np.random.binomial(len(self.effects), 0.5)
|
98 |
+
else:
|
99 |
+
n = self.n
|
100 |
+
if n == 0:
|
101 |
+
return wav
|
102 |
+
perms = itertools.permutations(self.effects, n)
|
103 |
+
effects = random.choice(list(perms))
|
104 |
+
return Chain(*effects)(wav, sr)
|
resemble-enhance/resemble_enhance/data/distorter/custom.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import random
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from functools import cached_property
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import librosa
|
8 |
+
import numpy as np
|
9 |
+
from scipy import signal
|
10 |
+
|
11 |
+
from ..utils import walk_paths
|
12 |
+
from .base import Effect
|
13 |
+
|
14 |
+
_logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class RandomRIR(Effect):
|
19 |
+
rir_dir: Path | None
|
20 |
+
rir_rate: int = 44_000
|
21 |
+
rir_suffix: str = ".npy"
|
22 |
+
deterministic: bool = False
|
23 |
+
|
24 |
+
@cached_property
|
25 |
+
def rir_paths(self):
|
26 |
+
if self.rir_dir is None:
|
27 |
+
return []
|
28 |
+
return list(walk_paths(self.rir_dir, self.rir_suffix))
|
29 |
+
|
30 |
+
def _sample_rir(self):
|
31 |
+
if len(self.rir_paths) == 0:
|
32 |
+
return None
|
33 |
+
|
34 |
+
if self.deterministic:
|
35 |
+
rir_path = self.rir_paths[0]
|
36 |
+
else:
|
37 |
+
rir_path = random.choice(self.rir_paths)
|
38 |
+
|
39 |
+
rir = np.squeeze(np.load(rir_path))
|
40 |
+
assert isinstance(rir, np.ndarray)
|
41 |
+
|
42 |
+
return rir
|
43 |
+
|
44 |
+
def apply(self, wav, sr):
|
45 |
+
# ref: https://github.com/haoheliu/voicefixer_main/blob/b06e07c945ac1d309b8a57ddcd599ca376b98cd9/dataloaders/augmentation/magical_effects.py#L158
|
46 |
+
|
47 |
+
if len(self.rir_paths) == 0:
|
48 |
+
return wav
|
49 |
+
|
50 |
+
length = len(wav)
|
51 |
+
|
52 |
+
wav = librosa.resample(wav, orig_sr=sr, target_sr=self.rir_rate, res_type="kaiser_fast")
|
53 |
+
rir = self._sample_rir()
|
54 |
+
|
55 |
+
wav = signal.convolve(wav, rir, mode="same")
|
56 |
+
|
57 |
+
actlev = np.max(np.abs(wav))
|
58 |
+
if actlev > 0.99:
|
59 |
+
wav = (wav / actlev) * 0.98
|
60 |
+
|
61 |
+
wav = librosa.resample(wav, orig_sr=self.rir_rate, target_sr=sr, res_type="kaiser_fast")
|
62 |
+
|
63 |
+
if abs(length - len(wav)) > 10:
|
64 |
+
_logger.warning(f"length mismatch: {length} vs {len(wav)}")
|
65 |
+
|
66 |
+
if length > len(wav):
|
67 |
+
wav = np.pad(wav, (0, length - len(wav)))
|
68 |
+
elif length < len(wav):
|
69 |
+
wav = wav[:length]
|
70 |
+
|
71 |
+
return wav
|
72 |
+
|
73 |
+
|
74 |
+
class RandomGaussianNoise(Effect):
|
75 |
+
def __init__(self, alpha_range=(0.8, 1)):
|
76 |
+
super().__init__()
|
77 |
+
self.alpha_range = alpha_range
|
78 |
+
|
79 |
+
def apply(self, wav, sr):
|
80 |
+
noise = np.random.randn(*wav.shape)
|
81 |
+
noise_energy = np.sum(noise**2)
|
82 |
+
wav_energy = np.sum(wav**2)
|
83 |
+
noise = noise * np.sqrt(wav_energy / noise_energy)
|
84 |
+
alpha = random.uniform(*self.alpha_range)
|
85 |
+
return wav * alpha + noise * (1 - alpha)
|
resemble-enhance/resemble_enhance/data/distorter/distorter.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ...hparams import HParams
|
2 |
+
from .base import Chain, Choice, Permutation
|
3 |
+
from .custom import RandomGaussianNoise, RandomRIR
|
4 |
+
|
5 |
+
|
6 |
+
class Distorter(Chain):
|
7 |
+
def __init__(self, hp: HParams, training: bool = False, mode: str = "enhancer"):
|
8 |
+
# Lazy import
|
9 |
+
from .sox import RandomBandpassDistorter, RandomEqualizer, RandomLowpassDistorter, RandomOverdrive, RandomReverb
|
10 |
+
|
11 |
+
if training:
|
12 |
+
permutation = Permutation(
|
13 |
+
RandomRIR(hp.rir_dir),
|
14 |
+
RandomReverb(),
|
15 |
+
RandomGaussianNoise(),
|
16 |
+
RandomOverdrive(),
|
17 |
+
RandomEqualizer(),
|
18 |
+
Choice(
|
19 |
+
RandomLowpassDistorter(),
|
20 |
+
RandomBandpassDistorter(),
|
21 |
+
),
|
22 |
+
)
|
23 |
+
if mode == "denoiser":
|
24 |
+
super().__init__(permutation)
|
25 |
+
else:
|
26 |
+
# 80%: distortion, 20%: clean
|
27 |
+
super().__init__(Choice(permutation, Chain(), p=[0.8, 0.2]))
|
28 |
+
else:
|
29 |
+
super().__init__(
|
30 |
+
RandomRIR(hp.rir_dir, deterministic=True),
|
31 |
+
RandomReverb(deterministic=True),
|
32 |
+
)
|
resemble-enhance/resemble_enhance/data/distorter/sox.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import warnings
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
try:
|
11 |
+
import augment
|
12 |
+
except ImportError:
|
13 |
+
raise ImportError(
|
14 |
+
"augment is not installed, please install it first using:"
|
15 |
+
"\npip install git+https://github.com/facebookresearch/WavAugment@54afcdb00ccc852c2f030f239f8532c9562b550e"
|
16 |
+
)
|
17 |
+
|
18 |
+
from .base import Effect
|
19 |
+
|
20 |
+
_logger = logging.getLogger(__name__)
|
21 |
+
_DEBUG = bool(os.environ.get("DEBUG", False))
|
22 |
+
|
23 |
+
|
24 |
+
class AttachableEffect(Effect):
|
25 |
+
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
|
26 |
+
raise NotImplementedError
|
27 |
+
|
28 |
+
def apply(self, wav: np.ndarray, sr: int):
|
29 |
+
chain = augment.EffectChain()
|
30 |
+
chain = self.attach(chain)
|
31 |
+
tensor = torch.from_numpy(wav)[None].float() # (1, T)
|
32 |
+
tensor = chain.apply(tensor, src_info={"rate": sr}, target_info={"channels": 1, "rate": sr})
|
33 |
+
wav = tensor.numpy()[0] # (T,)
|
34 |
+
return wav
|
35 |
+
|
36 |
+
|
37 |
+
class SoxEffect(AttachableEffect):
|
38 |
+
def __init__(self, effect_name: str, *args, **kwargs):
|
39 |
+
self.effect_name = effect_name
|
40 |
+
self.args = args
|
41 |
+
self.kwargs = kwargs
|
42 |
+
|
43 |
+
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
|
44 |
+
_logger.debug(f"Attaching {self.effect_name} with {self.args} and {self.kwargs}")
|
45 |
+
if not hasattr(chain, self.effect_name):
|
46 |
+
raise ValueError(f"EffectChain has no attribute {self.effect_name}")
|
47 |
+
return getattr(chain, self.effect_name)(*self.args, **self.kwargs)
|
48 |
+
|
49 |
+
|
50 |
+
class Maybe(AttachableEffect):
|
51 |
+
"""
|
52 |
+
Attach an effect with a probability.
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(self, prob: float, effect: AttachableEffect):
|
56 |
+
self.prob = prob
|
57 |
+
self.effect = effect
|
58 |
+
if _DEBUG:
|
59 |
+
warnings.warn("DEBUG mode is on. Maybe -> Must.")
|
60 |
+
self.prob = 1
|
61 |
+
|
62 |
+
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
|
63 |
+
if random.random() > self.prob:
|
64 |
+
return chain
|
65 |
+
return self.effect.attach(chain)
|
66 |
+
|
67 |
+
|
68 |
+
class Chain(AttachableEffect):
|
69 |
+
"""
|
70 |
+
Attach a chain of effects.
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(self, *effects: AttachableEffect):
|
74 |
+
self.effects = effects
|
75 |
+
|
76 |
+
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
|
77 |
+
for effect in self.effects:
|
78 |
+
chain = effect.attach(chain)
|
79 |
+
return chain
|
80 |
+
|
81 |
+
|
82 |
+
class Choice(AttachableEffect):
|
83 |
+
"""
|
84 |
+
Attach one of the effects randomly.
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(self, *effects: AttachableEffect):
|
88 |
+
self.effects = effects
|
89 |
+
|
90 |
+
def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
|
91 |
+
return random.choice(self.effects).attach(chain)
|
92 |
+
|
93 |
+
|
94 |
+
class Generator:
|
95 |
+
def __call__(self) -> str:
|
96 |
+
raise NotImplementedError
|
97 |
+
|
98 |
+
|
99 |
+
class Uniform(Generator):
|
100 |
+
def __init__(self, low, high):
|
101 |
+
self.low = low
|
102 |
+
self.high = high
|
103 |
+
|
104 |
+
def __call__(self) -> str:
|
105 |
+
return str(random.uniform(self.low, self.high))
|
106 |
+
|
107 |
+
|
108 |
+
class Randint(Generator):
|
109 |
+
def __init__(self, low, high):
|
110 |
+
self.low = low
|
111 |
+
self.high = high
|
112 |
+
|
113 |
+
def __call__(self) -> str:
|
114 |
+
return str(random.randint(self.low, self.high))
|
115 |
+
|
116 |
+
|
117 |
+
class Concat(Generator):
|
118 |
+
def __init__(self, *parts: Generator | str):
|
119 |
+
self.parts = parts
|
120 |
+
|
121 |
+
def __call__(self):
|
122 |
+
return "".join([part if isinstance(part, str) else part() for part in self.parts])
|
123 |
+
|
124 |
+
|
125 |
+
class RandomLowpassDistorter(SoxEffect):
|
126 |
+
def __init__(self, low=2000, high=16000):
|
127 |
+
super().__init__("sinc", "-n", Randint(50, 200), Concat("-", Uniform(low, high)))
|
128 |
+
|
129 |
+
|
130 |
+
class RandomBandpassDistorter(SoxEffect):
|
131 |
+
def __init__(self, low=100, high=1000, min_width=2000, max_width=4000):
|
132 |
+
super().__init__("sinc", "-n", Randint(50, 200), partial(self._fn, low, high, min_width, max_width))
|
133 |
+
|
134 |
+
@staticmethod
|
135 |
+
def _fn(low, high, min_width, max_width):
|
136 |
+
start = random.randint(low, high)
|
137 |
+
stop = start + random.randint(min_width, max_width)
|
138 |
+
return f"{start}-{stop}"
|
139 |
+
|
140 |
+
|
141 |
+
class RandomEqualizer(SoxEffect):
|
142 |
+
def __init__(self, low=100, high=4000, q_low=1, q_high=5, db_low: int = -30, db_high: int = 30):
|
143 |
+
super().__init__(
|
144 |
+
"equalizer",
|
145 |
+
Uniform(low, high),
|
146 |
+
lambda: f"{random.randint(q_low, q_high)}q",
|
147 |
+
lambda: random.randint(db_low, db_high),
|
148 |
+
)
|
149 |
+
|
150 |
+
|
151 |
+
class RandomOverdrive(SoxEffect):
|
152 |
+
def __init__(self, gain_low=5, gain_high=40, colour_low=20, colour_high=80):
|
153 |
+
super().__init__("overdrive", Uniform(gain_low, gain_high), Uniform(colour_low, colour_high))
|
154 |
+
|
155 |
+
|
156 |
+
class RandomReverb(Chain):
|
157 |
+
def __init__(self, deterministic=False):
|
158 |
+
super().__init__(
|
159 |
+
SoxEffect(
|
160 |
+
"reverb",
|
161 |
+
Uniform(50, 50) if deterministic else Uniform(0, 100),
|
162 |
+
Uniform(50, 50) if deterministic else Uniform(0, 100),
|
163 |
+
Uniform(50, 50) if deterministic else Uniform(0, 100),
|
164 |
+
),
|
165 |
+
SoxEffect("channels", 1),
|
166 |
+
)
|
167 |
+
|
168 |
+
|
169 |
+
class Flanger(SoxEffect):
|
170 |
+
def __init__(self):
|
171 |
+
super().__init__("flanger")
|
172 |
+
|
173 |
+
|
174 |
+
class Phaser(SoxEffect):
|
175 |
+
def __init__(self):
|
176 |
+
super().__init__("phaser")
|
resemble-enhance/resemble_enhance/data/utils.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Callable
|
3 |
+
|
4 |
+
from torch import Tensor
|
5 |
+
|
6 |
+
|
7 |
+
def walk_paths(root, suffix):
|
8 |
+
for path in Path(root).iterdir():
|
9 |
+
if path.is_dir():
|
10 |
+
yield from walk_paths(path, suffix)
|
11 |
+
elif path.suffix == suffix:
|
12 |
+
yield path
|
13 |
+
|
14 |
+
|
15 |
+
def rglob_audio_files(path: Path):
|
16 |
+
return list(walk_paths(path, ".wav")) + list(walk_paths(path, ".flac"))
|
17 |
+
|
18 |
+
|
19 |
+
def mix_fg_bg(fg: Tensor, bg: Tensor, alpha: float | Callable[..., float] = 0.5, eps=1e-7):
|
20 |
+
"""
|
21 |
+
Args:
|
22 |
+
fg: (b, t)
|
23 |
+
bg: (b, t)
|
24 |
+
"""
|
25 |
+
assert bg.shape == fg.shape, f"bg.shape != fg.shape: {bg.shape} != {fg.shape}"
|
26 |
+
fg = fg / (fg.abs().max(dim=-1, keepdim=True).values + eps)
|
27 |
+
bg = bg / (bg.abs().max(dim=-1, keepdim=True).values + eps)
|
28 |
+
|
29 |
+
fg_energy = fg.pow(2).sum(dim=-1, keepdim=True)
|
30 |
+
bg_energy = bg.pow(2).sum(dim=-1, keepdim=True)
|
31 |
+
|
32 |
+
fg = fg / (fg_energy + eps).sqrt()
|
33 |
+
bg = bg / (bg_energy + eps).sqrt()
|
34 |
+
|
35 |
+
if callable(alpha):
|
36 |
+
alpha = alpha()
|
37 |
+
|
38 |
+
assert 0 <= alpha <= 1, f"alpha must be between 0 and 1: {alpha}"
|
39 |
+
|
40 |
+
mx = alpha * fg + (1 - alpha) * bg
|
41 |
+
mx = mx / (mx.abs().max(dim=-1, keepdim=True).values + eps)
|
42 |
+
|
43 |
+
return mx
|
resemble-enhance/resemble_enhance/denoiser/__init__.py
ADDED
File without changes
|
resemble-enhance/resemble_enhance/denoiser/__main__.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
|
7 |
+
from .inference import denoise
|
8 |
+
|
9 |
+
|
10 |
+
@torch.inference_mode()
|
11 |
+
def main():
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument("in_dir", type=Path, help="Path to input audio folder")
|
14 |
+
parser.add_argument("out_dir", type=Path, help="Output folder")
|
15 |
+
parser.add_argument("--run_dir", type=Path, default="runs/denoiser", help="Path to run folder")
|
16 |
+
parser.add_argument("--suffix", type=str, default=".wav", help="File suffix")
|
17 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device")
|
18 |
+
args = parser.parse_args()
|
19 |
+
|
20 |
+
for path in args.in_dir.glob(f"**/*{args.suffix}"):
|
21 |
+
print(f"Processing {path} ..")
|
22 |
+
dwav, sr = torchaudio.load(path)
|
23 |
+
hwav, sr = denoise(dwav[0], sr, args.run_dir, args.device)
|
24 |
+
out_path = args.out_dir / path.relative_to(args.in_dir)
|
25 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
26 |
+
torchaudio.save(out_path, hwav[None], sr)
|
27 |
+
|
28 |
+
|
29 |
+
if __name__ == "__main__":
|
30 |
+
main()
|
resemble-enhance/resemble_enhance/denoiser/denoiser.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import Tensor, nn
|
6 |
+
|
7 |
+
from ..melspec import MelSpectrogram
|
8 |
+
from .hparams import HParams
|
9 |
+
from .unet import UNet
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
def _normalize(x: Tensor) -> Tensor:
|
15 |
+
return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7)
|
16 |
+
|
17 |
+
|
18 |
+
class Denoiser(nn.Module):
|
19 |
+
@property
|
20 |
+
def stft_cfg(self) -> dict:
|
21 |
+
hop_size = self.hp.hop_size
|
22 |
+
return dict(hop_length=hop_size, n_fft=hop_size * 4, win_length=hop_size * 4)
|
23 |
+
|
24 |
+
@property
|
25 |
+
def n_fft(self):
|
26 |
+
return self.stft_cfg["n_fft"]
|
27 |
+
|
28 |
+
@property
|
29 |
+
def eps(self):
|
30 |
+
return 1e-7
|
31 |
+
|
32 |
+
def __init__(self, hp: HParams):
|
33 |
+
super().__init__()
|
34 |
+
self.hp = hp
|
35 |
+
self.net = UNet(input_dim=3, output_dim=3)
|
36 |
+
self.mel_fn = MelSpectrogram(hp)
|
37 |
+
|
38 |
+
self.dummy: Tensor
|
39 |
+
self.register_buffer("dummy", torch.zeros(1), persistent=False)
|
40 |
+
|
41 |
+
def to_mel(self, x: Tensor, drop_last=True):
|
42 |
+
"""
|
43 |
+
Args:
|
44 |
+
x: (b t), wavs
|
45 |
+
Returns:
|
46 |
+
o: (b c t), mels
|
47 |
+
"""
|
48 |
+
if drop_last:
|
49 |
+
return self.mel_fn(x)[..., :-1] # (b d t)
|
50 |
+
return self.mel_fn(x)
|
51 |
+
|
52 |
+
def _stft(self, x):
|
53 |
+
"""
|
54 |
+
Args:
|
55 |
+
x: (b t)
|
56 |
+
Returns:
|
57 |
+
mag: (b f t) in [0, inf)
|
58 |
+
cos: (b f t) in [-1, 1]
|
59 |
+
sin: (b f t) in [-1, 1]
|
60 |
+
"""
|
61 |
+
dtype = x.dtype
|
62 |
+
device = x.device
|
63 |
+
|
64 |
+
if x.is_mps:
|
65 |
+
x = x.cpu()
|
66 |
+
|
67 |
+
window = torch.hann_window(self.stft_cfg["win_length"], device=x.device)
|
68 |
+
s = torch.stft(x.float(), **self.stft_cfg, window=window, return_complex=True) # (b f t+1)
|
69 |
+
|
70 |
+
s = s[..., :-1] # (b f t)
|
71 |
+
|
72 |
+
mag = s.abs() # (b f t)
|
73 |
+
|
74 |
+
phi = s.angle() # (b f t)
|
75 |
+
cos = phi.cos() # (b f t)
|
76 |
+
sin = phi.sin() # (b f t)
|
77 |
+
|
78 |
+
mag = mag.to(dtype=dtype, device=device)
|
79 |
+
cos = cos.to(dtype=dtype, device=device)
|
80 |
+
sin = sin.to(dtype=dtype, device=device)
|
81 |
+
|
82 |
+
return mag, cos, sin
|
83 |
+
|
84 |
+
def _istft(self, mag: Tensor, cos: Tensor, sin: Tensor):
|
85 |
+
"""
|
86 |
+
Args:
|
87 |
+
mag: (b f t) in [0, inf)
|
88 |
+
cos: (b f t) in [-1, 1]
|
89 |
+
sin: (b f t) in [-1, 1]
|
90 |
+
Returns:
|
91 |
+
x: (b t)
|
92 |
+
"""
|
93 |
+
device = mag.device
|
94 |
+
dtype = mag.dtype
|
95 |
+
|
96 |
+
if mag.is_mps:
|
97 |
+
mag = mag.cpu()
|
98 |
+
cos = cos.cpu()
|
99 |
+
sin = sin.cpu()
|
100 |
+
|
101 |
+
real = mag * cos # (b f t)
|
102 |
+
imag = mag * sin # (b f t)
|
103 |
+
|
104 |
+
s = torch.complex(real, imag) # (b f t)
|
105 |
+
|
106 |
+
if s.isnan().any():
|
107 |
+
logger.warning("NaN detected in ISTFT input.")
|
108 |
+
|
109 |
+
s = F.pad(s, (0, 1), "replicate") # (b f t+1)
|
110 |
+
|
111 |
+
window = torch.hann_window(self.stft_cfg["win_length"], device=s.device)
|
112 |
+
x = torch.istft(s, **self.stft_cfg, window=window, return_complex=False)
|
113 |
+
|
114 |
+
if x.isnan().any():
|
115 |
+
logger.warning("NaN detected in ISTFT output, set to zero.")
|
116 |
+
x = torch.where(x.isnan(), torch.zeros_like(x), x)
|
117 |
+
|
118 |
+
x = x.to(dtype=dtype, device=device)
|
119 |
+
|
120 |
+
return x
|
121 |
+
|
122 |
+
def _magphase(self, real, imag):
|
123 |
+
mag = (real.pow(2) + imag.pow(2) + self.eps).sqrt()
|
124 |
+
cos = real / mag
|
125 |
+
sin = imag / mag
|
126 |
+
return mag, cos, sin
|
127 |
+
|
128 |
+
def _predict(self, mag: Tensor, cos: Tensor, sin: Tensor):
|
129 |
+
"""
|
130 |
+
Args:
|
131 |
+
mag: (b f t)
|
132 |
+
cos: (b f t)
|
133 |
+
sin: (b f t)
|
134 |
+
Returns:
|
135 |
+
mag_mask: (b f t) in [0, 1], magnitude mask
|
136 |
+
cos_res: (b f t) in [-1, 1], phase residual
|
137 |
+
sin_res: (b f t) in [-1, 1], phase residual
|
138 |
+
"""
|
139 |
+
x = torch.stack([mag, cos, sin], dim=1) # (b 3 f t)
|
140 |
+
mag_mask, real, imag = self.net(x).unbind(1) # (b 3 f t)
|
141 |
+
mag_mask = mag_mask.sigmoid() # (b f t)
|
142 |
+
real = real.tanh() # (b f t)
|
143 |
+
imag = imag.tanh() # (b f t)
|
144 |
+
_, cos_res, sin_res = self._magphase(real, imag) # (b f t)
|
145 |
+
return mag_mask, sin_res, cos_res
|
146 |
+
|
147 |
+
def _separate(self, mag, cos, sin, mag_mask, cos_res, sin_res):
|
148 |
+
"""Ref: https://audio-agi.github.io/Separate-Anything-You-Describe/AudioSep_arXiv.pdf"""
|
149 |
+
sep_mag = F.relu(mag * mag_mask)
|
150 |
+
sep_cos = cos * cos_res - sin * sin_res
|
151 |
+
sep_sin = sin * cos_res + cos * sin_res
|
152 |
+
return sep_mag, sep_cos, sep_sin
|
153 |
+
|
154 |
+
def forward(self, x: Tensor, y: Tensor | None = None):
|
155 |
+
"""
|
156 |
+
Args:
|
157 |
+
x: (b t), a mixed audio
|
158 |
+
y: (b t), a fg audio
|
159 |
+
"""
|
160 |
+
assert x.dim() == 2, f"Expected (b t), got {x.size()}"
|
161 |
+
x = x.to(self.dummy)
|
162 |
+
x = _normalize(x)
|
163 |
+
|
164 |
+
if y is not None:
|
165 |
+
assert y.dim() == 2, f"Expected (b t), got {y.size()}"
|
166 |
+
y = y.to(self.dummy)
|
167 |
+
y = _normalize(y)
|
168 |
+
|
169 |
+
mag, cos, sin = self._stft(x) # (b 2f t)
|
170 |
+
mag_mask, sin_res, cos_res = self._predict(mag, cos, sin)
|
171 |
+
sep_mag, sep_cos, sep_sin = self._separate(mag, cos, sin, mag_mask, cos_res, sin_res)
|
172 |
+
|
173 |
+
o = self._istft(sep_mag, sep_cos, sep_sin)
|
174 |
+
|
175 |
+
npad = x.shape[-1] - o.shape[-1]
|
176 |
+
o = F.pad(o, (0, npad))
|
177 |
+
|
178 |
+
if y is not None:
|
179 |
+
self.losses = dict(l1=F.l1_loss(o, y))
|
180 |
+
|
181 |
+
return o
|
resemble-enhance/resemble_enhance/denoiser/hparams.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
from ..hparams import HParams as HParamsBase
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass(frozen=True)
|
7 |
+
class HParams(HParamsBase):
|
8 |
+
batch_size_per_gpu: int = 128
|
9 |
+
distort_prob: float = 0.5
|
resemble-enhance/resemble_enhance/denoiser/inference.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from functools import cache
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from ..inference import inference
|
7 |
+
from .train import Denoiser, HParams
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
@cache
|
13 |
+
def load_denoiser(run_dir, device):
|
14 |
+
if run_dir is None:
|
15 |
+
return Denoiser(HParams())
|
16 |
+
hp = HParams.load(run_dir)
|
17 |
+
denoiser = Denoiser(hp)
|
18 |
+
path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt"
|
19 |
+
state_dict = torch.load(path, map_location="cpu")["module"]
|
20 |
+
denoiser.load_state_dict(state_dict)
|
21 |
+
denoiser.eval()
|
22 |
+
denoiser.to(device)
|
23 |
+
return denoiser
|
24 |
+
|
25 |
+
|
26 |
+
@torch.inference_mode()
|
27 |
+
def denoise(dwav, sr, run_dir, device):
|
28 |
+
denoiser = load_denoiser(run_dir, device)
|
29 |
+
return inference(model=denoiser, dwav=dwav, sr=sr, device=device)
|
resemble-enhance/resemble_enhance/denoiser/train.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import random
|
3 |
+
from functools import partial
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import soundfile
|
7 |
+
import torch
|
8 |
+
from deepspeed import DeepSpeedConfig
|
9 |
+
from torch import Tensor
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from ..data import create_dataloaders, mix_fg_bg
|
13 |
+
from ..utils import Engine, TrainLoop, save_mels, setup_logging, tree_map
|
14 |
+
from ..utils.distributed import is_local_leader
|
15 |
+
from .denoiser import Denoiser
|
16 |
+
from .hparams import HParams
|
17 |
+
|
18 |
+
|
19 |
+
def load_G(run_dir: Path, hp: HParams | None = None, training=True):
|
20 |
+
if hp is None:
|
21 |
+
hp = HParams.load(run_dir)
|
22 |
+
assert isinstance(hp, HParams)
|
23 |
+
model = Denoiser(hp)
|
24 |
+
engine = Engine(model=model, config_class=DeepSpeedConfig(hp.deepspeed_config), ckpt_dir=run_dir / "ds" / "G")
|
25 |
+
if training:
|
26 |
+
engine.load_checkpoint()
|
27 |
+
else:
|
28 |
+
engine.load_checkpoint(load_optimizer_states=False, load_lr_scheduler_states=False)
|
29 |
+
return engine
|
30 |
+
|
31 |
+
|
32 |
+
def save_wav(path: Path, wav: Tensor, rate: int):
|
33 |
+
wav = wav.detach().cpu().numpy()
|
34 |
+
soundfile.write(path, wav, samplerate=rate)
|
35 |
+
|
36 |
+
|
37 |
+
def main():
|
38 |
+
parser = argparse.ArgumentParser()
|
39 |
+
parser.add_argument("run_dir", type=Path)
|
40 |
+
parser.add_argument("--yaml", type=Path, default=None)
|
41 |
+
parser.add_argument("--device", type=str, default="cuda")
|
42 |
+
args = parser.parse_args()
|
43 |
+
|
44 |
+
setup_logging(args.run_dir)
|
45 |
+
hp = HParams.load(args.run_dir, yaml=args.yaml)
|
46 |
+
|
47 |
+
if is_local_leader():
|
48 |
+
hp.save_if_not_exists(args.run_dir)
|
49 |
+
hp.print()
|
50 |
+
|
51 |
+
train_dl, val_dl = create_dataloaders(hp, mode="denoiser")
|
52 |
+
|
53 |
+
def feed_G(engine: Engine, batch: dict[str, Tensor]):
|
54 |
+
alpha_fn = lambda: random.uniform(*hp.mix_alpha_range)
|
55 |
+
if random.random() < hp.distort_prob:
|
56 |
+
fg_wavs = batch["fg_dwavs"]
|
57 |
+
else:
|
58 |
+
fg_wavs = batch["fg_wavs"]
|
59 |
+
mx_dwavs = mix_fg_bg(fg_wavs, batch["bg_dwavs"], alpha=alpha_fn)
|
60 |
+
pred = engine(mx_dwavs, fg_wavs)
|
61 |
+
losses = engine.gather_attribute("losses", prefix="losses")
|
62 |
+
return pred, losses
|
63 |
+
|
64 |
+
@torch.no_grad()
|
65 |
+
def eval_fn(engine: Engine, eval_dir, n_saved=10):
|
66 |
+
model = engine.module
|
67 |
+
model.eval()
|
68 |
+
|
69 |
+
step = engine.global_step
|
70 |
+
|
71 |
+
for i, batch in enumerate(tqdm(val_dl), 1):
|
72 |
+
batch = tree_map(lambda x: x.to(args.device) if isinstance(x, Tensor) else x, batch)
|
73 |
+
|
74 |
+
fg_dwavs = batch["fg_dwavs"] # 1 t
|
75 |
+
mx_dwavs = mix_fg_bg(fg_dwavs, batch["bg_dwavs"])
|
76 |
+
pred_fg_dwavs = model(mx_dwavs) # 1 t
|
77 |
+
|
78 |
+
mx_mels = model.to_mel(mx_dwavs) # 1 c t
|
79 |
+
fg_mels = model.to_mel(fg_dwavs) # 1 c t
|
80 |
+
pred_fg_mels = model.to_mel(pred_fg_dwavs) # 1 c t
|
81 |
+
|
82 |
+
rate = model.hp.wav_rate
|
83 |
+
get_path = lambda suffix: eval_dir / f"step_{step:08}_{i:03}{suffix}"
|
84 |
+
|
85 |
+
save_wav(get_path("_input.wav"), mx_dwavs[0], rate=rate)
|
86 |
+
save_wav(get_path("_predict.wav"), pred_fg_dwavs[0], rate=rate)
|
87 |
+
save_wav(get_path("_target.wav"), fg_dwavs[0], rate=rate)
|
88 |
+
|
89 |
+
save_mels(
|
90 |
+
get_path(".png"),
|
91 |
+
cond_mel=mx_mels[0].cpu().numpy(),
|
92 |
+
pred_mel=pred_fg_mels[0].cpu().numpy(),
|
93 |
+
targ_mel=fg_mels[0].cpu().numpy(),
|
94 |
+
)
|
95 |
+
|
96 |
+
if i >= n_saved:
|
97 |
+
break
|
98 |
+
|
99 |
+
train_loop = TrainLoop(
|
100 |
+
run_dir=args.run_dir,
|
101 |
+
train_dl=train_dl,
|
102 |
+
load_G=partial(load_G, hp=hp),
|
103 |
+
device=args.device,
|
104 |
+
feed_G=feed_G,
|
105 |
+
eval_fn=eval_fn,
|
106 |
+
)
|
107 |
+
|
108 |
+
train_loop.run(max_steps=hp.max_steps)
|
109 |
+
|
110 |
+
|
111 |
+
if __name__ == "__main__":
|
112 |
+
main()
|
resemble-enhance/resemble_enhance/denoiser/unet.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn.functional as F
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class PreactResBlock(nn.Sequential):
|
6 |
+
def __init__(self, dim):
|
7 |
+
super().__init__(
|
8 |
+
nn.GroupNorm(dim // 16, dim),
|
9 |
+
nn.GELU(),
|
10 |
+
nn.Conv2d(dim, dim, 3, padding=1),
|
11 |
+
nn.GroupNorm(dim // 16, dim),
|
12 |
+
nn.GELU(),
|
13 |
+
nn.Conv2d(dim, dim, 3, padding=1),
|
14 |
+
)
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
return x + super().forward(x)
|
18 |
+
|
19 |
+
|
20 |
+
class UNetBlock(nn.Module):
|
21 |
+
def __init__(self, input_dim, output_dim=None, scale_factor=1.0):
|
22 |
+
super().__init__()
|
23 |
+
if output_dim is None:
|
24 |
+
output_dim = input_dim
|
25 |
+
self.pre_conv = nn.Conv2d(input_dim, output_dim, 3, padding=1)
|
26 |
+
self.res_block1 = PreactResBlock(output_dim)
|
27 |
+
self.res_block2 = PreactResBlock(output_dim)
|
28 |
+
self.downsample = self.upsample = nn.Identity()
|
29 |
+
if scale_factor > 1:
|
30 |
+
self.upsample = nn.Upsample(scale_factor=scale_factor)
|
31 |
+
elif scale_factor < 1:
|
32 |
+
self.downsample = nn.Upsample(scale_factor=scale_factor)
|
33 |
+
|
34 |
+
def forward(self, x, h=None):
|
35 |
+
"""
|
36 |
+
Args:
|
37 |
+
x: (b c h w), last output
|
38 |
+
h: (b c h w), skip output
|
39 |
+
Returns:
|
40 |
+
o: (b c h w), output
|
41 |
+
s: (b c h w), skip output
|
42 |
+
"""
|
43 |
+
x = self.upsample(x)
|
44 |
+
if h is not None:
|
45 |
+
assert x.shape == h.shape, f"{x.shape} != {h.shape}"
|
46 |
+
x = x + h
|
47 |
+
x = self.pre_conv(x)
|
48 |
+
x = self.res_block1(x)
|
49 |
+
x = self.res_block2(x)
|
50 |
+
return self.downsample(x), x
|
51 |
+
|
52 |
+
|
53 |
+
class UNet(nn.Module):
|
54 |
+
def __init__(self, input_dim, output_dim, hidden_dim=16, num_blocks=4, num_middle_blocks=2):
|
55 |
+
super().__init__()
|
56 |
+
self.input_dim = input_dim
|
57 |
+
self.output_dim = output_dim
|
58 |
+
self.input_proj = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
59 |
+
self.encoder_blocks = nn.ModuleList(
|
60 |
+
[
|
61 |
+
UNetBlock(input_dim=hidden_dim * 2**i, output_dim=hidden_dim * 2 ** (i + 1), scale_factor=0.5)
|
62 |
+
for i in range(num_blocks)
|
63 |
+
]
|
64 |
+
)
|
65 |
+
self.middle_blocks = nn.ModuleList(
|
66 |
+
[UNetBlock(input_dim=hidden_dim * 2**num_blocks) for _ in range(num_middle_blocks)]
|
67 |
+
)
|
68 |
+
self.decoder_blocks = nn.ModuleList(
|
69 |
+
[
|
70 |
+
UNetBlock(input_dim=hidden_dim * 2 ** (i + 1), output_dim=hidden_dim * 2**i, scale_factor=2)
|
71 |
+
for i in reversed(range(num_blocks))
|
72 |
+
]
|
73 |
+
)
|
74 |
+
self.head = nn.Sequential(
|
75 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
|
76 |
+
nn.GELU(),
|
77 |
+
nn.Conv2d(hidden_dim, output_dim, 1),
|
78 |
+
)
|
79 |
+
|
80 |
+
@property
|
81 |
+
def scale_factor(self):
|
82 |
+
return 2 ** len(self.encoder_blocks)
|
83 |
+
|
84 |
+
def pad_to_fit(self, x):
|
85 |
+
"""
|
86 |
+
Args:
|
87 |
+
x: (b c h w), input
|
88 |
+
Returns:
|
89 |
+
x: (b c h' w'), padded input
|
90 |
+
"""
|
91 |
+
hpad = (self.scale_factor - x.shape[2] % self.scale_factor) % self.scale_factor
|
92 |
+
wpad = (self.scale_factor - x.shape[3] % self.scale_factor) % self.scale_factor
|
93 |
+
return F.pad(x, (0, wpad, 0, hpad))
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
"""
|
97 |
+
Args:
|
98 |
+
x: (b c h w), input
|
99 |
+
Returns:
|
100 |
+
o: (b c h w), output
|
101 |
+
"""
|
102 |
+
shape = x.shape
|
103 |
+
|
104 |
+
x = self.pad_to_fit(x)
|
105 |
+
x = self.input_proj(x)
|
106 |
+
|
107 |
+
s_list = []
|
108 |
+
for block in self.encoder_blocks:
|
109 |
+
x, s = block(x)
|
110 |
+
s_list.append(s)
|
111 |
+
|
112 |
+
for block in self.middle_blocks:
|
113 |
+
x, _ = block(x)
|
114 |
+
|
115 |
+
for block, s in zip(self.decoder_blocks, reversed(s_list)):
|
116 |
+
x, _ = block(x, s)
|
117 |
+
|
118 |
+
x = self.head(x)
|
119 |
+
x = x[..., : shape[2], : shape[3]]
|
120 |
+
|
121 |
+
return x
|
122 |
+
|
123 |
+
def test(self, shape=(3, 512, 256)):
|
124 |
+
import ptflops
|
125 |
+
|
126 |
+
macs, params = ptflops.get_model_complexity_info(
|
127 |
+
self,
|
128 |
+
shape,
|
129 |
+
as_strings=True,
|
130 |
+
print_per_layer_stat=True,
|
131 |
+
verbose=True,
|
132 |
+
)
|
133 |
+
|
134 |
+
print(f"macs: {macs}")
|
135 |
+
print(f"params: {params}")
|
136 |
+
|
137 |
+
|
138 |
+
def main():
|
139 |
+
model = UNet(3, 3)
|
140 |
+
model.test()
|
141 |
+
|
142 |
+
|
143 |
+
if __name__ == "__main__":
|
144 |
+
main()
|
resemble-enhance/resemble_enhance/enhancer/__init__.py
ADDED
File without changes
|
resemble-enhance/resemble_enhance/enhancer/__main__.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torchaudio
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from .inference import denoise, enhance
|
11 |
+
|
12 |
+
|
13 |
+
@torch.inference_mode()
|
14 |
+
def main():
|
15 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
16 |
+
parser.add_argument("in_dir", type=Path, help="Path to input audio folder")
|
17 |
+
parser.add_argument("out_dir", type=Path, help="Output folder")
|
18 |
+
parser.add_argument(
|
19 |
+
"--run_dir",
|
20 |
+
type=Path,
|
21 |
+
default=None,
|
22 |
+
help="Path to the enhancer run folder, if None, use the default model",
|
23 |
+
)
|
24 |
+
parser.add_argument(
|
25 |
+
"--suffix",
|
26 |
+
type=str,
|
27 |
+
default=".wav",
|
28 |
+
help="Audio file suffix",
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--device",
|
32 |
+
type=str,
|
33 |
+
default="cuda",
|
34 |
+
help="Device to use for computation, recommended to use CUDA",
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--denoise_only",
|
38 |
+
action="store_true",
|
39 |
+
help="Only apply denoising without enhancement",
|
40 |
+
)
|
41 |
+
parser.add_argument(
|
42 |
+
"--lambd",
|
43 |
+
type=float,
|
44 |
+
default=1.0,
|
45 |
+
help="Denoise strength for enhancement (0.0 to 1.0)",
|
46 |
+
)
|
47 |
+
parser.add_argument(
|
48 |
+
"--tau",
|
49 |
+
type=float,
|
50 |
+
default=0.5,
|
51 |
+
help="CFM prior temperature (0.0 to 1.0)",
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
"--solver",
|
55 |
+
type=str,
|
56 |
+
default="midpoint",
|
57 |
+
choices=["midpoint", "rk4", "euler"],
|
58 |
+
help="Numerical solver to use",
|
59 |
+
)
|
60 |
+
parser.add_argument(
|
61 |
+
"--nfe",
|
62 |
+
type=int,
|
63 |
+
default=64,
|
64 |
+
help="Number of function evaluations",
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--parallel_mode",
|
68 |
+
action="store_true",
|
69 |
+
help="Shuffle the audio paths and skip the existing ones, enabling multiple jobs to run in parallel",
|
70 |
+
)
|
71 |
+
|
72 |
+
args = parser.parse_args()
|
73 |
+
|
74 |
+
device = args.device
|
75 |
+
|
76 |
+
if device == "cuda" and not torch.cuda.is_available():
|
77 |
+
print("CUDA is not available but --device is set to cuda, using CPU instead")
|
78 |
+
device = "cpu"
|
79 |
+
|
80 |
+
start_time = time.perf_counter()
|
81 |
+
|
82 |
+
run_dir = args.run_dir
|
83 |
+
|
84 |
+
paths = sorted(args.in_dir.glob(f"**/*{args.suffix}"))
|
85 |
+
|
86 |
+
if args.parallel_mode:
|
87 |
+
random.shuffle(paths)
|
88 |
+
|
89 |
+
if len(paths) == 0:
|
90 |
+
print(f"No {args.suffix} files found in the following path: {args.in_dir}")
|
91 |
+
return
|
92 |
+
|
93 |
+
pbar = tqdm(paths)
|
94 |
+
|
95 |
+
for path in pbar:
|
96 |
+
out_path = args.out_dir / path.relative_to(args.in_dir)
|
97 |
+
if args.parallel_mode and out_path.exists():
|
98 |
+
continue
|
99 |
+
pbar.set_description(f"Processing {out_path}")
|
100 |
+
dwav, sr = torchaudio.load(path)
|
101 |
+
dwav = dwav.mean(0)
|
102 |
+
if args.denoise_only:
|
103 |
+
hwav, sr = denoise(
|
104 |
+
dwav=dwav,
|
105 |
+
sr=sr,
|
106 |
+
device=device,
|
107 |
+
run_dir=args.run_dir,
|
108 |
+
)
|
109 |
+
else:
|
110 |
+
hwav, sr = enhance(
|
111 |
+
dwav=dwav,
|
112 |
+
sr=sr,
|
113 |
+
device=device,
|
114 |
+
nfe=args.nfe,
|
115 |
+
solver=args.solver,
|
116 |
+
lambd=args.lambd,
|
117 |
+
tau=args.tau,
|
118 |
+
run_dir=run_dir,
|
119 |
+
)
|
120 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
121 |
+
torchaudio.save(out_path, hwav[None], sr)
|
122 |
+
|
123 |
+
# Cool emoji effect saying the job is done
|
124 |
+
elapsed_time = time.perf_counter() - start_time
|
125 |
+
print(f"🌟 Enhancement done! {len(paths)} files processed in {elapsed_time:.2f}s")
|
126 |
+
|
127 |
+
|
128 |
+
if __name__ == "__main__":
|
129 |
+
main()
|
resemble-enhance/resemble_enhance/enhancer/download.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
RUN_NAME = "enhancer_stage2"
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
|
11 |
+
def get_source_url(relpath):
|
12 |
+
return f"https://huggingface.co/ResembleAI/resemble-enhance/resolve/main/{RUN_NAME}/{relpath}?download=true"
|
13 |
+
|
14 |
+
|
15 |
+
def get_target_path(relpath: str | Path, run_dir: str | Path | None = None):
|
16 |
+
if run_dir is None:
|
17 |
+
run_dir = Path(__file__).parent.parent / "model_repo" / RUN_NAME
|
18 |
+
return Path(run_dir) / relpath
|
19 |
+
|
20 |
+
|
21 |
+
def download(run_dir: str | Path | None = None):
|
22 |
+
relpaths = ["hparams.yaml", "ds/G/latest", "ds/G/default/mp_rank_00_model_states.pt"]
|
23 |
+
for relpath in relpaths:
|
24 |
+
path = get_target_path(relpath, run_dir=run_dir)
|
25 |
+
if path.exists():
|
26 |
+
continue
|
27 |
+
url = get_source_url(relpath)
|
28 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
29 |
+
torch.hub.download_url_to_file(url, str(path))
|
30 |
+
return get_target_path("", run_dir=run_dir)
|
resemble-enhance/resemble_enhance/enhancer/enhancer.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import pandas as pd
|
5 |
+
import torch
|
6 |
+
from torch import Tensor, nn
|
7 |
+
from torch.distributions import Beta
|
8 |
+
|
9 |
+
from ..common import Normalizer
|
10 |
+
from ..denoiser.inference import load_denoiser
|
11 |
+
from ..melspec import MelSpectrogram
|
12 |
+
from ..utils.distributed import global_leader_only
|
13 |
+
from ..utils.train_loop import TrainLoop
|
14 |
+
from .hparams import HParams
|
15 |
+
from .lcfm import CFM, IRMAE, LCFM
|
16 |
+
from .univnet import UnivNet
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
def _maybe(fn):
|
22 |
+
def _fn(*args):
|
23 |
+
if args[0] is None:
|
24 |
+
return None
|
25 |
+
return fn(*args)
|
26 |
+
|
27 |
+
return _fn
|
28 |
+
|
29 |
+
|
30 |
+
def _normalize_wav(x: Tensor):
|
31 |
+
return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7)
|
32 |
+
|
33 |
+
|
34 |
+
class Enhancer(nn.Module):
|
35 |
+
def __init__(self, hp: HParams):
|
36 |
+
super().__init__()
|
37 |
+
self.hp = hp
|
38 |
+
|
39 |
+
n_mels = self.hp.num_mels
|
40 |
+
vocoder_input_dim = n_mels + self.hp.vocoder_extra_dim
|
41 |
+
latent_dim = self.hp.lcfm_latent_dim
|
42 |
+
|
43 |
+
self.lcfm = LCFM(
|
44 |
+
IRMAE(
|
45 |
+
input_dim=n_mels,
|
46 |
+
output_dim=vocoder_input_dim,
|
47 |
+
latent_dim=latent_dim,
|
48 |
+
),
|
49 |
+
CFM(
|
50 |
+
cond_dim=n_mels,
|
51 |
+
output_dim=self.hp.lcfm_latent_dim,
|
52 |
+
solver_nfe=self.hp.cfm_solver_nfe,
|
53 |
+
solver_method=self.hp.cfm_solver_method,
|
54 |
+
time_mapping_divisor=self.hp.cfm_time_mapping_divisor,
|
55 |
+
),
|
56 |
+
z_scale=self.hp.lcfm_z_scale,
|
57 |
+
)
|
58 |
+
|
59 |
+
self.lcfm.set_mode_(self.hp.lcfm_training_mode)
|
60 |
+
|
61 |
+
self.mel_fn = MelSpectrogram(hp)
|
62 |
+
self.vocoder = UnivNet(self.hp, vocoder_input_dim)
|
63 |
+
self.denoiser = load_denoiser(self.hp.denoiser_run_dir, "cpu")
|
64 |
+
self.normalizer = Normalizer()
|
65 |
+
|
66 |
+
self._eval_lambd = 0.0
|
67 |
+
|
68 |
+
self.dummy: Tensor
|
69 |
+
self.register_buffer("dummy", torch.zeros(1))
|
70 |
+
|
71 |
+
if self.hp.enhancer_stage1_run_dir is not None:
|
72 |
+
pretrained_path = self.hp.enhancer_stage1_run_dir / "ds/G/default/mp_rank_00_model_states.pt"
|
73 |
+
self._load_pretrained(pretrained_path)
|
74 |
+
|
75 |
+
logger.info(f"{self.__class__.__name__} summary")
|
76 |
+
logger.info(f"{self.summarize()}")
|
77 |
+
|
78 |
+
def _load_pretrained(self, path):
|
79 |
+
# Clone is necessary as otherwise it holds a reference to the original model
|
80 |
+
cfm_state_dict = {k: v.clone() for k, v in self.lcfm.cfm.state_dict().items()}
|
81 |
+
denoiser_state_dict = {k: v.clone() for k, v in self.denoiser.state_dict().items()}
|
82 |
+
state_dict = torch.load(path, map_location="cpu")["module"]
|
83 |
+
self.load_state_dict(state_dict, strict=False)
|
84 |
+
self.lcfm.cfm.load_state_dict(cfm_state_dict) # Reset cfm
|
85 |
+
self.denoiser.load_state_dict(denoiser_state_dict) # Reset denoiser
|
86 |
+
logger.info(f"Loaded pretrained model from {path}")
|
87 |
+
|
88 |
+
def summarize(self):
|
89 |
+
npa_train = lambda m: sum(p.numel() for p in m.parameters() if p.requires_grad)
|
90 |
+
npa = lambda m: sum(p.numel() for p in m.parameters())
|
91 |
+
rows = []
|
92 |
+
for name, module in self.named_children():
|
93 |
+
rows.append(dict(name=name, trainable=npa_train(module), total=npa(module)))
|
94 |
+
rows.append(dict(name="total", trainable=npa_train(self), total=npa(self)))
|
95 |
+
df = pd.DataFrame(rows)
|
96 |
+
return df.to_markdown(index=False)
|
97 |
+
|
98 |
+
def to_mel(self, x: Tensor, drop_last=True):
|
99 |
+
"""
|
100 |
+
Args:
|
101 |
+
x: (b t), wavs
|
102 |
+
Returns:
|
103 |
+
o: (b c t), mels
|
104 |
+
"""
|
105 |
+
if drop_last:
|
106 |
+
return self.mel_fn(x)[..., :-1] # (b d t)
|
107 |
+
return self.mel_fn(x)
|
108 |
+
|
109 |
+
@global_leader_only
|
110 |
+
@torch.no_grad()
|
111 |
+
def _visualize(self, original_mel, denoised_mel):
|
112 |
+
loop = TrainLoop.get_running_loop()
|
113 |
+
if loop is None or loop.global_step % 100 != 0:
|
114 |
+
return
|
115 |
+
|
116 |
+
plt.figure(figsize=(6, 6))
|
117 |
+
plt.subplot(211)
|
118 |
+
plt.title("Original")
|
119 |
+
plt.imshow(original_mel[0].cpu().numpy(), origin="lower", interpolation="none")
|
120 |
+
plt.subplot(212)
|
121 |
+
plt.title("Denoised")
|
122 |
+
plt.imshow(denoised_mel[0].cpu().numpy(), origin="lower", interpolation="none")
|
123 |
+
plt.tight_layout()
|
124 |
+
|
125 |
+
path = loop.get_running_loop_viz_path("input", ".png")
|
126 |
+
plt.savefig(path, dpi=300)
|
127 |
+
|
128 |
+
def _may_denoise(self, x: Tensor, y: Tensor | None = None):
|
129 |
+
if self.hp.lcfm_training_mode == "cfm":
|
130 |
+
return self.denoiser(x, y)
|
131 |
+
return x
|
132 |
+
|
133 |
+
def configurate_(self, nfe, solver, lambd, tau):
|
134 |
+
"""
|
135 |
+
Args:
|
136 |
+
nfe: number of function evaluations
|
137 |
+
solver: solver method
|
138 |
+
lambd: denoiser strength [0, 1]
|
139 |
+
tau: prior temperature [0, 1]
|
140 |
+
"""
|
141 |
+
self.lcfm.cfm.solver.configurate_(nfe, solver)
|
142 |
+
self.lcfm.eval_tau_(tau)
|
143 |
+
self._eval_lambd = lambd
|
144 |
+
|
145 |
+
def forward(self, x: Tensor, y: Tensor | None = None, z: Tensor | None = None):
|
146 |
+
"""
|
147 |
+
Args:
|
148 |
+
x: (b t), mix wavs (fg + bg)
|
149 |
+
y: (b t), fg clean wavs
|
150 |
+
z: (b t), fg distorted wavs
|
151 |
+
Returns:
|
152 |
+
o: (b t), reconstructed wavs
|
153 |
+
"""
|
154 |
+
assert x.dim() == 2, f"Expected (b t), got {x.size()}"
|
155 |
+
assert y is None or y.dim() == 2, f"Expected (b t), got {y.size()}"
|
156 |
+
|
157 |
+
if self.hp.lcfm_training_mode == "cfm":
|
158 |
+
self.normalizer.eval()
|
159 |
+
|
160 |
+
x = _normalize_wav(x)
|
161 |
+
y = _maybe(_normalize_wav)(y)
|
162 |
+
z = _maybe(_normalize_wav)(z)
|
163 |
+
|
164 |
+
x_mel_original = self.normalizer(self.to_mel(x), update=False) # (b d t)
|
165 |
+
|
166 |
+
if self.hp.lcfm_training_mode == "cfm":
|
167 |
+
if self.training:
|
168 |
+
lambd = Beta(0.2, 0.2).sample(x.shape[:1]).to(x.device)
|
169 |
+
lambd = lambd[:, None, None]
|
170 |
+
x_mel_denoised = self.normalizer(self.to_mel(self._may_denoise(x, z)), update=False)
|
171 |
+
x_mel_denoised = x_mel_denoised.detach()
|
172 |
+
x_mel_denoised = lambd * x_mel_denoised + (1 - lambd) * x_mel_original
|
173 |
+
self._visualize(x_mel_original, x_mel_denoised)
|
174 |
+
else:
|
175 |
+
lambd = self._eval_lambd
|
176 |
+
if lambd == 0:
|
177 |
+
x_mel_denoised = x_mel_original
|
178 |
+
else:
|
179 |
+
x_mel_denoised = self.normalizer(self.to_mel(self._may_denoise(x, z)), update=False)
|
180 |
+
x_mel_denoised = x_mel_denoised.detach()
|
181 |
+
x_mel_denoised = lambd * x_mel_denoised + (1 - lambd) * x_mel_original
|
182 |
+
else:
|
183 |
+
x_mel_denoised = x_mel_original
|
184 |
+
|
185 |
+
y_mel = _maybe(self.to_mel)(y) # (b d t)
|
186 |
+
y_mel = _maybe(self.normalizer)(y_mel)
|
187 |
+
|
188 |
+
lcfm_decoded = self.lcfm(x_mel_denoised, y_mel, ψ0=x_mel_original) # (b d t)
|
189 |
+
|
190 |
+
if lcfm_decoded is None:
|
191 |
+
o = None
|
192 |
+
else:
|
193 |
+
o = self.vocoder(lcfm_decoded, y)
|
194 |
+
|
195 |
+
return o
|
resemble-enhance/resemble_enhance/enhancer/hparams.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
from ..hparams import HParams as HParamsBase
|
5 |
+
|
6 |
+
|
7 |
+
@dataclass(frozen=True)
|
8 |
+
class HParams(HParamsBase):
|
9 |
+
cfm_solver_method: str = "midpoint"
|
10 |
+
cfm_solver_nfe: int = 64
|
11 |
+
cfm_time_mapping_divisor: int = 4
|
12 |
+
univnet_nc: int = 96
|
13 |
+
|
14 |
+
lcfm_latent_dim: int = 64
|
15 |
+
lcfm_training_mode: str = "ae"
|
16 |
+
lcfm_z_scale: float = 5
|
17 |
+
|
18 |
+
vocoder_extra_dim: int = 32
|
19 |
+
|
20 |
+
gan_training_start_step: int | None = 5_000
|
21 |
+
enhancer_stage1_run_dir: Path | None = None
|
22 |
+
|
23 |
+
denoiser_run_dir: Path | None = None
|
resemble-enhance/resemble_enhance/enhancer/inference.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from functools import cache
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from ..inference import inference
|
8 |
+
from .download import download
|
9 |
+
from .train import Enhancer, HParams
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
@cache
|
15 |
+
def load_enhancer(run_dir: str | Path | None, device):
|
16 |
+
run_dir = download(run_dir)
|
17 |
+
hp = HParams.load(run_dir)
|
18 |
+
enhancer = Enhancer(hp)
|
19 |
+
path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt"
|
20 |
+
state_dict = torch.load(path, map_location="cpu")["module"]
|
21 |
+
enhancer.load_state_dict(state_dict)
|
22 |
+
enhancer.eval()
|
23 |
+
enhancer.to(device)
|
24 |
+
return enhancer
|
25 |
+
|
26 |
+
|
27 |
+
@torch.inference_mode()
|
28 |
+
def denoise(dwav, sr, device, run_dir=None):
|
29 |
+
enhancer = load_enhancer(run_dir, device)
|
30 |
+
return inference(model=enhancer.denoiser, dwav=dwav, sr=sr, device=device)
|
31 |
+
|
32 |
+
|
33 |
+
@torch.inference_mode()
|
34 |
+
def enhance(dwav, sr, device, nfe=32, solver="midpoint", lambd=0.5, tau=0.5, run_dir=None):
|
35 |
+
assert 0 < nfe <= 128, f"nfe must be in (0, 128], got {nfe}"
|
36 |
+
assert solver in ("midpoint", "rk4", "euler"), f"solver must be in ('midpoint', 'rk4', 'euler'), got {solver}"
|
37 |
+
assert 0 <= lambd <= 1, f"lambd must be in [0, 1], got {lambd}"
|
38 |
+
assert 0 <= tau <= 1, f"tau must be in [0, 1], got {tau}"
|
39 |
+
enhancer = load_enhancer(run_dir, device)
|
40 |
+
enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau)
|
41 |
+
return inference(model=enhancer, dwav=dwav, sr=sr, device=device)
|
resemble-enhance/resemble_enhance/enhancer/lcfm/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .irmae import IRMAE
|
2 |
+
from .lcfm import CFM, LCFM
|
resemble-enhance/resemble_enhance/enhancer/lcfm/cfm.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from functools import partial
|
4 |
+
from typing import Protocol
|
5 |
+
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import numpy as np
|
8 |
+
import scipy
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch import Tensor, nn
|
12 |
+
from tqdm import trange
|
13 |
+
|
14 |
+
from .wn import WN
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
class VelocityField(Protocol):
|
20 |
+
def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor:
|
21 |
+
...
|
22 |
+
|
23 |
+
|
24 |
+
class Solver:
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
method="midpoint",
|
28 |
+
nfe=32,
|
29 |
+
viz_name="solver",
|
30 |
+
viz_every=100,
|
31 |
+
mel_fn=None,
|
32 |
+
time_mapping_divisor=4,
|
33 |
+
verbose=False,
|
34 |
+
):
|
35 |
+
self.configurate_(nfe=nfe, method=method)
|
36 |
+
|
37 |
+
self.verbose = verbose
|
38 |
+
self.viz_every = viz_every
|
39 |
+
self.viz_name = viz_name
|
40 |
+
|
41 |
+
self._camera = None
|
42 |
+
self._mel_fn = mel_fn
|
43 |
+
self._time_mapping = partial(self.exponential_decay_mapping, n=time_mapping_divisor)
|
44 |
+
|
45 |
+
def configurate_(self, nfe=None, method=None):
|
46 |
+
if nfe is None:
|
47 |
+
nfe = self.nfe
|
48 |
+
|
49 |
+
if method is None:
|
50 |
+
method = self.method
|
51 |
+
|
52 |
+
if nfe == 1 and method in ("midpoint", "rk4"):
|
53 |
+
logger.warning(f"1 NFE is not supported for {method}, using euler method instead.")
|
54 |
+
method = "euler"
|
55 |
+
|
56 |
+
self.nfe = nfe
|
57 |
+
self.method = method
|
58 |
+
|
59 |
+
@property
|
60 |
+
def time_mapping(self):
|
61 |
+
return self._time_mapping
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def exponential_decay_mapping(t, n=4):
|
65 |
+
"""
|
66 |
+
Args:
|
67 |
+
n: target step
|
68 |
+
"""
|
69 |
+
|
70 |
+
def h(t, a):
|
71 |
+
return (a**t - 1) / (a - 1)
|
72 |
+
|
73 |
+
# Solve h(1/n) = 0.5
|
74 |
+
a = float(scipy.optimize.fsolve(lambda a: h(1 / n, a) - 0.5, x0=0))
|
75 |
+
|
76 |
+
t = h(t, a=a)
|
77 |
+
|
78 |
+
return t
|
79 |
+
|
80 |
+
@torch.no_grad()
|
81 |
+
def _maybe_camera_snap(self, *, ψt, t):
|
82 |
+
camera = self._camera
|
83 |
+
if camera is not None:
|
84 |
+
if ψt.shape[1] == 1:
|
85 |
+
# Waveform, b 1 t, plot every 100 samples
|
86 |
+
plt.subplot(211)
|
87 |
+
plt.plot(ψt.detach().cpu().numpy()[0, 0, ::100], color="blue")
|
88 |
+
if self._mel_fn is not None:
|
89 |
+
plt.subplot(212)
|
90 |
+
mel = self._mel_fn(ψt.detach().cpu().numpy()[0, 0])
|
91 |
+
plt.imshow(mel, origin="lower", interpolation="none")
|
92 |
+
elif ψt.shape[1] == 2:
|
93 |
+
# Complex
|
94 |
+
plt.subplot(121)
|
95 |
+
plt.imshow(
|
96 |
+
ψt.detach().cpu().numpy()[0, 0],
|
97 |
+
origin="lower",
|
98 |
+
interpolation="none",
|
99 |
+
)
|
100 |
+
plt.subplot(122)
|
101 |
+
plt.imshow(
|
102 |
+
ψt.detach().cpu().numpy()[0, 1],
|
103 |
+
origin="lower",
|
104 |
+
interpolation="none",
|
105 |
+
)
|
106 |
+
else:
|
107 |
+
# Spectrogram, b c t
|
108 |
+
plt.imshow(ψt.detach().cpu().numpy()[0], origin="lower", interpolation="none")
|
109 |
+
ax = plt.gca()
|
110 |
+
ax.text(0.5, 1.01, f"t={t:.2f}", transform=ax.transAxes, ha="center")
|
111 |
+
camera.snap()
|
112 |
+
|
113 |
+
@staticmethod
|
114 |
+
def _euler_step(t, ψt, dt, f: VelocityField):
|
115 |
+
return ψt + dt * f(t=t, ψt=ψt, dt=dt)
|
116 |
+
|
117 |
+
@staticmethod
|
118 |
+
def _midpoint_step(t, ψt, dt, f: VelocityField):
|
119 |
+
return ψt + dt * f(t=t + dt / 2, ψt=ψt + dt * f(t=t, ψt=ψt, dt=dt) / 2, dt=dt)
|
120 |
+
|
121 |
+
@staticmethod
|
122 |
+
def _rk4_step(t, ψt, dt, f: VelocityField):
|
123 |
+
k1 = f(t=t, ψt=ψt, dt=dt)
|
124 |
+
k2 = f(t=t + dt / 2, ψt=ψt + dt * k1 / 2, dt=dt)
|
125 |
+
k3 = f(t=t + dt / 2, ψt=ψt + dt * k2 / 2, dt=dt)
|
126 |
+
k4 = f(t=t + dt, ψt=ψt + dt * k3, dt=dt)
|
127 |
+
return ψt + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6
|
128 |
+
|
129 |
+
@property
|
130 |
+
def _step(self):
|
131 |
+
if self.method == "euler":
|
132 |
+
return self._euler_step
|
133 |
+
elif self.method == "midpoint":
|
134 |
+
return self._midpoint_step
|
135 |
+
elif self.method == "rk4":
|
136 |
+
return self._rk4_step
|
137 |
+
else:
|
138 |
+
raise ValueError(f"Unknown method: {self.method}")
|
139 |
+
|
140 |
+
def get_running_train_loop(self):
|
141 |
+
try:
|
142 |
+
# Lazy import
|
143 |
+
from ...utils.train_loop import TrainLoop
|
144 |
+
|
145 |
+
return TrainLoop.get_running_loop()
|
146 |
+
except ImportError:
|
147 |
+
return None
|
148 |
+
|
149 |
+
@property
|
150 |
+
def visualizing(self):
|
151 |
+
loop = self.get_running_train_loop()
|
152 |
+
if loop is None:
|
153 |
+
return
|
154 |
+
out_path = loop.make_current_step_viz_path(self.viz_name, ".gif")
|
155 |
+
return loop.global_step % self.viz_every == 0 and not out_path.exists()
|
156 |
+
|
157 |
+
def _reset_camera(self):
|
158 |
+
try:
|
159 |
+
from celluloid import Camera
|
160 |
+
|
161 |
+
self._camera = Camera(plt.figure())
|
162 |
+
except:
|
163 |
+
pass
|
164 |
+
|
165 |
+
def _maybe_dump_camera(self):
|
166 |
+
camera = self._camera
|
167 |
+
loop = self.get_running_train_loop()
|
168 |
+
if camera is not None and loop is not None:
|
169 |
+
animation = camera.animate()
|
170 |
+
out_path = loop.make_current_step_viz_path(self.viz_name, ".gif")
|
171 |
+
out_path.parent.mkdir(exist_ok=True, parents=True)
|
172 |
+
animation.save(out_path, writer="pillow", fps=4)
|
173 |
+
plt.close()
|
174 |
+
self._camera = None
|
175 |
+
|
176 |
+
@property
|
177 |
+
def n_steps(self):
|
178 |
+
n = self.nfe
|
179 |
+
if self.method == "euler":
|
180 |
+
pass
|
181 |
+
elif self.method == "midpoint":
|
182 |
+
n //= 2
|
183 |
+
elif self.method == "rk4":
|
184 |
+
n //= 4
|
185 |
+
else:
|
186 |
+
raise ValueError(f"Unknown method: {self.method}")
|
187 |
+
return n
|
188 |
+
|
189 |
+
def solve(self, f: VelocityField, ψ0: Tensor, t0=0.0, t1=1.0):
|
190 |
+
ts = self._time_mapping(np.linspace(t0, t1, self.n_steps + 1))
|
191 |
+
|
192 |
+
if self.visualizing:
|
193 |
+
self._reset_camera()
|
194 |
+
|
195 |
+
if self.verbose:
|
196 |
+
steps = trange(self.n_steps, desc="CFM inference")
|
197 |
+
else:
|
198 |
+
steps = range(self.n_steps)
|
199 |
+
|
200 |
+
ψt = ψ0
|
201 |
+
|
202 |
+
for i in steps:
|
203 |
+
dt = ts[i + 1] - ts[i]
|
204 |
+
t = ts[i]
|
205 |
+
self._maybe_camera_snap(ψt=ψt, t=t)
|
206 |
+
ψt = self._step(t=t, ψt=ψt, dt=dt, f=f)
|
207 |
+
|
208 |
+
self._maybe_camera_snap(ψt=ψt, t=ts[-1])
|
209 |
+
|
210 |
+
ψ1 = ψt
|
211 |
+
del ψt
|
212 |
+
|
213 |
+
self._maybe_dump_camera()
|
214 |
+
|
215 |
+
return ψ1
|
216 |
+
|
217 |
+
def __call__(self, f: VelocityField, ψ0: Tensor, t0=0.0, t1=1.0):
|
218 |
+
return self.solve(f=f, ψ0=ψ0, t0=t0, t1=t1)
|
219 |
+
|
220 |
+
|
221 |
+
class SinusodialTimeEmbedding(nn.Module):
|
222 |
+
def __init__(self, d_embed):
|
223 |
+
super().__init__()
|
224 |
+
self.d_embed = d_embed
|
225 |
+
assert d_embed % 2 == 0
|
226 |
+
|
227 |
+
def forward(self, t):
|
228 |
+
t = t.unsqueeze(-1) # ... 1
|
229 |
+
p = torch.linspace(0, 4, self.d_embed // 2).to(t)
|
230 |
+
while p.dim() < t.dim():
|
231 |
+
p = p.unsqueeze(0) # ... d/2
|
232 |
+
sin = torch.sin(t * 10**p)
|
233 |
+
cos = torch.cos(t * 10**p)
|
234 |
+
return torch.cat([sin, cos], dim=-1)
|
235 |
+
|
236 |
+
|
237 |
+
@dataclass(eq=False)
|
238 |
+
class CFM(nn.Module):
|
239 |
+
"""
|
240 |
+
This mixin is for general diffusion models.
|
241 |
+
|
242 |
+
ψ0 stands for the gaussian noise, and ψ1 is the data point.
|
243 |
+
|
244 |
+
Here we follow the CFM style:
|
245 |
+
The generation process (reverse process) is from t=0 to t=1.
|
246 |
+
The forward process is from t=1 to t=0.
|
247 |
+
"""
|
248 |
+
|
249 |
+
cond_dim: int
|
250 |
+
output_dim: int
|
251 |
+
time_emb_dim: int = 128
|
252 |
+
viz_name: str = "cfm"
|
253 |
+
solver_nfe: int = 32
|
254 |
+
solver_method: str = "midpoint"
|
255 |
+
time_mapping_divisor: int = 4
|
256 |
+
|
257 |
+
def __post_init__(self):
|
258 |
+
super().__init__()
|
259 |
+
self.solver = Solver(
|
260 |
+
viz_name=self.viz_name,
|
261 |
+
viz_every=1,
|
262 |
+
nfe=self.solver_nfe,
|
263 |
+
method=self.solver_method,
|
264 |
+
time_mapping_divisor=self.time_mapping_divisor,
|
265 |
+
)
|
266 |
+
self.emb = SinusodialTimeEmbedding(self.time_emb_dim)
|
267 |
+
self.net = WN(
|
268 |
+
input_dim=self.output_dim,
|
269 |
+
output_dim=self.output_dim,
|
270 |
+
local_dim=self.cond_dim,
|
271 |
+
global_dim=self.time_emb_dim,
|
272 |
+
)
|
273 |
+
|
274 |
+
def _perturb(self, ψ1: Tensor, t: Tensor | None = None):
|
275 |
+
"""
|
276 |
+
Perturb ψ1 to ψt.
|
277 |
+
"""
|
278 |
+
raise NotImplementedError
|
279 |
+
|
280 |
+
def _sample_ψ0(self, x: Tensor):
|
281 |
+
"""
|
282 |
+
Args:
|
283 |
+
x: (b c t), which implies the shape of ψ0
|
284 |
+
"""
|
285 |
+
shape = list(x.shape)
|
286 |
+
shape[1] = self.output_dim
|
287 |
+
if self.training:
|
288 |
+
g = None
|
289 |
+
else:
|
290 |
+
g = torch.Generator(device=x.device)
|
291 |
+
g.manual_seed(0) # deterministic sampling during eval
|
292 |
+
ψ0 = torch.randn(shape, device=x.device, dtype=x.dtype, generator=g)
|
293 |
+
return ψ0
|
294 |
+
|
295 |
+
@property
|
296 |
+
def sigma(self):
|
297 |
+
return 1e-4
|
298 |
+
|
299 |
+
def _to_ψt(self, *, ψ1: Tensor, ψ0: Tensor, t: Tensor):
|
300 |
+
"""
|
301 |
+
Eq (22)
|
302 |
+
"""
|
303 |
+
while t.dim() < ψ1.dim():
|
304 |
+
t = t.unsqueeze(-1)
|
305 |
+
μ = t * ψ1 + (1 - t) * ψ0
|
306 |
+
return μ + torch.randn_like(μ) * self.sigma
|
307 |
+
|
308 |
+
def _to_u(self, *, ψ1, ψ0: Tensor):
|
309 |
+
"""
|
310 |
+
Eq (21)
|
311 |
+
"""
|
312 |
+
return ψ1 - ψ0
|
313 |
+
|
314 |
+
def _to_v(self, *, ψt, x, t: float | Tensor):
|
315 |
+
"""
|
316 |
+
Args:
|
317 |
+
ψt: (b c t)
|
318 |
+
x: (b c t)
|
319 |
+
t: (b)
|
320 |
+
Returns:
|
321 |
+
v: (b c t)
|
322 |
+
"""
|
323 |
+
if isinstance(t, (float, int)):
|
324 |
+
t = torch.full(ψt.shape[:1], t).to(ψt)
|
325 |
+
t = t.clamp(0, 1) # [0, 1)
|
326 |
+
g = self.emb(t) # (b d)
|
327 |
+
v = self.net(ψt, l=x, g=g)
|
328 |
+
return v
|
329 |
+
|
330 |
+
def compute_losses(self, x, y, ψ0) -> dict:
|
331 |
+
"""
|
332 |
+
Args:
|
333 |
+
x: (b c t)
|
334 |
+
y: (b c t)
|
335 |
+
Returns:
|
336 |
+
losses: dict
|
337 |
+
"""
|
338 |
+
t = torch.rand(len(x), device=x.device, dtype=x.dtype)
|
339 |
+
t = self.solver.time_mapping(t)
|
340 |
+
|
341 |
+
if ψ0 is None:
|
342 |
+
ψ0 = self._sample_ψ0(x)
|
343 |
+
|
344 |
+
ψt = self._to_ψt(ψ1=y, t=t, ψ0=ψ0)
|
345 |
+
|
346 |
+
v = self._to_v(ψt=ψt, t=t, x=x)
|
347 |
+
u = self._to_u(ψ1=y, ψ0=ψ0)
|
348 |
+
|
349 |
+
losses = dict(l1=F.l1_loss(v, u))
|
350 |
+
|
351 |
+
return losses
|
352 |
+
|
353 |
+
@torch.inference_mode()
|
354 |
+
def sample(self, x, ψ0=None, t0=0.0):
|
355 |
+
"""
|
356 |
+
Args:
|
357 |
+
x: (b c t)
|
358 |
+
Returns:
|
359 |
+
y: (b ... t)
|
360 |
+
"""
|
361 |
+
if ψ0 is None:
|
362 |
+
ψ0 = self._sample_ψ0(x)
|
363 |
+
f = lambda t, ψt, dt: self._to_v(ψt=ψt, t=t, x=x)
|
364 |
+
ψ1 = self.solver(f=f, ψ0=ψ0, t0=t0)
|
365 |
+
return ψ1
|
366 |
+
|
367 |
+
def forward(self, x: Tensor, y: Tensor | None = None, ψ0: Tensor | None = None, t0=0.0):
|
368 |
+
if y is None:
|
369 |
+
y = self.sample(x, ψ0=ψ0, t0=t0)
|
370 |
+
else:
|
371 |
+
self.losses = self.compute_losses(x, y, ψ0=ψ0)
|
372 |
+
return y
|
resemble-enhance/resemble_enhance/enhancer/lcfm/irmae.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import Tensor, nn
|
7 |
+
from torch.nn.utils.parametrizations import weight_norm
|
8 |
+
|
9 |
+
from ...common import Normalizer
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class IRMAEOutput:
|
16 |
+
latent: Tensor # latent vector
|
17 |
+
decoded: Tensor | None # decoder output, include extra dim
|
18 |
+
|
19 |
+
|
20 |
+
class ResBlock(nn.Sequential):
|
21 |
+
def __init__(self, channels, dilations=[1, 2, 4, 8]):
|
22 |
+
wn = weight_norm
|
23 |
+
super().__init__(
|
24 |
+
nn.GroupNorm(32, channels),
|
25 |
+
nn.GELU(),
|
26 |
+
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[0])),
|
27 |
+
nn.GroupNorm(32, channels),
|
28 |
+
nn.GELU(),
|
29 |
+
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[1])),
|
30 |
+
nn.GroupNorm(32, channels),
|
31 |
+
nn.GELU(),
|
32 |
+
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[2])),
|
33 |
+
nn.GroupNorm(32, channels),
|
34 |
+
nn.GELU(),
|
35 |
+
wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[3])),
|
36 |
+
)
|
37 |
+
|
38 |
+
def forward(self, x: Tensor):
|
39 |
+
return x + super().forward(x)
|
40 |
+
|
41 |
+
|
42 |
+
class IRMAE(nn.Module):
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
input_dim,
|
46 |
+
output_dim,
|
47 |
+
latent_dim,
|
48 |
+
hidden_dim=1024,
|
49 |
+
num_irms=4,
|
50 |
+
):
|
51 |
+
"""
|
52 |
+
Args:
|
53 |
+
input_dim: input dimension
|
54 |
+
output_dim: output dimension
|
55 |
+
latent_dim: latent dimension
|
56 |
+
hidden_dim: hidden layer dimension
|
57 |
+
num_irm_matrics: number of implicit rank minimization matrices
|
58 |
+
norm: normalization layer
|
59 |
+
"""
|
60 |
+
self.input_dim = input_dim
|
61 |
+
super().__init__()
|
62 |
+
|
63 |
+
self.encoder = nn.Sequential(
|
64 |
+
nn.Conv1d(input_dim, hidden_dim, 3, padding="same"),
|
65 |
+
*[ResBlock(hidden_dim) for _ in range(4)],
|
66 |
+
# Try to obtain compact representation (https://proceedings.neurips.cc/paper/2020/file/a9078e8653368c9c291ae2f8b74012e7-Paper.pdf)
|
67 |
+
*[nn.Conv1d(hidden_dim if i == 0 else latent_dim, latent_dim, 1, bias=False) for i in range(num_irms)],
|
68 |
+
nn.Tanh(),
|
69 |
+
)
|
70 |
+
|
71 |
+
self.decoder = nn.Sequential(
|
72 |
+
nn.Conv1d(latent_dim, hidden_dim, 3, padding="same"),
|
73 |
+
*[ResBlock(hidden_dim) for _ in range(4)],
|
74 |
+
nn.Conv1d(hidden_dim, output_dim, 1),
|
75 |
+
)
|
76 |
+
|
77 |
+
self.head = nn.Sequential(
|
78 |
+
nn.Conv1d(output_dim, hidden_dim, 3, padding="same"),
|
79 |
+
nn.GELU(),
|
80 |
+
nn.Conv1d(hidden_dim, input_dim, 1),
|
81 |
+
)
|
82 |
+
|
83 |
+
self.estimator = Normalizer()
|
84 |
+
|
85 |
+
def encode(self, x):
|
86 |
+
"""
|
87 |
+
Args:
|
88 |
+
x: (b c t) tensor
|
89 |
+
"""
|
90 |
+
z = self.encoder(x) # (b c t)
|
91 |
+
_ = self.estimator(z) # Estimate the glboal mean and std of z
|
92 |
+
self.stats = {}
|
93 |
+
self.stats["z_mean"] = z.mean().item()
|
94 |
+
self.stats["z_std"] = z.std().item()
|
95 |
+
self.stats["z_abs_68"] = z.abs().quantile(0.6827).item()
|
96 |
+
self.stats["z_abs_95"] = z.abs().quantile(0.9545).item()
|
97 |
+
self.stats["z_abs_99"] = z.abs().quantile(0.9973).item()
|
98 |
+
return z
|
99 |
+
|
100 |
+
def decode(self, z):
|
101 |
+
"""
|
102 |
+
Args:
|
103 |
+
z: (b c t) tensor
|
104 |
+
"""
|
105 |
+
return self.decoder(z)
|
106 |
+
|
107 |
+
def forward(self, x, skip_decoding=False):
|
108 |
+
"""
|
109 |
+
Args:
|
110 |
+
x: (b c t) tensor
|
111 |
+
skip_decoding: if True, skip the decoding step
|
112 |
+
"""
|
113 |
+
z = self.encode(x) # q(z|x)
|
114 |
+
|
115 |
+
if skip_decoding:
|
116 |
+
# This speeds up the training in cfm only mode
|
117 |
+
decoded = None
|
118 |
+
else:
|
119 |
+
decoded = self.decode(z) # p(x|z)
|
120 |
+
predicted = self.head(decoded)
|
121 |
+
self.losses = dict(mse=F.mse_loss(predicted, x))
|
122 |
+
|
123 |
+
return IRMAEOutput(latent=z, decoded=decoded)
|
resemble-enhance/resemble_enhance/enhancer/lcfm/lcfm.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from enum import Enum
|
3 |
+
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch import Tensor, nn
|
8 |
+
|
9 |
+
from .cfm import CFM
|
10 |
+
from .irmae import IRMAE, IRMAEOutput
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
def freeze_(module):
|
16 |
+
for p in module.parameters():
|
17 |
+
p.requires_grad_(False)
|
18 |
+
|
19 |
+
|
20 |
+
class LCFM(nn.Module):
|
21 |
+
class Mode(Enum):
|
22 |
+
AE = "ae"
|
23 |
+
CFM = "cfm"
|
24 |
+
|
25 |
+
def __init__(self, ae: IRMAE, cfm: CFM, z_scale: float = 1.0):
|
26 |
+
super().__init__()
|
27 |
+
self.ae = ae
|
28 |
+
self.cfm = cfm
|
29 |
+
self.z_scale = z_scale
|
30 |
+
self._mode = None
|
31 |
+
self._eval_tau = 0.5
|
32 |
+
|
33 |
+
@property
|
34 |
+
def mode(self):
|
35 |
+
return self._mode
|
36 |
+
|
37 |
+
def set_mode_(self, mode):
|
38 |
+
mode = self.Mode(mode)
|
39 |
+
self._mode = mode
|
40 |
+
|
41 |
+
if mode == mode.AE:
|
42 |
+
freeze_(self.cfm)
|
43 |
+
logger.info("Freeze cfm")
|
44 |
+
elif mode == mode.CFM:
|
45 |
+
freeze_(self.ae)
|
46 |
+
logger.info("Freeze ae (encoder and decoder)")
|
47 |
+
else:
|
48 |
+
raise ValueError(f"Unknown training mode: {mode}")
|
49 |
+
|
50 |
+
def get_running_train_loop(self):
|
51 |
+
try:
|
52 |
+
# Lazy import
|
53 |
+
from ...utils.train_loop import TrainLoop
|
54 |
+
|
55 |
+
return TrainLoop.get_running_loop()
|
56 |
+
except ImportError:
|
57 |
+
return None
|
58 |
+
|
59 |
+
@property
|
60 |
+
def global_step(self):
|
61 |
+
loop = self.get_running_train_loop()
|
62 |
+
if loop is None:
|
63 |
+
return None
|
64 |
+
return loop.global_step
|
65 |
+
|
66 |
+
@torch.no_grad()
|
67 |
+
def _visualize(self, x, y, y_):
|
68 |
+
loop = self.get_running_train_loop()
|
69 |
+
if loop is None:
|
70 |
+
return
|
71 |
+
|
72 |
+
plt.subplot(221)
|
73 |
+
plt.imshow(y[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
|
74 |
+
plt.title("GT")
|
75 |
+
|
76 |
+
plt.subplot(222)
|
77 |
+
y_ = y_[:, : y.shape[1]]
|
78 |
+
plt.imshow(y_[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
|
79 |
+
plt.title("Posterior")
|
80 |
+
|
81 |
+
plt.subplot(223)
|
82 |
+
z_ = self.cfm(x)
|
83 |
+
y__ = self.ae.decode(z_)
|
84 |
+
y__ = y__[:, : y.shape[1]]
|
85 |
+
plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
|
86 |
+
plt.title("C-Prior")
|
87 |
+
del y__
|
88 |
+
|
89 |
+
plt.subplot(224)
|
90 |
+
z_ = torch.randn_like(z_)
|
91 |
+
y__ = self.ae.decode(z_)
|
92 |
+
y__ = y__[:, : y.shape[1]]
|
93 |
+
plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
|
94 |
+
plt.title("Prior")
|
95 |
+
del z_, y__
|
96 |
+
|
97 |
+
path = loop.make_current_step_viz_path("recon", ".png")
|
98 |
+
path.parent.mkdir(exist_ok=True, parents=True)
|
99 |
+
plt.tight_layout()
|
100 |
+
plt.savefig(path, dpi=500)
|
101 |
+
plt.close()
|
102 |
+
|
103 |
+
def _scale(self, z: Tensor):
|
104 |
+
return z * self.z_scale
|
105 |
+
|
106 |
+
def _unscale(self, z: Tensor):
|
107 |
+
return z / self.z_scale
|
108 |
+
|
109 |
+
def eval_tau_(self, tau):
|
110 |
+
self._eval_tau = tau
|
111 |
+
|
112 |
+
def forward(self, x, y: Tensor | None = None, ψ0: Tensor | None = None):
|
113 |
+
"""
|
114 |
+
Args:
|
115 |
+
x: (b d t), condition mel
|
116 |
+
y: (b d t), target mel
|
117 |
+
ψ0: (b d t), starting mel
|
118 |
+
"""
|
119 |
+
if self.mode == self.Mode.CFM:
|
120 |
+
self.ae.eval() # Always set to eval when training cfm
|
121 |
+
|
122 |
+
if ψ0 is not None:
|
123 |
+
ψ0 = self._scale(self.ae.encode(ψ0))
|
124 |
+
if self.training:
|
125 |
+
tau = torch.rand_like(ψ0[:, :1, :1])
|
126 |
+
else:
|
127 |
+
tau = self._eval_tau
|
128 |
+
ψ0 = tau * torch.randn_like(ψ0) + (1 - tau) * ψ0
|
129 |
+
|
130 |
+
if y is None:
|
131 |
+
if self.mode == self.Mode.AE:
|
132 |
+
with torch.no_grad():
|
133 |
+
training = self.ae.training
|
134 |
+
self.ae.eval()
|
135 |
+
z = self.ae.encode(x)
|
136 |
+
self.ae.train(training)
|
137 |
+
else:
|
138 |
+
z = self._unscale(self.cfm(x, ψ0=ψ0))
|
139 |
+
|
140 |
+
h = self.ae.decode(z)
|
141 |
+
else:
|
142 |
+
ae_output: IRMAEOutput = self.ae(y, skip_decoding=self.mode == self.Mode.CFM)
|
143 |
+
|
144 |
+
if self.mode == self.Mode.CFM:
|
145 |
+
_ = self.cfm(x, self._scale(ae_output.latent.detach()), ψ0=ψ0)
|
146 |
+
|
147 |
+
h = ae_output.decoded
|
148 |
+
|
149 |
+
if h is not None and self.global_step is not None and self.global_step % 100 == 0:
|
150 |
+
self._visualize(x[:1], y[:1], h[:1])
|
151 |
+
|
152 |
+
return h
|
resemble-enhance/resemble_enhance/enhancer/lcfm/wn.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
|
10 |
+
@torch.jit.script
|
11 |
+
def _fused_tanh_sigmoid(h):
|
12 |
+
a, b = h.chunk(2, dim=1)
|
13 |
+
h = a.tanh() * b.sigmoid()
|
14 |
+
return h
|
15 |
+
|
16 |
+
|
17 |
+
class WNLayer(nn.Module):
|
18 |
+
"""
|
19 |
+
A DiffWave-like WN
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, hidden_dim, local_dim, global_dim, kernel_size, dilation):
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
local_output_dim = hidden_dim * 2
|
26 |
+
|
27 |
+
if global_dim is not None:
|
28 |
+
self.gconv = nn.Conv1d(global_dim, hidden_dim, 1)
|
29 |
+
|
30 |
+
if local_dim is not None:
|
31 |
+
self.lconv = nn.Conv1d(local_dim, local_output_dim, 1)
|
32 |
+
|
33 |
+
self.dconv = nn.Conv1d(hidden_dim, local_output_dim, kernel_size, dilation=dilation, padding="same")
|
34 |
+
|
35 |
+
self.out = nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size=1)
|
36 |
+
|
37 |
+
def forward(self, z, l, g):
|
38 |
+
identity = z
|
39 |
+
|
40 |
+
if g is not None:
|
41 |
+
if g.dim() == 2:
|
42 |
+
g = g.unsqueeze(-1)
|
43 |
+
z = z + self.gconv(g)
|
44 |
+
|
45 |
+
z = self.dconv(z)
|
46 |
+
|
47 |
+
if l is not None:
|
48 |
+
z = z + self.lconv(l)
|
49 |
+
|
50 |
+
z = _fused_tanh_sigmoid(z)
|
51 |
+
|
52 |
+
h = self.out(z)
|
53 |
+
|
54 |
+
z, s = h.chunk(2, dim=1)
|
55 |
+
|
56 |
+
o = (z + identity) / math.sqrt(2)
|
57 |
+
|
58 |
+
return o, s
|
59 |
+
|
60 |
+
|
61 |
+
class WN(nn.Module):
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
input_dim,
|
65 |
+
output_dim,
|
66 |
+
local_dim=None,
|
67 |
+
global_dim=None,
|
68 |
+
n_layers=30,
|
69 |
+
kernel_size=3,
|
70 |
+
dilation_cycle=5,
|
71 |
+
hidden_dim=512,
|
72 |
+
):
|
73 |
+
super().__init__()
|
74 |
+
assert kernel_size % 2 == 1
|
75 |
+
assert hidden_dim % 2 == 0
|
76 |
+
|
77 |
+
self.input_dim = input_dim
|
78 |
+
self.hidden_dim = hidden_dim
|
79 |
+
self.local_dim = local_dim
|
80 |
+
self.global_dim = global_dim
|
81 |
+
|
82 |
+
self.start = nn.Conv1d(input_dim, hidden_dim, 1)
|
83 |
+
if local_dim is not None:
|
84 |
+
self.local_norm = nn.InstanceNorm1d(local_dim)
|
85 |
+
|
86 |
+
self.layers = nn.ModuleList(
|
87 |
+
[
|
88 |
+
WNLayer(
|
89 |
+
hidden_dim=hidden_dim,
|
90 |
+
local_dim=local_dim,
|
91 |
+
global_dim=global_dim,
|
92 |
+
kernel_size=kernel_size,
|
93 |
+
dilation=2 ** (i % dilation_cycle),
|
94 |
+
)
|
95 |
+
for i in range(n_layers)
|
96 |
+
]
|
97 |
+
)
|
98 |
+
|
99 |
+
self.end = nn.Conv1d(hidden_dim, output_dim, 1)
|
100 |
+
|
101 |
+
def forward(self, z, l=None, g=None):
|
102 |
+
"""
|
103 |
+
Args:
|
104 |
+
z: input (b c t)
|
105 |
+
l: local condition (b c t)
|
106 |
+
g: global condition (b d)
|
107 |
+
"""
|
108 |
+
z = self.start(z)
|
109 |
+
|
110 |
+
if l is not None:
|
111 |
+
l = self.local_norm(l)
|
112 |
+
|
113 |
+
# Skips
|
114 |
+
s_list = []
|
115 |
+
|
116 |
+
for layer in self.layers:
|
117 |
+
z, s = layer(z, l, g)
|
118 |
+
s_list.append(s)
|
119 |
+
|
120 |
+
s_list = torch.stack(s_list, dim=0).sum(dim=0)
|
121 |
+
s_list = s_list / math.sqrt(len(self.layers))
|
122 |
+
|
123 |
+
o = self.end(s_list)
|
124 |
+
|
125 |
+
return o
|
126 |
+
|
127 |
+
def summarize(self, length=100):
|
128 |
+
from ptflops import get_model_complexity_info
|
129 |
+
|
130 |
+
x = torch.randn(1, self.input_dim, length)
|
131 |
+
|
132 |
+
macs, params = get_model_complexity_info(
|
133 |
+
self,
|
134 |
+
(self.input_dim, length),
|
135 |
+
as_strings=True,
|
136 |
+
print_per_layer_stat=True,
|
137 |
+
verbose=True,
|
138 |
+
)
|
139 |
+
|
140 |
+
print(f"Input shape: {x.shape}")
|
141 |
+
print(f"Computational complexity: {macs}")
|
142 |
+
print(f"Number of parameters: {params}")
|
143 |
+
|
144 |
+
|
145 |
+
if __name__ == "__main__":
|
146 |
+
model = WN(input_dim=64, output_dim=64)
|
147 |
+
model.summarize()
|
resemble-enhance/resemble_enhance/enhancer/train.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import random
|
3 |
+
from functools import partial
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import soundfile
|
7 |
+
import torch
|
8 |
+
from deepspeed import DeepSpeedConfig
|
9 |
+
from torch import Tensor
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from ..data import create_dataloaders, mix_fg_bg
|
13 |
+
from ..utils import Engine, TrainLoop, save_mels, setup_logging, tree_map
|
14 |
+
from ..utils.distributed import is_local_leader
|
15 |
+
from .enhancer import Enhancer
|
16 |
+
from .hparams import HParams
|
17 |
+
from .univnet.discriminator import Discriminator
|
18 |
+
|
19 |
+
|
20 |
+
def load_G(run_dir: Path, hp: HParams | None = None, training=True):
|
21 |
+
if hp is None:
|
22 |
+
hp = HParams.load(run_dir)
|
23 |
+
assert isinstance(hp, HParams)
|
24 |
+
model = Enhancer(hp)
|
25 |
+
engine = Engine(model=model, config_class=DeepSpeedConfig(hp.deepspeed_config), ckpt_dir=run_dir / "ds" / "G")
|
26 |
+
if training:
|
27 |
+
engine.load_checkpoint()
|
28 |
+
else:
|
29 |
+
engine.load_checkpoint(load_optimizer_states=False, load_lr_scheduler_states=False)
|
30 |
+
return engine
|
31 |
+
|
32 |
+
|
33 |
+
def load_D(run_dir: Path, hp: HParams):
|
34 |
+
if hp is None:
|
35 |
+
hp = HParams.load(run_dir)
|
36 |
+
assert isinstance(hp, HParams)
|
37 |
+
model = Discriminator(hp)
|
38 |
+
engine = Engine(model=model, config_class=DeepSpeedConfig(hp.deepspeed_config), ckpt_dir=run_dir / "ds" / "D")
|
39 |
+
engine.load_checkpoint()
|
40 |
+
return engine
|
41 |
+
|
42 |
+
|
43 |
+
def save_wav(path: Path, wav: Tensor, rate: int):
|
44 |
+
wav = wav.detach().cpu().numpy()
|
45 |
+
soundfile.write(path, wav, samplerate=rate)
|
46 |
+
|
47 |
+
|
48 |
+
def main():
|
49 |
+
parser = argparse.ArgumentParser()
|
50 |
+
parser.add_argument("run_dir", type=Path)
|
51 |
+
parser.add_argument("--yaml", type=Path, default=None)
|
52 |
+
parser.add_argument("--device", type=str, default="cuda")
|
53 |
+
args = parser.parse_args()
|
54 |
+
|
55 |
+
setup_logging(args.run_dir)
|
56 |
+
hp = HParams.load(args.run_dir, yaml=args.yaml)
|
57 |
+
|
58 |
+
if is_local_leader():
|
59 |
+
hp.save_if_not_exists(args.run_dir)
|
60 |
+
hp.print()
|
61 |
+
|
62 |
+
train_dl, val_dl = create_dataloaders(hp, mode="enhancer")
|
63 |
+
|
64 |
+
def feed_G(engine: Engine, batch: dict[str, Tensor]):
|
65 |
+
if hp.lcfm_training_mode == "ae":
|
66 |
+
pred = engine(batch["fg_wavs"], batch["fg_wavs"])
|
67 |
+
elif hp.lcfm_training_mode == "cfm":
|
68 |
+
alpha_fn = lambda: random.uniform(*hp.mix_alpha_range)
|
69 |
+
mx_dwavs = mix_fg_bg(batch["fg_dwavs"], batch["bg_dwavs"], alpha=alpha_fn)
|
70 |
+
pred = engine(mx_dwavs, batch["fg_wavs"], batch["fg_dwavs"])
|
71 |
+
else:
|
72 |
+
raise ValueError(f"Unknown training mode: {hp.lcfm_training_mode}")
|
73 |
+
losses = engine.gather_attribute("losses")
|
74 |
+
return pred, losses
|
75 |
+
|
76 |
+
def feed_D(engine: Engine, batch: dict | None, fake: Tensor):
|
77 |
+
if batch is None:
|
78 |
+
losses = engine(fake=fake)
|
79 |
+
else:
|
80 |
+
losses = engine(fake=fake, real=batch["fg_wavs"])
|
81 |
+
return losses
|
82 |
+
|
83 |
+
@torch.no_grad()
|
84 |
+
def eval_fn(engine: Engine, eval_dir, n_saved=10):
|
85 |
+
assert isinstance(hp, HParams)
|
86 |
+
|
87 |
+
model = engine.module
|
88 |
+
model.eval()
|
89 |
+
|
90 |
+
step = engine.global_step
|
91 |
+
|
92 |
+
for i, batch in enumerate(tqdm(val_dl), 1):
|
93 |
+
batch = tree_map(lambda x: x.to(args.device) if isinstance(x, Tensor) else x, batch)
|
94 |
+
|
95 |
+
fg_wavs = batch["fg_wavs"] # 1 t
|
96 |
+
|
97 |
+
if hp.lcfm_training_mode == "ae":
|
98 |
+
in_dwavs = fg_wavs
|
99 |
+
elif hp.lcfm_training_mode == "cfm":
|
100 |
+
in_dwavs = mix_fg_bg(fg_wavs, batch["bg_dwavs"])
|
101 |
+
else:
|
102 |
+
raise ValueError(f"Unknown training mode: {hp.lcfm_training_mode}")
|
103 |
+
|
104 |
+
pred_fg_wavs = model(in_dwavs) # 1 t
|
105 |
+
|
106 |
+
in_mels = model.to_mel(in_dwavs) # 1 c t
|
107 |
+
fg_mels = model.to_mel(fg_wavs) # 1 c t
|
108 |
+
pred_fg_mels = model.to_mel(pred_fg_wavs) # 1 c t
|
109 |
+
|
110 |
+
rate = model.hp.wav_rate
|
111 |
+
get_path = lambda suffix: eval_dir / f"step_{step:08}_{i:03}{suffix}"
|
112 |
+
|
113 |
+
save_wav(get_path("_input.wav"), in_dwavs[0], rate=rate)
|
114 |
+
save_wav(get_path("_predict.wav"), pred_fg_wavs[0], rate=rate)
|
115 |
+
save_wav(get_path("_target.wav"), fg_wavs[0], rate=rate)
|
116 |
+
|
117 |
+
save_mels(
|
118 |
+
get_path(".png"),
|
119 |
+
cond_mel=in_mels[0].cpu().numpy(),
|
120 |
+
pred_mel=pred_fg_mels[0].cpu().numpy(),
|
121 |
+
targ_mel=fg_mels[0].cpu().numpy(),
|
122 |
+
)
|
123 |
+
|
124 |
+
if i >= n_saved:
|
125 |
+
break
|
126 |
+
|
127 |
+
train_loop = TrainLoop(
|
128 |
+
run_dir=args.run_dir,
|
129 |
+
train_dl=train_dl,
|
130 |
+
load_G=partial(load_G, hp=hp),
|
131 |
+
load_D=partial(load_D, hp=hp),
|
132 |
+
device=args.device,
|
133 |
+
feed_G=feed_G,
|
134 |
+
feed_D=feed_D,
|
135 |
+
eval_fn=eval_fn,
|
136 |
+
gan_training_start_step=hp.gan_training_start_step,
|
137 |
+
)
|
138 |
+
|
139 |
+
train_loop.run(max_steps=hp.max_steps)
|
140 |
+
|
141 |
+
|
142 |
+
if __name__ == "__main__":
|
143 |
+
main()
|
resemble-enhance/resemble_enhance/enhancer/univnet/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .univnet import UnivNet
|
resemble-enhance/resemble_enhance/enhancer/univnet/alias_free_torch/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
from .filter import *
|
5 |
+
from .resample import *
|
resemble-enhance/resemble_enhance/enhancer/univnet/alias_free_torch/filter.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import math
|
8 |
+
|
9 |
+
if 'sinc' in dir(torch):
|
10 |
+
sinc = torch.sinc
|
11 |
+
else:
|
12 |
+
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
13 |
+
# https://adefossez.github.io/julius/julius/core.html
|
14 |
+
# LICENSE is in incl_licenses directory.
|
15 |
+
def sinc(x: torch.Tensor):
|
16 |
+
"""
|
17 |
+
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
18 |
+
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
19 |
+
"""
|
20 |
+
return torch.where(x == 0,
|
21 |
+
torch.tensor(1., device=x.device, dtype=x.dtype),
|
22 |
+
torch.sin(math.pi * x) / math.pi / x)
|
23 |
+
|
24 |
+
|
25 |
+
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
26 |
+
# https://adefossez.github.io/julius/julius/lowpass.html
|
27 |
+
# LICENSE is in incl_licenses directory.
|
28 |
+
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
29 |
+
even = (kernel_size % 2 == 0)
|
30 |
+
half_size = kernel_size // 2
|
31 |
+
|
32 |
+
#For kaiser window
|
33 |
+
delta_f = 4 * half_width
|
34 |
+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
35 |
+
if A > 50.:
|
36 |
+
beta = 0.1102 * (A - 8.7)
|
37 |
+
elif A >= 21.:
|
38 |
+
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
|
39 |
+
else:
|
40 |
+
beta = 0.
|
41 |
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
42 |
+
|
43 |
+
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
44 |
+
if even:
|
45 |
+
time = (torch.arange(-half_size, half_size) + 0.5)
|
46 |
+
else:
|
47 |
+
time = torch.arange(kernel_size) - half_size
|
48 |
+
if cutoff == 0:
|
49 |
+
filter_ = torch.zeros_like(time)
|
50 |
+
else:
|
51 |
+
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
52 |
+
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
53 |
+
# of the constant component in the input signal.
|
54 |
+
filter_ /= filter_.sum()
|
55 |
+
filter = filter_.view(1, 1, kernel_size)
|
56 |
+
|
57 |
+
return filter
|
58 |
+
|
59 |
+
|
60 |
+
class LowPassFilter1d(nn.Module):
|
61 |
+
def __init__(self,
|
62 |
+
cutoff=0.5,
|
63 |
+
half_width=0.6,
|
64 |
+
stride: int = 1,
|
65 |
+
padding: bool = True,
|
66 |
+
padding_mode: str = 'replicate',
|
67 |
+
kernel_size: int = 12):
|
68 |
+
# kernel_size should be even number for stylegan3 setup,
|
69 |
+
# in this implementation, odd number is also possible.
|
70 |
+
super().__init__()
|
71 |
+
if cutoff < -0.:
|
72 |
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
73 |
+
if cutoff > 0.5:
|
74 |
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
75 |
+
self.kernel_size = kernel_size
|
76 |
+
self.even = (kernel_size % 2 == 0)
|
77 |
+
self.pad_left = kernel_size // 2 - int(self.even)
|
78 |
+
self.pad_right = kernel_size // 2
|
79 |
+
self.stride = stride
|
80 |
+
self.padding = padding
|
81 |
+
self.padding_mode = padding_mode
|
82 |
+
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
83 |
+
self.register_buffer("filter", filter)
|
84 |
+
|
85 |
+
#input [B, C, T]
|
86 |
+
def forward(self, x):
|
87 |
+
_, C, _ = x.shape
|
88 |
+
|
89 |
+
if self.padding:
|
90 |
+
x = F.pad(x, (self.pad_left, self.pad_right),
|
91 |
+
mode=self.padding_mode)
|
92 |
+
out = F.conv1d(x, self.filter.expand(C, -1, -1),
|
93 |
+
stride=self.stride, groups=C)
|
94 |
+
|
95 |
+
return out
|
resemble-enhance/resemble_enhance/enhancer/univnet/alias_free_torch/resample.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
2 |
+
# LICENSE is in incl_licenses directory.
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from .filter import LowPassFilter1d
|
7 |
+
from .filter import kaiser_sinc_filter1d
|
8 |
+
|
9 |
+
|
10 |
+
class UpSample1d(nn.Module):
|
11 |
+
def __init__(self, ratio=2, kernel_size=None):
|
12 |
+
super().__init__()
|
13 |
+
self.ratio = ratio
|
14 |
+
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
15 |
+
self.stride = ratio
|
16 |
+
self.pad = self.kernel_size // ratio - 1
|
17 |
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
18 |
+
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
19 |
+
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
|
20 |
+
half_width=0.6 / ratio,
|
21 |
+
kernel_size=self.kernel_size)
|
22 |
+
self.register_buffer("filter", filter)
|
23 |
+
|
24 |
+
# x: [B, C, T]
|
25 |
+
def forward(self, x):
|
26 |
+
_, C, _ = x.shape
|
27 |
+
|
28 |
+
x = F.pad(x, (self.pad, self.pad), mode='replicate')
|
29 |
+
x = self.ratio * F.conv_transpose1d(
|
30 |
+
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
31 |
+
x = x[..., self.pad_left:-self.pad_right]
|
32 |
+
|
33 |
+
return x
|
34 |
+
|
35 |
+
|
36 |
+
class DownSample1d(nn.Module):
|
37 |
+
def __init__(self, ratio=2, kernel_size=None):
|
38 |
+
super().__init__()
|
39 |
+
self.ratio = ratio
|
40 |
+
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
41 |
+
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
|
42 |
+
half_width=0.6 / ratio,
|
43 |
+
stride=ratio,
|
44 |
+
kernel_size=self.kernel_size)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
xx = self.lowpass(x)
|
48 |
+
|
49 |
+
return xx
|
resemble-enhance/resemble_enhance/enhancer/univnet/amp.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Refer from https://github.com/NVIDIA/BigVGAN
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn.utils.parametrizations import weight_norm
|
9 |
+
|
10 |
+
from .alias_free_torch import DownSample1d, UpSample1d
|
11 |
+
|
12 |
+
|
13 |
+
class SnakeBeta(nn.Module):
|
14 |
+
"""
|
15 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
16 |
+
Shape:
|
17 |
+
- Input: (B, C, T)
|
18 |
+
- Output: (B, C, T), same shape as the input
|
19 |
+
Parameters:
|
20 |
+
- alpha - trainable parameter that controls frequency
|
21 |
+
- beta - trainable parameter that controls magnitude
|
22 |
+
References:
|
23 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
24 |
+
https://arxiv.org/abs/2006.08195
|
25 |
+
Examples:
|
26 |
+
>>> a1 = snakebeta(256)
|
27 |
+
>>> x = torch.randn(256)
|
28 |
+
>>> x = a1(x)
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, in_features, alpha=1.0, clamp=(1e-2, 50)):
|
32 |
+
"""
|
33 |
+
Initialization.
|
34 |
+
INPUT:
|
35 |
+
- in_features: shape of the input
|
36 |
+
- alpha - trainable parameter that controls frequency
|
37 |
+
- beta - trainable parameter that controls magnitude
|
38 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
39 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
40 |
+
alpha will be trained along with the rest of your model.
|
41 |
+
"""
|
42 |
+
super().__init__()
|
43 |
+
self.in_features = in_features
|
44 |
+
self.log_alpha = nn.Parameter(torch.zeros(in_features) + math.log(alpha))
|
45 |
+
self.log_beta = nn.Parameter(torch.zeros(in_features) + math.log(alpha))
|
46 |
+
self.clamp = clamp
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
"""
|
50 |
+
Forward pass of the function.
|
51 |
+
Applies the function to the input elementwise.
|
52 |
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
53 |
+
"""
|
54 |
+
alpha = self.log_alpha.exp().clamp(*self.clamp)
|
55 |
+
alpha = alpha[None, :, None]
|
56 |
+
|
57 |
+
beta = self.log_beta.exp().clamp(*self.clamp)
|
58 |
+
beta = beta[None, :, None]
|
59 |
+
|
60 |
+
x = x + (1.0 / beta) * (x * alpha).sin().pow(2)
|
61 |
+
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class UpActDown(nn.Module):
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
act,
|
69 |
+
up_ratio: int = 2,
|
70 |
+
down_ratio: int = 2,
|
71 |
+
up_kernel_size: int = 12,
|
72 |
+
down_kernel_size: int = 12,
|
73 |
+
):
|
74 |
+
super().__init__()
|
75 |
+
self.up_ratio = up_ratio
|
76 |
+
self.down_ratio = down_ratio
|
77 |
+
self.act = act
|
78 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
79 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
# x: [B,C,T]
|
83 |
+
x = self.upsample(x)
|
84 |
+
x = self.act(x)
|
85 |
+
x = self.downsample(x)
|
86 |
+
return x
|
87 |
+
|
88 |
+
|
89 |
+
class AMPBlock(nn.Sequential):
|
90 |
+
def __init__(self, channels, *, kernel_size=3, dilations=(1, 3, 5)):
|
91 |
+
super().__init__(*(self._make_layer(channels, kernel_size, d) for d in dilations))
|
92 |
+
|
93 |
+
def _make_layer(self, channels, kernel_size, dilation):
|
94 |
+
return nn.Sequential(
|
95 |
+
weight_norm(nn.Conv1d(channels, channels, kernel_size, dilation=dilation, padding="same")),
|
96 |
+
UpActDown(act=SnakeBeta(channels)),
|
97 |
+
weight_norm(nn.Conv1d(channels, channels, kernel_size, padding="same")),
|
98 |
+
)
|
99 |
+
|
100 |
+
def forward(self, x):
|
101 |
+
return x + super().forward(x)
|
resemble-enhance/resemble_enhance/enhancer/univnet/discriminator.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import Tensor, nn
|
6 |
+
from torch.nn.utils.parametrizations import weight_norm
|
7 |
+
|
8 |
+
from ..hparams import HParams
|
9 |
+
from .mrstft import get_stft_cfgs
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
class PeriodNetwork(nn.Module):
|
15 |
+
def __init__(self, period):
|
16 |
+
super().__init__()
|
17 |
+
self.period = period
|
18 |
+
wn = weight_norm
|
19 |
+
self.convs = nn.ModuleList(
|
20 |
+
[
|
21 |
+
wn(nn.Conv2d(1, 64, (5, 1), (3, 1), padding=(2, 0))),
|
22 |
+
wn(nn.Conv2d(64, 128, (5, 1), (3, 1), padding=(2, 0))),
|
23 |
+
wn(nn.Conv2d(128, 256, (5, 1), (3, 1), padding=(2, 0))),
|
24 |
+
wn(nn.Conv2d(256, 512, (5, 1), (3, 1), padding=(2, 0))),
|
25 |
+
wn(nn.Conv2d(512, 1024, (5, 1), 1, padding=(2, 0))),
|
26 |
+
]
|
27 |
+
)
|
28 |
+
self.conv_post = wn(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
"""
|
32 |
+
Args:
|
33 |
+
x: [B, 1, T]
|
34 |
+
"""
|
35 |
+
assert x.dim() == 3, f"(B, 1, T) is expected, but got {x.shape}."
|
36 |
+
|
37 |
+
# 1d to 2d
|
38 |
+
b, c, t = x.shape
|
39 |
+
if t % self.period != 0: # pad first
|
40 |
+
n_pad = self.period - (t % self.period)
|
41 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
42 |
+
t = t + n_pad
|
43 |
+
x = x.view(b, c, t // self.period, self.period)
|
44 |
+
|
45 |
+
for l in self.convs:
|
46 |
+
x = l(x)
|
47 |
+
x = F.leaky_relu(x, 0.2)
|
48 |
+
x = self.conv_post(x)
|
49 |
+
x = torch.flatten(x, 1, -1)
|
50 |
+
|
51 |
+
return x
|
52 |
+
|
53 |
+
|
54 |
+
class SpecNetwork(nn.Module):
|
55 |
+
def __init__(self, stft_cfg: dict):
|
56 |
+
super().__init__()
|
57 |
+
wn = weight_norm
|
58 |
+
self.stft_cfg = stft_cfg
|
59 |
+
self.convs = nn.ModuleList(
|
60 |
+
[
|
61 |
+
wn(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
|
62 |
+
wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
63 |
+
wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
64 |
+
wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
|
65 |
+
wn(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
|
66 |
+
]
|
67 |
+
)
|
68 |
+
self.conv_post = wn(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
"""
|
72 |
+
Args:
|
73 |
+
x: [B, 1, T]
|
74 |
+
"""
|
75 |
+
x = self.spectrogram(x)
|
76 |
+
x = x.unsqueeze(1)
|
77 |
+
for l in self.convs:
|
78 |
+
x = l(x)
|
79 |
+
x = F.leaky_relu(x, 0.2)
|
80 |
+
x = self.conv_post(x)
|
81 |
+
x = x.flatten(1, -1)
|
82 |
+
return x
|
83 |
+
|
84 |
+
def spectrogram(self, x):
|
85 |
+
"""
|
86 |
+
Args:
|
87 |
+
x: [B, 1, T]
|
88 |
+
"""
|
89 |
+
x = x.squeeze(1)
|
90 |
+
dtype = x.dtype
|
91 |
+
stft_cfg = dict(self.stft_cfg)
|
92 |
+
x = torch.stft(x.float(), center=False, return_complex=False, **stft_cfg)
|
93 |
+
mag = x.norm(p=2, dim=-1) # [B, F, TT]
|
94 |
+
mag = mag.to(dtype) # [B, F, TT]
|
95 |
+
return mag
|
96 |
+
|
97 |
+
|
98 |
+
class MD(nn.ModuleList):
|
99 |
+
def __init__(self, l: list):
|
100 |
+
super().__init__([self._create_network(x) for x in l])
|
101 |
+
self._loss_type = None
|
102 |
+
|
103 |
+
def loss_type_(self, loss_type):
|
104 |
+
self._loss_type = loss_type
|
105 |
+
|
106 |
+
def _create_network(self, _):
|
107 |
+
raise NotImplementedError
|
108 |
+
|
109 |
+
def _forward_each(self, d, x, y):
|
110 |
+
assert self._loss_type is not None, "loss_type is not set."
|
111 |
+
loss_type = self._loss_type
|
112 |
+
|
113 |
+
if loss_type == "hinge":
|
114 |
+
if y == 0:
|
115 |
+
# d(x) should be small -> -1
|
116 |
+
loss = F.relu(1 + d(x)).mean()
|
117 |
+
elif y == 1:
|
118 |
+
# d(x) should be large -> 1
|
119 |
+
loss = F.relu(1 - d(x)).mean()
|
120 |
+
else:
|
121 |
+
raise ValueError(f"Invalid y: {y}")
|
122 |
+
elif loss_type == "wgan":
|
123 |
+
if y == 0:
|
124 |
+
loss = d(x).mean()
|
125 |
+
elif y == 1:
|
126 |
+
loss = -d(x).mean()
|
127 |
+
else:
|
128 |
+
raise ValueError(f"Invalid y: {y}")
|
129 |
+
else:
|
130 |
+
raise ValueError(f"Invalid loss_type: {loss_type}")
|
131 |
+
|
132 |
+
return loss
|
133 |
+
|
134 |
+
def forward(self, x, y) -> Tensor:
|
135 |
+
losses = [self._forward_each(d, x, y) for d in self]
|
136 |
+
return torch.stack(losses).mean()
|
137 |
+
|
138 |
+
|
139 |
+
class MPD(MD):
|
140 |
+
def __init__(self):
|
141 |
+
super().__init__([2, 3, 7, 13, 17])
|
142 |
+
|
143 |
+
def _create_network(self, period):
|
144 |
+
return PeriodNetwork(period)
|
145 |
+
|
146 |
+
|
147 |
+
class MRD(MD):
|
148 |
+
def __init__(self, stft_cfgs):
|
149 |
+
super().__init__(stft_cfgs)
|
150 |
+
|
151 |
+
def _create_network(self, stft_cfg):
|
152 |
+
return SpecNetwork(stft_cfg)
|
153 |
+
|
154 |
+
|
155 |
+
class Discriminator(nn.Module):
|
156 |
+
@property
|
157 |
+
def wav_rate(self):
|
158 |
+
return self.hp.wav_rate
|
159 |
+
|
160 |
+
def __init__(self, hp: HParams):
|
161 |
+
super().__init__()
|
162 |
+
self.hp = hp
|
163 |
+
self.stft_cfgs = get_stft_cfgs(hp)
|
164 |
+
self.mpd = MPD()
|
165 |
+
self.mrd = MRD(self.stft_cfgs)
|
166 |
+
self.dummy_float: Tensor
|
167 |
+
self.register_buffer("dummy_float", torch.zeros(0), persistent=False)
|
168 |
+
|
169 |
+
def loss_type_(self, loss_type):
|
170 |
+
self.mpd.loss_type_(loss_type)
|
171 |
+
self.mrd.loss_type_(loss_type)
|
172 |
+
|
173 |
+
def forward(self, fake, real=None):
|
174 |
+
"""
|
175 |
+
Args:
|
176 |
+
fake: [B T]
|
177 |
+
real: [B T]
|
178 |
+
"""
|
179 |
+
fake = fake.to(self.dummy_float)
|
180 |
+
|
181 |
+
if real is None:
|
182 |
+
self.loss_type_("wgan")
|
183 |
+
else:
|
184 |
+
length_difference = (fake.shape[-1] - real.shape[-1]) / real.shape[-1]
|
185 |
+
assert length_difference < 0.05, f"length_difference should be smaller than 5%"
|
186 |
+
|
187 |
+
self.loss_type_("hinge")
|
188 |
+
real = real.to(self.dummy_float)
|
189 |
+
|
190 |
+
fake = fake[..., : real.shape[-1]]
|
191 |
+
real = real[..., : fake.shape[-1]]
|
192 |
+
|
193 |
+
losses = {}
|
194 |
+
|
195 |
+
assert fake.dim() == 2, f"(B, T) is expected, but got {fake.shape}."
|
196 |
+
assert real is None or real.dim() == 2, f"(B, T) is expected, but got {real.shape}."
|
197 |
+
|
198 |
+
fake = fake.unsqueeze(1)
|
199 |
+
|
200 |
+
if real is None:
|
201 |
+
losses["mpd"] = self.mpd(fake, 1)
|
202 |
+
losses["mrd"] = self.mrd(fake, 1)
|
203 |
+
else:
|
204 |
+
real = real.unsqueeze(1)
|
205 |
+
losses["mpd_fake"] = self.mpd(fake, 0)
|
206 |
+
losses["mpd_real"] = self.mpd(real, 1)
|
207 |
+
losses["mrd_fake"] = self.mrd(fake, 0)
|
208 |
+
losses["mrd_real"] = self.mrd(real, 1)
|
209 |
+
|
210 |
+
return losses
|
resemble-enhance/resemble_enhance/enhancer/univnet/lvcnet.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" refer from https://github.com/zceng/LVCNet """
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn.utils.parametrizations import weight_norm
|
8 |
+
|
9 |
+
from .amp import AMPBlock
|
10 |
+
|
11 |
+
|
12 |
+
class KernelPredictor(torch.nn.Module):
|
13 |
+
"""Kernel predictor for the location-variable convolutions"""
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
cond_channels,
|
18 |
+
conv_in_channels,
|
19 |
+
conv_out_channels,
|
20 |
+
conv_layers,
|
21 |
+
conv_kernel_size=3,
|
22 |
+
kpnet_hidden_channels=64,
|
23 |
+
kpnet_conv_size=3,
|
24 |
+
kpnet_dropout=0.0,
|
25 |
+
kpnet_nonlinear_activation="LeakyReLU",
|
26 |
+
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
|
27 |
+
):
|
28 |
+
"""
|
29 |
+
Args:
|
30 |
+
cond_channels (int): number of channel for the conditioning sequence,
|
31 |
+
conv_in_channels (int): number of channel for the input sequence,
|
32 |
+
conv_out_channels (int): number of channel for the output sequence,
|
33 |
+
conv_layers (int): number of layers
|
34 |
+
"""
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.conv_in_channels = conv_in_channels
|
38 |
+
self.conv_out_channels = conv_out_channels
|
39 |
+
self.conv_kernel_size = conv_kernel_size
|
40 |
+
self.conv_layers = conv_layers
|
41 |
+
|
42 |
+
kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w
|
43 |
+
kpnet_bias_channels = conv_out_channels * conv_layers # l_b
|
44 |
+
|
45 |
+
self.input_conv = nn.Sequential(
|
46 |
+
weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
|
47 |
+
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
48 |
+
)
|
49 |
+
|
50 |
+
self.residual_convs = nn.ModuleList()
|
51 |
+
padding = (kpnet_conv_size - 1) // 2
|
52 |
+
for _ in range(3):
|
53 |
+
self.residual_convs.append(
|
54 |
+
nn.Sequential(
|
55 |
+
nn.Dropout(kpnet_dropout),
|
56 |
+
weight_norm(
|
57 |
+
nn.Conv1d(
|
58 |
+
kpnet_hidden_channels,
|
59 |
+
kpnet_hidden_channels,
|
60 |
+
kpnet_conv_size,
|
61 |
+
padding=padding,
|
62 |
+
bias=True,
|
63 |
+
)
|
64 |
+
),
|
65 |
+
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
66 |
+
weight_norm(
|
67 |
+
nn.Conv1d(
|
68 |
+
kpnet_hidden_channels,
|
69 |
+
kpnet_hidden_channels,
|
70 |
+
kpnet_conv_size,
|
71 |
+
padding=padding,
|
72 |
+
bias=True,
|
73 |
+
)
|
74 |
+
),
|
75 |
+
getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
76 |
+
)
|
77 |
+
)
|
78 |
+
self.kernel_conv = weight_norm(
|
79 |
+
nn.Conv1d(
|
80 |
+
kpnet_hidden_channels,
|
81 |
+
kpnet_kernel_channels,
|
82 |
+
kpnet_conv_size,
|
83 |
+
padding=padding,
|
84 |
+
bias=True,
|
85 |
+
)
|
86 |
+
)
|
87 |
+
self.bias_conv = weight_norm(
|
88 |
+
nn.Conv1d(
|
89 |
+
kpnet_hidden_channels,
|
90 |
+
kpnet_bias_channels,
|
91 |
+
kpnet_conv_size,
|
92 |
+
padding=padding,
|
93 |
+
bias=True,
|
94 |
+
)
|
95 |
+
)
|
96 |
+
|
97 |
+
def forward(self, c):
|
98 |
+
"""
|
99 |
+
Args:
|
100 |
+
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
101 |
+
"""
|
102 |
+
batch, _, cond_length = c.shape
|
103 |
+
c = self.input_conv(c)
|
104 |
+
for residual_conv in self.residual_convs:
|
105 |
+
residual_conv.to(c.device)
|
106 |
+
c = c + residual_conv(c)
|
107 |
+
k = self.kernel_conv(c)
|
108 |
+
b = self.bias_conv(c)
|
109 |
+
kernels = k.contiguous().view(
|
110 |
+
batch,
|
111 |
+
self.conv_layers,
|
112 |
+
self.conv_in_channels,
|
113 |
+
self.conv_out_channels,
|
114 |
+
self.conv_kernel_size,
|
115 |
+
cond_length,
|
116 |
+
)
|
117 |
+
bias = b.contiguous().view(
|
118 |
+
batch,
|
119 |
+
self.conv_layers,
|
120 |
+
self.conv_out_channels,
|
121 |
+
cond_length,
|
122 |
+
)
|
123 |
+
|
124 |
+
return kernels, bias
|
125 |
+
|
126 |
+
|
127 |
+
class LVCBlock(torch.nn.Module):
|
128 |
+
"""the location-variable convolutions"""
|
129 |
+
|
130 |
+
def __init__(
|
131 |
+
self,
|
132 |
+
in_channels,
|
133 |
+
cond_channels,
|
134 |
+
stride,
|
135 |
+
dilations=[1, 3, 9, 27],
|
136 |
+
lReLU_slope=0.2,
|
137 |
+
conv_kernel_size=3,
|
138 |
+
cond_hop_length=256,
|
139 |
+
kpnet_hidden_channels=64,
|
140 |
+
kpnet_conv_size=3,
|
141 |
+
kpnet_dropout=0.0,
|
142 |
+
add_extra_noise=False,
|
143 |
+
downsampling=False,
|
144 |
+
):
|
145 |
+
super().__init__()
|
146 |
+
|
147 |
+
self.add_extra_noise = add_extra_noise
|
148 |
+
|
149 |
+
self.cond_hop_length = cond_hop_length
|
150 |
+
self.conv_layers = len(dilations)
|
151 |
+
self.conv_kernel_size = conv_kernel_size
|
152 |
+
|
153 |
+
self.kernel_predictor = KernelPredictor(
|
154 |
+
cond_channels=cond_channels,
|
155 |
+
conv_in_channels=in_channels,
|
156 |
+
conv_out_channels=2 * in_channels,
|
157 |
+
conv_layers=len(dilations),
|
158 |
+
conv_kernel_size=conv_kernel_size,
|
159 |
+
kpnet_hidden_channels=kpnet_hidden_channels,
|
160 |
+
kpnet_conv_size=kpnet_conv_size,
|
161 |
+
kpnet_dropout=kpnet_dropout,
|
162 |
+
kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope},
|
163 |
+
)
|
164 |
+
|
165 |
+
if downsampling:
|
166 |
+
self.convt_pre = nn.Sequential(
|
167 |
+
nn.LeakyReLU(lReLU_slope),
|
168 |
+
weight_norm(nn.Conv1d(in_channels, in_channels, 2 * stride + 1, padding="same")),
|
169 |
+
nn.AvgPool1d(stride, stride),
|
170 |
+
)
|
171 |
+
else:
|
172 |
+
if stride == 1:
|
173 |
+
self.convt_pre = nn.Sequential(
|
174 |
+
nn.LeakyReLU(lReLU_slope),
|
175 |
+
weight_norm(nn.Conv1d(in_channels, in_channels, 1)),
|
176 |
+
)
|
177 |
+
else:
|
178 |
+
self.convt_pre = nn.Sequential(
|
179 |
+
nn.LeakyReLU(lReLU_slope),
|
180 |
+
weight_norm(
|
181 |
+
nn.ConvTranspose1d(
|
182 |
+
in_channels,
|
183 |
+
in_channels,
|
184 |
+
2 * stride,
|
185 |
+
stride=stride,
|
186 |
+
padding=stride // 2 + stride % 2,
|
187 |
+
output_padding=stride % 2,
|
188 |
+
)
|
189 |
+
),
|
190 |
+
)
|
191 |
+
|
192 |
+
self.amp_block = AMPBlock(in_channels)
|
193 |
+
|
194 |
+
self.conv_blocks = nn.ModuleList()
|
195 |
+
for d in dilations:
|
196 |
+
self.conv_blocks.append(
|
197 |
+
nn.Sequential(
|
198 |
+
nn.LeakyReLU(lReLU_slope),
|
199 |
+
weight_norm(nn.Conv1d(in_channels, in_channels, conv_kernel_size, dilation=d, padding="same")),
|
200 |
+
nn.LeakyReLU(lReLU_slope),
|
201 |
+
)
|
202 |
+
)
|
203 |
+
|
204 |
+
def forward(self, x, c):
|
205 |
+
"""forward propagation of the location-variable convolutions.
|
206 |
+
Args:
|
207 |
+
x (Tensor): the input sequence (batch, in_channels, in_length)
|
208 |
+
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
209 |
+
|
210 |
+
Returns:
|
211 |
+
Tensor: the output sequence (batch, in_channels, in_length)
|
212 |
+
"""
|
213 |
+
_, in_channels, _ = x.shape # (B, c_g, L')
|
214 |
+
|
215 |
+
x = self.convt_pre(x) # (B, c_g, stride * L')
|
216 |
+
|
217 |
+
# Add one amp block just after the upsampling
|
218 |
+
x = self.amp_block(x) # (B, c_g, stride * L')
|
219 |
+
|
220 |
+
kernels, bias = self.kernel_predictor(c)
|
221 |
+
|
222 |
+
if self.add_extra_noise:
|
223 |
+
# Add extra noise to part of the feature
|
224 |
+
a, b = x.chunk(2, dim=1)
|
225 |
+
b = b + torch.randn_like(b) * 0.1
|
226 |
+
x = torch.cat([a, b], dim=1)
|
227 |
+
|
228 |
+
for i, conv in enumerate(self.conv_blocks):
|
229 |
+
output = conv(x) # (B, c_g, stride * L')
|
230 |
+
|
231 |
+
k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length)
|
232 |
+
b = bias[:, i, :, :] # (B, 2 * c_g, cond_length)
|
233 |
+
|
234 |
+
output = self.location_variable_convolution(
|
235 |
+
output, k, b, hop_size=self.cond_hop_length
|
236 |
+
) # (B, 2 * c_g, stride * L'): LVC
|
237 |
+
x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
|
238 |
+
output[:, in_channels:, :]
|
239 |
+
) # (B, c_g, stride * L'): GAU
|
240 |
+
|
241 |
+
return x
|
242 |
+
|
243 |
+
def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
|
244 |
+
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
|
245 |
+
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
|
246 |
+
Args:
|
247 |
+
x (Tensor): the input sequence (batch, in_channels, in_length).
|
248 |
+
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
|
249 |
+
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
|
250 |
+
dilation (int): the dilation of convolution.
|
251 |
+
hop_size (int): the hop_size of the conditioning sequence.
|
252 |
+
Returns:
|
253 |
+
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
|
254 |
+
"""
|
255 |
+
batch, _, in_length = x.shape
|
256 |
+
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
|
257 |
+
|
258 |
+
assert in_length == (
|
259 |
+
kernel_length * hop_size
|
260 |
+
), f"length of (x, kernel) is not matched, {in_length} != {kernel_length} * {hop_size}"
|
261 |
+
|
262 |
+
padding = dilation * int((kernel_size - 1) / 2)
|
263 |
+
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
|
264 |
+
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
|
265 |
+
|
266 |
+
if hop_size < dilation:
|
267 |
+
x = F.pad(x, (0, dilation), "constant", 0)
|
268 |
+
x = x.unfold(
|
269 |
+
3, dilation, dilation
|
270 |
+
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
|
271 |
+
x = x[:, :, :, :, :hop_size]
|
272 |
+
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
273 |
+
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
274 |
+
|
275 |
+
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
|
276 |
+
o = o.to(memory_format=torch.channels_last_3d)
|
277 |
+
bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
|
278 |
+
o = o + bias
|
279 |
+
o = o.contiguous().view(batch, out_channels, -1)
|
280 |
+
|
281 |
+
return o
|
resemble-enhance/resemble_enhance/enhancer/univnet/mrstft.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Copyright 2019 Tomoki Hayashi
|
4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
5 |
+
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from ..hparams import HParams
|
12 |
+
|
13 |
+
|
14 |
+
def _make_stft_cfg(hop_length, win_length=None):
|
15 |
+
if win_length is None:
|
16 |
+
win_length = 4 * hop_length
|
17 |
+
n_fft = 2 ** (win_length - 1).bit_length()
|
18 |
+
return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
|
19 |
+
|
20 |
+
|
21 |
+
def get_stft_cfgs(hp: HParams):
|
22 |
+
assert hp.wav_rate == 44100, f"wav_rate must be 44100, got {hp.wav_rate}"
|
23 |
+
return [_make_stft_cfg(h) for h in (100, 256, 512)]
|
24 |
+
|
25 |
+
|
26 |
+
def stft(x, n_fft, hop_length, win_length, window):
|
27 |
+
dtype = x.dtype
|
28 |
+
x = torch.stft(x.float(), n_fft, hop_length, win_length, window, return_complex=True)
|
29 |
+
x = x.abs().to(dtype)
|
30 |
+
x = x.transpose(2, 1) # (b f t) -> (b t f)
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
class SpectralConvergengeLoss(nn.Module):
|
35 |
+
def forward(self, x_mag, y_mag):
|
36 |
+
"""Calculate forward propagation.
|
37 |
+
Args:
|
38 |
+
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
39 |
+
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
40 |
+
Returns:
|
41 |
+
Tensor: Spectral convergence loss value.
|
42 |
+
"""
|
43 |
+
return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
|
44 |
+
|
45 |
+
|
46 |
+
class LogSTFTMagnitudeLoss(nn.Module):
|
47 |
+
def forward(self, x_mag, y_mag):
|
48 |
+
"""Calculate forward propagation.
|
49 |
+
Args:
|
50 |
+
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
51 |
+
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
52 |
+
Returns:
|
53 |
+
Tensor: Log STFT magnitude loss value.
|
54 |
+
"""
|
55 |
+
return F.l1_loss(torch.log1p(x_mag), torch.log1p(y_mag))
|
56 |
+
|
57 |
+
|
58 |
+
class STFTLoss(nn.Module):
|
59 |
+
def __init__(self, hp, stft_cfg: dict, window="hann_window"):
|
60 |
+
super().__init__()
|
61 |
+
self.hp = hp
|
62 |
+
self.stft_cfg = stft_cfg
|
63 |
+
self.spectral_convergenge_loss = SpectralConvergengeLoss()
|
64 |
+
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
|
65 |
+
self.register_buffer("window", getattr(torch, window)(stft_cfg["win_length"]), persistent=False)
|
66 |
+
|
67 |
+
def forward(self, x, y):
|
68 |
+
"""Calculate forward propagation.
|
69 |
+
Args:
|
70 |
+
x (Tensor): Predicted signal (B, T).
|
71 |
+
y (Tensor): Groundtruth signal (B, T).
|
72 |
+
Returns:
|
73 |
+
Tensor: Spectral convergence loss value.
|
74 |
+
Tensor: Log STFT magnitude loss value.
|
75 |
+
"""
|
76 |
+
stft_cfg = dict(self.stft_cfg)
|
77 |
+
x_mag = stft(x, **stft_cfg, window=self.window) # (b t) -> (b t f)
|
78 |
+
y_mag = stft(y, **stft_cfg, window=self.window)
|
79 |
+
sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
|
80 |
+
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
81 |
+
return dict(sc=sc_loss, mag=mag_loss)
|
82 |
+
|
83 |
+
|
84 |
+
class MRSTFTLoss(nn.Module):
|
85 |
+
def __init__(self, hp: HParams, window="hann_window"):
|
86 |
+
"""Initialize Multi resolution STFT loss module.
|
87 |
+
Args:
|
88 |
+
resolutions (list): List of (FFT size, hop size, window length).
|
89 |
+
window (str): Window function type.
|
90 |
+
"""
|
91 |
+
super().__init__()
|
92 |
+
stft_cfgs = get_stft_cfgs(hp)
|
93 |
+
self.stft_losses = nn.ModuleList()
|
94 |
+
self.hp = hp
|
95 |
+
for c in stft_cfgs:
|
96 |
+
self.stft_losses += [STFTLoss(hp, c, window=window)]
|
97 |
+
|
98 |
+
def forward(self, x, y):
|
99 |
+
"""Calculate forward propagation.
|
100 |
+
Args:
|
101 |
+
x (Tensor): Predicted signal (b t).
|
102 |
+
y (Tensor): Groundtruth signal (b t).
|
103 |
+
Returns:
|
104 |
+
Tensor: Multi resolution spectral convergence loss value.
|
105 |
+
Tensor: Multi resolution log STFT magnitude loss value.
|
106 |
+
"""
|
107 |
+
assert x.dim() == 2 and y.dim() == 2, f"(b t) is expected, but got {x.shape} and {y.shape}."
|
108 |
+
|
109 |
+
dtype = x.dtype
|
110 |
+
|
111 |
+
x = x.float()
|
112 |
+
y = y.float()
|
113 |
+
|
114 |
+
# Align length
|
115 |
+
x = x[..., : y.shape[-1]]
|
116 |
+
y = y[..., : x.shape[-1]]
|
117 |
+
|
118 |
+
losses = {}
|
119 |
+
|
120 |
+
for f in self.stft_losses:
|
121 |
+
d = f(x, y)
|
122 |
+
for k, v in d.items():
|
123 |
+
losses.setdefault(k, []).append(v)
|
124 |
+
|
125 |
+
for k, v in losses.items():
|
126 |
+
losses[k] = torch.stack(v, dim=0).mean().to(dtype)
|
127 |
+
|
128 |
+
return losses
|
resemble-enhance/resemble_enhance/enhancer/univnet/univnet.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import Tensor, nn
|
5 |
+
from torch.nn.utils.parametrizations import weight_norm
|
6 |
+
|
7 |
+
from ..hparams import HParams
|
8 |
+
from .lvcnet import LVCBlock
|
9 |
+
from .mrstft import MRSTFTLoss
|
10 |
+
|
11 |
+
|
12 |
+
class UnivNet(nn.Module):
|
13 |
+
@property
|
14 |
+
def d_noise(self):
|
15 |
+
return 128
|
16 |
+
|
17 |
+
@property
|
18 |
+
def strides(self):
|
19 |
+
return [7, 5, 4, 3]
|
20 |
+
|
21 |
+
@property
|
22 |
+
def dilations(self):
|
23 |
+
return [1, 3, 9, 27]
|
24 |
+
|
25 |
+
@property
|
26 |
+
def nc(self):
|
27 |
+
return self.hp.univnet_nc
|
28 |
+
|
29 |
+
@property
|
30 |
+
def scale_factor(self) -> int:
|
31 |
+
return self.hp.hop_size
|
32 |
+
|
33 |
+
def __init__(self, hp: HParams, d_input):
|
34 |
+
super().__init__()
|
35 |
+
self.d_input = d_input
|
36 |
+
|
37 |
+
self.hp = hp
|
38 |
+
|
39 |
+
self.blocks = nn.ModuleList(
|
40 |
+
[
|
41 |
+
LVCBlock(
|
42 |
+
self.nc,
|
43 |
+
d_input,
|
44 |
+
stride=stride,
|
45 |
+
dilations=self.dilations,
|
46 |
+
cond_hop_length=hop_length,
|
47 |
+
kpnet_conv_size=3,
|
48 |
+
)
|
49 |
+
for stride, hop_length in zip(self.strides, np.cumprod(self.strides))
|
50 |
+
]
|
51 |
+
)
|
52 |
+
|
53 |
+
self.conv_pre = weight_norm(nn.Conv1d(self.d_noise, self.nc, 7, padding=3, padding_mode="reflect"))
|
54 |
+
|
55 |
+
self.conv_post = nn.Sequential(
|
56 |
+
nn.LeakyReLU(0.2),
|
57 |
+
weight_norm(nn.Conv1d(self.nc, 1, 7, padding=3, padding_mode="reflect")),
|
58 |
+
nn.Tanh(),
|
59 |
+
)
|
60 |
+
|
61 |
+
self.mrstft = MRSTFTLoss(hp)
|
62 |
+
|
63 |
+
@property
|
64 |
+
def eps(self):
|
65 |
+
return 1e-5
|
66 |
+
|
67 |
+
def forward(self, x: Tensor, y: Tensor | None = None, npad=10):
|
68 |
+
"""
|
69 |
+
Args:
|
70 |
+
x: (b c t), acoustic features
|
71 |
+
y: (b t), waveform
|
72 |
+
Returns:
|
73 |
+
z: (b t), waveform
|
74 |
+
"""
|
75 |
+
assert x.ndim == 3, "x must be 3D tensor"
|
76 |
+
assert y is None or y.ndim == 2, "y must be 2D tensor"
|
77 |
+
assert x.shape[1] == self.d_input, f"x.shape[1] must be {self.d_input}, but got {x.shape}"
|
78 |
+
assert npad >= 0, "npad must be positive or zero"
|
79 |
+
|
80 |
+
x = F.pad(x, (0, npad), "constant", 0)
|
81 |
+
z = torch.randn(x.shape[0], self.d_noise, x.shape[2]).to(x)
|
82 |
+
z = self.conv_pre(z) # (b c t)
|
83 |
+
|
84 |
+
for block in self.blocks:
|
85 |
+
z = block(z, x) # (b c t)
|
86 |
+
|
87 |
+
z = self.conv_post(z) # (b 1 t)
|
88 |
+
z = z[..., : -self.scale_factor * npad]
|
89 |
+
z = z.squeeze(1) # (b t)
|
90 |
+
|
91 |
+
if y is not None:
|
92 |
+
self.losses = self.mrstft(z, y)
|
93 |
+
|
94 |
+
return z
|
resemble-enhance/resemble_enhance/hparams.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from dataclasses import asdict, dataclass
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from omegaconf import OmegaConf
|
6 |
+
from rich.console import Console
|
7 |
+
from rich.panel import Panel
|
8 |
+
from rich.table import Table
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
console = Console()
|
13 |
+
|
14 |
+
|
15 |
+
def _make_stft_cfg(hop_length, win_length=None):
|
16 |
+
if win_length is None:
|
17 |
+
win_length = 4 * hop_length
|
18 |
+
n_fft = 2 ** (win_length - 1).bit_length()
|
19 |
+
return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
|
20 |
+
|
21 |
+
|
22 |
+
def _build_rich_table(rows, columns, title=None):
|
23 |
+
table = Table(title=title, header_style=None)
|
24 |
+
for column in columns:
|
25 |
+
table.add_column(column.capitalize(), justify="left")
|
26 |
+
for row in rows:
|
27 |
+
table.add_row(*map(str, row))
|
28 |
+
return Panel(table, expand=False)
|
29 |
+
|
30 |
+
|
31 |
+
def _rich_print_dict(d, title="Config", key="Key", value="Value"):
|
32 |
+
console.print(_build_rich_table(d.items(), [key, value], title))
|
33 |
+
|
34 |
+
|
35 |
+
@dataclass(frozen=True)
|
36 |
+
class HParams:
|
37 |
+
# Dataset
|
38 |
+
fg_dir: Path = Path("data/fg")
|
39 |
+
bg_dir: Path = Path("data/bg")
|
40 |
+
rir_dir: Path = Path("data/rir")
|
41 |
+
load_fg_only: bool = False
|
42 |
+
praat_augment_prob: float = 0
|
43 |
+
|
44 |
+
# Audio settings
|
45 |
+
wav_rate: int = 44_100
|
46 |
+
n_fft: int = 2048
|
47 |
+
win_size: int = 2048
|
48 |
+
hop_size: int = 420 # 9.5ms
|
49 |
+
num_mels: int = 128
|
50 |
+
stft_magnitude_min: float = 1e-4
|
51 |
+
preemphasis: float = 0.97
|
52 |
+
mix_alpha_range: tuple[float, float] = (0.2, 0.8)
|
53 |
+
|
54 |
+
# Training
|
55 |
+
nj: int = 64
|
56 |
+
training_seconds: float = 1.0
|
57 |
+
batch_size_per_gpu: int = 16
|
58 |
+
min_lr: float = 1e-5
|
59 |
+
max_lr: float = 1e-4
|
60 |
+
warmup_steps: int = 1000
|
61 |
+
max_steps: int = 1_000_000
|
62 |
+
gradient_clipping: float = 1.0
|
63 |
+
|
64 |
+
@property
|
65 |
+
def deepspeed_config(self):
|
66 |
+
return {
|
67 |
+
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
|
68 |
+
"optimizer": {
|
69 |
+
"type": "Adam",
|
70 |
+
"params": {"lr": float(self.min_lr)},
|
71 |
+
},
|
72 |
+
"scheduler": {
|
73 |
+
"type": "WarmupDecayLR",
|
74 |
+
"params": {
|
75 |
+
"warmup_min_lr": float(self.min_lr),
|
76 |
+
"warmup_max_lr": float(self.max_lr),
|
77 |
+
"warmup_num_steps": self.warmup_steps,
|
78 |
+
"total_num_steps": self.max_steps,
|
79 |
+
"warmup_type": "linear",
|
80 |
+
},
|
81 |
+
},
|
82 |
+
"gradient_clipping": self.gradient_clipping,
|
83 |
+
}
|
84 |
+
|
85 |
+
@property
|
86 |
+
def stft_cfgs(self):
|
87 |
+
assert self.wav_rate == 44_100, f"wav_rate must be 44_100, got {self.wav_rate}"
|
88 |
+
return [_make_stft_cfg(h) for h in (100, 256, 512)]
|
89 |
+
|
90 |
+
@classmethod
|
91 |
+
def from_yaml(cls, path: Path) -> "HParams":
|
92 |
+
logger.info(f"Reading hparams from {path}")
|
93 |
+
# First merge to fix types (e.g., str -> Path)
|
94 |
+
return cls(**dict(OmegaConf.merge(cls(), OmegaConf.load(path))))
|
95 |
+
|
96 |
+
def save_if_not_exists(self, run_dir: Path):
|
97 |
+
path = run_dir / "hparams.yaml"
|
98 |
+
if path.exists():
|
99 |
+
logger.info(f"{path} already exists, not saving")
|
100 |
+
return
|
101 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
102 |
+
OmegaConf.save(asdict(self), str(path))
|
103 |
+
|
104 |
+
@classmethod
|
105 |
+
def load(cls, run_dir, yaml: Path | None = None):
|
106 |
+
hps = []
|
107 |
+
|
108 |
+
if (run_dir / "hparams.yaml").exists():
|
109 |
+
hps.append(cls.from_yaml(run_dir / "hparams.yaml"))
|
110 |
+
|
111 |
+
if yaml is not None:
|
112 |
+
hps.append(cls.from_yaml(yaml))
|
113 |
+
|
114 |
+
if len(hps) == 0:
|
115 |
+
hps.append(cls())
|
116 |
+
|
117 |
+
for hp in hps[1:]:
|
118 |
+
if hp != hps[0]:
|
119 |
+
errors = {}
|
120 |
+
for k, v in asdict(hp).items():
|
121 |
+
if getattr(hps[0], k) != v:
|
122 |
+
errors[k] = f"{getattr(hps[0], k)} != {v}"
|
123 |
+
raise ValueError(f"Found inconsistent hparams: {errors}, consider deleting {run_dir}")
|
124 |
+
|
125 |
+
return hps[0]
|
126 |
+
|
127 |
+
def print(self):
|
128 |
+
_rich_print_dict(asdict(self), title="HParams")
|
resemble-enhance/resemble_enhance/inference.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import time
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch.nn.utils.parametrize import remove_parametrizations
|
7 |
+
from torchaudio.functional import resample
|
8 |
+
from torchaudio.transforms import MelSpectrogram
|
9 |
+
from tqdm import trange
|
10 |
+
|
11 |
+
from .hparams import HParams
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
@torch.inference_mode()
|
17 |
+
def inference_chunk(model, dwav, sr, device, npad=441):
|
18 |
+
assert model.hp.wav_rate == sr, f"Expected {model.hp.wav_rate} Hz, got {sr} Hz"
|
19 |
+
del sr
|
20 |
+
|
21 |
+
length = dwav.shape[-1]
|
22 |
+
abs_max = dwav.abs().max().clamp(min=1e-7)
|
23 |
+
|
24 |
+
assert dwav.dim() == 1, f"Expected 1D waveform, got {dwav.dim()}D"
|
25 |
+
dwav = dwav.to(device)
|
26 |
+
dwav = dwav / abs_max # Normalize
|
27 |
+
dwav = F.pad(dwav, (0, npad))
|
28 |
+
hwav = model(dwav[None])[0].cpu() # (T,)
|
29 |
+
hwav = hwav[:length] # Trim padding
|
30 |
+
hwav = hwav * abs_max # Unnormalize
|
31 |
+
|
32 |
+
return hwav
|
33 |
+
|
34 |
+
|
35 |
+
def compute_corr(x, y):
|
36 |
+
return torch.fft.ifft(torch.fft.fft(x) * torch.fft.fft(y).conj()).abs()
|
37 |
+
|
38 |
+
|
39 |
+
def compute_offset(chunk1, chunk2, sr=44100):
|
40 |
+
"""
|
41 |
+
Args:
|
42 |
+
chunk1: (T,)
|
43 |
+
chunk2: (T,)
|
44 |
+
Returns:
|
45 |
+
offset: int, offset in samples such that chunk1 ~= chunk2.roll(-offset)
|
46 |
+
"""
|
47 |
+
hop_length = sr // 200 # 5 ms resolution
|
48 |
+
win_length = hop_length * 4
|
49 |
+
n_fft = 2 ** (win_length - 1).bit_length()
|
50 |
+
|
51 |
+
mel_fn = MelSpectrogram(
|
52 |
+
sample_rate=sr,
|
53 |
+
n_fft=n_fft,
|
54 |
+
win_length=win_length,
|
55 |
+
hop_length=hop_length,
|
56 |
+
n_mels=80,
|
57 |
+
f_min=0.0,
|
58 |
+
f_max=sr // 2,
|
59 |
+
)
|
60 |
+
|
61 |
+
spec1 = mel_fn(chunk1).log1p()
|
62 |
+
spec2 = mel_fn(chunk2).log1p()
|
63 |
+
|
64 |
+
corr = compute_corr(spec1, spec2) # (F, T)
|
65 |
+
corr = corr.mean(dim=0) # (T,)
|
66 |
+
|
67 |
+
argmax = corr.argmax().item()
|
68 |
+
|
69 |
+
if argmax > len(corr) // 2:
|
70 |
+
argmax -= len(corr)
|
71 |
+
|
72 |
+
offset = -argmax * hop_length
|
73 |
+
|
74 |
+
return offset
|
75 |
+
|
76 |
+
|
77 |
+
def merge_chunks(chunks, chunk_length, hop_length, sr=44100, length=None):
|
78 |
+
signal_length = (len(chunks) - 1) * hop_length + chunk_length
|
79 |
+
overlap_length = chunk_length - hop_length
|
80 |
+
signal = torch.zeros(signal_length, device=chunks[0].device)
|
81 |
+
|
82 |
+
fadein = torch.linspace(0, 1, overlap_length, device=chunks[0].device)
|
83 |
+
fadein = torch.cat([fadein, torch.ones(hop_length, device=chunks[0].device)])
|
84 |
+
fadeout = torch.linspace(1, 0, overlap_length, device=chunks[0].device)
|
85 |
+
fadeout = torch.cat([torch.ones(hop_length, device=chunks[0].device), fadeout])
|
86 |
+
|
87 |
+
for i, chunk in enumerate(chunks):
|
88 |
+
start = i * hop_length
|
89 |
+
end = start + chunk_length
|
90 |
+
|
91 |
+
if len(chunk) < chunk_length:
|
92 |
+
chunk = F.pad(chunk, (0, chunk_length - len(chunk)))
|
93 |
+
|
94 |
+
if i > 0:
|
95 |
+
pre_region = chunks[i - 1][-overlap_length:]
|
96 |
+
cur_region = chunk[:overlap_length]
|
97 |
+
offset = compute_offset(pre_region, cur_region, sr=sr)
|
98 |
+
start -= offset
|
99 |
+
end -= offset
|
100 |
+
|
101 |
+
if i == 0:
|
102 |
+
chunk = chunk * fadeout
|
103 |
+
elif i == len(chunks) - 1:
|
104 |
+
chunk = chunk * fadein
|
105 |
+
else:
|
106 |
+
chunk = chunk * fadein * fadeout
|
107 |
+
|
108 |
+
signal[start:end] += chunk[: len(signal[start:end])]
|
109 |
+
|
110 |
+
signal = signal[:length]
|
111 |
+
|
112 |
+
return signal
|
113 |
+
|
114 |
+
|
115 |
+
def remove_weight_norm_recursively(module):
|
116 |
+
for _, module in module.named_modules():
|
117 |
+
try:
|
118 |
+
remove_parametrizations(module, "weight")
|
119 |
+
except Exception:
|
120 |
+
pass
|
121 |
+
|
122 |
+
|
123 |
+
def inference(model, dwav, sr, device, chunk_seconds: float = 30.0, overlap_seconds: float = 1.0):
|
124 |
+
remove_weight_norm_recursively(model)
|
125 |
+
|
126 |
+
hp: HParams = model.hp
|
127 |
+
|
128 |
+
dwav = resample(
|
129 |
+
dwav,
|
130 |
+
orig_freq=sr,
|
131 |
+
new_freq=hp.wav_rate,
|
132 |
+
lowpass_filter_width=64,
|
133 |
+
rolloff=0.9475937167399596,
|
134 |
+
resampling_method="sinc_interp_kaiser",
|
135 |
+
beta=14.769656459379492,
|
136 |
+
)
|
137 |
+
|
138 |
+
del sr # Everything is in hp.wav_rate now
|
139 |
+
|
140 |
+
sr = hp.wav_rate
|
141 |
+
|
142 |
+
if torch.cuda.is_available():
|
143 |
+
torch.cuda.synchronize()
|
144 |
+
|
145 |
+
start_time = time.perf_counter()
|
146 |
+
|
147 |
+
chunk_length = int(sr * chunk_seconds)
|
148 |
+
overlap_length = int(sr * overlap_seconds)
|
149 |
+
hop_length = chunk_length - overlap_length
|
150 |
+
|
151 |
+
chunks = []
|
152 |
+
for start in trange(0, dwav.shape[-1], hop_length):
|
153 |
+
chunks.append(inference_chunk(model, dwav[start : start + chunk_length], sr, device))
|
154 |
+
|
155 |
+
hwav = merge_chunks(chunks, chunk_length, hop_length, sr=sr, length=dwav.shape[-1])
|
156 |
+
|
157 |
+
if torch.cuda.is_available():
|
158 |
+
torch.cuda.synchronize()
|
159 |
+
|
160 |
+
elapsed_time = time.perf_counter() - start_time
|
161 |
+
logger.info(f"Elapsed time: {elapsed_time:.3f} s, {hwav.shape[-1] / elapsed_time / 1000:.3f} kHz")
|
162 |
+
|
163 |
+
return hwav, sr
|
resemble-enhance/resemble_enhance/melspec.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torchaudio.transforms import MelSpectrogram as TorchMelSpectrogram
|
5 |
+
|
6 |
+
from .hparams import HParams
|
7 |
+
|
8 |
+
|
9 |
+
class MelSpectrogram(nn.Module):
|
10 |
+
def __init__(self, hp: HParams):
|
11 |
+
"""
|
12 |
+
Torch implementation of Resemble's mel extraction.
|
13 |
+
Note that the values are NOT identical to librosa's implementation
|
14 |
+
due to floating point precisions.
|
15 |
+
"""
|
16 |
+
super().__init__()
|
17 |
+
self.hp = hp
|
18 |
+
self.melspec = TorchMelSpectrogram(
|
19 |
+
hp.wav_rate,
|
20 |
+
n_fft=hp.n_fft,
|
21 |
+
win_length=hp.win_size,
|
22 |
+
hop_length=hp.hop_size,
|
23 |
+
f_min=0,
|
24 |
+
f_max=hp.wav_rate // 2,
|
25 |
+
n_mels=hp.num_mels,
|
26 |
+
power=1,
|
27 |
+
normalized=False,
|
28 |
+
# NOTE: Folowing librosa's default.
|
29 |
+
pad_mode="constant",
|
30 |
+
norm="slaney",
|
31 |
+
mel_scale="slaney",
|
32 |
+
)
|
33 |
+
self.register_buffer("stft_magnitude_min", torch.FloatTensor([hp.stft_magnitude_min]))
|
34 |
+
self.min_level_db = 20 * np.log10(hp.stft_magnitude_min)
|
35 |
+
self.preemphasis = hp.preemphasis
|
36 |
+
self.hop_size = hp.hop_size
|
37 |
+
|
38 |
+
def forward(self, wav, pad=True):
|
39 |
+
"""
|
40 |
+
Args:
|
41 |
+
wav: [B, T]
|
42 |
+
"""
|
43 |
+
device = wav.device
|
44 |
+
if wav.is_mps:
|
45 |
+
wav = wav.cpu()
|
46 |
+
self.to(wav.device)
|
47 |
+
if self.preemphasis > 0:
|
48 |
+
wav = torch.nn.functional.pad(wav, [1, 0], value=0)
|
49 |
+
wav = wav[..., 1:] - self.preemphasis * wav[..., :-1]
|
50 |
+
mel = self.melspec(wav)
|
51 |
+
mel = self._amp_to_db(mel)
|
52 |
+
mel_normed = self._normalize(mel)
|
53 |
+
assert not pad or mel_normed.shape[-1] == 1 + wav.shape[-1] // self.hop_size # Sanity check
|
54 |
+
mel_normed = mel_normed.to(device)
|
55 |
+
return mel_normed # (M, T)
|
56 |
+
|
57 |
+
def _normalize(self, s, headroom_db=15):
|
58 |
+
return (s - self.min_level_db) / (-self.min_level_db + headroom_db)
|
59 |
+
|
60 |
+
def _amp_to_db(self, x):
|
61 |
+
return x.clamp_min(self.hp.stft_magnitude_min).log10() * 20
|
resemble-enhance/resemble_enhance/utils/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .distributed import global_leader_only
|
2 |
+
from .engine import Engine, gather_attribute
|
3 |
+
from .logging import setup_logging
|
4 |
+
from .train_loop import TrainLoop, is_global_leader
|
5 |
+
from .utils import save_mels, tree_map
|
resemble-enhance/resemble_enhance/utils/control.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import selectors
|
3 |
+
import sys
|
4 |
+
from functools import cache
|
5 |
+
|
6 |
+
from .distributed import global_leader_only
|
7 |
+
|
8 |
+
_logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
|
11 |
+
@cache
|
12 |
+
def _get_stdin_selector():
|
13 |
+
selector = selectors.DefaultSelector()
|
14 |
+
selector.register(fileobj=sys.stdin, events=selectors.EVENT_READ)
|
15 |
+
return selector
|
16 |
+
|
17 |
+
|
18 |
+
@global_leader_only(boardcast_return=True)
|
19 |
+
def non_blocking_input():
|
20 |
+
s = ""
|
21 |
+
selector = _get_stdin_selector()
|
22 |
+
events = selector.select(timeout=0)
|
23 |
+
for key, _ in events:
|
24 |
+
s: str = key.fileobj.readline().strip()
|
25 |
+
_logger.info(f'Get stdin "{s}".')
|
26 |
+
return s
|
resemble-enhance/resemble_enhance/utils/distributed.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import socket
|
3 |
+
from functools import cache, partial, wraps
|
4 |
+
from typing import Callable
|
5 |
+
|
6 |
+
import deepspeed
|
7 |
+
import torch
|
8 |
+
from deepspeed.accelerator import get_accelerator
|
9 |
+
from torch.distributed import broadcast_object_list
|
10 |
+
|
11 |
+
|
12 |
+
def get_free_port():
|
13 |
+
sock = socket.socket()
|
14 |
+
sock.bind(("", 0))
|
15 |
+
return sock.getsockname()[1]
|
16 |
+
|
17 |
+
|
18 |
+
@cache
|
19 |
+
def fix_unset_envs():
|
20 |
+
envs = dict(RANK="0", WORLD_SIZE="1", MASTER_ADDR="localhost", MASTER_PORT=str(get_free_port()), LOCAL_RANK="0")
|
21 |
+
|
22 |
+
for key in envs:
|
23 |
+
value = os.getenv(key)
|
24 |
+
if value is not None:
|
25 |
+
return
|
26 |
+
|
27 |
+
for key, value in envs.items():
|
28 |
+
os.environ[key] = value
|
29 |
+
|
30 |
+
|
31 |
+
@cache
|
32 |
+
def init_distributed():
|
33 |
+
fix_unset_envs()
|
34 |
+
deepspeed.init_distributed(get_accelerator().communication_backend_name())
|
35 |
+
torch.cuda.set_device(local_rank())
|
36 |
+
|
37 |
+
|
38 |
+
def local_rank():
|
39 |
+
return int(os.getenv("LOCAL_RANK", 0))
|
40 |
+
|
41 |
+
|
42 |
+
def global_rank():
|
43 |
+
return int(os.getenv("RANK", 0))
|
44 |
+
|
45 |
+
|
46 |
+
def is_local_leader():
|
47 |
+
return local_rank() == 0
|
48 |
+
|
49 |
+
|
50 |
+
def is_global_leader():
|
51 |
+
return global_rank() == 0
|
52 |
+
|
53 |
+
|
54 |
+
def leader_only(leader_only_type, fn: Callable | None = None, boardcast_return=False) -> Callable:
|
55 |
+
"""
|
56 |
+
Args:
|
57 |
+
fn: The function to decorate
|
58 |
+
boardcast_return: Whether to boardcast the return value to all processes
|
59 |
+
(may cause deadlock if the function calls another decorated function)
|
60 |
+
"""
|
61 |
+
|
62 |
+
def wrapper(fn):
|
63 |
+
if hasattr(fn, "__leader_only_type__"):
|
64 |
+
raise RuntimeError(f"Function {fn.__name__} has already been decorated with {fn.__leader_only_type__}")
|
65 |
+
|
66 |
+
fn.__leader_only_type__ = leader_only_type
|
67 |
+
|
68 |
+
if leader_only_type == "local":
|
69 |
+
guard_fn = is_local_leader
|
70 |
+
elif leader_only_type == "global":
|
71 |
+
guard_fn = is_global_leader
|
72 |
+
else:
|
73 |
+
raise ValueError(f"Unknown leader_only_type: {leader_only_type}")
|
74 |
+
|
75 |
+
@wraps(fn)
|
76 |
+
def wrapped(*args, **kwargs):
|
77 |
+
if boardcast_return:
|
78 |
+
init_distributed()
|
79 |
+
obj_list = [None]
|
80 |
+
if guard_fn():
|
81 |
+
ret = fn(*args, **kwargs)
|
82 |
+
obj_list[0] = ret
|
83 |
+
if boardcast_return:
|
84 |
+
broadcast_object_list(obj_list, src=0)
|
85 |
+
return obj_list[0]
|
86 |
+
|
87 |
+
return wrapped
|
88 |
+
|
89 |
+
if fn is None:
|
90 |
+
return wrapper
|
91 |
+
|
92 |
+
return wrapper(fn)
|
93 |
+
|
94 |
+
|
95 |
+
local_leader_only = partial(leader_only, "local")
|
96 |
+
global_leader_only = partial(leader_only, "global")
|
resemble-enhance/resemble_enhance/utils/engine.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import re
|
3 |
+
from functools import cache, partial
|
4 |
+
from typing import Callable, TypeVar
|
5 |
+
|
6 |
+
import deepspeed
|
7 |
+
import pandas as pd
|
8 |
+
from deepspeed.accelerator import get_accelerator
|
9 |
+
from deepspeed.runtime.engine import DeepSpeedEngine
|
10 |
+
from deepspeed.runtime.utils import clip_grad_norm_
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
from .distributed import fix_unset_envs
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
T = TypeVar("T")
|
18 |
+
|
19 |
+
|
20 |
+
def flatten_dict(d):
|
21 |
+
records = pd.json_normalize(d, sep="/").to_dict(orient="records")
|
22 |
+
return records[0] if records else {}
|
23 |
+
|
24 |
+
|
25 |
+
def _get_named_modules(module, attrname, sep="/"):
|
26 |
+
for name, module in module.named_modules():
|
27 |
+
name = name.replace(".", sep)
|
28 |
+
if hasattr(module, attrname):
|
29 |
+
yield name, module
|
30 |
+
|
31 |
+
|
32 |
+
def gather_attribute(module, attrname, delete=True, prefix=None):
|
33 |
+
ret = {}
|
34 |
+
for name, module in _get_named_modules(module, attrname):
|
35 |
+
ret[name] = getattr(module, attrname)
|
36 |
+
if delete:
|
37 |
+
try:
|
38 |
+
delattr(module, attrname)
|
39 |
+
except Exception as e:
|
40 |
+
raise RuntimeError(f"{name} {module} {attrname}") from e
|
41 |
+
if prefix:
|
42 |
+
ret = {prefix: ret}
|
43 |
+
ret = flatten_dict(ret)
|
44 |
+
# remove consecutive /
|
45 |
+
ret = {re.sub(r"\/+", "/", k): v for k, v in ret.items()}
|
46 |
+
return ret
|
47 |
+
|
48 |
+
|
49 |
+
def dispatch_attribute(module, attrname, value, filter_fn: Callable[[nn.Module], bool] | None = None):
|
50 |
+
for _, module in _get_named_modules(module, attrname):
|
51 |
+
if filter_fn is None or filter_fn(module):
|
52 |
+
setattr(module, attrname, value)
|
53 |
+
|
54 |
+
|
55 |
+
@cache
|
56 |
+
def update_deepspeed_logger():
|
57 |
+
logger = logging.getLogger("DeepSpeed")
|
58 |
+
logger.setLevel(logging.WARNING)
|
59 |
+
|
60 |
+
|
61 |
+
@cache
|
62 |
+
def init_distributed():
|
63 |
+
update_deepspeed_logger()
|
64 |
+
fix_unset_envs()
|
65 |
+
deepspeed.init_distributed(get_accelerator().communication_backend_name())
|
66 |
+
|
67 |
+
|
68 |
+
def _try_each(*fns, e=None):
|
69 |
+
if len(fns) == 0:
|
70 |
+
raise RuntimeError("All functions failed")
|
71 |
+
|
72 |
+
head, *tails = fns
|
73 |
+
|
74 |
+
try:
|
75 |
+
return head()
|
76 |
+
except Exception as e:
|
77 |
+
logger.warning(f"Tried {head} but failed: {e}, trying next")
|
78 |
+
return _try_each(*tails)
|
79 |
+
|
80 |
+
|
81 |
+
class Engine(DeepSpeedEngine):
|
82 |
+
def __init__(self, *args, ckpt_dir, **kwargs):
|
83 |
+
init_distributed()
|
84 |
+
super().__init__(args=None, *args, **kwargs)
|
85 |
+
self._ckpt_dir = ckpt_dir
|
86 |
+
self._frozen_params = set()
|
87 |
+
self._fp32_grad_norm = None
|
88 |
+
|
89 |
+
@property
|
90 |
+
def path(self):
|
91 |
+
return self._ckpt_dir
|
92 |
+
|
93 |
+
def freeze_(self):
|
94 |
+
for p in self.module.parameters():
|
95 |
+
if p.requires_grad:
|
96 |
+
p.requires_grad_(False)
|
97 |
+
self._frozen_params.add(p)
|
98 |
+
|
99 |
+
def unfreeze_(self):
|
100 |
+
for p in self._frozen_params:
|
101 |
+
p.requires_grad_(True)
|
102 |
+
self._frozen_params.clear()
|
103 |
+
|
104 |
+
@property
|
105 |
+
def global_step(self):
|
106 |
+
return self.global_steps
|
107 |
+
|
108 |
+
def gather_attribute(self, *args, **kwargs):
|
109 |
+
return gather_attribute(self.module, *args, **kwargs)
|
110 |
+
|
111 |
+
def dispatch_attribute(self, *args, **kwargs):
|
112 |
+
return dispatch_attribute(self.module, *args, **kwargs)
|
113 |
+
|
114 |
+
def clip_fp32_gradients(self):
|
115 |
+
self._fp32_grad_norm = clip_grad_norm_(
|
116 |
+
parameters=self.module.parameters(),
|
117 |
+
max_norm=self.gradient_clipping(),
|
118 |
+
mpu=self.mpu,
|
119 |
+
)
|
120 |
+
|
121 |
+
def get_grad_norm(self):
|
122 |
+
grad_norm = self.get_global_grad_norm()
|
123 |
+
if grad_norm is None:
|
124 |
+
grad_norm = self._fp32_grad_norm
|
125 |
+
return grad_norm
|
126 |
+
|
127 |
+
def save_checkpoint(self, *args, **kwargs):
|
128 |
+
if not self._ckpt_dir.exists():
|
129 |
+
self._ckpt_dir.mkdir(parents=True, exist_ok=True)
|
130 |
+
super().save_checkpoint(save_dir=self._ckpt_dir, *args, **kwargs)
|
131 |
+
logger.info(f"Saved checkpoint to {self._ckpt_dir}")
|
132 |
+
|
133 |
+
def load_checkpoint(self, *args, **kwargs):
|
134 |
+
fn = partial(super().load_checkpoint, *args, load_dir=self._ckpt_dir, **kwargs)
|
135 |
+
return _try_each(
|
136 |
+
lambda: fn(),
|
137 |
+
lambda: fn(load_optimizer_states=False),
|
138 |
+
lambda: fn(load_lr_scheduler_states=False),
|
139 |
+
lambda: fn(load_optimizer_states=False, load_lr_scheduler_states=False),
|
140 |
+
lambda: fn(
|
141 |
+
load_optimizer_states=False,
|
142 |
+
load_lr_scheduler_states=False,
|
143 |
+
load_module_strict=False,
|
144 |
+
),
|
145 |
+
)
|
resemble-enhance/resemble_enhance/utils/logging.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
from rich.logging import RichHandler
|
5 |
+
|
6 |
+
from .distributed import global_leader_only
|
7 |
+
|
8 |
+
|
9 |
+
@global_leader_only
|
10 |
+
def setup_logging(run_dir):
|
11 |
+
handlers = []
|
12 |
+
stdout_handler = RichHandler()
|
13 |
+
stdout_handler.setLevel(logging.INFO)
|
14 |
+
handlers.append(stdout_handler)
|
15 |
+
|
16 |
+
if run_dir is not None:
|
17 |
+
filename = Path(run_dir) / f"log.txt"
|
18 |
+
filename.parent.mkdir(parents=True, exist_ok=True)
|
19 |
+
file_handler = logging.FileHandler(filename, mode="a")
|
20 |
+
file_handler.setLevel(logging.DEBUG)
|
21 |
+
handlers.append(file_handler)
|
22 |
+
|
23 |
+
# Update all existing loggers
|
24 |
+
for name in ["DeepSpeed"]:
|
25 |
+
logger = logging.getLogger(name)
|
26 |
+
if isinstance(logger, logging.Logger):
|
27 |
+
for handler in list(logger.handlers):
|
28 |
+
logger.removeHandler(handler)
|
29 |
+
for handler in handlers:
|
30 |
+
logger.addHandler(handler)
|
31 |
+
|
32 |
+
# Set the default logger
|
33 |
+
logging.basicConfig(
|
34 |
+
level=logging.getLevelName("INFO"),
|
35 |
+
format="%(message)s",
|
36 |
+
datefmt="[%X]",
|
37 |
+
handlers=handlers,
|
38 |
+
)
|