|
|
|
import torch |
|
from transformers import AutoTokenizer, GPT2LMHeadModel |
|
|
|
O_TKN = '<origin>' |
|
C_TKN = '<correct>' |
|
BOS = "</s>" |
|
EOS = "</s>" |
|
PAD = "<pad>" |
|
MASK = '<unused0>' |
|
SENT = '<unused1>' |
|
|
|
|
|
def chat(): |
|
tokenizer = AutoTokenizer.from_pretrained('skt/kogpt2-base-v2', |
|
eos_token=EOS, unk_token='<unk>', |
|
pad_token=PAD, mask_token=MASK) |
|
model = GPT2LMHeadModel.from_pretrained('Moo/kogpt2-proofreader') |
|
with torch.no_grad(): |
|
while True: |
|
q = input('์๋๋ฌธ์ฅ: ').strip() |
|
if q == 'quit': |
|
break |
|
a = '' |
|
while True: |
|
input_ids = torch.LongTensor(tokenizer.encode(O_TKN + q + C_TKN + a)).unsqueeze(dim=0) |
|
pred = model(input_ids) |
|
gen = tokenizer.convert_ids_to_tokens( |
|
torch.argmax( |
|
pred[0], |
|
dim=-1).squeeze().numpy().tolist())[-1] |
|
if gen == EOS: |
|
break |
|
a += gen.replace('โ', ' ') |
|
print(f"๊ต์ : {a.strip()}") |
|
|
|
|
|
if __name__ == "__main__": |
|
chat() |
|
|