Spaces:
Running
on
T4
Running
on
T4
from __future__ import absolute_import | |
import sys | |
import numpy as np | |
import torch | |
from torch import nn | |
import os | |
from collections import OrderedDict | |
from torch.autograd import Variable | |
import itertools | |
from models.stylegan2.lpips.base_model import BaseModel | |
from scipy.ndimage import zoom | |
import fractions | |
import functools | |
import skimage.transform | |
from tqdm import tqdm | |
from IPython import embed | |
from models.stylegan2.lpips import networks_basic as networks | |
import models.stylegan2.lpips as util | |
class DistModel(BaseModel): | |
def name(self): | |
return self.model_name | |
def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, | |
use_gpu=True, printNet=False, spatial=False, | |
is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): | |
''' | |
INPUTS | |
model - ['net-lin'] for linearly calibrated network | |
['net'] for off-the-shelf network | |
['L2'] for L2 distance in Lab colorspace | |
['SSIM'] for ssim in RGB colorspace | |
net - ['squeeze','alex','vgg'] | |
model_path - if None, will look in weights/[NET_NAME].pth | |
colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM | |
use_gpu - bool - whether or not to use a GPU | |
printNet - bool - whether or not to print network architecture out | |
spatial - bool - whether to output an array containing varying distances across spatial dimensions | |
spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below). | |
spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images. | |
spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear). | |
is_train - bool - [True] for training mode | |
lr - float - initial learning rate | |
beta1 - float - initial momentum term for adam | |
version - 0.1 for latest, 0.0 was original (with a bug) | |
gpu_ids - int array - [0] by default, gpus to use | |
''' | |
BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids) | |
self.model = model | |
self.net = net | |
self.is_train = is_train | |
self.spatial = spatial | |
self.gpu_ids = gpu_ids | |
self.model_name = '%s [%s]'%(model,net) | |
if(self.model == 'net-lin'): # pretrained net + linear layer | |
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, | |
use_dropout=True, spatial=spatial, version=version, lpips=True) | |
kw = {} | |
if not use_gpu: | |
kw['map_location'] = 'cpu' | |
if(model_path is None): | |
import inspect | |
model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net))) | |
if(not is_train): | |
print('Loading model from: %s'%model_path) | |
self.net.load_state_dict(torch.load(model_path, **kw), strict=False) | |
elif(self.model=='net'): # pretrained network | |
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) | |
elif(self.model in ['L2','l2']): | |
self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing | |
self.model_name = 'L2' | |
elif(self.model in ['DSSIM','dssim','SSIM','ssim']): | |
self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace) | |
self.model_name = 'SSIM' | |
else: | |
raise ValueError("Model [%s] not recognized." % self.model) | |
self.parameters = list(self.net.parameters()) | |
if self.is_train: # training mode | |
# extra network on top to go from distances (d0,d1) => predicted human judgment (h*) | |
self.rankLoss = networks.BCERankingLoss() | |
self.parameters += list(self.rankLoss.net.parameters()) | |
self.lr = lr | |
self.old_lr = lr | |
self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) | |
else: # test mode | |
self.net.eval() | |
if(use_gpu): | |
self.net.to(gpu_ids[0]) | |
self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) | |
if(self.is_train): | |
self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 | |
if(printNet): | |
print('---------- Networks initialized -------------') | |
networks.print_network(self.net) | |
print('-----------------------------------------------') | |
def forward(self, in0, in1, retPerLayer=False): | |
''' Function computes the distance between image patches in0 and in1 | |
INPUTS | |
in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] | |
OUTPUT | |
computed distances between in0 and in1 | |
''' | |
return self.net.forward(in0, in1, retPerLayer=retPerLayer) | |
# ***** TRAINING FUNCTIONS ***** | |
def optimize_parameters(self): | |
self.forward_train() | |
self.optimizer_net.zero_grad() | |
self.backward_train() | |
self.optimizer_net.step() | |
self.clamp_weights() | |
def clamp_weights(self): | |
for module in self.net.modules(): | |
if(hasattr(module, 'weight') and module.kernel_size==(1,1)): | |
module.weight.data = torch.clamp(module.weight.data,min=0) | |
def set_input(self, data): | |
self.input_ref = data['ref'] | |
self.input_p0 = data['p0'] | |
self.input_p1 = data['p1'] | |
self.input_judge = data['judge'] | |
if(self.use_gpu): | |
self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) | |
self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) | |
self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) | |
self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) | |
self.var_ref = Variable(self.input_ref,requires_grad=True) | |
self.var_p0 = Variable(self.input_p0,requires_grad=True) | |
self.var_p1 = Variable(self.input_p1,requires_grad=True) | |
def forward_train(self): # run forward pass | |
# print(self.net.module.scaling_layer.shift) | |
# print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) | |
self.d0 = self.forward(self.var_ref, self.var_p0) | |
self.d1 = self.forward(self.var_ref, self.var_p1) | |
self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) | |
self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) | |
self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) | |
return self.loss_total | |
def backward_train(self): | |
torch.mean(self.loss_total).backward() | |
def compute_accuracy(self,d0,d1,judge): | |
''' d0, d1 are Variables, judge is a Tensor ''' | |
d1_lt_d0 = (d1<d0).cpu().data.numpy().flatten() | |
judge_per = judge.cpu().numpy().flatten() | |
return d1_lt_d0*judge_per + (1-d1_lt_d0)*(1-judge_per) | |
def get_current_errors(self): | |
retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()), | |
('acc_r', self.acc_r)]) | |
for key in retDict.keys(): | |
retDict[key] = np.mean(retDict[key]) | |
return retDict | |
def get_current_visuals(self): | |
zoom_factor = 256/self.var_ref.data.size()[2] | |
ref_img = util.tensor2im(self.var_ref.data) | |
p0_img = util.tensor2im(self.var_p0.data) | |
p1_img = util.tensor2im(self.var_p1.data) | |
ref_img_vis = zoom(ref_img,[zoom_factor, zoom_factor, 1],order=0) | |
p0_img_vis = zoom(p0_img,[zoom_factor, zoom_factor, 1],order=0) | |
p1_img_vis = zoom(p1_img,[zoom_factor, zoom_factor, 1],order=0) | |
return OrderedDict([('ref', ref_img_vis), | |
('p0', p0_img_vis), | |
('p1', p1_img_vis)]) | |
def save(self, path, label): | |
if(self.use_gpu): | |
self.save_network(self.net.module, path, '', label) | |
else: | |
self.save_network(self.net, path, '', label) | |
self.save_network(self.rankLoss.net, path, 'rank', label) | |
def update_learning_rate(self,nepoch_decay): | |
lrd = self.lr / nepoch_decay | |
lr = self.old_lr - lrd | |
for param_group in self.optimizer_net.param_groups: | |
param_group['lr'] = lr | |
print('update lr [%s] decay: %f -> %f' % (type,self.old_lr, lr)) | |
self.old_lr = lr | |
def score_2afc_dataset(data_loader, func, name=''): | |
''' Function computes Two Alternative Forced Choice (2AFC) score using | |
distance function 'func' in dataset 'data_loader' | |
INPUTS | |
data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside | |
func - callable distance function - calling d=func(in0,in1) should take 2 | |
pytorch tensors with shape Nx3xXxY, and return numpy array of length N | |
OUTPUTS | |
[0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators | |
[1] - dictionary with following elements | |
d0s,d1s - N arrays containing distances between reference patch to perturbed patches | |
gts - N array in [0,1], preferred patch selected by human evaluators | |
(closer to "0" for left patch p0, "1" for right patch p1, | |
"0.6" means 60pct people preferred right patch, 40pct preferred left) | |
scores - N array in [0,1], corresponding to what percentage function agreed with humans | |
CONSTS | |
N - number of test triplets in data_loader | |
''' | |
d0s = [] | |
d1s = [] | |
gts = [] | |
for data in tqdm(data_loader.load_data(), desc=name): | |
d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist() | |
d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist() | |
gts+=data['judge'].cpu().numpy().flatten().tolist() | |
d0s = np.array(d0s) | |
d1s = np.array(d1s) | |
gts = np.array(gts) | |
scores = (d0s<d1s)*(1.-gts) + (d1s<d0s)*gts + (d1s==d0s)*.5 | |
return(np.mean(scores), dict(d0s=d0s,d1s=d1s,gts=gts,scores=scores)) | |
def score_jnd_dataset(data_loader, func, name=''): | |
''' Function computes JND score using distance function 'func' in dataset 'data_loader' | |
INPUTS | |
data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside | |
func - callable distance function - calling d=func(in0,in1) should take 2 | |
pytorch tensors with shape Nx3xXxY, and return pytorch array of length N | |
OUTPUTS | |
[0] - JND score in [0,1], mAP score (area under precision-recall curve) | |
[1] - dictionary with following elements | |
ds - N array containing distances between two patches shown to human evaluator | |
sames - N array containing fraction of people who thought the two patches were identical | |
CONSTS | |
N - number of test triplets in data_loader | |
''' | |
ds = [] | |
gts = [] | |
for data in tqdm(data_loader.load_data(), desc=name): | |
ds+=func(data['p0'],data['p1']).data.cpu().numpy().tolist() | |
gts+=data['same'].cpu().numpy().flatten().tolist() | |
sames = np.array(gts) | |
ds = np.array(ds) | |
sorted_inds = np.argsort(ds) | |
ds_sorted = ds[sorted_inds] | |
sames_sorted = sames[sorted_inds] | |
TPs = np.cumsum(sames_sorted) | |
FPs = np.cumsum(1-sames_sorted) | |
FNs = np.sum(sames_sorted)-TPs | |
precs = TPs/(TPs+FPs) | |
recs = TPs/(TPs+FNs) | |
score = util.voc_ap(recs,precs) | |
return(score, dict(ds=ds,sames=sames)) | |