import time
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gradio as gr
import torch
import spaces
model_id = "DeepMount00/Llama-3-8b-Ita"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto").eval() # to("cuda:0")
DESCRIPTION = '''
'''
PLACEHOLDER = """
DeepMount00 llama3
Chiedimi qualsiasi cosa...
"""
css = """
h1 {
text-align: center;
display: block;
}
"""
@spaces.GPU(duration=120)
def chat_llama3_8b(message: str, history: list, temperature: float, max_new_tokens: int) -> str:
# Initialize the conversation with a system prompt
conversation = [{"role": "system", "content": "Sei un assistente specializzato nella lingua italiana. Rispondi in modo preciso e dettagliato."}]
# Add historical conversation
for user, assistant in history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
# Add the current user message to the conversation
conversation.append({"role": "user", "content": message})
# Prepare the input for the model
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
# Parameters for generating text
do_sample = True if temperature > 0 else False # Use sampling unless temperature is 0
real_temperature = max(temperature, 0.001) # Avoid zero temperature which disables sampling
# Generate a response from the model
generated_ids = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=real_temperature,
eos_token_id=tokenizer.eos_token_id
)
# Decode the generated tokens
decoded = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
prompt_end_index = decoded[0].find(message) + len(message)
final_response = decoded[0][prompt_end_index:] if prompt_end_index != -1 else decoded[0]
return final_response.strip("assistant")
# Gradio block
chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
with gr.Blocks(fill_height=True, css=css) as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
gr.ChatInterface(
fn=chat_llama3_8b,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(minimum=0,
maximum=1,
step=0.1,
value=0.001,
label="Temperature",
render=False),
gr.Slider(minimum=128,
maximum=4096,
step=1,
value=512,
label="Max new tokens",
render=False),
],
examples=[
['Quanto è alta la torre di Pisa?'],
["Se un mattone pesa 1kg più mezzo mattone, quanto pesa il mattone? rispondi impostando l'equazione"],
['Quanto fa 9.000 * 9.000?'],
['Scrivi una funzione python che calcola i primi n numeri di fibonacci'],
['Inventa tre indovinelli tutti diversi con le relative risposte in formato json']
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()