yourusername's picture
:beers: cheers
66a6dc0
import torch
import numpy as np
import torch.multiprocessing as mp
from deepafx_st.processors.dsp.peq import ParametricEQ
from deepafx_st.processors.dsp.compressor import Compressor
from deepafx_st.processors.spsa.spsa_func import SPSAFunction
from deepafx_st.utils import rademacher
def dsp_func(x, p, dsp, sample_rate=24000):
(peq, comp), meta = dsp
p_peq = p[:meta]
p_comp = p[meta:]
y = peq(x, p_peq, sample_rate)
y = comp(y, p_comp, sample_rate)
return y
class SPSAChannel(torch.nn.Module):
"""
Args:
sample_rate (float): Sample rate of the plugin instance
parallel (bool, optional): Use parallel workers for DSP.
By default, this utilizes parallelized instances of the plugin channel,
where the number of workers is equal to the batch size.
"""
def __init__(
self,
sample_rate: int,
parallel: bool = False,
batch_size: int = 8,
):
super().__init__()
self.batch_size = batch_size
self.parallel = parallel
if self.parallel:
self.apply_func = SPSAFunction.apply
procs = {}
for b in range(self.batch_size):
peq = ParametricEQ(sample_rate)
comp = Compressor(sample_rate)
dsp = ((peq, comp), peq.num_control_params)
parent_conn, child_conn = mp.Pipe()
p = mp.Process(target=SPSAChannel.worker_pipe, args=(child_conn, dsp))
p.start()
procs[b] = [p, parent_conn, child_conn]
#print(b, p)
# Update stuff for external public members TODO: fix
self.ports = [peq.ports, comp.ports]
self.num_control_params = (
comp.num_control_params + peq.num_control_params
)
self.procs = procs
#print(self.procs)
else:
self.peq = ParametricEQ(sample_rate)
self.comp = Compressor(sample_rate)
self.apply_func = SPSAFunction.apply
self.ports = [self.peq.ports, self.comp.ports]
self.num_control_params = (
self.comp.num_control_params + self.peq.num_control_params
)
self.dsp = ((self.peq, self.comp), self.peq.num_control_params)
# add one param for wet/dry mix
# self.num_control_params += 1
def __del__(self):
if hasattr(self, "procs"):
for proc_idx, proc in self.procs.items():
#print(f"Closing {proc_idx}...")
proc[0].terminate()
def forward(self, x, p, epsilon=0.001, sample_rate=24000, **kwargs):
"""
Args:
x (Tensor): Input signal with shape: [batch x channels x samples]
p (Tensor): Audio effect control parameters with shape: [batch x parameters]
epsilon (float, optional): Twiddle parameter range for SPSA gradient estimation.
Returns:
y (Tensor): Processed audio signal.
"""
if self.parallel:
y = self.apply_func(x, p, None, epsilon, self, sample_rate)
else:
# this will process on CPU in NumPy
y = self.apply_func(x, p, None, epsilon, self, sample_rate)
return y.type_as(x)
@staticmethod
def static_backward(dsp, value):
(
batch_index,
x,
params,
needs_input_grad,
needs_param_grad,
grad_output,
epsilon,
) = value
grads_input = None
grads_params = None
ps = params.shape[-1]
factors = [1.0]
# estimate gradient w.r.t input
if needs_input_grad:
delta_k = rademacher(x.shape).numpy()
J_plus = dsp_func(x + epsilon * delta_k, params, dsp)
J_minus = dsp_func(x - epsilon * delta_k, params, dsp)
grads_input = (J_plus - J_minus) / (2.0 * epsilon)
# estimate gradient w.r.t params
grads_params_runs = []
if needs_param_grad:
for factor in factors:
params_sublist = []
delta_k = rademacher(params.shape).numpy()
# compute output in two random directions of the parameter space
params_plus = np.clip(params + (factor * epsilon * delta_k), 0, 1)
J_plus = dsp_func(x, params_plus, dsp)
params_minus = np.clip(params - (factor * epsilon * delta_k), 0, 1)
J_minus = dsp_func(x, params_minus, dsp)
grad_param = J_plus - J_minus
# compute gradient for each parameter as a function of epsilon and random direction
for sub_p_idx in range(ps):
grad_p = grad_param / (2 * epsilon * delta_k[sub_p_idx])
params_sublist.append(np.sum(grad_output * grad_p))
grads_params = np.array(params_sublist)
grads_params_runs.append(grads_params)
# average gradients
grads_params = np.mean(grads_params_runs, axis=0)
return grads_input, grads_params
@staticmethod
def static_forward(dsp, value):
batch_index, x, p, sample_rate = value
y = dsp_func(x, p, dsp, sample_rate)
return y
@staticmethod
def worker_pipe(child_conn, dsp):
while True:
msg, value = child_conn.recv()
if msg == "forward":
child_conn.send(SPSAChannel.static_forward(dsp, value))
elif msg == "backward":
child_conn.send(SPSAChannel.static_backward(dsp, value))
elif msg == "shutdown":
break