jhtonyKoo commited on
Commit
ef49e48
1 Parent(s): 78bac9e

modify loss

Browse files
__pycache__/inference.cpython-311.pyc CHANGED
Binary files a/__pycache__/inference.cpython-311.pyc and b/__pycache__/inference.cpython-311.pyc differ
 
inference.py CHANGED
@@ -68,7 +68,7 @@ class MasteringStyleTransfer:
68
  return output_audio, predicted_params
69
 
70
  def inference_time_optimization(self, input_tensor, reference_tensor, ito_config, initial_reference_feature):
71
- fit_embedding = torch.nn.Parameter(initial_reference_feature)
72
  optimizer = getattr(torch.optim, ito_config['optimizer'])([fit_embedding], lr=ito_config['learning_rate'])
73
 
74
  min_loss = float('inf')
@@ -97,9 +97,7 @@ class MasteringStyleTransfer:
97
  target = reference_tensor
98
  else:
99
  target = ito_config['clap_text_prompt']
100
- print(f'ito_config clap_distance_fn: {ito_config["clap_distance_fn"]}')
101
  total_loss = self.clap_loss(output_audio, target, self.args.sample_rate, distance_fn=ito_config['clap_distance_fn'])
102
- print(f'total_loss: {total_loss}')
103
 
104
  if total_loss < min_loss:
105
  min_loss = total_loss.item()
@@ -122,6 +120,9 @@ class MasteringStyleTransfer:
122
  total_loss.backward()
123
  optimizer.step()
124
 
 
 
 
125
  return all_results, min_loss_step
126
 
127
  def preprocess_audio(self, audio, target_sample_rate=44100, normalize=False):
@@ -290,7 +291,6 @@ class MasteringStyleTransfer:
290
 
291
  return "\n".join(output)
292
 
293
-
294
  def reload_weights(model, ckpt_path, device):
295
  checkpoint = torch.load(ckpt_path, map_location=device)
296
 
 
68
  return output_audio, predicted_params
69
 
70
  def inference_time_optimization(self, input_tensor, reference_tensor, ito_config, initial_reference_feature):
71
+ fit_embedding = torch.nn.Parameter(initial_reference_feature, requires_grad=True)
72
  optimizer = getattr(torch.optim, ito_config['optimizer'])([fit_embedding], lr=ito_config['learning_rate'])
73
 
74
  min_loss = float('inf')
 
97
  target = reference_tensor
98
  else:
99
  target = ito_config['clap_text_prompt']
 
100
  total_loss = self.clap_loss(output_audio, target, self.args.sample_rate, distance_fn=ito_config['clap_distance_fn'])
 
101
 
102
  if total_loss < min_loss:
103
  min_loss = total_loss.item()
 
120
  total_loss.backward()
121
  optimizer.step()
122
 
123
+ gc.collect()
124
+ torch.cuda.empty_cache()
125
+
126
  return all_results, min_loss_step
127
 
128
  def preprocess_audio(self, audio, target_sample_rate=44100, normalize=False):
 
291
 
292
  return "\n".join(output)
293
 
 
294
  def reload_weights(model, ckpt_path, device):
295
  checkpoint = torch.load(ckpt_path, map_location=device)
296
 
modules/__pycache__/loss.cpython-311.pyc CHANGED
Binary files a/modules/__pycache__/loss.cpython-311.pyc and b/modules/__pycache__/loss.cpython-311.pyc differ
 
modules/loss.py CHANGED
@@ -185,35 +185,26 @@ class CLAPFeatureLoss(nn.Module):
185
  self.target_sample_rate = 48000 # CLAP expects 48kHz audio
186
  self.model = laion_clap.CLAP_Module(enable_fusion=False)
187
  self.model.load_ckpt() # download the default pretrained checkpoint
188
-
189
- # Freeze the CLAP model parameters
190
- for param in self.model.parameters():
191
- param.requires_grad = False
192
 
193
- def forward(self, input_audio, target, sample_rate, distance_fn='mse'):
194
  # Process input audio
