|
import random |
|
from typing import * |
|
|
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import seaborn as sns |
|
import sentencepiece as sp |
|
import torch |
|
|
|
from huggingface_hub import hf_hub_download |
|
from torchtext.datasets import Multi30k |
|
|
|
from models import Seq2Seq |
|
|
|
|
|
|
|
model_path = hf_hub_download("msarmi9/multi30k", "models/de-en/model.bin") |
|
model = Seq2Seq(vocab_size=8000, hidden_dim=512, bos_idx=1, eos_idx=2, pad_idx=3, temperature=2) |
|
model.load_state_dict(torch.load(model_path)) |
|
model.eval() |
|
|
|
|
|
source_spm_path = hf_hub_download("msarmi9/multi30k", "models/de-en/de8000.model") |
|
target_spm_path = hf_hub_download("msarmi9/multi30k", "models/de-en/en8000.model") |
|
source_spm = sp.SentencePieceProcessor(model_file=source_spm_path, add_eos=True) |
|
target_spm = sp.SentencePieceProcessor(model_file=target_spm_path, add_eos=True) |
|
|
|
|
|
normalize = lambda sample: (sample[0].lower().strip(), sample[1].lower().strip()) |
|
test_source, _ = zip(*map(normalize, Multi30k(split="test", language_pair=("de", "en")))) |
|
|
|
|
|
def attention_heatmap(input_tokens: List[str], output_tokens: List[str], weights: np.ndarray) -> plt.Figure: |
|
figure = plt.figure(dpi=800, tight_layout=True) |
|
axes = sns.heatmap(weights, cmap="gray", cbar=False) |
|
axes.set_xticklabels(input_tokens, rotation=90) |
|
axes.set_yticklabels(output_tokens, rotation=0) |
|
axes.tick_params(axis="both", length=0) |
|
axes.xaxis.tick_top() |
|
plt.close() |
|
return figure |
|
|
|
|
|
@torch.inference_mode() |
|
def run(input: str) -> Tuple[str, plt.Figure]: |
|
"""Run inference on a single sentence. Returns prediction and attention heatmap.""""" |
|
input_tensor = torch.tensor(source_spm.encode(input), dtype=torch.int64) |
|
output, weights = model.decode(input_tensor, max_decode_length=max(len(input_tensor), 80)) |
|
output = target_spm.decode(output.detach().tolist()) |
|
input_tokens = source_spm.encode(input, out_type=str) |
|
output_tokens = target_spm.encode(output, out_type=str) |
|
return output, attention_heatmap(input_tokens, output_tokens, weights.detach().numpy()) |
|
|
|
|
|
if __name__ == "__main__": |
|
interface = gr.Interface( |
|
run, |
|
inputs=gr.inputs.Textbox(lines=4, label="German"), |
|
outputs=[ |
|
gr.outputs.Textbox(label="English"), |
|
gr.outputs.Image(type="plot", label="Attention Heatmap"), |
|
], |
|
title = "Multi30k Translation Widget", |
|
examples=random.sample(test_source, k=30), |
|
examples_per_page=10, |
|
allow_flagging="never", |
|
theme="huggingface", |
|
live=True, |
|
) |
|
|
|
interface.launch( |
|
enable_queue=True, |
|
cache_examples=True, |
|
) |
|
|