Spaces:
Running
Running
#!/usr/bin/env python3 | |
import time | |
import argparse | |
import numpy as np | |
import gym | |
import gym_minigrid | |
from gym_minigrid.wrappers import * | |
from gym_minigrid.window import Window | |
from utils import * | |
from models import MultiModalBaby11ACModel | |
from collections import Counter | |
import torch_ac | |
import json | |
from termcolor import colored, COLORS | |
from functools import partial | |
from tkinter import * | |
from torch.distributions import Categorical | |
inter_acl = False | |
draw_tree = True | |
def redraw(img): | |
if not args.agent_view: | |
img = env.render('rgb_array', tile_size=args.tile_size, mask_unobserved=args.mask_unobserved) | |
window.show_img(img) | |
def reset(): | |
# if args.seed != -1: | |
# env.seed(args.seed) | |
obs = env.reset() | |
if hasattr(env, 'mission'): | |
print('Mission: %s' % env.mission) | |
window.set_caption(env.mission) | |
redraw(obs) | |
tot_bonus = [0] | |
prev = { | |
"prev_obs": None, | |
"prev_info": {}, | |
} | |
shortened_obj_names = { | |
'lockablebox' : 'loc_box', | |
'applegenerator' : 'app_gen', | |
'generatorplatform': 'gen_pl', | |
'marbletee' : 'tee', | |
'remotedoor' : 'rem_door', | |
} | |
IDX_TO_OBJECT = {v: shortened_obj_names.get(k, k) for k, v in OBJECT_TO_IDX.items()} | |
# no duplicates | |
assert len(IDX_TO_OBJECT) == len(OBJECT_TO_IDX) | |
IDX_TO_COLOR = {v: k for k, v in COLOR_TO_IDX.items()} | |
assert len(IDX_TO_COLOR) == len(COLOR_TO_IDX) | |
# def to_string(enc): | |
# s = "{:<8} {} {} {} {} {:3} {:3} {}\t".format( | |
# IDX_TO_OBJECT.get(enc[0], enc[0]), # obj | |
# *enc[1:3], # x, y | |
# IDX_TO_COLOR.get(enc[3], enc[3])[:1].upper(), # color | |
# *enc[4:] # | |
# ) | |
# | |
# if IDX_TO_OBJECT.get(enc[0], enc[0]) == "unseen": | |
# pass | |
# # s = colored(s, "on_grey") | |
# | |
# elif IDX_TO_OBJECT.get(enc[0], enc[0]) != "empty": | |
# col = IDX_TO_COLOR.get(enc[3], enc[3]) | |
# if col in COLORS: | |
# s = colored(s, col) | |
# | |
# return s | |
def step(action): | |
if type(action) == np.ndarray: | |
obs, reward, done, info = env.step(action) | |
else: | |
action = [int(action), np.nan, np.nan] | |
obs, reward, done, info = env.step(action) | |
redraw(obs) | |
if done: | |
print('done!') | |
print('Reward=%.2f' % (reward)) | |
print('Exploration_bonus=%.2f' % (tot_bonus[0])) | |
tot_bonus[0] = 0 | |
with open(output_file, "a") as f: | |
if reward > 0: | |
f.write("Success!\n") | |
f.write("New episode.\n") | |
reset() | |
else: | |
print('\nStep=%s' % (env.step_count)) | |
# print to screen | |
print("Obs : ", end="") | |
print("".join(info["descriptions"]), end="") | |
if obs["utterance_history"] != "Conversation: \n": | |
print(obs['utterance_history']) | |
print("Act : ", end="") | |
# write to file | |
with open(output_file, "a") as f: | |
f.write("Obs : ") | |
f.write("".join(info["descriptions"])) | |
if obs["utterance_history"] != "Conversation: \n": | |
f.write(obs['utterance_history']) | |
# f.write("Your possible actions are:\n") | |
# f.write("(a) move forward\n") | |
# f.write("(b) turn left\n") | |
# f.write("(c) turn right\n") | |
# f.write("(d) toggle\n") | |
# f.write("(e) no_op\n") | |
f.write("Act : ") | |
print('Full reward (undiminshed)=%.2f' % (reward)) | |
def key_handler(event): | |
# if hasattr(event.canvas, "_event_loop") and event.canvas._event_loop.isRunning(): | |
# return | |
print('pressed', event.key) | |
action_dict = { | |
"up": "a) move forward", | |
"left": "b) turn left", | |
"right": "c) turn right", | |
" ": "d) toggle", | |
"shift": "e) no_op", | |
} | |
action_dict = { | |
"up": "move forward", | |
"left": "turn left", | |
"right": "turn right", | |
" ": "toggle", | |
"shift": "no_op", | |
} | |
if event.key in action_dict: | |
your_action = action_dict[event.key] | |
with open(output_file, "a") as f: | |
f.write("{}\n".format(your_action)) | |
if event.key == 'escape': | |
window.close() | |
return | |
if event.key == 'r': | |
reset() | |
return | |
if event.key == 'tab': | |
step(np.array([np.nan, np.nan, np.nan])) | |
return | |
if event.key == 'shift': | |
step(np.array([np.nan, np.nan, np.nan])) | |
return | |
if event.key == 'left': | |
step(env.actions.left) | |
return | |
if event.key == 'right': | |
step(env.actions.right) | |
return | |
if event.key == 'up': | |
step(env.actions.forward) | |
return | |
if event.key == 't': | |
step(env.actions.speak) | |
return | |
if event.key == '1': | |
step(np.array([np.nan, 0, 0])) | |
return | |
if event.key == '2': | |
step(np.array([np.nan, 0, 1])) | |
return | |
if event.key == '3': | |
step(np.array([np.nan, 1, 0])) | |
return | |
if event.key == '4': | |
step(np.array([np.nan, 1, 1])) | |
return | |
if event.key == '5': | |
step(np.array([np.nan, 2, 2])) | |
return | |
if event.key == '6': | |
step(np.array([np.nan, 1, 2])) | |
return | |
if event.key == '7': | |
step(np.array([np.nan, 2, 1])) | |
return | |
if event.key == '8': | |
step(np.array([np.nan, 1, 3])) | |
return | |
if event.key == 'p': | |
step(np.array([np.nan, 3, 3])) | |
return | |
# Spacebar | |
if event.key == ' ': | |
step(env.actions.toggle) | |
return | |
if event.key == '9': | |
step(env.actions.pickup) | |
return | |
if event.key == '0': | |
step(env.actions.drop) | |
return | |
if event.key == 'enter': | |
step(env.actions.done) | |
return | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--env", | |
help="gym environment to load", | |
# default="SocialAI-AsocialBoxInformationSeekingParamEnv-v1", | |
default="SocialAI-ColorBoxesLLMCSParamEnv-v1", | |
) | |
parser.add_argument( | |
"--seed", | |
type=int, | |
help="random seed to generate the environment with", | |
default=-1 | |
) | |
parser.add_argument( | |
"--tile_size", | |
type=int, | |
help="size at which to render tiles", | |
default=32 | |
) | |
parser.add_argument( | |
'--agent_view', | |
default=False, | |
help="draw the agent sees (partially observable view)", | |
action='store_true' | |
) | |
parser.add_argument( | |
'--print_grid', | |
default=False, | |
help="print the grid with symbols", | |
action='store_true' | |
) | |
parser.add_argument( | |
'--calc-bonus', | |
default=False, | |
help="calculate explo bonus", | |
action='store_true' | |
) | |
parser.add_argument( | |
'--mask-unobserved', | |
default=False, | |
help="mask cells that are not observed by the agent", | |
action='store_true' | |
) | |
parser.add_argument( | |
'--output-file', | |
default="./llm_data/in_context_color_test.txt", | |
help="file where to save episodes", | |
) | |
# Put all env related arguments after --env_args, e.g. --env_args nb_foo 1 is_bar True | |
parser.add_argument("--env-args", nargs='*', default=None) | |
args = parser.parse_args() | |
output_file=args.output_file | |
env = gym.make(args.env, **env_args_str_to_dict(args.env_args)) | |
if draw_tree: | |
# draw tree | |
env.parameter_tree.draw_tree( | |
filename="viz/SocialAIParam/{}_raw_tree".format(args.env), | |
ignore_labels=["Num_of_colors"], | |
) | |
if args.seed >= 0: | |
env.seed(args.seed) | |
with open(output_file, "a") as f: | |
f.write("New episode.\n") | |
window = Window('gym_minigrid - ' + args.env, figsize=(4, 4)) | |
window.reg_key_handler(key_handler) | |
env.window = window | |
# Blocking event loop | |
window.show(block=True) | |