Spaces:
Runtime error
Runtime error
File size: 3,649 Bytes
1d896f1 1970af8 1d896f1 1970af8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread
import spaces
import time
import subprocess
MIN_TOKENS=128
MAX_TOKENS=8192
DEFAULT_TOKENS=2048
DURATION=60
# Install flash attention
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
# Load model and tokenizer once when the app starts
model_token = os.environ["HF_TOKEN"]
model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-128k-instruct",
token=model_token,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct", token=model_token)
# Set device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Define error handling function
def handle_error(error):
return {"error": str(error)}
# Define chat function with input validation and error handling
@spaces.GPU(duration=DURATION)
def chat(message, history, temperature, do_sample, max_tokens):
try:
# Validate input
if not message:
raise ValueError("Please enter a message")
if temperature < 0 or temperature > 1:
raise ValueError("Temperature must be between 0 and 1")
if max_tokens < MIN_TOKENS or max_tokens > MAX_TOKENS:
raise ValueError(f"Max tokens must be between {MIN_TOKENS} and {MAX_TOKENS}")
# Prepare chat history
chat = []
for item in history:
chat.append({"role": "user", "content": item[0]})
if item[1] is not None:
chat.append({"role": "assistant", "content": item[1]})
chat.append({"role": "user", "content": message})
# Generate response
messages = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([messages], return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=do_sample,
temperature=temperature,
eos_token_id=[tokenizer.eos_token_id],
)
# Generate response
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Yield partial responses
partial_text = ""
for new_text in streamer:
partial_text += new_text
yield partial_text
# Yield final response
yield partial_text
except Exception as e:
yield handle_error(e)
# Create Gradio interface
demo = gr.ChatInterface(
fn=chat,
examples=[["Write me a poem about Machine Learning."]],
additional_inputs_accordion=gr.Accordion(
label="⚙️ Parameters", open=False, render=False
),
additional_inputs=[
gr.Slider(
minimum=0, maximum=1, step=0.1, value=0.9, label="Temperature", render=False
),
gr.Checkbox(label="Sampling", value=True),
gr.Slider(
minimum=MIN_TOKENS,
maximum=MAX_TOKENS,
step=1,
value=DEFAULT_TOKENS,
label="Max new tokens",
render=False,
),
],
stop_btn="Stop Generation",
title="Chat With LLMs",
description="Now Running [microsoft/Phi-3-mini-128k-instruct](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct)",
)
# Launch Gradio app
demo.launch() |