Xmaster6y's picture
attention interface
55ecc31 unverified
"""Global state of the app.
"""
import re
from transformers import AutoConfig
import torch
from nnsight import LanguageModel
conf = AutoConfig.from_pretrained("yp-edu/gpt2-stockfish-debug")
model = LanguageModel("yp-edu/gpt2-stockfish-debug")
model.eval()
def make_prompt(fen):
board, player, castling, *fen_remaining = fen.split()
board = re.sub(r"(\d)", lambda m: "0" * int(m.group(1)), board)
spaced_board = " ".join(board)
spaced_castling = " ".join(castling)
full_fen = f"{spaced_board} {player} {spaced_castling} {' '.join(fen_remaining)}"
return f"FEN: {full_fen} \nMOVE:"
def model_cache(fen):
global model
prompt = f"FEN: {fen}\nMOVE:"
attentions = {i: [] for i in range(12)}
with model.generate(prompt, max_new_tokens=10, output_attentions=True) as tracer:
out = model.generator.output.save()
for i in range(10):
for i in range(12):
attentions[i].append(model.transformer.h[i].attn.output[2].save())
tracer.next()
real_attentions = {}
for i in range(12):
real_attentions[i] = []
for a in attentions[i]:
try:
_ = a.shape
real_attentions[i].append(a)
except ValueError:
break
return out, real_attentions
def attribute_seqence(fen, out, attn_tensor):
global model
out_str = model.tokenizer.batch_decode(out)[0]