Spaces:
Running
Running
modify app
Browse files- app.py +5 -5
- modules/loss.py +62 -0
app.py
CHANGED
@@ -158,21 +158,21 @@ with gr.Blocks() as demo:
|
|
158 |
with gr.Row():
|
159 |
gr.Markdown("Interactive demo of Inference Time Optimization (ITO) for Music Mastering Style Transfer. \
|
160 |
The mastering style transfer is performed by a differentiable audio processing model, and the predicted parameters are shown as the output. \
|
161 |
-
Perform mastering style transfer with an input source audio and a reference mastering style audio. On top of this result, you can perform ITO to optimize the reference embedding
|
162 |
gr.Image("ito_snow.png", width=300)
|
163 |
|
164 |
gr.Markdown("## Step 1: Mastering Style Transfer")
|
165 |
|
166 |
with gr.Tab("Upload Audio"):
|
167 |
with gr.Row():
|
168 |
-
input_audio = gr.Audio(label="Source Audio
|
169 |
-
reference_audio = gr.Audio(label="Reference Style Audio
|
170 |
|
171 |
process_button = gr.Button("Process Mastering Style Transfer")
|
172 |
|
173 |
with gr.Row():
|
174 |
with gr.Column():
|
175 |
-
output_audio = gr.Audio(label="Output Audio
|
176 |
normalized_input = gr.Audio(label="Normalized Source Audio", type='numpy')
|
177 |
param_output = gr.Textbox(label="Predicted Parameters", lines=5)
|
178 |
|
@@ -213,7 +213,7 @@ with gr.Blocks() as demo:
|
|
213 |
gr.Markdown("## Step 2: Inference Time Optimization (ITO)")
|
214 |
|
215 |
with gr.Row():
|
216 |
-
ito_reference_audio = gr.Audio(label="ITO Reference Style Audio (optional)")
|
217 |
with gr.Column():
|
218 |
num_steps = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of Steps")
|
219 |
optimizer = gr.Dropdown(["Adam", "RAdam", "SGD"], value="RAdam", label="Optimizer")
|
|
|
158 |
with gr.Row():
|
159 |
gr.Markdown("Interactive demo of Inference Time Optimization (ITO) for Music Mastering Style Transfer. \
|
160 |
The mastering style transfer is performed by a differentiable audio processing model, and the predicted parameters are shown as the output. \
|
161 |
+
Perform mastering style transfer with an input source audio and a reference mastering style audio. On top of this result, you can perform ITO to optimize the reference embedding $z_{ref}$ to further gain control over the output mastering style.")
|
162 |
gr.Image("ito_snow.png", width=300)
|
163 |
|
164 |
gr.Markdown("## Step 1: Mastering Style Transfer")
|
165 |
|
166 |
with gr.Tab("Upload Audio"):
|
167 |
with gr.Row():
|
168 |
+
input_audio = gr.Audio(label="Source Audio $x_{in}$")
|
169 |
+
reference_audio = gr.Audio(label="Reference Style Audio $x_{ref}$")
|
170 |
|
171 |
process_button = gr.Button("Process Mastering Style Transfer")
|
172 |
|
173 |
with gr.Row():
|
174 |
with gr.Column():
|
175 |
+
output_audio = gr.Audio(label="Output Audio y'", type='numpy')
|
176 |
normalized_input = gr.Audio(label="Normalized Source Audio", type='numpy')
|
177 |
param_output = gr.Textbox(label="Predicted Parameters", lines=5)
|
178 |
|
|
|
213 |
gr.Markdown("## Step 2: Inference Time Optimization (ITO)")
|
214 |
|
215 |
with gr.Row():
|
216 |
+
ito_reference_audio = gr.Audio(label="ITO Reference Style Audio $x'_{ref}$ (optional)")
|
217 |
with gr.Column():
|
218 |
num_steps = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of Steps")
|
219 |
optimizer = gr.Dropdown(["Adam", "RAdam", "SGD"], value="RAdam", label="Optimizer")
|
modules/loss.py
CHANGED
@@ -176,6 +176,68 @@ class Loss:
|
|
176 |
)
|
177 |
|
178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
|
181 |
"""
|
|
|
176 |
)
|
177 |
|
178 |
|
179 |
+
import laion_clap
|
180 |
+
import torchaudio
|
181 |
+
# CLAP feature loss
|
182 |
+
class CLAPFeatureLoss(nn.Module):
|
183 |
+
def __init__(self, distance_fn='mse'):
|
184 |
+
super(CLAPFeatureLoss, self).__init__()
|
185 |
+
self.target_sample_rate = 48000 # CLAP expects 48kHz audio
|
186 |
+
self.model = laion_clap.CLAP_Module(enable_fusion=False)
|
187 |
+
self.model.load_ckpt() # download the default pretrained checkpoint
|
188 |
+
|
189 |
+
self.distance_fn = distance_fn
|
190 |
+
if distance_fn == 'mse':
|
191 |
+
self.compute_distance = F.mse_loss
|
192 |
+
elif distance_fn == 'l1':
|
193 |
+
self.compute_distance = F.l1_loss
|
194 |
+
elif distance_fn == 'cosine':
|
195 |
+
self.compute_distance = lambda x, y: 1 - F.cosine_similarity(x, y).mean()
|
196 |
+
else:
|
197 |
+
raise ValueError(f"Unsupported distance function: {distance_fn}")
|
198 |
+
|
199 |
+
def forward(self, input_audio, target_audio, sample_rate):
|
200 |
+
# Ensure input is in the correct shape (N, C, T)
|
201 |
+
if input_audio.dim() == 2:
|
202 |
+
input_audio = input_audio.unsqueeze(1)
|
203 |
+
if target_audio.dim() == 2:
|
204 |
+
target_audio = target_audio.unsqueeze(1)
|
205 |
+
|
206 |
+
# Convert to mono if stereo
|
207 |
+
if input_audio.shape[1] > 1:
|
208 |
+
input_audio = input_audio.mean(dim=1, keepdim=True)
|
209 |
+
if target_audio.shape[1] > 1:
|
210 |
+
target_audio = target_audio.mean(dim=1, keepdim=True)
|
211 |
+
|
212 |
+
# Resample if necessary
|
213 |
+
if sample_rate != self.target_sample_rate:
|
214 |
+
input_audio = self.resample(input_audio, sample_rate)
|
215 |
+
target_audio = self.resample(target_audio, sample_rate)
|
216 |
+
|
217 |
+
# Quantize audio data
|
218 |
+
input_audio = self.quantize(input_audio)
|
219 |
+
target_audio = self.quantize(target_audio)
|
220 |
+
|
221 |
+
# Get CLAP embeddings
|
222 |
+
input_embed = self.model.get_audio_embedding_from_data(x=input_audio, use_tensor=True)
|
223 |
+
target_embed = self.model.get_audio_embedding_from_data(x=target_audio, use_tensor=True)
|
224 |
+
|
225 |
+
# Compute loss using the specified distance function
|
226 |
+
loss = self.compute_distance(input_embed, target_embed)
|
227 |
+
|
228 |
+
return loss
|
229 |
+
|
230 |
+
def quantize(self, audio):
|
231 |
+
audio = audio.squeeze(1) # Remove channel dimension
|
232 |
+
audio = torch.clamp(audio, -1.0, 1.0)
|
233 |
+
audio = (audio * 32767.0).to(torch.int16).to(torch.float32) / 32767.0
|
234 |
+
return audio
|
235 |
+
|
236 |
+
def resample(self, audio, sample_rate):
|
237 |
+
resampler = torchaudio.transforms.Resample(
|
238 |
+
orig_freq=sample_rate, new_freq=self.target_sample_rate
|
239 |
+
).to(audio.device)
|
240 |
+
return resampler(audio)
|
241 |
|
242 |
|
243 |
"""
|