Unggi's picture
Update app.py
6d38642
raw
history blame
1.59 kB
import pip
pip.main(['install', 'torch'])
pip.main(['install', 'transformers'])
import torch
import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification
def load_model(model_name):
# model
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
return model, tokenizer
def inference(prompt_inputs):
model_name = "Unggi/feedback_prize_kor"
model, tokenizer = load_model(
model_name = model_name
)
# preprocessing
prompt_inputs = prompt_inputs.replace('\n', ' ')
# prompt ๊ตฌ๋‘์  ๋‹จ์œ„๋กœ ๋ถ„๋ฆฌํ•˜๊ธฐ
prompt_list = prompt_inputs.split('.|!|?')
class_id_list = []
for prompt in prompt_list:
inputs = tokenizer(
prompt,
return_tensors="pt"
)
with torch.no_grad():
logits = model(**inputs).logits
predicted_class_id = logits.argmax().item()
class_id = model.config.id2label[predicted_class_id]
class_id_list.append(class_id)
outputs = []
for p, c_id in zip(prompt_list, class_id_list):
outputs.append(p + '\t' + c_id)
outputs = '\n'.join(outputs)
return outputs
demo = gr.Interface(
fn=inference,
inputs="text",
outputs="text", #return ๊ฐ’
examples=[
"๋ฏผ์ฃผ์ฃผ์˜ ๊ตญ๊ฐ€์—์„œ ๊ตญ๋ฏผ์€ ์ฃผ์ธ์ด๋‹ค."
]
).launch() # launch(share=True)๋ฅผ ์„ค์ •ํ•˜๋ฉด ์™ธ๋ถ€์—์„œ ์ ‘์† ๊ฐ€๋Šฅํ•œ ๋งํฌ๊ฐ€ ์ƒ์„ฑ๋จ
demo.launch()