import time
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gradio as gr
import torch
model_id = r"/home/michele/PycharmProjects/mistral_finetuning/llama_ita_complete_v2"
# model_id = r"/home/michele/PycharmProjects/mistral_finetuning/mistral_ita_complete_v5"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto").eval() # to("cuda:0")
DESCRIPTION = '''
'''
PLACEHOLDER = """
DeepMount00 llama3
Chiedimi qualsiasi cosa...
"""
css = """
h1 {
text-align: center;
display: block;
}
"""
@spaces.GPU(duration=120)
def chat_llama3_8b(message: str, history: list, temperature: float, max_new_tokens: int) -> str:
# Creare la struttura della conversazione
conversation = []
for user, assistant in history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
# Preparare gli input per il modello
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
# Parametri per la generazione del testo
do_sample = True if temperature > 0 else False # Usa il campionamento a meno che la temperatura non sia 0
real_temperature = max(temperature, 0.001) # Evita temperatura 0 che disabilita il campionamento
# Generare una risposta dal modello
generated_ids = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=real_temperature,
eos_token_id=tokenizer.eos_token_id
)
# Decodificare i token generati
decoded = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
prompt_end_index = decoded[0].find(message) + len(message)
final_response = decoded[0][prompt_end_index:] if prompt_end_index != -1 else decoded[0]
return final_response.strip("assistant")
# Gradio block
chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
with gr.Blocks(fill_height=True, css=css) as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
gr.ChatInterface(
fn=chat_llama3_8b,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(minimum=0,
maximum=1,
step=0.1,
value=0.001,
label="Temperature",
render=False),
gr.Slider(minimum=128,
maximum=4096,
step=1,
value=512,
label="Max new tokens",
render=False),
],
examples=[
['Quanto è alta la torre di Pisa?'],
["Se un mattone pesa 1kg più mezzo mattone, quanto pesa il mattone? rispondi impostando l'equazione"],
['Quanto fa 9.000 * 9.000?'],
['Scrivi una funzione python che calcola i primi n numeri di fibonacci'],
['Inventa tre indovinelli tutti diversi con le relative risposte in formato json']
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()