File size: 2,113 Bytes
a9b202e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gym
import numpy as np

from collections import deque
from stable_baselines3.common.vec_env.base_vec_env import (
    VecEnvStepReturn,
    VecEnvObs,
)
from torch.utils.tensorboard.writer import SummaryWriter

from shared.stats import Episode, EpisodesStats


class EpisodeStatsWriter(gym.Wrapper):
    def __init__(
        self,
        env,
        tb_writer: SummaryWriter,
        training: bool = True,
        rolling_length=100,
    ):
        super().__init__(env)
        self.training = training
        self.tb_writer = tb_writer
        self.rolling_length = rolling_length
        self.episodes = deque(maxlen=rolling_length)
        self.total_steps = 0
        self.episode_cnt = 0
        self.last_episode_cnt_print = 0

    def step(self, actions: np.ndarray) -> VecEnvStepReturn:
        obs, rews, dones, infos = self.env.step(actions)
        self.total_steps += getattr(self.env, "num_envs", 1)
        step_episodes = []
        for info in infos:
            ep_info = info.get("episode")
            if ep_info:
                episode = Episode(ep_info["r"], ep_info["l"])
                step_episodes.append(episode)
                self.episodes.append(episode)
        if step_episodes:
            tag = "train" if self.training else "eval"
            step_stats = EpisodesStats(step_episodes, simple=True)
            step_stats.write_to_tensorboard(self.tb_writer, tag, self.total_steps)
            rolling_stats = EpisodesStats(self.episodes)
            rolling_stats.write_to_tensorboard(
                self.tb_writer, f"{tag}_rolling", self.total_steps
            )
            self.episode_cnt += len(step_episodes)
            if self.episode_cnt >= self.last_episode_cnt_print + self.rolling_length:
                print(
                    f"Episode: {self.episode_cnt} | "
                    f"Steps: {self.total_steps} | "
                    f"{rolling_stats}"
                )
                self.last_episode_cnt_print += self.rolling_length
        return obs, rews, dones, infos

    def reset(self) -> VecEnvObs:
        return self.env.reset()