Update fredalpaca.py
Browse files- fredalpaca.py +16 -19
fredalpaca.py
CHANGED
@@ -7,35 +7,32 @@ Original file is located at
|
|
7 |
https://colab.research.google.com/drive/1W6DsQPLinVnuJKqhVASYpuVwuHhhtGLc
|
8 |
"""
|
9 |
|
10 |
-
|
11 |
|
12 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
13 |
|
14 |
model_name = "IlyaGusev/fred_t5_ru_turbo_alpaca"
|
15 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
16 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to("cuda")
|
17 |
model.eval()
|
18 |
|
19 |
-
inputs = [
|
20 |
-
"Вопрос: Почему трава зеленая?",
|
21 |
-
"Задание: Сочини длинный рассказ, обязательно упоминая следующие объекты.\nДано: Таня, мяч",
|
22 |
-
"Могут ли в природе встретиться в одном месте белый медведь и пингвин? Если нет, то почему?",
|
23 |
-
"Задание: Заполни пропуски в предложении. Дано: Я пытался ____ от маньяка, но он меня настиг",
|
24 |
-
"Как приготовить лазанью?"
|
25 |
-
]
|
26 |
-
|
27 |
-
from transformers import GenerationConfig
|
28 |
-
|
29 |
generation_config = GenerationConfig.from_pretrained(model_name)
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
)[0]
|
38 |
print(tokenizer.decode(data["input_ids"][0].tolist()))
|
39 |
print(tokenizer.decode(output_ids.tolist()))
|
40 |
print("====================")
|
41 |
-
|
|
|
7 |
https://colab.research.google.com/drive/1W6DsQPLinVnuJKqhVASYpuVwuHhhtGLc
|
8 |
"""
|
9 |
|
10 |
+
import streamlit as st
|
11 |
|
12 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
13 |
+
from transformers import GenerationConfig
|
14 |
|
15 |
model_name = "IlyaGusev/fred_t5_ru_turbo_alpaca"
|
16 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
17 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to("cuda")
|
18 |
model.eval()
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
generation_config = GenerationConfig.from_pretrained(model_name)
|
21 |
+
st.title('Модель')
|
22 |
+
st.write('')
|
23 |
+
# Input text
|
24 |
+
text_in = st.text_input('Введите текст:')
|
25 |
+
# кнопка вывод результата
|
26 |
+
start = st.button('Start:')
|
27 |
+
if start:
|
28 |
+
for sample in text_in:
|
29 |
+
data = tokenizer(sample, return_tensors="pt")
|
30 |
+
data = {k: v.to(model.device) for k, v in data.items()}
|
31 |
+
output_ids = model.generate(
|
32 |
+
**data,
|
33 |
+
generation_config=generation_config
|
34 |
)[0]
|
35 |
print(tokenizer.decode(data["input_ids"][0].tolist()))
|
36 |
print(tokenizer.decode(output_ids.tolist()))
|
37 |
print("====================")
|
38 |
+
st.write("Результат:", tokenizer.decode(data["input_ids"][0].tolist()),tokenizer.decode(output_ids.tolist()))
|