fist commit
Browse files- app.py +208 -0
- requirements.txt +67 -0
app.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import T5ForConditionalGeneration, T5Tokenizer, AutoTokenizer, AutoModelForSeq2SeqLM, RobertaForQuestionAnswering
|
4 |
+
|
5 |
+
# 0.モデルのロード, Examplesの準備
|
6 |
+
tokenizer_sum = AutoTokenizer.from_pretrained("tsmatz/mt5_summarize_japanese")
|
7 |
+
model_sum = AutoModelForSeq2SeqLM.from_pretrained("tsmatz/mt5_summarize_japanese")
|
8 |
+
|
9 |
+
# 質問文の生成
|
10 |
+
tokenizer_gen_q = T5Tokenizer.from_pretrained("sonoisa/t5-base-japanese-question-generation")
|
11 |
+
model_gen_q = T5ForConditionalGeneration.from_pretrained("sonoisa/t5-base-japanese-question-generation")
|
12 |
+
|
13 |
+
# 回答の生成
|
14 |
+
tokenizer_qa = AutoTokenizer.from_pretrained("tsmatz/roberta_qa_japanese")
|
15 |
+
model_qa = RobertaForQuestionAnswering.from_pretrained("tsmatz/roberta_qa_japanese")
|
16 |
+
|
17 |
+
# Example 1
|
18 |
+
eg_text_1 = """
|
19 |
+
ポケットモンスターの原点は、1996年2月27日に発売されたゲームボーイ用ソフト『ポケットモンスター 赤・緑』である。
|
20 |
+
開発元はゲームフリーク。コンセプトメーカーにしてディレクターを務めたのは、同社代表取締役でもある田尻智。
|
21 |
+
この作品が小学生を中心に、口コミから火が点き大ヒットとなり、以降も多くの続編が発売されている(詳しくは「ポケットモンスター(ゲーム)」を参照)。
|
22 |
+
ゲーム本編作品だけでなく、派生作品や関連作品が数多く発売されている(詳しくはポケットモンスターの関連ゲームを参照)。
|
23 |
+
|
24 |
+
ポケモンはゲームのみならず、アニメ化、キャラクター商品化、カードゲーム、アーケードゲームと様々なメディアミックス展開がなされ、日本国外でも人気を獲得している。
|
25 |
+
|
26 |
+
ポケモン関連ゲームソフトの累計出荷数は、全世界で2017年11月時点で3億本以上[1]、2022年3月時点で4億4000万本以上に達している[2]。
|
27 |
+
その中で、メインシリーズの累計販売本数は2016年2月時点での最新作、ニンテンドー3DS『オメガルビー・アルファサファイア』までの25作品で2億100万本となる[3]。
|
28 |
+
"""
|
29 |
+
eg_ans_1_1 = "2月27日"
|
30 |
+
eg_ans_1_2 = "ポケットモンスター 赤・緑"
|
31 |
+
|
32 |
+
# Example 2
|
33 |
+
eg_text_2 = """
|
34 |
+
アンパンマンの生みの親であるやなせたかしの作品で1968年に「バラの花とジョー」、
|
35 |
+
「チリンの鈴」の絵本や映画にいち早くアンパンマンが登場しているが、この時はまだ人間の姿。
|
36 |
+
この童話は一年間連載された。[5]アンパンマン、やなせたかしの作品としての、「アンパンマン」は、
|
37 |
+
PHP研究所が発行する青年向け雑誌『PHP』の通巻第257号に当たる、『こどものえほん』の1969年10月号[6](同年10月1日刊行)に掲載された青年向け読物、
|
38 |
+
やなせたかし(絵と文)「アンパンマン」という形が初出である[7][8][9]。
|
39 |
+
この時期、やなせが『こどものえほん』のために執筆した読物は連載12本の短編で、「アンパンマン」はその6本目の作品であった。
|
40 |
+
これら12篇は、株式会社山梨シルクセンター(※3年後、株式会社サンリオへ社名変更)より単行本『十二の真珠』名義で1970年に刊行された。
|
41 |
+
|
42 |
+
空腹に喘ぐ人の所へ駆け付けて、自らの大事な持ち物であるパンを差し出して食べるよう勧めるという、のちのアンパンマンに通じる物語の骨組みが、
|
43 |
+
この作品のおいて早くも整えられている[10][6]。
|
44 |
+
絵本・漫画・アニメなど、のちに描かれるアンパンマンとの大きな違いと言えば、第一に主人公のアンパンマンが普通の人間のおじさんであり[10][6]、
|
45 |
+
パンは所有物に過ぎなかったことである。
|
46 |
+
"""
|
47 |
+
eg_ans_2_1 = "アンパンマン"
|
48 |
+
eg_ans_2_2 = "やなせたかし"
|
49 |
+
|
50 |
+
# 1. イベント用の関数
|
51 |
+
def summy(text):
|
52 |
+
"""要約
|
53 |
+
|
54 |
+
Args
|
55 |
+
text: str
|
56 |
+
要約対象のテキスト
|
57 |
+
|
58 |
+
Returns
|
59 |
+
summarize_text: str
|
60 |
+
要約結果のテキスト
|
61 |
+
|
62 |
+
TODO
|
63 |
+
処理の実装
|
64 |
+
"""
|
65 |
+
inputs = tokenizer_sum("summarize: " + text, return_tensors="pt")
|
66 |
+
outputs = model_sum.generate(
|
67 |
+
inputs["input_ids"],
|
68 |
+
max_new_tokens=300, # 生成数の上限
|
69 |
+
min_length=150, # 生成数の下限
|
70 |
+
num_beams=5 # ビームサーチの設定
|
71 |
+
)
|
72 |
+
summarize_text = tokenizer_sum.decode(output[0], skip_special_tokens=True)
|
73 |
+
|
74 |
+
return summarize_text
|
75 |
+
|
76 |
+
def generate_questions(answer_1, answer_2, text):
|
77 |
+
"""質問生成
|
78 |
+
|
79 |
+
Args
|
80 |
+
answers: list[str]
|
81 |
+
質問生成のための正解単語のリスト
|
82 |
+
text: str
|
83 |
+
質問文を生成する際に参照するテキスト
|
84 |
+
|
85 |
+
Returns
|
86 |
+
generated_questions: list[str]
|
87 |
+
生成された質問文のリスト
|
88 |
+
|
89 |
+
TODO
|
90 |
+
処理の実装
|
91 |
+
"""
|
92 |
+
answer_context_list = [(answer, text) for answer in [answer_1, answer_2]] # 解答を質問生成する元となる文(要約結果)とセットにする。
|
93 |
+
generated_questions = []
|
94 |
+
|
95 |
+
for answer, context in answer_context_list:
|
96 |
+
# モデルに入力可能な形式に変換する
|
97 |
+
# 「answer: 」と「context: 」を使った形式に変換にする
|
98 |
+
input = tokenizer_gen_q(f"answer: {answer} context: {context}", return_tensors="pt")
|
99 |
+
|
100 |
+
# 質問文を生成する
|
101 |
+
output = model_gen_q.generate(
|
102 |
+
input['input_ids'],
|
103 |
+
max_new_tokens=100,
|
104 |
+
num_beams=4 # ビームサーチの設定
|
105 |
+
)
|
106 |
+
|
107 |
+
# 生成された問題文のトークン列を文字列に変換する。
|
108 |
+
output = tokenizer_gen_q.decode(output[0], skip_special_tokens=True)
|
109 |
+
generated_questions.append(output)
|
110 |
+
|
111 |
+
return generated_questions
|
112 |
+
|
113 |
+
def extract_answer(question, text):
|
114 |
+
"""質問応答
|
115 |
+
|
116 |
+
Args
|
117 |
+
question: str
|
118 |
+
質問文のテキスト
|
119 |
+
text: str
|
120 |
+
質問に回答するために参照するテキスト
|
121 |
+
|
122 |
+
Returns
|
123 |
+
answer: str
|
124 |
+
回答のテキスト
|
125 |
+
|
126 |
+
TODO
|
127 |
+
処理の実装
|
128 |
+
"""
|
129 |
+
inputs = tokenizer_qa(question, text, return_tensors="pt") # tokenizerには複数のテキストを与える
|
130 |
+
|
131 |
+
# 正解箇所の予測
|
132 |
+
outputs = model_qa(**inputs)
|
133 |
+
answer_start_scores = outputs.start_logits
|
134 |
+
answer_end_scores = outputs.end_logits
|
135 |
+
|
136 |
+
# 予測結果の開始と終了のインデックスを取得
|
137 |
+
answer_start = torch.argmax(answer_start_scores)
|
138 |
+
answer_end = torch.argmax(answer_end_scores) + 1
|
139 |
+
|
140 |
+
# tokenizerの結果から正解を抽出する
|
141 |
+
input_ids = inputs["input_ids"].tolist()[0]
|
142 |
+
|
143 |
+
answer = tokenizer_qa.decode(input_ids[answer_start:answer_end])
|
144 |
+
|
145 |
+
return answer
|
146 |
+
|
147 |
+
def extract_answer_all(gen_q_1, gen_q_2, source_text, sum_text):
|
148 |
+
"""extract_answer()をまとめて実行する
|
149 |
+
TODO
|
150 |
+
処理の実装
|
151 |
+
"""
|
152 |
+
a_source_1 = extract_answer(gen_q_1, source_text)
|
153 |
+
a_sum_1 = extract_answer(gen_q_1, sum_text)
|
154 |
+
a_source_2 = extract_answer(gen_q_2, source_text)
|
155 |
+
a_sum_2 = extract_answer(gen_q_2, sum_text)
|
156 |
+
|
157 |
+
return a_source_1, a_sum_1, a_source_2, a_sum_2
|
158 |
+
|
159 |
+
# 2. UIの定義
|
160 |
+
with gr.Blocks() as demo:
|
161 |
+
gr.Markdown("### 1. 要約生成")
|
162 |
+
# TODO 要約のための入出力UIの作成
|
163 |
+
source_text = gr.Textbox(label="要約対象")
|
164 |
+
btn_summy = gr.Button("要約生成")
|
165 |
+
sum_text = gr.Textbox(label="要約結果")
|
166 |
+
|
167 |
+
gr.Markdown("### 2. 質問生成")
|
168 |
+
# TODO 質問文生成のための入力UIの作成
|
169 |
+
|
170 |
+
with gr.Row():
|
171 |
+
with gr.Column():
|
172 |
+
answer_1 = gr.Text(label="正解1")
|
173 |
+
with gr.Column():
|
174 |
+
answer_2 = gr.Text(label="正解2")
|
175 |
+
btn_generate_questions = gr.Button("質問生成")
|
176 |
+
|
177 |
+
gr.Markdown("### 3. 回答生成")
|
178 |
+
with gr.Row():
|
179 |
+
with gr.Column():
|
180 |
+
gen_q_1 = gr.Text(label="1番目の質問")
|
181 |
+
with gr.Column():
|
182 |
+
gen_q_2 = gr.Text(label="2番目の質問")
|
183 |
+
btn_extract_answer = gr.Button("回答生成")
|
184 |
+
with gr.Row():
|
185 |
+
with gr.Column():
|
186 |
+
a_source_1 = gr.Text(label="sourceからの答え")
|
187 |
+
a_sum_1 = gr.Text(label="sumからの答え")
|
188 |
+
with gr.Column():
|
189 |
+
a_source_2 = gr.Text(label="sourceからの答え")
|
190 |
+
a_sum_2 = gr.Text(label="sumからの答え")
|
191 |
+
|
192 |
+
# 2. イベント発火
|
193 |
+
btn_summy.click(summy, inputs=[source_text], outputs=[sum_text])
|
194 |
+
btn_generate_questions.click(generate_questions, inputs=[answer_1, answer_2, sum_text], outputs=[gen_q_1, gen_q_2])
|
195 |
+
btn_extract_answer.click(extract_answer_all,
|
196 |
+
inputs=[gen_q_1, gen_q_2, source_text, sum_text],
|
197 |
+
outputs=[a_source_1, a_sum_1, a_source_2, a_sum_2]
|
198 |
+
)
|
199 |
+
|
200 |
+
# Examplesの定義
|
201 |
+
gr.Markdown("## Examples")
|
202 |
+
gr.Examples(
|
203 |
+
[[eg_text_1, eg_ans_1_1, eg_ans_1_2], [eg_text_2, eg_ans_2_1, eg_ans_2_2]],
|
204 |
+
[source_text, answer_1, answer_2],
|
205 |
+
)
|
206 |
+
|
207 |
+
if __name__ == "__main__":
|
208 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==22.1.0
|
2 |
+
aiohttp==3.8.3
|
3 |
+
aiosignal==1.3.1
|
4 |
+
altair==4.2.2
|
5 |
+
anyio==3.6.2
|
6 |
+
async-timeout==4.0.2
|
7 |
+
attrs==22.2.0
|
8 |
+
certifi==2022.12.7
|
9 |
+
charset-normalizer==2.1.1
|
10 |
+
click==8.1.3
|
11 |
+
contourpy==1.0.7
|
12 |
+
cycler==0.11.0
|
13 |
+
entrypoints==0.4
|
14 |
+
fastapi==0.89.1
|
15 |
+
ffmpy==0.3.0
|
16 |
+
filelock==3.9.0
|
17 |
+
fonttools==4.38.0
|
18 |
+
frozenlist==1.3.3
|
19 |
+
fsspec==2023.1.0
|
20 |
+
gradio==3.17.1
|
21 |
+
h11==0.14.0
|
22 |
+
httpcore==0.16.3
|
23 |
+
httpx==0.23.3
|
24 |
+
huggingface-hub==0.12.0
|
25 |
+
idna==3.4
|
26 |
+
Jinja2==3.1.2
|
27 |
+
jsonschema==4.17.3
|
28 |
+
kiwisolver==1.4.4
|
29 |
+
linkify-it-py==1.0.3
|
30 |
+
markdown-it-py==2.1.0
|
31 |
+
MarkupSafe==2.1.2
|
32 |
+
matplotlib==3.6.3
|
33 |
+
mdit-py-plugins==0.3.3
|
34 |
+
mdurl==0.1.2
|
35 |
+
multidict==6.0.4
|
36 |
+
numpy==1.24.2
|
37 |
+
orjson==3.8.5
|
38 |
+
packaging==23.0
|
39 |
+
pandas==1.5.3
|
40 |
+
Pillow==9.4.0
|
41 |
+
pycryptodome==3.17
|
42 |
+
pydantic==1.10.4
|
43 |
+
pydub==0.25.1
|
44 |
+
pyparsing==3.0.9
|
45 |
+
pyrsistent==0.19.3
|
46 |
+
python-dateutil==2.8.2
|
47 |
+
python-multipart==0.0.5
|
48 |
+
pytz==2022.7.1
|
49 |
+
PyYAML==6.0
|
50 |
+
regex==2022.10.31
|
51 |
+
requests==2.28.2
|
52 |
+
rfc3986==1.5.0
|
53 |
+
sentencepiece==0.1.97
|
54 |
+
six==1.16.0
|
55 |
+
sniffio==1.3.0
|
56 |
+
starlette==0.22.0
|
57 |
+
tokenizers==0.13.2
|
58 |
+
toolz==0.12.0
|
59 |
+
torch==1.13.1
|
60 |
+
tqdm==4.64.1
|
61 |
+
transformers==4.26.0
|
62 |
+
typing_extensions==4.4.0
|
63 |
+
uc-micro-py==1.0.1
|
64 |
+
urllib3==1.26.14
|
65 |
+
uvicorn==0.20.0
|
66 |
+
websockets==10.4
|
67 |
+
yarl==1.8.2
|