# supervised by a global average embedding, which is a biased estimation of the true embedding # use projection to enable a complex decoding # makes no big difference than mean so far, the decoding may not work 🤦‍ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torch from tqdm import tqdm import random class Transform(nn.Module): def __init__(self, n=2, token_size=32, input_dim=2048): super().__init__() self.n=n self.dim= input_dim*token_size self.token_size=token_size self.input_dim=input_dim self.weight = nn.Parameter(torch.ones(self.n,1),requires_grad=True) self.projections = nn.ModuleList([nn.Sequential( nn.Linear(self.dim, 512), nn.ReLU(), nn.Linear(512, self.dim) ) for _ in range(self.n)]) def encode(self, x): x = x.view(-1, self.dim) x = self.weight*x return x def decode(self, x): out=[] for i in range(self.n): t = self.projections[i](x[i]) out.append(t) x = torch.stack(out, dim=0) x=x.view(self.n,self.token_size,self.input_dim) x=torch.mean(x,dim=0) return x def forward(self, x): x = self.encode(x) x = self.decode(x) return x def online_train(cond, device="cuda:1",step=1000): old_device=cond.device dtype=cond.dtype cond = cond.clone().to(device,torch.float32) cond.requires_grad=False torch.set_grad_enabled(True) print("online training, initializing model...") n=cond.shape[0] model=Transform(n=n) optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.0001) criterion = nn.MSELoss() model.to(device) model.train() y=torch.mean(cond,dim=0) random.seed(42) bar=tqdm(range(step)) for s in bar: optimizer.zero_grad() attack_weight=[random.uniform(0.5,1.5) for _ in range(n)] attack_weight=torch.tensor(attack_weight)[:,None,None].to(device) x=attack_weight*cond output = model(x) loss = criterion(output, y) loss.backward() optimizer.step() bar.set_postfix(loss=loss.item()) weight=model.weight cond=weight[:,:,None]*cond print(weight) print("online training, ending...") del model del optimizer cond=torch.mean(cond,dim=0).unsqueeze(0) return cond.to(old_device,dtype=dtype)