zjowowen's picture
init space
079c32c
import os
import numpy as np
import pytest
from easydict import EasyDict
from ding.league.player import Player, HistoricalPlayer, ActivePlayer, create_player
from ding.league.shared_payoff import create_payoff
from ding.league.starcraft_player import MainPlayer, MainExploiter, LeagueExploiter
from ding.league.tests.league_test_default_config import league_test_config
from ding.league.metric import LeagueMetricEnv
ONE_PHASE_STEP = 2000
env = LeagueMetricEnv()
@pytest.fixture(scope='function')
def setup_payoff():
cfg = EasyDict({'type': 'battle', 'decay': 0.99})
return create_payoff(cfg)
@pytest.fixture(scope='function')
def setup_league(setup_payoff):
players = []
for category in ['zerg', 'terran', 'protoss']:
# main_player
main_player_name = '{}_{}'.format('MainPlayer', category)
players.append(
create_player(
league_test_config.league, 'main_player', league_test_config.league.main_player, category, setup_payoff,
'ckpt_{}.pth'.format(main_player_name), main_player_name, 0, env.create_rating()
)
)
# main_exloiter
main_exploiter_name = '{}_{}'.format('MainExploiter', category)
players.append(
create_player(
league_test_config.league, 'main_exploiter', league_test_config.league.main_exploiter, category,
setup_payoff, 'ckpt_{}.pth'.format(main_exploiter_name), main_exploiter_name, 0, env.create_rating()
)
)
# league_exploiter
league_exploiter_name = '{}_{}'.format('LeagueExploiter', category)
for i in range(2):
players.append(
create_player(
league_test_config.league,
'league_exploiter',
league_test_config.league.league_exploiter,
category,
setup_payoff,
'ckpt_{}.pth'.format(league_exploiter_name),
league_exploiter_name,
0,
env.create_rating(),
)
)
# historical player: sl player is used as initial HistoricalPlayer
sl_hp_name = '{}_{}_sl'.format('MainPlayer', category)
players.append(
create_player(
league_test_config.league,
'historical_player',
EasyDict(),
category,
setup_payoff,
'ckpt_sl_{}'.format(sl_hp_name),
sl_hp_name,
0,
env.create_rating(),
parent_id=main_player_name,
)
)
for p in players:
setup_payoff.add_player(p)
return players
@pytest.mark.unittest
class TestMainPlayer:
def test_get_job(self, setup_league, setup_payoff):
N = 10
# no indicated p
# test get_job
for p in setup_league:
if isinstance(p, MainPlayer):
for i in range(N):
job_dict = p.get_job()
assert isinstance(job_dict, dict)
opponent = job_dict['opponent']
assert isinstance(opponent, Player)
assert opponent in setup_league
# payoff = setup_league[np.random.randint(0, len(setup_league))].payoff # random select reference
hp_list = []
for p in setup_league:
if isinstance(p, ActivePlayer):
p.total_agent_step = 2 * ONE_PHASE_STEP
hp = p.snapshot(env)
hp_list.append(hp)
setup_payoff.add_player(hp)
setup_league += hp_list # 12+3 + 12
# test get_job with branch prob
pfsp, sp, veri = False, False, False
for p in setup_league:
if isinstance(p, MainPlayer):
while True:
job_dict = p.get_job()
opponent = job_dict['opponent']
if isinstance(opponent, HistoricalPlayer) and 'MainPlayer' in opponent.parent_id:
veri = True
elif isinstance(opponent, HistoricalPlayer):
pfsp = True
elif isinstance(opponent, MainPlayer):
sp = True
else:
raise Exception("Main Player selects a wrong opponent {}", type(opponent))
if veri and pfsp and sp:
break
def test_snapshot(self, setup_league, setup_payoff):
N = 10
for p in setup_league:
for i in range(N):
if isinstance(p, ActivePlayer):
hp = p.snapshot(env)
assert isinstance(hp, HistoricalPlayer)
assert id(hp.payoff) == id(p.payoff)
assert hp.parent_id == p.player_id
def test_is_trained_enough(self, setup_league, setup_payoff):
for p in setup_league:
if isinstance(p, ActivePlayer):
assert not p.is_trained_enough()
assert p._last_enough_step == 0
# step_passed < ONE_PHASE_STEP
p.total_agent_step = ONE_PHASE_STEP * 0.99
assert not p.is_trained_enough()
assert p._last_enough_step == 0
# ONE_PHASE_STEP < step_passed < 2*ONE_PHASE_STEP, but low win rate
p.total_agent_step = ONE_PHASE_STEP + 1
assert not p.is_trained_enough()
assert p._last_enough_step == 0
# prepare HistoricalPlayer
# payoff = setup_league[np.random.randint(0, len(setup_league))].payoff # random select reference
hp_list = []
for p in setup_league:
if isinstance(p, MainPlayer):
hp = p.snapshot(env)
setup_payoff.add_player(hp)
hp_list.append(hp)
setup_league += hp_list
# update 10 wins against all historical players, should be trained enough
N = 10
assert isinstance(setup_league[0], MainPlayer)
for n in range(N):
for hp in [p for p in setup_league if isinstance(p, HistoricalPlayer)]:
match_info = {
'player_id': [setup_league[0].player_id, hp.player_id],
'result': [['wins']],
}
result = setup_payoff.update(match_info)
assert result
assert setup_league[0]._total_agent_step > ONE_PHASE_STEP
assert setup_league[0]._last_enough_step == 0
assert setup_league[0]._last_enough_step != setup_league[0]._total_agent_step
assert setup_league[0].is_trained_enough()
assert setup_league[0]._last_enough_step == setup_league[0]._total_agent_step
# update 10 draws against all historical players, should be not trained enough;
# then update ``total_agent_step`` to 2*ONE_PHASE_STEP, should be trained enough
assert isinstance(setup_league[5], MainPlayer)
for n in range(N):
for hp in hp_list:
match_info = {
'player_id': [setup_league[5].player_id, hp.player_id],
'result': [['draws']],
}
result = setup_payoff.update(match_info)
assert result
assert setup_league[5]._total_agent_step > ONE_PHASE_STEP
assert not setup_league[5].is_trained_enough()
setup_league[5].total_agent_step = 2 * ONE_PHASE_STEP
assert setup_league[5].is_trained_enough()
def test_mutate(self, setup_league, setup_payoff):
# main players do not mutate
assert isinstance(setup_league[0], MainPlayer)
for _ in range(10):
assert setup_league[0].mutate({}) is None
def test_sp_historical(self, setup_league, setup_payoff):
N = 10
main1 = setup_league[0] # 'zerg'
main2 = setup_league[5] # 'terran'
assert isinstance(main1, MainPlayer)
assert isinstance(main2, MainPlayer)
for n in range(N):
match_info = {
'player_id': [main1.player_id, main2.player_id],
'result': [['wins']],
}
result = setup_payoff.update(match_info)
assert result
for _ in range(200):
opponent = main2._sp_branch()
condition1 = opponent.category == 'terran' or opponent.category == 'protoss'
# condition2 means: zerg_main_opponent is too strong, so that must choose a historical weaker one
condition2 = opponent.category == 'zerg' and isinstance(
opponent, HistoricalPlayer
) and opponent.parent_id == main1.player_id
assert condition1 or condition2, (condition1, condition2)
@pytest.mark.unittest
class TestMainExploiter:
def test_get_job(self, setup_league, random_job_result, setup_payoff):
assert isinstance(setup_league[1], MainExploiter)
job_dict = setup_league[1].get_job()
opponent = job_dict['opponent']
assert isinstance(opponent, MainPlayer)
N = 10
# payoff = setup_league[np.random.randint(0, len(setup_league))].payoff # random select reference
for n in range(N):
for p in setup_league:
if isinstance(p, MainPlayer):
match_info = {
'player_id': [setup_league[1].player_id, p.player_id],
'result': [['losses']],
}
assert setup_payoff.update(match_info)
job_dict = setup_league[1].get_job()
opponent = job_dict['opponent']
# as long as main player, both active and historical are ok
assert (isinstance(opponent, HistoricalPlayer)
and 'MainPlayer' in opponent.parent_id) or isinstance(opponent, MainPlayer)
hp_list = []
for i in range(3):
for p in setup_league:
if isinstance(p, MainPlayer):
p.total_agent_step = (i + 1) * 2 * ONE_PHASE_STEP
hp = p.snapshot(env)
setup_payoff.add_player(hp)
hp_list.append(hp)
setup_league += hp_list
no_main_player_league = [p for p in setup_league if not isinstance(p, MainPlayer)]
for i in range(10000):
home = np.random.choice(no_main_player_league)
away = np.random.choice(no_main_player_league)
result = random_job_result()
match_info = {
'player_id': [home.player_id, away.player_id],
'result': [[result]],
}
assert setup_payoff.update(match_info)
for i in range(10):
job_dict = setup_league[1].get_job()
opponent = job_dict['opponent']
# as long as main player, both active and historical are ok
assert (isinstance(opponent, HistoricalPlayer)
and 'MainPlayer' in opponent.parent_id) or isinstance(opponent, MainPlayer)
def test_is_trained_enough(self, setup_league):
# only a few differences from `is_trained_enough` of MainPlayer
pass
def test_mutate(self, setup_league):
assert isinstance(setup_league[1], MainExploiter)
info = {'reset_checkpoint_path': 'pretrain_checkpoint.pth'}
for _ in range(10):
assert setup_league[1].mutate(info) == info['reset_checkpoint_path']
@pytest.mark.unittest
class TestLeagueExploiter:
def test_get_job(self, setup_league):
assert isinstance(setup_league[2], LeagueExploiter)
job_dict = setup_league[2].get_job()
opponent = job_dict['opponent']
assert isinstance(opponent, HistoricalPlayer)
assert isinstance(setup_league[3], LeagueExploiter)
job_dict = setup_league[3].get_job()
opponent = job_dict['opponent']
assert isinstance(opponent, HistoricalPlayer)
def test_is_trained_enough(self, setup_league):
# this function is the same as `is_trained_enough` of MainPlayer
pass
def test_mutate(self, setup_league):
assert isinstance(setup_league[2], LeagueExploiter)
info = {'reset_checkpoint_path': 'pretrain_checkpoint.pth'}
results = []
for _ in range(1000):
results.append(setup_league[2].mutate(info))
freq = len([t for t in results if t]) * 1.0 / len(results)
assert 0.2 <= freq <= 0.3 # approximate
if __name__ == '__main__':
pytest.main(["-sv", os.path.basename(__file__)])