Spaces:
Sleeping
Sleeping
File size: 1,432 Bytes
55ecc31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
"""Global state of the app.
"""
import re
from transformers import AutoConfig
import torch
from nnsight import LanguageModel
conf = AutoConfig.from_pretrained("yp-edu/gpt2-stockfish-debug")
model = LanguageModel("yp-edu/gpt2-stockfish-debug")
model.eval()
def make_prompt(fen):
board, player, castling, *fen_remaining = fen.split()
board = re.sub(r"(\d)", lambda m: "0" * int(m.group(1)), board)
spaced_board = " ".join(board)
spaced_castling = " ".join(castling)
full_fen = f"{spaced_board} {player} {spaced_castling} {' '.join(fen_remaining)}"
return f"FEN: {full_fen} \nMOVE:"
def model_cache(fen):
global model
prompt = f"FEN: {fen}\nMOVE:"
attentions = {i: [] for i in range(12)}
with model.generate(prompt, max_new_tokens=10, output_attentions=True) as tracer:
out = model.generator.output.save()
for i in range(10):
for i in range(12):
attentions[i].append(model.transformer.h[i].attn.output[2].save())
tracer.next()
real_attentions = {}
for i in range(12):
real_attentions[i] = []
for a in attentions[i]:
try:
_ = a.shape
real_attentions[i].append(a)
except ValueError:
break
return out, real_attentions
def attribute_seqence(fen, out, attn_tensor):
global model
out_str = model.tokenizer.batch_decode(out)[0]
|