File size: 4,796 Bytes
ffd9d26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import numpy as np
import warnings
import torch.optim as optim
import torch
from WAR.Models import NN_phi,NN_h_RELU
from WAR.Experiment_functions import *



def full_training(strategy,num_round,show_losses,show_chosen_each_round,
                  reset_phi,reset_h,weight_decay,lr_h=None,lr_phi=None,reduced=False,eta=1
                 ):
    
    """
    strategy: an object of class WAR
    num_round: total number of query rounds
    show_losses: display graphs showing the loss of h and phi each rounds
    show_chosen_each_round:display a graph showing the data queried each round
    reset_phi: if True, the phi neural network is reset after each round. can avoir overfitting but increase the number of epochs required to train the model
    reset_h:if True, the h neural network is reset after each round. can avoir overfitting but increase the number of epochs required to train the model
    lr_h: learning rate of h
    lr_phi: learning rate of phi
    reduced: will divide each query criterion by their standard deviation. In the case where they don't have the same amplitude, This will give them the same weight in the querying process. Irrelevant parameter if there is only one query criterion 
    eta:factor used to rebalance the criteria. If >1, distribution matching criterion gets more weight than the other(s). Irrelevant parameter if there is only one query criterion.
    
    """
    t1_descend_list=[]
    t2_ascend_list=[]
    acc = []# MAE
    acc_percentage=[] #MAPE
    acc_rmse=[] #RMSE
    
    only_train=False
    
    for rd in range(1,num_round+1):

        print('\n================Round {:d}==============='.format(rd))
        
        # if not enough unlabelled data to query a full batch, we will query the remaining data
        if len(np.arange(strategy.n_pool)[~strategy.idx_lb])<=strategy.num_elem_queried:
            only_train=True
        
        #reset neural networks
        if reset_phi==True:
            strategy.phi=NN_phi(dim_input=strategy.X_train.shape[1])
            strategy.opti_phi = optim.Adam(strategy.phi.parameters(), lr=lr_phi,maximize=True)


        if reset_h==True:
            strategy.h=NN_h_RELU(dim_input=strategy.X_train.shape[1])
            strategy.opti_h = optim.Adam(strategy.h.parameters(), lr=lr_h,weight_decay=weight_decay)

        
        t1,t2,b_idxs=strategy.train(only_train,reduced,eta)

        
        t1_descend_list.append(t1)
        t2_ascend_list.append(t2)
        if only_train==True:
            strategy.idx_lb[:]= True
        else:
            
            strategy.idx_lb[b_idxs] = True #"simulation" of the oracle who label the queried samples
            
        with torch.no_grad():
            if show_losses:
                display_loss_t1(t1,rd)
                display_loss_t2(t2,rd)

            if show_chosen_each_round:
                if strategy.X_train.shape[1]==1:
                    #display_phi(strategy.X_train,strategy.phi,rd)
                    display_chosen_labelled_datas(strategy.X_train.cpu(),strategy.idx_lb,strategy.y_train.cpu(),b_idxs,rd)
                    #display_prediction(strategy.X_test,strategy.h,strategy.y_test,rd)
                    
                else:
                    display_chosen_labelled_datas_PCA(strategy.X_train.cpu(),strategy.idx_lb,strategy.y_train.cpu(),b_idxs,rd)


            acc_rmse.append(RMSE(strategy.X_test,strategy.y_test,strategy.h).cpu())   
            acc.append(MAE(strategy.X_test,strategy.y_test,strategy.h).cpu())
            acc_percentage.append(MAPE(strategy.X_test,strategy.y_test,strategy.h).cpu())
        
            
    print('\n================Final training===============')
    

    
    t1,t2,_=strategy.train(only_train,reduced,eta)
    
    t1_descend_list.append(t1)
    t2_ascend_list.append(t2)
    
    with torch.no_grad():
        #display_loss_t1(t1,rd)
        #display_prediction(strategy.X_test,strategy.h,strategy.y_test,"final")

        acc.append(MAE(strategy.X_test,strategy.y_test,strategy.h).cpu())      
        acc_percentage.append(MAPE(strategy.X_test,strategy.y_test,strategy.h).cpu())        
        acc_rmse.append(RMSE(strategy.X_test,strategy.y_test,strategy.h).cpu()) 
        
    
    return acc,acc_percentage, acc_rmse,t1_descend_list,t2_ascend_list






def check_num_round(num_round,len_dataset,nb_initial_labelled_datas,num_elem_queried):
    max_round=int(np.ceil((len_dataset-nb_initial_labelled_datas)/num_elem_queried))
    if num_round>max_round:
        warnings.warn(f"when querying {num_elem_queried} data per round, num_rounds={num_round} is exceeding"+
        f" the maximum number of rounds (total data queried superior to number of initial unlabelled data).\nnum_round set to {max_round}")
        num_round=max_round
    return num_round