|
from typing import Union, Optional, Tuple |
|
import os |
|
from functools import partial |
|
from copy import deepcopy |
|
|
|
import easydict |
|
import torch |
|
import numpy as np |
|
from tensorboardX import SummaryWriter |
|
from torch.utils.data import DataLoader, Dataset |
|
|
|
from ding.envs import get_vec_env_setting, create_env_manager |
|
from ding.worker import BaseLearner, InteractionSerialEvaluator |
|
from ding.config import read_config, compile_config |
|
from ding.policy import create_policy |
|
from ding.utils import set_pkg_seed |
|
from dizoo.maze.envs.maze_env import Maze |
|
|
|
|
|
|
|
def get_vi_sequence(env, observation): |
|
"""Returns [L, W, W] optimal actions.""" |
|
xy = np.where(observation[Ellipsis, -1] == 1) |
|
start_x, start_y = xy[0][0], xy[1][0] |
|
target_location = env.target_location |
|
nav_map = env.nav_map |
|
current_points = [target_location] |
|
chosen_actions = {target_location: 0} |
|
visited_points = {target_location: True} |
|
vi_sequence = [] |
|
|
|
vi_map = np.full((env.size, env.size), fill_value=env.n_action, dtype=np.int32) |
|
|
|
found_start = False |
|
while current_points and not found_start: |
|
next_points = [] |
|
for point_x, point_y in current_points: |
|
for (action, (next_point_x, next_point_y)) in [(0, (point_x - 1, point_y)), (1, (point_x, point_y - 1)), |
|
(2, (point_x + 1, point_y)), (3, (point_x, point_y + 1))]: |
|
|
|
if (next_point_x, next_point_y) in visited_points: |
|
continue |
|
|
|
if not (0 <= next_point_x < len(nav_map) and 0 <= next_point_y < len(nav_map[next_point_x])): |
|
continue |
|
|
|
if nav_map[next_point_x][next_point_y] == 'x': |
|
continue |
|
|
|
next_points.append((next_point_x, next_point_y)) |
|
visited_points[(next_point_x, next_point_y)] = True |
|
chosen_actions[(next_point_x, next_point_y)] = action |
|
vi_map[next_point_x, next_point_y] = action |
|
|
|
if next_point_x == start_x and next_point_y == start_y: |
|
found_start = True |
|
vi_sequence.append(vi_map.copy()) |
|
current_points = next_points |
|
track_back = [] |
|
if found_start: |
|
cur_x, cur_y = start_x, start_y |
|
while cur_x != target_location[0] or cur_y != target_location[1]: |
|
act = vi_sequence[-1][cur_x, cur_y] |
|
track_back.append((torch.FloatTensor(env.process_states([cur_x, cur_y], env.get_maze_map())), act)) |
|
if act == 0: |
|
cur_x += 1 |
|
elif act == 1: |
|
cur_y += 1 |
|
elif act == 2: |
|
cur_x -= 1 |
|
elif act == 3: |
|
cur_y -= 1 |
|
|
|
return np.array(vi_sequence), track_back |
|
|
|
|
|
class BCDataset(Dataset): |
|
|
|
def __init__(self, all_data): |
|
self._data = all_data |
|
|
|
def __getitem__(self, item): |
|
return {'obs': self._data[item][0], 'action': self._data[item][1]} |
|
|
|
def __len__(self): |
|
return len(self._data) |
|
|
|
|
|
def load_bc_dataset(train_seeds=1, test_seeds=1, batch_size=32): |
|
|
|
def load_env(seed): |
|
ccc = easydict.EasyDict({'size': 16}) |
|
e = Maze(ccc) |
|
e.seed(seed) |
|
e.reset() |
|
return e |
|
|
|
envs = [load_env(i) for i in range(train_seeds + test_seeds)] |
|
data_train = [] |
|
data_test = [] |
|
|
|
for idx, env in enumerate(envs): |
|
if idx < train_seeds: |
|
data = data_train |
|
else: |
|
data = data_test |
|
|
|
start_obs = env.process_states(env._get_obs(), env.get_maze_map()) |
|
_, track_back = get_vi_sequence(env, start_obs) |
|
|
|
data += track_back |
|
|
|
train_data = BCDataset(data_train) |
|
test_data = BCDataset(data_test) |
|
|
|
train_dataset = DataLoader(train_data, batch_size=batch_size, shuffle=True) |
|
test_dataset = DataLoader(test_data, batch_size=batch_size, shuffle=True) |
|
return train_dataset, test_dataset |
|
|
|
|
|
def serial_pipeline_bc( |
|
input_cfg: Union[str, Tuple[dict, dict]], |
|
seed: int = 0, |
|
model: Optional[torch.nn.Module] = None, |
|
max_iter=int(1e6), |
|
) -> Union['Policy', bool]: |
|
r""" |
|
Overview: |
|
Serial pipeline entry of imitation learning. |
|
Arguments: |
|
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ |
|
``str`` type means config file path. \ |
|
``Tuple[dict, dict]`` type means [user_config, create_cfg]. |
|
- seed (:obj:`int`): Random seed. |
|
- data_path (:obj:`str`): Path of training data. |
|
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. |
|
Returns: |
|
- policy (:obj:`Policy`): Converged policy. |
|
- convergence (:obj:`bool`): whether il training is converged |
|
""" |
|
if isinstance(input_cfg, str): |
|
cfg, create_cfg = read_config(input_cfg) |
|
else: |
|
cfg, create_cfg = deepcopy(input_cfg) |
|
cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg) |
|
|
|
|
|
env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env) |
|
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) |
|
|
|
evaluator_env.seed(cfg.seed, dynamic_seed=False) |
|
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) |
|
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'eval']) |
|
|
|
|
|
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) |
|
dataloader, test_dataloader = load_bc_dataset() |
|
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) |
|
evaluator = InteractionSerialEvaluator( |
|
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name |
|
) |
|
|
|
|
|
|
|
|
|
learner.call_hook('before_run') |
|
stop = False |
|
iter_cnt = 0 |
|
for epoch in range(cfg.policy.learn.train_epoch): |
|
|
|
loss_list = [] |
|
for _, bat in enumerate(test_dataloader): |
|
bat['action'] = bat['action'].long() |
|
res = policy._forward_eval(bat['obs']) |
|
res = torch.argmax(res['logit'], dim=1) |
|
loss_list.append(torch.sum(res == bat['action'].squeeze(-1)).item() / bat['action'].shape[0]) |
|
label = 'validation_acc' |
|
tb_logger.add_scalar(label, sum(loss_list) / len(loss_list), iter_cnt) |
|
for i, train_data in enumerate(dataloader): |
|
if evaluator.should_eval(learner.train_iter): |
|
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter) |
|
if stop: |
|
break |
|
train_data['action'] = train_data['action'].long() |
|
learner.train(train_data) |
|
iter_cnt += 1 |
|
if iter_cnt >= max_iter: |
|
stop = True |
|
break |
|
if stop: |
|
break |
|
|
|
learner.call_hook('after_run') |
|
print('final reward is: {}'.format(reward)) |
|
return policy, stop |
|
|
|
|
|
if __name__ == '__main__': |
|
from dizoo.maze.config.maze_bc_config import main_config, create_config |
|
serial_pipeline_bc([main_config, create_config], seed=0) |
|
|