[refactor]: replace pyplot usage with matplotlib Figure
Browse files
app.py
CHANGED
@@ -2,13 +2,13 @@ import random
|
|
2 |
from typing import *
|
3 |
|
4 |
import gradio as gr
|
5 |
-
import matplotlib.pyplot as plt
|
6 |
import numpy as np
|
7 |
import seaborn as sns
|
8 |
import sentencepiece as sp
|
9 |
import torch
|
10 |
|
11 |
from huggingface_hub import hf_hub_download
|
|
|
12 |
from torchtext.datasets import Multi30k
|
13 |
|
14 |
from models import Seq2Seq
|
@@ -31,19 +31,18 @@ normalize = lambda sample: (sample[0].lower().strip(), sample[1].lower().strip()
|
|
31 |
test_source, _ = zip(*map(normalize, Multi30k(split="test", language_pair=("de", "en"))))
|
32 |
|
33 |
|
34 |
-
def attention_heatmap(input_tokens: List[str], output_tokens: List[str], weights: np.ndarray) ->
|
35 |
-
figure =
|
36 |
-
axes =
|
37 |
-
axes.
|
38 |
-
axes.
|
39 |
-
axes.tick_params(axis="
|
40 |
axes.xaxis.tick_top()
|
41 |
-
plt.close()
|
42 |
return figure
|
43 |
|
44 |
|
45 |
@torch.inference_mode()
|
46 |
-
def run(input: str) -> Tuple[str,
|
47 |
"""Run inference on a single sentence. Returns prediction and attention heatmap."""""
|
48 |
input = input.lower().strip().rstrip(".") + "."
|
49 |
input_tensor = torch.tensor(source_spm.encode(input), dtype=torch.int64)
|
|
|
2 |
from typing import *
|
3 |
|
4 |
import gradio as gr
|
|
|
5 |
import numpy as np
|
6 |
import seaborn as sns
|
7 |
import sentencepiece as sp
|
8 |
import torch
|
9 |
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
+
from matplotlib.figure import Figure
|
12 |
from torchtext.datasets import Multi30k
|
13 |
|
14 |
from models import Seq2Seq
|
|
|
31 |
test_source, _ = zip(*map(normalize, Multi30k(split="test", language_pair=("de", "en"))))
|
32 |
|
33 |
|
34 |
+
def attention_heatmap(input_tokens: List[str], output_tokens: List[str], weights: np.ndarray) -> Figure:
|
35 |
+
figure = Figure(dpi=800, tight_layout=True)
|
36 |
+
axes = figure.add_subplot()
|
37 |
+
axes = sns.heatmap(weights, ax=axes, xticklabels=input_tokens, yticklabels=output_tokens, cmap="gray", cbar=False)
|
38 |
+
axes.tick_params(axis="x", rotation=90, length=0)
|
39 |
+
axes.tick_params(axis="y", rotation=0, length=0)
|
40 |
axes.xaxis.tick_top()
|
|
|
41 |
return figure
|
42 |
|
43 |
|
44 |
@torch.inference_mode()
|
45 |
+
def run(input: str) -> Tuple[str, Figure]:
|
46 |
"""Run inference on a single sentence. Returns prediction and attention heatmap."""""
|
47 |
input = input.lower().strip().rstrip(".") + "."
|
48 |
input_tensor = torch.tensor(source_spm.encode(input), dtype=torch.int64)
|