File size: 4,871 Bytes
be5548b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import json
import numpy
import re
import torch
import torch_ac
import gym

import utils


def get_obss_preprocessor(obs_space, text=None, dialogue_current=None, dialogue_history=None, custom_image_preprocessor=None, custom_image_space_preprocessor=None):
    # Check if obs_space is an image space
    if isinstance(obs_space, gym.spaces.Box):
        obs_space = {"image": obs_space.shape}

        def preprocess_obss(obss, device=None):
            assert custom_image_preprocessor is None
            return torch_ac.DictList({
                "image": preprocess_images(obss, device=device)
            })

    # Check if it is a MiniGrid observation space
    elif isinstance(obs_space, gym.spaces.Dict) and list(obs_space.spaces.keys()) == ["image"]:

        assert (custom_image_preprocessor is None) == (custom_image_space_preprocessor is None)

        image_obs_space = obs_space.spaces["image"].shape

        if custom_image_preprocessor:
            image_obs_space = custom_image_space_preprocessor(image_obs_space)

        obs_space = {"image": image_obs_space, "text": 100}

        # must be specified in this case
        if text is None:
            raise ValueError("text argument must be specified.")
        if dialogue_current is None:
            raise ValueError("dialogue current argument must be specified.")
        if dialogue_history is None:
            raise ValueError("dialogue history argument must be specified.")

        vocab = Vocabulary(obs_space["text"])
        def preprocess_obss(obss, device=None):
            if custom_image_preprocessor is None:
                D = {
                    "image": preprocess_images([obs["image"] for obs in obss], device=device)
                }
            else:
                D = {
                    "image": custom_image_preprocessor([obs["image"] for obs in obss], device=device)
                }

            if dialogue_current:
                D["utterance"] = preprocess_texts([obs["utterance"] for obs in obss], vocab, device=device)

            if dialogue_history:
                D["utterance_history"] = preprocess_texts([obs["utterance_history"] for obs in obss], vocab, device=device)

            if text:
                D["text"] = preprocess_texts([obs["mission"] for obs in obss], vocab, device=device)


            return torch_ac.DictList(D)

        preprocess_obss.vocab = vocab

    else:
        raise ValueError("Unknown observation space: " + str(obs_space))

    return obs_space, preprocess_obss

def ride_ref_image_space_preprocessor(image_space):
    return image_space

def ride_ref_image_preprocessor(images, device=None):
    # Bug of Pytorch: very slow if not first converted to numpy array

    images = numpy.array(images)

    # grid dimensions
    size = images.shape[1]
    assert size == images.shape[2]

    # assert that 1, 2 are absolute cooridnates
    # assert images[:,:,:,1].max() <= size
    # assert images[:,:,:,2].max() <= size
    assert images[:,:,:,1].min() >= 0
    assert images[:,:,:,2].min() >= 0
    #
    # # 0, 1, 2 -> door state
    # assert all([e in set([0, 1, 2]) for e in numpy.unique(images[:, :, :, 4].reshape(-1))])
    #
    # only keep the (obj id, colors, state) -> multiply others by 0
    # print(images[:, :, :, 1].max())

    images[:, :, :, 1] *= 0
    images[:, :, :, 2] *= 0

    assert images.shape[-1] == 5

    return torch.tensor(images, device=device, dtype=torch.float)

def preprocess_images(images, device=None):
    # Bug of Pytorch: very slow if not first converted to numpy array
    images = numpy.array(images)
    return torch.tensor(images, device=device, dtype=torch.float)


def preprocess_texts(texts, vocab, device=None):
    var_indexed_texts = []
    max_text_len = 0

    for text in texts:
        tokens = re.findall("([a-z]+)", text.lower())
        var_indexed_text = numpy.array([vocab[token] for token in tokens])
        var_indexed_texts.append(var_indexed_text)
        max_text_len = max(len(var_indexed_text), max_text_len)

    indexed_texts = numpy.zeros((len(texts), max_text_len))

    for i, indexed_text in enumerate(var_indexed_texts):
        indexed_texts[i, :len(indexed_text)] = indexed_text

    return torch.tensor(indexed_texts, device=device, dtype=torch.long)


class Vocabulary:
    """A mapping from tokens to ids with a capacity of `max_size` words.
    It can be saved in a `vocab.json` file."""

    def __init__(self, max_size):
        self.max_size = max_size
        self.vocab = {}

    def load_vocab(self, vocab):
        self.vocab = vocab

    def __getitem__(self, token):
        if not token in self.vocab.keys():
            if len(self.vocab) >= self.max_size:
                raise ValueError("Maximum vocabulary capacity reached")
            self.vocab[token] = len(self.vocab) + 1
        return self.vocab[token]