agiats commited on
Commit
3dc3966
1 Parent(s): 2284a2f

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +213 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import List, Tuple
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torch
7
+ from sudachipy import dictionary
8
+ from sudachipy import tokenizer as sudachi_tokenizer
9
+ from transformers import AutoModelForCausalLM, PreTrainedTokenizer, T5Tokenizer
10
+
11
+ # trained model
12
+ model_dir = Path(__file__).parents[0] / "model"
13
+ device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
14
+ tokenizer = T5Tokenizer.from_pretrained(model_dir)
15
+ tokenizer.do_lower_case = True
16
+ trained_model = AutoModelForCausalLM.from_pretrained(model_dir)
17
+ trained_model.to(device)
18
+
19
+ # baseline model
20
+ baseline_model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium")
21
+ baseline_model.to(device)
22
+
23
+ sudachi_tokenizer_obj = dictionary.Dictionary().create()
24
+ mode = sudachi_tokenizer.Tokenizer.SplitMode.C
25
+
26
+
27
+ def sudachi_tokenize(input_text: str) -> List[str]:
28
+ morphemes = sudachi_tokenizer_obj.tokenize(input_text, mode)
29
+ return [morpheme.surface() for morpheme in morphemes]
30
+
31
+
32
+ def calc_offsets(tokens: List[str]) -> List[int]:
33
+ offsets = [0]
34
+ for token in tokens:
35
+ offsets.append(offsets[-1] + len(token))
36
+ return offsets
37
+
38
+
39
+ def distribute_surprisals_to_characters(
40
+ tokens2surprisal: List[Tuple[str, float]]
41
+ ) -> List[Tuple[str, float]]:
42
+ tokens2surprisal_by_character: List[Tuple[str, float]] = []
43
+ for token, surprisal in tokens2surprisal:
44
+ token_len = len(token)
45
+ for character in token:
46
+ tokens2surprisal_by_character.append((character, surprisal / token_len))
47
+ return tokens2surprisal_by_character
48
+
49
+
50
+ def calculate_surprisals_by_character(
51
+ input_text: str, model: AutoModelForCausalLM, tokenizer: PreTrainedTokenizer
52
+ ) -> Tuple[float, List[Tuple[str, float]]]:
53
+ input_tokens = [
54
+ token.replace("▁", "")
55
+ for token in tokenizer.tokenize(input_text)
56
+ if token != "▁"
57
+ ]
58
+ input_ids = tokenizer.encode(
59
+ "<s>" + input_text, add_special_tokens=False, return_tensors="pt"
60
+ ).to(device)
61
+
62
+ logits = model(input_ids)["logits"].squeeze(0)
63
+
64
+ surprisals = []
65
+ for i in range(logits.shape[0] - 1):
66
+ if input_ids[0][i + 1] == 9:
67
+ continue
68
+ logit = logits[i]
69
+ prob = torch.softmax(logit, dim=0)
70
+ neg_logprob = -torch.log(prob)
71
+ surprisals.append(neg_logprob[input_ids[0][i + 1]].item())
72
+ mean_surprisal = np.mean(surprisals)
73
+
74
+ tokens2surprisal: List[Tuple[str, float]] = []
75
+ for token, surprisal in zip(input_tokens, surprisals):
76
+ tokens2surprisal.append((token, surprisal))
77
+
78
+ char2surprisal = distribute_surprisals_to_characters(tokens2surprisal)
79
+
80
+ return mean_surprisal, char2surprisal
81
+
82
+
83
+ def aggregate_surprisals_by_offset(
84
+ char2surprisal: List[Tuple[str, float]], offsets: List[int]
85
+ ) -> List[Tuple[str, float]]:
86
+ tokens2surprisal = []
87
+ for i in range(len(offsets) - 1):
88
+ start = offsets[i]
89
+ end = offsets[i + 1]
90
+ surprisal = sum([surprisal for _, surprisal in char2surprisal[start:end]])
91
+ token = "".join([char for char, _ in char2surprisal[start:end]])
92
+ tokens2surprisal.append((token, surprisal))
93
+
94
+ return tokens2surprisal
95
+
96
+
97
+ def highlight_token(token: str, score: float):
98
+ if score > 0:
99
+ html_color = "#%02X%02X%02X" % (
100
+ 255,
101
+ int(255 * (1 - score)),
102
+ int(255 * (1 - score)),
103
+ )
104
+ else:
105
+ html_color = "#%02X%02X%02X" % (
106
+ int(255 * (1 + score)),
107
+ int(255 * (1 + score)),
108
+ 255,
109
+ )
110
+ return '<span style="background-color: {}; color: black">{}</span>'.format(
111
+ html_color, token
112
+ )
113
+
114
+
115
+ def create_highlighted_text(label: str, tokens2scores: List[Tuple[str, float]]):
116
+ highlighted_text: str = "<h2><b>" + label + "</b></h2>"
117
+ for token, score in tokens2scores:
118
+ highlighted_text += highlight_token(token, score)
119
+ return highlighted_text
120
+
121
+
122
+ def normalize_surprisals(
123
+ tokens2surprisal: List[Tuple[str, float]], log_scale: bool = False
124
+ ) -> List[Tuple[str, float]]:
125
+ if log_scale:
126
+ surprisals = [np.log(surprisal) for _, surprisal in tokens2surprisal]
127
+ else:
128
+ surprisals = [surprisal for _, surprisal in tokens2surprisal]
129
+ min_surprisal = np.min(surprisals)
130
+ max_surprisal = np.max(surprisals)
131
+ surprisals = [
132
+ (surprisal - min_surprisal) / (max_surprisal - min_surprisal)
133
+ for surprisal in surprisals
134
+ ]
135
+ assert min(surprisals) >= 0
136
+ assert max(surprisals) <= 1
137
+ return [
138
+ (token, surprisal)
139
+ for (token, _), surprisal in zip(tokens2surprisal, surprisals)
140
+ ]
141
+
142
+
143
+ def calculate_surprisal_diff(
144
+ tokens2surprisal: List[Tuple[str, float]],
145
+ baseline_tokens2surprisal: List[Tuple[str, float]],
146
+ scale: float = 100.0,
147
+ ):
148
+ diff_tokens2surprisal = [
149
+ (token, (surprisal - baseline_surprisal) * 100)
150
+ for (token, surprisal), (_, baseline_surprisal) in zip(
151
+ tokens2surprisal, baseline_tokens2surprisal
152
+ )
153
+ ]
154
+ return diff_tokens2surprisal
155
+
156
+
157
+ def main(input_text: str) -> Tuple[str, str, str]:
158
+ mean_surprisal, char2surprisal = calculate_surprisals_by_character(
159
+ input_text, trained_model, tokenizer
160
+ )
161
+ offsets = calc_offsets(sudachi_tokenize(input_text))
162
+ tokens2surprisal = aggregate_surprisals_by_offset(char2surprisal, offsets)
163
+ tokens2surprisal = normalize_surprisals(tokens2surprisal)
164
+ highlighted_text = create_highlighted_text("学習後モデル", tokens2surprisal)
165
+
166
+ (
167
+ baseline_mean_surprisal,
168
+ baseline_char2surprisal,
169
+ ) = calculate_surprisals_by_character(input_text, baseline_model, tokenizer)
170
+ baseline_tokens2surprisal = aggregate_surprisals_by_offset(
171
+ baseline_char2surprisal, offsets
172
+ )
173
+ baseline_tokens2surprisal = normalize_surprisals(baseline_tokens2surprisal)
174
+ baseline_highlighted_text = create_highlighted_text(
175
+ "学習前モデル", baseline_tokens2surprisal
176
+ )
177
+
178
+ diff_tokens2surprisal = calculate_surprisal_diff(
179
+ tokens2surprisal, baseline_tokens2surprisal, 100.0
180
+ )
181
+ diff_highlighted_text = create_highlighted_text("学習前後の差分", diff_tokens2surprisal)
182
+ return (
183
+ baseline_highlighted_text,
184
+ highlighted_text,
185
+ diff_highlighted_text,
186
+ )
187
+
188
+
189
+ if __name__ == "__main__":
190
+ demo = gr.Interface(
191
+ fn=main,
192
+ title="読みにくい箇所を検出するAI(デモ)",
193
+ description="テキストを入力すると、読みにくさに応じてハイライトされて出力されます。",
194
+ show_label=True,
195
+ inputs=gr.Textbox(
196
+ lines=5,
197
+ label="テキスト",
198
+ placeholder="ここにテキストを入力してください。",
199
+ ),
200
+ outputs=[
201
+ gr.HTML(label="学習前モデル", show_label=True),
202
+ gr.HTML(label="学習後モデル", show_label=True),
203
+ gr.HTML(label="学習前後の差分", show_label=True),
204
+ ],
205
+ examples=[
206
+ "太郎が二郎を殴った。",
207
+ "太郎が二郎に殴った。",
208
+ "サイエンスインパクトラボは、国立研究開発法人科学技術振興機構(JST)の「科学と社会」推進部が行う共創プログラムです。「先端の研究開発を行う研究者」と「社会課題解決に取り組むプレイヤー」が約3ヶ月に渡って共創活動を行います。",
209
+ "近年、ニューラル言語モデルが自然言語の統語知識をどれほど有しているかを、容認性判断課題を通して検証する研究が行われてきている。しかし、このような言語モデルの統語的評価を行うためのデータセットは、主に英語を中心とした欧米の諸言語を対象に構築されてきた。本研究では、既存のデータセットの問題点を克服しつつ、このようなデータセットが構築されてこなかった日本語を対象とした初めてのデータセットである JCoLA (JapaneseCorpus of Linguistic Acceptability) を構築した上で、それを用いた言語モデルの統語的評価を行った。",
210
+ ],
211
+ )
212
+
213
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch==1.12.1
2
+ transformers==4.20.0
3
+ sentencepiece==0.1.97
4
+ sudachipy
5
+ sudachidict_core