zjowowen's picture
init space
079c32c
raw
history blame
2.61 kB
from typing import Union, Optional, List, Any, Callable, Tuple
import pickle
import torch
from functools import partial
from ding.config import compile_config, read_config
from ding.envs import get_vec_env_setting
from ding.policy import create_policy
from ding.utils import set_pkg_seed
def eval(
input_cfg: Union[str, Tuple[dict, dict]],
seed: int = 0,
env_setting: Optional[List[Any]] = None,
model: Optional[torch.nn.Module] = None,
state_dict: Optional[dict] = None,
) -> float:
r"""
Overview:
Pure evaluation entry.
Arguments:
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
``str`` type means config file path. \
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
``BaseEnv`` subclass, collector env config, and evaluator env config.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- state_dict (:obj:`Optional[dict]`): The state_dict of policy or model.
"""
if isinstance(input_cfg, str):
cfg, create_cfg = read_config(input_cfg)
else:
cfg, create_cfg = input_cfg
create_cfg.policy.type += '_command'
cfg = compile_config(cfg, auto=True, create_cfg=create_cfg)
env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env)
env = env_fn(evaluator_env_cfg[0])
env.seed(seed, dynamic_seed=False)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
policy = create_policy(cfg.policy, model=model, enable_field=['eval']).eval_mode
if state_dict is None:
state_dict = torch.load(cfg.learner.load_path, map_location='cpu')
policy.load_state_dict(state_dict)
obs = env.reset()
episode_return = 0.
while True:
policy_output = policy.forward({0: obs})
action = policy_output[0]['action']
print(action)
timestep = env.step(action)
episode_return += timestep.reward
obs = timestep.obs
if timestep.done:
print(timestep.info)
break
env.save_replay(replay_dir='.', prefix=env._map_name)
print('Eval is over! The performance of your RL policy is {}'.format(episode_return))
if __name__ == "__main__":
path = '../exp/MMM/qmix/1/ckpt_BaseLearner_Wed_Jul_14_22_16_56_2021/iteration_9900.pth.tar'
cfg = '../config/smac_MMM_qmix_config.py'
state_dict = torch.load(path, map_location='cpu')
eval(cfg, seed=0, state_dict=state_dict)