jhtonyKoo commited on
Commit
d09ad44
·
1 Parent(s): c9034be

modify app

Browse files
Files changed (2) hide show
  1. app.py +5 -5
  2. modules/loss.py +62 -0
app.py CHANGED
@@ -158,21 +158,21 @@ with gr.Blocks() as demo:
158
  with gr.Row():
159
  gr.Markdown("Interactive demo of Inference Time Optimization (ITO) for Music Mastering Style Transfer. \
160
  The mastering style transfer is performed by a differentiable audio processing model, and the predicted parameters are shown as the output. \
161
- Perform mastering style transfer with an input source audio and a reference mastering style audio. On top of this result, you can perform ITO to optimize the reference embedding z~ref~ $z_{ref}$ to further gain control over the output mastering style.")
162
  gr.Image("ito_snow.png", width=300)
163
 
164
  gr.Markdown("## Step 1: Mastering Style Transfer")
165
 
166
  with gr.Tab("Upload Audio"):
167
  with gr.Row():
168
- input_audio = gr.Audio(label="Source Audio (x~in~ $x_{in}$)")
169
- reference_audio = gr.Audio(label="Reference Style Audio (x~ref~ $x_{ref}$)")
170
 
171
  process_button = gr.Button("Process Mastering Style Transfer")
172
 
173
  with gr.Row():
174
  with gr.Column():
175
- output_audio = gr.Audio(label="Output Audio (y')", type='numpy')
176
  normalized_input = gr.Audio(label="Normalized Source Audio", type='numpy')
177
  param_output = gr.Textbox(label="Predicted Parameters", lines=5)
178
 
@@ -213,7 +213,7 @@ with gr.Blocks() as demo:
213
  gr.Markdown("## Step 2: Inference Time Optimization (ITO)")
214
 
215
  with gr.Row():
216
- ito_reference_audio = gr.Audio(label="ITO Reference Style Audio (optional)")
217
  with gr.Column():
218
  num_steps = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of Steps")
219
  optimizer = gr.Dropdown(["Adam", "RAdam", "SGD"], value="RAdam", label="Optimizer")
 
158
  with gr.Row():
159
  gr.Markdown("Interactive demo of Inference Time Optimization (ITO) for Music Mastering Style Transfer. \
160
  The mastering style transfer is performed by a differentiable audio processing model, and the predicted parameters are shown as the output. \
161
+ Perform mastering style transfer with an input source audio and a reference mastering style audio. On top of this result, you can perform ITO to optimize the reference embedding $z_{ref}$ to further gain control over the output mastering style.")
162
  gr.Image("ito_snow.png", width=300)
163
 
164
  gr.Markdown("## Step 1: Mastering Style Transfer")
165
 
166
  with gr.Tab("Upload Audio"):
167
  with gr.Row():
168
+ input_audio = gr.Audio(label="Source Audio $x_{in}$")
169
+ reference_audio = gr.Audio(label="Reference Style Audio $x_{ref}$")
170
 
171
  process_button = gr.Button("Process Mastering Style Transfer")
172
 
173
  with gr.Row():
174
  with gr.Column():
175
+ output_audio = gr.Audio(label="Output Audio y'", type='numpy')
176
  normalized_input = gr.Audio(label="Normalized Source Audio", type='numpy')
177
  param_output = gr.Textbox(label="Predicted Parameters", lines=5)
178
 
 
213
  gr.Markdown("## Step 2: Inference Time Optimization (ITO)")
214
 
215
  with gr.Row():
216
+ ito_reference_audio = gr.Audio(label="ITO Reference Style Audio $x'_{ref}$ (optional)")
217
  with gr.Column():
218
  num_steps = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of Steps")
219
  optimizer = gr.Dropdown(["Adam", "RAdam", "SGD"], value="RAdam", label="Optimizer")
modules/loss.py CHANGED
@@ -176,6 +176,68 @@ class Loss:
176
  )
177
 
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
 
181
  """
 
176
  )
177
 
178
 
179
+ import laion_clap
180
+ import torchaudio
181
+ # CLAP feature loss
182
+ class CLAPFeatureLoss(nn.Module):
183
+ def __init__(self, distance_fn='mse'):
184
+ super(CLAPFeatureLoss, self).__init__()
185
+ self.target_sample_rate = 48000 # CLAP expects 48kHz audio
186
+ self.model = laion_clap.CLAP_Module(enable_fusion=False)
187
+ self.model.load_ckpt() # download the default pretrained checkpoint
188
+
189
+ self.distance_fn = distance_fn
190
+ if distance_fn == 'mse':
191
+ self.compute_distance = F.mse_loss
192
+ elif distance_fn == 'l1':
193
+ self.compute_distance = F.l1_loss
194
+ elif distance_fn == 'cosine':
195
+ self.compute_distance = lambda x, y: 1 - F.cosine_similarity(x, y).mean()
196
+ else:
197
+ raise ValueError(f"Unsupported distance function: {distance_fn}")
198
+
199
+ def forward(self, input_audio, target_audio, sample_rate):
200
+ # Ensure input is in the correct shape (N, C, T)
201
+ if input_audio.dim() == 2:
202
+ input_audio = input_audio.unsqueeze(1)
203
+ if target_audio.dim() == 2:
204
+ target_audio = target_audio.unsqueeze(1)
205
+
206
+ # Convert to mono if stereo
207
+ if input_audio.shape[1] > 1:
208
+ input_audio = input_audio.mean(dim=1, keepdim=True)
209
+ if target_audio.shape[1] > 1:
210
+ target_audio = target_audio.mean(dim=1, keepdim=True)
211
+
212
+ # Resample if necessary
213
+ if sample_rate != self.target_sample_rate:
214
+ input_audio = self.resample(input_audio, sample_rate)
215
+ target_audio = self.resample(target_audio, sample_rate)
216
+
217
+ # Quantize audio data
218
+ input_audio = self.quantize(input_audio)
219
+ target_audio = self.quantize(target_audio)
220
+
221
+ # Get CLAP embeddings
222
+ input_embed = self.model.get_audio_embedding_from_data(x=input_audio, use_tensor=True)
223
+ target_embed = self.model.get_audio_embedding_from_data(x=target_audio, use_tensor=True)
224
+
225
+ # Compute loss using the specified distance function
226
+ loss = self.compute_distance(input_embed, target_embed)
227
+
228
+ return loss
229
+
230
+ def quantize(self, audio):
231
+ audio = audio.squeeze(1) # Remove channel dimension
232
+ audio = torch.clamp(audio, -1.0, 1.0)
233
+ audio = (audio * 32767.0).to(torch.int16).to(torch.float32) / 32767.0
234
+ return audio
235
+
236
+ def resample(self, audio, sample_rate):
237
+ resampler = torchaudio.transforms.Resample(
238
+ orig_freq=sample_rate, new_freq=self.target_sample_rate
239
+ ).to(audio.device)
240
+ return resampler(audio)
241
 
242
 
243
  """