Spaces:
Build error
Build error
import torch | |
import numpy as np | |
import torch.multiprocessing as mp | |
from deepafx_st.processors.dsp.peq import ParametricEQ | |
from deepafx_st.processors.dsp.compressor import Compressor | |
from deepafx_st.processors.spsa.spsa_func import SPSAFunction | |
from deepafx_st.utils import rademacher | |
def dsp_func(x, p, dsp, sample_rate=24000): | |
(peq, comp), meta = dsp | |
p_peq = p[:meta] | |
p_comp = p[meta:] | |
y = peq(x, p_peq, sample_rate) | |
y = comp(y, p_comp, sample_rate) | |
return y | |
class SPSAChannel(torch.nn.Module): | |
""" | |
Args: | |
sample_rate (float): Sample rate of the plugin instance | |
parallel (bool, optional): Use parallel workers for DSP. | |
By default, this utilizes parallelized instances of the plugin channel, | |
where the number of workers is equal to the batch size. | |
""" | |
def __init__( | |
self, | |
sample_rate: int, | |
parallel: bool = False, | |
batch_size: int = 8, | |
): | |
super().__init__() | |
self.batch_size = batch_size | |
self.parallel = parallel | |
if self.parallel: | |
self.apply_func = SPSAFunction.apply | |
procs = {} | |
for b in range(self.batch_size): | |
peq = ParametricEQ(sample_rate) | |
comp = Compressor(sample_rate) | |
dsp = ((peq, comp), peq.num_control_params) | |
parent_conn, child_conn = mp.Pipe() | |
p = mp.Process(target=SPSAChannel.worker_pipe, args=(child_conn, dsp)) | |
p.start() | |
procs[b] = [p, parent_conn, child_conn] | |
#print(b, p) | |
# Update stuff for external public members TODO: fix | |
self.ports = [peq.ports, comp.ports] | |
self.num_control_params = ( | |
comp.num_control_params + peq.num_control_params | |
) | |
self.procs = procs | |
#print(self.procs) | |
else: | |
self.peq = ParametricEQ(sample_rate) | |
self.comp = Compressor(sample_rate) | |
self.apply_func = SPSAFunction.apply | |
self.ports = [self.peq.ports, self.comp.ports] | |
self.num_control_params = ( | |
self.comp.num_control_params + self.peq.num_control_params | |
) | |
self.dsp = ((self.peq, self.comp), self.peq.num_control_params) | |
# add one param for wet/dry mix | |
# self.num_control_params += 1 | |
def __del__(self): | |
if hasattr(self, "procs"): | |
for proc_idx, proc in self.procs.items(): | |
#print(f"Closing {proc_idx}...") | |
proc[0].terminate() | |
def forward(self, x, p, epsilon=0.001, sample_rate=24000, **kwargs): | |
""" | |
Args: | |
x (Tensor): Input signal with shape: [batch x channels x samples] | |
p (Tensor): Audio effect control parameters with shape: [batch x parameters] | |
epsilon (float, optional): Twiddle parameter range for SPSA gradient estimation. | |
Returns: | |
y (Tensor): Processed audio signal. | |
""" | |
if self.parallel: | |
y = self.apply_func(x, p, None, epsilon, self, sample_rate) | |
else: | |
# this will process on CPU in NumPy | |
y = self.apply_func(x, p, None, epsilon, self, sample_rate) | |
return y.type_as(x) | |
def static_backward(dsp, value): | |
( | |
batch_index, | |
x, | |
params, | |
needs_input_grad, | |
needs_param_grad, | |
grad_output, | |
epsilon, | |
) = value | |
grads_input = None | |
grads_params = None | |
ps = params.shape[-1] | |
factors = [1.0] | |
# estimate gradient w.r.t input | |
if needs_input_grad: | |
delta_k = rademacher(x.shape).numpy() | |
J_plus = dsp_func(x + epsilon * delta_k, params, dsp) | |
J_minus = dsp_func(x - epsilon * delta_k, params, dsp) | |
grads_input = (J_plus - J_minus) / (2.0 * epsilon) | |
# estimate gradient w.r.t params | |
grads_params_runs = [] | |
if needs_param_grad: | |
for factor in factors: | |
params_sublist = [] | |
delta_k = rademacher(params.shape).numpy() | |
# compute output in two random directions of the parameter space | |
params_plus = np.clip(params + (factor * epsilon * delta_k), 0, 1) | |
J_plus = dsp_func(x, params_plus, dsp) | |
params_minus = np.clip(params - (factor * epsilon * delta_k), 0, 1) | |
J_minus = dsp_func(x, params_minus, dsp) | |
grad_param = J_plus - J_minus | |
# compute gradient for each parameter as a function of epsilon and random direction | |
for sub_p_idx in range(ps): | |
grad_p = grad_param / (2 * epsilon * delta_k[sub_p_idx]) | |
params_sublist.append(np.sum(grad_output * grad_p)) | |
grads_params = np.array(params_sublist) | |
grads_params_runs.append(grads_params) | |
# average gradients | |
grads_params = np.mean(grads_params_runs, axis=0) | |
return grads_input, grads_params | |
def static_forward(dsp, value): | |
batch_index, x, p, sample_rate = value | |
y = dsp_func(x, p, dsp, sample_rate) | |
return y | |
def worker_pipe(child_conn, dsp): | |
while True: | |
msg, value = child_conn.recv() | |
if msg == "forward": | |
child_conn.send(SPSAChannel.static_forward(dsp, value)) | |
elif msg == "backward": | |
child_conn.send(SPSAChannel.static_backward(dsp, value)) | |
elif msg == "shutdown": | |
break | |