File size: 4,879 Bytes
05b4fca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import functools
import numpy as np
import torch
import torch.nn as nn
from sgmse.util.registry import Registry
BackboneRegistry = Registry("Backbone")
class GaussianFourierProjection(nn.Module):
"""Gaussian random features for encoding time steps."""
def __init__(self, embed_dim, scale=16, complex_valued=False):
super().__init__()
self.complex_valued = complex_valued
if not complex_valued:
# If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
# Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
# we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
# and this halving is not necessary.
embed_dim = embed_dim // 2
# Randomly sample weights during initialization. These weights are fixed
# during optimization and are not trainable.
self.W = nn.Parameter(torch.randn(embed_dim) * scale, requires_grad=False)
def forward(self, t):
t_proj = t[:, None] * self.W[None, :] * 2*np.pi
if self.complex_valued:
return torch.exp(1j * t_proj)
else:
return torch.cat([torch.sin(t_proj), torch.cos(t_proj)], dim=-1)
class DiffusionStepEmbedding(nn.Module):
"""Diffusion-Step embedding as in DiffWave / Vaswani et al. 2017."""
def __init__(self, embed_dim, complex_valued=False):
super().__init__()
self.complex_valued = complex_valued
if not complex_valued:
# If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
# Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
# we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
# and this halving is not necessary.
embed_dim = embed_dim // 2
self.embed_dim = embed_dim
def forward(self, t):
fac = 10**(4*torch.arange(self.embed_dim, device=t.device) / (self.embed_dim-1))
inner = t[:, None] * fac[None, :]
if self.complex_valued:
return torch.exp(1j * inner)
else:
return torch.cat([torch.sin(inner), torch.cos(inner)], dim=-1)
class ComplexLinear(nn.Module):
"""A potentially complex-valued linear layer. Reduces to a regular linear layer if `complex_valued=False`."""
def __init__(self, input_dim, output_dim, complex_valued):
super().__init__()
self.complex_valued = complex_valued
if self.complex_valued:
self.re = nn.Linear(input_dim, output_dim)
self.im = nn.Linear(input_dim, output_dim)
else:
self.lin = nn.Linear(input_dim, output_dim)
def forward(self, x):
if self.complex_valued:
return (self.re(x.real) - self.im(x.imag)) + 1j*(self.re(x.imag) + self.im(x.real))
else:
return self.lin(x)
class FeatureMapDense(nn.Module):
"""A fully connected layer that reshapes outputs to feature maps."""
def __init__(self, input_dim, output_dim, complex_valued=False):
super().__init__()
self.complex_valued = complex_valued
self.dense = ComplexLinear(input_dim, output_dim, complex_valued=complex_valued)
def forward(self, x):
return self.dense(x)[..., None, None]
def torch_complex_from_reim(re, im):
return torch.view_as_complex(torch.stack([re, im], dim=-1))
class ArgsComplexMultiplicationWrapper(nn.Module):
"""Adapted from `asteroid`'s `complex_nn.py`, allowing args/kwargs to be passed through forward().
Make a complex-valued module `F` from a real-valued module `f` by applying
complex multiplication rules:
F(a + i b) = f1(a) - f1(b) + i (f2(b) + f2(a))
where `f1`, `f2` are instances of `f` that do *not* share weights.
Args:
module_cls (callable): A class or function that returns a Torch module/functional.
Constructor of `f` in the formula above. Called 2x with `*args`, `**kwargs`,
to construct the real and imaginary component modules.
"""
def __init__(self, module_cls, *args, **kwargs):
super().__init__()
self.re_module = module_cls(*args, **kwargs)
self.im_module = module_cls(*args, **kwargs)
def forward(self, x, *args, **kwargs):
return torch_complex_from_reim(
self.re_module(x.real, *args, **kwargs) - self.im_module(x.imag, *args, **kwargs),
self.re_module(x.imag, *args, **kwargs) + self.im_module(x.real, *args, **kwargs),
)
ComplexConv2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.Conv2d)
ComplexConvTranspose2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.ConvTranspose2d)
|