VPTQ_demo / app.py
OpenSourceRonin's picture
disable plot
81f7d00
raw
history blame
7.03 kB
import spaces
import os
import threading
from collections import deque
import plotly.graph_objs as go
import pynvml
import gradio as gr
from huggingface_hub import snapshot_download
from vptq.app_utils import get_chat_loop_generator
models = [
{
"name": "VPTQ-community/Meta-Llama-3.1-8B-Instruct-v8-k65536-65536-woft",
"bits": "4 bits"
},
{
"name": "VPTQ-community/Meta-Llama-3.1-8B-Instruct-v8-k65536-256-woft",
"bits": "3 bits"
},
]
# Queues for storing historical data (saving the last 100 GPU utilization and memory usage values)
gpu_util_history = deque(maxlen=100)
mem_usage_history = deque(maxlen=100)
def initialize_nvml():
"""
Initialize NVML (NVIDIA Management Library).
"""
pynvml.nvmlInit()
def get_gpu_info():
"""
Get GPU utilization and memory usage information.
Returns:
dict: A dictionary containing GPU utilization and memory usage information.
"""
handle = pynvml.nvmlDeviceGetHandleByIndex(0) # Assuming a single GPU setup
utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
gpu_info = {
'gpu_util': utilization.gpu,
'mem_used': memory.used / 1024**2, # Convert bytes to MiB
'mem_total': memory.total / 1024**2, # Convert bytes to MiB
'mem_percent': (memory.used / memory.total) * 100
}
return gpu_info
def _update_charts(chart_height: int = 200) -> go.Figure:
"""
Update the GPU utilization and memory usage charts.
Args:
chart_height (int, optional): used to set the height of the chart. Defaults to 200.
Returns:
plotly.graph_objs.Figure: The updated figure containing the GPU and memory usage charts.
"""
# obtain GPU information
gpu_info = get_gpu_info()
# records the latest GPU utilization and memory usage values
gpu_util = round(gpu_info.get('gpu_util', 0), 1)
mem_used = round(gpu_info.get('mem_used', 0) / 1024, 2) # Convert MiB to GiB
gpu_util_history.append(gpu_util)
mem_usage_history.append(mem_used)
# create GPU utilization line chart
gpu_trace = go.Scatter(
y=list(gpu_util_history),
mode='lines+markers',
text=list(gpu_util_history),
line=dict(shape='spline', color='blue'), # Make the line smooth and set color
yaxis='y1' # Link to y-axis 1
)
# create memory usage line chart
mem_trace = go.Scatter(
y=list(mem_usage_history),
mode='lines+markers',
text=list(mem_usage_history),
line=dict(shape='spline', color='red'), # Make the line smooth and set color
yaxis='y2' # Link to y-axis 2
)
# set the layout of the chart
layout = go.Layout(
xaxis=dict(title=None, showticklabels=False, ticks=''),
yaxis=dict(
title='GPU Utilization (%)',
range=[-5, 110],
titlefont=dict(color='blue'),
tickfont=dict(color='blue'),
),
yaxis2=dict(title='Memory Usage (GiB)',
range=[0, max(24,
max(mem_usage_history) + 1)],
titlefont=dict(color='red'),
tickfont=dict(color='red'),
overlaying='y',
side='right'),
height=chart_height, # set the height of the chart
margin=dict(l=10, r=10, t=0, b=0), # set the margin of the chart
showlegend=False # disable the legend
)
fig = go.Figure(data=[gpu_trace, mem_trace], layout=layout)
return fig
def initialize_history():
"""
Initializes the GPU utilization and memory usage history.
"""
for _ in range(100):
gpu_info = get_gpu_info()
gpu_util_history.append(round(gpu_info.get('gpu_util', 0), 1))
mem_usage_history.append(round(gpu_info.get('mem_percent', 0), 1))
def enable_gpu_info():
pynvml.nvmlInit()
def disable_gpu_info():
pynvml.nvmlShutdown()
model_choices = [f"{model['name']} ({model['bits']})" for model in models]
display_to_model = {f"{model['name']} ({model['bits']})": model['name'] for model in models}
def download_model(model):
print(f"Downloading {model['name']}...")
snapshot_download(repo_id=model['name'])
def download_models_in_background():
print('Downloading models for the first time...')
for model in models:
download_model(model)
download_thread = threading.Thread(target=download_models_in_background)
download_thread.start()
loaded_models = {}
@spaces.GPU
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
selected_model_display_label,
):
model_name = display_to_model[selected_model_display_label]
# Check if the model is already loaded
if model_name not in loaded_models:
# Load and store the model in the cache
loaded_models[model_name] = get_chat_loop_generator(model_name)
chat_completion = loaded_models[model_name]
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for message in chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message
response += token
yield response
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
# enable_gpu_info()
with gr.Blocks(fill_height=True) as demo:
# with gr.Row():
# def update_chart():
# return _update_charts(chart_height=200)
# gpu_chart = gr.Plot(update_chart, every=0.1) # update every 0.1 seconds
with gr.Column():
chat_interface = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
gr.Dropdown(
choices=model_choices,
value=model_choices[0],
label="Select Model",
),
],
)
if __name__ == "__main__":
share = os.getenv("SHARE_LINK", None) in ["1", "true", "True"]
demo.launch(share=share)
# disable_gpu_info()