import torch import string import streamlit as st from transformers import GPT2LMHeadModel from tokenizers import Tokenizer @st.cache def get_model(): model = GPT2LMHeadModel.from_pretrained('skt/kogpt2-base-v2') model.eval() return model tokenizer = Tokenizer.from_file('skt/kogpt2-base-v2') default_text = "현대인들은 왜 항상 불안해 할까?" N_SENT = 3 model = get_model() st.title("KoGPT2 Demo Page(ver 2.0)") st.markdown(""" ### 모델 | Model | # of params | Type | # of layers | # of heads | ffn_dim | hidden_dims | |--------------|:----:|:-------:|--------:|--------:|--------:|--------------:| | `KoGPT2` | 125M | Decoder | 12 | 12 | 3072 | 768 | ### 샘플링 방법 - greedy sampling - 최대 출력 길이 : 128/1,024 ## Conditional Generation """) text = st.text_area("Input Text:", value=default_text) st.write(text) st.markdown(""" > *현재 2core 인스턴스에서 예측이 진행되어 다소 느릴 수 있음* """) punct = ('!', '?', '.') if text: st.markdown("## Predict") with st.spinner('processing..'): print(f'input > {text}') input_ids = tokenizer.encode(text).ids gen_ids = model.generate(torch.tensor([input_ids]), max_length=128, repetition_penalty=2.0, # num_beams=2, # length_penalty=1.0, use_cache=True, pad_token_id=tokenizer.token_to_id(''), eos_token_id=tokenizer.token_to_id(''), bos_token_id=tokenizer.token_to_id(''), bad_words_ids=[[tokenizer.token_to_id('')] ]) generated = tokenizer.decode(gen_ids[0,:].tolist()).strip() if generated != '' and generated[-1] not in punct: for i in reversed(range(len(generated))): if generated[i] in punct: break generated = generated[:(i+1)] print(f'KoGPT > {generated}') st.write(generated)