Spaces:
Running
Running
modify loss
Browse files- __pycache__/inference.cpython-311.pyc +0 -0
- inference.py +4 -4
- modules/__pycache__/loss.cpython-311.pyc +0 -0
- modules/loss.py +26 -102
__pycache__/inference.cpython-311.pyc
CHANGED
Binary files a/__pycache__/inference.cpython-311.pyc and b/__pycache__/inference.cpython-311.pyc differ
|
|
inference.py
CHANGED
@@ -68,7 +68,7 @@ class MasteringStyleTransfer:
|
|
68 |
return output_audio, predicted_params
|
69 |
|
70 |
def inference_time_optimization(self, input_tensor, reference_tensor, ito_config, initial_reference_feature):
|
71 |
-
fit_embedding = torch.nn.Parameter(initial_reference_feature)
|
72 |
optimizer = getattr(torch.optim, ito_config['optimizer'])([fit_embedding], lr=ito_config['learning_rate'])
|
73 |
|
74 |
min_loss = float('inf')
|
@@ -97,9 +97,7 @@ class MasteringStyleTransfer:
|
|
97 |
target = reference_tensor
|
98 |
else:
|
99 |
target = ito_config['clap_text_prompt']
|
100 |
-
print(f'ito_config clap_distance_fn: {ito_config["clap_distance_fn"]}')
|
101 |
total_loss = self.clap_loss(output_audio, target, self.args.sample_rate, distance_fn=ito_config['clap_distance_fn'])
|
102 |
-
print(f'total_loss: {total_loss}')
|
103 |
|
104 |
if total_loss < min_loss:
|
105 |
min_loss = total_loss.item()
|
@@ -122,6 +120,9 @@ class MasteringStyleTransfer:
|
|
122 |
total_loss.backward()
|
123 |
optimizer.step()
|
124 |
|
|
|
|
|
|
|
125 |
return all_results, min_loss_step
|
126 |
|
127 |
def preprocess_audio(self, audio, target_sample_rate=44100, normalize=False):
|
@@ -290,7 +291,6 @@ class MasteringStyleTransfer:
|
|
290 |
|
291 |
return "\n".join(output)
|
292 |
|
293 |
-
|
294 |
def reload_weights(model, ckpt_path, device):
|
295 |
checkpoint = torch.load(ckpt_path, map_location=device)
|
296 |
|
|
|
68 |
return output_audio, predicted_params
|
69 |
|
70 |
def inference_time_optimization(self, input_tensor, reference_tensor, ito_config, initial_reference_feature):
|
71 |
+
fit_embedding = torch.nn.Parameter(initial_reference_feature, requires_grad=True)
|
72 |
optimizer = getattr(torch.optim, ito_config['optimizer'])([fit_embedding], lr=ito_config['learning_rate'])
|
73 |
|
74 |
min_loss = float('inf')
|
|
|
97 |
target = reference_tensor
|
98 |
else:
|
99 |
target = ito_config['clap_text_prompt']
|
|
|
100 |
total_loss = self.clap_loss(output_audio, target, self.args.sample_rate, distance_fn=ito_config['clap_distance_fn'])
|
|
|
101 |
|
102 |
if total_loss < min_loss:
|
103 |
min_loss = total_loss.item()
|
|
|
120 |
total_loss.backward()
|
121 |
optimizer.step()
|
122 |
|
123 |
+
gc.collect()
|
124 |
+
torch.cuda.empty_cache()
|
125 |
+
|
126 |
return all_results, min_loss_step
|
127 |
|
128 |
def preprocess_audio(self, audio, target_sample_rate=44100, normalize=False):
|
|
|
291 |
|
292 |
return "\n".join(output)
|
293 |
|
|
|
294 |
def reload_weights(model, ckpt_path, device):
|
295 |
checkpoint = torch.load(ckpt_path, map_location=device)
|
296 |
|
modules/__pycache__/loss.cpython-311.pyc
CHANGED
Binary files a/modules/__pycache__/loss.cpython-311.pyc and b/modules/__pycache__/loss.cpython-311.pyc differ
|
|
modules/loss.py
CHANGED
@@ -185,35 +185,26 @@ class CLAPFeatureLoss(nn.Module):
|
|
185 |
self.target_sample_rate = 48000 # CLAP expects 48kHz audio
|
186 |
self.model = laion_clap.CLAP_Module(enable_fusion=False)
|
187 |
self.model.load_ckpt() # download the default pretrained checkpoint
|
188 |
-
|
189 |
-
# Freeze the CLAP model parameters
|
190 |
-
for param in self.model.parameters():
|
191 |
-
param.requires_grad = False
|
192 |
|
193 |
-
def forward(self, input_audio, target, sample_rate, distance_fn='
|
194 |
# Process input audio
|
195 |
-
|
196 |
-
input_audio = self.preprocess_audio(input_audio, sample_rate)
|
197 |
-
|
198 |
-
with torch.enable_grad():
|
199 |
-
input_embed = self.model.get_audio_embedding_from_data(x=input_audio, use_tensor=True)
|
200 |
|
201 |
# Process target (audio or text)
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
else:
|
209 |
-
raise ValueError("Target must be either audio tensor or text (string or list of strings)")
|
210 |
|
211 |
# Compute loss using the specified distance function
|
212 |
loss = self.compute_distance(input_embed, target_embed, distance_fn)
|
213 |
|
214 |
return loss
|
215 |
|
216 |
-
def
|
217 |
# Ensure input is in the correct shape (N, C, T)
|
218 |
if audio.dim() == 2:
|
219 |
audio = audio.unsqueeze(1)
|
@@ -221,15 +212,22 @@ class CLAPFeatureLoss(nn.Module):
|
|
221 |
# Convert to mono if stereo
|
222 |
if audio.shape[1] > 1:
|
223 |
audio = audio.mean(dim=1, keepdim=True)
|
224 |
-
|
225 |
# Resample if necessary
|
226 |
if sample_rate != self.target_sample_rate:
|
227 |
audio = self.resample(audio, sample_rate)
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
def compute_distance(self, x, y, distance_fn):
|
235 |
if distance_fn == 'mse':
|
@@ -241,86 +239,12 @@ class CLAPFeatureLoss(nn.Module):
|
|
241 |
else:
|
242 |
raise ValueError(f"Unsupported distance function: {distance_fn}")
|
243 |
|
244 |
-
def
|
245 |
-
audio = audio.squeeze(1) # Remove channel dimension
|
246 |
-
audio = torch.clamp(audio, -1.0, 1.0)
|
247 |
-
audio = (audio * 32767.0).to(torch.int16).to(torch.float32) / 32767.0
|
248 |
-
return audio
|
249 |
-
|
250 |
-
def resample(self, audio, orig_sample_rate):
|
251 |
resampler = torchaudio.transforms.Resample(
|
252 |
-
orig_freq=
|
253 |
).to(audio.device)
|
254 |
return resampler(audio)
|
255 |
-
|
256 |
-
# def forward(self, input_audio, target, sample_rate, distance_fn='cosine'):
|
257 |
-
# # Process input audio
|
258 |
-
# input_embed = self.process_audio(input_audio, sample_rate)
|
259 |
-
|
260 |
-
# # Process target (audio or text)
|
261 |
-
# if isinstance(target, torch.Tensor):
|
262 |
-
# target_embed = self.process_audio(target, sample_rate)
|
263 |
-
# elif isinstance(target, str) or (isinstance(target, list) and isinstance(target[0], str)):
|
264 |
-
# target_embed = self.process_text(target)
|
265 |
-
# else:
|
266 |
-
# raise ValueError("Target must be either audio tensor or text (string or list of strings)")
|
267 |
-
|
268 |
-
# # Compute loss using the specified distance function
|
269 |
-
# loss = self.compute_distance(input_embed, target_embed, distance_fn)
|
270 |
-
|
271 |
-
# return loss
|
272 |
-
|
273 |
-
# def process_audio(self, audio, sample_rate):
|
274 |
-
# # Ensure input is in the correct shape (N, C, T)
|
275 |
-
# if audio.dim() == 2:
|
276 |
-
# audio = audio.unsqueeze(1)
|
277 |
-
|
278 |
-
# # Convert to mono if stereo
|
279 |
-
# if audio.shape[1] > 1:
|
280 |
-
# audio = audio.mean(dim=1, keepdim=True)
|
281 |
-
|
282 |
-
# # Resample if necessary
|
283 |
-
# if sample_rate != self.target_sample_rate:
|
284 |
-
# audio = self.resample(audio, sample_rate)
|
285 |
-
|
286 |
-
# # Quantize audio data
|
287 |
-
# audio = self.quantize(audio)
|
288 |
-
|
289 |
-
# # Get CLAP embeddings
|
290 |
-
# with torch.no_grad():
|
291 |
-
# embed = self.model.get_audio_embedding_from_data(x=audio, use_tensor=True)
|
292 |
-
# return embed
|
293 |
-
|
294 |
-
# def process_text(self, text):
|
295 |
-
# # Get CLAP embeddings for text
|
296 |
-
# # ensure input is a list of strings
|
297 |
-
# if not isinstance(text, list):
|
298 |
-
# text = [text]
|
299 |
-
# with torch.no_grad():
|
300 |
-
# embed = self.model.get_text_embedding(text, use_tensor=True)
|
301 |
-
# return embed
|
302 |
-
|
303 |
-
# def compute_distance(self, x, y, distance_fn):
|
304 |
-
# if distance_fn == 'mse':
|
305 |
-
# return F.mse_loss(x, y)
|
306 |
-
# elif distance_fn == 'l1':
|
307 |
-
# return F.l1_loss(x, y)
|
308 |
-
# elif distance_fn == 'cosine':
|
309 |
-
# return 1 - F.cosine_similarity(x, y).mean()
|
310 |
-
# else:
|
311 |
-
# raise ValueError(f"Unsupported distance function: {distance_fn}")
|
312 |
-
|
313 |
-
# def quantize(self, audio):
|
314 |
-
# audio = audio.squeeze(1) # Remove channel dimension
|
315 |
-
# audio = torch.clamp(audio, -1.0, 1.0)
|
316 |
-
# audio = (audio * 32767.0).to(torch.int16).to(torch.float32) / 32767.0
|
317 |
-
# return audio
|
318 |
-
|
319 |
-
# def resample(self, audio, input_sample_rate):
|
320 |
-
# resampler = torchaudio.transforms.Resample(
|
321 |
-
# orig_freq=input_sample_rate, new_freq=self.target_sample_rate
|
322 |
-
# ).to(audio.device)
|
323 |
-
# return resampler(audio)
|
324 |
|
325 |
|
326 |
"""
|
|
|
185 |
self.target_sample_rate = 48000 # CLAP expects 48kHz audio
|
186 |
self.model = laion_clap.CLAP_Module(enable_fusion=False)
|
187 |
self.model.load_ckpt() # download the default pretrained checkpoint
|
188 |
+
self.model.eval()
|
|
|
|
|
|
|
189 |
|
190 |
+
def forward(self, input_audio, target, sample_rate, distance_fn='cosine'):
|
191 |
# Process input audio
|
192 |
+
input_embed = self.process_audio(input_audio, sample_rate)
|
|
|
|
|
|
|
|
|
193 |
|
194 |
# Process target (audio or text)
|
195 |
+
if isinstance(target, torch.Tensor):
|
196 |
+
target_embed = self.process_audio(target, sample_rate)
|
197 |
+
elif isinstance(target, str) or (isinstance(target, list) and isinstance(target[0], str)):
|
198 |
+
target_embed = self.process_text(target)
|
199 |
+
else:
|
200 |
+
raise ValueError("Target must be either audio tensor or text (string or list of strings)")
|
|
|
|
|
201 |
|
202 |
# Compute loss using the specified distance function
|
203 |
loss = self.compute_distance(input_embed, target_embed, distance_fn)
|
204 |
|
205 |
return loss
|
206 |
|
207 |
+
def process_audio(self, audio, sample_rate):
|
208 |
# Ensure input is in the correct shape (N, C, T)
|
209 |
if audio.dim() == 2:
|
210 |
audio = audio.unsqueeze(1)
|
|
|
212 |
# Convert to mono if stereo
|
213 |
if audio.shape[1] > 1:
|
214 |
audio = audio.mean(dim=1, keepdim=True)
|
|
|
215 |
# Resample if necessary
|
216 |
if sample_rate != self.target_sample_rate:
|
217 |
audio = self.resample(audio, sample_rate)
|
218 |
+
audio = audio.squeeze(1)
|
219 |
+
|
220 |
+
# Get CLAP embeddings
|
221 |
+
embed = self.model.get_audio_embedding_from_data(x=audio, use_tensor=True)
|
222 |
+
return embed
|
223 |
+
|
224 |
+
def process_text(self, text):
|
225 |
+
# Get CLAP embeddings for text
|
226 |
+
# ensure input is a list of strings
|
227 |
+
if not isinstance(text, list):
|
228 |
+
text = [text]
|
229 |
+
embed = self.model.get_text_embedding(text, use_tensor=True)
|
230 |
+
return embed
|
231 |
|
232 |
def compute_distance(self, x, y, distance_fn):
|
233 |
if distance_fn == 'mse':
|
|
|
239 |
else:
|
240 |
raise ValueError(f"Unsupported distance function: {distance_fn}")
|
241 |
|
242 |
+
def resample(self, audio, input_sample_rate):
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
resampler = torchaudio.transforms.Resample(
|
244 |
+
orig_freq=input_sample_rate, new_freq=self.target_sample_rate
|
245 |
).to(audio.device)
|
246 |
return resampler(audio)
|
247 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
|
249 |
|
250 |
"""
|