multi30k / app.py
msarmi9's picture
[refactor]: replace pyplot usage with matplotlib Figure
e40e595
raw
history blame
2.79 kB
import random
from typing import *
import gradio as gr
import numpy as np
import seaborn as sns
import sentencepiece as sp
import torch
from huggingface_hub import hf_hub_download
from matplotlib.figure import Figure
from torchtext.datasets import Multi30k
from models import Seq2Seq
# Load model
model_path = hf_hub_download("msarmi9/multi30k", "models/de-en/model.bin")
model = Seq2Seq(vocab_size=8000, hidden_dim=512, bos_idx=1, eos_idx=2, pad_idx=3, temperature=2)
model.load_state_dict(torch.load(model_path))
model.eval()
# Load sentencepiece tokenizers
source_spm_path = hf_hub_download("msarmi9/multi30k", "models/de-en/de8000.model")
target_spm_path = hf_hub_download("msarmi9/multi30k", "models/de-en/en8000.model")
source_spm = sp.SentencePieceProcessor(model_file=source_spm_path, add_eos=True)
target_spm = sp.SentencePieceProcessor(model_file=target_spm_path, add_eos=True)
# Load test set for example inputs
normalize = lambda sample: (sample[0].lower().strip(), sample[1].lower().strip())
test_source, _ = zip(*map(normalize, Multi30k(split="test", language_pair=("de", "en"))))
def attention_heatmap(input_tokens: List[str], output_tokens: List[str], weights: np.ndarray) -> Figure:
figure = Figure(dpi=800, tight_layout=True)
axes = figure.add_subplot()
axes = sns.heatmap(weights, ax=axes, xticklabels=input_tokens, yticklabels=output_tokens, cmap="gray", cbar=False)
axes.tick_params(axis="x", rotation=90, length=0)
axes.tick_params(axis="y", rotation=0, length=0)
axes.xaxis.tick_top()
return figure
@torch.inference_mode()
def run(input: str) -> Tuple[str, Figure]:
"""Run inference on a single sentence. Returns prediction and attention heatmap."""""
input = input.lower().strip().rstrip(".") + "."
input_tensor = torch.tensor(source_spm.encode(input), dtype=torch.int64)
output, weights = model.decode(input_tensor, max_decode_length=max(len(input_tensor), 80))
output = target_spm.decode(output.detach().tolist())
input_tokens = source_spm.encode(input, out_type=str)
output_tokens = target_spm.encode(output, out_type=str)
return output, attention_heatmap(input_tokens, output_tokens, weights.detach().numpy())
if __name__ == "__main__":
interface = gr.Interface(
run,
inputs=gr.inputs.Textbox(lines=4, label="German"),
outputs=[
gr.outputs.Textbox(label="English"),
gr.outputs.Image(type="plot", label="Attention Heatmap"),
],
title = "Multi30k Translation Widget",
examples=random.sample(test_source, k=30),
examples_per_page=10,
allow_flagging="never",
theme="huggingface",
live=True,
)
interface.launch(
enable_queue=True,
cache_examples=True,
)