|
import numpy as np |
|
import torch |
|
import torch.optim as optim |
|
from torch.utils.data import DataLoader |
|
import itertools |
|
|
|
from WAR.Experiment_functions import display_phi |
|
from WAR.dataset_handler import myData |
|
|
|
|
|
class WAR: |
|
|
|
def __init__(self,X_train,y_train,X_test,y_test,idx_lb,total_epoch_h,total_epoch_phi,batch_size_train,num_elem_queried |
|
,phi,h,opti_phi,opti_h,second_query_strategy=None): |
|
|
|
""" |
|
device: device on which to train the model. |
|
X_train: trainset. |
|
Y_train: labels of the trainset |
|
idx_lb: indices of the trainset that would be considered as labelled. |
|
n_pool: length of the trainset. |
|
total_epoch_h: number of epochs to train h. |
|
total_epoch_phi: number of epochs to train phi. |
|
batch_size_train: size of the batch in the training process. |
|
num_elem_queried: number of elem queried each round. |
|
phi: phi neural network. |
|
h: h neural network. |
|
opti_phi: phi optimizer. |
|
opti_h: h optimizer. |
|
cost: define the cost function for both neural network. "MSE" or MAE". |
|
second_query_strategy: second strategy to assist our distribution-matching criterion. |
|
""" |
|
|
|
self.device=torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.X_train = X_train.to(self.device) |
|
self.y_train = y_train.to(self.device) |
|
self.X_test=X_test.to(self.device) |
|
self.y_test=y_test.to(self.device) |
|
self.idx_lb = idx_lb |
|
self.n_pool = len(y_train) |
|
self.total_epoch_h=total_epoch_h |
|
self.total_epoch_phi=total_epoch_phi |
|
self.batch_size_train=batch_size_train |
|
self.num_elem_queried=num_elem_queried |
|
self.phi=phi.to(self.device) |
|
self.h=h.to(self.device) |
|
self.opti_phi=opti_phi |
|
self.opti_h=opti_h |
|
self.cost="MSE" |
|
self.second_query_strategy=second_query_strategy |
|
|
|
|
|
|
|
|
|
|
|
def cost_func(self,predicted,true): |
|
if self.cost=="MSE": |
|
return (predicted-true)**2 |
|
elif self.cost=="MAE": |
|
return abs(predicted-true) |
|
else: |
|
raise Exception("invalid cost function") |
|
|
|
|
|
|
|
|
|
def train(self,only_train=False,reduced=True,eta=3): |
|
|
|
""" |
|
only_train: activite when there is no more unlabelled data in the trainset. Will only train h and not train phi or query data. |
|
reduced: will divide each query criterion by their standard deviation. In the case where they don't have the same amplitude, This will give them the same weight in the querying process. Irrelevant parameter if there is only one query criterion (self.second_query_strategy=None). |
|
eta:factor used to rebalance the criteria. If >1, distribution matching criterion gets more weight than the other(s). Irrelevant parameter if there is only one query criterion. |
|
|
|
""" |
|
|
|
t1_descend=[] |
|
t2_ascend=[] |
|
|
|
|
|
idx_lb_train = np.arange(self.n_pool)[self.idx_lb] |
|
idx_ulb_train = np.arange(self.n_pool)[~self.idx_lb] |
|
|
|
|
|
trainset_labelled=myData(self.X_train[idx_lb_train],self.y_train[idx_lb_train]) |
|
trainloader_labelled= DataLoader(trainset_labelled,shuffle=True,batch_size=self.batch_size_train) |
|
|
|
for epoch in range(self.total_epoch_h): |
|
|
|
for i,data in enumerate(trainloader_labelled,0): |
|
label_x, label_y=data |
|
self.opti_h.zero_grad() |
|
|
|
lb_out = self.h(label_x) |
|
h_descent=torch.mean(self.cost_func(lb_out,label_y)) |
|
t1_descend.append(h_descent.detach().cpu()) |
|
h_descent.backward() |
|
self.opti_h.step() |
|
|
|
|
|
b_idxs=[] |
|
if not only_train: |
|
|
|
|
|
|
|
|
|
idxs_temp=self.idx_lb.copy() |
|
|
|
for elem_queried in range(self.num_elem_queried): |
|
|
|
trainset_total=myData(self.X_train,self.y_train) |
|
trainloader_total= DataLoader(trainset_total,shuffle=True,batch_size=len(trainset_total)) |
|
trainset_labelled=myData(self.X_train[idx_lb_train],self.y_train[idx_lb_train]) |
|
trainloader_labelled= DataLoader(trainset_labelled,shuffle=True,batch_size=self.batch_size_train) |
|
for epoch in range(self.total_epoch_phi): |
|
iterator_total_phi=itertools.cycle(trainloader_total) |
|
iterator_labelled_phi=itertools.cycle(trainloader_labelled) |
|
for i in range(len(trainloader_labelled)): |
|
label_x,label_y = next(iterator_labelled_phi) |
|
total_x,total_y = next(iterator_total_phi) |
|
|
|
self.opti_phi.zero_grad() |
|
phi_ascent = (torch.mean(self.phi(total_x))-torch.mean(self.phi(label_x))) |
|
t2_ascend.append(phi_ascent.detach().cpu()) |
|
phi_ascent.backward() |
|
self.opti_phi.step() |
|
|
|
|
|
b_queried=self.query(reduced,eta,idx_ulb_train) |
|
idxs_temp[b_queried]=True |
|
idx_ulb_train = np.arange(self.n_pool)[~idxs_temp] |
|
idx_lb_train = np.arange(self.n_pool)[idxs_temp] |
|
b_idxs.append(b_queried) |
|
self.idx_lb=idxs_temp |
|
return t1_descend,t2_ascend,b_idxs |
|
|
|
|
|
|
|
def query(self,reduced,eta,idx_ulb_train): |
|
|
|
|
|
""" |
|
reduced:same as for function "train" |
|
eta: sme as for function "train" |
|
idx_ulb_train:indices of unlabeled points |
|
|
|
""" |
|
|
|
if self.second_query_strategy=="loss_approximation": |
|
second_query_criterion = self.predict_loss(self.X_train[idx_ulb_train]) |
|
|
|
with torch.no_grad(): |
|
phi_scores = self.phi(self.X_train[idx_ulb_train]).view(-1) |
|
|
|
if reduced and self.second_query_strategy!=None: |
|
phi_scores_reduced=phi_scores/torch.std(phi_scores) |
|
second_query_criterion_reduced=second_query_criterion/torch.std(second_query_criterion) |
|
total_scores =-(eta*phi_scores_reduced+second_query_criterion_reduced ) |
|
|
|
elif self.second_query_strategy!=None: |
|
total_scores =-(eta*phi_scores+second_query_criterion) |
|
|
|
else: |
|
total_scores =-eta*phi_scores |
|
|
|
b=torch.argmin(total_scores) |
|
|
|
return idx_ulb_train[b] |
|
|
|
|
|
def predict_loss(self,X): |
|
|
|
""" |
|
X: set of unlabeled elements of the trainset |
|
|
|
""" |
|
|
|
idxs_lb=np.arange(self.n_pool)[self.idx_lb] |
|
losses=[] |
|
with torch.no_grad(): |
|
for i in X: |
|
idx_nearest_Xk,dist=self.Idx_NearestP(i,idxs_lb) |
|
losses.append(self.Max_cost_B(idx_nearest_Xk,dist,i)) |
|
|
|
return torch.Tensor(losses).to(self.device) |
|
|
|
def Idx_NearestP(self,Xu,idxs_lb): |
|
|
|
|
|
""" |
|
Xu:unlabeled point |
|
idxs_lb: indices of labeled points |
|
|
|
""" |
|
|
|
distances=[] |
|
for i in idxs_lb: |
|
distances.append(torch.norm(Xu-self.X_train[i])) |
|
|
|
return idxs_lb[distances.index(min(distances))],float(min(distances)) |
|
|
|
|
|
|
|
def Max_cost_B(self,idx_Xk,distance,i): |
|
|
|
""" |
|
idx_Xk: labeled point indice nearest to the unlabeled point |
|
distance: distance between them |
|
i:unlabeled point |
|
|
|
""" |
|
|
|
est_h_unl_X=self.h(i) |
|
true_value_labelled_X=self.y_train[idx_Xk] |
|
bound_min= true_value_labelled_X-distance |
|
bound_max= true_value_labelled_X+distance |
|
return max(self.cost_func(est_h_unl_X,bound_min),self.cost_func(est_h_unl_X,bound_max))[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|