Spaces:
Running
Running
import gradio as gr | |
from optimum.pipelines import pipeline | |
# モデルのロード | |
MODEL_NAME = "p1atdev/zenz-v1-onnx" | |
pipe = pipeline("text-generation", MODEL_NAME) | |
# ひらがなをカタカナに変換する関数 | |
def hiragana_to_katakana(hiragana: str): | |
katakana = "" | |
for char in hiragana: | |
# ひらがなの文字コードの範囲はU+3041からU+309F | |
if 0x3041 <= ord(char) <= 0x309F: | |
katakana += chr(ord(char) + 0x60) | |
else: | |
katakana += char | |
return katakana | |
# 入力を調整する関数 | |
def preprocess_input(user_input: str): | |
prefix = "\uEE00" # 前に付与する文字列 | |
suffix = "\uEE01" # 後ろに付与する文字列 | |
processed_input = prefix + hiragana_to_katakana(user_input) + suffix | |
return processed_input | |
# 出力を生成する関数 | |
def generate_output(input_text: str, num_beams: int = 4): | |
generated_outputs = pipe( | |
input_text, | |
max_new_tokens=len(input_text) * 2, | |
num_beams=num_beams, | |
num_return_sequences=num_beams, | |
) | |
generated_texts = [output["generated_text"] for output in generated_outputs] # type: ignore | |
return generated_texts | |
# 出力を調整する関数 | |
def postprocess_output(model_outputs: list[str]): | |
suffix = "\uEE01" | |
# \uEE01の後の部分を抽出 | |
for i, model_output in enumerate(model_outputs): | |
if suffix in model_output: | |
model_outputs[i] = model_output.split(suffix)[1] | |
return "\n".join( | |
[f"{i+1}: {model_output}" for i, model_output in enumerate(model_outputs)] | |
) | |
# 変換処理をまとめる関数 | |
def process_conversion(user_input: str, num_beams: int = 4): | |
processed_input = preprocess_input(user_input) | |
generated_outputs = generate_output(processed_input, num_beams) | |
postprocessed_output = postprocess_output(generated_outputs) | |
return postprocessed_output | |
# インターフェースを定義 | |
def interface(): | |
with gr.Blocks() as ui: | |
gr.Markdown( | |
"""## ニューラルかな漢字変換モデルzenz-v1のデモ (ONNX版) | |
変換したい文字列をひらがな・カタカナを入力してください""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.TextArea( | |
label="変換する文字列(ひらがな・カタカナ)", | |
info="変換したいテキストをひらがなかカタカナで入力します。入力すると右に反映されます。", | |
) | |
num_beams = gr.Slider( | |
label="候補数", | |
info="多くするとより変換に時間がかかります", | |
minimum=1, | |
maximum=20, | |
step=1, | |
value=4, | |
) | |
with gr.Column(): | |
output_text = gr.TextArea( | |
label="変換結果 (リアルタイム反映)", | |
info="変換候補が出力されます。上の候補ほど確信度が高いです。", | |
) | |
gr.Examples( | |
examples=[ | |
["きめつのえいがをみました"], | |
["はがいたいのでしかいにみてもらった"], | |
["くつろぐにふといでかんたといいます"], | |
["けんかをかった"], | |
["けんかにかった"], | |
["こうえんをおねがいする"], | |
["こうえんでおねがいする"], | |
["つきむらてまり"], | |
], | |
inputs=[input_text], | |
) | |
gr.Markdown( | |
"""\ | |
- 使用しているモデル (ONNX): [p1atdev/zenz-v1-onnx](https://huggingface.co/p1atdev/zenz-v1-onnx) | |
- オリジナル(変換元)のモデル: [Miwa-Keita/zenz-v1-checkpoints](https://huggingface.co/Miwa-Keita/zenz-v1-checkpoints) | |
""" | |
) | |
input_text.change( | |
fn=process_conversion, | |
inputs=[input_text, num_beams], | |
outputs=output_text, | |
) | |
num_beams.change( | |
fn=process_conversion, | |
inputs=[input_text, num_beams], | |
outputs=output_text, | |
) | |
ui.launch() | |
# ローンチ | |
if __name__ == "__main__": | |
interface() | |