vasevooo commited on
Commit
5d53558
1 Parent(s): e83b857

Update pages/gpt.py

Browse files
Files changed (1) hide show
  1. pages/gpt.py +40 -38
pages/gpt.py CHANGED
@@ -1,46 +1,48 @@
1
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
2
- import torch
3
  import streamlit as st
4
-
 
 
 
 
 
 
 
5
  model = GPT2LMHeadModel.from_pretrained(
6
  'sberbank-ai/rugpt3small_based_on_gpt2',
7
  output_attentions = False,
8
  output_hidden_states = False,
9
  )
10
-
11
- tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
12
-
13
-
14
  # Вешаем сохраненные веса на нашу модель
15
- model.load_state_dict(torch.load('models/model.pt', map_location=torch.device('cpu')))
16
-
17
- prompt = st.text_input('Введите текст prompt:')
18
- length = st.slider('Длина генерируемой последовательности:', 10, 1000, 50)
19
- num_samples = st.slider('Число генераций:', 1, 10, 1)
20
- temperature = st.slider('Температура:', 0.1, 1.0, 0.5)
21
-
22
-
23
-
24
- def generate_text(model, tokenizer, prompt, length, num_samples, temperature):
25
- input_ids = tokenizer.encode(prompt, return_tensors='pt')
26
- output_sequences = model.generate(
27
- input_ids=input_ids,
28
- max_length=length,
29
- num_return_sequences=num_samples,
30
- temperature=temperature
31
- )
32
-
33
- generated_texts = []
34
- for output_sequence in output_sequences:
35
- generated_text = tokenizer.decode(output_sequence, clean_up_tokenization_spaces=True)
36
- generated_texts.append(generated_text)
37
-
38
- return generated_texts
39
-
40
-
41
- if st.button('Сгенерировать текст'):
42
- generated_texts = generate_text(model, tokenizer, prompt, length, num_samples, temperature)
43
- for i, text in enumerate(generated_texts):
44
- st.write(f'Текст {i+1}:')
45
- st.write(text)
46
-
 
1
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
 
2
  import streamlit as st
3
+ import torch
4
+ import textwrap
5
+ import plotly.express as px
6
+
7
+
8
+
9
+
10
+ tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
11
  model = GPT2LMHeadModel.from_pretrained(
12
  'sberbank-ai/rugpt3small_based_on_gpt2',
13
  output_attentions = False,
14
  output_hidden_states = False,
15
  )
 
 
 
 
16
  # Вешаем сохраненные веса на нашу модель
17
+ model.load_state_dict(torch.load('models/modelgpt.pt', map_location=torch.device('cpu')))
18
+
19
+
20
+ length = st.sidebar.slider('**Длина генерируемой последовательности:**', 8, 256, 15)
21
+ num_samples = st.sidebar.slider('**Число генераций:**', 1, 10, 1)
22
+ temperature = st.sidebar.slider('**Температура:**', 1.0, 10.0, 2.0)
23
+ top_k = st.sidebar.slider('**Количество наиболее вероятных слов генерации:**', 10, 200, 50)
24
+ top_p = st.sidebar.slider('**Минимальная суммарная вероятность топовых слов:**', 0.4, 1.0, 0.9)
25
+
26
+
27
+ prompt = st.text_input('**Введите текст 👇:**')
28
+ if st.button('**Сгенерировать текст**'):
29
+
30
+ with torch.inference_mode():
31
+ prompt = tokenizer.encode(prompt, return_tensors='pt')
32
+ out = model.generate(
33
+ input_ids=prompt,
34
+ max_length=length,
35
+ num_beams=8,
36
+ do_sample=True,
37
+ temperature=temperature,
38
+ top_k=top_k,
39
+ top_p=top_p,
40
+ no_repeat_ngram_size=3,
41
+ num_return_sequences=num_samples,
42
+ ).cpu().numpy()
43
+ st.write('**_Результат_** 👇')
44
+ for i, out_ in enumerate(out):
45
+
46
+ with st.expander(f'Текст {i+1}:'):
47
+ st.write(textwrap.fill(tokenizer.decode(out_), 100))
48
+ st.image("pict/wow.png")