File size: 584 Bytes
67cda2a
4c39b84
 
 
 
67cda2a
 
4c39b84
 
 
 
 
 
 
67cda2a
4c39b84
67cda2a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from platform import python_version
from transformers import AutoModelForCausalLM, AutoTokenizer
from tuned_lens.nn import TunedLens
from tuned_lens.plotting import plot_lens

import gradio as gr

LENS_PATH = '<PATH TO LENS>'

def plot_lens_outputs(text):
    model = AutoModelForCausalLM.from_pretrained('gpt2')
    tokenizer = AutoTokenizer.from_pretrained('gpt2')
    #lens = TunedLens.load(LENS_PATH)
    return gr.outputs.Plot(plot_lens(model, tokenizer, text=text))

iface = gr.Interface(fn=plot_lens_outputs, inputs="text", outputs=gr.outputs.Plot(type="auto"))
iface.launch()