chessllm / app.py
mlabonne's picture
Update app.py
5091a94 verified
raw
history blame
16.6 kB
import io
import os
import re
import time
from collections import defaultdict
from datetime import datetime
import cairosvg
import chess
import chess.svg
import gistyc
import numpy as np
import outlines.models as models
from outlines import generate
import pandas as pd
import requests
from tqdm.auto import tqdm
from PIL import Image as PILImage
import gradio as gr
def generate_regex(board):
"""
Generate regular expression for legal moves.
Based on https://gist.github.com/903124/cfbefa24da95e2316e0d5e8ef8ed360d by @903124S.
"""
legal_moves = list(board.legal_moves)
move_strings = [board.san(move) for move in legal_moves]
move_strings = [re.sub(r"[+#]", "", move) for move in move_strings]
regex_pattern = "|".join(re.escape(move) for move in move_strings)
return regex_pattern
def write_pgn(
pgn_moves, model_id_white, model_id_black, result, time_budget, termination
):
# Get current UTC date and time
current_utc_datetime = datetime.utcnow()
utc_date = current_utc_datetime.strftime("%Y.%m.%d")
utc_time = current_utc_datetime.strftime("%H:%M:%S")
# Output the final PGN with CLKS and additional details
final_pgn = f"""[Event 'Chess LLM Arena']
[Site 'https://github.com/mlabonne/chessllm']
[Date '{utc_date}']
[White '{model_id_white}']
[Black '{model_id_black}']
[Result '{result}']
[Time '{utc_time}']
[TimeControl '{time_budget}+0']
[Termination '{termination}']
{pgn_moves}
"""
return final_pgn
def determine_termination(board, time_budget_white, time_budget_black):
if board.is_checkmate():
return "Checkmate"
elif board.is_stalemate():
return "Stalemate"
elif board.is_insufficient_material():
return "Draw due to insufficient material"
elif board.can_claim_threefold_repetition():
return "Draw by threefold repetition"
elif board.can_claim_fifty_moves():
return "Draw by fifty-move rule"
elif time_budget_white <= 0 or time_budget_black <= 0:
return "Timeout"
else:
return "Unknown"
def format_elapsed(seconds):
"""Formats elapsed time dynamically to hh:mm:ss, mm:ss, or ss format."""
hours, remainder = divmod(int(seconds), 3600)
minutes, seconds = divmod(remainder, 60)
if hours:
return f"{hours:02d}:{minutes:02d}:{seconds:02d}"
elif minutes:
return f"{minutes:02d}:{seconds:02d}"
else:
return f"{seconds:02d}"
def create_gif(image_list, gif_path, duration):
# Convert numpy arrays back to PIL images
pil_images = [PILImage.fromarray(image) for image in image_list]
# Save the images as a GIF
pil_images[0].save(
gif_path, save_all=True, append_images=pil_images[1:], duration=duration, loop=0
)
def render_init(board):
svg = chess.svg.board(board=board).encode("utf-8")
png = cairosvg.svg2png(bytestring=svg)
image = PILImage.open(io.BytesIO(png))
# Calculate the size of the new image
width, height = image.size
new_width = 3 * width
# Create a new blank image with the desired dimensions
new_image = PILImage.new('RGB', (width, height), 'white')
# Calculate the position to paste the chess board image
left = 0 # One third of the width
upper = 0
# Paste the chess board image into the new image
new_image.paste(image, (left, upper))
return new_image
def render_new(board):
last_move = board.peek()
svg = chess.svg.board(board=board, arrows=[(last_move.from_square, last_move.to_square)]).encode("utf-8")
png = cairosvg.svg2png(bytestring=svg)
image = PILImage.open(io.BytesIO(png))
# Calculate the size of the new image
width, height = image.size
new_width = 3 * width
# Create a new blank image with the desired dimensions
new_image = PILImage.new('RGB', (width, height), 'white')
# Calculate the position to paste the chess board image
left = 0 # One third of the width
upper = 0
# Paste the chess board image into the new image
new_image.paste(image, (left, upper))
return new_image
def save_result_file(
pgn_id, model_id_white, model_id_black, termination, result, auth_token, gist_id
):
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Data to be written to the file
data_str = f"{pgn_id},{timestamp},{model_id_white},{model_id_black},{termination},{result}\n"
# Append data to a text file
with open("chessllm_results.csv", "a") as file:
file.write(data_str)
# Update the Gist
gist_api = gistyc.GISTyc(auth_token=GITHUB_TOKEN)
response_update_data = gist_api.update_gist(
file_name="chessllm_results.csv", gist_id=gist_id
)
def save_pgn(final_pgn, file_name, auth_token):
# Write final PGN to a file
with open(file_name + ".pgn", "w") as file:
file.write(final_pgn)
gist_api = gistyc.GISTyc(auth_token=GITHUB_TOKEN)
response_data = gist_api.create_gist(file_name=file_name + ".pgn")
print(response_data)
return response_data["id"]
def download_file(base_url, file_name):
"""
Helper function to download a file.
"""
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
url = f"{base_url}?ts={timestamp}"
headers = {
"Cache-Control": "no-cache, no-store, must-revalidate",
"Pragma": "no-cache",
"Expires": "0",
}
response = requests.get(url, headers=headers)
if response.status_code == 200:
with open(file_name, "wb") as file:
file.write(response.content)
else:
print(f"Failed to download file. HTTP status code: {response.status_code}")
def get_leaderboard():
"""
Wrapper to download the leaderboard and update the Gradio Dataframe.
"""
url = f"https://gist.githubusercontent.com/chessllm/{LEAD_GIST_ID}/raw"
download_file(url, "chessllm_leaderboard.csv")
return pd.read_csv("chessllm_leaderboard.csv")
def calculate_elo(rank1, rank2, result):
"""
Calculate the new ELO rating for a player.
:param rank1: The current ELO rating of player 1
:param rank2: The current ELO rating of player 2
:param result: 1 if player 1 wins, 0 if player 2 wins, 0.5 for a draw
:return: The updated ELO rating of player 1
"""
K = 32
expected_score1 = 1 / (1 + 10 ** ((rank2 - rank1) / 400))
new_rank1 = rank1 + K * (result - expected_score1)
return round(new_rank1)
def update_elo_ratings(chess_data):
"""
Update ELO ratings for each player based on the match results in the dataset.
:param chess_data: DataFrame with chess match results
:return: A dictionary with updated ELO ratings for each player
"""
elo_ratings = defaultdict(lambda: 1000) # Default ELO rating is 1000
for index, row in chess_data.iterrows():
if row["Result"] == "*":
continue # Skip ongoing games
model1 = row["Model1"]
model2 = row["Model2"]
result = row["Result"]
model1_elo = elo_ratings[model1]
model2_elo = elo_ratings[model2]
# Update ELO based on the result
if result == "1-0": # Model1 wins
elo_ratings[model1] = calculate_elo(model1_elo, model2_elo, 1)
elo_ratings[model2] = calculate_elo(model2_elo, model1_elo, 0)
elif result == "0-1": # Model2 wins
elo_ratings[model1] = calculate_elo(model1_elo, model2_elo, 0)
elo_ratings[model2] = calculate_elo(model2_elo, model1_elo, 1)
elif result == "1/2-1/2": # Draw
elo_ratings[model1] = calculate_elo(model1_elo, model2_elo, 0.5)
elo_ratings[model2] = calculate_elo(model2_elo, model1_elo, 0.5)
return elo_ratings
def update(model_id_white, model_id_black):
model_white = models.transformers(model_id_white)
model_black = models.transformers(model_id_black)
TIME_BUDGET = 180
prompt = '1.'
# Initialize the chess board
board = chess.Board()
board_images = []
pgn_moves = ""
move_number = 1
result = None
# Render first image
image = render_init(board)
yield image
# Time budget
time_budget_white = TIME_BUDGET
time_budget_black = TIME_BUDGET
white_bar_format = "{desc} {n:.0f} seconds left | Elapsed: {elapsed}"
black_bar_format = "{desc} {n:.0f} seconds left | Elapsed: {elapsed}"
white_bar = tqdm(total=time_budget_white, desc=f"{model_id_white.split('/')[-1]}:", bar_format=white_bar_format, colour='white')
black_bar = tqdm(total=time_budget_black, desc=f"{model_id_black.split('/')[-1]}:", bar_format=black_bar_format, colour='black')
elo_ratings_df = pd.read_csv("chessllm_leaderboard.csv")
# Game loop
while not board.is_game_over():
# Select model
current_model = model_white if board.turn == chess.WHITE else model_black
# Generate regex pattern
regex_pattern = generate_regex(board)
# Generate move
start_time = time.time()
guided = generate.regex(current_model, regex_pattern)(prompt)
end_time = time.time()
move_duration = end_time - start_time
try:
# Parse move
move_san = guided.strip()
move = board.parse_san(move_san)
if move not in board.legal_moves:
print(f"Illegal move: {move_san}")
break
board.push(move)
# Write move
if board.turn == chess.BLACK:
move_str = f"{move_number}. {move_san} "
move_number += 1
else:
move_str = f"{move_san} "
pgn_moves += move_str
# Render the board to an image
image = render_new(board)
board_images.append(np.array(image))
# Deduct the time taken for the move from the model's time budget
if board.turn == chess.WHITE:
time_budget_black -= move_duration
black_bar.n = time_budget_black
black_bar.set_postfix_str(f"{format_elapsed(black_bar.format_dict['elapsed'])} elapsed")
black_bar.refresh()
if time_budget_black <= 0:
result = "1-0"
break
else:
time_budget_white -= move_duration
white_bar.n = time_budget_white
white_bar.set_postfix_str(f"{format_elapsed(white_bar.format_dict['elapsed'])} elapsed")
white_bar.refresh()
if time_budget_white <= 0:
result = "0-1"
break
# Display board
yield image
except ValueError:
print(f"Invalid move: {guided}")
break
white_bar.close()
black_bar.close()
# Get result
if result is None:
result = board.result()
# Create PGN
termination = determine_termination(board, time_budget_white, time_budget_black)
final_pgn = write_pgn(
pgn_moves, model_id_white, model_id_black, result, TIME_BUDGET, termination
)
file_name = f"{model_id_white.split('/')[-1]}_vs_{model_id_black.split('/')[-1]}"
pgn_id = save_pgn(final_pgn, file_name, GITHUB_TOKEN)
# Save results
save_result_file(
pgn_id, model_id_white, model_id_black, termination, result, GITHUB_TOKEN, RESULT_GIST_ID
)
# Create and display the GIF
create_gif(board_images, file_name + ".gif", duration=400)
# Print ELO ratings
current_elo_white = elo_ratings_df.loc[
elo_ratings_df["Model"] == model_id_white, "ELO Rating"
].get(0, 1000)
current_elo_black = elo_ratings_df.loc[
elo_ratings_df["Model"] == model_id_black, "ELO Rating"
].get(0, 1000)
if result == "1-0":
new_elo_white = calculate_elo(current_elo_white, current_elo_black, 1)
new_elo_black = calculate_elo(current_elo_black, current_elo_white, 0)
update_str = f"""{model_id_white} wins! ({termination})
ELO change:
* {model_id_white}: {current_elo_white} -> {new_elo_white} ({new_elo_white - current_elo_white:+})
* {model_id_black}: {current_elo_black} -> {new_elo_black} ({new_elo_black - current_elo_black:+})
"""
gr.Info(update_str)
print(update_str)
elif result == "0-1":
new_elo_white = calculate_elo(current_elo_white, current_elo_black, 0)
new_elo_black = calculate_elo(current_elo_black, current_elo_white, 1)
update_str = f"""{model_id_black} wins! ({termination})
ELO change:
* {model_id_white}: {current_elo_white} -> {new_elo_white} ({new_elo_white - current_elo_white:+})
* {model_id_black}: {current_elo_black} -> {new_elo_black} ({new_elo_black - current_elo_black:+})
"""
gr.Info(update_str)
print(update_str)
elif result == "1/2-1/2":
new_elo_white = calculate_elo(current_elo_white, current_elo_black, 0.5)
new_elo_black = calculate_elo(current_elo_black, current_elo_white, 0.5)
update_str = f"""Draw! ({termination})
ELO change:
* {model_id_white}: {current_elo_white} -> {new_elo_white} ({new_elo_white - current_elo_white:+})
* {model_id_black}: {current_elo_black} -> {new_elo_black} ({new_elo_black - current_elo_black:+})
"""
gr.Info(update_str)
print(update_str)
elif result == "*":
print(f"Ongoing game! ({termination})")
# Update ELO ratings for each model
chess_data = pd.read_csv('chessllm_results.csv')
elo_ratings = update_elo_ratings(chess_data)
# Convert the dictionary to a DataFrame for better display
elo_ratings_df = pd.DataFrame(elo_ratings.items(), columns=['Model', 'ELO Rating'])
# Round the ELO ratings to the nearest integer
elo_ratings_df['ELO Rating'] = elo_ratings_df['ELO Rating'].round().astype(int)
elo_ratings_df.sort_values(by='ELO Rating', ascending=False, inplace=True)
elo_ratings_df.reset_index(drop=True, inplace=True)
elo_ratings_df.to_csv('chessllm_leaderboard.csv', index=False)
# Upload chessllm_leaderboard.csv to GIST
gist_api = gistyc.GISTyc(auth_token=GITHUB_TOKEN)
response_update_data = gist_api.update_gist(file_name='chessllm_leaderboard.csv', gist_id={LEAD_GIST_ID})
yield file_name + ".gif"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
GITHUB_TOKEN = os.environ["GITHUB_TOKEN"]
RESULT_GIST_ID = "c491299e7b8a45a61ce5403a70cf8656"
LEAD_GIST_ID = "696115fe2df47fb2350fcff2663678c9"
# Download results and leaderboard
url1 = (f"https://gist.githubusercontent.com/chessllm/{RESULT_GIST_ID}/raw")
download_file(url1, "chessllm_results.csv")
elo_ratings_df = get_leaderboard()
# Render chessboard
board = chess.Board()
image = render_init(board)
title = """
<div align="center">
<p style="font-size: 36px;">⚔️ Chess LLM Arena (preview)</p>
<p style="font-size: 20px;">💻 <a href="https://github.com/mlabonne/chessllm">GitHub</a> • 💾 <a href="https://gist.github.com/chessllm/696115fe2df47fb2350fcff2663678c9">Gist Database</a> • 🤖 <a href="https://colab.research.google.com/drive/1e2PszrvaY4Lv5SiRXuBGb5R4GdZsm-H3">Trainer</a> • 📁 <a href="https://colab.research.google.com/drive/11UjbfajCzphe707_V7PD-2e5WIzyintf">Dataset</a></p>
<p><em>Pick two chess LLMs and make them compete in a chess match. When the game is over, it will automatically update the crowd-sourced leaderboard. Build a dataset and train your own small language model to compete in the arena.</em></p>
</div>
"""
footer = """
<p><em>Made by Maxime Labonne, Kostis Gourgoulias, and Ruchi Bahl.</em></p>
"""
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
gr.Markdown(title)
with gr.Row():
model_id_white = gr.Textbox(label="♘ White Model ID", value="mlabonne/chesspythia-70m")
model_id_black = gr.Textbox(label="♞ Black Model ID", value="EleutherAI/pythia-70m-deduped")
btn = gr.Button("Fight!")
with gr.Row():
gr.HTML("""<div id='chessboard' style='width: 50%; display: block;'></div>""")
out = gr.Image(value=image, show_label=False, show_share_button=False, elem_id="chessboard")
gr.HTML("""<div id='rightblock' style='width: 50%; display: block;'></div>""")
btn.click(fn=update, inputs=[model_id_white, model_id_black], outputs=out)
gr.Markdown('<div align="center"><p style="font-size: 30px;">🏆 Leaderboard</p></div>')
leaderboard = gr.Dataframe(value=get_leaderboard, every=60)
gr.Markdown(footer)
demo.queue(api_open=False).launch(show_api=False)