Artic-Intell / app.py
Vitrous's picture
Update app.py
7ab6c0e verified
raw
history blame
3.48 kB
import gradio as gr
import plotly.express as px
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# Set environment variables for GPU usage and memory allocation
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
torch.cuda.empty_cache()
torch.cuda.set_per_process_memory_fraction(0.8) # Adjust the fraction as needed
# Define device
device = "cuda" # The device to load the model onto
# System message (placeholder, adjust as needed)
system_message = ""
# Load the model and tokenizer
def hermes_model():
tokenizer = AutoTokenizer.from_pretrained("TheBloke/CapybaraHermes-2.5-Mistral-7B-AWQ")
model = AutoModelForCausalLM.from_pretrained(
"TheBloke/CapybaraHermes-2.5-Mistral-7B-AWQ", low_cpu_mem_usage=True, device_map="auto"
)
return model, tokenizer
model, tokenizer = hermes_model()
# Function to generate a response from the model
def chat_response(msg_prompt: str) -> str:
"""
Generates a response from the model given a prompt.
Args:
msg_prompt (str): The user's message prompt.
Returns:
str: The model's response.
"""
generation_params = {
"do_sample": True,
"temperature": 0.7,
"top_p": 0.95,
"top_k": 40,
"max_new_tokens": 512,
"repetition_penalty": 1.1,
}
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, **generation_params)
try:
prompt_template = f'''system
{system_message}
user
{msg_prompt}
assistant
'''
pipe_output = pipe(prompt_template)[0]['generated_text']
# Separate assistant's response from the output
response_lines = pipe_output.split('assistant')
assistant_response = response_lines[-1].strip() if len(response_lines) > 1 else pipe_output.strip()
return assistant_response
except Exception as e:
return str(e)
# Function to generate a random plot
def random_plot():
df = px.data.iris()
fig = px.scatter(df, x="sepal_width", y="sepal_length", color="species",
size='petal_length', hover_data=['petal_width'])
return fig
# Function to handle likes/dislikes (for demonstration purposes)
def print_like_dislike(x: gr.LikeData):
print(x.index, x.value, x.liked)
# Function to add messages to the chat history
def add_message(history, message):
for x in message["files"]:
history.append(((x,), None))
if message["text"] is not None:
history.append((message["text"], None))
return history, gr.update(value=None, interactive=True)
# Function to simulate the bot response
def bot(history):
user_message = history[-1][0]
bot_response = chat_response(user_message)
history[-1][1] = bot_response
return history
fig = random_plot()
# Gradio interface setup
with gr.Blocks(fill_height=True) as demo:
chatbot = gr.Chatbot(elem_id="chatbot", bubble_full_width=False, scale=1)
chat_input = gr.MultimodalTextbox(
interactive=True,
file_count="multiple",
placeholder="Enter message or upload file...",
show_label=False
)
chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
bot_msg = chat_msg.then(bot, chatbot, chatbot)
bot_msg.then(lambda: gr.update(interactive=True), None, [chat_input])
chatbot.like(print_like_dislike, None, None)
demo.queue()
demo.launch()