crang's picture
Update app.py
1970af8 verified
raw
history blame
3.65 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread
import spaces
import time
import subprocess
MIN_TOKENS=128
MAX_TOKENS=8192
DEFAULT_TOKENS=2048
DURATION=60
# Install flash attention
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
# Load model and tokenizer once when the app starts
model_token = os.environ["HF_TOKEN"]
model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-128k-instruct",
token=model_token,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct", token=model_token)
# Set device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Define error handling function
def handle_error(error):
return {"error": str(error)}
# Define chat function with input validation and error handling
@spaces.GPU(duration=DURATION)
def chat(message, history, temperature, do_sample, max_tokens):
try:
# Validate input
if not message:
raise ValueError("Please enter a message")
if temperature < 0 or temperature > 1:
raise ValueError("Temperature must be between 0 and 1")
if max_tokens < MIN_TOKENS or max_tokens > MAX_TOKENS:
raise ValueError(f"Max tokens must be between {MIN_TOKENS} and {MAX_TOKENS}")
# Prepare chat history
chat = []
for item in history:
chat.append({"role": "user", "content": item[0]})
if item[1] is not None:
chat.append({"role": "assistant", "content": item[1]})
chat.append({"role": "user", "content": message})
# Generate response
messages = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([messages], return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=do_sample,
temperature=temperature,
eos_token_id=[tokenizer.eos_token_id],
)
# Generate response
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Yield partial responses
partial_text = ""
for new_text in streamer:
partial_text += new_text
yield partial_text
# Yield final response
yield partial_text
except Exception as e:
yield handle_error(e)
# Create Gradio interface
demo = gr.ChatInterface(
fn=chat,
examples=[["Write me a poem about Machine Learning."]],
additional_inputs_accordion=gr.Accordion(
label="⚙️ Parameters", open=False, render=False
),
additional_inputs=[
gr.Slider(
minimum=0, maximum=1, step=0.1, value=0.9, label="Temperature", render=False
),
gr.Checkbox(label="Sampling", value=True),
gr.Slider(
minimum=MIN_TOKENS,
maximum=MAX_TOKENS,
step=1,
value=DEFAULT_TOKENS,
label="Max new tokens",
render=False,
),
],
stop_btn="Stop Generation",
title="Chat With LLMs",
description="Now Running [microsoft/Phi-3-mini-128k-instruct](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct)",
)
# Launch Gradio app
demo.launch()