YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

metrics:
https://mvsep.com/quality_checker/entry/8921
Metric sdr for instrum: 17.5961
Metric si_sdr for instrum: 17.5021
Metric l1_freq for instrum: 40.2225
Metric log_wmse for instrum: 14.3112
Metric aura_stft for instrum: 15.7642
Metric aura_mrstft for instrum: 20.1560
Metric bleedless for instrum: 42.8670
Metric fullness for instrum: 32.0372

In this model, the mask estimator has been changed to FNO1d.
To run this model, you currently need to use the Music Source Separation Training repository and rewrite some parts.

Replace MaskEstimator in models/bs_roformer/bs_roformer.py with the following.


from neuralop.models import FNO1d
class MaskEstimator(Module):
    @beartype
    def __init__(
            self,
            dim,
            dim_inputs: Tuple[int, ...],
            depth,
            mlp_expansion_factor=4
    ):
        super().__init__()
        self.dim_inputs = dim_inputs
        self.to_freqs = ModuleList([])
        dim_hidden = dim * mlp_expansion_factor

        for dim_in in dim_inputs:
            net = []

            mlp = nn.Sequential(
                FNO1d(n_modes_height=64, hidden_channels=dim, in_channels=dim, out_channels=dim_in*2, lifting_channels=dim, projection_channels=dim, n_layers=3, separable=True),
                nn.GLU(dim=-2)
            )

            self.to_freqs.append(mlp)

    def forward(self, x):
        x = x.unbind(dim=-2)

        outs = []

        for band_features, mlp in zip(x, self.to_freqs):
            band_features = rearrange(band_features, 'b t c -> b c t')
            with torch.autocast(device_type='cuda', enabled=False, dtype=torch.float32):
                freq_out = mlp(band_features).float()
            freq_out = rearrange(freq_out, 'b c t -> b t c')
            outs.append(freq_out)

        return torch.cat(outs, dim=-1)


In addition, install the neuraloperator library.
pip install neuraloperator

Also, if you are using pytorch 2.6 or later, an error will occur with torch.load.
If an error occurs, add the following line above torch.load. (at utils/model_utils.py line 531)
  
with torch.serialization.safe_globals([torch._C._nn.gelu]):
  

Errors may also occur when using load_state_dict. In such cases, specify strict=False as an argument.(at utils/model_utils.py line 532)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support