File size: 4,796 Bytes
ffd9d26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import numpy as np
import warnings
import torch.optim as optim
import torch
from WAR.Models import NN_phi,NN_h_RELU
from WAR.Experiment_functions import *
def full_training(strategy,num_round,show_losses,show_chosen_each_round,
reset_phi,reset_h,weight_decay,lr_h=None,lr_phi=None,reduced=False,eta=1
):
"""
strategy: an object of class WAR
num_round: total number of query rounds
show_losses: display graphs showing the loss of h and phi each rounds
show_chosen_each_round:display a graph showing the data queried each round
reset_phi: if True, the phi neural network is reset after each round. can avoir overfitting but increase the number of epochs required to train the model
reset_h:if True, the h neural network is reset after each round. can avoir overfitting but increase the number of epochs required to train the model
lr_h: learning rate of h
lr_phi: learning rate of phi
reduced: will divide each query criterion by their standard deviation. In the case where they don't have the same amplitude, This will give them the same weight in the querying process. Irrelevant parameter if there is only one query criterion
eta:factor used to rebalance the criteria. If >1, distribution matching criterion gets more weight than the other(s). Irrelevant parameter if there is only one query criterion.
"""
t1_descend_list=[]
t2_ascend_list=[]
acc = []# MAE
acc_percentage=[] #MAPE
acc_rmse=[] #RMSE
only_train=False
for rd in range(1,num_round+1):
print('\n================Round {:d}==============='.format(rd))
# if not enough unlabelled data to query a full batch, we will query the remaining data
if len(np.arange(strategy.n_pool)[~strategy.idx_lb])<=strategy.num_elem_queried:
only_train=True
#reset neural networks
if reset_phi==True:
strategy.phi=NN_phi(dim_input=strategy.X_train.shape[1])
strategy.opti_phi = optim.Adam(strategy.phi.parameters(), lr=lr_phi,maximize=True)
if reset_h==True:
strategy.h=NN_h_RELU(dim_input=strategy.X_train.shape[1])
strategy.opti_h = optim.Adam(strategy.h.parameters(), lr=lr_h,weight_decay=weight_decay)
t1,t2,b_idxs=strategy.train(only_train,reduced,eta)
t1_descend_list.append(t1)
t2_ascend_list.append(t2)
if only_train==True:
strategy.idx_lb[:]= True
else:
strategy.idx_lb[b_idxs] = True #"simulation" of the oracle who label the queried samples
with torch.no_grad():
if show_losses:
display_loss_t1(t1,rd)
display_loss_t2(t2,rd)
if show_chosen_each_round:
if strategy.X_train.shape[1]==1:
#display_phi(strategy.X_train,strategy.phi,rd)
display_chosen_labelled_datas(strategy.X_train.cpu(),strategy.idx_lb,strategy.y_train.cpu(),b_idxs,rd)
#display_prediction(strategy.X_test,strategy.h,strategy.y_test,rd)
else:
display_chosen_labelled_datas_PCA(strategy.X_train.cpu(),strategy.idx_lb,strategy.y_train.cpu(),b_idxs,rd)
acc_rmse.append(RMSE(strategy.X_test,strategy.y_test,strategy.h).cpu())
acc.append(MAE(strategy.X_test,strategy.y_test,strategy.h).cpu())
acc_percentage.append(MAPE(strategy.X_test,strategy.y_test,strategy.h).cpu())
print('\n================Final training===============')
t1,t2,_=strategy.train(only_train,reduced,eta)
t1_descend_list.append(t1)
t2_ascend_list.append(t2)
with torch.no_grad():
#display_loss_t1(t1,rd)
#display_prediction(strategy.X_test,strategy.h,strategy.y_test,"final")
acc.append(MAE(strategy.X_test,strategy.y_test,strategy.h).cpu())
acc_percentage.append(MAPE(strategy.X_test,strategy.y_test,strategy.h).cpu())
acc_rmse.append(RMSE(strategy.X_test,strategy.y_test,strategy.h).cpu())
return acc,acc_percentage, acc_rmse,t1_descend_list,t2_ascend_list
def check_num_round(num_round,len_dataset,nb_initial_labelled_datas,num_elem_queried):
max_round=int(np.ceil((len_dataset-nb_initial_labelled_datas)/num_elem_queried))
if num_round>max_round:
warnings.warn(f"when querying {num_elem_queried} data per round, num_rounds={num_round} is exceeding"+
f" the maximum number of rounds (total data queried superior to number of initial unlabelled data).\nnum_round set to {max_round}")
num_round=max_round
return num_round |