Spaces:
Running
Running
File size: 4,871 Bytes
be5548b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os
import json
import numpy
import re
import torch
import torch_ac
import gym
import utils
def get_obss_preprocessor(obs_space, text=None, dialogue_current=None, dialogue_history=None, custom_image_preprocessor=None, custom_image_space_preprocessor=None):
# Check if obs_space is an image space
if isinstance(obs_space, gym.spaces.Box):
obs_space = {"image": obs_space.shape}
def preprocess_obss(obss, device=None):
assert custom_image_preprocessor is None
return torch_ac.DictList({
"image": preprocess_images(obss, device=device)
})
# Check if it is a MiniGrid observation space
elif isinstance(obs_space, gym.spaces.Dict) and list(obs_space.spaces.keys()) == ["image"]:
assert (custom_image_preprocessor is None) == (custom_image_space_preprocessor is None)
image_obs_space = obs_space.spaces["image"].shape
if custom_image_preprocessor:
image_obs_space = custom_image_space_preprocessor(image_obs_space)
obs_space = {"image": image_obs_space, "text": 100}
# must be specified in this case
if text is None:
raise ValueError("text argument must be specified.")
if dialogue_current is None:
raise ValueError("dialogue current argument must be specified.")
if dialogue_history is None:
raise ValueError("dialogue history argument must be specified.")
vocab = Vocabulary(obs_space["text"])
def preprocess_obss(obss, device=None):
if custom_image_preprocessor is None:
D = {
"image": preprocess_images([obs["image"] for obs in obss], device=device)
}
else:
D = {
"image": custom_image_preprocessor([obs["image"] for obs in obss], device=device)
}
if dialogue_current:
D["utterance"] = preprocess_texts([obs["utterance"] for obs in obss], vocab, device=device)
if dialogue_history:
D["utterance_history"] = preprocess_texts([obs["utterance_history"] for obs in obss], vocab, device=device)
if text:
D["text"] = preprocess_texts([obs["mission"] for obs in obss], vocab, device=device)
return torch_ac.DictList(D)
preprocess_obss.vocab = vocab
else:
raise ValueError("Unknown observation space: " + str(obs_space))
return obs_space, preprocess_obss
def ride_ref_image_space_preprocessor(image_space):
return image_space
def ride_ref_image_preprocessor(images, device=None):
# Bug of Pytorch: very slow if not first converted to numpy array
images = numpy.array(images)
# grid dimensions
size = images.shape[1]
assert size == images.shape[2]
# assert that 1, 2 are absolute cooridnates
# assert images[:,:,:,1].max() <= size
# assert images[:,:,:,2].max() <= size
assert images[:,:,:,1].min() >= 0
assert images[:,:,:,2].min() >= 0
#
# # 0, 1, 2 -> door state
# assert all([e in set([0, 1, 2]) for e in numpy.unique(images[:, :, :, 4].reshape(-1))])
#
# only keep the (obj id, colors, state) -> multiply others by 0
# print(images[:, :, :, 1].max())
images[:, :, :, 1] *= 0
images[:, :, :, 2] *= 0
assert images.shape[-1] == 5
return torch.tensor(images, device=device, dtype=torch.float)
def preprocess_images(images, device=None):
# Bug of Pytorch: very slow if not first converted to numpy array
images = numpy.array(images)
return torch.tensor(images, device=device, dtype=torch.float)
def preprocess_texts(texts, vocab, device=None):
var_indexed_texts = []
max_text_len = 0
for text in texts:
tokens = re.findall("([a-z]+)", text.lower())
var_indexed_text = numpy.array([vocab[token] for token in tokens])
var_indexed_texts.append(var_indexed_text)
max_text_len = max(len(var_indexed_text), max_text_len)
indexed_texts = numpy.zeros((len(texts), max_text_len))
for i, indexed_text in enumerate(var_indexed_texts):
indexed_texts[i, :len(indexed_text)] = indexed_text
return torch.tensor(indexed_texts, device=device, dtype=torch.long)
class Vocabulary:
"""A mapping from tokens to ids with a capacity of `max_size` words.
It can be saved in a `vocab.json` file."""
def __init__(self, max_size):
self.max_size = max_size
self.vocab = {}
def load_vocab(self, vocab):
self.vocab = vocab
def __getitem__(self, token):
if not token in self.vocab.keys():
if len(self.vocab) >= self.max_size:
raise ValueError("Maximum vocabulary capacity reached")
self.vocab[token] = len(self.vocab) + 1
return self.vocab[token]
|