diff --git a/modules/Enhancer/ResembleEnhance.py b/modules/Enhancer/ResembleEnhance.py
index 1bf974b2ab1ba5b399d545b35b8f2ef5d3e23e6e..adbf23d68cdf85479d69c7c50ed846badd803f94 100644
--- a/modules/Enhancer/ResembleEnhance.py
+++ b/modules/Enhancer/ResembleEnhance.py
@@ -1,13 +1,8 @@
 import os
 from typing import List
-
-try:
-    from resemble_enhance.enhancer.enhancer import Enhancer
-    from resemble_enhance.enhancer.hparams import HParams
-    from resemble_enhance.inference import inference
-except:
-    HParams = dict
-    Enhancer = dict
+from modules.repos_static.resemble_enhance.enhancer.enhancer import Enhancer
+from modules.repos_static.resemble_enhance.enhancer.hparams import HParams
+from modules.repos_static.resemble_enhance.inference import inference
 
 import torch
 
diff --git a/modules/repos_static/__init__.py b/modules/repos_static/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/repos_static/readme.md b/modules/repos_static/readme.md
new file mode 100644
index 0000000000000000000000000000000000000000..9b64e0e376a50b3cf2c05071cca74b8807306811
--- /dev/null
+++ b/modules/repos_static/readme.md
@@ -0,0 +1,5 @@
+# repos static
+
+## resemble_enhance
+
+https://github.com/resemble-ai/resemble-enhance/tree/main
diff --git a/modules/repos_static/resemble_enhance/__init__.py b/modules/repos_static/resemble_enhance/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/repos_static/resemble_enhance/common.py b/modules/repos_static/resemble_enhance/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfe3980103294a7b57fce918ffa8592f7b935c50
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/common.py
@@ -0,0 +1,55 @@
+import logging
+
+import torch
+from torch import Tensor, nn
+
+logger = logging.getLogger(__name__)
+
+
+class Normalizer(nn.Module):
+    def __init__(self, momentum=0.01, eps=1e-9):
+        super().__init__()
+        self.momentum = momentum
+        self.eps = eps
+        self.running_mean_unsafe: Tensor
+        self.running_var_unsafe: Tensor
+        self.register_buffer("running_mean_unsafe", torch.full([], torch.nan))
+        self.register_buffer("running_var_unsafe", torch.full([], torch.nan))
+
+    @property
+    def started(self):
+        return not torch.isnan(self.running_mean_unsafe)
+
+    @property
+    def running_mean(self):
+        if not self.started:
+            return torch.zeros_like(self.running_mean_unsafe)
+        return self.running_mean_unsafe
+
+    @property
+    def running_std(self):
+        if not self.started:
+            return torch.ones_like(self.running_var_unsafe)
+        return (self.running_var_unsafe + self.eps).sqrt()
+
+    @torch.no_grad()
+    def _ema(self, a: Tensor, x: Tensor):
+        return (1 - self.momentum) * a + self.momentum * x
+
+    def update_(self, x):
+        if not self.started:
+            self.running_mean_unsafe = x.mean()
+            self.running_var_unsafe = x.var()
+        else:
+            self.running_mean_unsafe = self._ema(self.running_mean_unsafe, x.mean())
+            self.running_var_unsafe = self._ema(self.running_var_unsafe, (x - self.running_mean).pow(2).mean())
+
+    def forward(self, x: Tensor, update=True):
+        if self.training and update:
+            self.update_(x)
+        self.stats = dict(mean=self.running_mean.item(), std=self.running_std.item())
+        x = (x - self.running_mean) / self.running_std
+        return x
+
+    def inverse(self, x: Tensor):
+        return x * self.running_std + self.running_mean
diff --git a/modules/repos_static/resemble_enhance/data/__init__.py b/modules/repos_static/resemble_enhance/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ebc6373ce4e90804e2f12828b7d9467a85656e5
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/data/__init__.py
@@ -0,0 +1,48 @@
+import logging
+import random
+
+from torch.utils.data import DataLoader
+
+from ..hparams import HParams
+from .dataset import Dataset
+from .utils import mix_fg_bg, rglob_audio_files
+
+logger = logging.getLogger(__name__)
+
+
+def _create_datasets(hp: HParams, mode, val_size=10, seed=123):
+    paths = rglob_audio_files(hp.fg_dir)
+    logger.info(f"Found {len(paths)} audio files in {hp.fg_dir}")
+
+    random.Random(seed).shuffle(paths)
+    train_paths = paths[:-val_size]
+    val_paths = paths[-val_size:]
+
+    train_ds = Dataset(train_paths, hp, training=True, mode=mode)
+    val_ds = Dataset(val_paths, hp, training=False, mode=mode)
+
+    logger.info(f"Train set: {len(train_ds)} samples - Val set: {len(val_ds)} samples")
+
+    return train_ds, val_ds
+
+
+def create_dataloaders(hp: HParams, mode):
+    train_ds, val_ds = _create_datasets(hp=hp, mode=mode)
+
+    train_dl = DataLoader(
+        train_ds,
+        batch_size=hp.batch_size_per_gpu,
+        shuffle=True,
+        num_workers=hp.nj,
+        drop_last=True,
+        collate_fn=train_ds.collate_fn,
+    )
+    val_dl = DataLoader(
+        val_ds,
+        batch_size=1,
+        shuffle=False,
+        num_workers=hp.nj,
+        drop_last=False,
+        collate_fn=val_ds.collate_fn,
+    )
+    return train_dl, val_dl
diff --git a/modules/repos_static/resemble_enhance/data/dataset.py b/modules/repos_static/resemble_enhance/data/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ba57c1736367345d171c2fc4feceefbfc25362a
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/data/dataset.py
@@ -0,0 +1,171 @@
+import logging
+import random
+from pathlib import Path
+
+import numpy as np
+import torch
+import torchaudio
+import torchaudio.functional as AF
+from torch.nn.utils.rnn import pad_sequence
+from torch.utils.data import Dataset as DatasetBase
+
+from ..hparams import HParams
+from .distorter import Distorter
+from .utils import rglob_audio_files
+
+logger = logging.getLogger(__name__)
+
+
+def _normalize(x):
+    return x / (np.abs(x).max() + 1e-7)
+
+
+def _collate(batch, key, tensor=True, pad=True):
+    l = [d[key] for d in batch]
+    if l[0] is None:
+        return None
+    if tensor:
+        l = [torch.from_numpy(x) for x in l]
+    if pad:
+        assert tensor, "Can't pad non-tensor"
+        l = pad_sequence(l, batch_first=True)
+    return l
+
+
+def praat_augment(wav, sr):
+    try:
+        import parselmouth
+    except ImportError:
+        raise ImportError("Please install parselmouth>=0.5.0 to use Praat augmentation")
+    # "praat-parselmouth @ git+https://github.com/YannickJadoul/Parselmouth@0bbcca69705ed73322f3712b19d71bb3694b2540",
+    # https://github.com/YannickJadoul/Parselmouth/issues/68
+    # note that this function may hang if the praat version is 0.4.3
+    assert wav.ndim == 1, f"wav.ndim must be 1 but got {wav.ndim}"
+    sound = parselmouth.Sound(wav, sr)
+    formant_shift_ratio = random.uniform(1.1, 1.5)
+    pitch_range_factor = random.uniform(0.5, 2.0)
+    sound = parselmouth.praat.call(sound, "Change gender", 75, 600, formant_shift_ratio, 0, pitch_range_factor, 1.0)
+    wav = np.array(sound.values)[0].astype(np.float32)
+    return wav
+
+
+class Dataset(DatasetBase):
+    def __init__(
+        self,
+        fg_paths: list[Path],
+        hp: HParams,
+        training=True,
+        max_retries=100,
+        silent_fg_prob=0.01,
+        mode=False,
+    ):
+        super().__init__()
+
+        assert mode in ("enhancer", "denoiser"), f"Invalid mode: {mode}"
+
+        self.hp = hp
+        self.fg_paths = fg_paths
+        self.bg_paths = rglob_audio_files(hp.bg_dir)
+
+        if len(self.fg_paths) == 0:
+            raise ValueError(f"No foreground audio files found in {hp.fg_dir}")
+
+        if len(self.bg_paths) == 0:
+            raise ValueError(f"No background audio files found in {hp.bg_dir}")
+
+        logger.info(f"Found {len(self.fg_paths)} foreground files and {len(self.bg_paths)} background files")
+
+        self.training = training
+        self.max_retries = max_retries
+        self.silent_fg_prob = silent_fg_prob
+
+        self.mode = mode
+        self.distorter = Distorter(hp, training=training, mode=mode)
+
+    def _load_wav(self, path, length=None, random_crop=True):
+        wav, sr = torchaudio.load(path)
+
+        wav = AF.resample(
+            waveform=wav,
+            orig_freq=sr,
+            new_freq=self.hp.wav_rate,
+            lowpass_filter_width=64,
+            rolloff=0.9475937167399596,
+            resampling_method="sinc_interp_kaiser",
+            beta=14.769656459379492,
+        )
+
+        wav = wav.float().numpy()
+
+        if wav.ndim == 2:
+            wav = np.mean(wav, axis=0)
+
+        if length is None and self.training:
+            length = int(self.hp.training_seconds * self.hp.wav_rate)
+
+        if length is not None:
+            if random_crop:
+                start = random.randint(0, max(0, len(wav) - length))
+                wav = wav[start : start + length]
+            else:
+                wav = wav[:length]
+
+        if length is not None and len(wav) < length:
+            wav = np.pad(wav, (0, length - len(wav)))
+
+        wav = _normalize(wav)
+
+        return wav
+
+    def _getitem_unsafe(self, index: int):
+        fg_path = self.fg_paths[index]
+
+        if self.training and random.random() < self.silent_fg_prob:
+            fg_wav = np.zeros(int(self.hp.training_seconds * self.hp.wav_rate), dtype=np.float32)
+        else:
+            fg_wav = self._load_wav(fg_path)
+            if random.random() < self.hp.praat_augment_prob and self.training:
+                fg_wav = praat_augment(fg_wav, self.hp.wav_rate)
+
+        if self.hp.load_fg_only:
+            bg_wav = None
+            fg_dwav = None
+            bg_dwav = None
+        else:
+            fg_dwav = _normalize(self.distorter(fg_wav, self.hp.wav_rate)).astype(np.float32)
+            if self.training:
+                bg_path = random.choice(self.bg_paths)
+            else:
+                # Deterministic for validation
+                bg_path = self.bg_paths[index % len(self.bg_paths)]
+            bg_wav = self._load_wav(bg_path, length=len(fg_wav), random_crop=self.training)
+            bg_dwav = _normalize(self.distorter(bg_wav, self.hp.wav_rate)).astype(np.float32)
+
+        return dict(
+            fg_wav=fg_wav,
+            bg_wav=bg_wav,
+            fg_dwav=fg_dwav,
+            bg_dwav=bg_dwav,
+        )
+
+    def __getitem__(self, index: int):
+        for i in range(self.max_retries):
+            try:
+                return self._getitem_unsafe(index)
+            except Exception as e:
+                if i == self.max_retries - 1:
+                    raise RuntimeError(f"Failed to load {self.fg_paths[index]} after {self.max_retries} retries") from e
+                logger.debug(f"Error loading {self.fg_paths[index]}: {e}, skipping")
+                index = np.random.randint(0, len(self))
+
+    def __len__(self):
+        return len(self.fg_paths)
+
+    @staticmethod
+    def collate_fn(batch):
+        return dict(
+            fg_wavs=_collate(batch, "fg_wav"),
+            bg_wavs=_collate(batch, "bg_wav"),
+            fg_dwavs=_collate(batch, "fg_dwav"),
+            bg_dwavs=_collate(batch, "bg_dwav"),
+        )
diff --git a/modules/repos_static/resemble_enhance/data/distorter/__init__.py b/modules/repos_static/resemble_enhance/data/distorter/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad059fd9af40fbfac1aceebf39fac6a09562c7de
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/data/distorter/__init__.py
@@ -0,0 +1 @@
+from .distorter import Distorter
diff --git a/modules/repos_static/resemble_enhance/data/distorter/base.py b/modules/repos_static/resemble_enhance/data/distorter/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..d43d84fa840dd25804d9c5e5dc9673f65fdc5b94
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/data/distorter/base.py
@@ -0,0 +1,104 @@
+import itertools
+import os
+import random
+import time
+import warnings
+
+import numpy as np
+
+_DEBUG = bool(os.environ.get("DEBUG", False))
+
+
+class Effect:
+    def apply(self, wav: np.ndarray, sr: int):
+        """
+        Args:
+            wav: (T)
+            sr: sample rate
+        Returns:
+            wav: (T) with the same sample rate of `sr`
+        """
+        raise NotImplementedError
+
+    def __call__(self, wav: np.ndarray, sr: int):
+        """
+        Args:
+            wav: (T)
+            sr: sample rate
+        Returns:
+            wav: (T) with the same sample rate of `sr`
+        """
+        assert len(wav.shape) == 1, wav.shape
+
+        if _DEBUG:
+            start = time.time()
+        else:
+            start = None
+
+        shape = wav.shape
+        assert wav.ndim == 1, f"{self}: Expected wav.ndim == 1, got {wav.ndim}."
+        wav = self.apply(wav, sr)
+        assert shape == wav.shape, f"{self}: {shape} != {wav.shape}."
+
+        if start is not None:
+            end = time.time()
+            print(f"{self.__class__.__name__}: {end - start:.3f} sec")
+
+        return wav
+
+
+class Chain(Effect):
+    def __init__(self, *effects):
+        super().__init__()
+
+        self.effects = effects
+
+    def apply(self, wav, sr):
+        for effect in self.effects:
+            wav = effect(wav, sr)
+        return wav
+
+
+class Maybe(Effect):
+    def __init__(self, prob, effect):
+        super().__init__()
+
+        self.prob = prob
+        self.effect = effect
+
+        if _DEBUG:
+            warnings.warn("DEBUG mode is on. Maybe -> Must.")
+            self.prob = 1
+
+    def apply(self, wav, sr):
+        if random.random() > self.prob:
+            return wav
+        return self.effect(wav, sr)
+
+
+class Choice(Effect):
+    def __init__(self, *effects, **kwargs):
+        super().__init__()
+        self.effects = effects
+        self.kwargs = kwargs
+
+    def apply(self, wav, sr):
+        return np.random.choice(self.effects, **self.kwargs)(wav, sr)
+
+
+class Permutation(Effect):
+    def __init__(self, *effects, n: int | None = None):
+        super().__init__()
+        self.effects = effects
+        self.n = n
+
+    def apply(self, wav, sr):
+        if self.n is None:
+            n = np.random.binomial(len(self.effects), 0.5)
+        else:
+            n = self.n
+        if n == 0:
+            return wav
+        perms = itertools.permutations(self.effects, n)
+        effects = random.choice(list(perms))
+        return Chain(*effects)(wav, sr)
diff --git a/modules/repos_static/resemble_enhance/data/distorter/custom.py b/modules/repos_static/resemble_enhance/data/distorter/custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..28428f7789cebb2d174c581111711f4d73f6565b
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/data/distorter/custom.py
@@ -0,0 +1,85 @@
+import logging
+import random
+from dataclasses import dataclass
+from functools import cached_property
+from pathlib import Path
+
+import librosa
+import numpy as np
+from scipy import signal
+
+from ..utils import walk_paths
+from .base import Effect
+
+_logger = logging.getLogger(__name__)
+
+
+@dataclass
+class RandomRIR(Effect):
+    rir_dir: Path | None
+    rir_rate: int = 44_000
+    rir_suffix: str = ".npy"
+    deterministic: bool = False
+
+    @cached_property
+    def rir_paths(self):
+        if self.rir_dir is None:
+            return []
+        return list(walk_paths(self.rir_dir, self.rir_suffix))
+
+    def _sample_rir(self):
+        if len(self.rir_paths) == 0:
+            return None
+
+        if self.deterministic:
+            rir_path = self.rir_paths[0]
+        else:
+            rir_path = random.choice(self.rir_paths)
+
+        rir = np.squeeze(np.load(rir_path))
+        assert isinstance(rir, np.ndarray)
+
+        return rir
+
+    def apply(self, wav, sr):
+        # ref: https://github.com/haoheliu/voicefixer_main/blob/b06e07c945ac1d309b8a57ddcd599ca376b98cd9/dataloaders/augmentation/magical_effects.py#L158
+
+        if len(self.rir_paths) == 0:
+            return wav
+
+        length = len(wav)
+
+        wav = librosa.resample(wav, orig_sr=sr, target_sr=self.rir_rate, res_type="kaiser_fast")
+        rir = self._sample_rir()
+
+        wav = signal.convolve(wav, rir, mode="same")
+
+        actlev = np.max(np.abs(wav))
+        if actlev > 0.99:
+            wav = (wav / actlev) * 0.98
+
+        wav = librosa.resample(wav, orig_sr=self.rir_rate, target_sr=sr, res_type="kaiser_fast")
+
+        if abs(length - len(wav)) > 10:
+            _logger.warning(f"length mismatch: {length} vs {len(wav)}")
+
+        if length > len(wav):
+            wav = np.pad(wav, (0, length - len(wav)))
+        elif length < len(wav):
+            wav = wav[:length]
+
+        return wav
+
+
+class RandomGaussianNoise(Effect):
+    def __init__(self, alpha_range=(0.8, 1)):
+        super().__init__()
+        self.alpha_range = alpha_range
+
+    def apply(self, wav, sr):
+        noise = np.random.randn(*wav.shape)
+        noise_energy = np.sum(noise**2)
+        wav_energy = np.sum(wav**2)
+        noise = noise * np.sqrt(wav_energy / noise_energy)
+        alpha = random.uniform(*self.alpha_range)
+        return wav * alpha + noise * (1 - alpha)
diff --git a/modules/repos_static/resemble_enhance/data/distorter/distorter.py b/modules/repos_static/resemble_enhance/data/distorter/distorter.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f787a8cdbf941ae7c8e3ac925d1aa66dad5e978
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/data/distorter/distorter.py
@@ -0,0 +1,32 @@
+from ...hparams import HParams
+from .base import Chain, Choice, Permutation
+from .custom import RandomGaussianNoise, RandomRIR
+
+
+class Distorter(Chain):
+    def __init__(self, hp: HParams, training: bool = False, mode: str = "enhancer"):
+        # Lazy import
+        from .sox import RandomBandpassDistorter, RandomEqualizer, RandomLowpassDistorter, RandomOverdrive, RandomReverb
+
+        if training:
+            permutation = Permutation(
+                RandomRIR(hp.rir_dir),
+                RandomReverb(),
+                RandomGaussianNoise(),
+                RandomOverdrive(),
+                RandomEqualizer(),
+                Choice(
+                    RandomLowpassDistorter(),
+                    RandomBandpassDistorter(),
+                ),
+            )
+            if mode == "denoiser":
+                super().__init__(permutation)
+            else:
+                # 80%: distortion, 20%: clean
+                super().__init__(Choice(permutation, Chain(), p=[0.8, 0.2]))
+        else:
+            super().__init__(
+                RandomRIR(hp.rir_dir, deterministic=True),
+                RandomReverb(deterministic=True),
+            )
diff --git a/modules/repos_static/resemble_enhance/data/distorter/sox.py b/modules/repos_static/resemble_enhance/data/distorter/sox.py
new file mode 100644
index 0000000000000000000000000000000000000000..92a2d74033d33b975141c1ece7ac5619d1dfcc39
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/data/distorter/sox.py
@@ -0,0 +1,176 @@
+import logging
+import os
+import random
+import warnings
+from functools import partial
+
+import numpy as np
+import torch
+
+try:
+    import augment
+except ImportError:
+    raise ImportError(
+        "augment is not installed, please install it first using:"
+        "\npip install git+https://github.com/facebookresearch/WavAugment@54afcdb00ccc852c2f030f239f8532c9562b550e"
+    )
+
+from .base import Effect
+
+_logger = logging.getLogger(__name__)
+_DEBUG = bool(os.environ.get("DEBUG", False))
+
+
+class AttachableEffect(Effect):
+    def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
+        raise NotImplementedError
+
+    def apply(self, wav: np.ndarray, sr: int):
+        chain = augment.EffectChain()
+        chain = self.attach(chain)
+        tensor = torch.from_numpy(wav)[None].float()  # (1, T)
+        tensor = chain.apply(tensor, src_info={"rate": sr}, target_info={"channels": 1, "rate": sr})
+        wav = tensor.numpy()[0]  # (T,)
+        return wav
+
+
+class SoxEffect(AttachableEffect):
+    def __init__(self, effect_name: str, *args, **kwargs):
+        self.effect_name = effect_name
+        self.args = args
+        self.kwargs = kwargs
+
+    def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
+        _logger.debug(f"Attaching {self.effect_name} with {self.args} and {self.kwargs}")
+        if not hasattr(chain, self.effect_name):
+            raise ValueError(f"EffectChain has no attribute {self.effect_name}")
+        return getattr(chain, self.effect_name)(*self.args, **self.kwargs)
+
+
+class Maybe(AttachableEffect):
+    """
+    Attach an effect with a probability.
+    """
+
+    def __init__(self, prob: float, effect: AttachableEffect):
+        self.prob = prob
+        self.effect = effect
+        if _DEBUG:
+            warnings.warn("DEBUG mode is on. Maybe -> Must.")
+            self.prob = 1
+
+    def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
+        if random.random() > self.prob:
+            return chain
+        return self.effect.attach(chain)
+
+
+class Chain(AttachableEffect):
+    """
+    Attach a chain of effects.
+    """
+
+    def __init__(self, *effects: AttachableEffect):
+        self.effects = effects
+
+    def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
+        for effect in self.effects:
+            chain = effect.attach(chain)
+        return chain
+
+
+class Choice(AttachableEffect):
+    """
+    Attach one of the effects randomly.
+    """
+
+    def __init__(self, *effects: AttachableEffect):
+        self.effects = effects
+
+    def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
+        return random.choice(self.effects).attach(chain)
+
+
+class Generator:
+    def __call__(self) -> str:
+        raise NotImplementedError
+
+
+class Uniform(Generator):
+    def __init__(self, low, high):
+        self.low = low
+        self.high = high
+
+    def __call__(self) -> str:
+        return str(random.uniform(self.low, self.high))
+
+
+class Randint(Generator):
+    def __init__(self, low, high):
+        self.low = low
+        self.high = high
+
+    def __call__(self) -> str:
+        return str(random.randint(self.low, self.high))
+
+
+class Concat(Generator):
+    def __init__(self, *parts: Generator | str):
+        self.parts = parts
+
+    def __call__(self):
+        return "".join([part if isinstance(part, str) else part() for part in self.parts])
+
+
+class RandomLowpassDistorter(SoxEffect):
+    def __init__(self, low=2000, high=16000):
+        super().__init__("sinc", "-n", Randint(50, 200), Concat("-", Uniform(low, high)))
+
+
+class RandomBandpassDistorter(SoxEffect):
+    def __init__(self, low=100, high=1000, min_width=2000, max_width=4000):
+        super().__init__("sinc", "-n", Randint(50, 200), partial(self._fn, low, high, min_width, max_width))
+
+    @staticmethod
+    def _fn(low, high, min_width, max_width):
+        start = random.randint(low, high)
+        stop = start + random.randint(min_width, max_width)
+        return f"{start}-{stop}"
+
+
+class RandomEqualizer(SoxEffect):
+    def __init__(self, low=100, high=4000, q_low=1, q_high=5, db_low: int = -30, db_high: int = 30):
+        super().__init__(
+            "equalizer",
+            Uniform(low, high),
+            lambda: f"{random.randint(q_low, q_high)}q",
+            lambda: random.randint(db_low, db_high),
+        )
+
+
+class RandomOverdrive(SoxEffect):
+    def __init__(self, gain_low=5, gain_high=40, colour_low=20, colour_high=80):
+        super().__init__("overdrive", Uniform(gain_low, gain_high), Uniform(colour_low, colour_high))
+
+
+class RandomReverb(Chain):
+    def __init__(self, deterministic=False):
+        super().__init__(
+            SoxEffect(
+                "reverb",
+                Uniform(50, 50) if deterministic else Uniform(0, 100),
+                Uniform(50, 50) if deterministic else Uniform(0, 100),
+                Uniform(50, 50) if deterministic else Uniform(0, 100),
+            ),
+            SoxEffect("channels", 1),
+        )
+
+
+class Flanger(SoxEffect):
+    def __init__(self):
+        super().__init__("flanger")
+
+
+class Phaser(SoxEffect):
+    def __init__(self):
+        super().__init__("phaser")
diff --git a/modules/repos_static/resemble_enhance/data/utils.py b/modules/repos_static/resemble_enhance/data/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..77f59d345b75cac76c6c423c734ae9c70a626590
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/data/utils.py
@@ -0,0 +1,43 @@
+from pathlib import Path
+from typing import Callable
+
+from torch import Tensor
+
+
+def walk_paths(root, suffix):
+    for path in Path(root).iterdir():
+        if path.is_dir():
+            yield from walk_paths(path, suffix)
+        elif path.suffix == suffix:
+            yield path
+
+
+def rglob_audio_files(path: Path):
+    return list(walk_paths(path, ".wav")) + list(walk_paths(path, ".flac"))
+
+
+def mix_fg_bg(fg: Tensor, bg: Tensor, alpha: float | Callable[..., float] = 0.5, eps=1e-7):
+    """
+    Args:
+        fg: (b, t)
+        bg: (b, t)
+    """
+    assert bg.shape == fg.shape, f"bg.shape != fg.shape: {bg.shape} != {fg.shape}"
+    fg = fg / (fg.abs().max(dim=-1, keepdim=True).values + eps)
+    bg = bg / (bg.abs().max(dim=-1, keepdim=True).values + eps)
+
+    fg_energy = fg.pow(2).sum(dim=-1, keepdim=True)
+    bg_energy = bg.pow(2).sum(dim=-1, keepdim=True)
+
+    fg = fg / (fg_energy + eps).sqrt()
+    bg = bg / (bg_energy + eps).sqrt()
+
+    if callable(alpha):
+        alpha = alpha()
+
+    assert 0 <= alpha <= 1, f"alpha must be between 0 and 1: {alpha}"
+
+    mx = alpha * fg + (1 - alpha) * bg
+    mx = mx / (mx.abs().max(dim=-1, keepdim=True).values + eps)
+
+    return mx
diff --git a/modules/repos_static/resemble_enhance/denoiser/__init__.py b/modules/repos_static/resemble_enhance/denoiser/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/repos_static/resemble_enhance/denoiser/__main__.py b/modules/repos_static/resemble_enhance/denoiser/__main__.py
new file mode 100644
index 0000000000000000000000000000000000000000..86188661c35d10721c94dc21f88f4babf45f6f7d
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/denoiser/__main__.py
@@ -0,0 +1,30 @@
+import argparse
+from pathlib import Path
+
+import torch
+import torchaudio
+
+from .inference import denoise
+
+
+@torch.inference_mode()
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("in_dir", type=Path, help="Path to input audio folder")
+    parser.add_argument("out_dir", type=Path, help="Output folder")
+    parser.add_argument("--run_dir", type=Path, default="runs/denoiser", help="Path to run folder")
+    parser.add_argument("--suffix", type=str, default=".wav", help="File suffix")
+    parser.add_argument("--device", type=str, default="cuda", help="Device")
+    args = parser.parse_args()
+
+    for path in args.in_dir.glob(f"**/*{args.suffix}"):
+        print(f"Processing {path} ..")
+        dwav, sr = torchaudio.load(path)
+        hwav, sr = denoise(dwav[0], sr, args.run_dir, args.device)
+        out_path = args.out_dir / path.relative_to(args.in_dir)
+        out_path.parent.mkdir(parents=True, exist_ok=True)
+        torchaudio.save(out_path, hwav[None], sr)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/modules/repos_static/resemble_enhance/denoiser/denoiser.py b/modules/repos_static/resemble_enhance/denoiser/denoiser.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b1d49cdc257a84073fd43b205f5f497386ce80f
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/denoiser/denoiser.py
@@ -0,0 +1,181 @@
+import logging
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+from ..melspec import MelSpectrogram
+from .hparams import HParams
+from .unet import UNet
+
+logger = logging.getLogger(__name__)
+
+
+def _normalize(x: Tensor) -> Tensor:
+    return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7)
+
+
+class Denoiser(nn.Module):
+    @property
+    def stft_cfg(self) -> dict:
+        hop_size = self.hp.hop_size
+        return dict(hop_length=hop_size, n_fft=hop_size * 4, win_length=hop_size * 4)
+
+    @property
+    def n_fft(self):
+        return self.stft_cfg["n_fft"]
+
+    @property
+    def eps(self):
+        return 1e-7
+
+    def __init__(self, hp: HParams):
+        super().__init__()
+        self.hp = hp
+        self.net = UNet(input_dim=3, output_dim=3)
+        self.mel_fn = MelSpectrogram(hp)
+
+        self.dummy: Tensor
+        self.register_buffer("dummy", torch.zeros(1), persistent=False)
+
+    def to_mel(self, x: Tensor, drop_last=True):
+        """
+        Args:
+            x: (b t), wavs
+        Returns:
+            o: (b c t), mels
+        """
+        if drop_last:
+            return self.mel_fn(x)[..., :-1]  # (b d t)
+        return self.mel_fn(x)
+
+    def _stft(self, x):
+        """
+        Args:
+            x: (b t)
+        Returns:
+            mag: (b f t) in [0, inf)
+            cos: (b f t) in [-1, 1]
+            sin: (b f t) in [-1, 1]
+        """
+        dtype = x.dtype
+        device = x.device
+
+        if x.is_mps:
+            x = x.cpu()
+
+        window = torch.hann_window(self.stft_cfg["win_length"], device=x.device)
+        s = torch.stft(x.float(), **self.stft_cfg, window=window, return_complex=True)  # (b f t+1)
+
+        s = s[..., :-1]  # (b f t)
+
+        mag = s.abs()  # (b f t)
+
+        phi = s.angle()  # (b f t)
+        cos = phi.cos()  # (b f t)
+        sin = phi.sin()  # (b f t)
+
+        mag = mag.to(dtype=dtype, device=device)
+        cos = cos.to(dtype=dtype, device=device)
+        sin = sin.to(dtype=dtype, device=device)
+
+        return mag, cos, sin
+
+    def _istft(self, mag: Tensor, cos: Tensor, sin: Tensor):
+        """
+        Args:
+            mag: (b f t) in [0, inf)
+            cos: (b f t) in [-1, 1]
+            sin: (b f t) in [-1, 1]
+        Returns:
+            x: (b t)
+        """
+        device = mag.device
+        dtype = mag.dtype
+
+        if mag.is_mps:
+            mag = mag.cpu()
+            cos = cos.cpu()
+            sin = sin.cpu()
+
+        real = mag * cos  # (b f t)
+        imag = mag * sin  # (b f t)
+
+        s = torch.complex(real, imag)  # (b f t)
+
+        if s.isnan().any():
+            logger.warning("NaN detected in ISTFT input.")
+
+        s = F.pad(s, (0, 1), "replicate")  # (b f t+1)
+
+        window = torch.hann_window(self.stft_cfg["win_length"], device=s.device)
+        x = torch.istft(s, **self.stft_cfg, window=window, return_complex=False)
+
+        if x.isnan().any():
+            logger.warning("NaN detected in ISTFT output, set to zero.")
+            x = torch.where(x.isnan(), torch.zeros_like(x), x)
+
+        x = x.to(dtype=dtype, device=device)
+
+        return x
+
+    def _magphase(self, real, imag):
+        mag = (real.pow(2) + imag.pow(2) + self.eps).sqrt()
+        cos = real / mag
+        sin = imag / mag
+        return mag, cos, sin
+
+    def _predict(self, mag: Tensor, cos: Tensor, sin: Tensor):
+        """
+        Args:
+            mag: (b f t)
+            cos: (b f t)
+            sin: (b f t)
+        Returns:
+            mag_mask: (b f t) in [0, 1], magnitude mask
+            cos_res: (b f t) in [-1, 1], phase residual
+            sin_res: (b f t) in [-1, 1], phase residual
+        """
+        x = torch.stack([mag, cos, sin], dim=1)  # (b 3 f t)
+        mag_mask, real, imag = self.net(x).unbind(1)  # (b 3 f t)
+        mag_mask = mag_mask.sigmoid()  # (b f t)
+        real = real.tanh()  # (b f t)
+        imag = imag.tanh()  # (b f t)
+        _, cos_res, sin_res = self._magphase(real, imag)  # (b f t)
+        return mag_mask, sin_res, cos_res
+
+    def _separate(self, mag, cos, sin, mag_mask, cos_res, sin_res):
+        """Ref: https://audio-agi.github.io/Separate-Anything-You-Describe/AudioSep_arXiv.pdf"""
+        sep_mag = F.relu(mag * mag_mask)
+        sep_cos = cos * cos_res - sin * sin_res
+        sep_sin = sin * cos_res + cos * sin_res
+        return sep_mag, sep_cos, sep_sin
+
+    def forward(self, x: Tensor, y: Tensor | None = None):
+        """
+        Args:
+            x: (b t), a mixed audio
+            y: (b t), a fg audio
+        """
+        assert x.dim() == 2, f"Expected (b t), got {x.size()}"
+        x = x.to(self.dummy)
+        x = _normalize(x)
+
+        if y is not None:
+            assert y.dim() == 2, f"Expected (b t), got {y.size()}"
+            y = y.to(self.dummy)
+            y = _normalize(y)
+
+        mag, cos, sin = self._stft(x)  # (b 2f t)
+        mag_mask, sin_res, cos_res = self._predict(mag, cos, sin)
+        sep_mag, sep_cos, sep_sin = self._separate(mag, cos, sin, mag_mask, cos_res, sin_res)
+
+        o = self._istft(sep_mag, sep_cos, sep_sin)
+
+        npad = x.shape[-1] - o.shape[-1]
+        o = F.pad(o, (0, npad))
+
+        if y is not None:
+            self.losses = dict(l1=F.l1_loss(o, y))
+
+        return o
diff --git a/modules/repos_static/resemble_enhance/denoiser/hparams.py b/modules/repos_static/resemble_enhance/denoiser/hparams.py
new file mode 100644
index 0000000000000000000000000000000000000000..72ec1e5680e1f3323406f1206caf7945e0fb7b3b
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/denoiser/hparams.py
@@ -0,0 +1,9 @@
+from dataclasses import dataclass
+
+from ..hparams import HParams as HParamsBase
+
+
+@dataclass(frozen=True)
+class HParams(HParamsBase):
+    batch_size_per_gpu: int = 128
+    distort_prob: float = 0.5
diff --git a/modules/repos_static/resemble_enhance/denoiser/inference.py b/modules/repos_static/resemble_enhance/denoiser/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..9111321baaa428d46fa2d5f789fc437654c50f8b
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/denoiser/inference.py
@@ -0,0 +1,31 @@
+import logging
+from functools import cache
+
+import torch
+
+from ..denoiser.denoiser import Denoiser
+
+from ..inference import inference
+from .hparams import HParams
+
+logger = logging.getLogger(__name__)
+
+
+@cache
+def load_denoiser(run_dir, device):
+    if run_dir is None:
+        return Denoiser(HParams())
+    hp = HParams.load(run_dir)
+    denoiser = Denoiser(hp)
+    path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt"
+    state_dict = torch.load(path, map_location="cpu")["module"]
+    denoiser.load_state_dict(state_dict)
+    denoiser.eval()
+    denoiser.to(device)
+    return denoiser
+
+
+@torch.inference_mode()
+def denoise(dwav, sr, run_dir, device):
+    denoiser = load_denoiser(run_dir, device)
+    return inference(model=denoiser, dwav=dwav, sr=sr, device=device)
diff --git a/modules/repos_static/resemble_enhance/denoiser/unet.py b/modules/repos_static/resemble_enhance/denoiser/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b8f78309ce03f776c4a6d9f28f1f9763c94ea7a
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/denoiser/unet.py
@@ -0,0 +1,144 @@
+import torch.nn.functional as F
+from torch import nn
+
+
+class PreactResBlock(nn.Sequential):
+    def __init__(self, dim):
+        super().__init__(
+            nn.GroupNorm(dim // 16, dim),
+            nn.GELU(),
+            nn.Conv2d(dim, dim, 3, padding=1),
+            nn.GroupNorm(dim // 16, dim),
+            nn.GELU(),
+            nn.Conv2d(dim, dim, 3, padding=1),
+        )
+
+    def forward(self, x):
+        return x + super().forward(x)
+
+
+class UNetBlock(nn.Module):
+    def __init__(self, input_dim, output_dim=None, scale_factor=1.0):
+        super().__init__()
+        if output_dim is None:
+            output_dim = input_dim
+        self.pre_conv = nn.Conv2d(input_dim, output_dim, 3, padding=1)
+        self.res_block1 = PreactResBlock(output_dim)
+        self.res_block2 = PreactResBlock(output_dim)
+        self.downsample = self.upsample = nn.Identity()
+        if scale_factor > 1:
+            self.upsample = nn.Upsample(scale_factor=scale_factor)
+        elif scale_factor < 1:
+            self.downsample = nn.Upsample(scale_factor=scale_factor)
+
+    def forward(self, x, h=None):
+        """
+        Args:
+            x: (b c h w), last output
+            h: (b c h w), skip output
+        Returns:
+            o: (b c h w), output
+            s: (b c h w), skip output
+        """
+        x = self.upsample(x)
+        if h is not None:
+            assert x.shape == h.shape, f"{x.shape} != {h.shape}"
+            x = x + h
+        x = self.pre_conv(x)
+        x = self.res_block1(x)
+        x = self.res_block2(x)
+        return self.downsample(x), x
+
+
+class UNet(nn.Module):
+    def __init__(self, input_dim, output_dim, hidden_dim=16, num_blocks=4, num_middle_blocks=2):
+        super().__init__()
+        self.input_dim = input_dim
+        self.output_dim = output_dim
+        self.input_proj = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
+        self.encoder_blocks = nn.ModuleList(
+            [
+                UNetBlock(input_dim=hidden_dim * 2**i, output_dim=hidden_dim * 2 ** (i + 1), scale_factor=0.5)
+                for i in range(num_blocks)
+            ]
+        )
+        self.middle_blocks = nn.ModuleList(
+            [UNetBlock(input_dim=hidden_dim * 2**num_blocks) for _ in range(num_middle_blocks)]
+        )
+        self.decoder_blocks = nn.ModuleList(
+            [
+                UNetBlock(input_dim=hidden_dim * 2 ** (i + 1), output_dim=hidden_dim * 2**i, scale_factor=2)
+                for i in reversed(range(num_blocks))
+            ]
+        )
+        self.head = nn.Sequential(
+            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
+            nn.GELU(),
+            nn.Conv2d(hidden_dim, output_dim, 1),
+        )
+
+    @property
+    def scale_factor(self):
+        return 2 ** len(self.encoder_blocks)
+
+    def pad_to_fit(self, x):
+        """
+        Args:
+            x: (b c h w), input
+        Returns:
+            x: (b c h' w'), padded input
+        """
+        hpad = (self.scale_factor - x.shape[2] % self.scale_factor) % self.scale_factor
+        wpad = (self.scale_factor - x.shape[3] % self.scale_factor) % self.scale_factor
+        return F.pad(x, (0, wpad, 0, hpad))
+
+    def forward(self, x):
+        """
+        Args:
+            x: (b c h w), input
+        Returns:
+            o: (b c h w), output
+        """
+        shape = x.shape
+
+        x = self.pad_to_fit(x)
+        x = self.input_proj(x)
+
+        s_list = []
+        for block in self.encoder_blocks:
+            x, s = block(x)
+            s_list.append(s)
+
+        for block in self.middle_blocks:
+            x, _ = block(x)
+
+        for block, s in zip(self.decoder_blocks, reversed(s_list)):
+            x, _ = block(x, s)
+
+        x = self.head(x)
+        x = x[..., : shape[2], : shape[3]]
+
+        return x
+
+    def test(self, shape=(3, 512, 256)):
+        import ptflops
+
+        macs, params = ptflops.get_model_complexity_info(
+            self,
+            shape,
+            as_strings=True,
+            print_per_layer_stat=True,
+            verbose=True,
+        )
+
+        print(f"macs: {macs}")
+        print(f"params: {params}")
+
+
+def main():
+    model = UNet(3, 3)
+    model.test()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/modules/repos_static/resemble_enhance/enhancer/__init__.py b/modules/repos_static/resemble_enhance/enhancer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/repos_static/resemble_enhance/enhancer/__main__.py b/modules/repos_static/resemble_enhance/enhancer/__main__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c1ad5ce68497c73756585009a59ea225c89ab94
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/enhancer/__main__.py
@@ -0,0 +1,129 @@
+import argparse
+import random
+import time
+from pathlib import Path
+
+import torch
+import torchaudio
+from tqdm import tqdm
+
+from .inference import denoise, enhance
+
+
+@torch.inference_mode()
+def main():
+    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+    parser.add_argument("in_dir", type=Path, help="Path to input audio folder")
+    parser.add_argument("out_dir", type=Path, help="Output folder")
+    parser.add_argument(
+        "--run_dir",
+        type=Path,
+        default=None,
+        help="Path to the enhancer run folder, if None, use the default model",
+    )
+    parser.add_argument(
+        "--suffix",
+        type=str,
+        default=".wav",
+        help="Audio file suffix",
+    )
+    parser.add_argument(
+        "--device",
+        type=str,
+        default="cuda",
+        help="Device to use for computation, recommended to use CUDA",
+    )
+    parser.add_argument(
+        "--denoise_only",
+        action="store_true",
+        help="Only apply denoising without enhancement",
+    )
+    parser.add_argument(
+        "--lambd",
+        type=float,
+        default=1.0,
+        help="Denoise strength for enhancement (0.0 to 1.0)",
+    )
+    parser.add_argument(
+        "--tau",
+        type=float,
+        default=0.5,
+        help="CFM prior temperature (0.0 to 1.0)",
+    )
+    parser.add_argument(
+        "--solver",
+        type=str,
+        default="midpoint",
+        choices=["midpoint", "rk4", "euler"],
+        help="Numerical solver to use",
+    )
+    parser.add_argument(
+        "--nfe",
+        type=int,
+        default=64,
+        help="Number of function evaluations",
+    )
+    parser.add_argument(
+        "--parallel_mode",
+        action="store_true",
+        help="Shuffle the audio paths and skip the existing ones, enabling multiple jobs to run in parallel",
+    )
+
+    args = parser.parse_args()
+
+    device = args.device
+
+    if device == "cuda" and not torch.cuda.is_available():
+        print("CUDA is not available but --device is set to cuda, using CPU instead")
+        device = "cpu"
+
+    start_time = time.perf_counter()
+
+    run_dir = args.run_dir
+
+    paths = sorted(args.in_dir.glob(f"**/*{args.suffix}"))
+
+    if args.parallel_mode:
+        random.shuffle(paths)
+
+    if len(paths) == 0:
+        print(f"No {args.suffix} files found in the following path: {args.in_dir}")
+        return
+
+    pbar = tqdm(paths)
+
+    for path in pbar:
+        out_path = args.out_dir / path.relative_to(args.in_dir)
+        if args.parallel_mode and out_path.exists():
+            continue
+        pbar.set_description(f"Processing {out_path}")
+        dwav, sr = torchaudio.load(path)
+        dwav = dwav.mean(0)
+        if args.denoise_only:
+            hwav, sr = denoise(
+                dwav=dwav,
+                sr=sr,
+                device=device,
+                run_dir=args.run_dir,
+            )
+        else:
+            hwav, sr = enhance(
+                dwav=dwav,
+                sr=sr,
+                device=device,
+                nfe=args.nfe,
+                solver=args.solver,
+                lambd=args.lambd,
+                tau=args.tau,
+                run_dir=run_dir,
+            )
+        out_path.parent.mkdir(parents=True, exist_ok=True)
+        torchaudio.save(out_path, hwav[None], sr)
+
+    # Cool emoji effect saying the job is done
+    elapsed_time = time.perf_counter() - start_time
+    print(f"🌟 Enhancement done! {len(paths)} files processed in {elapsed_time:.2f}s")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/modules/repos_static/resemble_enhance/enhancer/download.py b/modules/repos_static/resemble_enhance/enhancer/download.py
new file mode 100644
index 0000000000000000000000000000000000000000..614b9a4b4f9a1a10b79f12ca1a25821247ea2a16
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/enhancer/download.py
@@ -0,0 +1,30 @@
+import logging
+from pathlib import Path
+
+import torch
+
+RUN_NAME = "enhancer_stage2"
+
+logger = logging.getLogger(__name__)
+
+
+def get_source_url(relpath):
+    return f"https://huggingface.co/ResembleAI/resemble-enhance/resolve/main/{RUN_NAME}/{relpath}?download=true"
+
+
+def get_target_path(relpath: str | Path, run_dir: str | Path | None = None):
+    if run_dir is None:
+        run_dir = Path(__file__).parent.parent / "model_repo" / RUN_NAME
+    return Path(run_dir) / relpath
+
+
+def download(run_dir: str | Path | None = None):
+    relpaths = ["hparams.yaml", "ds/G/latest", "ds/G/default/mp_rank_00_model_states.pt"]
+    for relpath in relpaths:
+        path = get_target_path(relpath, run_dir=run_dir)
+        if path.exists():
+            continue
+        url = get_source_url(relpath)
+        path.parent.mkdir(parents=True, exist_ok=True)
+        torch.hub.download_url_to_file(url, str(path))
+    return get_target_path("", run_dir=run_dir)
diff --git a/modules/repos_static/resemble_enhance/enhancer/enhancer.py b/modules/repos_static/resemble_enhance/enhancer/enhancer.py
new file mode 100644
index 0000000000000000000000000000000000000000..84cda8b0ad3cb0d99060d27d13638bd5dae2098c
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/enhancer/enhancer.py
@@ -0,0 +1,185 @@
+import logging
+
+import matplotlib.pyplot as plt
+import pandas as pd
+import torch
+from torch import Tensor, nn
+from torch.distributions import Beta
+
+from ..common import Normalizer
+from ..denoiser.inference import load_denoiser
+from ..melspec import MelSpectrogram
+from .hparams import HParams
+from .lcfm import CFM, IRMAE, LCFM
+from .univnet import UnivNet
+
+logger = logging.getLogger(__name__)
+
+
+def _maybe(fn):
+    def _fn(*args):
+        if args[0] is None:
+            return None
+        return fn(*args)
+
+    return _fn
+
+
+def _normalize_wav(x: Tensor):
+    return x / (x.abs().max(dim=-1, keepdim=True).values + 1e-7)
+
+
+class Enhancer(nn.Module):
+    def __init__(self, hp: HParams):
+        super().__init__()
+        self.hp = hp
+
+        n_mels = self.hp.num_mels
+        vocoder_input_dim = n_mels + self.hp.vocoder_extra_dim
+        latent_dim = self.hp.lcfm_latent_dim
+
+        self.lcfm = LCFM(
+            IRMAE(
+                input_dim=n_mels,
+                output_dim=vocoder_input_dim,
+                latent_dim=latent_dim,
+            ),
+            CFM(
+                cond_dim=n_mels,
+                output_dim=self.hp.lcfm_latent_dim,
+                solver_nfe=self.hp.cfm_solver_nfe,
+                solver_method=self.hp.cfm_solver_method,
+                time_mapping_divisor=self.hp.cfm_time_mapping_divisor,
+            ),
+            z_scale=self.hp.lcfm_z_scale,
+        )
+
+        self.lcfm.set_mode_(self.hp.lcfm_training_mode)
+
+        self.mel_fn = MelSpectrogram(hp)
+        self.vocoder = UnivNet(self.hp, vocoder_input_dim)
+        self.denoiser = load_denoiser(self.hp.denoiser_run_dir, "cpu")
+        self.normalizer = Normalizer()
+
+        self._eval_lambd = 0.0
+
+        self.dummy: Tensor
+        self.register_buffer("dummy", torch.zeros(1))
+
+        if self.hp.enhancer_stage1_run_dir is not None:
+            pretrained_path = (
+                self.hp.enhancer_stage1_run_dir
+                / "ds/G/default/mp_rank_00_model_states.pt"
+            )
+            self._load_pretrained(pretrained_path)
+
+        logger.info(f"{self.__class__.__name__} summary")
+        logger.info(f"{self.summarize()}")
+
+    def _load_pretrained(self, path):
+        # Clone is necessary as otherwise it holds a reference to the original model
+        cfm_state_dict = {k: v.clone() for k, v in self.lcfm.cfm.state_dict().items()}
+        denoiser_state_dict = {
+            k: v.clone() for k, v in self.denoiser.state_dict().items()
+        }
+        state_dict = torch.load(path, map_location="cpu")["module"]
+        self.load_state_dict(state_dict, strict=False)
+        self.lcfm.cfm.load_state_dict(cfm_state_dict)  # Reset cfm
+        self.denoiser.load_state_dict(denoiser_state_dict)  # Reset denoiser
+        logger.info(f"Loaded pretrained model from {path}")
+
+    def summarize(self):
+        npa_train = lambda m: sum(p.numel() for p in m.parameters() if p.requires_grad)
+        npa = lambda m: sum(p.numel() for p in m.parameters())
+        rows = []
+        for name, module in self.named_children():
+            rows.append(dict(name=name, trainable=npa_train(module), total=npa(module)))
+        rows.append(dict(name="total", trainable=npa_train(self), total=npa(self)))
+        df = pd.DataFrame(rows)
+        return df.to_markdown(index=False)
+
+    def to_mel(self, x: Tensor, drop_last=True):
+        """
+        Args:
+            x: (b t), wavs
+        Returns:
+            o: (b c t), mels
+        """
+        if drop_last:
+            return self.mel_fn(x)[..., :-1]  # (b d t)
+        return self.mel_fn(x)
+
+    def _may_denoise(self, x: Tensor, y: Tensor | None = None):
+        if self.hp.lcfm_training_mode == "cfm":
+            return self.denoiser(x, y)
+        return x
+
+    def configurate_(self, nfe, solver, lambd, tau):
+        """
+        Args:
+            nfe: number of function evaluations
+            solver: solver method
+            lambd: denoiser strength [0, 1]
+            tau: prior temperature [0, 1]
+        """
+        self.lcfm.cfm.solver.configurate_(nfe, solver)
+        self.lcfm.eval_tau_(tau)
+        self._eval_lambd = lambd
+
+    def forward(self, x: Tensor, y: Tensor | None = None, z: Tensor | None = None):
+        """
+        Args:
+            x: (b t), mix wavs (fg + bg)
+            y: (b t), fg clean wavs
+            z: (b t), fg distorted wavs
+        Returns:
+            o: (b t), reconstructed wavs
+        """
+        assert x.dim() == 2, f"Expected (b t), got {x.size()}"
+        assert y is None or y.dim() == 2, f"Expected (b t), got {y.size()}"
+
+        if self.hp.lcfm_training_mode == "cfm":
+            self.normalizer.eval()
+
+        x = _normalize_wav(x)
+        y = _maybe(_normalize_wav)(y)
+        z = _maybe(_normalize_wav)(z)
+
+        x_mel_original = self.normalizer(self.to_mel(x), update=False)  # (b d t)
+
+        if self.hp.lcfm_training_mode == "cfm":
+            if self.training:
+                lambd = Beta(0.2, 0.2).sample(x.shape[:1]).to(x.device)
+                lambd = lambd[:, None, None]
+                x_mel_denoised = self.normalizer(
+                    self.to_mel(self._may_denoise(x, z)), update=False
+                )
+                x_mel_denoised = x_mel_denoised.detach()
+                x_mel_denoised = lambd * x_mel_denoised + (1 - lambd) * x_mel_original
+                self._visualize(x_mel_original, x_mel_denoised)
+            else:
+                lambd = self._eval_lambd
+                if lambd == 0:
+                    x_mel_denoised = x_mel_original
+                else:
+                    x_mel_denoised = self.normalizer(
+                        self.to_mel(self._may_denoise(x, z)), update=False
+                    )
+                    x_mel_denoised = x_mel_denoised.detach()
+                    x_mel_denoised = (
+                        lambd * x_mel_denoised + (1 - lambd) * x_mel_original
+                    )
+        else:
+            x_mel_denoised = x_mel_original
+
+        y_mel = _maybe(self.to_mel)(y)  # (b d t)
+        y_mel = _maybe(self.normalizer)(y_mel)
+
+        lcfm_decoded = self.lcfm(x_mel_denoised, y_mel, ψ0=x_mel_original)  # (b d t)
+
+        if lcfm_decoded is None:
+            o = None
+        else:
+            o = self.vocoder(lcfm_decoded, y)
+
+        return o
diff --git a/modules/repos_static/resemble_enhance/enhancer/hparams.py b/modules/repos_static/resemble_enhance/enhancer/hparams.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca89bea6f5d7d4ec4f543f8bde88b29dcae69f6a
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/enhancer/hparams.py
@@ -0,0 +1,23 @@
+from dataclasses import dataclass
+from pathlib import Path
+
+from ..hparams import HParams as HParamsBase
+
+
+@dataclass(frozen=True)
+class HParams(HParamsBase):
+    cfm_solver_method: str = "midpoint"
+    cfm_solver_nfe: int = 64
+    cfm_time_mapping_divisor: int = 4
+    univnet_nc: int = 96
+
+    lcfm_latent_dim: int = 64
+    lcfm_training_mode: str = "ae"
+    lcfm_z_scale: float = 5
+
+    vocoder_extra_dim: int = 32
+
+    gan_training_start_step: int | None = 5_000
+    enhancer_stage1_run_dir: Path | None = None
+
+    denoiser_run_dir: Path | None = None
diff --git a/modules/repos_static/resemble_enhance/enhancer/inference.py b/modules/repos_static/resemble_enhance/enhancer/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..af57a2c7d3e5cc7b08b00f85f0135e881e50fcbe
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/enhancer/inference.py
@@ -0,0 +1,48 @@
+import logging
+from functools import cache
+from pathlib import Path
+
+import torch
+
+from ..inference import inference
+from .download import download
+from .hparams import HParams
+from .enhancer import Enhancer
+
+logger = logging.getLogger(__name__)
+
+
+@cache
+def load_enhancer(run_dir: str | Path | None, device):
+    run_dir = download(run_dir)
+    hp = HParams.load(run_dir)
+    enhancer = Enhancer(hp)
+    path = run_dir / "ds" / "G" / "default" / "mp_rank_00_model_states.pt"
+    state_dict = torch.load(path, map_location="cpu")["module"]
+    enhancer.load_state_dict(state_dict)
+    enhancer.eval()
+    enhancer.to(device)
+    return enhancer
+
+
+@torch.inference_mode()
+def denoise(dwav, sr, device, run_dir=None):
+    enhancer = load_enhancer(run_dir, device)
+    return inference(model=enhancer.denoiser, dwav=dwav, sr=sr, device=device)
+
+
+@torch.inference_mode()
+def enhance(
+    dwav, sr, device, nfe=32, solver="midpoint", lambd=0.5, tau=0.5, run_dir=None
+):
+    assert 0 < nfe <= 128, f"nfe must be in (0, 128], got {nfe}"
+    assert solver in (
+        "midpoint",
+        "rk4",
+        "euler",
+    ), f"solver must be in ('midpoint', 'rk4', 'euler'), got {solver}"
+    assert 0 <= lambd <= 1, f"lambd must be in [0, 1], got {lambd}"
+    assert 0 <= tau <= 1, f"tau must be in [0, 1], got {tau}"
+    enhancer = load_enhancer(run_dir, device)
+    enhancer.configurate_(nfe=nfe, solver=solver, lambd=lambd, tau=tau)
+    return inference(model=enhancer, dwav=dwav, sr=sr, device=device)
diff --git a/modules/repos_static/resemble_enhance/enhancer/lcfm/__init__.py b/modules/repos_static/resemble_enhance/enhancer/lcfm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9eca51c6bc6b2132389ac7ec0380159169a69499
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/enhancer/lcfm/__init__.py
@@ -0,0 +1,2 @@
+from .irmae import IRMAE
+from .lcfm import CFM, LCFM
diff --git a/modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py b/modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5125267b7f32e11c58e4b96bffa3ba1e96fdc4f
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py
@@ -0,0 +1,372 @@
+import logging
+from dataclasses import dataclass
+from functools import partial
+from typing import Protocol
+
+import matplotlib.pyplot as plt
+import numpy as np
+import scipy
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+from tqdm import trange
+
+from .wn import WN
+
+logger = logging.getLogger(__name__)
+
+
+class VelocityField(Protocol):
+    def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor:
+        ...
+
+
+class Solver:
+    def __init__(
+        self,
+        method="midpoint",
+        nfe=32,
+        viz_name="solver",
+        viz_every=100,
+        mel_fn=None,
+        time_mapping_divisor=4,
+        verbose=False,
+    ):
+        self.configurate_(nfe=nfe, method=method)
+
+        self.verbose = verbose
+        self.viz_every = viz_every
+        self.viz_name = viz_name
+
+        self._camera = None
+        self._mel_fn = mel_fn
+        self._time_mapping = partial(self.exponential_decay_mapping, n=time_mapping_divisor)
+
+    def configurate_(self, nfe=None, method=None):
+        if nfe is None:
+            nfe = self.nfe
+
+        if method is None:
+            method = self.method
+
+        if nfe == 1 and method in ("midpoint", "rk4"):
+            logger.warning(f"1 NFE is not supported for {method}, using euler method instead.")
+            method = "euler"
+
+        self.nfe = nfe
+        self.method = method
+
+    @property
+    def time_mapping(self):
+        return self._time_mapping
+
+    @staticmethod
+    def exponential_decay_mapping(t, n=4):
+        """
+        Args:
+            n: target step
+        """
+
+        def h(t, a):
+            return (a**t - 1) / (a - 1)
+
+        # Solve h(1/n) = 0.5
+        a = float(scipy.optimize.fsolve(lambda a: h(1 / n, a) - 0.5, x0=0))
+
+        t = h(t, a=a)
+
+        return t
+
+    @torch.no_grad()
+    def _maybe_camera_snap(self, *, ψt, t):
+        camera = self._camera
+        if camera is not None:
+            if ψt.shape[1] == 1:
+                # Waveform, b 1 t, plot every 100 samples
+                plt.subplot(211)
+                plt.plot(ψt.detach().cpu().numpy()[0, 0, ::100], color="blue")
+                if self._mel_fn is not None:
+                    plt.subplot(212)
+                    mel = self._mel_fn(ψt.detach().cpu().numpy()[0, 0])
+                    plt.imshow(mel, origin="lower", interpolation="none")
+            elif ψt.shape[1] == 2:
+                # Complex
+                plt.subplot(121)
+                plt.imshow(
+                    ψt.detach().cpu().numpy()[0, 0],
+                    origin="lower",
+                    interpolation="none",
+                )
+                plt.subplot(122)
+                plt.imshow(
+                    ψt.detach().cpu().numpy()[0, 1],
+                    origin="lower",
+                    interpolation="none",
+                )
+            else:
+                # Spectrogram, b c t
+                plt.imshow(ψt.detach().cpu().numpy()[0], origin="lower", interpolation="none")
+            ax = plt.gca()
+            ax.text(0.5, 1.01, f"t={t:.2f}", transform=ax.transAxes, ha="center")
+            camera.snap()
+
+    @staticmethod
+    def _euler_step(t, ψt, dt, f: VelocityField):
+        return ψt + dt * f(t=t, ψt=ψt, dt=dt)
+
+    @staticmethod
+    def _midpoint_step(t, ψt, dt, f: VelocityField):
+        return ψt + dt * f(t=t + dt / 2, ψt=ψt + dt * f(t=t, ψt=ψt, dt=dt) / 2, dt=dt)
+
+    @staticmethod
+    def _rk4_step(t, ψt, dt, f: VelocityField):
+        k1 = f(t=t, ψt=ψt, dt=dt)
+        k2 = f(t=t + dt / 2, ψt=ψt + dt * k1 / 2, dt=dt)
+        k3 = f(t=t + dt / 2, ψt=ψt + dt * k2 / 2, dt=dt)
+        k4 = f(t=t + dt, ψt=ψt + dt * k3, dt=dt)
+        return ψt + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6
+
+    @property
+    def _step(self):
+        if self.method == "euler":
+            return self._euler_step
+        elif self.method == "midpoint":
+            return self._midpoint_step
+        elif self.method == "rk4":
+            return self._rk4_step
+        else:
+            raise ValueError(f"Unknown method: {self.method}")
+
+    def get_running_train_loop(self):
+        try:
+            # Lazy import
+            from ...utils.train_loop import TrainLoop
+
+            return TrainLoop.get_running_loop()
+        except ImportError:
+            return None
+
+    @property
+    def visualizing(self):
+        loop = self.get_running_train_loop()
+        if loop is None:
+            return
+        out_path = loop.make_current_step_viz_path(self.viz_name, ".gif")
+        return loop.global_step % self.viz_every == 0 and not out_path.exists()
+
+    def _reset_camera(self):
+        try:
+            from celluloid import Camera
+
+            self._camera = Camera(plt.figure())
+        except:
+            pass
+
+    def _maybe_dump_camera(self):
+        camera = self._camera
+        loop = self.get_running_train_loop()
+        if camera is not None and loop is not None:
+            animation = camera.animate()
+            out_path = loop.make_current_step_viz_path(self.viz_name, ".gif")
+            out_path.parent.mkdir(exist_ok=True, parents=True)
+            animation.save(out_path, writer="pillow", fps=4)
+            plt.close()
+            self._camera = None
+
+    @property
+    def n_steps(self):
+        n = self.nfe
+        if self.method == "euler":
+            pass
+        elif self.method == "midpoint":
+            n //= 2
+        elif self.method == "rk4":
+            n //= 4
+        else:
+            raise ValueError(f"Unknown method: {self.method}")
+        return n
+
+    def solve(self, f: VelocityField, ψ0: Tensor, t0=0.0, t1=1.0):
+        ts = self._time_mapping(np.linspace(t0, t1, self.n_steps + 1))
+
+        if self.visualizing:
+            self._reset_camera()
+
+        if self.verbose:
+            steps = trange(self.n_steps, desc="CFM inference")
+        else:
+            steps = range(self.n_steps)
+
+        ψt = ψ0
+
+        for i in steps:
+            dt = ts[i + 1] - ts[i]
+            t = ts[i]
+            self._maybe_camera_snap(ψt=ψt, t=t)
+            ψt = self._step(t=t, ψt=ψt, dt=dt, f=f)
+
+        self._maybe_camera_snap(ψt=ψt, t=ts[-1])
+
+        ψ1 = ψt
+        del ψt
+
+        self._maybe_dump_camera()
+
+        return ψ1
+
+    def __call__(self, f: VelocityField, ψ0: Tensor, t0=0.0, t1=1.0):
+        return self.solve(f=f, ψ0=ψ0, t0=t0, t1=t1)
+
+
+class SinusodialTimeEmbedding(nn.Module):
+    def __init__(self, d_embed):
+        super().__init__()
+        self.d_embed = d_embed
+        assert d_embed % 2 == 0
+
+    def forward(self, t):
+        t = t.unsqueeze(-1)  # ... 1
+        p = torch.linspace(0, 4, self.d_embed // 2).to(t)
+        while p.dim() < t.dim():
+            p = p.unsqueeze(0)  # ... d/2
+        sin = torch.sin(t * 10**p)
+        cos = torch.cos(t * 10**p)
+        return torch.cat([sin, cos], dim=-1)
+
+
+@dataclass(eq=False)
+class CFM(nn.Module):
+    """
+    This mixin is for general diffusion models.
+
+    ψ0 stands for the gaussian noise, and ψ1 is the data point.
+
+    Here we follow the CFM style:
+        The generation process (reverse process) is from t=0 to t=1.
+        The forward process is from t=1 to t=0.
+    """
+
+    cond_dim: int
+    output_dim: int
+    time_emb_dim: int = 128
+    viz_name: str = "cfm"
+    solver_nfe: int = 32
+    solver_method: str = "midpoint"
+    time_mapping_divisor: int = 4
+
+    def __post_init__(self):
+        super().__init__()
+        self.solver = Solver(
+            viz_name=self.viz_name,
+            viz_every=1,
+            nfe=self.solver_nfe,
+            method=self.solver_method,
+            time_mapping_divisor=self.time_mapping_divisor,
+        )
+        self.emb = SinusodialTimeEmbedding(self.time_emb_dim)
+        self.net = WN(
+            input_dim=self.output_dim,
+            output_dim=self.output_dim,
+            local_dim=self.cond_dim,
+            global_dim=self.time_emb_dim,
+        )
+
+    def _perturb(self, ψ1: Tensor, t: Tensor | None = None):
+        """
+        Perturb ψ1 to ψt.
+        """
+        raise NotImplementedError
+
+    def _sample_ψ0(self, x: Tensor):
+        """
+        Args:
+            x: (b c t), which implies the shape of ψ0
+        """
+        shape = list(x.shape)
+        shape[1] = self.output_dim
+        if self.training:
+            g = None
+        else:
+            g = torch.Generator(device=x.device)
+            g.manual_seed(0)  # deterministic sampling during eval
+        ψ0 = torch.randn(shape, device=x.device, dtype=x.dtype, generator=g)
+        return ψ0
+
+    @property
+    def sigma(self):
+        return 1e-4
+
+    def _to_ψt(self, *, ψ1: Tensor, ψ0: Tensor, t: Tensor):
+        """
+        Eq (22)
+        """
+        while t.dim() < ψ1.dim():
+            t = t.unsqueeze(-1)
+        μ = t * ψ1 + (1 - t) * ψ0
+        return μ + torch.randn_like(μ) * self.sigma
+
+    def _to_u(self, *, ψ1, ψ0: Tensor):
+        """
+        Eq (21)
+        """
+        return ψ1 - ψ0
+
+    def _to_v(self, *, ψt, x, t: float | Tensor):
+        """
+        Args:
+            ψt: (b c t)
+            x: (b c t)
+            t: (b)
+        Returns:
+            v: (b c t)
+        """
+        if isinstance(t, (float, int)):
+            t = torch.full(ψt.shape[:1], t).to(ψt)
+        t = t.clamp(0, 1)  # [0, 1)
+        g = self.emb(t)  # (b d)
+        v = self.net(ψt, l=x, g=g)
+        return v
+
+    def compute_losses(self, x, y, ψ0) -> dict:
+        """
+        Args:
+            x: (b c t)
+            y: (b c t)
+        Returns:
+            losses: dict
+        """
+        t = torch.rand(len(x), device=x.device, dtype=x.dtype)
+        t = self.solver.time_mapping(t)
+
+        if ψ0 is None:
+            ψ0 = self._sample_ψ0(x)
+
+        ψt = self._to_ψt(ψ1=y, t=t, ψ0=ψ0)
+
+        v = self._to_v(ψt=ψt, t=t, x=x)
+        u = self._to_u(ψ1=y, ψ0=ψ0)
+
+        losses = dict(l1=F.l1_loss(v, u))
+
+        return losses
+
+    @torch.inference_mode()
+    def sample(self, x, ψ0=None, t0=0.0):
+        """
+        Args:
+            x: (b c t)
+        Returns:
+            y: (b ... t)
+        """
+        if ψ0 is None:
+            ψ0 = self._sample_ψ0(x)
+        f = lambda t, ψt, dt: self._to_v(ψt=ψt, t=t, x=x)
+        ψ1 = self.solver(f=f, ψ0=ψ0, t0=t0)
+        return ψ1
+
+    def forward(self, x: Tensor, y: Tensor | None = None, ψ0: Tensor | None = None, t0=0.0):
+        if y is None:
+            y = self.sample(x, ψ0=ψ0, t0=t0)
+        else:
+            self.losses = self.compute_losses(x, y, ψ0=ψ0)
+        return y
diff --git a/modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py b/modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py
new file mode 100644
index 0000000000000000000000000000000000000000..e71ab0cd8b9f07c3c27ca3877ee79b6510445d1f
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py
@@ -0,0 +1,123 @@
+import logging
+from dataclasses import dataclass
+
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor, nn
+from torch.nn.utils.parametrizations import weight_norm
+
+from ...common import Normalizer
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class IRMAEOutput:
+    latent: Tensor  # latent vector
+    decoded: Tensor | None  # decoder output, include extra dim
+
+
+class ResBlock(nn.Sequential):
+    def __init__(self, channels, dilations=[1, 2, 4, 8]):
+        wn = weight_norm
+        super().__init__(
+            nn.GroupNorm(32, channels),
+            nn.GELU(),
+            wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[0])),
+            nn.GroupNorm(32, channels),
+            nn.GELU(),
+            wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[1])),
+            nn.GroupNorm(32, channels),
+            nn.GELU(),
+            wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[2])),
+            nn.GroupNorm(32, channels),
+            nn.GELU(),
+            wn(nn.Conv1d(channels, channels, 3, padding="same", dilation=dilations[3])),
+        )
+
+    def forward(self, x: Tensor):
+        return x + super().forward(x)
+
+
+class IRMAE(nn.Module):
+    def __init__(
+        self,
+        input_dim,
+        output_dim,
+        latent_dim,
+        hidden_dim=1024,
+        num_irms=4,
+    ):
+        """
+        Args:
+            input_dim: input dimension
+            output_dim: output dimension
+            latent_dim: latent dimension
+            hidden_dim: hidden layer dimension
+            num_irm_matrics: number of implicit rank minimization matrices
+            norm: normalization layer
+        """
+        self.input_dim = input_dim
+        super().__init__()
+
+        self.encoder = nn.Sequential(
+            nn.Conv1d(input_dim, hidden_dim, 3, padding="same"),
+            *[ResBlock(hidden_dim) for _ in range(4)],
+            # Try to obtain compact representation (https://proceedings.neurips.cc/paper/2020/file/a9078e8653368c9c291ae2f8b74012e7-Paper.pdf)
+            *[nn.Conv1d(hidden_dim if i == 0 else latent_dim, latent_dim, 1, bias=False) for i in range(num_irms)],
+            nn.Tanh(),
+        )
+
+        self.decoder = nn.Sequential(
+            nn.Conv1d(latent_dim, hidden_dim, 3, padding="same"),
+            *[ResBlock(hidden_dim) for _ in range(4)],
+            nn.Conv1d(hidden_dim, output_dim, 1),
+        )
+
+        self.head = nn.Sequential(
+            nn.Conv1d(output_dim, hidden_dim, 3, padding="same"),
+            nn.GELU(),
+            nn.Conv1d(hidden_dim, input_dim, 1),
+        )
+
+        self.estimator = Normalizer()
+
+    def encode(self, x):
+        """
+        Args:
+            x: (b c t) tensor
+        """
+        z = self.encoder(x)  # (b c t)
+        _ = self.estimator(z)  # Estimate the glboal mean and std of z
+        self.stats = {}
+        self.stats["z_mean"] = z.mean().item()
+        self.stats["z_std"] = z.std().item()
+        self.stats["z_abs_68"] = z.abs().quantile(0.6827).item()
+        self.stats["z_abs_95"] = z.abs().quantile(0.9545).item()
+        self.stats["z_abs_99"] = z.abs().quantile(0.9973).item()
+        return z
+
+    def decode(self, z):
+        """
+        Args:
+            z: (b c t) tensor
+        """
+        return self.decoder(z)
+
+    def forward(self, x, skip_decoding=False):
+        """
+        Args:
+            x: (b c t) tensor
+            skip_decoding: if True, skip the decoding step
+        """
+        z = self.encode(x)  # q(z|x)
+
+        if skip_decoding:
+            # This speeds up the training in cfm only mode
+            decoded = None
+        else:
+            decoded = self.decode(z)  # p(x|z)
+            predicted = self.head(decoded)
+            self.losses = dict(mse=F.mse_loss(predicted, x))
+
+        return IRMAEOutput(latent=z, decoded=decoded)
diff --git a/modules/repos_static/resemble_enhance/enhancer/lcfm/lcfm.py b/modules/repos_static/resemble_enhance/enhancer/lcfm/lcfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c2f5f88718e2f42f82e2f4714ea510b4677b450
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/enhancer/lcfm/lcfm.py
@@ -0,0 +1,152 @@
+import logging
+from enum import Enum
+
+import matplotlib.pyplot as plt
+import torch
+import torch.nn as nn
+from torch import Tensor, nn
+
+from .cfm import CFM
+from .irmae import IRMAE, IRMAEOutput
+
+logger = logging.getLogger(__name__)
+
+
+def freeze_(module):
+    for p in module.parameters():
+        p.requires_grad_(False)
+
+
+class LCFM(nn.Module):
+    class Mode(Enum):
+        AE = "ae"
+        CFM = "cfm"
+
+    def __init__(self, ae: IRMAE, cfm: CFM, z_scale: float = 1.0):
+        super().__init__()
+        self.ae = ae
+        self.cfm = cfm
+        self.z_scale = z_scale
+        self._mode = None
+        self._eval_tau = 0.5
+
+    @property
+    def mode(self):
+        return self._mode
+
+    def set_mode_(self, mode):
+        mode = self.Mode(mode)
+        self._mode = mode
+
+        if mode == mode.AE:
+            freeze_(self.cfm)
+            logger.info("Freeze cfm")
+        elif mode == mode.CFM:
+            freeze_(self.ae)
+            logger.info("Freeze ae (encoder and decoder)")
+        else:
+            raise ValueError(f"Unknown training mode: {mode}")
+
+    def get_running_train_loop(self):
+        try:
+            # Lazy import
+            from ...utils.train_loop import TrainLoop
+
+            return TrainLoop.get_running_loop()
+        except ImportError:
+            return None
+
+    @property
+    def global_step(self):
+        loop = self.get_running_train_loop()
+        if loop is None:
+            return None
+        return loop.global_step
+
+    @torch.no_grad()
+    def _visualize(self, x, y, y_):
+        loop = self.get_running_train_loop()
+        if loop is None:
+            return
+
+        plt.subplot(221)
+        plt.imshow(y[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
+        plt.title("GT")
+
+        plt.subplot(222)
+        y_ = y_[:, : y.shape[1]]
+        plt.imshow(y_[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
+        plt.title("Posterior")
+
+        plt.subplot(223)
+        z_ = self.cfm(x)
+        y__ = self.ae.decode(z_)
+        y__ = y__[:, : y.shape[1]]
+        plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
+        plt.title("C-Prior")
+        del y__
+
+        plt.subplot(224)
+        z_ = torch.randn_like(z_)
+        y__ = self.ae.decode(z_)
+        y__ = y__[:, : y.shape[1]]
+        plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
+        plt.title("Prior")
+        del z_, y__
+
+        path = loop.make_current_step_viz_path("recon", ".png")
+        path.parent.mkdir(exist_ok=True, parents=True)
+        plt.tight_layout()
+        plt.savefig(path, dpi=500)
+        plt.close()
+
+    def _scale(self, z: Tensor):
+        return z * self.z_scale
+
+    def _unscale(self, z: Tensor):
+        return z / self.z_scale
+
+    def eval_tau_(self, tau):
+        self._eval_tau = tau
+
+    def forward(self, x, y: Tensor | None = None, ψ0: Tensor | None = None):
+        """
+        Args:
+            x: (b d t), condition mel
+            y: (b d t), target mel
+            ψ0: (b d t), starting mel
+        """
+        if self.mode == self.Mode.CFM:
+            self.ae.eval()  # Always set to eval when training cfm
+
+        if ψ0 is not None:
+            ψ0 = self._scale(self.ae.encode(ψ0))
+            if self.training:
+                tau = torch.rand_like(ψ0[:, :1, :1])
+            else:
+                tau = self._eval_tau
+            ψ0 = tau * torch.randn_like(ψ0) + (1 - tau) * ψ0
+
+        if y is None:
+            if self.mode == self.Mode.AE:
+                with torch.no_grad():
+                    training = self.ae.training
+                    self.ae.eval()
+                    z = self.ae.encode(x)
+                    self.ae.train(training)
+            else:
+                z = self._unscale(self.cfm(x, ψ0=ψ0))
+
+            h = self.ae.decode(z)
+        else:
+            ae_output: IRMAEOutput = self.ae(y, skip_decoding=self.mode == self.Mode.CFM)
+
+            if self.mode == self.Mode.CFM:
+                _ = self.cfm(x, self._scale(ae_output.latent.detach()), ψ0=ψ0)
+
+            h = ae_output.decoded
+
+            if h is not None and self.global_step is not None and self.global_step % 100 == 0:
+                self._visualize(x[:1], y[:1], h[:1])
+
+        return h
diff --git a/modules/repos_static/resemble_enhance/enhancer/lcfm/wn.py b/modules/repos_static/resemble_enhance/enhancer/lcfm/wn.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bde173c205bb74f30ed95a9f013b3eb5b2abe5a
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/enhancer/lcfm/wn.py
@@ -0,0 +1,147 @@
+import logging
+import math
+
+import torch
+import torch.nn as nn
+
+logger = logging.getLogger(__name__)
+
+
+@torch.jit.script
+def _fused_tanh_sigmoid(h):
+    a, b = h.chunk(2, dim=1)
+    h = a.tanh() * b.sigmoid()
+    return h
+
+
+class WNLayer(nn.Module):
+    """
+    A DiffWave-like WN
+    """
+
+    def __init__(self, hidden_dim, local_dim, global_dim, kernel_size, dilation):
+        super().__init__()
+
+        local_output_dim = hidden_dim * 2
+
+        if global_dim is not None:
+            self.gconv = nn.Conv1d(global_dim, hidden_dim, 1)
+
+        if local_dim is not None:
+            self.lconv = nn.Conv1d(local_dim, local_output_dim, 1)
+
+        self.dconv = nn.Conv1d(hidden_dim, local_output_dim, kernel_size, dilation=dilation, padding="same")
+
+        self.out = nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size=1)
+
+    def forward(self, z, l, g):
+        identity = z
+
+        if g is not None:
+            if g.dim() == 2:
+                g = g.unsqueeze(-1)
+            z = z + self.gconv(g)
+
+        z = self.dconv(z)
+
+        if l is not None:
+            z = z + self.lconv(l)
+
+        z = _fused_tanh_sigmoid(z)
+
+        h = self.out(z)
+
+        z, s = h.chunk(2, dim=1)
+
+        o = (z + identity) / math.sqrt(2)
+
+        return o, s
+
+
+class WN(nn.Module):
+    def __init__(
+        self,
+        input_dim,
+        output_dim,
+        local_dim=None,
+        global_dim=None,
+        n_layers=30,
+        kernel_size=3,
+        dilation_cycle=5,
+        hidden_dim=512,
+    ):
+        super().__init__()
+        assert kernel_size % 2 == 1
+        assert hidden_dim % 2 == 0
+
+        self.input_dim = input_dim
+        self.hidden_dim = hidden_dim
+        self.local_dim = local_dim
+        self.global_dim = global_dim
+
+        self.start = nn.Conv1d(input_dim, hidden_dim, 1)
+        if local_dim is not None:
+            self.local_norm = nn.InstanceNorm1d(local_dim)
+
+        self.layers = nn.ModuleList(
+            [
+                WNLayer(
+                    hidden_dim=hidden_dim,
+                    local_dim=local_dim,
+                    global_dim=global_dim,
+                    kernel_size=kernel_size,
+                    dilation=2 ** (i % dilation_cycle),
+                )
+                for i in range(n_layers)
+            ]
+        )
+
+        self.end = nn.Conv1d(hidden_dim, output_dim, 1)
+
+    def forward(self, z, l=None, g=None):
+        """
+        Args:
+            z: input (b c t)
+            l: local condition (b c t)
+            g: global condition (b d)
+        """
+        z = self.start(z)
+
+        if l is not None:
+            l = self.local_norm(l)
+
+        # Skips
+        s_list = []
+
+        for layer in self.layers:
+            z, s = layer(z, l, g)
+            s_list.append(s)
+
+        s_list = torch.stack(s_list, dim=0).sum(dim=0)
+        s_list = s_list / math.sqrt(len(self.layers))
+
+        o = self.end(s_list)
+
+        return o
+
+    def summarize(self, length=100):
+        from ptflops import get_model_complexity_info
+
+        x = torch.randn(1, self.input_dim, length)
+
+        macs, params = get_model_complexity_info(
+            self,
+            (self.input_dim, length),
+            as_strings=True,
+            print_per_layer_stat=True,
+            verbose=True,
+        )
+
+        print(f"Input shape: {x.shape}")
+        print(f"Computational complexity: {macs}")
+        print(f"Number of parameters: {params}")
+
+
+if __name__ == "__main__":
+    model = WN(input_dim=64, output_dim=64)
+    model.summarize()
diff --git a/modules/repos_static/resemble_enhance/enhancer/univnet/__init__.py b/modules/repos_static/resemble_enhance/enhancer/univnet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4d2fea066e2e71371c6af840e759f1676380170
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/enhancer/univnet/__init__.py
@@ -0,0 +1 @@
+from .univnet import UnivNet
diff --git a/modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/__init__.py b/modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..270596c8f44f9295026cf308b39151a08dbed85a
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/__init__.py
@@ -0,0 +1,5 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+#   LICENSE is in incl_licenses directory.
+
+from .filter import *
+from .resample import *
diff --git a/modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/filter.py b/modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..5165557d7dcadcb4d07018e13562b22f8c85e91e
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/filter.py
@@ -0,0 +1,95 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+#   LICENSE is in incl_licenses directory.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+
+if 'sinc' in dir(torch):
+    sinc = torch.sinc
+else:
+    # This code is adopted from adefossez's julius.core.sinc under the MIT License
+    # https://adefossez.github.io/julius/julius/core.html
+    #   LICENSE is in incl_licenses directory.
+    def sinc(x: torch.Tensor):
+        """
+        Implementation of sinc, i.e. sin(pi * x) / (pi * x)
+        __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
+        """
+        return torch.where(x == 0,
+                           torch.tensor(1., device=x.device, dtype=x.dtype),
+                           torch.sin(math.pi * x) / math.pi / x)
+
+
+# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
+# https://adefossez.github.io/julius/julius/lowpass.html
+#   LICENSE is in incl_licenses directory.
+def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
+    even = (kernel_size % 2 == 0)
+    half_size = kernel_size // 2
+
+    #For kaiser window
+    delta_f = 4 * half_width
+    A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
+    if A > 50.:
+        beta = 0.1102 * (A - 8.7)
+    elif A >= 21.:
+        beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
+    else:
+        beta = 0.
+    window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
+
+    # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
+    if even:
+        time = (torch.arange(-half_size, half_size) + 0.5)
+    else:
+        time = torch.arange(kernel_size) - half_size
+    if cutoff == 0:
+        filter_ = torch.zeros_like(time)
+    else:
+        filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
+        # Normalize filter to have sum = 1, otherwise we will have a small leakage
+        # of the constant component in the input signal.
+        filter_ /= filter_.sum()
+        filter = filter_.view(1, 1, kernel_size)
+
+    return filter
+
+
+class LowPassFilter1d(nn.Module):
+    def __init__(self,
+                 cutoff=0.5,
+                 half_width=0.6,
+                 stride: int = 1,
+                 padding: bool = True,
+                 padding_mode: str = 'replicate',
+                 kernel_size: int = 12):
+        # kernel_size should be even number for stylegan3 setup,
+        # in this implementation, odd number is also possible.
+        super().__init__()
+        if cutoff < -0.:
+            raise ValueError("Minimum cutoff must be larger than zero.")
+        if cutoff > 0.5:
+            raise ValueError("A cutoff above 0.5 does not make sense.")
+        self.kernel_size = kernel_size
+        self.even = (kernel_size % 2 == 0)
+        self.pad_left = kernel_size // 2 - int(self.even)
+        self.pad_right = kernel_size // 2
+        self.stride = stride
+        self.padding = padding
+        self.padding_mode = padding_mode
+        filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
+        self.register_buffer("filter", filter)
+
+    #input [B, C, T]
+    def forward(self, x):
+        _, C, _ = x.shape
+
+        if self.padding:
+            x = F.pad(x, (self.pad_left, self.pad_right),
+                      mode=self.padding_mode)
+        out = F.conv1d(x, self.filter.expand(C, -1, -1),
+                       stride=self.stride, groups=C)
+
+        return out
diff --git a/modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/resample.py b/modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/resample.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fc6e12a9dbaa9ac41bd349b7f1797442052e4f6
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/enhancer/univnet/alias_free_torch/resample.py
@@ -0,0 +1,49 @@
+# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
+#   LICENSE is in incl_licenses directory.
+
+import torch.nn as nn
+from torch.nn import functional as F
+from .filter import LowPassFilter1d
+from .filter import kaiser_sinc_filter1d
+
+
+class UpSample1d(nn.Module):
+    def __init__(self, ratio=2, kernel_size=None):
+        super().__init__()
+        self.ratio = ratio
+        self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
+        self.stride = ratio
+        self.pad = self.kernel_size // ratio - 1
+        self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
+        self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
+        filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
+                                      half_width=0.6 / ratio,
+                                      kernel_size=self.kernel_size)
+        self.register_buffer("filter", filter)
+
+    # x: [B, C, T]
+    def forward(self, x):
+        _, C, _ = x.shape
+
+        x = F.pad(x, (self.pad, self.pad), mode='replicate')
+        x = self.ratio * F.conv_transpose1d(
+            x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
+        x = x[..., self.pad_left:-self.pad_right]
+
+        return x
+
+
+class DownSample1d(nn.Module):
+    def __init__(self, ratio=2, kernel_size=None):
+        super().__init__()
+        self.ratio = ratio
+        self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
+        self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
+                                       half_width=0.6 / ratio,
+                                       stride=ratio,
+                                       kernel_size=self.kernel_size)
+
+    def forward(self, x):
+        xx = self.lowpass(x)
+
+        return xx
diff --git a/modules/repos_static/resemble_enhance/enhancer/univnet/amp.py b/modules/repos_static/resemble_enhance/enhancer/univnet/amp.py
new file mode 100644
index 0000000000000000000000000000000000000000..469026338771408a24253ae52c8f2f22a6057475
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/enhancer/univnet/amp.py
@@ -0,0 +1,101 @@
+# Refer from https://github.com/NVIDIA/BigVGAN
+
+import math
+
+import torch
+import torch.nn as nn
+from torch import nn
+from torch.nn.utils.parametrizations import weight_norm
+
+from .alias_free_torch import DownSample1d, UpSample1d
+
+
+class SnakeBeta(nn.Module):
+    """
+    A modified Snake function which uses separate parameters for the magnitude of the periodic components
+    Shape:
+        - Input: (B, C, T)
+        - Output: (B, C, T), same shape as the input
+    Parameters:
+        - alpha - trainable parameter that controls frequency
+        - beta - trainable parameter that controls magnitude
+    References:
+        - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
+        https://arxiv.org/abs/2006.08195
+    Examples:
+        >>> a1 = snakebeta(256)
+        >>> x = torch.randn(256)
+        >>> x = a1(x)
+    """
+
+    def __init__(self, in_features, alpha=1.0, clamp=(1e-2, 50)):
+        """
+        Initialization.
+        INPUT:
+            - in_features: shape of the input
+            - alpha - trainable parameter that controls frequency
+            - beta - trainable parameter that controls magnitude
+            alpha is initialized to 1 by default, higher values = higher-frequency.
+            beta is initialized to 1 by default, higher values = higher-magnitude.
+            alpha will be trained along with the rest of your model.
+        """
+        super().__init__()
+        self.in_features = in_features
+        self.log_alpha = nn.Parameter(torch.zeros(in_features) + math.log(alpha))
+        self.log_beta = nn.Parameter(torch.zeros(in_features) + math.log(alpha))
+        self.clamp = clamp
+
+    def forward(self, x):
+        """
+        Forward pass of the function.
+        Applies the function to the input elementwise.
+        SnakeBeta ∶= x + 1/b * sin^2 (xa)
+        """
+        alpha = self.log_alpha.exp().clamp(*self.clamp)
+        alpha = alpha[None, :, None]
+
+        beta = self.log_beta.exp().clamp(*self.clamp)
+        beta = beta[None, :, None]
+
+        x = x + (1.0 / beta) * (x * alpha).sin().pow(2)
+
+        return x
+
+
+class UpActDown(nn.Module):
+    def __init__(
+        self,
+        act,
+        up_ratio: int = 2,
+        down_ratio: int = 2,
+        up_kernel_size: int = 12,
+        down_kernel_size: int = 12,
+    ):
+        super().__init__()
+        self.up_ratio = up_ratio
+        self.down_ratio = down_ratio
+        self.act = act
+        self.upsample = UpSample1d(up_ratio, up_kernel_size)
+        self.downsample = DownSample1d(down_ratio, down_kernel_size)
+
+    def forward(self, x):
+        # x: [B,C,T]
+        x = self.upsample(x)
+        x = self.act(x)
+        x = self.downsample(x)
+        return x
+
+
+class AMPBlock(nn.Sequential):
+    def __init__(self, channels, *, kernel_size=3, dilations=(1, 3, 5)):
+        super().__init__(*(self._make_layer(channels, kernel_size, d) for d in dilations))
+
+    def _make_layer(self, channels, kernel_size, dilation):
+        return nn.Sequential(
+            weight_norm(nn.Conv1d(channels, channels, kernel_size, dilation=dilation, padding="same")),
+            UpActDown(act=SnakeBeta(channels)),
+            weight_norm(nn.Conv1d(channels, channels, kernel_size, padding="same")),
+        )
+
+    def forward(self, x):
+        return x + super().forward(x)
diff --git a/modules/repos_static/resemble_enhance/enhancer/univnet/discriminator.py b/modules/repos_static/resemble_enhance/enhancer/univnet/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef3bd2552ea6f7f654c72737e079ce3239835d68
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/enhancer/univnet/discriminator.py
@@ -0,0 +1,210 @@
+import logging
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+from torch.nn.utils.parametrizations import weight_norm
+
+from ..hparams import HParams
+from .mrstft import get_stft_cfgs
+
+logger = logging.getLogger(__name__)
+
+
+class PeriodNetwork(nn.Module):
+    def __init__(self, period):
+        super().__init__()
+        self.period = period
+        wn = weight_norm
+        self.convs = nn.ModuleList(
+            [
+                wn(nn.Conv2d(1, 64, (5, 1), (3, 1), padding=(2, 0))),
+                wn(nn.Conv2d(64, 128, (5, 1), (3, 1), padding=(2, 0))),
+                wn(nn.Conv2d(128, 256, (5, 1), (3, 1), padding=(2, 0))),
+                wn(nn.Conv2d(256, 512, (5, 1), (3, 1), padding=(2, 0))),
+                wn(nn.Conv2d(512, 1024, (5, 1), 1, padding=(2, 0))),
+            ]
+        )
+        self.conv_post = wn(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
+
+    def forward(self, x):
+        """
+        Args:
+            x: [B, 1, T]
+        """
+        assert x.dim() == 3, f"(B, 1, T) is expected, but got {x.shape}."
+
+        # 1d to 2d
+        b, c, t = x.shape
+        if t % self.period != 0:  # pad first
+            n_pad = self.period - (t % self.period)
+            x = F.pad(x, (0, n_pad), "reflect")
+            t = t + n_pad
+        x = x.view(b, c, t // self.period, self.period)
+
+        for l in self.convs:
+            x = l(x)
+            x = F.leaky_relu(x, 0.2)
+        x = self.conv_post(x)
+        x = torch.flatten(x, 1, -1)
+
+        return x
+
+
+class SpecNetwork(nn.Module):
+    def __init__(self, stft_cfg: dict):
+        super().__init__()
+        wn = weight_norm
+        self.stft_cfg = stft_cfg
+        self.convs = nn.ModuleList(
+            [
+                wn(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
+                wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
+                wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
+                wn(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
+                wn(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
+            ]
+        )
+        self.conv_post = wn(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
+
+    def forward(self, x):
+        """
+        Args:
+            x: [B, 1, T]
+        """
+        x = self.spectrogram(x)
+        x = x.unsqueeze(1)
+        for l in self.convs:
+            x = l(x)
+            x = F.leaky_relu(x, 0.2)
+        x = self.conv_post(x)
+        x = x.flatten(1, -1)
+        return x
+
+    def spectrogram(self, x):
+        """
+        Args:
+            x: [B, 1, T]
+        """
+        x = x.squeeze(1)
+        dtype = x.dtype
+        stft_cfg = dict(self.stft_cfg)
+        x = torch.stft(x.float(), center=False, return_complex=False, **stft_cfg)
+        mag = x.norm(p=2, dim=-1)  # [B, F, TT]
+        mag = mag.to(dtype)  # [B, F, TT]
+        return mag
+
+
+class MD(nn.ModuleList):
+    def __init__(self, l: list):
+        super().__init__([self._create_network(x) for x in l])
+        self._loss_type = None
+
+    def loss_type_(self, loss_type):
+        self._loss_type = loss_type
+
+    def _create_network(self, _):
+        raise NotImplementedError
+
+    def _forward_each(self, d, x, y):
+        assert self._loss_type is not None, "loss_type is not set."
+        loss_type = self._loss_type
+
+        if loss_type == "hinge":
+            if y == 0:
+                # d(x) should be small -> -1
+                loss = F.relu(1 + d(x)).mean()
+            elif y == 1:
+                # d(x) should be large -> 1
+                loss = F.relu(1 - d(x)).mean()
+            else:
+                raise ValueError(f"Invalid y: {y}")
+        elif loss_type == "wgan":
+            if y == 0:
+                loss = d(x).mean()
+            elif y == 1:
+                loss = -d(x).mean()
+            else:
+                raise ValueError(f"Invalid y: {y}")
+        else:
+            raise ValueError(f"Invalid loss_type: {loss_type}")
+
+        return loss
+
+    def forward(self, x, y) -> Tensor:
+        losses = [self._forward_each(d, x, y) for d in self]
+        return torch.stack(losses).mean()
+
+
+class MPD(MD):
+    def __init__(self):
+        super().__init__([2, 3, 7, 13, 17])
+
+    def _create_network(self, period):
+        return PeriodNetwork(period)
+
+
+class MRD(MD):
+    def __init__(self, stft_cfgs):
+        super().__init__(stft_cfgs)
+
+    def _create_network(self, stft_cfg):
+        return SpecNetwork(stft_cfg)
+
+
+class Discriminator(nn.Module):
+    @property
+    def wav_rate(self):
+        return self.hp.wav_rate
+
+    def __init__(self, hp: HParams):
+        super().__init__()
+        self.hp = hp
+        self.stft_cfgs = get_stft_cfgs(hp)
+        self.mpd = MPD()
+        self.mrd = MRD(self.stft_cfgs)
+        self.dummy_float: Tensor
+        self.register_buffer("dummy_float", torch.zeros(0), persistent=False)
+
+    def loss_type_(self, loss_type):
+        self.mpd.loss_type_(loss_type)
+        self.mrd.loss_type_(loss_type)
+
+    def forward(self, fake, real=None):
+        """
+        Args:
+            fake: [B T]
+            real: [B T]
+        """
+        fake = fake.to(self.dummy_float)
+
+        if real is None:
+            self.loss_type_("wgan")
+        else:
+            length_difference = (fake.shape[-1] - real.shape[-1]) / real.shape[-1]
+            assert length_difference < 0.05, f"length_difference should be smaller than 5%"
+
+            self.loss_type_("hinge")
+            real = real.to(self.dummy_float)
+
+            fake = fake[..., : real.shape[-1]]
+            real = real[..., : fake.shape[-1]]
+
+        losses = {}
+
+        assert fake.dim() == 2, f"(B, T) is expected, but got {fake.shape}."
+        assert real is None or real.dim() == 2, f"(B, T) is expected, but got {real.shape}."
+
+        fake = fake.unsqueeze(1)
+
+        if real is None:
+            losses["mpd"] = self.mpd(fake, 1)
+            losses["mrd"] = self.mrd(fake, 1)
+        else:
+            real = real.unsqueeze(1)
+            losses["mpd_fake"] = self.mpd(fake, 0)
+            losses["mpd_real"] = self.mpd(real, 1)
+            losses["mrd_fake"] = self.mrd(fake, 0)
+            losses["mrd_real"] = self.mrd(real, 1)
+
+        return losses
diff --git a/modules/repos_static/resemble_enhance/enhancer/univnet/lvcnet.py b/modules/repos_static/resemble_enhance/enhancer/univnet/lvcnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..da56619090206c45fece0bc2c70e8fd3d2513704
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/enhancer/univnet/lvcnet.py
@@ -0,0 +1,281 @@
+""" refer from https://github.com/zceng/LVCNet """
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import nn
+from torch.nn.utils.parametrizations import weight_norm
+
+from .amp import AMPBlock
+
+
+class KernelPredictor(torch.nn.Module):
+    """Kernel predictor for the location-variable convolutions"""
+
+    def __init__(
+        self,
+        cond_channels,
+        conv_in_channels,
+        conv_out_channels,
+        conv_layers,
+        conv_kernel_size=3,
+        kpnet_hidden_channels=64,
+        kpnet_conv_size=3,
+        kpnet_dropout=0.0,
+        kpnet_nonlinear_activation="LeakyReLU",
+        kpnet_nonlinear_activation_params={"negative_slope": 0.1},
+    ):
+        """
+        Args:
+            cond_channels (int): number of channel for the conditioning sequence,
+            conv_in_channels (int): number of channel for the input sequence,
+            conv_out_channels (int): number of channel for the output sequence,
+            conv_layers (int): number of layers
+        """
+        super().__init__()
+
+        self.conv_in_channels = conv_in_channels
+        self.conv_out_channels = conv_out_channels
+        self.conv_kernel_size = conv_kernel_size
+        self.conv_layers = conv_layers
+
+        kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers  # l_w
+        kpnet_bias_channels = conv_out_channels * conv_layers  # l_b
+
+        self.input_conv = nn.Sequential(
+            weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)),
+            getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
+        )
+
+        self.residual_convs = nn.ModuleList()
+        padding = (kpnet_conv_size - 1) // 2
+        for _ in range(3):
+            self.residual_convs.append(
+                nn.Sequential(
+                    nn.Dropout(kpnet_dropout),
+                    weight_norm(
+                        nn.Conv1d(
+                            kpnet_hidden_channels,
+                            kpnet_hidden_channels,
+                            kpnet_conv_size,
+                            padding=padding,
+                            bias=True,
+                        )
+                    ),
+                    getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
+                    weight_norm(
+                        nn.Conv1d(
+                            kpnet_hidden_channels,
+                            kpnet_hidden_channels,
+                            kpnet_conv_size,
+                            padding=padding,
+                            bias=True,
+                        )
+                    ),
+                    getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
+                )
+            )
+        self.kernel_conv = weight_norm(
+            nn.Conv1d(
+                kpnet_hidden_channels,
+                kpnet_kernel_channels,
+                kpnet_conv_size,
+                padding=padding,
+                bias=True,
+            )
+        )
+        self.bias_conv = weight_norm(
+            nn.Conv1d(
+                kpnet_hidden_channels,
+                kpnet_bias_channels,
+                kpnet_conv_size,
+                padding=padding,
+                bias=True,
+            )
+        )
+
+    def forward(self, c):
+        """
+        Args:
+            c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
+        """
+        batch, _, cond_length = c.shape
+        c = self.input_conv(c)
+        for residual_conv in self.residual_convs:
+            residual_conv.to(c.device)
+            c = c + residual_conv(c)
+        k = self.kernel_conv(c)
+        b = self.bias_conv(c)
+        kernels = k.contiguous().view(
+            batch,
+            self.conv_layers,
+            self.conv_in_channels,
+            self.conv_out_channels,
+            self.conv_kernel_size,
+            cond_length,
+        )
+        bias = b.contiguous().view(
+            batch,
+            self.conv_layers,
+            self.conv_out_channels,
+            cond_length,
+        )
+
+        return kernels, bias
+
+
+class LVCBlock(torch.nn.Module):
+    """the location-variable convolutions"""
+
+    def __init__(
+        self,
+        in_channels,
+        cond_channels,
+        stride,
+        dilations=[1, 3, 9, 27],
+        lReLU_slope=0.2,
+        conv_kernel_size=3,
+        cond_hop_length=256,
+        kpnet_hidden_channels=64,
+        kpnet_conv_size=3,
+        kpnet_dropout=0.0,
+        add_extra_noise=False,
+        downsampling=False,
+    ):
+        super().__init__()
+
+        self.add_extra_noise = add_extra_noise
+
+        self.cond_hop_length = cond_hop_length
+        self.conv_layers = len(dilations)
+        self.conv_kernel_size = conv_kernel_size
+
+        self.kernel_predictor = KernelPredictor(
+            cond_channels=cond_channels,
+            conv_in_channels=in_channels,
+            conv_out_channels=2 * in_channels,
+            conv_layers=len(dilations),
+            conv_kernel_size=conv_kernel_size,
+            kpnet_hidden_channels=kpnet_hidden_channels,
+            kpnet_conv_size=kpnet_conv_size,
+            kpnet_dropout=kpnet_dropout,
+            kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope},
+        )
+
+        if downsampling:
+            self.convt_pre = nn.Sequential(
+                nn.LeakyReLU(lReLU_slope),
+                weight_norm(nn.Conv1d(in_channels, in_channels, 2 * stride + 1, padding="same")),
+                nn.AvgPool1d(stride, stride),
+            )
+        else:
+            if stride == 1:
+                self.convt_pre = nn.Sequential(
+                    nn.LeakyReLU(lReLU_slope),
+                    weight_norm(nn.Conv1d(in_channels, in_channels, 1)),
+                )
+            else:
+                self.convt_pre = nn.Sequential(
+                    nn.LeakyReLU(lReLU_slope),
+                    weight_norm(
+                        nn.ConvTranspose1d(
+                            in_channels,
+                            in_channels,
+                            2 * stride,
+                            stride=stride,
+                            padding=stride // 2 + stride % 2,
+                            output_padding=stride % 2,
+                        )
+                    ),
+                )
+
+        self.amp_block = AMPBlock(in_channels)
+
+        self.conv_blocks = nn.ModuleList()
+        for d in dilations:
+            self.conv_blocks.append(
+                nn.Sequential(
+                    nn.LeakyReLU(lReLU_slope),
+                    weight_norm(nn.Conv1d(in_channels, in_channels, conv_kernel_size, dilation=d, padding="same")),
+                    nn.LeakyReLU(lReLU_slope),
+                )
+            )
+
+    def forward(self, x, c):
+        """forward propagation of the location-variable convolutions.
+        Args:
+            x (Tensor): the input sequence (batch, in_channels, in_length)
+            c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
+
+        Returns:
+            Tensor: the output sequence (batch, in_channels, in_length)
+        """
+        _, in_channels, _ = x.shape  # (B, c_g, L')
+
+        x = self.convt_pre(x)  # (B, c_g, stride * L')
+
+        # Add one amp block just after the upsampling
+        x = self.amp_block(x)  # (B, c_g, stride * L')
+
+        kernels, bias = self.kernel_predictor(c)
+
+        if self.add_extra_noise:
+            # Add extra noise to part of the feature
+            a, b = x.chunk(2, dim=1)
+            b = b + torch.randn_like(b) * 0.1
+            x = torch.cat([a, b], dim=1)
+
+        for i, conv in enumerate(self.conv_blocks):
+            output = conv(x)  # (B, c_g, stride * L')
+
+            k = kernels[:, i, :, :, :, :]  # (B, 2 * c_g, c_g, kernel_size, cond_length)
+            b = bias[:, i, :, :]  # (B, 2 * c_g, cond_length)
+
+            output = self.location_variable_convolution(
+                output, k, b, hop_size=self.cond_hop_length
+            )  # (B, 2 * c_g, stride * L'): LVC
+            x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh(
+                output[:, in_channels:, :]
+            )  # (B, c_g, stride * L'): GAU
+
+        return x
+
+    def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256):
+        """perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
+        Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
+        Args:
+            x (Tensor): the input sequence (batch, in_channels, in_length).
+            kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
+            bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
+            dilation (int): the dilation of convolution.
+            hop_size (int): the hop_size of the conditioning sequence.
+        Returns:
+            (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
+        """
+        batch, _, in_length = x.shape
+        batch, _, out_channels, kernel_size, kernel_length = kernel.shape
+
+        assert in_length == (
+            kernel_length * hop_size
+        ), f"length of (x, kernel) is not matched, {in_length} != {kernel_length} * {hop_size}"
+
+        padding = dilation * int((kernel_size - 1) / 2)
+        x = F.pad(x, (padding, padding), "constant", 0)  # (batch, in_channels, in_length + 2*padding)
+        x = x.unfold(2, hop_size + 2 * padding, hop_size)  # (batch, in_channels, kernel_length, hop_size + 2*padding)
+
+        if hop_size < dilation:
+            x = F.pad(x, (0, dilation), "constant", 0)
+        x = x.unfold(
+            3, dilation, dilation
+        )  # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
+        x = x[:, :, :, :, :hop_size]
+        x = x.transpose(3, 4)  # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
+        x = x.unfold(4, kernel_size, 1)  # (batch, in_channels, kernel_length, dilation, _, kernel_size)
+
+        o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
+        o = o.to(memory_format=torch.channels_last_3d)
+        bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d)
+        o = o + bias
+        o = o.contiguous().view(batch, out_channels, -1)
+
+        return o
diff --git a/modules/repos_static/resemble_enhance/enhancer/univnet/mrstft.py b/modules/repos_static/resemble_enhance/enhancer/univnet/mrstft.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce95b43269c17ff05736bc338220e59345524309
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/enhancer/univnet/mrstft.py
@@ -0,0 +1,128 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Tomoki Hayashi
+#  MIT License (https://opensource.org/licenses/MIT)
+
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ..hparams import HParams
+
+
+def _make_stft_cfg(hop_length, win_length=None):
+    if win_length is None:
+        win_length = 4 * hop_length
+    n_fft = 2 ** (win_length - 1).bit_length()
+    return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
+
+
+def get_stft_cfgs(hp: HParams):
+    assert hp.wav_rate == 44100, f"wav_rate must be 44100, got {hp.wav_rate}"
+    return [_make_stft_cfg(h) for h in (100, 256, 512)]
+
+
+def stft(x, n_fft, hop_length, win_length, window):
+    dtype = x.dtype
+    x = torch.stft(x.float(), n_fft, hop_length, win_length, window, return_complex=True)
+    x = x.abs().to(dtype)
+    x = x.transpose(2, 1)  # (b f t) -> (b t f)
+    return x
+
+
+class SpectralConvergengeLoss(nn.Module):
+    def forward(self, x_mag, y_mag):
+        """Calculate forward propagation.
+        Args:
+            x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+            y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+        Returns:
+            Tensor: Spectral convergence loss value.
+        """
+        return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
+
+
+class LogSTFTMagnitudeLoss(nn.Module):
+    def forward(self, x_mag, y_mag):
+        """Calculate forward propagation.
+        Args:
+            x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+            y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+        Returns:
+            Tensor: Log STFT magnitude loss value.
+        """
+        return F.l1_loss(torch.log1p(x_mag), torch.log1p(y_mag))
+
+
+class STFTLoss(nn.Module):
+    def __init__(self, hp, stft_cfg: dict, window="hann_window"):
+        super().__init__()
+        self.hp = hp
+        self.stft_cfg = stft_cfg
+        self.spectral_convergenge_loss = SpectralConvergengeLoss()
+        self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
+        self.register_buffer("window", getattr(torch, window)(stft_cfg["win_length"]), persistent=False)
+
+    def forward(self, x, y):
+        """Calculate forward propagation.
+        Args:
+            x (Tensor): Predicted signal (B, T).
+            y (Tensor): Groundtruth signal (B, T).
+        Returns:
+            Tensor: Spectral convergence loss value.
+            Tensor: Log STFT magnitude loss value.
+        """
+        stft_cfg = dict(self.stft_cfg)
+        x_mag = stft(x, **stft_cfg, window=self.window)  # (b t) -> (b t f)
+        y_mag = stft(y, **stft_cfg, window=self.window)
+        sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
+        mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
+        return dict(sc=sc_loss, mag=mag_loss)
+
+
+class MRSTFTLoss(nn.Module):
+    def __init__(self, hp: HParams, window="hann_window"):
+        """Initialize Multi resolution STFT loss module.
+        Args:
+            resolutions (list): List of (FFT size, hop size, window length).
+            window (str): Window function type.
+        """
+        super().__init__()
+        stft_cfgs = get_stft_cfgs(hp)
+        self.stft_losses = nn.ModuleList()
+        self.hp = hp
+        for c in stft_cfgs:
+            self.stft_losses += [STFTLoss(hp, c, window=window)]
+
+    def forward(self, x, y):
+        """Calculate forward propagation.
+        Args:
+            x (Tensor): Predicted signal (b t).
+            y (Tensor): Groundtruth signal (b t).
+        Returns:
+            Tensor: Multi resolution spectral convergence loss value.
+            Tensor: Multi resolution log STFT magnitude loss value.
+        """
+        assert x.dim() == 2 and y.dim() == 2, f"(b t) is expected, but got {x.shape} and {y.shape}."
+
+        dtype = x.dtype
+
+        x = x.float()
+        y = y.float()
+
+        # Align length
+        x = x[..., : y.shape[-1]]
+        y = y[..., : x.shape[-1]]
+
+        losses = {}
+
+        for f in self.stft_losses:
+            d = f(x, y)
+            for k, v in d.items():
+                losses.setdefault(k, []).append(v)
+
+        for k, v in losses.items():
+            losses[k] = torch.stack(v, dim=0).mean().to(dtype)
+
+        return losses
diff --git a/modules/repos_static/resemble_enhance/enhancer/univnet/univnet.py b/modules/repos_static/resemble_enhance/enhancer/univnet/univnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb20217f048f398236698f6a38927310d0c1ba9b
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/enhancer/univnet/univnet.py
@@ -0,0 +1,94 @@
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+from torch.nn.utils.parametrizations import weight_norm
+
+from ..hparams import HParams
+from .lvcnet import LVCBlock
+from .mrstft import MRSTFTLoss
+
+
+class UnivNet(nn.Module):
+    @property
+    def d_noise(self):
+        return 128
+
+    @property
+    def strides(self):
+        return [7, 5, 4, 3]
+
+    @property
+    def dilations(self):
+        return [1, 3, 9, 27]
+
+    @property
+    def nc(self):
+        return self.hp.univnet_nc
+
+    @property
+    def scale_factor(self) -> int:
+        return self.hp.hop_size
+
+    def __init__(self, hp: HParams, d_input):
+        super().__init__()
+        self.d_input = d_input
+
+        self.hp = hp
+
+        self.blocks = nn.ModuleList(
+            [
+                LVCBlock(
+                    self.nc,
+                    d_input,
+                    stride=stride,
+                    dilations=self.dilations,
+                    cond_hop_length=hop_length,
+                    kpnet_conv_size=3,
+                )
+                for stride, hop_length in zip(self.strides, np.cumprod(self.strides))
+            ]
+        )
+
+        self.conv_pre = weight_norm(nn.Conv1d(self.d_noise, self.nc, 7, padding=3, padding_mode="reflect"))
+
+        self.conv_post = nn.Sequential(
+            nn.LeakyReLU(0.2),
+            weight_norm(nn.Conv1d(self.nc, 1, 7, padding=3, padding_mode="reflect")),
+            nn.Tanh(),
+        )
+
+        self.mrstft = MRSTFTLoss(hp)
+
+    @property
+    def eps(self):
+        return 1e-5
+
+    def forward(self, x: Tensor, y: Tensor | None = None, npad=10):
+        """
+        Args:
+            x: (b c t), acoustic features
+            y: (b t), waveform
+        Returns:
+            z: (b t), waveform
+        """
+        assert x.ndim == 3, "x must be 3D tensor"
+        assert y is None or y.ndim == 2, "y must be 2D tensor"
+        assert x.shape[1] == self.d_input, f"x.shape[1] must be {self.d_input}, but got {x.shape}"
+        assert npad >= 0, "npad must be positive or zero"
+
+        x = F.pad(x, (0, npad), "constant", 0)
+        z = torch.randn(x.shape[0], self.d_noise, x.shape[2]).to(x)
+        z = self.conv_pre(z)  # (b c t)
+
+        for block in self.blocks:
+            z = block(z, x)  # (b c t)
+
+        z = self.conv_post(z)  # (b 1 t)
+        z = z[..., : -self.scale_factor * npad]
+        z = z.squeeze(1)  # (b t)
+
+        if y is not None:
+            self.losses = self.mrstft(z, y)
+
+        return z
diff --git a/modules/repos_static/resemble_enhance/hparams.py b/modules/repos_static/resemble_enhance/hparams.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8e716175fa962ada1d98cd755430e2ea770278c
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/hparams.py
@@ -0,0 +1,128 @@
+import logging
+from dataclasses import asdict, dataclass
+from pathlib import Path
+
+from omegaconf import OmegaConf
+from rich.console import Console
+from rich.panel import Panel
+from rich.table import Table
+
+logger = logging.getLogger(__name__)
+
+console = Console()
+
+
+def _make_stft_cfg(hop_length, win_length=None):
+    if win_length is None:
+        win_length = 4 * hop_length
+    n_fft = 2 ** (win_length - 1).bit_length()
+    return dict(n_fft=n_fft, hop_length=hop_length, win_length=win_length)
+
+
+def _build_rich_table(rows, columns, title=None):
+    table = Table(title=title, header_style=None)
+    for column in columns:
+        table.add_column(column.capitalize(), justify="left")
+    for row in rows:
+        table.add_row(*map(str, row))
+    return Panel(table, expand=False)
+
+
+def _rich_print_dict(d, title="Config", key="Key", value="Value"):
+    console.print(_build_rich_table(d.items(), [key, value], title))
+
+
+@dataclass(frozen=True)
+class HParams:
+    # Dataset
+    fg_dir: Path = Path("data/fg")
+    bg_dir: Path = Path("data/bg")
+    rir_dir: Path = Path("data/rir")
+    load_fg_only: bool = False
+    praat_augment_prob: float = 0
+
+    # Audio settings
+    wav_rate: int = 44_100
+    n_fft: int = 2048
+    win_size: int = 2048
+    hop_size: int = 420  # 9.5ms
+    num_mels: int = 128
+    stft_magnitude_min: float = 1e-4
+    preemphasis: float = 0.97
+    mix_alpha_range: tuple[float, float] = (0.2, 0.8)
+
+    # Training
+    nj: int = 64
+    training_seconds: float = 1.0
+    batch_size_per_gpu: int = 16
+    min_lr: float = 1e-5
+    max_lr: float = 1e-4
+    warmup_steps: int = 1000
+    max_steps: int = 1_000_000
+    gradient_clipping: float = 1.0
+
+    @property
+    def deepspeed_config(self):
+        return {
+            "train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
+            "optimizer": {
+                "type": "Adam",
+                "params": {"lr": float(self.min_lr)},
+            },
+            "scheduler": {
+                "type": "WarmupDecayLR",
+                "params": {
+                    "warmup_min_lr": float(self.min_lr),
+                    "warmup_max_lr": float(self.max_lr),
+                    "warmup_num_steps": self.warmup_steps,
+                    "total_num_steps": self.max_steps,
+                    "warmup_type": "linear",
+                },
+            },
+            "gradient_clipping": self.gradient_clipping,
+        }
+
+    @property
+    def stft_cfgs(self):
+        assert self.wav_rate == 44_100, f"wav_rate must be 44_100, got {self.wav_rate}"
+        return [_make_stft_cfg(h) for h in (100, 256, 512)]
+
+    @classmethod
+    def from_yaml(cls, path: Path) -> "HParams":
+        logger.info(f"Reading hparams from {path}")
+        # First merge to fix types (e.g., str -> Path)
+        return cls(**dict(OmegaConf.merge(cls(), OmegaConf.load(path))))
+
+    def save_if_not_exists(self, run_dir: Path):
+        path = run_dir / "hparams.yaml"
+        if path.exists():
+            logger.info(f"{path} already exists, not saving")
+            return
+        path.parent.mkdir(parents=True, exist_ok=True)
+        OmegaConf.save(asdict(self), str(path))
+
+    @classmethod
+    def load(cls, run_dir, yaml: Path | None = None):
+        hps = []
+
+        if (run_dir / "hparams.yaml").exists():
+            hps.append(cls.from_yaml(run_dir / "hparams.yaml"))
+
+        if yaml is not None:
+            hps.append(cls.from_yaml(yaml))
+
+        if len(hps) == 0:
+            hps.append(cls())
+
+        for hp in hps[1:]:
+            if hp != hps[0]:
+                errors = {}
+                for k, v in asdict(hp).items():
+                    if getattr(hps[0], k) != v:
+                        errors[k] = f"{getattr(hps[0], k)} != {v}"
+                raise ValueError(f"Found inconsistent hparams: {errors}, consider deleting {run_dir}")
+
+        return hps[0]
+
+    def print(self):
+        _rich_print_dict(asdict(self), title="HParams")
diff --git a/modules/repos_static/resemble_enhance/inference.py b/modules/repos_static/resemble_enhance/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e78a11fdf134bcc182e5c9ef0cf81e02c64850b
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/inference.py
@@ -0,0 +1,163 @@
+import logging
+import time
+
+import torch
+import torch.nn.functional as F
+from torch.nn.utils.parametrize import remove_parametrizations
+from torchaudio.functional import resample
+from torchaudio.transforms import MelSpectrogram
+from tqdm import trange
+
+from .hparams import HParams
+
+logger = logging.getLogger(__name__)
+
+
+@torch.inference_mode()
+def inference_chunk(model, dwav, sr, device, npad=441):
+    assert model.hp.wav_rate == sr, f"Expected {model.hp.wav_rate} Hz, got {sr} Hz"
+    del sr
+
+    length = dwav.shape[-1]
+    abs_max = dwav.abs().max().clamp(min=1e-7)
+
+    assert dwav.dim() == 1, f"Expected 1D waveform, got {dwav.dim()}D"
+    dwav = dwav.to(device)
+    dwav = dwav / abs_max  # Normalize
+    dwav = F.pad(dwav, (0, npad))
+    hwav = model(dwav[None])[0].cpu()  # (T,)
+    hwav = hwav[:length]  # Trim padding
+    hwav = hwav * abs_max  # Unnormalize
+
+    return hwav
+
+
+def compute_corr(x, y):
+    return torch.fft.ifft(torch.fft.fft(x) * torch.fft.fft(y).conj()).abs()
+
+
+def compute_offset(chunk1, chunk2, sr=44100):
+    """
+    Args:
+        chunk1: (T,)
+        chunk2: (T,)
+    Returns:
+        offset: int, offset in samples such that chunk1 ~= chunk2.roll(-offset)
+    """
+    hop_length = sr // 200  # 5 ms resolution
+    win_length = hop_length * 4
+    n_fft = 2 ** (win_length - 1).bit_length()
+
+    mel_fn = MelSpectrogram(
+        sample_rate=sr,
+        n_fft=n_fft,
+        win_length=win_length,
+        hop_length=hop_length,
+        n_mels=80,
+        f_min=0.0,
+        f_max=sr // 2,
+    )
+
+    spec1 = mel_fn(chunk1).log1p()
+    spec2 = mel_fn(chunk2).log1p()
+
+    corr = compute_corr(spec1, spec2)  # (F, T)
+    corr = corr.mean(dim=0)  # (T,)
+
+    argmax = corr.argmax().item()
+
+    if argmax > len(corr) // 2:
+        argmax -= len(corr)
+
+    offset = -argmax * hop_length
+
+    return offset
+
+
+def merge_chunks(chunks, chunk_length, hop_length, sr=44100, length=None):
+    signal_length = (len(chunks) - 1) * hop_length + chunk_length
+    overlap_length = chunk_length - hop_length
+    signal = torch.zeros(signal_length, device=chunks[0].device)
+
+    fadein = torch.linspace(0, 1, overlap_length, device=chunks[0].device)
+    fadein = torch.cat([fadein, torch.ones(hop_length, device=chunks[0].device)])
+    fadeout = torch.linspace(1, 0, overlap_length, device=chunks[0].device)
+    fadeout = torch.cat([torch.ones(hop_length, device=chunks[0].device), fadeout])
+
+    for i, chunk in enumerate(chunks):
+        start = i * hop_length
+        end = start + chunk_length
+
+        if len(chunk) < chunk_length:
+            chunk = F.pad(chunk, (0, chunk_length - len(chunk)))
+
+        if i > 0:
+            pre_region = chunks[i - 1][-overlap_length:]
+            cur_region = chunk[:overlap_length]
+            offset = compute_offset(pre_region, cur_region, sr=sr)
+            start -= offset
+            end -= offset
+
+        if i == 0:
+            chunk = chunk * fadeout
+        elif i == len(chunks) - 1:
+            chunk = chunk * fadein
+        else:
+            chunk = chunk * fadein * fadeout
+
+        signal[start:end] += chunk[: len(signal[start:end])]
+
+    signal = signal[:length]
+
+    return signal
+
+
+def remove_weight_norm_recursively(module):
+    for _, module in module.named_modules():
+        try:
+            remove_parametrizations(module, "weight")
+        except Exception:
+            pass
+
+
+def inference(model, dwav, sr, device, chunk_seconds: float = 30.0, overlap_seconds: float = 1.0):
+    remove_weight_norm_recursively(model)
+
+    hp: HParams = model.hp
+
+    dwav = resample(
+        dwav,
+        orig_freq=sr,
+        new_freq=hp.wav_rate,
+        lowpass_filter_width=64,
+        rolloff=0.9475937167399596,
+        resampling_method="sinc_interp_kaiser",
+        beta=14.769656459379492,
+    )
+
+    del sr  # Everything is in hp.wav_rate now
+
+    sr = hp.wav_rate
+
+    if torch.cuda.is_available():
+        torch.cuda.synchronize()
+
+    start_time = time.perf_counter()
+
+    chunk_length = int(sr * chunk_seconds)
+    overlap_length = int(sr * overlap_seconds)
+    hop_length = chunk_length - overlap_length
+
+    chunks = []
+    for start in trange(0, dwav.shape[-1], hop_length):
+        chunks.append(inference_chunk(model, dwav[start : start + chunk_length], sr, device))
+
+    hwav = merge_chunks(chunks, chunk_length, hop_length, sr=sr, length=dwav.shape[-1])
+
+    if torch.cuda.is_available():
+        torch.cuda.synchronize()
+
+    elapsed_time = time.perf_counter() - start_time
+    logger.info(f"Elapsed time: {elapsed_time:.3f} s, {hwav.shape[-1] / elapsed_time / 1000:.3f} kHz")
+
+    return hwav, sr
diff --git a/modules/repos_static/resemble_enhance/melspec.py b/modules/repos_static/resemble_enhance/melspec.py
new file mode 100644
index 0000000000000000000000000000000000000000..dce1f8bfb95b9a1814db8c7305c07ccf2bfa9111
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/melspec.py
@@ -0,0 +1,61 @@
+import numpy as np
+import torch
+from torch import nn
+from torchaudio.transforms import MelSpectrogram as TorchMelSpectrogram
+
+from .hparams import HParams
+
+
+class MelSpectrogram(nn.Module):
+    def __init__(self, hp: HParams):
+        """
+        Torch implementation of Resemble's mel extraction.
+        Note that the values are NOT identical to librosa's implementation
+        due to floating point precisions.
+        """
+        super().__init__()
+        self.hp = hp
+        self.melspec = TorchMelSpectrogram(
+            hp.wav_rate,
+            n_fft=hp.n_fft,
+            win_length=hp.win_size,
+            hop_length=hp.hop_size,
+            f_min=0,
+            f_max=hp.wav_rate // 2,
+            n_mels=hp.num_mels,
+            power=1,
+            normalized=False,
+            # NOTE: Folowing librosa's default.
+            pad_mode="constant",
+            norm="slaney",
+            mel_scale="slaney",
+        )
+        self.register_buffer("stft_magnitude_min", torch.FloatTensor([hp.stft_magnitude_min]))
+        self.min_level_db = 20 * np.log10(hp.stft_magnitude_min)
+        self.preemphasis = hp.preemphasis
+        self.hop_size = hp.hop_size
+
+    def forward(self, wav, pad=True):
+        """
+        Args:
+            wav: [B, T]
+        """
+        device = wav.device
+        if wav.is_mps:
+            wav = wav.cpu()
+            self.to(wav.device)
+        if self.preemphasis > 0:
+            wav = torch.nn.functional.pad(wav, [1, 0], value=0)
+            wav = wav[..., 1:] - self.preemphasis * wav[..., :-1]
+        mel = self.melspec(wav)
+        mel = self._amp_to_db(mel)
+        mel_normed = self._normalize(mel)
+        assert not pad or mel_normed.shape[-1] == 1 + wav.shape[-1] // self.hop_size  # Sanity check
+        mel_normed = mel_normed.to(device)
+        return mel_normed  # (M, T)
+
+    def _normalize(self, s, headroom_db=15):
+        return (s - self.min_level_db) / (-self.min_level_db + headroom_db)
+
+    def _amp_to_db(self, x):
+        return x.clamp_min(self.hp.stft_magnitude_min).log10() * 20
diff --git a/modules/repos_static/resemble_enhance/utils/__init__.py b/modules/repos_static/resemble_enhance/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..07f6303742506be443f477c40a42a9551b4e8af4
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/utils/__init__.py
@@ -0,0 +1,2 @@
+from .logging import setup_logging
+from .utils import save_mels, tree_map
diff --git a/modules/repos_static/resemble_enhance/utils/control.py b/modules/repos_static/resemble_enhance/utils/control.py
new file mode 100644
index 0000000000000000000000000000000000000000..56b74b46d73b0c3757849dad310ca0899bb5f5a4
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/utils/control.py
@@ -0,0 +1,26 @@
+import logging
+import selectors
+import sys
+from functools import cache
+
+from .distributed import global_leader_only
+
+_logger = logging.getLogger(__name__)
+
+
+@cache
+def _get_stdin_selector():
+    selector = selectors.DefaultSelector()
+    selector.register(fileobj=sys.stdin, events=selectors.EVENT_READ)
+    return selector
+
+
+@global_leader_only(boardcast_return=True)
+def non_blocking_input():
+    s = ""
+    selector = _get_stdin_selector()
+    events = selector.select(timeout=0)
+    for key, _ in events:
+        s: str = key.fileobj.readline().strip()
+        _logger.info(f'Get stdin "{s}".')
+    return s
diff --git a/modules/repos_static/resemble_enhance/utils/logging.py b/modules/repos_static/resemble_enhance/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..26c43b6dc785ff6547478cb04833dd92b5df7311
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/utils/logging.py
@@ -0,0 +1,38 @@
+import logging
+from pathlib import Path
+
+from rich.logging import RichHandler
+
+from .distributed import global_leader_only
+
+
+@global_leader_only
+def setup_logging(run_dir):
+    handlers = []
+    stdout_handler = RichHandler()
+    stdout_handler.setLevel(logging.INFO)
+    handlers.append(stdout_handler)
+
+    if run_dir is not None:
+        filename = Path(run_dir) / f"log.txt"
+        filename.parent.mkdir(parents=True, exist_ok=True)
+        file_handler = logging.FileHandler(filename, mode="a")
+        file_handler.setLevel(logging.DEBUG)
+        handlers.append(file_handler)
+
+    # Update all existing loggers
+    for name in ["DeepSpeed"]:
+        logger = logging.getLogger(name)
+        if isinstance(logger, logging.Logger):
+            for handler in list(logger.handlers):
+                logger.removeHandler(handler)
+            for handler in handlers:
+                logger.addHandler(handler)
+
+    # Set the default logger
+    logging.basicConfig(
+        level=logging.getLevelName("INFO"),
+        format="%(message)s",
+        datefmt="[%X]",
+        handlers=handlers,
+    )
diff --git a/modules/repos_static/resemble_enhance/utils/utils.py b/modules/repos_static/resemble_enhance/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c402c9ae2bd634e903d2a9861243005e6a8c9147
--- /dev/null
+++ b/modules/repos_static/resemble_enhance/utils/utils.py
@@ -0,0 +1,73 @@
+from typing import Callable, TypeVar, overload
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+def save_mels(path, *, targ_mel, pred_mel, cond_mel):
+    n = 3 if cond_mel is None else 4
+
+    plt.figure(figsize=(10, n * 4))
+
+    i = 1
+
+    plt.subplot(n, 1, i)
+    plt.imshow(pred_mel, origin="lower", interpolation="none")
+    plt.title(f"Pred mel {pred_mel.shape}")
+    i += 1
+
+    plt.subplot(n, 1, i)
+    plt.imshow(targ_mel, origin="lower", interpolation="none")
+    plt.title(f"GT mel {targ_mel.shape}")
+    i += 1
+
+    plt.subplot(n, 1, i)
+    pred_mel = pred_mel[:, : targ_mel.shape[1]]
+    targ_mel = targ_mel[:, : pred_mel.shape[1]]
+    plt.imshow(np.abs(pred_mel - targ_mel), origin="lower", interpolation="none")
+    plt.title(f"Diff mel {pred_mel.shape}, mse={np.mean((pred_mel - targ_mel)**2):.4f}")
+    i += 1
+
+    if cond_mel is not None:
+        plt.subplot(n, 1, i)
+        plt.imshow(cond_mel, origin="lower", interpolation="none")
+        plt.title(f"Cond mel {cond_mel.shape}")
+        i += 1
+
+    plt.savefig(path, dpi=480)
+    plt.close()
+
+
+T = TypeVar("T")
+
+
+@overload
+def tree_map(fn: Callable, x: list[T]) -> list[T]:
+    ...
+
+
+@overload
+def tree_map(fn: Callable, x: tuple[T]) -> tuple[T]:
+    ...
+
+
+@overload
+def tree_map(fn: Callable, x: dict[str, T]) -> dict[str, T]:
+    ...
+
+
+@overload
+def tree_map(fn: Callable, x: T) -> T:
+    ...
+
+
+def tree_map(fn: Callable, x):
+    if isinstance(x, list):
+        x = [tree_map(fn, xi) for xi in x]
+    elif isinstance(x, tuple):
+        x = (tree_map(fn, xi) for xi in x)
+    elif isinstance(x, dict):
+        x = {k: tree_map(fn, v) for k, v in x.items()}
+    else:
+        x = fn(x)
+    return x
diff --git a/modules/speaker.py b/modules/speaker.py
index d066f2b20d3cd0d54331eba6c7b905db695bc794..2fcbc5a4ac99a89f382fd4f9988757d4c8f71470 100644
--- a/modules/speaker.py
+++ b/modules/speaker.py
@@ -99,6 +99,10 @@ class SpeakerManager:
                 self.speakers[speaker_file] = Speaker.from_file(
                     self.speaker_dir + speaker_file
                 )
+        # 检查是否有被删除的,同步到 speakers
+        for fname, spk in self.speakers.items():
+            if not os.path.exists(self.speaker_dir + fname):
+                del self.speakers[fname]
 
     def list_speakers(self):
         return list(self.speakers.values())
diff --git a/modules/webui/speaker/__init__.py b/modules/webui/speaker/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/webui/speaker/speaker_creator.py b/modules/webui/speaker/speaker_creator.py
new file mode 100644
index 0000000000000000000000000000000000000000..344ab1f08317be3f7741d0aa3d3c705b64e34aef
--- /dev/null
+++ b/modules/webui/speaker/speaker_creator.py
@@ -0,0 +1,171 @@
+import gradio as gr
+import torch
+from modules.speaker import Speaker
+from modules.utils.SeedContext import SeedContext
+from modules.hf import spaces
+from modules.models import load_chat_tts
+from modules.utils.rng import np_rng
+from modules.webui.webui_utils import get_speakers, tts_generate
+
+import tempfile
+
+names_list = [
+    "Alice",
+    "Bob",
+    "Carol",
+    "Carlos",
+    "Charlie",
+    "Chuck",
+    "Chad",
+    "Craig",
+    "Dan",
+    "Dave",
+    "David",
+    "Erin",
+    "Eve",
+    "Yves",
+    "Faythe",
+    "Frank",
+    "Grace",
+    "Heidi",
+    "Ivan",
+    "Judy",
+    "Mallory",
+    "Mallet",
+    "Darth",
+    "Michael",
+    "Mike",
+    "Niaj",
+    "Olivia",
+    "Oscar",
+    "Peggy",
+    "Pat",
+    "Rupert",
+    "Sybil",
+    "Trent",
+    "Ted",
+    "Trudy",
+    "Victor",
+    "Vanna",
+    "Walter",
+    "Wendy",
+]
+
+
+@torch.inference_mode()
+@spaces.GPU
+def create_spk_from_seed(
+    seed: int,
+    name: str,
+    gender: str,
+    desc: str,
+):
+    chat_tts = load_chat_tts()
+    with SeedContext(seed):
+        emb = chat_tts.sample_random_speaker()
+    spk = Speaker(seed=-2, name=name, gender=gender, describe=desc)
+    spk.emb = emb
+
+    with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
+        torch.save(spk, tmp_file)
+        tmp_file_path = tmp_file.name
+
+    return tmp_file_path
+
+
+@torch.inference_mode()
+@spaces.GPU
+def test_spk_voice(seed: int, text: str):
+    return tts_generate(
+        spk=seed,
+        text=text,
+    )
+
+
+def random_speaker():
+    seed = np_rng()
+    name = names_list[seed % len(names_list)]
+    return seed, name
+
+
+creator_ui_desc = """
+## Speaker Creator
+使用本面板快捷抽卡生成 speaker.pt 文件。
+
+1. **生成说话人**:输入种子、名字、性别和描述。点击 "Generate speaker.pt" 按钮,生成的说话人配置会保存为.pt文件。
+2. **测试说话人声音**:输入测试文本。点击 "Test Voice" 按钮,生成的音频会在 "Output Audio" 中播放。
+3. **随机生成说话人**:点击 "Random Speaker" 按钮,随机生成一个种子和名字,可以进一步编辑其他信息并测试。
+"""
+
+
+def speaker_creator_ui():
+    def on_generate(seed, name, gender, desc):
+        file_path = create_spk_from_seed(seed, name, gender, desc)
+        return file_path
+
+    def create_test_voice_card(seed_input):
+        with gr.Group():
+            gr.Markdown("🎤Test voice")
+            with gr.Row():
+                test_voice_btn = gr.Button("Test Voice", variant="secondary")
+
+                with gr.Column(scale=4):
+                    test_text = gr.Textbox(
+                        label="Test Text",
+                        placeholder="Please input test text",
+                        value="说话人测试 123456789 [uv_break] ok, test done [lbreak]",
+                    )
+                    with gr.Row():
+                        current_seed = gr.Label(label="Current Seed", value=-1)
+                        with gr.Column(scale=4):
+                            output_audio = gr.Audio(label="Output Audio")
+
+        test_voice_btn.click(
+            fn=test_spk_voice,
+            inputs=[seed_input, test_text],
+            outputs=[output_audio],
+        )
+        test_voice_btn.click(
+            fn=lambda x: x,
+            inputs=[seed_input],
+            outputs=[current_seed],
+        )
+
+    gr.Markdown(creator_ui_desc)
+
+    with gr.Row():
+        with gr.Column(scale=2):
+            with gr.Group():
+                gr.Markdown("ℹ️Speaker info")
+                seed_input = gr.Number(label="Seed", value=2)
+                name_input = gr.Textbox(
+                    label="Name", placeholder="Enter speaker name", value="Bob"
+                )
+                gender_input = gr.Textbox(
+                    label="Gender", placeholder="Enter gender", value="*"
+                )
+                desc_input = gr.Textbox(
+                    label="Description",
+                    placeholder="Enter description",
+                )
+                random_button = gr.Button("Random Speaker")
+            with gr.Group():
+                gr.Markdown("🔊Generate speaker.pt")
+                generate_button = gr.Button("Save .pt file")
+                output_file = gr.File(label="Save to File")
+        with gr.Column(scale=5):
+            create_test_voice_card(seed_input=seed_input)
+            create_test_voice_card(seed_input=seed_input)
+            create_test_voice_card(seed_input=seed_input)
+            create_test_voice_card(seed_input=seed_input)
+
+    random_button.click(
+        random_speaker,
+        outputs=[seed_input, name_input],
+    )
+
+    generate_button.click(
+        fn=on_generate,
+        inputs=[seed_input, name_input, gender_input, desc_input],
+        outputs=[output_file],
+    )
diff --git a/modules/webui/speaker/speaker_merger.py b/modules/webui/speaker/speaker_merger.py
new file mode 100644
index 0000000000000000000000000000000000000000..66a0854790d258f3e0cb3476efc995aac862364f
--- /dev/null
+++ b/modules/webui/speaker/speaker_merger.py
@@ -0,0 +1,255 @@
+import io
+import gradio as gr
+import torch
+
+from modules.hf import spaces
+from modules.webui.webui_utils import get_speakers, tts_generate
+from modules.speaker import speaker_mgr, Speaker
+
+import tempfile
+
+
+def spk_to_tensor(spk):
+    spk = spk.split(" : ")[1].strip() if " : " in spk else spk
+    if spk == "None" or spk == "":
+        return None
+    return speaker_mgr.get_speaker(spk).emb
+
+
+def get_speaker_show_name(spk):
+    if spk.gender == "*" or spk.gender == "":
+        return spk.name
+    return f"{spk.gender} : {spk.name}"
+
+
+def merge_spk(
+    spk_a,
+    spk_a_w,
+    spk_b,
+    spk_b_w,
+    spk_c,
+    spk_c_w,
+    spk_d,
+    spk_d_w,
+):
+    tensor_a = spk_to_tensor(spk_a)
+    tensor_b = spk_to_tensor(spk_b)
+    tensor_c = spk_to_tensor(spk_c)
+    tensor_d = spk_to_tensor(spk_d)
+
+    assert (
+        tensor_a is not None
+        or tensor_b is not None
+        or tensor_c is not None
+        or tensor_d is not None
+    ), "At least one speaker should be selected"
+
+    merge_tensor = torch.zeros_like(
+        tensor_a
+        if tensor_a is not None
+        else (
+            tensor_b
+            if tensor_b is not None
+            else tensor_c if tensor_c is not None else tensor_d
+        )
+    )
+
+    total_weight = 0
+    if tensor_a is not None:
+        merge_tensor += spk_a_w * tensor_a
+        total_weight += spk_a_w
+    if tensor_b is not None:
+        merge_tensor += spk_b_w * tensor_b
+        total_weight += spk_b_w
+    if tensor_c is not None:
+        merge_tensor += spk_c_w * tensor_c
+        total_weight += spk_c_w
+    if tensor_d is not None:
+        merge_tensor += spk_d_w * tensor_d
+        total_weight += spk_d_w
+
+    if total_weight > 0:
+        merge_tensor /= total_weight
+
+    merged_spk = Speaker.from_tensor(merge_tensor)
+    merged_spk.name = "<MIX>"
+
+    return merged_spk
+
+
+@torch.inference_mode()
+@spaces.GPU
+def merge_and_test_spk_voice(
+    spk_a, spk_a_w, spk_b, spk_b_w, spk_c, spk_c_w, spk_d, spk_d_w, test_text
+):
+    merged_spk = merge_spk(
+        spk_a,
+        spk_a_w,
+        spk_b,
+        spk_b_w,
+        spk_c,
+        spk_c_w,
+        spk_d,
+        spk_d_w,
+    )
+    return tts_generate(
+        spk=merged_spk,
+        text=test_text,
+    )
+
+
+@torch.inference_mode()
+@spaces.GPU
+def merge_spk_to_file(
+    spk_a,
+    spk_a_w,
+    spk_b,
+    spk_b_w,
+    spk_c,
+    spk_c_w,
+    spk_d,
+    spk_d_w,
+    speaker_name,
+    speaker_gender,
+    speaker_desc,
+):
+    merged_spk = merge_spk(
+        spk_a, spk_a_w, spk_b, spk_b_w, spk_c, spk_c_w, spk_d, spk_d_w
+    )
+    merged_spk.name = speaker_name
+    merged_spk.gender = speaker_gender
+    merged_spk.desc = speaker_desc
+
+    with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
+        torch.save(merged_spk, tmp_file)
+        tmp_file_path = tmp_file.name
+
+    return tmp_file_path
+
+
+merge_desc = """
+## Speaker Merger
+
+在本面板中,您可以选择多个说话人并指定他们的权重,合成新的语音并进行测试。以下是各个功能的详细说明:
+
+1. 选择说话人: 您可以从下拉菜单中选择最多四个说话人(A、B、C、D),每个说话人都有一个对应的权重滑块,范围从0到10。权重决定了每个说话人在合成语音中的影响程度。
+2. 合成语音: 在选择好说话人和设置好权重后,您可以在“Test Text”框中输入要测试的文本,然后点击“测试语音”按钮来生成并播放合成的语音。
+3. 保存说话人: 您还可以在右侧的“说话人信息”部分填写新的说话人的名称、性别和描述,并点击“Save Speaker”按钮来保存合成的说话人。保存后的说话人文件将显示在“Merged Speaker”栏中,供下载使用。
+"""
+
+
+def get_spk_choices():
+    speakers = get_speakers()
+
+    speaker_names = ["None"] + [get_speaker_show_name(speaker) for speaker in speakers]
+    return speaker_names
+
+
+# 显示 a b c d 四个选择框,选择一个或多个,然后可以试音,并导出
+def create_speaker_merger():
+    speaker_names = get_spk_choices()
+
+    gr.Markdown(merge_desc)
+
+    def spk_picker(label_tail: str):
+        with gr.Row():
+            spk_a = gr.Dropdown(
+                choices=speaker_names, value="None", label=f"Speaker {label_tail}"
+            )
+            refresh_a_btn = gr.Button("🔄", variant="secondary")
+
+        def refresh_a():
+            speaker_mgr.refresh_speakers()
+            speaker_names = get_spk_choices()
+            return gr.update(choices=speaker_names)
+
+        refresh_a_btn.click(refresh_a, outputs=[spk_a])
+        spk_a_w = gr.Slider(
+            value=1,
+            minimum=0,
+            maximum=10,
+            step=0.1,
+            label=f"Weight {label_tail}",
+        )
+        return spk_a, spk_a_w
+
+    with gr.Row():
+        with gr.Column(scale=5):
+            with gr.Row():
+                with gr.Group():
+                    spk_a, spk_a_w = spk_picker("A")
+
+                with gr.Group():
+                    spk_b, spk_b_w = spk_picker("B")
+
+                with gr.Group():
+                    spk_c, spk_c_w = spk_picker("C")
+
+                with gr.Group():
+                    spk_d, spk_d_w = spk_picker("D")
+
+            with gr.Row():
+                with gr.Column(scale=3):
+                    with gr.Group():
+                        gr.Markdown("🎤Test voice")
+                        with gr.Row():
+                            test_voice_btn = gr.Button(
+                                "Test Voice", variant="secondary"
+                            )
+
+                            with gr.Column(scale=4):
+                                test_text = gr.Textbox(
+                                    label="Test Text",
+                                    placeholder="Please input test text",
+                                    value="说话人合并测试 123456789 [uv_break] ok, test done [lbreak]",
+                                )
+
+                                output_audio = gr.Audio(label="Output Audio")
+
+        with gr.Column(scale=1):
+            with gr.Group():
+                gr.Markdown("🗃️Save to file")
+
+                speaker_name = gr.Textbox(label="Name", value="forge_speaker_merged")
+                speaker_gender = gr.Textbox(label="Gender", value="*")
+                speaker_desc = gr.Textbox(label="Description", value="merged speaker")
+
+                save_btn = gr.Button("Save Speaker", variant="primary")
+
+                merged_spker = gr.File(
+                    label="Merged Speaker", interactive=False, type="binary"
+                )
+
+    test_voice_btn.click(
+        merge_and_test_spk_voice,
+        inputs=[
+            spk_a,
+            spk_a_w,
+            spk_b,
+            spk_b_w,
+            spk_c,
+            spk_c_w,
+            spk_d,
+            spk_d_w,
+            test_text,
+        ],
+        outputs=[output_audio],
+    )
+
+    save_btn.click(
+        merge_spk_to_file,
+        inputs=[
+            spk_a,
+            spk_a_w,
+            spk_b,
+            spk_b_w,
+            spk_c,
+            spk_c_w,
+            spk_d,
+            spk_d_w,
+            speaker_name,
+            speaker_gender,
+            speaker_desc,
+        ],
+        outputs=[merged_spker],
+    )
diff --git a/modules/webui/speaker_tab.py b/modules/webui/speaker_tab.py
index 31abf96c4b2acc213a674603ee0e44400add3e4b..4021bc6646a9877dcd29284f49e4a95cab3e6531 100644
--- a/modules/webui/speaker_tab.py
+++ b/modules/webui/speaker_tab.py
@@ -1,259 +1,13 @@
-import io
 import gradio as gr
-import torch
 
-from modules.hf import spaces
-from modules.webui.webui_utils import get_speakers, tts_generate
-from modules.speaker import speaker_mgr, Speaker
+from modules.webui.speaker.speaker_merger import create_speaker_merger
+from modules.webui.speaker.speaker_creator import speaker_creator_ui
 
-import tempfile
 
-
-def spk_to_tensor(spk):
-    spk = spk.split(" : ")[1].strip() if " : " in spk else spk
-    if spk == "None" or spk == "":
-        return None
-    return speaker_mgr.get_speaker(spk).emb
-
-
-def get_speaker_show_name(spk):
-    if spk.gender == "*" or spk.gender == "":
-        return spk.name
-    return f"{spk.gender} : {spk.name}"
-
-
-def merge_spk(
-    spk_a,
-    spk_a_w,
-    spk_b,
-    spk_b_w,
-    spk_c,
-    spk_c_w,
-    spk_d,
-    spk_d_w,
-):
-    tensor_a = spk_to_tensor(spk_a)
-    tensor_b = spk_to_tensor(spk_b)
-    tensor_c = spk_to_tensor(spk_c)
-    tensor_d = spk_to_tensor(spk_d)
-
-    assert (
-        tensor_a is not None
-        or tensor_b is not None
-        or tensor_c is not None
-        or tensor_d is not None
-    ), "At least one speaker should be selected"
-
-    merge_tensor = torch.zeros_like(
-        tensor_a
-        if tensor_a is not None
-        else (
-            tensor_b
-            if tensor_b is not None
-            else tensor_c if tensor_c is not None else tensor_d
-        )
-    )
-
-    total_weight = 0
-    if tensor_a is not None:
-        merge_tensor += spk_a_w * tensor_a
-        total_weight += spk_a_w
-    if tensor_b is not None:
-        merge_tensor += spk_b_w * tensor_b
-        total_weight += spk_b_w
-    if tensor_c is not None:
-        merge_tensor += spk_c_w * tensor_c
-        total_weight += spk_c_w
-    if tensor_d is not None:
-        merge_tensor += spk_d_w * tensor_d
-        total_weight += spk_d_w
-
-    if total_weight > 0:
-        merge_tensor /= total_weight
-
-    merged_spk = Speaker.from_tensor(merge_tensor)
-    merged_spk.name = "<MIX>"
-
-    return merged_spk
-
-
-@torch.inference_mode()
-@spaces.GPU
-def merge_and_test_spk_voice(
-    spk_a, spk_a_w, spk_b, spk_b_w, spk_c, spk_c_w, spk_d, spk_d_w, test_text
-):
-    merged_spk = merge_spk(
-        spk_a,
-        spk_a_w,
-        spk_b,
-        spk_b_w,
-        spk_c,
-        spk_c_w,
-        spk_d,
-        spk_d_w,
-    )
-    return tts_generate(
-        spk=merged_spk,
-        text=test_text,
-    )
-
-
-@torch.inference_mode()
-@spaces.GPU
-def merge_spk_to_file(
-    spk_a,
-    spk_a_w,
-    spk_b,
-    spk_b_w,
-    spk_c,
-    spk_c_w,
-    spk_d,
-    spk_d_w,
-    speaker_name,
-    speaker_gender,
-    speaker_desc,
-):
-    merged_spk = merge_spk(
-        spk_a, spk_a_w, spk_b, spk_b_w, spk_c, spk_c_w, spk_d, spk_d_w
-    )
-    merged_spk.name = speaker_name
-    merged_spk.gender = speaker_gender
-    merged_spk.desc = speaker_desc
-
-    with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
-        torch.save(merged_spk, tmp_file)
-        tmp_file_path = tmp_file.name
-
-    return tmp_file_path
-
-
-merge_desc = """
-## Speaker Merger
-
-在本面板中,您可以选择多个说话人并指定他们的权重,合成新的语音并进行测试。以下是各个功能的详细说明:
-
-### 1. 选择说话人
-您可以从下拉菜单中选择最多四个说话人(A、B、C、D),每个说话人都有一个对应的权重滑块,范围从0到10。权重决定了每个说话人在合成语音中的影响程度。
-
-### 2. 合成语音
-在选择好说话人和设置好权重后,您可以在“测试文本”框中输入要测试的文本,然后点击“测试语音”按钮来生成并播放合成的语音。
-
-### 3. 保存说话人
-您还可以在右侧的“说话人信息”部分填写新的说话人的名称、性别和描述,并点击“保存说话人”按钮来保存合成的说话人。保存后的说话人文件将显示在“合成说话人”栏中,供下载使用。
-"""
-
-
-# 显示 a b c d 四个选择框,选择一个或多个,然后可以试音,并导出
 def create_speaker_panel():
-    speakers = get_speakers()
-
-    speaker_names = ["None"] + [get_speaker_show_name(speaker) for speaker in speakers]
 
     with gr.Tabs():
+        with gr.TabItem("Creator"):
+            speaker_creator_ui()
         with gr.TabItem("Merger"):
-            gr.Markdown(merge_desc)
-
-            with gr.Row():
-                with gr.Column(scale=5):
-                    with gr.Row():
-                        with gr.Group():
-                            spk_a = gr.Dropdown(
-                                choices=speaker_names, value="None", label="Speaker A"
-                            )
-                            spk_a_w = gr.Slider(
-                                value=1, minimum=0, maximum=10, step=1, label="Weight A"
-                            )
-
-                        with gr.Group():
-                            spk_b = gr.Dropdown(
-                                choices=speaker_names, value="None", label="Speaker B"
-                            )
-                            spk_b_w = gr.Slider(
-                                value=1, minimum=0, maximum=10, step=1, label="Weight B"
-                            )
-
-                        with gr.Group():
-                            spk_c = gr.Dropdown(
-                                choices=speaker_names, value="None", label="Speaker C"
-                            )
-                            spk_c_w = gr.Slider(
-                                value=1, minimum=0, maximum=10, step=1, label="Weight C"
-                            )
-
-                        with gr.Group():
-                            spk_d = gr.Dropdown(
-                                choices=speaker_names, value="None", label="Speaker D"
-                            )
-                            spk_d_w = gr.Slider(
-                                value=1, minimum=0, maximum=10, step=1, label="Weight D"
-                            )
-
-                    with gr.Row():
-                        with gr.Column(scale=3):
-                            with gr.Group():
-                                gr.Markdown("🎤Test voice")
-                                with gr.Row():
-                                    test_voice_btn = gr.Button(
-                                        "Test Voice", variant="secondary"
-                                    )
-
-                                    with gr.Column(scale=4):
-                                        test_text = gr.Textbox(
-                                            label="Test Text",
-                                            placeholder="Please input test text",
-                                            value="说话人合并测试 123456789 [uv_break] ok, test done [lbreak]",
-                                        )
-
-                                        output_audio = gr.Audio(label="Output Audio")
-
-                with gr.Column(scale=1):
-                    with gr.Group():
-                        gr.Markdown("🗃️Save to file")
-
-                        speaker_name = gr.Textbox(
-                            label="Name", value="forge_speaker_merged"
-                        )
-                        speaker_gender = gr.Textbox(label="Gender", value="*")
-                        speaker_desc = gr.Textbox(
-                            label="Description", value="merged speaker"
-                        )
-
-                        save_btn = gr.Button("Save Speaker", variant="primary")
-
-                        merged_spker = gr.File(
-                            label="Merged Speaker", interactive=False, type="binary"
-                        )
-
-            test_voice_btn.click(
-                merge_and_test_spk_voice,
-                inputs=[
-                    spk_a,
-                    spk_a_w,
-                    spk_b,
-                    spk_b_w,
-                    spk_c,
-                    spk_c_w,
-                    spk_d,
-                    spk_d_w,
-                    test_text,
-                ],
-                outputs=[output_audio],
-            )
-
-            save_btn.click(
-                merge_spk_to_file,
-                inputs=[
-                    spk_a,
-                    spk_a_w,
-                    spk_b,
-                    spk_b_w,
-                    spk_c,
-                    spk_c_w,
-                    spk_d,
-                    spk_d_w,
-                    speaker_name,
-                    speaker_gender,
-                    speaker_desc,
-                ],
-                outputs=[merged_spker],
-            )
+            create_speaker_merger()
diff --git a/modules/webui/tts_tab.py b/modules/webui/tts_tab.py
index 0c807d5e2fe6e514b44917c891db16c26557eaca..d51cb2eea4c646590c5b3f63b7ec266ade221a44 100644
--- a/modules/webui/tts_tab.py
+++ b/modules/webui/tts_tab.py
@@ -13,10 +13,7 @@ from modules import config
 
 default_text_content = """
 chat T T S 是一款强大的对话式文本转语音模型。它有中英混读和多说话人的能力。
-chat T T S 不仅能够生成自然流畅的语音,还能控制[laugh]笑声啊[laugh],
-停顿啊[uv_break]语气词啊等副语言现象[uv_break]。这个韵律超越了许多开源模型[uv_break]。
-请注意,chat T T S 的使用应遵守法律和伦理准则,避免滥用的安全风险。[uv_break]
-"""
+""".strip()
 
 
 def create_tts_interface():