zenz-v1-space / app.py
Plat
chore: onnx model
f8fa551
import gradio as gr
from optimum.pipelines import pipeline
# モデルのロード
MODEL_NAME = "p1atdev/zenz-v1-onnx"
pipe = pipeline("text-generation", MODEL_NAME)
# ひらがなをカタカナに変換する関数
def hiragana_to_katakana(hiragana: str):
katakana = ""
for char in hiragana:
# ひらがなの文字コードの範囲はU+3041からU+309F
if 0x3041 <= ord(char) <= 0x309F:
katakana += chr(ord(char) + 0x60)
else:
katakana += char
return katakana
# 入力を調整する関数
def preprocess_input(user_input: str):
prefix = "\uEE00" # 前に付与する文字列
suffix = "\uEE01" # 後ろに付与する文字列
processed_input = prefix + hiragana_to_katakana(user_input) + suffix
return processed_input
# 出力を生成する関数
def generate_output(input_text: str, num_beams: int = 4):
generated_outputs = pipe(
input_text,
max_new_tokens=len(input_text) * 2,
num_beams=num_beams,
num_return_sequences=num_beams,
)
generated_texts = [output["generated_text"] for output in generated_outputs] # type: ignore
return generated_texts
# 出力を調整する関数
def postprocess_output(model_outputs: list[str]):
suffix = "\uEE01"
# \uEE01の後の部分を抽出
for i, model_output in enumerate(model_outputs):
if suffix in model_output:
model_outputs[i] = model_output.split(suffix)[1]
return "\n".join(
[f"{i+1}: {model_output}" for i, model_output in enumerate(model_outputs)]
)
# 変換処理をまとめる関数
def process_conversion(user_input: str, num_beams: int = 4):
processed_input = preprocess_input(user_input)
generated_outputs = generate_output(processed_input, num_beams)
postprocessed_output = postprocess_output(generated_outputs)
return postprocessed_output
# インターフェースを定義
def interface():
with gr.Blocks() as ui:
gr.Markdown(
"""## ニューラルかな漢字変換モデルzenz-v1のデモ (ONNX版)
変換したい文字列をひらがな・カタカナを入力してください"""
)
with gr.Row():
with gr.Column():
input_text = gr.TextArea(
label="変換する文字列(ひらがな・カタカナ)",
info="変換したいテキストをひらがなかカタカナで入力します。入力すると右に反映されます。",
)
num_beams = gr.Slider(
label="候補数",
info="多くするとより変換に時間がかかります",
minimum=1,
maximum=20,
step=1,
value=4,
)
with gr.Column():
output_text = gr.TextArea(
label="変換結果 (リアルタイム反映)",
info="変換候補が出力されます。上の候補ほど確信度が高いです。",
)
gr.Examples(
examples=[
["きめつのえいがをみました"],
["はがいたいのでしかいにみてもらった"],
["くつろぐにふといでかんたといいます"],
["けんかをかった"],
["けんかにかった"],
["こうえんをおねがいする"],
["こうえんでおねがいする"],
["つきむらてまり"],
],
inputs=[input_text],
)
gr.Markdown(
"""\
- 使用しているモデル (ONNX): [p1atdev/zenz-v1-onnx](https://huggingface.co/p1atdev/zenz-v1-onnx)
- オリジナル(変換元)のモデル: [Miwa-Keita/zenz-v1-checkpoints](https://huggingface.co/Miwa-Keita/zenz-v1-checkpoints)
"""
)
input_text.change(
fn=process_conversion,
inputs=[input_text, num_beams],
outputs=output_text,
)
num_beams.change(
fn=process_conversion,
inputs=[input_text, num_beams],
outputs=output_text,
)
ui.launch()
# ローンチ
if __name__ == "__main__":
interface()