195
- with torch.no_grad():
196
- input_audio = self.preprocess_audio(input_audio, sample_rate)
197
-
198
- with torch.enable_grad():
199
- input_embed = self.model.get_audio_embedding_from_data(x=input_audio, use_tensor=True)
200
 
201
  # Process target (audio or text)
202
- with torch.no_grad():
203
- if isinstance(target, torch.Tensor):
204
- target_audio = self.preprocess_audio(target, sample_rate)
205
- target_embed = self.model.get_audio_embedding_from_data(x=target_audio, use_tensor=True)
206
- elif isinstance(target, str) or (isinstance(target, list) and isinstance(target[0], str)):
207
- target_embed = self.model.get_text_embedding(target, use_tensor=True)
208
- else:
209
- raise ValueError("Target must be either audio tensor or text (string or list of strings)")
210
 
211
  # Compute loss using the specified distance function
212
  loss = self.compute_distance(input_embed, target_embed, distance_fn)
213
 
214
  return loss
215
 
216
- def preprocess_audio(self, audio, sample_rate):
217
  # Ensure input is in the correct shape (N, C, T)
218
  if audio.dim() == 2:
219
  audio = audio.unsqueeze(1)
@@ -221,15 +212,22 @@ class CLAPFeatureLoss(nn.Module):
221
  # Convert to mono if stereo
222
  if audio.shape[1] > 1:
223
  audio = audio.mean(dim=1, keepdim=True)
224
-
225
  # Resample if necessary
226
  if sample_rate != self.target_sample_rate:
227
  audio = self.resample(audio, sample_rate)
228
-
229
- # Quantize audio data
230
- audio = self.quantize(audio)
231
-
232
- return audio
 
 
 
 
 
 
 
 
233
 
234
  def compute_distance(self, x, y, distance_fn):
235
  if distance_fn == 'mse':
@@ -241,86 +239,12 @@ class CLAPFeatureLoss(nn.Module):
241
  else:
242
  raise ValueError(f"Unsupported distance function: {distance_fn}")
243
 
244
- def quantize(self, audio):
245
- audio = audio.squeeze(1) # Remove channel dimension
246
- audio = torch.clamp(audio, -1.0, 1.0)
247
- audio = (audio * 32767.0).to(torch.int16).to(torch.float32) / 32767.0
248
- return audio
249
-
250
- def resample(self, audio, orig_sample_rate):
251
  resampler = torchaudio.transforms.Resample(
252
- orig_freq=orig_sample_rate, new_freq=self.target_sample_rate
253
  ).to(audio.device)
254
  return resampler(audio)
255
-
256
- # def forward(self, input_audio, target, sample_rate, distance_fn='cosine'):
257
- # # Process input audio
258
- # input_embed = self.process_audio(input_audio, sample_rate)
259
-
260
- # # Process target (audio or text)
261
- # if isinstance(target, torch.Tensor):
262
- # target_embed = self.process_audio(target, sample_rate)
263
- # elif isinstance(target, str) or (isinstance(target, list) and isinstance(target[0], str)):
264
- # target_embed = self.process_text(target)
265
- # else:
266
- # raise ValueError("Target must be either audio tensor or text (string or list of strings)")
267
-
268
- # # Compute loss using the specified distance function
269
- # loss = self.compute_distance(input_embed, target_embed, distance_fn)
270
-
271
- # return loss
272
-
273
- # def process_audio(self, audio, sample_rate):
274
- # # Ensure input is in the correct shape (N, C, T)
275
- # if audio.dim() == 2:
276
- # audio = audio.unsqueeze(1)
277
-
278
- # # Convert to mono if stereo
279
- # if audio.shape[1] > 1:
280
- # audio = audio.mean(dim=1, keepdim=True)
281
-
282
- # # Resample if necessary
283
- # if sample_rate != self.target_sample_rate:
284
- # audio = self.resample(audio, sample_rate)
285
-
286
- # # Quantize audio data
287
- # audio = self.quantize(audio)
288
-
289
- # # Get CLAP embeddings
290
- # with torch.no_grad():
291
- # embed = self.model.get_audio_embedding_from_data(x=audio, use_tensor=True)
292
- # return embed
293
-
294
- # def process_text(self, text):
295
- # # Get CLAP embeddings for text
296
- # # ensure input is a list of strings
297
- # if not isinstance(text, list):
298
- # text = [text]
299
- # with torch.no_grad():
300
- # embed = self.model.get_text_embedding(text, use_tensor=True)
301
- # return embed
302
-
303
- # def compute_distance(self, x, y, distance_fn):
304
- # if distance_fn == 'mse':
305
- # return F.mse_loss(x, y)
306
- # elif distance_fn == 'l1':
307
- # return F.l1_loss(x, y)
308
- # elif distance_fn == 'cosine':
309
- # return 1 - F.cosine_similarity(x, y).mean()
310
- # else:
311
- # raise ValueError(f"Unsupported distance function: {distance_fn}")
312
-
313
- # def quantize(self, audio):
314
- # audio = audio.squeeze(1) # Remove channel dimension
315
- # audio = torch.clamp(audio, -1.0, 1.0)
316
- # audio = (audio * 32767.0).to(torch.int16).to(torch.float32) / 32767.0
317
- # return audio
318
-
319
- # def resample(self, audio, input_sample_rate):
320
- # resampler = torchaudio.transforms.Resample(
321
- # orig_freq=input_sample_rate, new_freq=self.target_sample_rate
322
- # ).to(audio.device)
323
- # return resampler(audio)
324
 
