from ding.framework import Context, OnlineRLContext, OfflineRLContext import random import numpy as np import treetensor.torch as ttorch import torch batch_size = 64 n_sample = 8 action_dim = 1 obs_dim = 4 logit_dim = 2 n_episodes = 2 n_episode_length = 16 update_per_collect = 4 collector_env_num = 8 # the range here is meaningless and just for test def fake_train_data(): train_data = ttorch.as_tensor( { 'action': torch.randint(0, 2, size=(action_dim, )), 'collect_train_iter': torch.randint(0, 100, size=(1, )), 'done': torch.tensor(False), 'env_data_id': torch.tensor([2]), 'next_obs': torch.randn(obs_dim), 'obs': torch.randn(obs_dim), 'reward': torch.randint(0, 2, size=(1, )), } ) return train_data def fake_online_rl_context(): ctx = OnlineRLContext( env_step=random.randint(0, 100), env_episode=random.randint(0, 100), train_iter=random.randint(0, 100), train_data=[fake_train_data() for _ in range(batch_size)], train_output=[{ 'cur_lr': 0.001, 'total_loss': random.uniform(0, 2) } for _ in range(update_per_collect)], obs=torch.randn(collector_env_num, obs_dim), action=[np.random.randint(low=0, high=1, size=(action_dim), dtype=np.int64) for _ in range(collector_env_num)], inference_output={ env_id: { 'logit': torch.randn(logit_dim), 'action': torch.randint(0, 2, size=(action_dim, )) } for env_id in range(collector_env_num) }, collect_kwargs={'eps': random.uniform(0, 1)}, trajectories=[fake_train_data() for _ in range(n_sample)], episodes=[[fake_train_data() for _ in range(n_episode_length)] for _ in range(n_episodes)], trajectory_end_idx=[i for i in range(n_sample)], eval_value=random.uniform(-1.0, 1.0), last_eval_iter=random.randint(0, 100), ) return ctx def fake_offline_rl_context(): ctx = OfflineRLContext( train_epoch=random.randint(0, 100), train_iter=random.randint(0, 100), train_data=[fake_train_data() for _ in range(batch_size)], train_output=[{ 'cur_lr': 0.001, 'total_loss': random.uniform(0, 2) } for _ in range(update_per_collect)], eval_value=random.uniform(-1.0, 1.0), last_eval_iter=random.randint(0, 100), ) return ctx