File size: 3,078 Bytes
3be620b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import numpy as np
from omegaconf import OmegaConf
import omegaconf
from ray.tune import Trainable
from ganime.model.base import load_model
from ganime.data.base import load_dataset
import tensorflow as tf

from ganime.utils.callbacks import TensorboardImage


class TrainableGANime(Trainable):
    def setup(self, config):
        strategy = tf.distribute.MirroredStrategy()

        tune_config = self.load_config_file_and_replace(config)
        self.batch_size = tune_config["trainer"]["batch_size"]

        self.n_devices = strategy.num_replicas_in_sync
        self.global_batch_size = self.batch_size * self.n_devices

        self.train_dataset, self.validation_dataset, self.test_dataset = load_dataset(
            dataset_name=config["dataset_name"],
            dataset_path=config["dataset_path"],
            batch_size=self.global_batch_size,
        )

        self.model = load_model(config["model"], config=tune_config, strategy=strategy)

        for data in self.train_dataset.take(1):
            train_sample_data = data
        for data in self.validation_dataset.take(1):
            validation_sample_data = data

        tensorboard_image_callback = TensorboardImage(
            self.logdir, train_sample_data, validation_sample_data
        )
        checkpointing = tf.keras.callbacks.ModelCheckpoint(
            os.path.join(self.logdir, "checkpoint", "checkpoint"),
            monitor="total_loss",
            save_best_only=True,
            save_weights_only=True,
        )
        self.callbacks = [tensorboard_image_callback, checkpointing]

    def load_config_file_and_replace(self, config):
        cfg = OmegaConf.load(config["config_file"])
        hyperparameters = config["hyperparameters"]

        for hp_key, hp_value in hyperparameters.items():
            cfg = self.replace_item(cfg, hp_key, hp_value)
        return cfg

    def replace_item(self, obj, key, replace_value):
        for k, v in obj.items():
            if isinstance(v, dict) or isinstance(v, omegaconf.dictconfig.DictConfig):
                obj[k] = self.replace_item(v, key, replace_value)
        if key in obj:
            obj[key] = replace_value
        return obj

    def step(self):

        self.model.fit(
            self.train_dataset,
            initial_epoch=self.training_iteration,
            epochs=self.training_iteration + 1,
            callbacks=self.callbacks,
            verbose=0,
        )
        scores = self.model.evaluate(self.validation_dataset, verbose=0)
        if np.nan in scores:
            self.stop()
        return dict(zip(self.model.metrics_names, scores))

    def save_checkpoint(self, tmp_checkpoint_dir):
        # checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
        # torch.save(self.model.state_dict(), checkpoint_path)
        # return tmp_checkpoint_dir
        pass

    def load_checkpoint(self, tmp_checkpoint_dir):
        # checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
        # self.model.load_state_dict(torch.load(checkpoint_path))
        pass