|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
|
|
from inference import RelationsInference |
|
from utils import KGType,Model_Type |
|
|
|
|
|
|
|
|
|
|
|
examples = [["What's the meaning of life?", "eli5", "constraint"], |
|
["boat, water, bird", "commongen", "constraint"], |
|
["What flows under a bridge?", "commonsense_qa", "constraint"]] |
|
|
|
bart = RelationsInference( |
|
model_path='MrVicente/commonsense_bart_commongen', |
|
kg_type=KGType.CONCEPTNET, |
|
model_type=Model_Type.RELATIONS, |
|
max_length=32 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def infer_bart(context, task_type, decoding_type_str): |
|
response, encoder_attentions, model_input = bart.generate_based_on_context(context, use_kg=False) |
|
return response[0] |
|
|
|
|
|
def plot_attention(layer, head): |
|
fig = plt.figure() |
|
plt.plot([1, 2, 3], [2, 4, 6]) |
|
plt.title("Things") |
|
plt.ylabel("Cases") |
|
plt.xlabel("Days since Day 0") |
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
|
|
app = gr.Blocks() |
|
with app: |
|
gr.Markdown( |
|
""" |
|
# Demo |
|
### Test Commonsense Relation-Aware BART (BART-RA) model |
|
|
|
Tutorial: <br> |
|
1) Select the possible model variations and tasks;<br> |
|
2) Change the inputs and Click the buttons to produce results;<br> |
|
3) See attention visualisations, by choosing a specific layer and head;<br> |
|
""") |
|
with gr.Row(): |
|
context_input = gr.Textbox(lines=2, value="What's the meaning of life?", label='Input:') |
|
model_result_output = gr.Textbox(lines=2, label='Model result:') |
|
with gr.Column(): |
|
task_type_choice = gr.Radio( |
|
["eli5", "commongen"], value="eli5", label="What task do you want to try?" |
|
) |
|
decoding_type_choice = gr.Radio( |
|
["default", "constraint"], value="default", label="What decoding strategy do you want to use?" |
|
) |
|
with gr.Row(): |
|
model_btn = gr.Button(value="See Model Results") |
|
gr.Markdown( |
|
""" |
|
--- |
|
Observe Attention |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
layer = gr.Slider(0, 11, 0, step=1, label="Layer") |
|
head = gr.Slider(0, 15, 0, step=1, label="Head") |
|
with gr.Column(): |
|
plot_output = gr.Plot() |
|
with gr.Row(): |
|
vis_btn = gr.Button(value="See Attention Scores") |
|
model_btn.click(fn=infer_bart, inputs=[context_input, task_type_choice, decoding_type_choice], |
|
outputs=[model_result_output]) |
|
vis_btn.click(fn=plot_attention, inputs=[layer, head], outputs=[plot_output]) |
|
|
|
if __name__ == '__main__': |
|
app.launch() |