import gym import torch from easydict import EasyDict from ding.config import compile_config from ding.envs import DingEnvWrapper from ding.policy import C51Policy, single_env_forward_wrapper from ding.model import C51DQN from dizoo.classic_control.cartpole.config.cartpole_c51_config import cartpole_c51_config, cartpole_c51_create_config def main(main_config: EasyDict, create_config: EasyDict, ckpt_path: str): main_config.exp_name = 'cartpole_c51_deploy' cfg = compile_config(main_config, create_cfg=create_config, auto=True) env = DingEnvWrapper(gym.make('CartPole-v0'), EasyDict(env_wrapper='default')) model = C51DQN(**cfg.policy.model) state_dict = torch.load(ckpt_path, map_location='cpu') model.load_state_dict(state_dict['model']) policy = C51Policy(cfg.policy, model=model).eval_mode forward_fn = single_env_forward_wrapper(policy.forward) obs = env.reset() returns = 0. while True: action = forward_fn(obs) obs, rew, done, info = env.step(action) returns += rew if done: break print(f'Deploy is finished, final epsiode return is: {returns}') if __name__ == "__main__": main(cartpole_c51_config, cartpole_c51_create_config, 'cartpole_c51_seed0/ckpt/ckpt_best.pth.tar')