Spaces:
Build error
Build error
import torch | |
import string | |
import streamlit as st | |
from transformers import GPT2LMHeadModel | |
from tokenizers import Tokenizer | |
def get_model(): | |
model = GPT2LMHeadModel.from_pretrained('skt/kogpt2-base-v2') | |
model.eval() | |
return model | |
tokenizer = Tokenizer.from_file('skt/kogpt2-base-v2') | |
default_text = "νλμΈλ€μ μ νμ λΆμν΄ ν κΉ?" | |
N_SENT = 3 | |
model = get_model() | |
st.title("KoGPT2 Demo Page(ver 2.0)") | |
st.markdown(""" | |
### λͺ¨λΈ | |
| Model | # of params | Type | # of layers | # of heads | ffn_dim | hidden_dims | | |
|--------------|:----:|:-------:|--------:|--------:|--------:|--------------:| | |
| `KoGPT2` | 125M | Decoder | 12 | 12 | 3072 | 768 | | |
### μνλ§ λ°©λ² | |
- greedy sampling | |
- μ΅λ μΆλ ₯ κΈΈμ΄ : 128/1,024 | |
## Conditional Generation | |
""") | |
text = st.text_area("Input Text:", value=default_text) | |
st.write(text) | |
st.markdown(""" | |
> *νμ¬ 2core μΈμ€ν΄μ€μμ μμΈ‘μ΄ μ§νλμ΄ λ€μ λ릴 μ μμ* | |
""") | |
punct = ('!', '?', '.') | |
if text: | |
st.markdown("## Predict") | |
with st.spinner('processing..'): | |
print(f'input > {text}') | |
input_ids = tokenizer.encode(text).ids | |
gen_ids = model.generate(torch.tensor([input_ids]), | |
max_length=128, | |
repetition_penalty=2.0, | |
# num_beams=2, | |
# length_penalty=1.0, | |
use_cache=True, | |
pad_token_id=tokenizer.token_to_id('<pad>'), | |
eos_token_id=tokenizer.token_to_id('</s>'), | |
bos_token_id=tokenizer.token_to_id('</s>'), | |
bad_words_ids=[[tokenizer.token_to_id('<unk>')] ]) | |
generated = tokenizer.decode(gen_ids[0,:].tolist()).strip() | |
if generated != '' and generated[-1] not in punct: | |
for i in reversed(range(len(generated))): | |
if generated[i] in punct: | |
break | |
generated = generated[:(i+1)] | |
print(f'KoGPT > {generated}') | |
st.write(generated) | |