import gradio as gr import torch from transformers import GPT2TokenizerFast, GPT2LMHeadModel from gpt2_knn_attention import GPT2KNNAttention from knn_memory import KNNLayer, ClearMemoryLayer def inject_knn_in_gpt2(model, knn_memory, bos_token_id, eos_token_id, device, layer_ind=8): layer = model.transformer.h[layer_ind].attn state = layer.state_dict() knn_layer = GPT2KNNAttention( config, knn_memory, device, is_cross_attention=False, layer_idx=layer.layer_idx) knn_state = knn_layer.state_dict() for k, v in state.items(): knn_state[k] = v knn_layer.load_state_dict(knn_state) model.transformer.h[8].attn = knn_layer model.transformer = ClearMemoryLayer( knn_memory, bos_token_id, eos_token_id, model.transformer) model.eval() model_name = "gpt2" tokenizer = GPT2TokenizerFast.from_pretrained(model_name) model = GPT2LMHeadModel.from_pretrained(model_name) config = model.config model.eval() knn_memory = KNNLayer(config, share_memory=False, batch_size=1) bos_token_id, eos_token_id = tokenizer.bos_token_id, tokenizer.eos_token_id bos_token, eos_token = tokenizer.bos_token, tokenizer.eos_token device = torch.device("cuda" if torch.cuda.is_available() else "cpu") inject_knn_in_gpt2( model, knn_memory, bos_token_id, eos_token_id, device, layer_ind=8) model.load_state_dict(torch.load('gpt2_knn_attention.pt')) def generate(text, temperature, max_new_tokens, top_p): encoded_input = tokenizer(text, return_tensors='pt') output = model.generate(**encoded_input, do_sample=True, max_new_tokens=int(max_new_tokens), temperature=temperature, top_p=top_p) return tokenizer.decode(output[0]) desc = "Попытка повторить статью от Google [Memorizing Transformers](https://arxiv.org/abs/2203.08913). "\ "В ней вводиться новый слой **KNNAttention**, который использует approximate kNN в базе с (key, value), чтобы делать attention по большому контексту. Это позволяет расширить контекст трансформера до размера книг и статей, несильно замедляя его.\n\n"\ "Я написал свои **KNNAttention**, переписал слой **GPT2Attention**, чтобы он использовал KNNAttention, а также написал несколько вспомогательный классов для всего этого.\n\n"\ "Я сел писать это за **3 недели** до дедлайна, но все равно не довел до результата, которого изначально хотел. Но я доволен проделанной работой :)" demo = gr.Interface( fn=generate, inputs=[gr.inputs.Textbox(lines=5, label="Input Text"), gr.Slider(0.001, 2.0, step=0.05, value=0.8, label='temperature'), gr.Slider(1, 512, step=1, value=32, label='max_new_tokens'), gr.Slider(0.1, 1.0, step=0.02, value=0.92, label='top_p')], outputs=gr.outputs.Textbox(label="Generated Text"), description=desc, title="Memorizing Transformers", examples=[ ["My name is Lewis and I like to", 0.8, 32, 0.92] ] ) demo.launch()