shaoan xie
Add application file
bc2c9f6
raw
history blame
5.82 kB
import torch
import torch.nn as nn
import torch.distributions as D
from torch.nn import functional as F
import numpy as np
from torch.autograd import Variable
class BaseFlow(nn.Module):
def __init__(self):
super().__init__()
def sample(self, n=1, context=None, **kwargs):
dim = self.dim
if isinstance(self.dim, int):
dim = [dim, ]
spl = Variable(torch.FloatTensor(n, *dim).normal_())
lgd = Variable(torch.from_numpy(
np.zeros(n).astype('float32')))
if context is None:
context = Variable(torch.from_numpy(
np.ones((n, self.context_dim)).astype('float32')))
if hasattr(self, 'gpu'):
if self.gpu:
spl = spl.cuda()
lgd = lgd.cuda()
context = context.gpu()
return self.forward((spl, lgd, context))
def cuda(self):
self.gpu = True
return super(BaseFlow, self).cuda()
def varify(x):
return torch.autograd.Variable(torch.from_numpy(x))
def oper(array,oper,axis=-1,keepdims=False):
a_oper = oper(array)
if keepdims:
shape = []
for j,s in enumerate(array.size()):
shape.append(s)
shape[axis] = -1
a_oper = a_oper.view(*shape)
return a_oper
def log_sum_exp(A, axis=-1, sum_op=torch.sum):
maximum = lambda x: x.max(axis)[0]
A_max = oper(A,maximum,axis,True)
summation = lambda x: sum_op(torch.exp(x-A_max), axis)
B = torch.log(oper(A,summation,axis,True)) + A_max
return B
delta = 1e-6
logsigmoid = lambda x: -F.softplus(-x)
log = lambda x: torch.log(x*1e2)-np.log(1e2)
softplus_ = nn.Softplus()
softplus = lambda x: softplus_(x) + delta
def softmax(x, dim=-1):
e_x = torch.exp(x - x.max(dim=dim, keepdim=True)[0])
out = e_x / e_x.sum(dim=dim, keepdim=True)
return out
class DenseSigmoidFlow(nn.Module):
def __init__(self, hidden_dim, in_dim=1, out_dim=1):
super().__init__()
self.in_dim = in_dim
self.hidden_dim = hidden_dim
self.out_dim = out_dim
self.act_a = lambda x: F.softplus(x)
self.act_b = lambda x: x
self.act_w = lambda x: torch.softmax(x, dim=3)
self.act_u = lambda x: torch.softmax(x, dim=3)
self.u_ = torch.nn.Parameter(torch.Tensor(hidden_dim, in_dim))
self.w_ = torch.nn.Parameter(torch.Tensor(out_dim, hidden_dim))
self.num_params = 3* hidden_dim + in_dim
self.reset_parameters()
def reset_parameters(self):
self.u_.data.uniform_(-0.001, 0.001)
self.w_.data.uniform_(-0.001, 0.001)
def forward(self, x, dsparams):
delta = 1e-7
inv = np.log(np.exp(1 - delta) - 1)
ndim = self.hidden_dim
pre_u = self.u_[None, None, :, :] + dsparams[:, :, -self.in_dim:][:, :, None, :]
pre_w = self.w_[None, None, :, :] + dsparams[:, :, 2 * ndim:3 * ndim][:, :, None, :]
a = self.act_a(dsparams[:, :, 0 * ndim:1 * ndim] + inv)
b = self.act_b(dsparams[:, :, 1 * ndim:2 * ndim])
w = self.act_w(pre_w)
u = self.act_u(pre_u)
pre_sigm = torch.sum(u * a[:, :, :, None] * x[:, :, None, :], 3) + b
sigm = torch.selu(pre_sigm)
x_pre = torch.sum(w * sigm[:, :, None, :], dim=3)
#x_ = torch.special.logit(x_pre, eps=1e-5)
#xnew = x_
xnew = x_pre
return xnew
class DDSF(nn.Module):
def __init__(self, n_blocks=1, hidden_dim=16):
super().__init__()
self.num_params = 0
if n_blocks == 1:
model = [DenseSigmoidFlow(hidden_dim, in_dim=1, out_dim=1)]
else:
model = [DenseSigmoidFlow(hidden_dim=hidden_dim, in_dim=1, out_dim=hidden_dim)]
for _ in range(n_blocks-2):
model += [DenseSigmoidFlow(hidden_dim=hidden_dim, in_dim=hidden_dim, out_dim=hidden_dim)]
model += [DenseSigmoidFlow(hidden_dim=hidden_dim, in_dim=hidden_dim, out_dim=1)]
self.model = nn.Sequential(*model)
for block in self.model:
self.num_params += block.num_params
def forward(self, x, dsparams):
x = x.unsqueeze(2)
start = 0
for block in self.model:
block_dsparams = dsparams[:,:,start:start+block.num_params]
x = block(x, block_dsparams)
start += block.num_params
return x.squeeze(2)
def compute_jacobian(inputs, outputs):
batch_size = outputs.size(0)
outVector = torch.sum(outputs,0).view(-1)
outdim = outVector.size()[0]
jac = torch.stack([torch.autograd.grad(outVector[i], inputs,
retain_graph=True, create_graph=True)[0].view(batch_size, outdim) for i in range(outdim)], dim=1)
jacs = [jac[i,:,:] for i in range(batch_size)]
print(jacs[1])
if __name__ == '__main__':
flow = DDSF(n_blocks=10, hidden_dim=50)
x = torch.arange(20).view(10, 2)/10.-1.
x = Variable(x, requires_grad=True)
dsparams = torch.randn(1, 2, 2*flow.num_params).repeat(10,1,1)
y = flow(x, dsparams)
print(x, y)
compute_jacobian(x, y)
"""
flow = ConvDenseSigmoidFlow(1,256,1)
dsparams = torch.randn(1, 2, 1000).repeat(10,1,1)
x = torch.arange(20).view(10,2,1).repeat(1,1,4).view(10,2,2,2)/10.
print(x.size(), dsparams.size())
out = flow(x, dsparams)
print(x, out.flatten(2), out.size())
flow = ConvDDSF(n_blocks=3)
dsparams = torch.randn(1, 2, flow.num_params).repeat(10,1,1)
x = torch.arange(80).view(10,2,4).view(10,2,2,2)/10
print(x.size(), dsparams.size())
out = flow(x, dsparams)
print(x, out.flatten(2), out.size())
"""