File size: 1,591 Bytes
6ade039
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd5680b
 
6ade039
 
 
 
 
 
6d38642
 
 
bd5680b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ade039
bd5680b
6ade039
da2515f
6ade039
bd5680b
6ade039
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import pip
pip.main(['install', 'torch'])
pip.main(['install', 'transformers'])

import torch
import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification

def load_model(model_name):
    # model
    model = AutoModelForSequenceClassification.from_pretrained(model_name)
    # tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    return model, tokenizer


def inference(prompt_inputs):

    model_name = "Unggi/feedback_prize_kor"

    model, tokenizer = load_model(
        model_name = model_name
    )

    # preprocessing
    prompt_inputs = prompt_inputs.replace('\n', ' ')

    # prompt 구두점 단위로 분리하기
    prompt_list = prompt_inputs.split('.|!|?')

    class_id_list = []

    for prompt in prompt_list:
        inputs = tokenizer(
            prompt, 
            return_tensors="pt"
            )
    
        with torch.no_grad():
            logits = model(**inputs).logits
    
        predicted_class_id = logits.argmax().item()
        class_id = model.config.id2label[predicted_class_id]

        class_id_list.append(class_id)

    outputs = []

    for p, c_id in zip(prompt_list, class_id_list):

        outputs.append(p + '\t' + c_id)

    outputs = '\n'.join(outputs)

    return outputs

demo = gr.Interface(
    fn=inference, 
    inputs="text", 
    outputs="text", #return 값
    examples=[
        "민주주의 국가에서 국민은 주인이다."
    ]
    ).launch() # launch(share=True)를 설정하면 외부에서 접속 가능한 링크가 생성됨

demo.launch()