File size: 1,485 Bytes
dda3d40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import numpy as np
from cleverhans.future.tf2.attacks import fast_gradient_method
import pandas as pd
from sklearn.model_selection import KFold
import sys
import tensorflow
import tensorflow as tf

from _utility import print_test, get_adversarial_examples

import pickle

folder_name = "./adversarial_examples_parseval_net/src/logs/saved_models/"


def train(
    instance,
    X_train,
    Y_train,
    X_test,
    y_test,
    epochs,
    BS,
    sgd,
    generator,
    callbacks_list,
    model_name="ResNet",
):

    kfold = KFold(n_splits=10, random_state=42, shuffle=False)

    for j, (train, val) in enumerate(kfold.split(X_train)):

        model = instance.create_wide_residual_network()
        model.compile(loss="categorical_crossentropy", optimizer=sgd, metrics=["acc"])

        print("Finished compiling")

        x_train, y_train = X_train[train], Y_train[train]
        x_val, y_val = X_train[val], Y_train[val]

        hist = model.fit(
            generator.flow(x_train, y_train, batch_size=BS),
            steps_per_epoch=len(x_train) // BS,
            epochs=epochs,
            callbacks=callbacks_list,
            validation_data=(x_val, y_val),
            validation_steps=x_val.shape[0] // BS,
        )
        ## write the history

        with open("history_" + model_name + str(j), "wb") as file_pi:
            pickle.dump(hist.history, file_pi)

        model_name = folder_name + model_name + "_" + str(j) + ".h5"
        model.save(model_name)