|
import numpy as np |
|
import warnings |
|
import torch.optim as optim |
|
import torch |
|
from WAR.Models import NN_phi,NN_h_RELU |
|
from WAR.Experiment_functions import * |
|
|
|
|
|
|
|
def full_training(strategy,num_round,show_losses,show_chosen_each_round, |
|
reset_phi,reset_h,weight_decay,lr_h=None,lr_phi=None,reduced=False,eta=1 |
|
): |
|
|
|
""" |
|
strategy: an object of class WAR |
|
num_round: total number of query rounds |
|
show_losses: display graphs showing the loss of h and phi each rounds |
|
show_chosen_each_round:display a graph showing the data queried each round |
|
reset_phi: if True, the phi neural network is reset after each round. can avoir overfitting but increase the number of epochs required to train the model |
|
reset_h:if True, the h neural network is reset after each round. can avoir overfitting but increase the number of epochs required to train the model |
|
lr_h: learning rate of h |
|
lr_phi: learning rate of phi |
|
reduced: will divide each query criterion by their standard deviation. In the case where they don't have the same amplitude, This will give them the same weight in the querying process. Irrelevant parameter if there is only one query criterion |
|
eta:factor used to rebalance the criteria. If >1, distribution matching criterion gets more weight than the other(s). Irrelevant parameter if there is only one query criterion. |
|
|
|
""" |
|
t1_descend_list=[] |
|
t2_ascend_list=[] |
|
acc = [] |
|
acc_percentage=[] |
|
acc_rmse=[] |
|
|
|
only_train=False |
|
|
|
for rd in range(1,num_round+1): |
|
|
|
print('\n================Round {:d}==============='.format(rd)) |
|
|
|
|
|
if len(np.arange(strategy.n_pool)[~strategy.idx_lb])<=strategy.num_elem_queried: |
|
only_train=True |
|
|
|
|
|
if reset_phi==True: |
|
strategy.phi=NN_phi(dim_input=strategy.X_train.shape[1]) |
|
strategy.opti_phi = optim.Adam(strategy.phi.parameters(), lr=lr_phi,maximize=True) |
|
|
|
|
|
if reset_h==True: |
|
strategy.h=NN_h_RELU(dim_input=strategy.X_train.shape[1]) |
|
strategy.opti_h = optim.Adam(strategy.h.parameters(), lr=lr_h,weight_decay=weight_decay) |
|
|
|
|
|
t1,t2,b_idxs=strategy.train(only_train,reduced,eta) |
|
|
|
|
|
t1_descend_list.append(t1) |
|
t2_ascend_list.append(t2) |
|
if only_train==True: |
|
strategy.idx_lb[:]= True |
|
else: |
|
|
|
strategy.idx_lb[b_idxs] = True |
|
|
|
with torch.no_grad(): |
|
if show_losses: |
|
display_loss_t1(t1,rd) |
|
display_loss_t2(t2,rd) |
|
|
|
if show_chosen_each_round: |
|
if strategy.X_train.shape[1]==1: |
|
|
|
display_chosen_labelled_datas(strategy.X_train.cpu(),strategy.idx_lb,strategy.y_train.cpu(),b_idxs,rd) |
|
|
|
|
|
else: |
|
display_chosen_labelled_datas_PCA(strategy.X_train.cpu(),strategy.idx_lb,strategy.y_train.cpu(),b_idxs,rd) |
|
|
|
|
|
acc_rmse.append(RMSE(strategy.X_test,strategy.y_test,strategy.h).cpu()) |
|
acc.append(MAE(strategy.X_test,strategy.y_test,strategy.h).cpu()) |
|
acc_percentage.append(MAPE(strategy.X_test,strategy.y_test,strategy.h).cpu()) |
|
|
|
|
|
print('\n================Final training===============') |
|
|
|
|
|
|
|
t1,t2,_=strategy.train(only_train,reduced,eta) |
|
|
|
t1_descend_list.append(t1) |
|
t2_ascend_list.append(t2) |
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
acc.append(MAE(strategy.X_test,strategy.y_test,strategy.h).cpu()) |
|
acc_percentage.append(MAPE(strategy.X_test,strategy.y_test,strategy.h).cpu()) |
|
acc_rmse.append(RMSE(strategy.X_test,strategy.y_test,strategy.h).cpu()) |
|
|
|
|
|
return acc,acc_percentage, acc_rmse,t1_descend_list,t2_ascend_list |
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_num_round(num_round,len_dataset,nb_initial_labelled_datas,num_elem_queried): |
|
max_round=int(np.ceil((len_dataset-nb_initial_labelled_datas)/num_elem_queried)) |
|
if num_round>max_round: |
|
warnings.warn(f"when querying {num_elem_queried} data per round, num_rounds={num_round} is exceeding"+ |
|
f" the maximum number of rounds (total data queried superior to number of initial unlabelled data).\nnum_round set to {max_round}") |
|
num_round=max_round |
|
return num_round |