|
import gymnasium
|
|
from gymnasium import spaces
|
|
import numpy as np
|
|
|
|
|
|
|
|
class GridWorldEnv(gymnasium.Env):
|
|
metadata = {'render.modes': ['human']}
|
|
|
|
def __init__(self):
|
|
super(GridWorldEnv, self).__init__()
|
|
self.grid_size = (12, 12)
|
|
self.num_categories = 21
|
|
self.action_space = spaces.Discrete(5)
|
|
|
|
|
|
|
|
|
|
|
|
self.observation_space = spaces.Box(low=-1, high=20, shape=(144+2+self.num_categories, ), dtype=np.float64)
|
|
self.max_steps = 4 * 12 * 12
|
|
self.reset()
|
|
|
|
def reset(self, seed=None):
|
|
self.grid = self.initialize_grid()
|
|
self.loc = np.random.randint(0, self.grid_size[0], 2)
|
|
self.bag = np.zeros((self.num_categories, 1), dtype=int)
|
|
self.steps = 0
|
|
return self._get_obs(), self._get_info()
|
|
|
|
def initialize_grid(self):
|
|
grid = np.zeros(self.grid_size, dtype=int)
|
|
total_cells = self.grid_size[0] * self.grid_size[1]
|
|
remaining_cells = total_cells
|
|
category_counts = []
|
|
|
|
for _ in range(self.num_categories - 1):
|
|
count = np.random.randint(1, 3) * 4
|
|
if count > remaining_cells:
|
|
count = remaining_cells
|
|
category_counts.append(count)
|
|
remaining_cells -= count
|
|
|
|
category_counts.append(remaining_cells)
|
|
|
|
positions = np.random.permutation(total_cells)
|
|
index = 0
|
|
for category, count in enumerate(category_counts):
|
|
for _ in range(count):
|
|
x, y = divmod(positions[index], self.grid_size[1])
|
|
grid[x, y] = category
|
|
index += 1
|
|
|
|
return grid
|
|
|
|
def step(self, action):
|
|
action = int(action)
|
|
self.steps += 1
|
|
reward = -0.1 - np.sum(self.bag) / (12 * 12)
|
|
if action == 0:
|
|
self.loc[0] = min(self.loc[0] + 1, self.grid_size[0] - 1)
|
|
elif action == 1:
|
|
self.loc[1] = min(self.loc[1] + 1, self.grid_size[1] - 1)
|
|
elif action == 2:
|
|
self.loc[0] = max(self.loc[0] - 1, 0)
|
|
elif action == 3:
|
|
self.loc[1] = max(self.loc[1] - 1, 0)
|
|
elif action == 4:
|
|
x, y = self.loc
|
|
if self.grid[x, y] == -1:
|
|
reward -= 2
|
|
else:
|
|
self.bag[self.grid[x, y]] += 1
|
|
self.grid[x, y] = -1
|
|
|
|
|
|
if np.any(self.bag == 4):
|
|
category = np.argmax(self.bag == 4)
|
|
self.bag[category] -= 4
|
|
reward += 1
|
|
|
|
done = bool(np.sum(self.grid) == -12 * 12)
|
|
if done:
|
|
reward += 100
|
|
truncated = bool(self.steps >= self.max_steps)
|
|
if truncated:
|
|
reward = -3 * (np.sum(self.bag) + np.sum(self.grid != -1))
|
|
|
|
|
|
return self._get_obs(), reward, done, truncated, self._get_info()
|
|
|
|
def _get_obs(self):
|
|
|
|
return np.concatenate((self.grid.flatten(), self.loc, self.bag.flatten()))
|
|
|
|
def _get_info(self):
|
|
return {'grid': np.copy(self.grid) / 20, 'loc': np.copy(self.loc) / 12, 'bag': np.copy(self.bag) / 4}
|
|
|
|
def render(self, mode='rgb_array'):
|
|
if mode == 'human':
|
|
print("Grid:")
|
|
print(self.grid)
|
|
print("Loc:", self.loc)
|
|
print("Bag:", self.bag)
|
|
|
|
def close(self):
|
|
pass
|
|
|
|
if __name__ == "__main__":
|
|
|
|
env = GridWorldEnv()
|
|
obs = env.reset()
|
|
env.render()
|
|
obs, reward, done, info = env.step(1)
|
|
env.render()
|
|
obs, reward, done, info = env.step(0)
|
|
env.render()
|
|
obs, reward, done, info = env.step(4)
|
|
env.render()
|
|
|