viz-gpt2-stockfish-debug / src /play_interface.py
Xmaster6y's picture
attention interface
55ecc31 unverified
raw
history blame
6.12 kB
"""Interface to play against the model.
"""
from typing import Optional
import huggingface_hub
import chess
import chess.svg
import uuid
import random
import wandb
import gradio as gr
from . import constants
model_name = "yp-edu/gpt2-stockfish-debug"
headers = {
"X-Wait-For-Model": "true",
"X-Use-Cache": "false",
}
client = huggingface_hub.InferenceClient(model=model_name, headers=headers)
inference_fn = client.text_generation
def plot_board(
board: chess.Board,
orientation: Optional[bool] = None,
):
if orientation is None:
orientation = board.turn
try:
last_move = board.peek()
arrows = [(last_move.from_square, last_move.to_square)]
except IndexError:
arrows = []
if board.is_check():
check = board.king(board.turn)
else:
check = None
svg_board = chess.svg.board(
board,
orientation=orientation,
check=check,
size=350,
arrows=arrows,
)
id = str(uuid.uuid4())
with open(f"{constants.FIGURE_DIRECTORY}/board_{id}.svg", "w") as f:
f.write(svg_board)
return f"{constants.FIGURE_DIRECTORY}/board_{id}.svg"
def render_board(
current_board: chess.Board,
orientation: Optional[bool] = None,
):
fen = current_board.fen()
pgn = current_board.root().variation_san(current_board.move_stack)
image_board = plot_board(current_board, orientation=orientation)
return fen, pgn, "", image_board
def play_user_move(
uci_move: str,
current_board: chess.Board,
):
current_board.push_uci(uci_move)
return current_board
def play_ai_move(
current_board: chess.Board,
temperature: float = 0.1,
):
uci_move = inference_fn(
prompt=f"FEN: {current_board.fen()}\nMOVE:",
temperature=temperature,
)
current_board.push_uci(uci_move.strip())
return current_board
def try_play_move(
username: str,
move_to_play: str,
current_board: chess.Board,
):
if current_board.is_game_over():
gr.Warning("The game is already over")
return (
*render_board(current_board, orientation=not current_board.turn),
current_board,
)
try:
current_board = play_user_move(move_to_play.strip(), current_board)
if current_board.is_game_over():
gr.Info(f"Congratulations, {username}!")
with wandb.init(project="gpt2-stockfish-debug", entity="yp-edu") as run:
run.log(
{
"username": username,
"winin": current_board.fullmove_number,
"pgn": current_board.root().variation_san(
current_board.move_stack
),
}
)
run.finish()
return (
*render_board(current_board, orientation=not current_board.turn),
current_board,
)
except:
gr.Warning("Invalid move")
return *render_board(current_board), current_board
temperature_retries = [(i + 1) / 10 for i in range(10)]
for temperature in temperature_retries:
try:
current_board = play_ai_move(current_board, temperature=temperature)
break
except:
gr.Warning(f"AI move failed with temperature {temperature}")
else:
gr.Warning("AI move failed with all temperatures")
random_move = random.choice(list(current_board.legal_moves))
gr.Warning(f"Playing random move {random_move}")
current_board.push(random_move)
return *render_board(current_board), current_board
return *render_board(current_board), current_board
with gr.Blocks() as interface:
with gr.Row():
with gr.Column():
username = gr.Textbox(
label="Username to record on leaderboard (should you win)",
lines=1,
max_lines=1,
value="",
)
leaderboard_md = gr.Markdown(
label="Leaderboard",
value="See the leaderboard [here](https://wandb.ai/yp-edu/gpt2-stockfish-debug/reports/Leaderboard--Vmlldzo2OTU0NDc2?accessToken=xito8t675j3e55owwer09hp3kk9emdg8620kesufhbng0ap4uodlulrny0t0o15n).",
)
current_fen = gr.Textbox(
label="Board FEN",
lines=1,
max_lines=1,
value=chess.STARTING_FEN,
)
current_pgn = gr.Textbox(
label="Action sequence",
lines=1,
value="",
)
with gr.Row():
move_to_play = gr.Textbox(
label="Move to play (UCI)",
lines=1,
max_lines=1,
value="",
)
play_button = gr.Button("Play")
reset_button = gr.Button("Reset")
with gr.Column():
image_board = gr.Image(label="Board")
static_inputs = [
username,
move_to_play,
]
static_outputs = [
current_fen,
current_pgn,
move_to_play,
image_board,
]
is_ai_white = random.choice([True, False])
init_board = chess.Board()
if is_ai_white:
init_board = play_ai_move(init_board)
state_board = gr.State(value=init_board)
play_button.click(
try_play_move,
inputs=[*static_inputs, state_board],
outputs=[*static_outputs, state_board],
)
move_to_play.submit(
try_play_move,
inputs=[*static_inputs, state_board],
outputs=[*static_outputs, state_board],
)
def reset_board():
board = chess.Board()
is_ai_white = random.choice([True, False])
if is_ai_white:
board = play_ai_move(board)
return *render_board(board), board
reset_button.click(
reset_board,
outputs=[*static_outputs, state_board],
)
interface.load(render_board, inputs=[state_board], outputs=[*static_outputs])