zjowowen's picture
init space
079c32c
raw
history blame
32.7 kB
from typing import TYPE_CHECKING, Optional, Callable, Dict, List, Union
from ditk import logging
from easydict import EasyDict
from matplotlib import pyplot as plt
from matplotlib import animation
import os
import numpy as np
import torch
import wandb
import pickle
import treetensor.numpy as tnp
from ding.framework import task
from ding.envs import BaseEnvManagerV2
from ding.utils import DistributedWriter
from ding.torch_utils import to_ndarray
from ding.utils.default_helper import one_time_warning
if TYPE_CHECKING:
from ding.framework import OnlineRLContext, OfflineRLContext
def online_logger(record_train_iter: bool = False, train_show_freq: int = 100) -> Callable:
"""
Overview:
Create an online RL tensorboard logger for recording training and evaluation metrics.
Arguments:
- record_train_iter (:obj:`bool`): Whether to record training iteration. Default is False.
- train_show_freq (:obj:`int`): Frequency of showing training logs. Default is 100.
Returns:
- _logger (:obj:`Callable`): A logger function that takes an OnlineRLContext object as input.
Raises:
- RuntimeError: If writer is None.
- NotImplementedError: If the key of train_output is not supported, such as "scalars".
Examples:
>>> task.use(online_logger(record_train_iter=False, train_show_freq=1000))
"""
if task.router.is_active and not task.has_role(task.role.LEARNER):
return task.void()
writer = DistributedWriter.get_instance()
if writer is None:
raise RuntimeError("logger writer is None, you should call `ding_init(cfg)` at the beginning of training.")
last_train_show_iter = -1
def _logger(ctx: "OnlineRLContext"):
if task.finish:
writer.close()
nonlocal last_train_show_iter
if not np.isinf(ctx.eval_value):
if record_train_iter:
writer.add_scalar('basic/eval_episode_return_mean-env_step', ctx.eval_value, ctx.env_step)
writer.add_scalar('basic/eval_episode_return_mean-train_iter', ctx.eval_value, ctx.train_iter)
else:
writer.add_scalar('basic/eval_episode_return_mean', ctx.eval_value, ctx.env_step)
if ctx.train_output is not None and ctx.train_iter - last_train_show_iter >= train_show_freq:
last_train_show_iter = ctx.train_iter
if isinstance(ctx.train_output, List):
output = ctx.train_output.pop() # only use latest output for some algorithms, like PPO
else:
output = ctx.train_output
for k, v in output.items():
if k in ['priority', 'td_error_priority']:
continue
if "[scalars]" in k:
new_k = k.split(']')[-1]
raise NotImplementedError
elif "[histogram]" in k:
new_k = k.split(']')[-1]
writer.add_histogram(new_k, v, ctx.env_step)
if record_train_iter:
writer.add_histogram(new_k, v, ctx.train_iter)
else:
if record_train_iter:
writer.add_scalar('basic/train_{}-train_iter'.format(k), v, ctx.train_iter)
writer.add_scalar('basic/train_{}-env_step'.format(k), v, ctx.env_step)
else:
writer.add_scalar('basic/train_{}'.format(k), v, ctx.env_step)
return _logger
def offline_logger(train_show_freq: int = 100) -> Callable:
"""
Overview:
Create an offline RL tensorboard logger for recording training and evaluation metrics.
Arguments:
- train_show_freq (:obj:`int`): Frequency of showing training logs. Defaults to 100.
Returns:
- _logger (:obj:`Callable`): A logger function that takes an OfflineRLContext object as input.
Raises:
- RuntimeError: If writer is None.
- NotImplementedError: If the key of train_output is not supported, such as "scalars".
Examples:
>>> task.use(offline_logger(train_show_freq=1000))
"""
if task.router.is_active and not task.has_role(task.role.LEARNER):
return task.void()
writer = DistributedWriter.get_instance()
if writer is None:
raise RuntimeError("logger writer is None, you should call `ding_init(cfg)` at the beginning of training.")
last_train_show_iter = -1
def _logger(ctx: "OfflineRLContext"):
nonlocal last_train_show_iter
if task.finish:
writer.close()
if not np.isinf(ctx.eval_value):
writer.add_scalar('basic/eval_episode_return_mean-train_iter', ctx.eval_value, ctx.train_iter)
if ctx.train_output is not None and ctx.train_iter - last_train_show_iter >= train_show_freq:
last_train_show_iter = ctx.train_iter
output = ctx.train_output
for k, v in output.items():
if k in ['priority']:
continue
if "[scalars]" in k:
new_k = k.split(']')[-1]
raise NotImplementedError
elif "[histogram]" in k:
new_k = k.split(']')[-1]
writer.add_histogram(new_k, v, ctx.train_iter)
else:
writer.add_scalar('basic/train_{}-train_iter'.format(k), v, ctx.train_iter)
return _logger
# four utility functions for wandb logger
def softmax(logit: np.ndarray) -> np.ndarray:
v = np.exp(logit)
return v / v.sum(axis=-1, keepdims=True)
def action_prob(num, action_prob, ln):
ax = plt.gca()
ax.set_ylim([0, 1])
for rect, x in zip(ln, action_prob[num]):
rect.set_height(x)
return ln
def return_prob(num, return_prob, ln):
return ln
def return_distribution(episode_return):
num = len(episode_return)
max_return = max(episode_return)
min_return = min(episode_return)
hist, bins = np.histogram(episode_return, bins=np.linspace(min_return - 50, max_return + 50, 6))
gap = (max_return - min_return + 100) / 5
x_dim = ['{:.1f}'.format(min_return - 50 + gap * x) for x in range(5)]
return hist / num, x_dim
def wandb_online_logger(
record_path: str = None,
cfg: Union[dict, EasyDict] = None,
exp_config: Union[dict, EasyDict] = None,
metric_list: Optional[List[str]] = None,
env: Optional[BaseEnvManagerV2] = None,
model: Optional[torch.nn.Module] = None,
anonymous: bool = False,
project_name: str = 'default-project',
run_name: str = None,
wandb_sweep: bool = False,
) -> Callable:
"""
Overview:
Wandb visualizer to track the experiment.
Arguments:
- record_path (:obj:`str`): The path to save the replay of simulation.
- cfg (:obj:`Union[dict, EasyDict]`): Config, a dict of following settings:
- gradient_logger: boolean. Whether to track the gradient.
- plot_logger: boolean. Whether to track the metrics like reward and loss.
- video_logger: boolean. Whether to upload the rendering video replay.
- action_logger: boolean. `q_value` or `action probability`.
- return_logger: boolean. Whether to track the return value.
- metric_list (:obj:`Optional[List[str]]`): Logged metric list, specialized by different policies.
- env (:obj:`BaseEnvManagerV2`): Evaluator environment.
- model (:obj:`nn.Module`): Policy neural network model.
- anonymous (:obj:`bool`): Open the anonymous mode of wandb or not. The anonymous mode allows visualization \
of data without wandb count.
- project_name (:obj:`str`): The name of wandb project.
- run_name (:obj:`str`): The name of wandb run.
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep.
'''
Returns:
- _plot (:obj:`Callable`): A logger function that takes an OnlineRLContext object as input.
"""
if task.router.is_active and not task.has_role(task.role.LEARNER):
return task.void()
color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"]
if metric_list is None:
metric_list = ["q_value", "target q_value", "loss", "lr", "entropy", "target_q_value", "td_error"]
# Initialize wandb with default settings
# Settings can be covered by calling wandb.init() at the top of the script
if exp_config:
if not wandb_sweep:
if run_name is not None:
if anonymous:
wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name, anonymous="must")
else:
wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name)
else:
if anonymous:
wandb.init(project=project_name, config=exp_config, reinit=True, anonymous="must")
else:
wandb.init(project=project_name, config=exp_config, reinit=True)
else:
if run_name is not None:
if anonymous:
wandb.init(project=project_name, config=exp_config, name=run_name, anonymous="must")
else:
wandb.init(project=project_name, config=exp_config, name=run_name)
else:
if anonymous:
wandb.init(project=project_name, config=exp_config, anonymous="must")
else:
wandb.init(project=project_name, config=exp_config)
else:
if not wandb_sweep:
if run_name is not None:
if anonymous:
wandb.init(project=project_name, reinit=True, name=run_name, anonymous="must")
else:
wandb.init(project=project_name, reinit=True, name=run_name)
else:
if anonymous:
wandb.init(project=project_name, reinit=True, anonymous="must")
else:
wandb.init(project=project_name, reinit=True)
else:
if run_name is not None:
if anonymous:
wandb.init(project=project_name, name=run_name, anonymous="must")
else:
wandb.init(project=project_name, name=run_name)
else:
if anonymous:
wandb.init(project=project_name, anonymous="must")
else:
wandb.init(project=project_name)
plt.switch_backend('agg')
if cfg is None:
cfg = EasyDict(
dict(
gradient_logger=False,
plot_logger=True,
video_logger=False,
action_logger=False,
return_logger=False,
)
)
else:
if not isinstance(cfg, EasyDict):
cfg = EasyDict(cfg)
for key in ["gradient_logger", "plot_logger", "video_logger", "action_logger", "return_logger", "vis_dataset"]:
if key not in cfg.keys():
cfg[key] = False
# The visualizer is called to save the replay of the simulation
# which will be uploaded to wandb later
if env is not None and cfg.video_logger is True and record_path is not None:
env.enable_save_replay(replay_path=record_path)
if cfg.gradient_logger:
wandb.watch(model, log="all", log_freq=100, log_graph=True)
else:
one_time_warning(
"If you want to use wandb to visualize the gradient, please set gradient_logger = True in the config."
)
first_plot = True
def _plot(ctx: "OnlineRLContext"):
nonlocal first_plot
if first_plot:
first_plot = False
ctx.wandb_url = wandb.run.get_project_url()
info_for_logging = {}
if cfg.plot_logger:
for metric in metric_list:
if isinstance(ctx.train_output, Dict) and metric in ctx.train_output:
if isinstance(ctx.train_output[metric], torch.Tensor):
info_for_logging.update({metric: ctx.train_output[metric].cpu().detach().numpy()})
else:
info_for_logging.update({metric: ctx.train_output[metric]})
elif isinstance(ctx.train_output, List) and len(ctx.train_output) > 0 and metric in ctx.train_output[0]:
metric_value_list = []
for item in ctx.train_output:
if isinstance(item[metric], torch.Tensor):
metric_value_list.append(item[metric].cpu().detach().numpy())
else:
metric_value_list.append(item[metric])
metric_value = np.mean(metric_value_list)
info_for_logging.update({metric: metric_value})
else:
one_time_warning(
"If you want to use wandb to visualize the result, please set plot_logger = True in the config."
)
if ctx.eval_value != -np.inf:
if hasattr(ctx, "eval_value_min"):
info_for_logging.update({
"episode return min": ctx.eval_value_min,
})
if hasattr(ctx, "eval_value_max"):
info_for_logging.update({
"episode return max": ctx.eval_value_max,
})
if hasattr(ctx, "eval_value_std"):
info_for_logging.update({
"episode return std": ctx.eval_value_std,
})
if hasattr(ctx, "eval_value"):
info_for_logging.update({
"episode return mean": ctx.eval_value,
})
if hasattr(ctx, "train_iter"):
info_for_logging.update({
"train iter": ctx.train_iter,
})
if hasattr(ctx, "env_step"):
info_for_logging.update({
"env step": ctx.env_step,
})
eval_output = ctx.eval_output['output']
episode_return = ctx.eval_output['episode_return']
episode_return = np.array(episode_return)
if len(episode_return.shape) == 2:
episode_return = episode_return.squeeze(1)
if cfg.video_logger:
if 'replay_video' in ctx.eval_output:
# save numpy array "images" of shape (N,1212,3,224,320) to N video files in mp4 format
# The numpy tensor must be either 4 dimensional or 5 dimensional.
# Channels should be (time, channel, height, width) or (batch, time, channel, height width)
video_images = ctx.eval_output['replay_video']
video_images = video_images.astype(np.uint8)
info_for_logging.update({"replay_video": wandb.Video(video_images, fps=60)})
elif record_path is not None:
file_list = []
for p in os.listdir(record_path):
if os.path.splitext(p)[-1] == ".mp4":
file_list.append(p)
file_list.sort(key=lambda fn: os.path.getmtime(os.path.join(record_path, fn)))
video_path = os.path.join(record_path, file_list[-2])
info_for_logging.update({"video": wandb.Video(video_path, format="mp4")})
if cfg.action_logger:
action_path = os.path.join(record_path, (str(ctx.env_step) + "_action.gif"))
if all(['logit' in v for v in eval_output]) or hasattr(eval_output, "logit"):
if isinstance(eval_output, tnp.ndarray):
action_prob = softmax(eval_output.logit)
else:
action_prob = [softmax(to_ndarray(v['logit'])) for v in eval_output]
fig, ax = plt.subplots()
plt.ylim([-1, 1])
action_dim = len(action_prob[1])
x_range = [str(x + 1) for x in range(action_dim)]
ln = ax.bar(x_range, [0 for x in range(action_dim)], color=color_list[:action_dim])
ani = animation.FuncAnimation(
fig, action_prob, fargs=(action_prob, ln), blit=True, save_count=len(action_prob)
)
ani.save(action_path, writer='pillow')
info_for_logging.update({"action": wandb.Video(action_path, format="gif")})
elif all(['action' in v for v in eval_output[0]]):
for i, action_trajectory in enumerate(eval_output):
fig, ax = plt.subplots()
fig_data = np.array([[i + 1, *v['action']] for i, v in enumerate(action_trajectory)])
steps = fig_data[:, 0]
actions = fig_data[:, 1:]
plt.ylim([-1, 1])
for j in range(actions.shape[1]):
ax.scatter(steps, actions[:, j])
info_for_logging.update({"actions_of_trajectory_{}".format(i): fig})
if cfg.return_logger:
return_path = os.path.join(record_path, (str(ctx.env_step) + "_return.gif"))
fig, ax = plt.subplots()
ax = plt.gca()
ax.set_ylim([0, 1])
hist, x_dim = return_distribution(episode_return)
assert len(hist) == len(x_dim)
ln_return = ax.bar(x_dim, hist, width=1, color='r', linewidth=0.7)
ani = animation.FuncAnimation(fig, return_prob, fargs=(hist, ln_return), blit=True, save_count=1)
ani.save(return_path, writer='pillow')
info_for_logging.update({"return distribution": wandb.Video(return_path, format="gif")})
if bool(info_for_logging):
wandb.log(data=info_for_logging, step=ctx.env_step)
plt.clf()
return _plot
def wandb_offline_logger(
record_path: str = None,
cfg: Union[dict, EasyDict] = None,
exp_config: Union[dict, EasyDict] = None,
metric_list: Optional[List[str]] = None,
env: Optional[BaseEnvManagerV2] = None,
model: Optional[torch.nn.Module] = None,
anonymous: bool = False,
project_name: str = 'default-project',
run_name: str = None,
wandb_sweep: bool = False,
) -> Callable:
"""
Overview:
Wandb visualizer to track the experiment.
Arguments:
- record_path (:obj:`str`): The path to save the replay of simulation.
- cfg (:obj:`Union[dict, EasyDict]`): Config, a dict of following settings:
- gradient_logger: boolean. Whether to track the gradient.
- plot_logger: boolean. Whether to track the metrics like reward and loss.
- video_logger: boolean. Whether to upload the rendering video replay.
- action_logger: boolean. `q_value` or `action probability`.
- return_logger: boolean. Whether to track the return value.
- vis_dataset: boolean. Whether to visualize the dataset.
- metric_list (:obj:`Optional[List[str]]`): Logged metric list, specialized by different policies.
- env (:obj:`BaseEnvManagerV2`): Evaluator environment.
- model (:obj:`nn.Module`): Policy neural network model.
- anonymous (:obj:`bool`): Open the anonymous mode of wandb or not. The anonymous mode allows visualization \
of data without wandb count.
- project_name (:obj:`str`): The name of wandb project.
- run_name (:obj:`str`): The name of wandb run.
- wandb_sweep (:obj:`bool`): Whether to use wandb sweep.
'''
Returns:
- _plot (:obj:`Callable`): A logger function that takes an OfflineRLContext object as input.
"""
if task.router.is_active and not task.has_role(task.role.LEARNER):
return task.void()
color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"]
if metric_list is None:
metric_list = ["q_value", "target q_value", "loss", "lr", "entropy", "target_q_value", "td_error"]
# Initialize wandb with default settings
# Settings can be covered by calling wandb.init() at the top of the script
if exp_config:
if not wandb_sweep:
if run_name is not None:
if anonymous:
wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name, anonymous="must")
else:
wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name)
else:
if anonymous:
wandb.init(project=project_name, config=exp_config, reinit=True, anonymous="must")
else:
wandb.init(project=project_name, config=exp_config, reinit=True)
else:
if run_name is not None:
if anonymous:
wandb.init(project=project_name, config=exp_config, name=run_name, anonymous="must")
else:
wandb.init(project=project_name, config=exp_config, name=run_name)
else:
if anonymous:
wandb.init(project=project_name, config=exp_config, anonymous="must")
else:
wandb.init(project=project_name, config=exp_config)
else:
if not wandb_sweep:
if run_name is not None:
if anonymous:
wandb.init(project=project_name, reinit=True, name=run_name, anonymous="must")
else:
wandb.init(project=project_name, reinit=True, name=run_name)
else:
if anonymous:
wandb.init(project=project_name, reinit=True, anonymous="must")
else:
wandb.init(project=project_name, reinit=True)
else:
if run_name is not None:
if anonymous:
wandb.init(project=project_name, name=run_name, anonymous="must")
else:
wandb.init(project=project_name, name=run_name)
else:
if anonymous:
wandb.init(project=project_name, anonymous="must")
else:
wandb.init(project=project_name)
plt.switch_backend('agg')
plt.switch_backend('agg')
if cfg is None:
cfg = EasyDict(
dict(
gradient_logger=False,
plot_logger=True,
video_logger=False,
action_logger=False,
return_logger=False,
vis_dataset=True,
)
)
else:
if not isinstance(cfg, EasyDict):
cfg = EasyDict(cfg)
for key in ["gradient_logger", "plot_logger", "video_logger", "action_logger", "return_logger", "vis_dataset"]:
if key not in cfg.keys():
cfg[key] = False
# The visualizer is called to save the replay of the simulation
# which will be uploaded to wandb later
if env is not None and cfg.video_logger is True and record_path is not None:
env.enable_save_replay(replay_path=record_path)
if cfg.gradient_logger:
wandb.watch(model, log="all", log_freq=100, log_graph=True)
else:
one_time_warning(
"If you want to use wandb to visualize the gradient, please set gradient_logger = True in the config."
)
first_plot = True
def _vis_dataset(datasetpath: str):
try:
from sklearn.manifold import TSNE
except ImportError:
import sys
logging.warning("Please install sklearn first, such as `pip3 install scikit-learn`.")
sys.exit(1)
try:
import h5py
except ImportError:
import sys
logging.warning("Please install h5py first, such as `pip3 install h5py`.")
sys.exit(1)
assert os.path.splitext(datasetpath)[-1] in ['.pkl', '.h5', '.hdf5']
if os.path.splitext(datasetpath)[-1] == '.pkl':
with open(datasetpath, 'rb') as f:
data = pickle.load(f)
obs = []
action = []
reward = []
for i in range(len(data)):
obs.extend(data[i]['observations'])
action.extend(data[i]['actions'])
reward.extend(data[i]['rewards'])
elif os.path.splitext(datasetpath)[-1] in ['.h5', '.hdf5']:
with h5py.File(datasetpath, 'r') as f:
obs = f['obs'][()]
action = f['action'][()]
reward = f['reward'][()]
cmap = plt.cm.hsv
obs = np.array(obs)
reward = np.array(reward)
obs_action = np.hstack((obs, np.array(action)))
reward = reward / (max(reward) - min(reward))
embedded_obs = TSNE(n_components=2).fit_transform(obs)
embedded_obs_action = TSNE(n_components=2).fit_transform(obs_action)
x_min, x_max = np.min(embedded_obs, 0), np.max(embedded_obs, 0)
embedded_obs = embedded_obs / (x_max - x_min)
x_min, x_max = np.min(embedded_obs_action, 0), np.max(embedded_obs_action, 0)
embedded_obs_action = embedded_obs_action / (x_max - x_min)
fig = plt.figure()
f, axes = plt.subplots(nrows=1, ncols=3)
axes[0].scatter(embedded_obs[:, 0], embedded_obs[:, 1], c=cmap(reward))
axes[1].scatter(embedded_obs[:, 0], embedded_obs[:, 1], c=cmap(action))
axes[2].scatter(embedded_obs_action[:, 0], embedded_obs_action[:, 1], c=cmap(reward))
axes[0].set_title('state-reward')
axes[1].set_title('state-action')
axes[2].set_title('stateAction-reward')
plt.savefig('dataset.png')
wandb.log({"dataset": wandb.Image("dataset.png")})
if cfg.vis_dataset is True:
_vis_dataset(exp_config.dataset_path)
def _plot(ctx: "OfflineRLContext"):
nonlocal first_plot
if first_plot:
first_plot = False
ctx.wandb_url = wandb.run.get_project_url()
info_for_logging = {}
if cfg.plot_logger:
for metric in metric_list:
if isinstance(ctx.train_output, Dict) and metric in ctx.train_output:
if isinstance(ctx.train_output[metric], torch.Tensor):
info_for_logging.update({metric: ctx.train_output[metric].cpu().detach().numpy()})
else:
info_for_logging.update({metric: ctx.train_output[metric]})
elif isinstance(ctx.train_output, List) and len(ctx.train_output) > 0 and metric in ctx.train_output[0]:
metric_value_list = []
for item in ctx.train_output:
if isinstance(item[metric], torch.Tensor):
metric_value_list.append(item[metric].cpu().detach().numpy())
else:
metric_value_list.append(item[metric])
metric_value = np.mean(metric_value_list)
info_for_logging.update({metric: metric_value})
else:
one_time_warning(
"If you want to use wandb to visualize the result, please set plot_logger = True in the config."
)
if ctx.eval_value != -np.inf:
if hasattr(ctx, "eval_value_min"):
info_for_logging.update({
"episode return min": ctx.eval_value_min,
})
if hasattr(ctx, "eval_value_max"):
info_for_logging.update({
"episode return max": ctx.eval_value_max,
})
if hasattr(ctx, "eval_value_std"):
info_for_logging.update({
"episode return std": ctx.eval_value_std,
})
if hasattr(ctx, "eval_value"):
info_for_logging.update({
"episode return mean": ctx.eval_value,
})
if hasattr(ctx, "train_iter"):
info_for_logging.update({
"train iter": ctx.train_iter,
})
if hasattr(ctx, "train_epoch"):
info_for_logging.update({
"train_epoch": ctx.train_epoch,
})
eval_output = ctx.eval_output['output']
episode_return = ctx.eval_output['episode_return']
episode_return = np.array(episode_return)
if len(episode_return.shape) == 2:
episode_return = episode_return.squeeze(1)
if cfg.video_logger:
if 'replay_video' in ctx.eval_output:
# save numpy array "images" of shape (N,1212,3,224,320) to N video files in mp4 format
# The numpy tensor must be either 4 dimensional or 5 dimensional.
# Channels should be (time, channel, height, width) or (batch, time, channel, height width)
video_images = ctx.eval_output['replay_video']
video_images = video_images.astype(np.uint8)
info_for_logging.update({"replay_video": wandb.Video(video_images, fps=60)})
elif record_path is not None:
file_list = []
for p in os.listdir(record_path):
if os.path.splitext(p)[-1] == ".mp4":
file_list.append(p)
file_list.sort(key=lambda fn: os.path.getmtime(os.path.join(record_path, fn)))
video_path = os.path.join(record_path, file_list[-2])
info_for_logging.update({"video": wandb.Video(video_path, format="mp4")})
if cfg.action_logger:
action_path = os.path.join(record_path, (str(ctx.trained_env_step) + "_action.gif"))
if all(['logit' in v for v in eval_output]) or hasattr(eval_output, "logit"):
if isinstance(eval_output, tnp.ndarray):
action_prob = softmax(eval_output.logit)
else:
action_prob = [softmax(to_ndarray(v['logit'])) for v in eval_output]
fig, ax = plt.subplots()
plt.ylim([-1, 1])
action_dim = len(action_prob[1])
x_range = [str(x + 1) for x in range(action_dim)]
ln = ax.bar(x_range, [0 for x in range(action_dim)], color=color_list[:action_dim])
ani = animation.FuncAnimation(
fig, action_prob, fargs=(action_prob, ln), blit=True, save_count=len(action_prob)
)
ani.save(action_path, writer='pillow')
info_for_logging.update({"action": wandb.Video(action_path, format="gif")})
elif all(['action' in v for v in eval_output[0]]):
for i, action_trajectory in enumerate(eval_output):
fig, ax = plt.subplots()
fig_data = np.array([[i + 1, *v['action']] for i, v in enumerate(action_trajectory)])
steps = fig_data[:, 0]
actions = fig_data[:, 1:]
plt.ylim([-1, 1])
for j in range(actions.shape[1]):
ax.scatter(steps, actions[:, j])
info_for_logging.update({"actions_of_trajectory_{}".format(i): fig})
if cfg.return_logger:
return_path = os.path.join(record_path, (str(ctx.trained_env_step) + "_return.gif"))
fig, ax = plt.subplots()
ax = plt.gca()
ax.set_ylim([0, 1])
hist, x_dim = return_distribution(episode_return)
assert len(hist) == len(x_dim)
ln_return = ax.bar(x_dim, hist, width=1, color='r', linewidth=0.7)
ani = animation.FuncAnimation(fig, return_prob, fargs=(hist, ln_return), blit=True, save_count=1)
ani.save(return_path, writer='pillow')
info_for_logging.update({"return distribution": wandb.Video(return_path, format="gif")})
if bool(info_for_logging):
wandb.log(data=info_for_logging, step=ctx.trained_env_step)
plt.clf()
return _plot