zjowowen's picture
init space
079c32c
raw
history blame
2.65 kB
import random
import numpy as np
from dizoo.gfootball.envs.obs.gfootball_obs import PlayerObs, MatchObs
from ding.utils.data import default_collate
def generate_data(player_obs: dict) -> np.array:
dim = player_obs['dim']
min = player_obs['value']['min']
max = player_obs['value']['max']
dinfo = player_obs['value']['dinfo']
if dinfo in ['one-hot', 'boolean vector']:
data = np.zeros((dim, ), dtype=np.float32)
data[random.randint(0, dim - 1)] = 1
return data
elif dinfo == 'float':
data = np.random.rand(dim)
for dim_idx in range(dim):
data[dim_idx] = min[dim_idx] + (max[dim_idx] - min[dim_idx]) * data[dim_idx]
return data
class FakeGfootballDataset:
def __init__(self):
match_obs = MatchObs({})
player_obs = PlayerObs({})
self.match_obs_info = match_obs.template
self.player_obs_info = player_obs.template
self.action_dim = 19
self.batch_size = 4
del match_obs, player_obs
def __len__(self) -> int:
return self.batch_size
def get_random_action(self) -> np.array:
return np.random.randint(0, self.action_dim - 1, size=(1, ))
def get_random_obs(self) -> dict:
inputs = {}
for match_obs in self.match_obs_info:
key = match_obs['ret_key']
data = generate_data(match_obs)
inputs[key] = data
players_list = []
for _ in range(22):
one_player = {}
for player_obs in self.player_obs_info:
key = player_obs['ret_key']
data = generate_data(player_obs)
one_player[key] = data
players_list.append(one_player)
inputs['players'] = players_list
return inputs
def get_batched_obs(self, bs: int) -> dict:
batch = []
for _ in range(bs):
batch.append(self.get_random_obs())
return default_collate(batch)
def get_random_reward(self) -> np.array:
return np.array([random.random() - 0.5])
def get_random_terminals(self) -> int:
sample = random.random()
if sample > 0.99:
return 1
return 0
def get_batch_sample(self, bs: int) -> list:
batch = []
for _ in range(bs):
step = {}
step['obs'] = self.get_random_obs()
step['next_obs'] = self.get_random_obs()
step['action'] = self.get_random_action()
step['done'] = self.get_random_terminals()
step['reward'] = self.get_random_reward()
batch.append(step)
return batch