File size: 4,332 Bytes
b896c7a
f8fa551
b896c7a
b169a72
f8fa551
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b169a72
46b5046
f8fa551
b169a72
 
f8fa551
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b169a72
46b5046
f8fa551
46b5046
 
f8fa551
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46b5046
b169a72
f8fa551
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b169a72
 
f8fa551
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import gradio as gr
from optimum.pipelines import pipeline

# モデルのロード
MODEL_NAME = "p1atdev/zenz-v1-onnx"
pipe = pipeline("text-generation", MODEL_NAME)


# ひらがなをカタカナに変換する関数
def hiragana_to_katakana(hiragana: str):
    katakana = ""
    for char in hiragana:
        # ひらがなの文字コードの範囲はU+3041からU+309F
        if 0x3041 <= ord(char) <= 0x309F:
            katakana += chr(ord(char) + 0x60)
        else:
            katakana += char
    return katakana


# 入力を調整する関数
def preprocess_input(user_input: str):
    prefix = "\uEE00"  # 前に付与する文字列
    suffix = "\uEE01"  # 後ろに付与する文字列
    processed_input = prefix + hiragana_to_katakana(user_input) + suffix

    return processed_input


# 出力を生成する関数
def generate_output(input_text: str, num_beams: int = 4):
    generated_outputs = pipe(
        input_text,
        max_new_tokens=len(input_text) * 2,
        num_beams=num_beams,
        num_return_sequences=num_beams,
    )
    generated_texts = [output["generated_text"] for output in generated_outputs]  # type: ignore
    return generated_texts


# 出力を調整する関数
def postprocess_output(model_outputs: list[str]):
    suffix = "\uEE01"
    # \uEE01の後の部分を抽出
    for i, model_output in enumerate(model_outputs):
        if suffix in model_output:
            model_outputs[i] = model_output.split(suffix)[1]
    return "\n".join(
        [f"{i+1}: {model_output}" for i, model_output in enumerate(model_outputs)]
    )


# 変換処理をまとめる関数
def process_conversion(user_input: str, num_beams: int = 4):
    processed_input = preprocess_input(user_input)
    generated_outputs = generate_output(processed_input, num_beams)
    postprocessed_output = postprocess_output(generated_outputs)
    return postprocessed_output


# インターフェースを定義
def interface():
    with gr.Blocks() as ui:
        gr.Markdown(
            """## ニューラルかな漢字変換モデルzenz-v1のデモ (ONNX版)
            変換したい文字列をひらがな・カタカナを入力してください"""
        )

        with gr.Row():
            with gr.Column():
                input_text = gr.TextArea(
                    label="変換する文字列(ひらがな・カタカナ)",
                    info="変換したいテキストをひらがなかカタカナで入力します。入力すると右に反映されます。",
                )
                num_beams = gr.Slider(
                    label="候補数",
                    info="多くするとより変換に時間がかかります",
                    minimum=1,
                    maximum=20,
                    step=1,
                    value=4,
                )

            with gr.Column():
                output_text = gr.TextArea(
                    label="変換結果 (リアルタイム反映)",
                    info="変換候補が出力されます。上の候補ほど確信度が高いです。",
                )

        gr.Examples(
            examples=[
                ["きめつのえいがをみました"],
                ["はがいたいのでしかいにみてもらった"],
                ["くつろぐにふといでかんたといいます"],
                ["けんかをかった"],
                ["けんかにかった"],
                ["こうえんをおねがいする"],
                ["こうえんでおねがいする"],
                ["つきむらてまり"],
            ],
            inputs=[input_text],
        )

        gr.Markdown(
            """\
- 使用しているモデル (ONNX): [p1atdev/zenz-v1-onnx](https://huggingface.co/p1atdev/zenz-v1-onnx)
- オリジナル(変換元)のモデル: [Miwa-Keita/zenz-v1-checkpoints](https://huggingface.co/Miwa-Keita/zenz-v1-checkpoints)
"""
        )

        input_text.change(
            fn=process_conversion,
            inputs=[input_text, num_beams],
            outputs=output_text,
        )
        num_beams.change(
            fn=process_conversion,
            inputs=[input_text, num_beams],
            outputs=output_text,
        )

    ui.launch()


# ローンチ
if __name__ == "__main__":
    interface()