|
from easydict import EasyDict |
|
import copy |
|
|
|
|
|
class Node(): |
|
""" |
|
Overview: |
|
Alpha-Beta-Pruning Search Node. |
|
https://mathspp.com/blog/minimax-algorithm-and-alpha-beta-pruning |
|
Arguments: |
|
env: Class Env, such as |
|
zoo.board_games.tictactoe.envs.tictactoe_env.TicTacToeEnv, |
|
zoo.board_games.gomoku.envs.gomoku_env.GomokuEnv |
|
""" |
|
|
|
def __init__(self, board, legal_actions, start_player_index=0, parent=None, prev_action=None, env=None): |
|
super().__init__() |
|
self.env = env |
|
self.board = board |
|
self.legal_actions = copy.deepcopy(legal_actions) |
|
self.children = [] |
|
self.parent = parent |
|
self.prev_action = prev_action |
|
self.start_player_index = start_player_index |
|
self.tree_expanded = False |
|
|
|
def __str__(self): |
|
return f"Tree({', '.join(str(child) for child in self.children)})" |
|
|
|
def expand(self): |
|
if self.start_player_index == 0: |
|
next_start_player_index = 1 |
|
else: |
|
next_start_player_index = 0 |
|
if self.is_terminal_node is False: |
|
|
|
|
|
while len(self.legal_actions) > 0: |
|
action = self.legal_actions.pop(0) |
|
board, legal_actions = self.env.simulate_action_v2(self.board, self.start_player_index, action) |
|
child_node = Node( |
|
board, |
|
legal_actions, |
|
start_player_index=next_start_player_index, |
|
parent=self, |
|
prev_action=action, |
|
env=self.env |
|
) |
|
|
|
self.children.append(child_node) |
|
self.tree_expanded = True |
|
|
|
@property |
|
def expanded(self): |
|
|
|
return self.tree_expanded |
|
|
|
def is_fully_expanded(self): |
|
return len(self.children) == len(self.legal_actions) |
|
|
|
@property |
|
def is_terminal_node(self): |
|
self.env.reset_v2(self.start_player_index, init_state=self.board) |
|
return self.env.get_done_reward()[0] |
|
|
|
@property |
|
def value(self): |
|
""" |
|
def get_done_reward(self): |
|
Overview: |
|
To judge game whether over, and get reward |
|
Returns: |
|
[game_over, reward] |
|
if winner = 1 reward = 1 |
|
if winner = 2 reward = -1 |
|
if winner = -1 reward = 0 |
|
""" |
|
self.env.reset_v2(self.start_player_index, init_state=self.board) |
|
return self.env.get_done_reward()[1] |
|
|
|
@property |
|
def estimated_value(self): |
|
return 0 |
|
|
|
@property |
|
def state(self): |
|
return self.board |
|
|
|
|
|
def pruning(tree, maximising_player, alpha=float("-inf"), beta=float("+inf"), depth=999, first_level=True): |
|
if tree.is_terminal_node is True: |
|
return tree.value |
|
|
|
if depth == 0: |
|
return tree.estimated_value |
|
|
|
|
|
if tree.expanded is False: |
|
tree.expand() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
val = float("-inf") if maximising_player else float("+inf") |
|
for subtree in tree.children: |
|
sub_val = pruning(subtree, not maximising_player, alpha, beta, depth - 1, first_level=False) |
|
if maximising_player: |
|
val = max(sub_val, val) |
|
if val > alpha: |
|
best_subtree = subtree |
|
alpha = val |
|
else: |
|
val = min(sub_val, val) |
|
if val < beta: |
|
best_subtree = subtree |
|
beta = val |
|
if beta <= alpha: |
|
break |
|
|
|
if first_level is True: |
|
return val, best_subtree |
|
else: |
|
return val |
|
|
|
|
|
class AlphaBetaPruningBot: |
|
|
|
def __init__(self, ENV, cfg, bot_name): |
|
self.name = bot_name |
|
self.ENV = ENV |
|
self.cfg = cfg |
|
|
|
def get_best_action(self, board, player_index, depth=999): |
|
try: |
|
simulator_env = copy.deepcopy(self.ENV(EasyDict(self.cfg))) |
|
except: |
|
simulator_env = copy.deepcopy(self.ENV) |
|
simulator_env.reset(start_player_index=player_index, init_state=board) |
|
root = Node(board, simulator_env.legal_actions, start_player_index=player_index, env=simulator_env) |
|
if player_index == 0: |
|
val, best_subtree = pruning(root, True, depth=depth, first_level=True) |
|
else: |
|
val, best_subtree = pruning(root, False, depth=depth, first_level=True) |
|
|
|
|
|
|
|
return best_subtree.prev_action |
|
|
|
|
|
if __name__ == "__main__": |
|
import time |
|
|
|
from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv |
|
cfg = dict( |
|
prob_random_agent=0, |
|
prob_expert_agent=0, |
|
battle_mode='self_play_mode', |
|
agent_vs_human=False, |
|
bot_action_type='alpha_beta_pruning', |
|
channel_last=True, |
|
scale=True, |
|
) |
|
env = TicTacToeEnv(EasyDict(cfg)) |
|
player_0 = AlphaBetaPruningBot(TicTacToeEnv, cfg, 'player 1') |
|
player_1 = AlphaBetaPruningBot(TicTacToeEnv, cfg, 'player 2') |
|
|
|
|
|
player_index = 0 |
|
env.reset() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
state = env.board |
|
print('-' * 15) |
|
print(state) |
|
|
|
while not env.get_done_reward()[0]: |
|
if player_index == 0: |
|
start = time.time() |
|
action = player_0.get_best_action(state, player_index=player_index) |
|
print('player 1 action time: ', time.time() - start) |
|
player_index = 1 |
|
else: |
|
start = time.time() |
|
action = player_1.get_best_action(state, player_index=player_index) |
|
print('player 2 action time: ', time.time() - start) |
|
player_index = 0 |
|
env.step(action) |
|
state = env.board |
|
print('-' * 15) |
|
print(state) |
|
row, col = env.action_to_coord(action) |
|
|
|
|
|
assert env.get_done_winner()[0] is False, env.get_done_winner()[1] == -1 |
|
|
|
|
|
|
|
|
|
""" |
|
|
|
##### Gomoku ##### |
|
from zoo.board_games.gomoku.envs.gomoku_env import GomokuEnv |
|
cfg = dict( |
|
board_size=5, |
|
prob_random_agent=0, |
|
prob_expert_agent=0, |
|
battle_mode='self_play_mode', |
|
scale=True, |
|
channel_last=True, |
|
agent_vs_human=False, |
|
bot_action_type='alpha_beta_pruning', # {'v0', 'alpha_beta_pruning'} |
|
prob_random_action_in_bot=0., |
|
check_action_to_connect4_in_bot_v0=False, |
|
) |
|
env = GomokuEnv(EasyDict(cfg)) |
|
player_0 = AlphaBetaPruningBot(GomokuEnv, cfg, 'player 1') # player_index = 0, player = 1 |
|
player_1 = AlphaBetaPruningBot(GomokuEnv, cfg, 'player 2') # player_index = 1, player = 2 |
|
|
|
### test from the init empty board ### |
|
player_index = 0 # player 1 fist |
|
env.reset() |
|
|
|
### test from the init specified board ### |
|
# player_index = 1 # player 2 fist |
|
# init_state = [[1, 1, 1, 1, 0], |
|
# [1, 0, 0, 0, 2], |
|
# [0, 0, 2, 0, 2], |
|
# [0, 2, 0, 0, 2], |
|
# [2, 1, 1, 0, 0], ] |
|
# # init_state = [[1, 1, 1, 1, 2], |
|
# # [1, 1, 2, 1, 2], |
|
# # [2, 1, 2, 2, 2], |
|
# # [0, 0, 0, 2, 2], |
|
# # [2, 1, 1, 1, 0], ] |
|
# env.reset(player_index, init_state) |
|
|
|
state = env.board |
|
print('-' * 15) |
|
print(state) |
|
|
|
while not env.get_done_reward()[0]: |
|
if player_index == 0: |
|
start = time.time() |
|
action = player_0.get_best_action(state, player_index=player_index) |
|
print('player 1 action time: ', time.time() - start) |
|
player_index = 1 |
|
else: |
|
start = time.time() |
|
action = player_1.get_best_action(state, player_index=player_index) |
|
print('player 2 action time: ', time.time() - start) |
|
player_index = 0 |
|
env.step(action) |
|
state = env.board |
|
print('-' * 15) |
|
print(state) |
|
|
|
assert env.get_done_winner()[0] is False, env.get_done_winner()[1] == -1 |
|
# assert env.get_done_winner()[0] is True, env.get_done_winner()[1] == 2 |
|
""" |
|
|
|
|