Spaces:
Running
on
Zero
Running
on
Zero
import subprocess | |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
import torch | |
import spaces | |
import gradio as gr | |
import flash_attn | |
from threading import Thread | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
BitsAndBytesConfig, | |
TextIteratorStreamer, | |
StoppingCriteria, | |
StoppingCriteriaList | |
) | |
MODEL_ID = "unsloth/QwQ-32B-unsloth-bnb-4bit" | |
DEFAULT_SYSTEM_PROMPT = """ | |
Think step by step and explain your reasoning clearly. Break down the problem into logical components, verify each step, and ensure consistency before arriving at the final answer." | |
For complex reasoning tasks, you can enhance it with: | |
"If there are multiple possible solutions, consider each one before selecting the best answer." | |
"Use intermediate calculations and justify each step before proceeding." | |
"If relevant, include real-world analogies to improve clarity. | |
""" | |
CSS = """ | |
.gr-chatbot { min-height: 500px; border-radius: 15px; } | |
.special-tag { color: #2ecc71; font-weight: 600; } | |
footer { display: none !important; } | |
""" | |
class StopOnTokens(StoppingCriteria): | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
# Stop when the EOS token is generated. | |
return input_ids[0][-1] == tokenizer.eos_token_id | |
def initialize_model(): | |
# Enable 4-bit quantization for faster inference and lower memory usage. | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
device_map="cuda", | |
#quantization_config=quantization_config, | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True, | |
attn_implementation="flash_attention_2" | |
) | |
model.to("cuda") | |
model.eval() # set evaluation mode to disable gradients and speed up inference | |
return model, tokenizer | |
def format_response(text): | |
# List of replacements to format key tokens with HTML for styling. | |
replacements = [ | |
("[Understand]", '\n<strong class="special-tag">[Understand]</strong>\n'), | |
("[think]", '\n<strong class="special-tag">[think]</strong>\n'), | |
("[/think]", '\n<strong class="special-tag">[/think]</strong>\n'), | |
("[Answer]", '\n<strong class="special-tag">[Answer]</strong>\n'), | |
("[/Answer]", '\n<strong class="special-tag">[/Answer]</strong>\n'), | |
] | |
for old, new in replacements: | |
text = text.replace(old, new) | |
return text | |
# --- New helper: Llama-3 conversation template --- | |
def apply_llama3_chat_template(conversation, add_generation_prompt=True): | |
""" | |
Convert the conversation (a list of dicts with 'role' and 'content') | |
into a single prompt string in Llama-3 style. | |
""" | |
prompt = "" | |
for msg in conversation: | |
role = msg["role"].upper() | |
if role == "SYSTEM": | |
prompt += "<|SYSTEM|>\n" + msg["content"].strip() + "\n" | |
elif role == "USER": | |
prompt += "<|USER|>\n" + msg["content"].strip() + "\n" | |
elif role == "ASSISTANT": | |
prompt += "<|ASSISTANT|>\n" + msg["content"].strip() + "<think>\n" | |
if add_generation_prompt: | |
prompt += "<|ASSISTANT|>\n" | |
return prompt | |
def generate_response(message, chat_history, system_prompt, temperature, max_tokens, top_p, top_k, repetition_penalty): | |
# Build the conversation history. | |
conversation = [{"role": "system", "content": system_prompt}] | |
for user_msg, bot_msg in chat_history: | |
conversation.append({"role": "user", "content": user_msg}) | |
conversation.append({"role": "assistant", "content": bot_msg}) | |
conversation.append({"role": "user", "content": message}) | |
# Use the Llama-3 conversation template to build the prompt. | |
prompt = apply_llama3_chat_template(conversation, add_generation_prompt=True) | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device) | |
# Setup the streamer to yield new tokens as they are generated. | |
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) | |
# Prepare generation parameters including extra customization options. | |
generate_kwargs = { | |
"input_ids": input_ids, | |
"streamer": streamer, | |
"max_new_tokens": max_tokens, | |
"temperature": temperature, | |
"top_p": top_p, | |
"top_k": top_k, | |
"repetition_penalty": repetition_penalty, | |
"stopping_criteria": StoppingCriteriaList([StopOnTokens()]) | |
} | |
# Run the generation inside a no_grad block for speed. | |
def generate_inference(): | |
with torch.inference_mode(): | |
model.generate(**generate_kwargs) | |
Thread(target=generate_inference, daemon=True).start() | |
# Stream the output tokens. | |
partial_message = "" | |
new_history = chat_history + [(message, "")] | |
for new_token in streamer: | |
partial_message += new_token | |
formatted = format_response(partial_message) | |
new_history[-1] = (message, formatted + "▌") | |
yield new_history | |
# Final update without the cursor. | |
new_history[-1] = (message, format_response(partial_message)) | |
yield new_history | |
# Initialize the model and tokenizer globally. | |
model, tokenizer = initialize_model() | |
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
<h1 align="center">🧠 AI Reasoning Assistant</h1> | |
<p align="center">Ask me hard questions and see the reasoning unfold.</p> | |
""") | |
chatbot = gr.Chatbot(label="Conversation", elem_id="chatbot") | |
msg = gr.Textbox(label="Your Question", placeholder="Type your question...") | |
with gr.Accordion("⚙️ Settings", open=False): | |
system_prompt = gr.TextArea(value=DEFAULT_SYSTEM_PROMPT, label="System Instructions") | |
temperature = gr.Slider(0, 1, value=0.6, label="Creativity (Temperature)") | |
max_tokens = gr.Slider(128, 32768, 32768, label="Max Response Length") | |
top_p = gr.Slider(0.0, 1.0, value=0.95, label="Top P (Nucleus Sampling)") | |
top_k = gr.Slider(0, 100, value=35, label="Top K") | |
repetition_penalty = gr.Slider(0.5, 2.0, value=1.1, label="Repetition Penalty") | |
clear = gr.Button("Clear History") | |
# Link the input textbox with the generation function. | |
msg.submit( | |
generate_response, | |
[msg, chatbot, system_prompt, temperature, max_tokens, top_p, top_k, repetition_penalty], | |
chatbot, | |
show_progress=True | |
) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
if __name__ == "__main__": | |
demo.queue().launch() |