Facepalm0 commited on
Commit
5e6fb83
·
verified ·
1 Parent(s): 0b3548a

Upload make_env.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. make_env.py +119 -0
make_env.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium
2
+ from gymnasium import spaces
3
+ import numpy as np
4
+
5
+ # Define the environment class. There is a (12*12) grid, and the agent can move in four directions and collect randomly initialized items. The items are divided into 21 categories represented by int in range(21), and the number of each category is a multiple of 4.
6
+
7
+ class GridWorldEnv(gymnasium.Env):
8
+ metadata = {'render.modes': ['human']}
9
+
10
+ def __init__(self):
11
+ super(GridWorldEnv, self).__init__()
12
+ self.grid_size = (12, 12)
13
+ self.num_categories = 21
14
+ self.action_space = spaces.Discrete(5) # 0: down, 1: right, 2: up, 3: left, 4: collect
15
+ # self.action_space = spaces.Box(low=0, high=5, shape=(1,), dtype=np.int64)
16
+ # self.observation_space = spaces.Dict({
17
+ # 'grid': spaces.Box(low=-1 / 20, high=20 / 20, shape=self.grid_size, dtype=np.float64),
18
+ # 'loc': spaces.Box(low=0 / 12, high=11 / 12, shape=(2,), dtype=np.float64),
19
+ # 'bag': spaces.Box(low=0 / 4, high=4 / 4, shape=(self.num_categories, 1), dtype=np.float64)})
20
+ self.observation_space = spaces.Box(low=-1, high=20, shape=(144+2+self.num_categories, ), dtype=np.float64)
21
+ self.max_steps = 4 * 12 * 12
22
+ self.reset()
23
+
24
+ def reset(self, seed=None):
25
+ self.grid = self.initialize_grid()
26
+ self.loc = np.random.randint(0, self.grid_size[0], 2)
27
+ self.bag = np.zeros((self.num_categories, 1), dtype=int)
28
+ self.steps = 0
29
+ return self._get_obs(), self._get_info()
30
+
31
+ def initialize_grid(self): # Randomly initialize the grid
32
+ grid = np.zeros(self.grid_size, dtype=int)
33
+ total_cells = self.grid_size[0] * self.grid_size[1]
34
+ remaining_cells = total_cells
35
+ category_counts = []
36
+ # Ensure each category has at least 4 items and is a multiple of 4
37
+ for _ in range(self.num_categories - 1):
38
+ count = np.random.randint(1, 3) * 4
39
+ if count > remaining_cells:
40
+ count = remaining_cells
41
+ category_counts.append(count)
42
+ remaining_cells -= count
43
+ # Assign the remaining cells to the last category
44
+ category_counts.append(remaining_cells)
45
+ # Shuffle the positions and assign categories
46
+ positions = np.random.permutation(total_cells)
47
+ index = 0
48
+ for category, count in enumerate(category_counts):
49
+ for _ in range(count):
50
+ x, y = divmod(positions[index], self.grid_size[1])
51
+ grid[x, y] = category
52
+ index += 1
53
+
54
+ return grid
55
+
56
+ def step(self, action):
57
+ action = int(action)
58
+ self.steps += 1
59
+ reward = -0.1 - np.sum(self.bag) / (12 * 12)
60
+ if action == 0: # down
61
+ self.loc[0] = min(self.loc[0] + 1, self.grid_size[0] - 1)
62
+ elif action == 1: # right
63
+ self.loc[1] = min(self.loc[1] + 1, self.grid_size[1] - 1)
64
+ elif action == 2: # up
65
+ self.loc[0] = max(self.loc[0] - 1, 0)
66
+ elif action == 3: # left
67
+ self.loc[1] = max(self.loc[1] - 1, 0)
68
+ elif action == 4: # collect
69
+ x, y = self.loc
70
+ if self.grid[x, y] == -1:
71
+ reward -= 2 # Penalty for collecting an empty cell
72
+ else:
73
+ self.bag[self.grid[x, y]] += 1
74
+ self.grid[x, y] = -1 # Remove the item from the grid
75
+
76
+ # 如果bag中有4个相同的物品,消除这4个物品,并获得奖励=+1
77
+ if np.any(self.bag == 4):
78
+ category = np.argmax(self.bag == 4)
79
+ self.bag[category] -= 4
80
+ reward += 1
81
+
82
+ done = bool(np.sum(self.grid) == -12 * 12) # Done when all items are collected
83
+ if done:
84
+ reward += 100
85
+ truncated = bool(self.steps >= self.max_steps) # Truncate the episode after a certain number of steps
86
+ if truncated:
87
+ reward = -3 * (np.sum(self.bag) + np.sum(self.grid != -1)) # Penalty for not collecting enough items
88
+
89
+ # print(f'grid: {self.grid}')
90
+ return self._get_obs(), reward, done, truncated, self._get_info()
91
+
92
+ def _get_obs(self):
93
+ # return {'grid': np.copy(self.grid) / 20, 'loc': np.copy(self.loc) / 12, 'bag': np.copy(self.bag) / 4}
94
+ return np.concatenate((self.grid.flatten(), self.loc, self.bag.flatten()))
95
+
96
+ def _get_info(self):
97
+ return {'grid': np.copy(self.grid) / 20, 'loc': np.copy(self.loc) / 12, 'bag': np.copy(self.bag) / 4}
98
+
99
+ def render(self, mode='rgb_array'):
100
+ if mode == 'human':
101
+ print("Grid:")
102
+ print(self.grid)
103
+ print("Loc:", self.loc)
104
+ print("Bag:", self.bag)
105
+
106
+ def close(self):
107
+ pass
108
+
109
+ if __name__ == "__main__":
110
+ # Example usage
111
+ env = GridWorldEnv()
112
+ obs = env.reset()
113
+ env.render()
114
+ obs, reward, done, info = env.step(1) # Move right
115
+ env.render()
116
+ obs, reward, done, info = env.step(0) # Move down
117
+ env.render()
118
+ obs, reward, done, info = env.step(4) # Collect item
119
+ env.render()