ds_chat / app.py
beyoru's picture
Update app.py
be4a377 verified
raw
history blame
3.74 kB
import gradio as gr
from huggingface_hub import InferenceClient
import string
import numpy as np
from transformers import AutoTokenizer
import onnxruntime as ort
import os
# Initialize client and models
client = InferenceClient(api_key=os.environ.get('HF_TOKEN'))
# Constants for EOU calculation
PUNCS = string.punctuation.replace("'", "")
MAX_HISTORY = 4
MAX_HISTORY_TOKENS = 1024
EOU_THRESHOLD = 0.5
# Initialize tokenizer and ONNX session
HG_MODEL = "livekit/turn-detector"
ONNX_FILENAME = "model_quantized.onnx"
tokenizer = AutoTokenizer.from_pretrained(HG_MODEL)
onnx_session = ort.InferenceSession(ONNX_FILENAME, providers=["CPUExecutionProvider"])
# Helper functions for EOU
def softmax(logits):
exp_logits = np.exp(logits - np.max(logits))
return exp_logits / np.sum(exp_logits)
def normalize_text(text):
def strip_puncs(text):
return text.translate(str.maketrans("", "", PUNCS))
return " ".join(strip_puncs(text).lower().split())
def format_chat_ctx(chat_ctx):
new_chat_ctx = []
for msg in chat_ctx:
if msg["role"] in ("user", "assistant"):
content = normalize_text(msg["content"])
if content:
msg["content"] = content
new_chat_ctx.append(msg)
convo_text = tokenizer.apply_chat_template(
new_chat_ctx, add_generation_prompt=False, add_special_tokens=False, tokenize=False
)
ix = convo_text.rfind("<|im_end|>")
return convo_text[:ix]
def calculate_eou(chat_ctx, session):
formatted_text = format_chat_ctx(chat_ctx[-MAX_HISTORY:])
inputs = tokenizer(
formatted_text,
return_tensors="np",
truncation=True,
max_length=MAX_HISTORY_TOKENS,
)
input_ids = np.array(inputs["input_ids"], dtype=np.int64)
outputs = session.run(["logits"], {"input_ids": input_ids})
logits = outputs[0][0, -1, :]
probs = softmax(logits)
eou_token_id = tokenizer.encode("<|im_end|>")[-1]
return probs[eou_token_id]
messages = []
def chatbot(user_input):
global messages
# Exit condition
if user_input.lower() == "exit":
messages = [] # Reset conversation history
return "Chat ended. Refresh the page to start again."
# Add user message to conversation history
messages.append({"role": "user", "content": user_input})
# Calculate EOU to determine if user has finished typing
eou_prob = calculate_eou(messages, onnx_session)
if eou_prob < EOU_THRESHOLD:
yield "[I'm waiting for you to complete the sentence...]"
return
# Stream the chatbot's response
stream = client.chat.completions.create(
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
messages=messages,
temperature=0.6,
max_tokens=1024,
top_p=0.95,
stream=True
)
bot_response = ""
for chunk in stream:
bot_response += chunk.choices[0].delta.content
yield bot_response
# Add final bot response to conversation history
messages.append({"role": "assistant", "content": bot_response})
# Create Gradio interface
with gr.Blocks(theme='darkdefault') as demo:
gr.Markdown("""# Chat with DeepSeek-R1
Type your message below to interact with the chatbot. Type "exit" to end the conversation.
""")
with gr.Row():
with gr.Column():
user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
submit_button = gr.Button("Send")
with gr.Column():
chat_output = gr.Textbox(label="Chatbot Response", interactive=False)
# Define interactions
submit_button.click(chatbot, inputs=[user_input], outputs=[chat_output])
# Launch the app
demo.launch()