|
import random |
|
import numpy as np |
|
|
|
from dizoo.gfootball.envs.obs.gfootball_obs import PlayerObs, MatchObs |
|
from ding.utils.data import default_collate |
|
|
|
|
|
def generate_data(player_obs: dict) -> np.array: |
|
dim = player_obs['dim'] |
|
min = player_obs['value']['min'] |
|
max = player_obs['value']['max'] |
|
dinfo = player_obs['value']['dinfo'] |
|
if dinfo in ['one-hot', 'boolean vector']: |
|
data = np.zeros((dim, ), dtype=np.float32) |
|
data[random.randint(0, dim - 1)] = 1 |
|
return data |
|
elif dinfo == 'float': |
|
data = np.random.rand(dim) |
|
for dim_idx in range(dim): |
|
data[dim_idx] = min[dim_idx] + (max[dim_idx] - min[dim_idx]) * data[dim_idx] |
|
return data |
|
|
|
|
|
class FakeGfootballDataset: |
|
|
|
def __init__(self): |
|
match_obs = MatchObs({}) |
|
player_obs = PlayerObs({}) |
|
self.match_obs_info = match_obs.template |
|
self.player_obs_info = player_obs.template |
|
self.action_dim = 19 |
|
self.batch_size = 4 |
|
del match_obs, player_obs |
|
|
|
def __len__(self) -> int: |
|
return self.batch_size |
|
|
|
def get_random_action(self) -> np.array: |
|
return np.random.randint(0, self.action_dim - 1, size=(1, )) |
|
|
|
def get_random_obs(self) -> dict: |
|
inputs = {} |
|
for match_obs in self.match_obs_info: |
|
key = match_obs['ret_key'] |
|
data = generate_data(match_obs) |
|
inputs[key] = data |
|
players_list = [] |
|
for _ in range(22): |
|
one_player = {} |
|
for player_obs in self.player_obs_info: |
|
key = player_obs['ret_key'] |
|
data = generate_data(player_obs) |
|
one_player[key] = data |
|
players_list.append(one_player) |
|
inputs['players'] = players_list |
|
return inputs |
|
|
|
def get_batched_obs(self, bs: int) -> dict: |
|
batch = [] |
|
for _ in range(bs): |
|
batch.append(self.get_random_obs()) |
|
return default_collate(batch) |
|
|
|
def get_random_reward(self) -> np.array: |
|
return np.array([random.random() - 0.5]) |
|
|
|
def get_random_terminals(self) -> int: |
|
sample = random.random() |
|
if sample > 0.99: |
|
return 1 |
|
return 0 |
|
|
|
def get_batch_sample(self, bs: int) -> list: |
|
batch = [] |
|
for _ in range(bs): |
|
step = {} |
|
step['obs'] = self.get_random_obs() |
|
step['next_obs'] = self.get_random_obs() |
|
step['action'] = self.get_random_action() |
|
step['done'] = self.get_random_terminals() |
|
step['reward'] = self.get_random_reward() |
|
batch.append(step) |
|
return batch |
|
|