RA-BART / app.py
MrVicente's picture
added demo base code
6cf191b
raw
history blame
2.76 kB
import gradio as gr
import matplotlib.pyplot as plt
from inference import RelationsInference
from utils import KGType,Model_Type
#############################
# Constants
#############################
examples = [["What's the meaning of life?", "eli5", "constraint"],
["boat, water, bird", "commongen", "constraint"],
["What flows under a bridge?", "commonsense_qa", "constraint"]]
bart = RelationsInference(
model_path='MrVicente/commonsense_bart_commongen',
kg_type=KGType.CONCEPTNET,
model_type=Model_Type.RELATIONS,
max_length=32
)
#############################
# Helper
#############################
def infer_bart(context, task_type, decoding_type_str):
response, encoder_attentions, model_input = bart.generate_based_on_context(context, use_kg=False)
return response[0]
def plot_attention(layer, head):
fig = plt.figure()
plt.plot([1, 2, 3], [2, 4, 6])
plt.title("Things")
plt.ylabel("Cases")
plt.xlabel("Days since Day 0")
return fig
#############################
# Interface
#############################
app = gr.Blocks()
with app:
gr.Markdown(
"""
# Demo
### Test Commonsense Relation-Aware BART (BART-RA) model
Tutorial: <br>
1) Select the possible model variations and tasks;<br>
2) Change the inputs and Click the buttons to produce results;<br>
3) See attention visualisations, by choosing a specific layer and head;<br>
""")
with gr.Row():
context_input = gr.Textbox(lines=2, value="What's the meaning of life?", label='Input:')
model_result_output = gr.Textbox(lines=2, label='Model result:')
with gr.Column():
task_type_choice = gr.Radio(
["eli5", "commongen"], value="eli5", label="What task do you want to try?"
)
decoding_type_choice = gr.Radio(
["default", "constraint"], value="default", label="What decoding strategy do you want to use?"
)
with gr.Row():
model_btn = gr.Button(value="See Model Results")
gr.Markdown(
"""
---
Observe Attention
"""
)
with gr.Row():
with gr.Column():
layer = gr.Slider(0, 11, 0, step=1, label="Layer")
head = gr.Slider(0, 15, 0, step=1, label="Head")
with gr.Column():
plot_output = gr.Plot()
with gr.Row():
vis_btn = gr.Button(value="See Attention Scores")
model_btn.click(fn=infer_bart, inputs=[context_input, task_type_choice, decoding_type_choice],
outputs=[model_result_output])
vis_btn.click(fn=plot_attention, inputs=[layer, head], outputs=[plot_output])
if __name__ == '__main__':
app.launch()