Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
# Загрузка токенизатора и модели | |
model_name = "GoidaAlignment/GOIDA-0.5B" # Замените на вашу модель | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
model = model.to("cuda" if torch.cuda.is_available() else "cpu") | |
# Шаблонная функция для форматирования диалога | |
def apply_chat_template(chat, add_generation_prompt=True): | |
formatted_chat = "" | |
for message in chat: | |
role = message["role"] | |
content = message["content"] | |
if role == "user": | |
formatted_chat += f"User: {content}\n" | |
elif role == "assistant": | |
formatted_chat += f"Assistant: {content}\n" | |
if add_generation_prompt: | |
formatted_chat += "Assistant: " | |
return formatted_chat | |
# Функция генерации ответа | |
def generate_response(user_input, chat_history): | |
chat_history.append({"role": "user", "content": user_input}) | |
formatted_chat = apply_chat_template(chat_history, add_generation_prompt=True) | |
# Токенизация | |
inputs = tokenizer(formatted_chat, return_tensors="pt", add_special_tokens=False) | |
inputs = {key: tensor.to(model.device) for key, tensor in inputs.items()} | |
# Генерация | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=64, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
# Декодирование | |
decoded_output = tokenizer.decode(outputs[0][inputs["input_ids"].size(1):], skip_special_tokens=True) | |
chat_history.append({"role": "assistant", "content": decoded_output}) | |
return decoded_output, chat_history | |
# Интерфейс Gradio | |
with gr.Blocks() as demo: | |
gr.Markdown("# Chatbot на основе модели ГОЙДАААА\nВзаимодействуйте с языковой моделью.") | |
chatbot = gr.Chatbot() | |
user_input = gr.Textbox(placeholder="Введите ваше сообщение...") | |
clear = gr.Button("Очистить чат") | |
chat_history = gr.State([]) # Состояние для хранения истории чата | |
user_input.submit( | |
generate_response, | |
[user_input, chat_history], | |
[chatbot, chat_history] | |
) | |
clear.click(lambda: ([], []), None, [chatbot, chat_history]) | |
if __name__ == "__main__": | |
demo.launch() | |