kogpt2-proofreader / correct.py
Moo's picture
Update correct.py
df3a61e
raw
history blame
No virus
1.22 kB
# -*- coding: utf-8 -*-
import torch
from transformers import AutoTokenizer, GPT2LMHeadModel
O_TKN = '<origin>'
C_TKN = '<correct>'
BOS = "</s>"
EOS = "</s>"
PAD = "<pad>"
MASK = '<unused0>'
SENT = '<unused1>'
def chat():
tokenizer = AutoTokenizer.from_pretrained('skt/kogpt2-base-v2',
eos_token=EOS, unk_token='<unk>',
pad_token=PAD, mask_token=MASK)
model = GPT2LMHeadModel.from_pretrained('Moo/kogpt2-proofreader')
with torch.no_grad():
while True:
q = input('์›๋ž˜๋ฌธ์žฅ: ').strip()
if q == 'quit':
break
a = ''
while True:
input_ids = torch.LongTensor(tokenizer.encode(O_TKN + q + C_TKN + a)).unsqueeze(dim=0)
pred = model(input_ids)
gen = tokenizer.convert_ids_to_tokens(
torch.argmax(
pred[0],
dim=-1).squeeze().numpy().tolist())[-1]
if gen == EOS:
break
a += gen.replace('โ–', ' ')
print(f"๊ต์ •: {a.strip()}")
if __name__ == "__main__":
chat()