import os import random import pytest import copy from easydict import EasyDict import torch from ding.league import create_league one_vs_one_league_default_config = dict( league=dict( league_type='one_vs_one', import_names=["ding.league"], # ---player---- # "player_category" is just a name. Depends on the env. # For example, in StarCraft, this can be ['zerg', 'terran', 'protoss']. player_category=['default'], # Support different types of active players for solo and battle league. # For solo league, supports ['solo_active_player']. # For battle league, supports ['battle_active_player', 'main_player', 'main_exploiter', 'league_exploiter']. active_players=dict( naive_sp_player=1, # {player_type: player_num} ), naive_sp_player=dict( # There should be keys ['one_phase_step', 'branch_probs', 'strong_win_rate']. # Specifically for 'main_exploiter' of StarCraft, there should be an additional key ['min_valid_win_rate']. one_phase_step=10, branch_probs=dict( pfsp=0.5, sp=0.5, ), strong_win_rate=0.7, ), # "use_pretrain" means whether to use pretrain model to initialize active player. use_pretrain=False, # "use_pretrain_init_historical" means whether to use pretrain model to initialize historical player. # "pretrain_checkpoint_path" is the pretrain checkpoint path used in "use_pretrain" and # "use_pretrain_init_historical". If both are False, "pretrain_checkpoint_path" can be omitted as well. # Otherwise, "pretrain_checkpoint_path" should list paths of all player categories. use_pretrain_init_historical=False, pretrain_checkpoint_path=dict(default='default_cate_pretrain.pth', ), # ---payoff--- payoff=dict( # Supports ['battle'] type='battle', decay=0.99, min_win_rate_games=8, ), path_policy='./league', ), ) one_vs_one_league_default_config = EasyDict(one_vs_one_league_default_config) def get_random_result(): ran = random.random() if ran < 1. / 3: return "wins" elif ran < 1. / 2: return "losses" else: return "draws" @pytest.mark.unittest class TestOneVsOneLeague: def test_naive(self): league = create_league(one_vs_one_league_default_config.league) assert (len(league.active_players) == 1) assert (len(league.historical_players) == 0) active_player_ids = [p.player_id for p in league.active_players] assert set(active_player_ids) == set(league.active_players_ids) active_player_id = active_player_ids[0] active_player_ckpt = league.active_players[0].checkpoint_path tmp = torch.tensor([1, 2, 3]) path_policy = one_vs_one_league_default_config.league.path_policy torch.save(tmp, active_player_ckpt) # judge_snapshot & update_active_player assert not league.judge_snapshot(active_player_id) player_update_dict = { 'player_id': active_player_id, 'train_iteration': one_vs_one_league_default_config.league.naive_sp_player.one_phase_step * 2, } league.update_active_player(player_update_dict) assert league.judge_snapshot(active_player_id) historical_player_ids = [p.player_id for p in league.historical_players] assert len(historical_player_ids) == 1 historical_player_id = historical_player_ids[0] # get_job_info, eval_flag=False vs_active = False vs_historical = False while True: collect_job_info = league.get_job_info(active_player_id, eval_flag=False) assert collect_job_info['agent_num'] == 2 assert len(collect_job_info['checkpoint_path']) == 2 assert collect_job_info['launch_player'] == active_player_id assert collect_job_info['player_id'][0] == active_player_id if collect_job_info['player_active_flag'][1]: assert collect_job_info['player_id'][1] == collect_job_info['player_id'][0] vs_active = True else: assert collect_job_info['player_id'][1] == historical_player_id vs_historical = True if vs_active and vs_historical: break # get_job_info, eval_flag=False eval_job_info = league.get_job_info(active_player_id, eval_flag=True) assert eval_job_info['agent_num'] == 1 assert len(eval_job_info['checkpoint_path']) == 1 assert eval_job_info['launch_player'] == active_player_id assert eval_job_info['player_id'][0] == active_player_id assert len(eval_job_info['player_id']) == 1 assert len(eval_job_info['player_active_flag']) == 1 assert eval_job_info['eval_opponent'] in league.active_players[0]._eval_opponent_difficulty # finish_job episode_num = 5 env_num = 8 player_id = [active_player_id, historical_player_id] result = [[get_random_result() for __ in range(8)] for _ in range(5)] payoff_update_info = { 'launch_player': active_player_id, 'player_id': player_id, 'episode_num': episode_num, 'env_num': env_num, 'result': result, } league.finish_job(payoff_update_info) wins = 0 games = episode_num * env_num for i in result: for j in i: if j == 'wins': wins += 1 league.payoff[league.active_players[0], league.historical_players[0]] == wins / games os.popen("rm -rf {}".format(path_policy)) print("Finish!") def test_league_info(self): cfg = copy.deepcopy(one_vs_one_league_default_config.league) cfg.path_policy = 'test_league_info' league = create_league(cfg) active_player_id = [p.player_id for p in league.active_players][0] active_player_ckpt = [p.checkpoint_path for p in league.active_players][0] tmp = torch.tensor([1, 2, 3]) torch.save(tmp, active_player_ckpt) assert (len(league.active_players) == 1) assert (len(league.historical_players) == 0) print('\n') print(repr(league.payoff)) print(league.player_rank(string=True)) league.judge_snapshot(active_player_id, force=True) for i in range(10): job = league.get_job_info(active_player_id, eval_flag=False) payoff_update_info = { 'launch_player': active_player_id, 'player_id': job['player_id'], 'episode_num': 2, 'env_num': 4, 'result': [[get_random_result() for __ in range(4)] for _ in range(2)] } league.finish_job(payoff_update_info) # if not self-play if job['player_id'][0] != job['player_id'][1]: win_loss_result = sum(payoff_update_info['result'], []) home = league.get_player_by_id(job['player_id'][0]) away = league.get_player_by_id(job['player_id'][1]) home.rating, away.rating = league.metric_env.rate_1vs1(home.rating, away.rating, win_loss_result) print(repr(league.payoff)) print(league.player_rank(string=True)) os.popen("rm -rf {}".format(cfg.path_policy)) if __name__ == '__main__': pytest.main(["-sv", os.path.basename(__file__)])