Spaces:
Sleeping
Sleeping
import argparse | |
import time | |
import torch | |
from torch_ac.utils.penv import ParallelEnv | |
import utils | |
from models import ACModel, RandomTalkingMultiHeadedACModel | |
# Parse arguments | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--env", required=True, | |
help="name of the environment (REQUIRED)") | |
parser.add_argument("--model", required=True, | |
help="name of the trained model (REQUIRED)") | |
parser.add_argument("--episodes", type=int, default=100, | |
help="number of episodes of evaluation (default: 100)") | |
parser.add_argument("--seed", type=int, default=0, | |
help="random seed (default: 0)") | |
parser.add_argument("--procs", type=int, default=16, | |
help="number of processes (default: 16)") | |
parser.add_argument("--argmax", action="store_true", default=False, | |
help="action with highest probability is selected") | |
parser.add_argument("--worst-episodes-to-show", type=int, default=10, | |
help="how many worst episodes to show") | |
parser.add_argument("--memory", action="store_true", default=False, | |
help="add a LSTM to the model") | |
parser.add_argument("--text", action="store_true", default=False, | |
help="add a GRU to the model") | |
parser.add_argument("--dialogue", action="store_true", default=False, | |
help="add a GRU to the model") | |
parser.add_argument("--multi-headed-agent", action="store_true", default=False, | |
help="add a talking head") | |
args = parser.parse_args() | |
# Set seed for all randomness sources | |
utils.seed(args.seed) | |
# Set device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Device: {device}\n") | |
# Load environments | |
envs = [] | |
for i in range(args.procs): | |
env = utils.make_env(args.env, args.seed + 10000 * i) | |
envs.append(env) | |
env = ParallelEnv(envs) | |
print("Environments loaded\n") | |
# Load agent | |
model_dir = utils.get_model_dir(args.model) | |
agent = utils.Agent(env.observation_space, env.action_space, model_dir, | |
device=device, argmax=args.argmax, num_envs=args.procs, | |
use_memory=args.memory, use_text=args.text, use_dialogue=args.dialogue, | |
agent_class=RandomTalkingMultiHeadedACModel if args.multi_headed_agent else ACModel | |
) | |
print("Agent loaded\n") | |
# Initialize logs | |
logs = {"num_frames_per_episode": [], "return_per_episode": []} | |
# Run agent | |
start_time = time.time() | |
obss = env.reset() | |
log_done_counter = 0 | |
log_episode_return = torch.zeros(args.procs, device=device) | |
log_episode_num_frames = torch.zeros(args.procs, device=device) | |
while log_done_counter < args.episodes: | |
actions = agent.get_actions(obss) | |
obss, rewards, dones, _ = env.step(actions) | |
agent.analyze_feedbacks(rewards, dones) | |
log_episode_return += torch.tensor(rewards, device=device, dtype=torch.float) | |
log_episode_num_frames += torch.ones(args.procs, device=device) | |
for i, done in enumerate(dones): | |
if done: | |
log_done_counter += 1 | |
logs["return_per_episode"].append(log_episode_return[i].item()) | |
logs["num_frames_per_episode"].append(log_episode_num_frames[i].item()) | |
mask = 1 - torch.tensor(dones, device=device, dtype=torch.float) | |
log_episode_return *= mask | |
log_episode_num_frames *= mask | |
end_time = time.time() | |
# Print logs | |
num_frames = sum(logs["num_frames_per_episode"]) | |
fps = num_frames/(end_time - start_time) | |
duration = int(end_time - start_time) | |
return_per_episode = utils.synthesize(logs["return_per_episode"]) | |
num_frames_per_episode = utils.synthesize(logs["num_frames_per_episode"]) | |
print("F {} | FPS {:.0f} | D {} | R:μσmM {:.2f} {:.2f} {:.2f} {:.2f} | F:μσmM {:.1f} {:.1f} {} {}" | |
.format(num_frames, fps, duration, | |
*return_per_episode.values(), | |
*num_frames_per_episode.values())) | |
# Print worst episodes | |
n = args.worst_episodes_to_show | |
if n > 0: | |
print("\n{} worst episodes:".format(n)) | |
indexes = sorted(range(len(logs["return_per_episode"])), key=lambda k: logs["return_per_episode"][k]) | |
for i in indexes[:n]: | |
print("- episode {}: R={}, F={}".format(i, logs["return_per_episode"][i], logs["num_frames_per_episode"][i])) | |