msarmi9 commited on
Commit
e40e595
·
1 Parent(s): 3815353

[refactor]: replace pyplot usage with matplotlib Figure

Browse files
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -2,13 +2,13 @@ import random
2
  from typing import *
3
 
4
  import gradio as gr
5
- import matplotlib.pyplot as plt
6
  import numpy as np
7
  import seaborn as sns
8
  import sentencepiece as sp
9
  import torch
10
 
11
  from huggingface_hub import hf_hub_download
 
12
  from torchtext.datasets import Multi30k
13
 
14
  from models import Seq2Seq
@@ -31,19 +31,18 @@ normalize = lambda sample: (sample[0].lower().strip(), sample[1].lower().strip()
31
  test_source, _ = zip(*map(normalize, Multi30k(split="test", language_pair=("de", "en"))))
32
 
33
 
34
- def attention_heatmap(input_tokens: List[str], output_tokens: List[str], weights: np.ndarray) -> plt.Figure:
35
- figure = plt.figure(dpi=800, tight_layout=True)
36
- axes = sns.heatmap(weights, cmap="gray", cbar=False)
37
- axes.set_xticklabels(input_tokens, rotation=90)
38
- axes.set_yticklabels(output_tokens, rotation=0)
39
- axes.tick_params(axis="both", length=0)
40
  axes.xaxis.tick_top()
41
- plt.close()
42
  return figure
43
 
44
 
45
  @torch.inference_mode()
46
- def run(input: str) -> Tuple[str, plt.Figure]:
47
  """Run inference on a single sentence. Returns prediction and attention heatmap."""""
48
  input = input.lower().strip().rstrip(".") + "."
49
  input_tensor = torch.tensor(source_spm.encode(input), dtype=torch.int64)
 
2
  from typing import *
3
 
4
  import gradio as gr
 
5
  import numpy as np
6
  import seaborn as sns
7
  import sentencepiece as sp
8
  import torch
9
 
10
  from huggingface_hub import hf_hub_download
11
+ from matplotlib.figure import Figure
12
  from torchtext.datasets import Multi30k
13
 
14
  from models import Seq2Seq
 
31
  test_source, _ = zip(*map(normalize, Multi30k(split="test", language_pair=("de", "en"))))
32
 
33
 
34
+ def attention_heatmap(input_tokens: List[str], output_tokens: List[str], weights: np.ndarray) -> Figure:
35
+ figure = Figure(dpi=800, tight_layout=True)
36
+ axes = figure.add_subplot()
37
+ axes = sns.heatmap(weights, ax=axes, xticklabels=input_tokens, yticklabels=output_tokens, cmap="gray", cbar=False)
38
+ axes.tick_params(axis="x", rotation=90, length=0)
39
+ axes.tick_params(axis="y", rotation=0, length=0)
40
  axes.xaxis.tick_top()
 
41
  return figure
42
 
43
 
44
  @torch.inference_mode()
45
+ def run(input: str) -> Tuple[str, Figure]:
46
  """Run inference on a single sentence. Returns prediction and attention heatmap."""""
47
  input = input.lower().strip().rstrip(".") + "."
48
  input_tensor = torch.tensor(source_spm.encode(input), dtype=torch.int64)