vladyur commited on
Commit
bb72c45
1 Parent(s): a16dba0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -15,7 +15,8 @@ def get_model(model_name, model_path):
15
 
16
 
17
  #@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None, re.Pattern: lambda _: None}, allow_output_mutation=True, suppress_st_warning=True)
18
- def predict(text, model, tokenizer, n_beams=5, temperature=2.5, top_p=0.8, max_length=300):
 
19
  input_ids = tokenizer.encode(text, return_tensors="pt")
20
  with torch.no_grad():
21
  out = model.generate(input_ids,
@@ -32,8 +33,8 @@ def predict(text, model, tokenizer, n_beams=5, temperature=2.5, top_p=0.8, max_l
32
  model, tokenizer = get_model('sberbank-ai/rugpt3medium_based_on_gpt2', 'korzh-medium_30epochs_1bs.bin')
33
 
34
  st.title("NeuroKorzh")
35
- # st.markdown("<img width=200px src='https://avatars.yandex.net/get-music-content/2399641/5d26d7e5.p.975699/m1000x1000'>",
36
- # unsafe_allow_html=True)
37
 
38
  st.markdown("\n")
39
 
 
15
 
16
 
17
  #@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None, re.Pattern: lambda _: None}, allow_output_mutation=True, suppress_st_warning=True)
18
+ def predict(text, model, tokenizer, n_beams=5, temperature=2.5, top_p=0.8, max_length=200):
19
+ text += '\n'
20
  input_ids = tokenizer.encode(text, return_tensors="pt")
21
  with torch.no_grad():
22
  out = model.generate(input_ids,
 
33
  model, tokenizer = get_model('sberbank-ai/rugpt3medium_based_on_gpt2', 'korzh-medium_30epochs_1bs.bin')
34
 
35
  st.title("NeuroKorzh")
36
+ st.markdown("<img width=200px src='https://avatars.yandex.net/get-music-content/2399641/5d26d7e5.p.975699/m1000x1000'>",
37
+ unsafe_allow_html=True)
38
 
39
  st.markdown("\n")
40