File size: 1,532 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import pytest
from lzero.entry import eval_muzero
from test_atari_sampled_efficientzero_config import create_config, main_config
from gym.wrappers import RecordVideo
@pytest.mark.envtest
class TestAtariLightZeroEnvVisualization:
def test_naive_env(self):
import gym, random
env = gym.make('BreakoutNoFrameskip-v4', render_mode='human')
env = RecordVideo(env, video_folder='./', name_prefix='navie')
env.reset()
score=0
while True:
action = random.choice([0,1,2,3])
obs, reward, done, info = env.step(action)
score+=reward
if done:
break
print('Score:{}'.format(score))
env.close()
def test_lightzero_env(self):
create_config.env_manager.type = 'base' # Visualization requires the 'type' to be set as base
main_config.env.evaluator_env_num = 1 # Visualization requires the 'env_num' to be set as 1
main_config.env.n_evaluator_episode = 2
main_config.env.render_mode_human = True
main_config.env.save_video = True
main_config.env.save_path = './'
main_config.env.eval_max_episode_steps=int(1e2) # Set as needed
model_path = "/path/ckpt/ckpt_best.pth.tar"
returns_mean, returns = eval_muzero(
[main_config, create_config],
seed=0,
num_episodes_each_seed=1,
print_seed_details=False,
model_path=model_path
)
print(returns_mean, returns)
|