from math import pi, log from functools import wraps from multiprocessing import context from textwrap import indent import models.util_funcs as util_funcs import math, copy import numpy as np import torch from torch import nn, einsum import torch.nn.functional as F from einops import rearrange, repeat from einops.layers.torch import Reduce import pdb from einops.layers.torch import Rearrange from options import get_parser_main_model opts = get_parser_main_model().parse_args() class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) def forward(self, x): """ :param x: [x_len, batch_size, emb_size] :return: [x_len, batch_size, emb_size] """ x = x + self.pe[:x.size(0), :].to(x.device) return self.dropout(x) def exists(val): return val is not None def default(val, d): return val if exists(val) else d def cache_fn(f): cache = dict() @wraps(f) def cached_fn(*args, _cache = True, key = None, **kwargs): if not _cache: return f(*args, **kwargs) nonlocal cache if key in cache: return cache[key] result = f(*args, **kwargs) cache[key] = result return result return cached_fn def fourier_encode(x, max_freq, num_bands = 4): ''' x: ([64, 64, 2, 1]) is between [-1,1] max_feq is 10 num_bands is 6 ''' x = x.unsqueeze(-1) device, dtype, orig_x = x.device, x.dtype, x scales = torch.linspace(1., max_freq / 2, num_bands, device = device, dtype = dtype) # tensor([1.0000, 1.8000, 2.6000, 3.4000, 4.2000, 5.0000] scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)] # r([[[[1.0000, 1.8000, 2.6000, 3.4000, 4.2000, 5.0000]]]], x = x * scales * pi x = torch.cat([x.sin(), x.cos()], dim = -1) x = torch.cat((x, orig_x), dim = -1) return x class PreNorm(nn.Module): def __init__(self, dim, fn, context_dim = None): super().__init__() self.fn = fn self.norm = nn.LayerNorm(dim) self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None def forward(self, x, **kwargs): x = self.norm(x) if exists(self.norm_context): context = kwargs['context'] normed_context = self.norm_context(context) kwargs.update(context = normed_context) return self.fn(x, **kwargs) class GEGLU(nn.Module): def forward(self, x): x, gates = x.chunk(2, dim = -1) return x * F.gelu(gates) class FeedForward(nn.Module): def __init__(self, dim, mult = 4, dropout = 0.): super().__init__() self.net = nn.Sequential( nn.Linear(dim, dim * mult * 2), GEGLU(), nn.Linear(dim * mult, dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) class Attention(nn.Module): def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, dropout = 0.,cls_conv_dim=None): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) self.scale = dim_head ** -0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias = False) self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) # 27 to 5012*2 = 1024 self.dropout = nn.Dropout(dropout) self.to_out = nn.Linear(inner_dim, query_dim) #self.cls_dim_adjust = nn.Linear(context_dim,cls_conv_dim) def forward(self, x, context = None, mask = None, ref_cls_onehot=None): h = self.heads q = self.to_q(x) context = default(context, x) k, v = self.to_kv(context).chunk(2, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v)) sim = einsum('b i d, b j d -> b i j', q, k) * self.scale if exists(mask): mask = repeat(mask, 'b j k -> (b h) k j', h = h) sim.masked_fill(mask == 0, -1e9) # attention, what we cannot get enough of attn = sim.softmax(dim = -1) attn = self.dropout(attn) out = einsum('b i j, b j d -> b i d', attn, v) out = rearrange(out, '(b h) n d -> b n (h d)', h = h) return self.to_out(out), attn class SVGEmbedding(nn.Module): def __init__(self): super().__init__() self.command_embed = nn.Embedding(4, 512) self.arg_embed = nn.Embedding(128, 128,padding_idx=0) self.embed_fcn = nn.Linear(128 * 8, 512) self.pos_encoding = PositionalEncoding(d_model=opts.hidden_size, max_len=opts.max_seq_len + 1) self._init_embeddings() def _init_embeddings(self): nn.init.kaiming_normal_(self.command_embed.weight, mode="fan_in") nn.init.kaiming_normal_(self.arg_embed.weight, mode="fan_in") nn.init.kaiming_normal_(self.embed_fcn.weight, mode="fan_in") def forward(self, commands, args, groups=None): S, GN,_ = commands.shape src = self.command_embed(commands.long()).squeeze() + \ self.embed_fcn(self.arg_embed((args).long()).view(S, GN, -1)) # shift due to -1 PAD_VAL src = self.pos_encoding(src) return src class PositionwiseFeedForward(nn.Module): "Implements FFN equation." def __init__(self, d_model, d_ff, dropout): super(PositionwiseFeedForward, self).__init__() self.w_1 = nn.Linear(d_model, d_ff) self.w_2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) def forward(self, x): return self.w_2(F.relu(self.dropout(self.w_1(x)))) class Transformer_decoder(nn.Module): def __init__(self): super().__init__() self.SVG_embedding = SVGEmbedding() self.command_fcn = nn.Linear(512, 4) self.args_fcn = nn.Linear(512, 8 * 128) c = copy.deepcopy attn = MultiHeadedAttention(h=8, d_model=512, dropout=0.0) ff = PositionwiseFeedForward(d_model=512, d_ff=1024, dropout=0.0) self.decoder_layers = clones(DecoderLayer(512, c(attn), c(attn),c(ff), dropout=0.0), 6) self.decoder_norm = nn.LayerNorm(512) self.decoder_layers_parallel = clones(DecoderLayer(512, c(attn), c(attn), c(ff), dropout=0.0), 1) self.decoder_norm_parallel = nn.LayerNorm(512) if opts.ref_nshot == 52: self.cls_embedding = nn.Embedding(96,512) else: self.cls_embedding = nn.Embedding(52,512) self.cls_token = nn.Parameter(torch.zeros(1, 1, 512)) def forward(self, x, memory, trg_char, src_mask=None, tgt_mask=None): memory = memory.unsqueeze(1) commands = x[:, :, :1] args = x[:, :, 1:] x = self.SVG_embedding(commands, args).transpose(0,1) trg_char = trg_char.long() trg_char = self.cls_embedding(trg_char) x[:, 0:1, :] = trg_char tgt_mask = tgt_mask.squeeze() for layer in self.decoder_layers: x,attn = layer(x, memory, src_mask, tgt_mask) out = self.decoder_norm(x) N, S, _ = out.shape cmd_logits = self.command_fcn(out) args_logits = self.args_fcn(out) # shape: bs, max_len, 8, 256 args_logits = args_logits.reshape(N, S, 8, 128) return cmd_logits,args_logits,attn def parallel_decoder(self, cmd_logits, args_logits, memory, trg_char): memory = memory.unsqueeze(1) cmd_args_mask = torch.Tensor([[0, 0, 0., 0., 0., 0., 0., 0.], [1, 1, 0., 0., 0., 0., 1., 1.], [1, 1, 0., 0., 0., 0., 1., 1.], [1, 1, 1., 1., 1., 1., 1., 1.]]).to(cmd_logits.device) if opts.mode == 'train': cmd2 = torch.argmax(cmd_logits, -1).unsqueeze(-1).transpose(0, 1) arg2 = torch.argmax(args_logits, -1).transpose(0, 1) cmd2paddingmask = _get_key_padding_mask(cmd2).transpose(0,1).unsqueeze(-1).to(cmd2.device) cmd2 = cmd2 * cmd2paddingmask args_mask = torch.matmul(F.one_hot(cmd2.long(),4).float(), cmd_args_mask).transpose(-1,-2).squeeze(-1) arg2 = arg2 * args_mask x = self.SVG_embedding(cmd2, arg2).transpose(0, 1) else: cmd2 = cmd_logits arg2 = args_logits cmd2paddingmask = _get_key_padding_mask(cmd2).transpose(0, 1).unsqueeze(-1).to(cmd2.device) cmd2 = cmd2 * cmd2paddingmask args_mask = torch.matmul(F.one_hot(cmd2.long(),4).float(), cmd_args_mask).transpose(-1, -2).squeeze(-1) arg2 = arg2 * args_mask x = self.SVG_embedding(cmd2, arg2).transpose(0,1) S = x.size(1) B = x.size(0) tgt_mask = torch.ones(S,S).to(x.device).unsqueeze(0).repeat(B, 1, 1) cmd2paddingmask = cmd2paddingmask.transpose(0, 1).transpose(-1, -2) tgt_mask = tgt_mask * cmd2paddingmask trg_char = trg_char.long() trg_char = self.cls_embedding(trg_char) x = torch.cat([trg_char, x],1) x[:, 0:1, :] = trg_char x = x[:,:opts.max_seq_len,:] tgt_mask = tgt_mask #*tri for layer in self.decoder_layers_parallel: x, attn = layer(x, memory, src_mask=None, tgt_mask=tgt_mask) out = self.decoder_norm_parallel(x) N, S, _ = out.shape cmd_logits = self.command_fcn(out) args_logits = self.args_fcn(out) args_logits = args_logits.reshape(N, S, 8, 128) return cmd_logits, args_logits def _get_key_padding_mask(commands, seq_dim=0): """ Args: commands: Shape [S, ...] """ lens =[] with torch.no_grad(): key_padding_mask = (commands == 0).cumsum(dim=seq_dim) > 0 commands=commands.transpose(0,1).squeeze(-1) #bs, opts.max_seq_len for i in range(commands.size(0)): try: seqi = commands[i]#blue opts.max_seq_len index = torch.where(seqi==0)[0][0] except: index=opts.max_seq_len lens.append(index) lens = torch.tensor(lens)+1#blue b seqlen_mask = util_funcs.sequence_mask(lens, opts.max_seq_len)#blue b,opts.max_seq_len return seqlen_mask class Transformer(nn.Module): def __init__( self, *, num_freq_bands, depth, max_freq, input_channels = 1, input_axis = 2, num_latents = 512, latent_dim = 512, cross_heads = 1, latent_heads = 8, cross_dim_head = 64, latent_dim_head = 64, num_classes = 1000, attn_dropout = 0., ff_dropout = 0., weight_tie_layers = False, fourier_encode_data = True, self_per_cross_attn = 2, final_classifier_head = True ): """The shape of the final attention mechanism will be: depth * (cross attention -> self_per_cross_attn * self attention) Args: num_freq_bands: Number of freq bands, with original value (2 * K + 1) depth: Depth of net. max_freq: Maximum frequency, hyperparameter depending on how fine the data is. freq_base: Base for the frequency input_channels: Number of channels for each token of the input. input_axis: Number of axes for input data (2 for images, 3 for video) num_latents: Number of latents, or induced set points, or centroids. Different papers giving it different names. latent_dim: Latent dimension. cross_heads: Number of heads for cross attention. Paper said 1. latent_heads: Number of heads for latent self attention, 8. cross_dim_head: Number of dimensions per cross attention head. latent_dim_head: Number of dimensions per latent self attention head. num_classes: Output number of classes. attn_dropout: Attention dropout ff_dropout: Feedforward dropout weight_tie_layers: Whether to weight tie layers (optional). fourier_encode_data: Whether to auto-fourier encode the data, using the input_axis given. defaults to True, but can be turned off if you are fourier encoding the data yourself. self_per_cross_attn: Number of self attention blocks per cross attn. final_classifier_head: mean pool and project embeddings to number of classes (num_classes) at the end """ super().__init__() self.input_axis = input_axis self.max_freq = max_freq self.num_freq_bands = num_freq_bands self.fourier_encode_data = fourier_encode_data fourier_channels = (input_axis * ((num_freq_bands * 2) + 1)) if fourier_encode_data else 0 # 26 input_dim = fourier_channels + input_channels self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) get_cross_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, input_dim, heads=cross_heads, dim_head=cross_dim_head, dropout=attn_dropout), context_dim=input_dim) get_cross_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout)) get_latent_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, heads=latent_heads, dim_head=latent_dim_head, dropout=attn_dropout)) get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout=ff_dropout)) get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff = map(cache_fn, (get_cross_attn, get_cross_ff, get_latent_attn, get_latent_ff)) #self_per_cross_attn=1 self.layers = nn.ModuleList([]) for i in range(depth): should_cache = i > 0 and weight_tie_layers cache_args = {'_cache': should_cache} self_attns = nn.ModuleList([]) for block_ind in range(self_per_cross_attn): #BUG 之前是2 self_per_cross_attn self_attns.append(nn.ModuleList([ get_latent_attn(**cache_args, key = block_ind), get_latent_ff(**cache_args, key = block_ind) ])) self.layers.append(nn.ModuleList([ get_cross_attn(**cache_args), get_cross_ff(**cache_args), self_attns ])) get_cross_attn2 = lambda: PreNorm(latent_dim, Attention(latent_dim, input_dim, heads = cross_heads, dim_head = cross_dim_head, dropout = attn_dropout), context_dim = input_dim) get_cross_ff2 = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout)) get_latent_attn2 = lambda: PreNorm(latent_dim, Attention(latent_dim, heads = latent_heads, dim_head = latent_dim_head, dropout = attn_dropout)) get_latent_ff2 = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout)) get_cross_attn2, get_cross_ff2, get_latent_attn2, get_latent_ff2 = map(cache_fn, (get_cross_attn2, get_cross_ff2, get_latent_attn2, get_latent_ff2)) self.layers_cnnsvg = nn.ModuleList([]) for i in range(1): should_cache = i > 0 and weight_tie_layers cache_args = {'_cache': should_cache} self_attns2 = nn.ModuleList([]) for block_ind in range(self_per_cross_attn): self_attns2.append(nn.ModuleList([ get_latent_attn2(**cache_args, key = block_ind), get_latent_ff2(**cache_args, key = block_ind) ])) self.layers_cnnsvg.append(nn.ModuleList([ get_cross_attn2(**cache_args), get_cross_ff2(**cache_args), self_attns2 ])) self.to_logits = nn.Sequential( Reduce('b n d -> b d', 'mean'), nn.LayerNorm(latent_dim), nn.Linear(latent_dim, num_classes) ) if final_classifier_head else nn.Identity() self.pre_lstm_fc = nn.Linear(10,opts.hidden_size) self.posr = PositionalEncoding(d_model=opts.hidden_size,max_len=opts.max_seq_len) patch_height = 2 patch_width = 2 patch_dim = 1 * patch_height * patch_width self.to_patch_embedding = nn.Sequential( Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), nn.Linear(patch_dim, 16), ) self.SVG_embedding = SVGEmbedding() self.cls_token = nn.Parameter(torch.zeros(1, 1, 512)) def forward(self, data, seq, ref_cls_onehot=None, mask=None, return_embeddings=True): b, *axis, _, device, dtype = *data.shape, data.device, data.dtype assert len(axis) == self.input_axis, 'input data must have the right number of axis' # img is 2 x = seq commands=x[:, :, :1] args=x[:, :, 1:] x = self.SVG_embedding(commands, args).transpose(0,1) cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = x.size(0)) x = torch.cat([cls_tokens,x],dim = 1) cls_one_pad = torch.ones((1,1,1)).to(x.device).repeat(x.size(0),1,1) mask = torch.cat([cls_one_pad,mask],dim=-1) self_atten = [] for cross_attn, cross_ff, self_attns in self.layers: for self_attn, self_ff in self_attns: x_,atten = self_attn(x,mask=mask) x = x_ + x self_atten.append(atten) x = self_ff(x) + x x = x + torch.randn_like(x) # add a perturbation return x, self_atten def att_residual(self, x, mask=None): for cross_attn, cross_ff, self_attns in self.layers_cnnsvg: for self_attn, self_ff in self_attns: x_, atten = self_attn(x) x = x_ + x x = self_ff(x) + x return x def loss(self, cmd_logits, args_logits, trg_seq, trg_seqlen, trg_pts_aux): ''' Inputs: cmd_logits: [b, 51, 4] args_logits: [b, 51, 6] ''' cmd_args_mask = torch.Tensor([[0, 0, 0., 0., 0., 0., 0., 0.], [1, 1, 0., 0., 0., 0., 1., 1.], [1, 1, 0., 0., 0., 0., 1., 1.], [1, 1, 1., 1., 1., 1., 1., 1.]]).to(cmd_logits.device) tgt_commands = trg_seq[:,:,:1].transpose(0,1) tgt_args = trg_seq[:,:,1:].transpose(0,1) seqlen_mask = util_funcs.sequence_mask(trg_seqlen, opts.max_seq_len).unsqueeze(-1) seqlen_mask2 = seqlen_mask.repeat(1,1,4)# NOTE b,501,4 seqlen_mask4 = seqlen_mask.repeat(1,1,8) seqlen_mask3 = seqlen_mask.unsqueeze(-1).repeat(1,1,8,128) tgt_commands_onehot = F.one_hot(tgt_commands, 4) tgt_args_onehot = F.one_hot(tgt_args, 128) args_mask = torch.matmul(tgt_commands_onehot.float(),cmd_args_mask).squeeze() loss_cmd = torch.sum(- tgt_commands_onehot.squeeze() * F.log_softmax(cmd_logits, -1), -1) loss_cmd = torch.mul(loss_cmd, seqlen_mask.squeeze()) loss_cmd = torch.mean(torch.sum(loss_cmd/trg_seqlen.unsqueeze(-1),-1)) loss_args = (torch.sum(-tgt_args_onehot*F.log_softmax(args_logits,-1),-1)*seqlen_mask4*args_mask) loss_args = torch.mean(loss_args,dim=-1,keepdim=False) loss_args = torch.mean(torch.sum(loss_args/trg_seqlen.unsqueeze(-1),-1)) SE_mask = torch.Tensor([[1, 1], [0, 0], [1, 1], [1, 1]]).to(cmd_logits.device) SE_args_mask = torch.matmul(tgt_commands_onehot.float(),SE_mask).squeeze().unsqueeze(-1) args_prob = F.softmax(args_logits, -1) args_end = args_prob[:,:,6:] args_end_shifted = torch.cat((torch.zeros(args_end.size(0),1,args_end.size(2),args_end.size(3)).to(args_end.device),args_end),1) args_end_shifted = args_end_shifted[:,:opts.max_seq_len,:,:] args_end_shifted = args_end_shifted*SE_args_mask + args_end*(1-SE_args_mask) args_start = args_prob[:,:,:2] seqlen_mask5 = util_funcs.sequence_mask(trg_seqlen-1, opts.max_seq_len).unsqueeze(-1) seqlen_mask5 = seqlen_mask5.repeat(1,1,2) smooth_constrained = torch.sum(torch.pow((args_end_shifted - args_start), 2), -1) * seqlen_mask5 smooth_constrained = torch.mean(smooth_constrained, dim=-1, keepdim=False) smooth_constrained = torch.mean(torch.sum(smooth_constrained / (trg_seqlen - 1).unsqueeze(-1), -1)) args_prob2 = F.softmax(args_logits / 0.1, -1) c = torch.argmax(args_prob2,-1).unsqueeze(-1).float() - args_prob2.detach() p_argmax = args_prob2 + c p_argmax = torch.mean(p_argmax,-1) control_pts = denumericalize(p_argmax) p0 = control_pts[:,:,:2] p1 = control_pts[:,:,2:4] p2 = control_pts[:,:,4:6] p3 = control_pts[:,:,6:8] line_mask = (tgt_commands==2).float() + (tgt_commands==1).float() curve_mask = (tgt_commands==3).float() t=0.25 aux_pts_line = p0 + t*(p3-p0) for t in [0.5,0.75]: coord_t = p0 + t*(p3-p0) aux_pts_line = torch.cat((aux_pts_line,coord_t),-1) aux_pts_line = aux_pts_line*line_mask t=0.25 aux_pts_curve = (1-t)*(1-t)*(1-t)*p0 + 3*t*(1-t)*(1-t)*p1 + 3*t*t*(1-t)*p2 + t*t*t*p3 for t in [0.5, 0.75]: coord_t = (1-t)*(1-t)*(1-t)*p0 + 3*t*(1-t)*(1-t)*p1 + 3*t*t*(1-t)*p2 + t*t*t*p3 aux_pts_curve = torch.cat((aux_pts_curve,coord_t),-1) aux_pts_curve = aux_pts_curve * curve_mask aux_pts_predict = aux_pts_curve + aux_pts_line seqlen_mask_aux = util_funcs.sequence_mask(trg_seqlen - 1, opts.max_seq_len).unsqueeze(-1) aux_pts_loss = torch.pow((aux_pts_predict - trg_pts_aux), 2) * seqlen_mask_aux loss_aux = torch.mean(aux_pts_loss, dim=-1, keepdim=False) loss_aux = torch.mean(torch.sum(loss_aux / trg_seqlen.unsqueeze(-1), -1)) loss = opts.loss_w_cmd * loss_cmd + opts.loss_w_args * loss_args + opts.loss_w_aux * loss_aux + opts.loss_w_smt * smooth_constrained svg_losses = {} svg_losses['loss_total'] = loss svg_losses["loss_cmd"] = loss_cmd svg_losses["loss_args"] = loss_args svg_losses["loss_smt"] = smooth_constrained svg_losses["loss_aux"] = loss_aux return svg_losses class DecoderLayer(nn.Module): "Decoder is made of self-attn, src-attn, and feed forward (defined below)" def __init__(self, size, self_attn, src_attn, feed_forward, dropout): super(DecoderLayer, self).__init__() self.size = size self.self_attn = self_attn self.src_attn = src_attn self.feed_forward = feed_forward self.sublayer = clones(SublayerConnection(size, dropout), 3) def forward(self, x, memory, src_mask, tgt_mask): "Follow Figure 1 (right) for connections." m = memory x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) attn = self.self_attn.attn return self.sublayer[2](x, self.feed_forward),attn def subsequent_mask(size): "Mask out subsequent positions." attn_shape = (1, size, size) subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') return torch.from_numpy(subsequent_mask) == 0 def numericalize(cmd, n=128): """NOTE: shall only be called after normalization""" # assert np.max(cmd.origin) <= 1.0 and np.min(cmd.origin) >= -1.0 cmd = (cmd / 30 * n).round().clip(min=0, max=n-1).int() return cmd def denumericalize(cmd, n=128): cmd = cmd / n * 30 return cmd def attention(query, key, value, mask=None, trg_tri_mask=None,dropout=None, posr=None): "Compute 'Scaled Dot Product Attention'" d_k = query.size(-1) scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) if posr is not None: posr = posr.unsqueeze(1) scores = scores + posr if mask is not None: try: scores = scores.masked_fill(mask == 0, -1e9) # note mask: b,1,501,501 scores: b, head, 501,501 except Exception as e: print("Shape: ",scores.shape) print("Error: ",e) import pdb; pdb.set_trace() if trg_tri_mask is not None: scores = scores.masked_fill(trg_tri_mask == 0, -1e9) p_attn = F.softmax(scores, dim=-1) if dropout is not None: p_attn = dropout(p_attn) return torch.matmul(p_attn, value), p_attn class MultiHeadedAttention(nn.Module): def __init__(self, h, d_model, dropout): "Take in model size and number of heads." super(MultiHeadedAttention, self).__init__() assert d_model % h == 0 # We assume d_v always equals d_k self.d_k = d_model // h #32 self.h = h #8 self.linears = clones(nn.Linear(d_model, d_model), 4) self.attn = None self.dropout = nn.Dropout(p=dropout) def forward(self, query, key, value, mask=None,trg_tri_mask=None, posr=None): "Implements Figure 2" if mask is not None: # Same mask applied to all h heads. mask = mask.unsqueeze(1) nbatches = query.size(0) #16 query, key, value = \ [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) for l, x in zip(self.linears, (query, key, value))] x, self.attn = attention(query, key, value, mask=mask,trg_tri_mask=trg_tri_mask, dropout=self.dropout, posr=posr) x = x.transpose(1, 2).contiguous() \ .view(nbatches, -1, self.h * self.d_k) return self.linears[-1](x) def clones(module, N): "Produce N identical layers." return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) class SublayerConnection(nn.Module): """ A residual connection followed by a layer norm. Note for code simplicity the norm is first as opposed to last. """ def __init__(self, size, dropout): super(SublayerConnection, self).__init__() self.norm = nn.LayerNorm(size) self.dropout = nn.Dropout(dropout) def forward(self, x, sublayer): "Apply residual connection to any sublayer with the same size." x_norm=self.norm(x) return x + self.dropout(sublayer(x_norm))#+ self.augs(x_norm) if __name__ == '__main__': model = Transformer( input_channels = 1, # number of channels for each token of the input input_axis = 2, # number of axis for input data (2 for images, 3 for video) num_freq_bands = 6, # number of freq bands, with original value (2 * K + 1) max_freq = 10., # maximum frequency, hyperparameter depending on how fine the data is depth = 6, # depth of net. The shape of the final attention mechanism will be: # depth * (cross attention -> self_per_cross_attn * self attention) num_latents = 256, # number of latents, or induced set points, or centroids. different papers giving it different names latent_dim = 512, # latent dimension cross_heads = 1, # number of heads for cross attention. paper said 1 latent_heads = 8, # number of heads for latent self attention, 8 cross_dim_head = 64, # number of dimensions per cross attention head latent_dim_head = 64, # number of dimensions per latent self attention head num_classes = 1000, # output number of classes attn_dropout = 0., ff_dropout = 0., weight_tie_layers = False, # whether to weight tie layers (optional, as indicated in the diagram) fourier_encode_data = True, # whether to auto-fourier encode the data, using the input_axis given. defaults to True, but can be turned off if you are fourier encoding the data yourself self_per_cross_attn = 2 # number of self attention blocks per cross attention ) img = torch.randn(1, 224, 224, 3) # 1 imagenet image, pixelized model(img) # (1, 1000)