ChessPT / app.py
philipp-zettl's picture
update app
961e9fd
import gradio as gr
from io import StringIO
from model import DecoderTransformer, Tokenizer
from huggingface_hub import hf_hub_download
import torch
import chess
import chess.svg
import chess.pgn
from svglib.svglib import svg2rlg
from reportlab.graphics import renderPM
from PIL import Image
import os
from uuid import uuid4
vocab_size = 36
n_embed = 384
context_size = 256
n_layer = 6
n_head = 6
dropout = 0.2
device = 'cpu'
model_id = "philipp-zettl/chessPT"
model_path = hf_hub_download(repo_id=model_id, filename="chessPT-v0.5.pth")
tokenizer_path = hf_hub_download(repo_id=model_id, filename="tokenizer-v0.5.json")
model = DecoderTransformer(vocab_size, n_embed, context_size, n_layer, n_head, dropout)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.to(device)
tokenizer = Tokenizer.from_pretrained(tokenizer_path)
invalid_move_plot = Image.open('./invalid_move.png')
def gen_image_from_svg(img, filename):
with open(filename + '.svg', 'w') as f:
f.write(img)
drawing = svg2rlg(filename + '.svg')
renderPM.drawToFile(drawing, f"{filename}.png", fmt="PNG")
plot = Image.open(f'{filename}.png')
os.remove(f'{filename}.png')
os.remove(f'{filename}.svg')
return plot
def get_board(pgn):
pgn_str = StringIO(pgn)
try:
game = chess.pgn.read_game(pgn_str)
board = game.board()
for move in game.mainline_moves():
board.push(move)
except Exception as e:
if 'illegal san' in str(e):
return None
return board
def gen_board_image(pgn):
board = get_board(pgn)
return chess.svg.board(board)
def gen_move(pgn):
model_input = torch.tensor(tokenizer.encode(pgn), dtype=torch.long, device=device).view((1, len(pgn)))
is_invalid = True
board = get_board(pgn)
while is_invalid:
new_pgn = tokenizer.decode(model.generate(model_input, max_new_tokens=4, context_size=context_size, temperature=0.2)[0].tolist())
try:
print(f'checking {new_pgn}')
mv = new_pgn[len(pgn):].split(' ')[0]
new_pgn = pgn.rstrip() + f' {mv}'
is_invalid = get_board(new_pgn) is None
except Exception:
is_invalid = True
print(f'For {pgn} invalid "{new_pgn[len(pgn):].split(" ")[0]}" {new_pgn}')
#print(mov in board.legal_moves)
return new_pgn
def generate(prompt):
model_input = torch.tensor(tokenizer.encode(prompt), dtype=torch.long, device=device).view((1, len(prompt)))
pgn = tokenizer.decode(model.generate(model_input, max_new_tokens=4, context_size=context_size, temperature=0.2)[0].tolist())
img = gen_board_image(pgn)
filename = f'./moves-{str(uuid4())}'
plot = gen_image_from_svg(img, filename)
return pgn, plot
with gr.Blocks() as demo:
gr.Markdown("""
# ChessPT
Welcome to ChessPT.
The **C**hess-**P**re-trained-**T**ransformer.
The rules are simple:
- "Interactive": Play a game against the model
- "Next turn prediction": provide a PGN string of your current game, the model will predict the next token
""")
def manual():
with gr.Tab("Next turn prediction"):
prompt = gr.Text(label="PGN")
output = gr.Text(label="Next turn", interactive=False)
img = gr.Image()
submit = gr.Button("Submit")
submit.click(generate, [prompt], [output, img])
gr.Examples(
[
["1. e4", ],
["1. e4 g6 2."],
],
inputs=[prompt],
outputs=[output, img],
fn=generate
)
def interactive():
with gr.Tab("Interactive"):
color = gr.Dropdown(["white", "black"], value='white', label="Chose a color")
start_button = gr.Button("Start Game")
def start_game(c):
pgn = '1. '
if c == 'black':
pgn = gen_move(pgn)
img = gen_board_image(pgn)
fn = 'foo'
return gen_image_from_svg(img, fn), pgn, 1
state = gr.Text(label='PGN', value='', interactive=False)
game = gr.Image()
move_counter = gr.State(value=1)
start_button.click(
start_game,
inputs=[color],
outputs=[game, state, move_counter]
)
next_move = gr.Text(label='Next move')
gen_next_move_button = gr.Button("Submit")
def gen_next_move(pgn, new_move, move_ctr, c):
pgn += (' ' if c == 'black' else '') + new_move.strip() + ' '
if c == 'black':
move_ctr += 1
pgn = f'{pgn.rstrip()} {move_ctr}. '
print(f'gen for {pgn}')
pgn = gen_move(pgn)
print(f'got {pgn}')
img = gen_board_image(pgn)
if c == 'white':
move_ctr += 1
pgn = f'{pgn.rstrip()} {move_ctr}. '
return gen_image_from_svg(img, 'foo-bar'), pgn, move_ctr
gen_next_move_button.click(
gen_next_move,
inputs=[state, next_move, move_counter, color],
outputs=[game, state, move_counter]
)
interactive()
manual()
demo.launch()