Spaces:
Running
Running
"""Global state of the app. | |
""" | |
import re | |
from transformers import AutoConfig | |
import torch | |
from nnsight import LanguageModel | |
conf = AutoConfig.from_pretrained("yp-edu/gpt2-stockfish-debug") | |
model = LanguageModel("yp-edu/gpt2-stockfish-debug") | |
model.eval() | |
def make_prompt(fen): | |
board, player, castling, *fen_remaining = fen.split() | |
board = re.sub(r"(\d)", lambda m: "0" * int(m.group(1)), board) | |
spaced_board = " ".join(board) | |
spaced_castling = " ".join(castling) | |
full_fen = f"{spaced_board} {player} {spaced_castling} {' '.join(fen_remaining)}" | |
return f"FEN: {full_fen} \nMOVE:" | |
def model_cache(fen): | |
global model | |
prompt = f"FEN: {fen}\nMOVE:" | |
attentions = {i: [] for i in range(12)} | |
with model.generate(prompt, max_new_tokens=10, output_attentions=True) as tracer: | |
out = model.generator.output.save() | |
for i in range(10): | |
for i in range(12): | |
attentions[i].append(model.transformer.h[i].attn.output[2].save()) | |
tracer.next() | |
real_attentions = {} | |
for i in range(12): | |
real_attentions[i] = [] | |
for a in attentions[i]: | |
try: | |
_ = a.shape | |
real_attentions[i].append(a) | |
except ValueError: | |
break | |
return out, real_attentions | |
def attribute_seqence(fen, out, attn_tensor): | |
global model | |
out_str = model.tokenizer.batch_decode(out)[0] | |