Spaces:
Runtime error
Runtime error
File size: 2,871 Bytes
7d6f77f 518485c 833aa0a da9ee26 7d6f77f 47d0a4a 7d6f77f 724876e 1bc822d 724876e 7d6f77f 724876e 7d6f77f d2e6254 bb72c45 7d6f77f b86439f 7d6f77f d2e6254 1bc822d 7d6f77f 652c2db 7d6f77f 833aa0a 7d6f77f da9ee26 d2e6254 da9ee26 cc7b3e8 7d6f77f cc7b3e8 652c2db 833aa0a 7d6f77f b71c3b1 cc7b3e8 7d6f77f cc7b3e8 652c2db cc7b3e8 652c2db cc7b3e8 652c2db cc7b3e8 652c2db 4bd4566 cc7b3e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import transformers
import torch
import tokenizers
import streamlit as st
import re
from PIL import Image
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None, re.Pattern: lambda _: None}, allow_output_mutation=True, suppress_st_warning=True)
def get_model(model_name, model_path):
tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens({
'eos_token': '[EOS]'
})
model = transformers.GPT2LMHeadModel.from_pretrained(model_name)
model.resize_token_embeddings(len(tokenizer))
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
return model, tokenizer
def predict(text, model, tokenizer, n_beams=5, temperature=2.5, top_p=0.8, length_of_generated=300):
text += '\n'
input_ids = tokenizer.encode(text, return_tensors="pt")
length_of_prompt = len(input_ids[0])
with torch.no_grad():
out = model.generate(input_ids,
do_sample=True,
num_beams=n_beams,
temperature=temperature,
top_p=top_p,
max_length=length_of_prompt + length_of_generated,
eos_token_id=tokenizer.eos_token_id
)
generated = list(map(tokenizer.decode, out))[0]
return generated.replace('\n[EOS]\n', '')
medium_model, medium_tokenizer = get_model('sberbank-ai/rugpt3medium_based_on_gpt2', 'korzh-medium_best_eval_loss.bin')
large_model, large_tokenizer = get_model('sberbank-ai/rugpt3large_based_on_gpt2', 'korzh-large_best_eval_loss.bin')
# st.title("NeuroKorzh")
image = Image.open('korzh.jpg')
st.image(image, caption='НейроКорж')
option = st.selectbox('Выберите своего Коржа', ('Быстрый', 'Глубокий'))
craziness = st.slider(label='Абсурдность', min_value=0, max_value=100, value=50, step=5)
temperature = 2 + craziness / 50.
st.markdown("\n")
text = st.text_area(label='Напишите начало песни', value='Что делать, Макс?', height=70)
button = st.button('Старт')
if button:
try:
with st.spinner("Пушечка пишется"):
if option == 'Быстрый':
result = predict(text, medium_model, medium_tokenizer, temperature=temperature)
elif option == 'Глубокий':
result = predict(text, large_model, large_tokenizer, temperature=temperature)
else:
st.error('Error in selectbox')
st.text_area(label='', value=result, height=1000)
except Exception:
st.error("Ooooops, something went wrong. Please try again and report to me, tg: @vladyur") |