# -*- coding: utf-8 -*- import torch from transformers import AutoTokenizer, GPT2LMHeadModel O_TKN = '' C_TKN = '' BOS = "" EOS = "" PAD = "" MASK = '' SENT = '' def chat(): tokenizer = AutoTokenizer.from_pretrained('skt/kogpt2-base-v2', eos_token=EOS, unk_token='', pad_token=PAD, mask_token=MASK) model = GPT2LMHeadModel.from_pretrained('Moo/kogpt2-proofreader') with torch.no_grad(): while True: q = input('원래문장: ').strip() if q == 'quit': break a = '' while True: input_ids = torch.LongTensor(tokenizer.encode(O_TKN + q + C_TKN + a)).unsqueeze(dim=0) pred = model(input_ids) gen = tokenizer.convert_ids_to_tokens( torch.argmax( pred[0], dim=-1).squeeze().numpy().tolist())[-1] if gen == EOS: break a += gen.replace('▁', ' ') print(f"교정: {a.strip()}") if __name__ == "__main__": chat()