Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
access_token = os.getenv('HF_TOKEN') | |
# Define the repository ID and access token | |
repo_id = "Mikhil-jivus/Llama-32-3B-FineTuned" | |
# Load the tokenizer and model from the Hugging Face repository | |
tokenizer = AutoTokenizer.from_pretrained(repo_id, token=access_token) | |
model = AutoModelForCausalLM.from_pretrained( | |
repo_id, | |
token=access_token, | |
torch_dtype=torch.bfloat16, # or use torch.bfloat16 if supported | |
device_map="auto" # Automatically use available GPU/CPU efficiently | |
) | |
# Define a function to clean up any repeated segments in the generated response | |
def clean_response(response, history): | |
# Check for repetition in the response and remove it | |
if len(history) > 0: | |
last_user_message, last_bot_response = history[-1] | |
if last_bot_response in response: | |
response = response.replace(last_bot_response, "").strip() | |
return response | |
def respond( | |
message, | |
history: list[tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
# Add system prompt only once at the beginning of the conversation | |
if len(history) == 0: | |
input_text = f"system: {system_message}\nuser: {message}\n" | |
else: | |
input_text = f"user: {message}\n" | |
# Append previous conversation history to the input text | |
for user_msg, bot_msg in history: | |
input_text += f"user: {user_msg}\nassistant: {bot_msg}\n" | |
# Tokenize the input messages | |
input_ids = tokenizer.encode(input_text, return_tensors="pt") | |
# Move input_ids to the GPU | |
input_ids = input_ids.to("cuda") | |
# Create attention mask and move to GPU | |
attention_mask = input_ids.ne(tokenizer.pad_token_id).long().to("cuda") | |
# Generate a response | |
chat_history_ids = model.generate( | |
input_ids, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
pad_token_id=tokenizer.eos_token_id, | |
do_sample=True, | |
attention_mask=attention_mask, | |
) | |
# Decode the response | |
response = tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True) | |
# Clean the response to remove any repeated or unnecessary text | |
response = clean_response(response, history) | |
# Update history with the new user message and bot response | |
history.append((message, response)) | |
return response | |
# Set up the Gradio app interface | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox(value="You are a helpful and friendly assistant.", label="System message"), | |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)", | |
), | |
], | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |