Moo commited on
Commit
df3a61e
โ€ข
1 Parent(s): 8a1e4a5

Update correct.py

Browse files
Files changed (1) hide show
  1. correct.py +19 -19
correct.py CHANGED
@@ -12,27 +12,27 @@ SENT = '<unused1>'
12
 
13
 
14
  def chat():
15
- tokenizer = AutoTokenizer.from_pretrained('skt/kogpt2-base-v2',
16
- eos_token=EOS, unk_token='<unk>',
17
- pad_token=PAD, mask_token=MASK)
18
- model = GPT2LMHeadModel.from_pretrained('Moo/kogpt2-proofreader')
19
- with torch.no_grad():
 
 
 
 
 
20
  while True:
21
- q = input('์›๋ž˜๋ฌธ์žฅ: ').strip()
22
- if q == 'quit':
 
 
 
 
 
23
  break
24
- a = ''
25
- while True:
26
- input_ids = torch.LongTensor(tokenizer.encode(O_TKN + q + C_TKN + a)).unsqueeze(dim=0)
27
- pred = model(input_ids)
28
- gen = tokenizer.convert_ids_to_tokens(
29
- torch.argmax(
30
- pred[0],
31
- dim=-1).squeeze().numpy().tolist())[-1]
32
- if gen == EOS:
33
- break
34
- a += gen.replace('โ–', ' ')
35
- print(f"๊ต์ •: {a.strip()}")
36
 
37
 
38
  if __name__ == "__main__":
 
12
 
13
 
14
  def chat():
15
+ tokenizer = AutoTokenizer.from_pretrained('skt/kogpt2-base-v2',
16
+ eos_token=EOS, unk_token='<unk>',
17
+ pad_token=PAD, mask_token=MASK)
18
+ model = GPT2LMHeadModel.from_pretrained('Moo/kogpt2-proofreader')
19
+ with torch.no_grad():
20
+ while True:
21
+ q = input('์›๋ž˜๋ฌธ์žฅ: ').strip()
22
+ if q == 'quit':
23
+ break
24
+ a = ''
25
  while True:
26
+ input_ids = torch.LongTensor(tokenizer.encode(O_TKN + q + C_TKN + a)).unsqueeze(dim=0)
27
+ pred = model(input_ids)
28
+ gen = tokenizer.convert_ids_to_tokens(
29
+ torch.argmax(
30
+ pred[0],
31
+ dim=-1).squeeze().numpy().tolist())[-1]
32
+ if gen == EOS:
33
  break
34
+ a += gen.replace('โ–', ' ')
35
+ print(f"๊ต์ •: {a.strip()}")
 
 
 
 
 
 
 
 
 
 
36
 
37
 
38
  if __name__ == "__main__":