philipp-zettl commited on
Commit
49b8bf0
1 Parent(s): d41eec1

improve UI

Browse files
Files changed (2) hide show
  1. app.py +15 -2
  2. invalid_move.png +0 -0
app.py CHANGED
@@ -9,6 +9,7 @@ import chess.pgn
9
  from svglib.svglib import svg2rlg
10
  from reportlab.graphics import renderPM
11
  from PIL import Image
 
12
 
13
 
14
  vocab_size=33
@@ -30,18 +31,30 @@ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
30
  model.to(device)
31
  tokenizer = Tokenizer.from_pretrained(tokenizer_path)
32
 
 
 
33
  def generate(prompt):
34
  model_input = torch.tensor(tokenizer.encode(prompt), dtype=torch.long, device=device).view((1, len(prompt)))
35
  pgn = tokenizer.decode(model.generate(model_input, max_new_tokens=4, context_size=context_size)[0].tolist())
36
  pgn_str = StringIO(pgn)
37
- game = chess.pgn.read_game(pgn_str)
38
- img = chess.svg.board(game.board())
 
 
 
 
 
 
 
39
  filename = f'./moves-{pgn}'
40
  with open(filename + '.svg', 'w') as f:
41
  f.write(img)
42
  drawing = svg2rlg(filename + '.svg')
43
  renderPM.drawToFile(drawing, f"{filename}.png", fmt="PNG")
44
  plot = Image.open(f'{filename}.png')
 
 
 
45
  return pgn, plot
46
 
47
 
 
9
  from svglib.svglib import svg2rlg
10
  from reportlab.graphics import renderPM
11
  from PIL import Image
12
+ import os
13
 
14
 
15
  vocab_size=33
 
31
  model.to(device)
32
  tokenizer = Tokenizer.from_pretrained(tokenizer_path)
33
 
34
+ invalid_move_plot = Image.open('./invalid_move.png')
35
+
36
  def generate(prompt):
37
  model_input = torch.tensor(tokenizer.encode(prompt), dtype=torch.long, device=device).view((1, len(prompt)))
38
  pgn = tokenizer.decode(model.generate(model_input, max_new_tokens=4, context_size=context_size)[0].tolist())
39
  pgn_str = StringIO(pgn)
40
+ try:
41
+ game = chess.pgn.read_game(pgn_str)
42
+ board = game.board()
43
+ for move in game.mainline_moves():
44
+ board.push(move)
45
+ img = chess.svg.board(board)
46
+ except Exception as e:
47
+ if 'illegal san' in str(e):
48
+ return pgn, invalid_move_plot
49
  filename = f'./moves-{pgn}'
50
  with open(filename + '.svg', 'w') as f:
51
  f.write(img)
52
  drawing = svg2rlg(filename + '.svg')
53
  renderPM.drawToFile(drawing, f"{filename}.png", fmt="PNG")
54
  plot = Image.open(f'{filename}.png')
55
+
56
+ os.remove(f'{filename}.png')
57
+ os.remove(f'{filename}.svg')
58
  return pgn, plot
59
 
60
 
invalid_move.png ADDED