"""Global state of the app. """ import re from transformers import AutoConfig import torch from nnsight import LanguageModel conf = AutoConfig.from_pretrained("yp-edu/gpt2-stockfish-debug") model = LanguageModel("yp-edu/gpt2-stockfish-debug") model.eval() def make_prompt(fen): board, player, castling, *fen_remaining = fen.split() board = re.sub(r"(\d)", lambda m: "0" * int(m.group(1)), board) spaced_board = " ".join(board) spaced_castling = " ".join(castling) full_fen = f"{spaced_board} {player} {spaced_castling} {' '.join(fen_remaining)}" return f"FEN: {full_fen} \nMOVE:" def model_cache(fen): global model prompt = f"FEN: {fen}\nMOVE:" attentions = {i: [] for i in range(12)} with model.generate(prompt, max_new_tokens=10, output_attentions=True) as tracer: out = model.generator.output.save() for i in range(10): for i in range(12): attentions[i].append(model.transformer.h[i].attn.output[2].save()) tracer.next() real_attentions = {} for i in range(12): real_attentions[i] = [] for a in attentions[i]: try: _ = a.shape real_attentions[i].append(a) except ValueError: break return out, real_attentions def attribute_seqence(fen, out, attn_tensor): global model out_str = model.tokenizer.batch_decode(out)[0]