File size: 2,871 Bytes
7d6f77f
 
 
 
518485c
833aa0a
da9ee26
7d6f77f
 
47d0a4a
7d6f77f
 
724876e
1bc822d
724876e
7d6f77f
724876e
7d6f77f
 
 
 
 
d2e6254
bb72c45
7d6f77f
b86439f
7d6f77f
 
 
 
 
 
d2e6254
1bc822d
7d6f77f
652c2db
 
 
7d6f77f
 
833aa0a
 
7d6f77f
da9ee26
d2e6254
da9ee26
cc7b3e8
7d6f77f
cc7b3e8
652c2db
 
833aa0a
7d6f77f
 
b71c3b1
cc7b3e8
7d6f77f
 
cc7b3e8
 
 
652c2db
cc7b3e8
652c2db
cc7b3e8
652c2db
cc7b3e8
652c2db
4bd4566
cc7b3e8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import transformers
import torch
import tokenizers
import streamlit as st
import re

from PIL import Image


@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None, re.Pattern: lambda _: None}, allow_output_mutation=True, suppress_st_warning=True)
def get_model(model_name, model_path):
    tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_name)
    tokenizer.add_special_tokens({
        'eos_token': '[EOS]'
    })
    model = transformers.GPT2LMHeadModel.from_pretrained(model_name)
    model.resize_token_embeddings(len(tokenizer))
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model.eval()
    return model, tokenizer


def predict(text, model, tokenizer, n_beams=5, temperature=2.5, top_p=0.8, length_of_generated=300):
    text += '\n'
    input_ids = tokenizer.encode(text, return_tensors="pt")
    length_of_prompt = len(input_ids[0])
    with torch.no_grad():
        out = model.generate(input_ids,
                             do_sample=True,
                             num_beams=n_beams,
                             temperature=temperature,
                             top_p=top_p,
                             max_length=length_of_prompt + length_of_generated,
                             eos_token_id=tokenizer.eos_token_id
                             )
                          
    generated = list(map(tokenizer.decode, out))[0]
    return generated.replace('\n[EOS]\n', '')


medium_model, medium_tokenizer = get_model('sberbank-ai/rugpt3medium_based_on_gpt2', 'korzh-medium_best_eval_loss.bin')
large_model, large_tokenizer = get_model('sberbank-ai/rugpt3large_based_on_gpt2', 'korzh-large_best_eval_loss.bin')

# st.title("NeuroKorzh")

image = Image.open('korzh.jpg')
st.image(image, caption='НейроКорж')

option = st.selectbox('Выберите своего Коржа', ('Быстрый', 'Глубокий'))
craziness = st.slider(label='Абсурдность', min_value=0, max_value=100, value=50, step=5)
temperature = 2 + craziness / 50.

st.markdown("\n")

text = st.text_area(label='Напишите начало песни', value='Что делать, Макс?', height=70)
button = st.button('Старт')

if button:
    try:
        with st.spinner("Пушечка пишется"):
            if option == 'Быстрый':
                result = predict(text, medium_model, medium_tokenizer, temperature=temperature)
            elif option == 'Глубокий':
                result = predict(text, large_model, large_tokenizer, temperature=temperature)
            else:
                st.error('Error in selectbox')
        
        st.text_area(label='', value=result, height=1000)
    
    except Exception:
        st.error("Ooooops, something went wrong. Please try again and report to me, tg: @vladyur")