Spaces:
Running
Running
import argparse | |
import random | |
import warnings | |
import numpy as np | |
import time | |
import datetime | |
import torch | |
import gym_minigrid.social_ai_envs | |
import torch_ac | |
import sys | |
import json | |
import utils | |
from pathlib import Path | |
from distutils.dir_util import copy_tree | |
from utils.env import env_args_str_to_dict | |
from models import * | |
# Parse arguments | |
parser = argparse.ArgumentParser() | |
## General parameters | |
parser.add_argument("--algo", required=True, | |
help="algorithm to use: ppo (REQUIRED)") | |
parser.add_argument("--env", required=True, | |
help="name of the environment to train on (REQUIRED)") | |
parser.add_argument("--model", default=None, | |
help="name of the model (default: {ENV}_{ALGO}_{TIME})") | |
parser.add_argument("--seed", type=int, default=1, | |
help="random seed (default: 1)") | |
parser.add_argument("--log-interval", type=int, default=10, | |
help="number of updates between two logs (default: 10)") | |
parser.add_argument("--save-interval", type=int, default=10, | |
help="number of updates between two saves (default: 10, 0 means no saving)") | |
parser.add_argument("--procs", type=int, default=16, | |
help="number of processes (default: 16)") | |
parser.add_argument("--frames", type=int, default=10**7, | |
help="number of frames of training (default: 1e7)") | |
parser.add_argument("--continue-train", default=None, | |
help="path to the model to finetune", type=str) | |
parser.add_argument("--finetune-train", default=None, | |
help="path to the model to finetune", type=str) | |
parser.add_argument("--compact-save", "-cs", action="store_true", default=False, | |
help="Keep only last model save") | |
parser.add_argument("--lr-schedule-end-frames", type=int, default=0, | |
help="Learning rate will be diminished from --lr to 0 linearly over the period of --lr-schedule-end-frames (default: 0 - no diminsh)") | |
parser.add_argument("--lr-end", type=float, default=0, | |
help="the final lr that will be reached at 'lr-schedule-end-frames' (default = 0)") | |
## Periodic test parameters | |
parser.add_argument("--test-set-name", required=False, | |
help="name of the environment to test on, default use the train env", default="SocialAITestSet") | |
# parser.add_argument("--test-env", required=False, | |
# help="name of the environment to test on, default use the train env") | |
# parser.add_argument("--no-test", "-nt", action="store_true", default=False, | |
# help="don't perform periodic testing") | |
parser.add_argument("--test-seed", type=int, default=0, | |
help="random seed (default: 0)") | |
parser.add_argument("--test-episodes", type=int, default=50, | |
help="number of episodes to test") | |
parser.add_argument("--test-interval", type=int, default=-1, | |
help="number of updates between two tests (default: -1, no testing)") | |
parser.add_argument("--test-env-args", nargs='*', default="like_train_no_acl") | |
## Parameters for main algorithm | |
parser.add_argument("--acl", action="store_true", default=False, | |
help="use acl") | |
parser.add_argument("--acl-type", type=str, default=None, | |
help="acl type") | |
parser.add_argument("--acl-thresholds", nargs="+", type=float, default=(0.75, 0.75), | |
help="per phase thresholds for expert CL") | |
parser.add_argument("--acl-minimum-episodes", type=int, default=1000, | |
help="Never go to second phase before this.") | |
parser.add_argument("--acl-average-interval", type=int, default=500, | |
help="Average the perfromance estimate over this many last episodes") | |
parser.add_argument("--epochs", type=int, default=4, | |
help="number of epochs for PPO (default: 4)") | |
parser.add_argument("--exploration-bonus", action="store_true", default=False, | |
help="Use a count based exploration bonus") | |
parser.add_argument("--exploration-bonus-type", nargs="+", default=["lang"], | |
help="modality on which to use the bonus (lang/grid)") | |
parser.add_argument("--exploration-bonus-params", nargs="+", type=float, default=(30., 50.), | |
help="parameters for a count based exploration bonus (C, M)") | |
parser.add_argument("--exploration-bonus-tanh", nargs="+", type=float, default=None, | |
help="tanh expl bonus scale, None means no tanh") | |
parser.add_argument("--expert-exploration-bonus", action="store_true", default=False, | |
help="Use an expert exploration bonus") | |
parser.add_argument("--episodic-exploration-bonus", action="store_true", default=False, | |
help="Use the exploration bonus in a episodic setting") | |
parser.add_argument("--batch-size", type=int, default=256, | |
help="batch size for PPO (default: 256)") | |
parser.add_argument("--frames-per-proc", type=int, default=None, | |
help="number of frames per process before update (default: 5 for A2C and 128 for PPO)") | |
parser.add_argument("--discount", type=float, default=0.99, | |
help="discount factor (default: 0.99)") | |
parser.add_argument("--lr", type=float, default=0.001, | |
help="learning rate (default: 0.001)") | |
parser.add_argument("--gae-lambda", type=float, default=0.99, | |
help="lambda coefficient in GAE formula (default: 0.99, 1 means no gae)") | |
parser.add_argument("--entropy-coef", type=float, default=0.01, | |
help="entropy term coefficient (default: 0.01)") | |
parser.add_argument("--value-loss-coef", type=float, default=0.5, | |
help="value loss term coefficient (default: 0.5)") | |
parser.add_argument("--max-grad-norm", type=float, default=0.5, | |
help="maximum norm of gradient (default: 0.5)") | |
parser.add_argument("--optim-eps", type=float, default=1e-8, | |
help="Adam and RMSprop optimizer epsilon (default: 1e-8)") | |
parser.add_argument("--optim-alpha", type=float, default=0.99, | |
help="RMSprop optimizer alpha (default: 0.99)") | |
parser.add_argument("--clip-eps", type=float, default=0.2, | |
help="clipping epsilon for PPO (default: 0.2)") | |
parser.add_argument("--recurrence", type=int, default=1, | |
help="number of time-steps gradient is backpropagated (default: 1). If > 1, a LSTM is added to the model to have memory.") | |
parser.add_argument("--text", action="store_true", default=False, | |
help="add a GRU to the model to handle text input") | |
parser.add_argument("--dialogue", action="store_true", default=False, | |
help="add a GRU to the model to handle the history of dialogue input") | |
parser.add_argument("--current-dialogue-only", action="store_true", default=False, | |
help="add a GRU to the model to handle only the current dialogue input") | |
parser.add_argument("--multi-headed-agent", action="store_true", default=False, | |
help="add a talking head") | |
parser.add_argument("--babyai11_agent", action="store_true", default=False, | |
help="use the babyAI 1.1 agent architecture") | |
parser.add_argument("--multi-headed-babyai11-agent", action="store_true", default=False, | |
help="use the multi headed babyAI 1.1 agent architecture") | |
parser.add_argument("--custom-ppo", action="store_true", default=False, | |
help="use BabyAI original PPO hyperparameters") | |
parser.add_argument("--custom-ppo-2", action="store_true", default=False, | |
help="use BabyAI original PPO hyperparameters but with smaller memory") | |
parser.add_argument("--custom-ppo-3", action="store_true", default=False, | |
help="use BabyAI original PPO hyperparameters but with no memory") | |
parser.add_argument("--custom-ppo-rnd", action="store_true", default=False, | |
help="rnd reconstruct") | |
parser.add_argument("--custom-ppo-rnd-reference", action="store_true", default=False, | |
help="rnd reconstruct") | |
parser.add_argument("--custom-ppo-ride", action="store_true", default=False, | |
help="rnd reconstruct") | |
parser.add_argument("--custom-ppo-ride-reference", action="store_true", default=False, | |
help="rnd reconstruct") | |
parser.add_argument("--ppo-hp-tuning", action="store_true", default=False, | |
help="use PPO hyperparameters selected from our HP tuning") | |
parser.add_argument("--multi-modal-babyai11-agent", action="store_true", default=False, | |
help="use the multi headed babyAI 1.1 agent architecture") | |
# ride ref | |
parser.add_argument("--ride-ref-agent", action="store_true", default=False, | |
help="Model from the ride paper") | |
parser.add_argument("--ride-ref-preprocessor", action="store_true", default=False, | |
help="use ride reference preprocessor (3D images)") | |
parser.add_argument("--bAI-lang-model", help="lang model type for babyAI models", default="gru") | |
parser.add_argument("--memory-dim", type=int, help="memory dim (128 is small 2048 is big", default=128) | |
parser.add_argument("--clipped-rewards", action="store_true", default=False, | |
help="add a talking head") | |
parser.add_argument("--intrinsic-reward-epochs", type=int, default=0, | |
help="") | |
parser.add_argument("--balance-moa-training", action="store_true", default=False, | |
help="balance moa training to handle class imbalance.") | |
parser.add_argument("--moa-memory-dim", type=int, help="memory dim (default=128)", default=128) | |
# rnd + ride | |
parser.add_argument("--intrinsic-reward-coef", type=float, default=0.1, | |
help="") | |
parser.add_argument("--intrinsic-reward-learning-rate", type=float, default=0.0001, | |
help="") | |
parser.add_argument("--intrinsic-reward-momentum", type=float, default=0, | |
help="") | |
parser.add_argument("--intrinsic-reward-epsilon", type=float, default=0.01, | |
help="") | |
parser.add_argument("--intrinsic-reward-alpha", type=float, default=0.99, | |
help="") | |
parser.add_argument("--intrinsic-reward-max-grad-norm", type=float, default=40, | |
help="") | |
# rnd + soc_inf | |
parser.add_argument("--intrinsic-reward-loss-coef", type=float, default=0.1, | |
help="") | |
# ride | |
parser.add_argument("--intrinsic-reward-forward-loss-coef", type=float, default=10, | |
help="") | |
parser.add_argument("--intrinsic-reward-inverse-loss-coef", type=float, default=0.1, | |
help="") | |
parser.add_argument("--reset-rnd-ride-at-phase", action="store_true", default=False, | |
help="expert knowledge resets rnd ride at acl phase change") | |
# babyAI1.1 related | |
parser.add_argument("--arch", default="original_endpool_res", | |
help="image embedding architecture") | |
parser.add_argument("--num-films", type=int, default=2, | |
help="") | |
# Put all env related arguments after --env_args, e.g. --env_args nb_foo 1 is_bar True | |
parser.add_argument("--env-args", nargs='*', default=None) | |
args = parser.parse_args() | |
if args.compact_save: | |
print("Compact save is deprecated. Don't use it. It doesn't do anything now.") | |
if args.save_interval != args.log_interval: | |
print(f"save_interval ({args.save_interval}) and log_interval ({args.log_interval}) are not the same. This is not ideal for train continuation.") | |
if args.seed == -1: | |
args.seed = np.random.randint(424242) | |
if args.custom_ppo: | |
print("babyAI's ppo config") | |
assert not args.custom_ppo_2 | |
assert not args.custom_ppo_3 | |
args.frames_per_proc = 40 | |
args.lr = 1e-4 | |
args.gae_lambda = 0.99 | |
args.recurrence = 20 | |
args.optim_eps = 1e-05 | |
args.clip_eps = 0.2 | |
args.batch_size = 1280 | |
elif args.custom_ppo_2: | |
print("babyAI's ppo config with smaller memory") | |
assert not args.custom_ppo | |
assert not args.custom_ppo_3 | |
args.frames_per_proc = 40 | |
args.lr = 1e-4 | |
args.gae_lambda = 0.99 | |
args.recurrence = 10 | |
args.optim_eps = 1e-05 | |
args.clip_eps = 0.2 | |
args.batch_size = 1280 | |
elif args.custom_ppo_3: | |
print("babyAI's ppo config with no memory") | |
assert not args.custom_ppo | |
assert not args.custom_ppo_2 | |
args.frames_per_proc = 40 | |
args.lr = 1e-4 | |
args.gae_lambda = 0.99 | |
args.recurrence = 1 | |
args.optim_eps = 1e-05 | |
args.clip_eps = 0.2 | |
args.batch_size = 1280 | |
elif args.custom_ppo_rnd: | |
print("RND reconstruct") | |
assert not args.custom_ppo | |
assert not args.custom_ppo_2 | |
assert not args.custom_ppo_3 | |
args.frames_per_proc = 40 | |
args.lr = 1e-4 | |
args.recurrence = 1 | |
# args.recurrence = 5 # use 5 for SocialAI envs | |
args.batch_size = 640 | |
args.epochs = 4 | |
# args.optim_eps = 1e-05 | |
# args.entropy_coef = 0.0001 | |
args.clipped_rewards = True | |
elif args.custom_ppo_ride: | |
print("RIDE reconstruct") | |
assert not args.custom_ppo | |
assert not args.custom_ppo_2 | |
assert not args.custom_ppo_3 | |
assert not args.custom_ppo_rnd | |
args.frames_per_proc = 40 | |
args.lr = 1e-4 | |
args.recurrence = 1 | |
# args.recurrence = 5 # use 5 for SocialAI envs | |
args.batch_size = 640 | |
args.epochs = 4 | |
# args.optim_eps = 1e-05 | |
# args.entropy_coef = 0.0005 | |
args.clipped_rewards = True | |
elif args.custom_ppo_rnd_reference: | |
print("RND reconstruct") | |
assert not args.custom_ppo | |
assert not args.custom_ppo_2 | |
assert not args.custom_ppo_3 | |
args.frames_per_proc = 128 # 128 for PPO | |
args.lr = 1e-4 | |
args.recurrence = 64 | |
args.gae_lambda = 0.99 | |
args.batch_size = 1280 | |
args.epochs = 4 | |
args.optim_eps = 1e-05 | |
args.clip_eps = 0.2 | |
args.entropy_coef = 0.0001 | |
args.clipped_rewards = True | |
elif args.custom_ppo_ride_reference: | |
print("RIDE reference") | |
assert not args.custom_ppo | |
assert not args.custom_ppo_2 | |
assert not args.custom_ppo_3 | |
assert not args.custom_ppo_rnd | |
args.frames_per_proc = 128 # 128 for PPO | |
args.lr = 1e-4 | |
args.recurrence = 64 | |
args.gae_lambda = 0.99 | |
args.batch_size = 1280 | |
args.epochs = 4 | |
args.optim_eps = 1e-05 | |
args.clip_eps = 0.2 | |
args.entropy_coef = 0.0005 | |
args.clipped_rewards = True | |
elif args.ppo_hp_tuning: | |
args.frames_per_proc = 40 | |
args.lr = 1e-4 | |
args.recurrence = 5 | |
args.batch_size = 640 | |
args.epochs = 4 | |
if args.env not in [ | |
"MiniGrid-KeyCorridorS3R3-v0", | |
"MiniGrid-MultiRoom-N2-S4-v0", | |
"MiniGrid-MultiRoom-N4-S5-v0", | |
"MiniGrid-MultiRoom-N7-S4-v0", | |
"MiniGrid-MultiRoomNoisyTV-N7-S4-v0" | |
]: | |
if args.recurrence <= 1: | |
print("You are using recurrence {} with {} env. This is probably unintentional.".format(args.recurrence, args.env)) | |
# warnings.warn("You are using recurrence {} with {} env. This is probably unintentional.".format(args.recurrence, args.env)) | |
args.mem = args.recurrence > 1 | |
# Set run dir | |
date = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S") | |
default_model_name = f"{args.env}_{args.algo}_seed{args.seed}_{date}" | |
model_name = args.model or default_model_name | |
model_dir = utils.get_model_dir(model_name) | |
if Path(model_dir).exists() and args.continue_train is None: | |
raise ValueError(f"Dir {model_dir} already exists and continue train is None.") | |
# Load loggers and Tensorboard writer | |
txt_logger = utils.get_txt_logger(model_dir) | |
csv_file, csv_logger = utils.get_csv_logger(model_dir) | |
# Log command and all script arguments | |
txt_logger.info("{}\n".format(" ".join(sys.argv))) | |
txt_logger.info("{}\n".format(args)) | |
# Set seed for all randomness sources | |
utils.seed(args.seed) | |
# Set device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
txt_logger.info(f"Device: {device}\n") | |
# Create env_args dict | |
env_args = env_args_str_to_dict(args.env_args) | |
if args.acl: | |
# expert_acl = "three_stage_expert" | |
expert_acl = args.acl_type | |
print(f"Using curriculum: {expert_acl}.") | |
else: | |
expert_acl = None | |
env_args_no_acl = env_args.copy() | |
env_args["curriculum"] = expert_acl | |
env_args["expert_curriculum_thresholds"] = args.acl_thresholds | |
env_args["expert_curriculum_average_interval"] = args.acl_average_interval | |
env_args["expert_curriculum_minimum_episodes"] = args.acl_minimum_episodes | |
env_args["egocentric_observation"] = True | |
# test env args | |
if not args.test_env_args: | |
test_env_args = {} | |
elif args.test_env_args == "like_train_no_acl": | |
test_env_args = env_args_no_acl | |
elif args.test_env_args == "like_train": | |
test_env_args = env_args | |
else: | |
test_env_args = env_args_str_to_dict(args.test_env_args) | |
if "SocialAI-" not in args.env: | |
env_args = {} | |
test_env_args = {} | |
print("train_env_args:", env_args) | |
print("test_env_args:", test_env_args) | |
# Load train environments | |
envs = [] | |
for i in range(args.procs): | |
envs.append(utils.make_env(args.env, args.seed + 10000 * i, env_args=env_args)) | |
txt_logger.info("Environments loaded\n") | |
if args.continue_train and args.finetune_train: | |
raise ValueError(f"Continue path ({args.continue_train}) and finetune path ({args.finetune_train}) can't both be set.") | |
# Load training status | |
if args.continue_train: | |
if args.continue_train == "auto": | |
status_continue_path = Path(model_dir) | |
args.continue_train = status_continue_path # just in case | |
else: | |
status_continue_path = Path(args.continue_train) | |
if status_continue_path.is_dir(): | |
# if dir, assume experiment dir so append the seed | |
# status_continue_path = Path(status_continue_path) / str(args.seed) | |
status_continue_path = utils.get_status_path(status_continue_path) | |
else: | |
if not status_continue_path.is_file(): | |
raise ValueError(f"{status_continue_path} is not a file") | |
if "status" not in status_continue_path.name: | |
raise UserWarning(f"{status_continue_path} is does not contain status, is this the correct file? ") | |
status = utils.load_status(status_continue_path) | |
txt_logger.info("Training status loaded\n") | |
txt_logger.info(f"{model_name} continued from {status_continue_path}") | |
# copy everything from model_dir to backup_dir | |
assert Path(status_continue_path).is_file() | |
elif args.finetune_train: | |
status_finetune_path = Path(args.finetune_train) | |
if status_finetune_path.is_dir(): | |
# if dir, assume experiment dir so append the seed | |
status_finetune_seed_path = Path(status_finetune_path) / str(args.seed) | |
if status_finetune_seed_path.exists(): | |
# if a seed folder exists assume that you use that one | |
status_finetune_path = utils.get_status_path(status_finetune_seed_path) | |
else: | |
# if not assume that no seed folder exists | |
status_finetune_path = utils.get_status_path(status_finetune_path) | |
else: | |
if not status_finetune_path.is_file(): | |
raise ValueError(f"{status_finetune_path} is not dir or a file") | |
if "status" not in status_finetune_path.name: | |
raise UserWarning(f"{status_finetune_path} is does not contain status, is this the correct file? ") | |
status = utils.load_status(status_finetune_path) | |
txt_logger.info("Training status loaded\n") | |
txt_logger.info(f"{model_name} finetuning from {status_finetune_path}") | |
# copy everything from model_dir to backup_dir | |
assert Path(status_finetune_path).is_file() | |
# reset parameters for finetuning | |
status["num_frames"] = 0 | |
status["update"] = 0 | |
del status["optimizer_state"] | |
del status["lr_scheduler_state"] | |
del status["env_args"] | |
else: | |
status = {"num_frames": 0, "update": 0} | |
# Parameter sanity checks | |
if args.dialogue and args.current_dialogue_only: | |
raise ValueError("Either use dialogue or current-dialogue-only") | |
if not args.dialogue and not args.current_dialogue_only: | |
warnings.warn("Not using dialogue") | |
if args.text: | |
raise ValueError("Text should not be used. Use dialogue instead.") | |
# Load observations preprocessor | |
obs_space, preprocess_obss = utils.get_obss_preprocessor( | |
obs_space=envs[0].observation_space, | |
text=args.text, | |
dialogue_current=args.current_dialogue_only, | |
dialogue_history=args.dialogue, | |
custom_image_preprocessor=utils.ride_ref_image_preprocessor if args.ride_ref_preprocessor else None, | |
custom_image_space_preprocessor=utils.ride_ref_image_space_preprocessor if args.ride_ref_preprocessor else None, | |
) | |
if args.continue_train is not None or args.finetune_train is not None: | |
assert "vocab" in status | |
preprocess_obss.vocab.load_vocab(status["vocab"]) | |
txt_logger.info("Observations preprocessor loaded") | |
if args.exploration_bonus: | |
if args.expert_exploration_bonus: | |
warnings.warn("You are using expert exploration bonus.") | |
# Load model | |
assert sum(map(int, [ | |
args.multi_modal_babyai11_agent, | |
args.multi_headed_babyai11_agent, | |
args.babyai11_agent, | |
args.multi_headed_agent, | |
])) <= 1 | |
if args.multi_modal_babyai11_agent: | |
acmodel = MultiModalBaby11ACModel( | |
obs_space=obs_space, | |
action_space=envs[0].action_space, | |
arch=args.arch, | |
use_text=args.text, | |
use_dialogue=args.dialogue, | |
use_current_dialogue_only=args.current_dialogue_only, | |
use_memory=args.mem, | |
lang_model=args.bAI_lang_model, | |
memory_dim=args.memory_dim, | |
num_films=args.num_films | |
) | |
elif args.ride_ref_agent: | |
assert args.mem | |
assert not args.text | |
assert not args.dialogue | |
acmodel = RefACModel( | |
obs_space=obs_space, | |
action_space=envs[0].action_space, | |
use_memory=args.mem, | |
use_text=args.text, | |
use_dialogue=args.dialogue, | |
input_size=obs_space['image'][-1], | |
) | |
if args.current_dialogue_only: raise NotImplementedError("current dialogue only") | |
else: | |
acmodel = ACModel( | |
obs_space=obs_space, | |
action_space=envs[0].action_space, | |
use_memory=args.mem, | |
use_text=args.text, | |
use_dialogue=args.dialogue, | |
input_size=obs_space['image'][-1], | |
) | |
if args.current_dialogue_only: raise NotImplementedError("current dialogue only") | |
# if args.continue_train is not None: | |
# assert "model_state" in status | |
# acmodel.load_state_dict(status["model_state"]) | |
acmodel.to(device) | |
txt_logger.info("Model loaded\n") | |
txt_logger.info("{}\n".format(acmodel)) | |
# Load algo | |
assert args.algo == "ppo" | |
algo = torch_ac.PPOAlgo( | |
envs=envs, | |
acmodel=acmodel, | |
device=device, | |
num_frames_per_proc=args.frames_per_proc, | |
discount=args.discount, | |
lr=args.lr, | |
gae_lambda=args.gae_lambda, | |
entropy_coef=args.entropy_coef, | |
value_loss_coef=args.value_loss_coef, | |
max_grad_norm=args.max_grad_norm, | |
recurrence=args.recurrence, | |
adam_eps=args.optim_eps, | |
clip_eps=args.clip_eps, | |
epochs=args.epochs, | |
batch_size=args.batch_size, | |
preprocess_obss=preprocess_obss, | |
exploration_bonus=args.exploration_bonus, | |
exploration_bonus_tanh=args.exploration_bonus_tanh, | |
exploration_bonus_type=args.exploration_bonus_type, | |
exploration_bonus_params=args.exploration_bonus_params, | |
expert_exploration_bonus=args.expert_exploration_bonus, | |
episodic_exploration_bonus=args.episodic_exploration_bonus, | |
clipped_rewards=args.clipped_rewards, | |
# for rnd, ride, and social influence | |
intrinsic_reward_coef=args.intrinsic_reward_coef, | |
# for rnd and ride | |
intrinsic_reward_epochs=args.intrinsic_reward_epochs, | |
intrinsic_reward_learning_rate=args.intrinsic_reward_learning_rate, | |
intrinsic_reward_momentum=args.intrinsic_reward_momentum, | |
intrinsic_reward_epsilon=args.intrinsic_reward_epsilon, | |
intrinsic_reward_alpha=args.intrinsic_reward_alpha, | |
intrinsic_reward_max_grad_norm=args.intrinsic_reward_max_grad_norm, | |
# for rnd and social influence | |
intrinsic_reward_loss_coef=args.intrinsic_reward_loss_coef, | |
# for ride | |
intrinsic_reward_forward_loss_coef=args.intrinsic_reward_forward_loss_coef, | |
intrinsic_reward_inverse_loss_coef=args.intrinsic_reward_inverse_loss_coef, | |
# for social influence | |
balance_moa_training=args.balance_moa_training, | |
moa_memory_dim=args.moa_memory_dim, | |
lr_schedule_end_frames=args.lr_schedule_end_frames, | |
end_lr=args.lr_end, | |
reset_rnd_ride_at_phase=args.reset_rnd_ride_at_phase, | |
) | |
if args.continue_train or args.finetune_train: | |
algo.load_status_dict(status) | |
# txt_logger.info(f"Model + Algo loaded from {args.continue_train or args.finetune_train}\n") | |
if args.continue_train: | |
txt_logger.info(f"Model + Algo loaded from {status_continue_path} \n") | |
elif args.finetune_train: | |
txt_logger.info(f"Model + Algo loaded from {status_finetune_path} \n") | |
# todo: make nicer | |
# Set and load test environment | |
if args.test_set_name: | |
if args.test_set_name == "SocialAITestSet": | |
# "SocialAI-AskEyeContactLanguageBoxesInformationSeekingParamEnv-v1", | |
# "SocialAI-NoIntroPointingBoxesInformationSeekingParamEnv-v1" | |
test_env_names = [ | |
"SocialAI-TestLanguageColorBoxesInformationSeekingEnv-v1", | |
"SocialAI-TestLanguageFeedbackBoxesInformationSeekingEnv-v1", | |
"SocialAI-TestPointingBoxesInformationSeekingEnv-v1", | |
"SocialAI-TestEmulationBoxesInformationSeekingEnv-v1", | |
"SocialAI-TestLanguageColorSwitchesInformationSeekingEnv-v1", | |
"SocialAI-TestLanguageFeedbackSwitchesInformationSeekingEnv-v1", | |
"SocialAI-TestPointingSwitchesInformationSeekingEnv-v1", | |
"SocialAI-TestEmulationSwitchesInformationSeekingEnv-v1", | |
"SocialAI-TestLanguageColorMarbleInformationSeekingEnv-v1", | |
"SocialAI-TestLanguageFeedbackMarbleInformationSeekingEnv-v1", | |
"SocialAI-TestPointingMarbleInformationSeekingEnv-v1", | |
"SocialAI-TestEmulationMarbleInformationSeekingEnv-v1", | |
"SocialAI-TestLanguageColorGeneratorsInformationSeekingEnv-v1", | |
"SocialAI-TestLanguageFeedbackGeneratorsInformationSeekingEnv-v1", | |
"SocialAI-TestPointingGeneratorsInformationSeekingEnv-v1", | |
"SocialAI-TestEmulationGeneratorsInformationSeekingEnv-v1", | |
"SocialAI-TestLanguageColorLeversInformationSeekingEnv-v1", | |
"SocialAI-TestLanguageFeedbackLeversInformationSeekingEnv-v1", | |
"SocialAI-TestPointingLeversInformationSeekingEnv-v1", | |
"SocialAI-TestEmulationLeversInformationSeekingEnv-v1", | |
"SocialAI-TestLanguageColorDoorsInformationSeekingEnv-v1", | |
"SocialAI-TestLanguageFeedbackDoorsInformationSeekingEnv-v1", | |
"SocialAI-TestPointingDoorsInformationSeekingEnv-v1", | |
"SocialAI-TestEmulationDoorsInformationSeekingEnv-v1", | |
"SocialAI-TestLeverDoorCollaborationEnv-v1", | |
"SocialAI-TestMarblePushCollaborationEnv-v1", | |
"SocialAI-TestMarblePassCollaborationEnv-v1", | |
"SocialAI-TestAppleStealingPerspectiveTakingEnv-v1" | |
] | |
elif args.test_set_name == "SocialAIGSTestSet": | |
test_env_names = [ | |
"SocialAI-GridSearchParamEnv-v1", | |
"SocialAI-GridSearchPointingParamEnv-v1", | |
"SocialAI-GridSearchLangColorParamEnv-v1", | |
"SocialAI-GridSearchLangFeedbackParamEnv-v1", | |
] | |
elif args.test_set_name == "SocialAICuesGSTestSet": | |
test_env_names = [ | |
"SocialAI-CuesGridSearchParamEnv-v1", | |
"SocialAI-CuesGridSearchPointingParamEnv-v1", | |
"SocialAI-CuesGridSearchLangColorParamEnv-v1", | |
"SocialAI-CuesGridSearchLangFeedbackParamEnv-v1", | |
] | |
elif args.test_set_name == "BoxesPointingTestSet": | |
test_env_names = [ | |
"SocialAI-TestPointingBoxesInformationSeekingParamEnv-v1", | |
] | |
elif args.test_set_name == "PointingTestSet": | |
test_env_names = gym_minigrid.social_ai_envs.pointing_test_set | |
elif args.test_set_name == "LangColorTestSet": | |
test_env_names = gym_minigrid.social_ai_envs.language_color_test_set | |
elif args.test_set_name == "LangFeedbackTestSet": | |
test_env_names = gym_minigrid.social_ai_envs.language_feedback_test_set | |
# joint attention | |
elif args.test_set_name == "JAPointingTestSet": | |
test_env_names = gym_minigrid.social_ai_envs.ja_pointing_test_set | |
elif args.test_set_name == "JALangColorTestSet": | |
test_env_names = gym_minigrid.social_ai_envs.ja_language_color_test_set | |
elif args.test_set_name == "JALangFeedbackTestSet": | |
test_env_names = gym_minigrid.social_ai_envs.ja_language_feedback_test_set | |
# emulation | |
elif args.test_set_name == "DistrEmulationTestSet": | |
test_env_names = gym_minigrid.social_ai_envs.distr_emulation_test_set | |
elif args.test_set_name == "NoDistrEmulationTestSet": | |
test_env_names = gym_minigrid.social_ai_envs.no_distr_emulation_test_set | |
# formats | |
elif args.test_set_name == "NFormatsTestSet": | |
test_env_names = gym_minigrid.social_ai_envs.N_formats_test_set | |
elif args.test_set_name == "EFormatsTestSet": | |
test_env_names = gym_minigrid.social_ai_envs.E_formats_test_set | |
elif args.test_set_name == "AFormatsTestSet": | |
test_env_names = gym_minigrid.social_ai_envs.A_formats_test_set | |
elif args.test_set_name == "AEFormatsTestSet": | |
test_env_names = gym_minigrid.social_ai_envs.AE_formats_test_set | |
elif args.test_set_name == "RoleReversalTestSet": | |
test_env_names = gym_minigrid.social_ai_envs.role_reversal_test_set | |
else: | |
raise ValueError("Undefined test set name.") | |
else: | |
test_env_names = [args.env] | |
# test_envs = [] | |
testers = [] | |
if args.test_interval > 0: | |
for test_env_name in test_env_names: | |
make_env_args = { | |
"env_key": test_env_name, | |
"seed": args.test_seed, | |
"env_args": test_env_args, | |
} | |
testers.append(utils.Tester( | |
make_env_args, args.test_seed, args.test_episodes, model_dir, acmodel, preprocess_obss, device) | |
) | |
# test_env = utils.make_env(test_env_name, args.test_seed, env_args=test_env_args) | |
# test_envs.append(test_env) | |
# init tester | |
# testers.append(utils.Tester(test_env, args.test_seed, args.test_episodes, model_dir, acmodel, preprocess_obss, device)) | |
if args.continue_train: | |
for tester in testers: | |
tester.load() | |
# Save config | |
env_args_ = {k: v.__repr__() if k == "curriculum" else v for k, v in env_args.items()} | |
test_env_args_ = {k: v.__repr__() if k == "curriculum" else v for k, v in test_env_args.items()} | |
config_dict = { | |
"seed": args.seed, | |
"env": args.env, | |
"env_args": env_args_, | |
"test_seed": args.test_seed, | |
"test_env": args.test_set_name, | |
"test_env_args": test_env_args_ | |
} | |
config_dict.update(algo.get_config_dict()) | |
config_dict.update(acmodel.get_config_dict()) | |
with open(model_dir+'/config.json', 'w') as fp: | |
json.dump(config_dict, fp) | |
# Train model | |
num_frames = status["num_frames"] | |
update = status["update"] | |
start_time = time.time() | |
log_add_headers = num_frames == 0 or not args.continue_train | |
long_term_save_interval = 5000000 | |
if args.continue_train: | |
# set next long term save interval | |
next_long_term_save = (1 + num_frames // long_term_save_interval) * long_term_save_interval | |
else: | |
next_long_term_save = 0 # for long term logging | |
while num_frames < args.frames: | |
# Update model parameters | |
update_start_time = time.time() | |
# print("current_seed_pre_train:", np.random.get_state()[1][0]) | |
exps, logs1 = algo.collect_experiences() | |
logs2 = algo.update_parameters(exps) | |
logs = {**logs1, **logs2} | |
update_end_time = time.time() | |
num_frames += logs["num_frames"] | |
update += 1 | |
NPC_intro = np.mean(logs["NPC_introduced_to"]) | |
# Print logs | |
if update % args.log_interval == 0: | |
fps = logs["num_frames"]/(update_end_time - update_start_time) | |
duration = int(time.time() - start_time) | |
return_per_episode = utils.synthesize(logs["return_per_episode"]) | |
extrinsic_return_per_episode = utils.synthesize(logs["extrinsic_return_per_episode"]) | |
exploration_bonus_per_episode = utils.synthesize(logs["exploration_bonus_per_episode"]) | |
success_rate = utils.synthesize(logs["success_rate_per_episode"]) | |
curriculum_max_success_rate = utils.synthesize(logs["curriculum_max_mean_perf_per_episode"]) | |
curriculum_param = utils.synthesize(logs["curriculum_param_per_episode"]) | |
rreturn_per_episode = utils.synthesize(logs["reshaped_return_per_episode"]) | |
num_frames_per_episode = utils.synthesize(logs["num_frames_per_episode"]) | |
# intrinsic_reward_perf = utils.synthesize(logs["intr_reward_perf"]) | |
# intrinsic_reward_perf_ = utils.synthesize(logs["intr_reward_perf_"]) | |
intrinsic_reward_perf = logs["intr_reward_perf"] | |
intrinsic_reward_perf_ = logs["intr_reward_perf_"] | |
lr_ = logs["lr"] | |
time_now = int(datetime.datetime.now().strftime("%d%m%Y%H%M%S")) | |
header = ["update", "frames", "FPS", "duration", "time"] | |
data = [update, num_frames, fps, duration, time_now] | |
data_to_print = [update, num_frames, fps, duration, time_now] | |
header += ["success_rate_" + key for key in success_rate.keys()] | |
data += success_rate.values() | |
data_to_print += success_rate.values() | |
header += ["curriculum_max_success_rate_" + key for key in curriculum_max_success_rate.keys()] | |
data += curriculum_max_success_rate.values() | |
if args.acl: | |
data_to_print += curriculum_max_success_rate.values() | |
header += ["curriculum_param_" + key for key in curriculum_param.keys()] | |
data += curriculum_param.values() | |
if args.acl: | |
data_to_print += curriculum_param.values() | |
header += ["extrinsic_return_" + key for key in extrinsic_return_per_episode.keys()] | |
data += extrinsic_return_per_episode.values() | |
data_to_print += extrinsic_return_per_episode.values() | |
# turn on | |
header += ["exploration_bonus_" + key for key in exploration_bonus_per_episode.keys()] | |
data += exploration_bonus_per_episode.values() | |
data_to_print += exploration_bonus_per_episode.values() | |
header += ["rreturn_" + key for key in rreturn_per_episode.keys()] | |
data += rreturn_per_episode.values() | |
data_to_print += rreturn_per_episode.values() | |
header += ["intrinsic_reward_perf_"] | |
data += [intrinsic_reward_perf] | |
# data_to_print += [intrinsic_reward_perf] | |
header += ["intrinsic_reward_perf2_"] | |
data += [intrinsic_reward_perf_] | |
# data_to_print += [intrinsic_reward_perf_] | |
# header += ["num_frames_" + key for key in num_frames_per_episode.keys()] | |
# data += num_frames_per_episode.values() | |
header += ["NPC_intro"] | |
data += [NPC_intro] | |
data_to_print += [NPC_intro] | |
header += ["lr"] | |
data += [lr_] | |
data_to_print += [lr_] | |
# header += ["entropy", "value", "policy_loss", "value_loss", "grad_norm"] | |
# data += [logs["entropy"], logs["value"], logs["policy_loss"], logs["value_loss"], logs["grad_norm"]] | |
# curr_history_len = len(algo.env.envs[0].curriculum.performance_history) | |
# header += ["curr_history_len"] | |
# data += [curr_history_len] | |
txt_logger.info("".join([ | |
"U {} | F {:06} | FPS {:04.0f} | D {} | T {} ", | |
"| SR:μσmM {:.2f} {:.1f} {:.1f} {:.1f} ", | |
"| CurMaxSR:μσmM {:.2f} {:.1f} {:.1f} {:.1f} " if args.acl else "", | |
"| CurPhase:μσmM {:.2f} {:.1f} {:.1f} {:.1f} " if args.acl else "", | |
"| ExR:μσmM {:.2f} {:.1f} {:.1f} {:.1f} ", | |
"| InR:μσmM {:.2f} {:.1f} {:.1f} {:.1f} ", | |
"| rR:μσmM {:.6f} {:.1f} {:.1f} {:.1f} ", | |
# "| irp:μσmM {:.6f} {:.2f} {:.2f} {:.2f} ", | |
# "| irp_:μσmM {:.6f} {:.2f} {:.2f} {:.2f} ", | |
# "| F:μσmM {:.1f} {:.1f} {} {} ", | |
"| NPC_intro: {:.3f}", | |
"| lr: {:.5f}", | |
# "| cur_his_len: {:.5f}" if args.acl else "", | |
# "| H {:.3f} | V {:.3f} | pL {:.3f} | vL {:.3f} | ∇ {:.3f}" | |
]).format(*data_to_print)) | |
header += ["return_" + key for key in return_per_episode.keys()] | |
data += return_per_episode.values() | |
if log_add_headers: | |
csv_logger.writerow(header) | |
log_add_headers = False | |
csv_logger.writerow(data) | |
csv_file.flush() | |
# Save status | |
long_term_save = False | |
if num_frames >= next_long_term_save: | |
next_long_term_save += long_term_save_interval | |
long_term_save = True | |
if (args.save_interval > 0 and update % args.save_interval == 0) or long_term_save: | |
# continuing train works best when save_interval == log_interval, the csv is cleaner wo redundancies | |
status = {"num_frames": num_frames, "update": update} | |
algo_status = algo.get_status_dict() | |
status = {**status, **algo_status} | |
if hasattr(preprocess_obss, "vocab"): | |
status["vocab"] = preprocess_obss.vocab.vocab | |
status["env_args"] = env_args | |
if long_term_save: | |
utils.save_status(status, model_dir, num_frames=num_frames) | |
utils.save_model(acmodel, model_dir, num_frames=num_frames) | |
txt_logger.info("Status and Model saved for {} frames".format(num_frames)) | |
else: | |
utils.save_status(status, model_dir) | |
utils.save_model(acmodel, model_dir) | |
txt_logger.info("Status and Model saved") | |
if args.test_interval > 0 and (update % args.test_interval == 0 or update == 1): | |
txt_logger.info(f"Testing at update {update}.") | |
test_success_rates = [] | |
for tester in testers: | |
mean_success_rate, mean_rewards = tester.test_agent(num_frames) | |
test_success_rates.append(mean_success_rate) | |
txt_logger.info(f"\t{tester.envs[0].spec.id} -> {mean_success_rate} (SR)") | |
tester.dump() | |
if len(testers): | |
txt_logger.info(f"Test set SR: {np.array(test_success_rates).mean()}") | |
# save at the end | |
status = {"num_frames": num_frames, "update": update} | |
algo_status = algo.get_status_dict() | |
status = {**status, **algo_status} | |
if hasattr(preprocess_obss, "vocab"): | |
status["vocab"] = preprocess_obss.vocab.vocab | |
status["env_args"] = env_args | |
utils.save_status(status, model_dir) | |
utils.save_model(acmodel, model_dir) | |
txt_logger.info("Status and Model saved at the end") | |