Upload make_env.py with huggingface_hub
Browse files- make_env.py +119 -0
make_env.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gymnasium
|
2 |
+
from gymnasium import spaces
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
# Define the environment class. There is a (12*12) grid, and the agent can move in four directions and collect randomly initialized items. The items are divided into 21 categories represented by int in range(21), and the number of each category is a multiple of 4.
|
6 |
+
|
7 |
+
class GridWorldEnv(gymnasium.Env):
|
8 |
+
metadata = {'render.modes': ['human']}
|
9 |
+
|
10 |
+
def __init__(self):
|
11 |
+
super(GridWorldEnv, self).__init__()
|
12 |
+
self.grid_size = (12, 12)
|
13 |
+
self.num_categories = 21
|
14 |
+
self.action_space = spaces.Discrete(5) # 0: down, 1: right, 2: up, 3: left, 4: collect
|
15 |
+
# self.action_space = spaces.Box(low=0, high=5, shape=(1,), dtype=np.int64)
|
16 |
+
# self.observation_space = spaces.Dict({
|
17 |
+
# 'grid': spaces.Box(low=-1 / 20, high=20 / 20, shape=self.grid_size, dtype=np.float64),
|
18 |
+
# 'loc': spaces.Box(low=0 / 12, high=11 / 12, shape=(2,), dtype=np.float64),
|
19 |
+
# 'bag': spaces.Box(low=0 / 4, high=4 / 4, shape=(self.num_categories, 1), dtype=np.float64)})
|
20 |
+
self.observation_space = spaces.Box(low=-1, high=20, shape=(144+2+self.num_categories, ), dtype=np.float64)
|
21 |
+
self.max_steps = 4 * 12 * 12
|
22 |
+
self.reset()
|
23 |
+
|
24 |
+
def reset(self, seed=None):
|
25 |
+
self.grid = self.initialize_grid()
|
26 |
+
self.loc = np.random.randint(0, self.grid_size[0], 2)
|
27 |
+
self.bag = np.zeros((self.num_categories, 1), dtype=int)
|
28 |
+
self.steps = 0
|
29 |
+
return self._get_obs(), self._get_info()
|
30 |
+
|
31 |
+
def initialize_grid(self): # Randomly initialize the grid
|
32 |
+
grid = np.zeros(self.grid_size, dtype=int)
|
33 |
+
total_cells = self.grid_size[0] * self.grid_size[1]
|
34 |
+
remaining_cells = total_cells
|
35 |
+
category_counts = []
|
36 |
+
# Ensure each category has at least 4 items and is a multiple of 4
|
37 |
+
for _ in range(self.num_categories - 1):
|
38 |
+
count = np.random.randint(1, 3) * 4
|
39 |
+
if count > remaining_cells:
|
40 |
+
count = remaining_cells
|
41 |
+
category_counts.append(count)
|
42 |
+
remaining_cells -= count
|
43 |
+
# Assign the remaining cells to the last category
|
44 |
+
category_counts.append(remaining_cells)
|
45 |
+
# Shuffle the positions and assign categories
|
46 |
+
positions = np.random.permutation(total_cells)
|
47 |
+
index = 0
|
48 |
+
for category, count in enumerate(category_counts):
|
49 |
+
for _ in range(count):
|
50 |
+
x, y = divmod(positions[index], self.grid_size[1])
|
51 |
+
grid[x, y] = category
|
52 |
+
index += 1
|
53 |
+
|
54 |
+
return grid
|
55 |
+
|
56 |
+
def step(self, action):
|
57 |
+
action = int(action)
|
58 |
+
self.steps += 1
|
59 |
+
reward = -0.1 - np.sum(self.bag) / (12 * 12)
|
60 |
+
if action == 0: # down
|
61 |
+
self.loc[0] = min(self.loc[0] + 1, self.grid_size[0] - 1)
|
62 |
+
elif action == 1: # right
|
63 |
+
self.loc[1] = min(self.loc[1] + 1, self.grid_size[1] - 1)
|
64 |
+
elif action == 2: # up
|
65 |
+
self.loc[0] = max(self.loc[0] - 1, 0)
|
66 |
+
elif action == 3: # left
|
67 |
+
self.loc[1] = max(self.loc[1] - 1, 0)
|
68 |
+
elif action == 4: # collect
|
69 |
+
x, y = self.loc
|
70 |
+
if self.grid[x, y] == -1:
|
71 |
+
reward -= 2 # Penalty for collecting an empty cell
|
72 |
+
else:
|
73 |
+
self.bag[self.grid[x, y]] += 1
|
74 |
+
self.grid[x, y] = -1 # Remove the item from the grid
|
75 |
+
|
76 |
+
# 如果bag中有4个相同的物品,消除这4个物品,并获得奖励=+1
|
77 |
+
if np.any(self.bag == 4):
|
78 |
+
category = np.argmax(self.bag == 4)
|
79 |
+
self.bag[category] -= 4
|
80 |
+
reward += 1
|
81 |
+
|
82 |
+
done = bool(np.sum(self.grid) == -12 * 12) # Done when all items are collected
|
83 |
+
if done:
|
84 |
+
reward += 100
|
85 |
+
truncated = bool(self.steps >= self.max_steps) # Truncate the episode after a certain number of steps
|
86 |
+
if truncated:
|
87 |
+
reward = -3 * (np.sum(self.bag) + np.sum(self.grid != -1)) # Penalty for not collecting enough items
|
88 |
+
|
89 |
+
# print(f'grid: {self.grid}')
|
90 |
+
return self._get_obs(), reward, done, truncated, self._get_info()
|
91 |
+
|
92 |
+
def _get_obs(self):
|
93 |
+
# return {'grid': np.copy(self.grid) / 20, 'loc': np.copy(self.loc) / 12, 'bag': np.copy(self.bag) / 4}
|
94 |
+
return np.concatenate((self.grid.flatten(), self.loc, self.bag.flatten()))
|
95 |
+
|
96 |
+
def _get_info(self):
|
97 |
+
return {'grid': np.copy(self.grid) / 20, 'loc': np.copy(self.loc) / 12, 'bag': np.copy(self.bag) / 4}
|
98 |
+
|
99 |
+
def render(self, mode='rgb_array'):
|
100 |
+
if mode == 'human':
|
101 |
+
print("Grid:")
|
102 |
+
print(self.grid)
|
103 |
+
print("Loc:", self.loc)
|
104 |
+
print("Bag:", self.bag)
|
105 |
+
|
106 |
+
def close(self):
|
107 |
+
pass
|
108 |
+
|
109 |
+
if __name__ == "__main__":
|
110 |
+
# Example usage
|
111 |
+
env = GridWorldEnv()
|
112 |
+
obs = env.reset()
|
113 |
+
env.render()
|
114 |
+
obs, reward, done, info = env.step(1) # Move right
|
115 |
+
env.render()
|
116 |
+
obs, reward, done, info = env.step(0) # Move down
|
117 |
+
env.render()
|
118 |
+
obs, reward, done, info = env.step(4) # Collect item
|
119 |
+
env.render()
|