File size: 3,362 Bytes
b18ddcc
 
 
 
 
 
 
3cc5c1d
b18ddcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import numpy as np
import optuna

from time import perf_counter
from torch.utils.tensorboard.writer import SummaryWriter
from typing import NamedTuple, Union

from rl_algo_impls.shared.callbacks import Callback
from rl_algo_impls.shared.callbacks.eval_callback import evaluate
from rl_algo_impls.shared.policy.policy import Policy
from rl_algo_impls.shared.stats import EpisodesStats
from rl_algo_impls.wrappers.episode_stats_writer import EpisodeStatsWriter
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, find_wrapper


class Evaluation(NamedTuple):
    eval_stat: EpisodesStats
    train_stat: EpisodesStats
    score: float


class OptimizeCallback(Callback):
    def __init__(
        self,
        policy: Policy,
        env: VecEnv,
        trial: optuna.Trial,
        tb_writer: SummaryWriter,
        step_freq: Union[int, float] = 50_000,
        n_episodes: int = 10,
        deterministic: bool = True,
    ) -> None:
        super().__init__()
        self.policy = policy
        self.env = env
        self.trial = trial
        self.tb_writer = tb_writer
        self.step_freq = step_freq
        self.n_episodes = n_episodes
        self.deterministic = deterministic

        stats_writer = find_wrapper(policy.env, EpisodeStatsWriter)
        assert stats_writer
        self.stats_writer = stats_writer

        self.eval_step = 1
        self.is_pruned = False
        self.last_eval_stat = None
        self.last_train_stat = None
        self.last_score = -np.inf

    def on_step(self, timesteps_elapsed: int = 1) -> bool:
        super().on_step(timesteps_elapsed)
        if self.timesteps_elapsed >= self.eval_step * self.step_freq:
            self.evaluate()
            return not self.is_pruned
        return True

    def evaluate(self) -> None:
        self.last_eval_stat, self.last_train_stat, score = evaluation(
            self.policy,
            self.env,
            self.tb_writer,
            self.n_episodes,
            self.deterministic,
            self.timesteps_elapsed,
        )
        self.last_score = score

        self.trial.report(score, self.eval_step)
        if self.trial.should_prune():
            self.is_pruned = True

        self.eval_step += 1


def evaluation(
    policy: Policy,
    env: VecEnv,
    tb_writer: SummaryWriter,
    n_episodes: int,
    deterministic: bool,
    timesteps_elapsed: int,
) -> Evaluation:
    start_time = perf_counter()
    eval_stat = evaluate(
        env,
        policy,
        n_episodes,
        deterministic=deterministic,
        print_returns=False,
    )
    end_time = perf_counter()
    tb_writer.add_scalar(
        "eval/steps_per_second",
        eval_stat.length.sum() / (end_time - start_time),
        timesteps_elapsed,
    )
    policy.train()
    print(f"Eval Timesteps: {timesteps_elapsed} | {eval_stat}")
    eval_stat.write_to_tensorboard(tb_writer, "eval", timesteps_elapsed)

    stats_writer = find_wrapper(policy.env, EpisodeStatsWriter)
    assert stats_writer

    train_stat = EpisodesStats(stats_writer.episodes)
    print(f"  Train Stat: {train_stat}")

    score = (eval_stat.score.mean + train_stat.score.mean) / 2
    print(f"  Score: {round(score, 2)}")
    tb_writer.add_scalar(
        "eval/score",
        score,
        timesteps_elapsed,
    )

    return Evaluation(eval_stat, train_stat, score)