File size: 2,507 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
from ding.framework import Context, OnlineRLContext, OfflineRLContext
import random
import numpy as np
import treetensor.torch as ttorch
import torch
batch_size = 64
n_sample = 8
action_dim = 1
obs_dim = 4
logit_dim = 2
n_episodes = 2
n_episode_length = 16
update_per_collect = 4
collector_env_num = 8
# the range here is meaningless and just for test
def fake_train_data():
train_data = ttorch.as_tensor(
{
'action': torch.randint(0, 2, size=(action_dim, )),
'collect_train_iter': torch.randint(0, 100, size=(1, )),
'done': torch.tensor(False),
'env_data_id': torch.tensor([2]),
'next_obs': torch.randn(obs_dim),
'obs': torch.randn(obs_dim),
'reward': torch.randint(0, 2, size=(1, )),
}
)
return train_data
def fake_online_rl_context():
ctx = OnlineRLContext(
env_step=random.randint(0, 100),
env_episode=random.randint(0, 100),
train_iter=random.randint(0, 100),
train_data=[fake_train_data() for _ in range(batch_size)],
train_output=[{
'cur_lr': 0.001,
'total_loss': random.uniform(0, 2)
} for _ in range(update_per_collect)],
obs=torch.randn(collector_env_num, obs_dim),
action=[np.random.randint(low=0, high=1, size=(action_dim), dtype=np.int64) for _ in range(collector_env_num)],
inference_output={
env_id: {
'logit': torch.randn(logit_dim),
'action': torch.randint(0, 2, size=(action_dim, ))
}
for env_id in range(collector_env_num)
},
collect_kwargs={'eps': random.uniform(0, 1)},
trajectories=[fake_train_data() for _ in range(n_sample)],
episodes=[[fake_train_data() for _ in range(n_episode_length)] for _ in range(n_episodes)],
trajectory_end_idx=[i for i in range(n_sample)],
eval_value=random.uniform(-1.0, 1.0),
last_eval_iter=random.randint(0, 100),
)
return ctx
def fake_offline_rl_context():
ctx = OfflineRLContext(
train_epoch=random.randint(0, 100),
train_iter=random.randint(0, 100),
train_data=[fake_train_data() for _ in range(batch_size)],
train_output=[{
'cur_lr': 0.001,
'total_loss': random.uniform(0, 2)
} for _ in range(update_per_collect)],
eval_value=random.uniform(-1.0, 1.0),
last_eval_iter=random.randint(0, 100),
)
return ctx
|