Makaria commited on
Commit
5138869
·
1 Parent(s): c5a72a1
Files changed (1) hide show
  1. app.py +32 -11
app.py CHANGED
@@ -1,29 +1,50 @@
1
  import os
2
  import gradio as gr
3
- from transformers import GPT2Tokenizer, GPT2LMHeadModel
4
  import torch
5
 
6
  # Импортируем токены из переменных окружения
7
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
8
 
9
- # Загрузка модели и токенизатора с использованием токена
10
- model_name = "gpt2" # Или другой, если нужно
11
- tokenizer = GPT2Tokenizer.from_pretrained(model_name, use_auth_token=HUGGINGFACE_TOKEN)
12
- model = GPT2LMHeadModel.from_pretrained(model_name, use_auth_token=HUGGINGFACE_TOKEN)
13
 
14
  # Функция для ведения диалога
15
- def chat_with_model(user_input):
16
- input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
17
- chat_history_ids = model.generate(input_ids, max_length=100, pad_token_id=tokenizer.eos_token_id)
18
- bot_response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
19
- return bot_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  # Создание интерфейса Gradio
22
  iface = gr.Interface(
23
  fn=chat_with_model,
24
  inputs="text",
25
  outputs="text",
26
- title="Чатбот на GPT-2",
27
  description="Поболтай со своим чатботом!"
28
  )
29
 
 
1
  import os
2
  import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
 
6
  # Импортируем токены из переменных окружения
7
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
8
 
9
+ # Загрузка модели и токенизатора DialoGPT
10
+ model_name = "microsoft/DialoGPT-medium" # Можно использовать small или large версии
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_TOKEN)
12
+ model = AutoModelForCausalLM.from_pretrained(model_name, token=HUGGINGFACE_TOKEN)
13
 
14
  # Функция для ведения диалога
15
+ def chat_with_model(user_input, chat_history=[]):
16
+ # Кодируем входное сообщение и добавляем к истории
17
+ new_user_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
18
+
19
+ # Объединяем новую информацию с историей
20
+ bot_input_ids = torch.cat([torch.tensor(chat_history), new_user_input_ids], dim=-1) if chat_history else new_user_input_ids
21
+
22
+ # Генерируем ответ
23
+ chat_history_ids = model.generate(
24
+ bot_input_ids,
25
+ max_length=100,
26
+ num_return_sequences=1,
27
+ pad_token_id=tokenizer.eos_token_id,
28
+ temperature=0.7,
29
+ top_k=50,
30
+ top_p=0.95
31
+ )
32
+
33
+ # Декодируем ответ
34
+ bot_response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
35
+
36
+ # Обновляем историю
37
+ chat_history.append(new_user_input_ids)
38
+ chat_history.append(chat_history_ids[:, bot_input_ids.shape[-1]:])
39
+
40
+ return bot_response, chat_history
41
 
42
  # Создание интерфейса Gradio
43
  iface = gr.Interface(
44
  fn=chat_with_model,
45
  inputs="text",
46
  outputs="text",
47
+ title="Чатбот на DialoGPT",
48
  description="Поболтай со своим чатботом!"
49
  )
50