325
 
326
  """
 
185
  self.target_sample_rate = 48000 # CLAP expects 48kHz audio
186
  self.model = laion_clap.CLAP_Module(enable_fusion=False)
187
  self.model.load_ckpt() # download the default pretrained checkpoint
188
+ self.model.eval()
 
 
 
189
 
190
+ def forward(self, input_audio, target, sample_rate, distance_fn='cosine'):
191
  # Process input audio
192
+ input_embed = self.process_audio(input_audio, sample_rate)
 
 
 
 
193
 
194
  # Process target (audio or text)
195
+ if isinstance(target, torch.Tensor):
196
+ target_embed = self.process_audio(target, sample_rate)
197
+ elif isinstance(target, str) or (isinstance(target, list) and isinstance(target[0], str)):
198
+ target_embed = self.process_text(target)
199
+ else:
200
+ raise ValueError("Target must be either audio tensor or text (string or list of strings)")
 
 
201
 
202
  # Compute loss using the specified distance function
203
  loss = self.compute_distance(input_embed, target_embed, distance_fn)
204
 
205
  return loss
206
 
207
+ def process_audio(self, audio, sample_rate):
208
  # Ensure input is in the correct shape (N, C, T)
209
  if audio.dim() == 2:
210
  audio = audio.unsqueeze(1)
 
212
  # Convert to mono if stereo
213
  if audio.shape[1] > 1:
214
  audio = audio.mean(dim=1, keepdim=True)
 
215
  # Resample if necessary
216
  if sample_rate != self.target_sample_rate:
217
  audio = self.resample(audio, sample_rate)
218
+ audio = audio.squeeze(1)
219
+
220
+ # Get CLAP embeddings
221
+ embed = self.model.get_audio_embedding_from_data(x=audio, use_tensor=True)
222
+ return embed
223
+
224
+ def process_text(self, text):
225
+ # Get CLAP embeddings for text
226
+ # ensure input is a list of strings
227
+ if not isinstance(text, list):
228
+ text = [text]
229
+ embed = self.model.get_text_embedding(text, use_tensor=True)
230
+ return embed
231
 
232
  def compute_distance(self, x, y, distance_fn):
233
  if distance_fn == 'mse':
 
239
  else:
240
  raise ValueError(f"Unsupported distance function: {distance_fn}")
241
 
242
+ def resample(self, audio, input_sample_rate):
 
 
 
 
 
 
243
  resampler = torchaudio.transforms.Resample(
244
+ orig_freq=input_sample_rate, new_freq=self.target_sample_rate
245
  ).to(audio.device)
246
  return resampler(audio)
247
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
 
250
  """