import torch import torch.nn as nn import torch.distributions as D from torch.nn import functional as F import numpy as np from torch.autograd import Variable class BaseFlow(nn.Module): def __init__(self): super().__init__() def sample(self, n=1, context=None, **kwargs): dim = self.dim if isinstance(self.dim, int): dim = [dim, ] spl = Variable(torch.FloatTensor(n, *dim).normal_()) lgd = Variable(torch.from_numpy( np.zeros(n).astype('float32'))) if context is None: context = Variable(torch.from_numpy( np.ones((n, self.context_dim)).astype('float32'))) if hasattr(self, 'gpu'): if self.gpu: spl = spl.cuda() lgd = lgd.cuda() context = context.gpu() return self.forward((spl, lgd, context)) def cuda(self): self.gpu = True return super(BaseFlow, self).cuda() def varify(x): return torch.autograd.Variable(torch.from_numpy(x)) def oper(array,oper,axis=-1,keepdims=False): a_oper = oper(array) if keepdims: shape = [] for j,s in enumerate(array.size()): shape.append(s) shape[axis] = -1 a_oper = a_oper.view(*shape) return a_oper def log_sum_exp(A, axis=-1, sum_op=torch.sum): maximum = lambda x: x.max(axis)[0] A_max = oper(A,maximum,axis,True) summation = lambda x: sum_op(torch.exp(x-A_max), axis) B = torch.log(oper(A,summation,axis,True)) + A_max return B delta = 1e-6 logsigmoid = lambda x: -F.softplus(-x) log = lambda x: torch.log(x*1e2)-np.log(1e2) softplus_ = nn.Softplus() softplus = lambda x: softplus_(x) + delta def softmax(x, dim=-1): e_x = torch.exp(x - x.max(dim=dim, keepdim=True)[0]) out = e_x / e_x.sum(dim=dim, keepdim=True) return out class DenseSigmoidFlow(nn.Module): def __init__(self, hidden_dim, in_dim=1, out_dim=1): super().__init__() self.in_dim = in_dim self.hidden_dim = hidden_dim self.out_dim = out_dim self.act_a = lambda x: F.softplus(x) self.act_b = lambda x: x self.act_w = lambda x: torch.softmax(x, dim=3) self.act_u = lambda x: torch.softmax(x, dim=3) self.u_ = torch.nn.Parameter(torch.Tensor(hidden_dim, in_dim)) self.w_ = torch.nn.Parameter(torch.Tensor(out_dim, hidden_dim)) self.num_params = 3* hidden_dim + in_dim self.reset_parameters() def reset_parameters(self): self.u_.data.uniform_(-0.001, 0.001) self.w_.data.uniform_(-0.001, 0.001) def forward(self, x, dsparams): delta = 1e-7 inv = np.log(np.exp(1 - delta) - 1) ndim = self.hidden_dim pre_u = self.u_[None, None, :, :] + dsparams[:, :, -self.in_dim:][:, :, None, :] pre_w = self.w_[None, None, :, :] + dsparams[:, :, 2 * ndim:3 * ndim][:, :, None, :] a = self.act_a(dsparams[:, :, 0 * ndim:1 * ndim] + inv) b = self.act_b(dsparams[:, :, 1 * ndim:2 * ndim]) w = self.act_w(pre_w) u = self.act_u(pre_u) pre_sigm = torch.sum(u * a[:, :, :, None] * x[:, :, None, :], 3) + b sigm = torch.selu(pre_sigm) x_pre = torch.sum(w * sigm[:, :, None, :], dim=3) #x_ = torch.special.logit(x_pre, eps=1e-5) #xnew = x_ xnew = x_pre return xnew class DDSF(nn.Module): def __init__(self, n_blocks=1, hidden_dim=16): super().__init__() self.num_params = 0 if n_blocks == 1: model = [DenseSigmoidFlow(hidden_dim, in_dim=1, out_dim=1)] else: model = [DenseSigmoidFlow(hidden_dim=hidden_dim, in_dim=1, out_dim=hidden_dim)] for _ in range(n_blocks-2): model += [DenseSigmoidFlow(hidden_dim=hidden_dim, in_dim=hidden_dim, out_dim=hidden_dim)] model += [DenseSigmoidFlow(hidden_dim=hidden_dim, in_dim=hidden_dim, out_dim=1)] self.model = nn.Sequential(*model) for block in self.model: self.num_params += block.num_params def forward(self, x, dsparams): x = x.unsqueeze(2) start = 0 for block in self.model: block_dsparams = dsparams[:,:,start:start+block.num_params] x = block(x, block_dsparams) start += block.num_params return x.squeeze(2) def compute_jacobian(inputs, outputs): batch_size = outputs.size(0) outVector = torch.sum(outputs,0).view(-1) outdim = outVector.size()[0] jac = torch.stack([torch.autograd.grad(outVector[i], inputs, retain_graph=True, create_graph=True)[0].view(batch_size, outdim) for i in range(outdim)], dim=1) jacs = [jac[i,:,:] for i in range(batch_size)] print(jacs[1]) if __name__ == '__main__': flow = DDSF(n_blocks=10, hidden_dim=50) x = torch.arange(20).view(10, 2)/10.-1. x = Variable(x, requires_grad=True) dsparams = torch.randn(1, 2, 2*flow.num_params).repeat(10,1,1) y = flow(x, dsparams) print(x, y) compute_jacobian(x, y) """ flow = ConvDenseSigmoidFlow(1,256,1) dsparams = torch.randn(1, 2, 1000).repeat(10,1,1) x = torch.arange(20).view(10,2,1).repeat(1,1,4).view(10,2,2,2)/10. print(x.size(), dsparams.size()) out = flow(x, dsparams) print(x, out.flatten(2), out.size()) flow = ConvDDSF(n_blocks=3) dsparams = torch.randn(1, 2, flow.num_params).repeat(10,1,1) x = torch.arange(80).view(10,2,4).view(10,2,2,2)/10 print(x.size(), dsparams.size()) out = flow(x, dsparams) print(x, out.flatten(2), out.size()) """