File size: 9,333 Bytes
ffd9d26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import itertools
from WAR.Experiment_functions import display_phi
from WAR.dataset_handler import myData
class WAR:
def __init__(self,X_train,y_train,X_test,y_test,idx_lb,total_epoch_h,total_epoch_phi,batch_size_train,num_elem_queried
,phi,h,opti_phi,opti_h,second_query_strategy=None):
"""
device: device on which to train the model.
X_train: trainset.
Y_train: labels of the trainset
idx_lb: indices of the trainset that would be considered as labelled.
n_pool: length of the trainset.
total_epoch_h: number of epochs to train h.
total_epoch_phi: number of epochs to train phi.
batch_size_train: size of the batch in the training process.
num_elem_queried: number of elem queried each round.
phi: phi neural network.
h: h neural network.
opti_phi: phi optimizer.
opti_h: h optimizer.
cost: define the cost function for both neural network. "MSE" or MAE".
second_query_strategy: second strategy to assist our distribution-matching criterion.
"""
self.device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.X_train = X_train.to(self.device)
self.y_train = y_train.to(self.device)
self.X_test=X_test.to(self.device)
self.y_test=y_test.to(self.device)
self.idx_lb = idx_lb
self.n_pool = len(y_train)
self.total_epoch_h=total_epoch_h
self.total_epoch_phi=total_epoch_phi
self.batch_size_train=batch_size_train
self.num_elem_queried=num_elem_queried
self.phi=phi.to(self.device)
self.h=h.to(self.device)
self.opti_phi=opti_phi
self.opti_h=opti_h
self.cost="MSE"
self.second_query_strategy=second_query_strategy
#cost function used to train both phi and h
def cost_func(self,predicted,true):
if self.cost=="MSE":
return (predicted-true)**2
elif self.cost=="MAE":
return abs(predicted-true)
else:
raise Exception("invalid cost function")
def train(self,only_train=False,reduced=True,eta=3):# train function for one round
"""
only_train: activite when there is no more unlabelled data in the trainset. Will only train h and not train phi or query data.
reduced: will divide each query criterion by their standard deviation. In the case where they don't have the same amplitude, This will give them the same weight in the querying process. Irrelevant parameter if there is only one query criterion (self.second_query_strategy=None).
eta:factor used to rebalance the criteria. If >1, distribution matching criterion gets more weight than the other(s). Irrelevant parameter if there is only one query criterion.
"""
#recover loss
t1_descend=[]
t2_ascend=[]
# separating labelled and unlabelled data respectively
idx_lb_train = np.arange(self.n_pool)[self.idx_lb]
idx_ulb_train = np.arange(self.n_pool)[~self.idx_lb]
trainset_labelled=myData(self.X_train[idx_lb_train],self.y_train[idx_lb_train])
trainloader_labelled= DataLoader(trainset_labelled,shuffle=True,batch_size=self.batch_size_train)
for epoch in range(self.total_epoch_h):
for i,data in enumerate(trainloader_labelled,0):
label_x, label_y=data
self.opti_h.zero_grad()
# T1 (train h)
lb_out = self.h(label_x)
h_descent=torch.mean(self.cost_func(lb_out,label_y))
t1_descend.append(h_descent.detach().cpu())
h_descent.backward()
self.opti_h.step()
b_idxs=[]# batch of queried points
if not only_train:
#T2 (train phi)
# temporary set of labelled data indices. Used only to retrain phi during the time oracle has not been called.
#h is no retrained during this time.
idxs_temp=self.idx_lb.copy()
for elem_queried in range(self.num_elem_queried):
trainset_total=myData(self.X_train,self.y_train)
trainloader_total= DataLoader(trainset_total,shuffle=True,batch_size=len(trainset_total))
trainset_labelled=myData(self.X_train[idx_lb_train],self.y_train[idx_lb_train])
trainloader_labelled= DataLoader(trainset_labelled,shuffle=True,batch_size=self.batch_size_train)
for epoch in range(self.total_epoch_phi):
iterator_total_phi=itertools.cycle(trainloader_total)
iterator_labelled_phi=itertools.cycle(trainloader_labelled)
for i in range(len(trainloader_labelled)):
label_x,label_y = next(iterator_labelled_phi)
total_x,total_y = next(iterator_total_phi)
#display_phi(self.X_train,self.phi)
self.opti_phi.zero_grad()
phi_ascent = (torch.mean(self.phi(total_x))-torch.mean(self.phi(label_x)))
t2_ascend.append(phi_ascent.detach().cpu())
phi_ascent.backward()
self.opti_phi.step()
# Query process
b_queried=self.query(reduced,eta,idx_ulb_train)# query one element
idxs_temp[b_queried]=True #add it to the temporary set of labeled point indices indices
idx_ulb_train = np.arange(self.n_pool)[~idxs_temp] #update the set of unlabeled point indices
idx_lb_train = np.arange(self.n_pool)[idxs_temp] #update the set of labeled point indices
b_idxs.append(b_queried)#add the chosen point in the batch
self.idx_lb=idxs_temp#end of the query process: update the true set of labeled point indices indices
return t1_descend,t2_ascend,b_idxs
def query(self,reduced,eta,idx_ulb_train):# computing T3: query one point according to the chosen query criteria
"""
reduced:same as for function "train"
eta: sme as for function "train"
idx_ulb_train:indices of unlabeled points
"""
if self.second_query_strategy=="loss_approximation":
second_query_criterion = self.predict_loss(self.X_train[idx_ulb_train])
with torch.no_grad():
phi_scores = self.phi(self.X_train[idx_ulb_train]).view(-1)
if reduced and self.second_query_strategy!=None:
phi_scores_reduced=phi_scores/torch.std(phi_scores)
second_query_criterion_reduced=second_query_criterion/torch.std(second_query_criterion)
total_scores =-(eta*phi_scores_reduced+second_query_criterion_reduced )
elif self.second_query_strategy!=None:
total_scores =-(eta*phi_scores+second_query_criterion)
else:
total_scores =-eta*phi_scores
b=torch.argmin(total_scores)
return idx_ulb_train[b]
def predict_loss(self,X):# Second query criterion which act as loss estimator (uncertainty and diversity-based sampling)
"""
X: set of unlabeled elements of the trainset
"""
idxs_lb=np.arange(self.n_pool)[self.idx_lb]#get labeled data indices
losses=[]
with torch.no_grad():
for i in X:
idx_nearest_Xk,dist=self.Idx_NearestP(i,idxs_lb)
losses.append(self.Max_cost_B(idx_nearest_Xk,dist,i))
return torch.Tensor(losses).to(self.device)
def Idx_NearestP(self,Xu,idxs_lb):# Return the closest labeled point to the unlabeled point
"""
Xu:unlabeled point
idxs_lb: indices of labeled points
"""
distances=[]
for i in idxs_lb:
distances.append(torch.norm(Xu-self.X_train[i]))
return idxs_lb[distances.index(min(distances))],float(min(distances))
def Max_cost_B(self,idx_Xk,distance,i):#return the "maximum loss" of the unlabeled point
"""
idx_Xk: labeled point indice nearest to the unlabeled point
distance: distance between them
i:unlabeled point
"""
est_h_unl_X=self.h(i)
true_value_labelled_X=self.y_train[idx_Xk]
bound_min= true_value_labelled_X-distance
bound_max= true_value_labelled_X+distance
return max(self.cost_func(est_h_unl_X,bound_min),self.cost_func(est_h_unl_X,bound_max))[0]
|